SplitFed Learning Explained
- SplitFed Learning is a distributed training paradigm that partitions a global neural network between client-side and server-side submodels to enhance computation efficiency and ensure data privacy.
- It integrates the principles of Split Learning and Federated Learning, using layer-wise partitioning and FedAvg aggregation to reduce communication costs and manage client heterogeneity.
- Techniques like the MU-SplitFed algorithm enable up to 2× acceleration and improved straggler resilience by incorporating local zero-order updates on the split server.
SplitFed Learning (SFL) is a distributed training paradigm that fuses the model parallelism of Split Learning (SL) with the scalability and data privacy of Federated Learning (FL). By decoupling model training through layer-wise partitioning and enabling FedAvg-style aggregation across multiple clients, SFL significantly reduces on-device computation and communication bottlenecks while maintaining strong privacy guarantees. SFL is deployed in diverse domains including resource-constrained edge networks and privacy-sensitive applications such as healthcare, and provides theoretical and practical advancements in straggler resilience, communication efficiency, heterogeneity mitigation, and convergence guarantees.
1. Core Architecture and Protocol
SFL partitions a global neural network at a designated "cut" layer. The client-side submodel (layers 1 to ) is held by each client, while the server-side submodel (layers to end) is held by the split server. Clients receive the current global client-side parameters, compute forward embeddings, and transmit these activations (smashed data) to the split server. The split server computes loss and backpropagates partial gradients to each client, which completes the backward update on its local client-side submodel. Both the split server and the Fed server then aggregate respective model weights using FedAvg. This loop, iterating over multiple communication rounds, leverages both FL's parallel client synchronization and SL's computational offloading (Thapa et al., 2020).
| Component | Role | Typical Data Exchanged |
|---|---|---|
| Clients | Compute client-side forward/backward, hold local data | Forward activations, backward gradients, w_c updates |
| Split Server | Compute server-side FP/BP, holds server-side submodel | Activations input, w_s aggregations, BP gradients out |
| Fed Server | Aggregates client-side weights with FedAvg | w_c model portions |
This design yields reduced on-device compute, lower per-round communication (via activation transfer not full model), and preserves local raw-data privacy.
2. Straggler Resilience and Unbalanced SFL Algorithms
A key challenge in synchronous SFL is the straggler effect: every global round waits for the slowest client/server communication. The MU-SplitFed algorithm addresses this by decoupling server-side progress from client delays. The split server is permitted to perform local zeroth-order (SPSA) updates per received client embedding. The protocol can be summarized as:
- Client samples random perturbation and sends embeddings , .
- The split server performs ZO local steps per client per round, then computes client-side ZO gradients.
- Each client finishes local updates and the Fed server aggregates w_c, while the split server aggregates w_s.
Theoretically, MU-SplitFed achieves convergence rate under nonconvexity, with linear-in- communication round speedup. Experiments demonstrate up to 2 acceleration over synchronous SFL baselines and superior robustness under heterogeneous straggler scenarios. The trade-off is that overly large increases ZO