Federated Attentive Message Passing (FedAMP)
- FedAMP is a federated learning strategy that uses pairwise adaptive attention to personalize client models on non-IID data.
- It alternates between attentive message passing on the server and local proximal updates on client devices to optimize personalized parameters.
- Empirical evaluations on FMNIST, EMNIST, and CIFAR100 show improved accuracy over traditional federated methods with robust convergence guarantees.
Federated Attentive Message Passing (FedAMP) is a federated learning methodology designed to enable client models to collaborate via adaptive, pairwise attention mechanisms, with the specific goal of improving performance on non-IID data distributions. By leveraging attention-inducing communication between models, FedAMP personalizes learned parameters for each client while maximizing the benefits of inter-client similarity. The approach was introduced as a solution to the persistent challenge of non-IID data in cross-silo federated learning, offering provable convergence, practical robustness, and demonstrably superior empirical results when compared to established methods (Huang et al., 2020).
1. Formal Problem Framework and Objective
Let denote the number of clients, each indexed by . Each client holds:
- A private dataset sampled from distribution (non-IID over ).
- Local model parameters .
- A loss function .
The global objective is to learn personalized parameters such that each is near-optimal for its own distribution 0, while still exploiting cross-client similarities. This leads to the following aggregate optimization target:
1
where 2, 3 balances personalization/collaboration, and 4 is a concave, increasing penalty that induces attention. An example is 5.
2. FedAMP Algorithmic Structure
FedAMP implements an alternating incremental-proximal optimization on 6. Each communication round 7 proceeds as follows:
- Message-Passing / Attention Step (Server-Side):
For each client 8, compute the attentive aggregate:
9
where 0 for 1, and 2. The update can also be regarded as a perturbed gradient step:
3
- Local Proximal Update (Client-Side):
Each client 4 solves:
5
In practice this is implemented via a small number of local SGD or Adam steps.
The algorithm iterates these two steps for 6. Pseudocode matching the above logic is presented in the original work.
3. Attention Mechanism and Similarity Adaptation
The attention kernel 7 serves as a nonincreasing, nonnegative similarity function:
- Small 8 yields large 9, encouraging strong pairwise collaboration.
- Large 0 yields small 1, limiting influence across dissimilar clients.
A widely used instantiation is the RBF kernel, 2, so 3. Consequently, 4.
The attention coefficients 5 thus implement a form of adaptive, pairwise, non-linear communication, automatically amplifying within-cluster collaboration on non-IID data.
4. Theoretical Convergence Analysis
FedAMP offers convergence guarantees for both convex and nonconvex formulations of the objective 6, under bounded-gradient assumptions:
- Convex Case: If each 7 and 8 are convex, and 9,
0
Diminishing 1 ensuring 2 yields 3.
- Smooth, Nonconvex Case: If 4 and 5 are 6-smooth, and 7,
8
With diminishing 9 as above, any limit point of 0 is stationary.
The two-step update is interpretable as a proximal-gradient procedure, and analysis leverages established incremental/proximal methods.
5. Heuristic Extension for Deep Neural Models
For high-dimensional parameterizations (1 large, as in DNNs), Euclidean distances become less meaningful. The heuristic variant "HeurFedAMP" alters the computation of 2:
- Set self-attention 3 to 4 (e.g., 5).
- For 6,
7
where 8 is cosine similarity and 9 a temperature parameter.
This maintains 0 while biasing attention based on angular rather than Euclidean closeness, empirically improving performance on DNNs.
6. Empirical Evaluation and Results
FedAMP and its heuristic extension are evaluated on MNIST, FMNIST, EMNIST, and CIFAR100 datasets with client partitions covering IID, pathological non-IID (each client only 2 labels), and practical non-IID (clients in 3 clusters with unbalanced samples).
Mean testing accuracy (BMTA) under the practical non-IID scenario (mean over clients):
| Dataset | FedAvg | FedProx | APFL | FedAMP | HeurFedAMP |
|---|---|---|---|---|---|
| FMNIST | 79.5% | 78.7% | 84.1% | 91.0% | 91.4% |
| EMNIST | N/A | N/A | N/A | 81.2% | 81.5% |
| CIFAR100 | 35.2% | 37.3% | N/A | N/A | 53.3% |
Pairwise attention heatmaps (EMNIST, clients 0–61) reveal that attention coefficients form clear blocks, aligning with ground-truth clusters—FedAMP automatically learns and exploits such latent structure.
7. Practical Guidelines and Implications
Key operational insights include:
- Data regime sensitivity: On IID data, FedAMP reduces to global averaging (like FedAvg); on clustered non-IID data, it amplifies within-cluster collaboration.
- Hyperparameters: 1 balances personalization/collaboration; initial 2 should be 3 then decay 4; attention kernel 5 must be tuned; self-attention 6 in HeurFedAMP set to 7.
- Robustness: Proximal step only requires available 8—drops are naturally handled; attention down-weights corrupted or noisy clients, conferring resilience to label noise.
FedAMP constitutes a principled, provably convergent, and empirically validated framework for federated learning with adaptive, pairwise, non-linear collaboration, with particular effectiveness on non-IID problems and high-dimensional models (Huang et al., 2020).