Abstract: Security and privacy concerns in real-world applications have led to the development of adversarially robust federated models. Previous works mainly target overcoming the adaptability constraints regarding communication and computation costs. However, the straightforward combination of adversarial training and federated learning might lead to undesired robust accuracy degradation emerging at later training stages. We reveal that the attribution behind this phenomenon is that the generated adversarial data could exacerbate the data heterogeneity among local clients, making the wrapped federated learning perform poorly. To deal with this problem, we introduce an α-slack mechanism to relax the original learning objective of federated adversarial training, and propose a novel framework called Slack Federated Adversarial Training (SFAT) to combat the intensified heterogeneity. By assigning the client-wise slack during aggregation, SFAT realizes a weighted aggregation that alleviates the optimization bias induced by the local adversarial generation. We further extend to a more general setting, permitting both clients trained by standard/adversarial training in a unified framework, and propose SFAT* with a hierarchical aggregation schema for this scenario. Theoretically, we analyze the convergence of our method to properly relax the learning objective. Experimentally, we verify the rationality and effectiveness of our methods on various benchmarked and real-world datasets with different adversarial training and federated optimization methods.
Key Words: Adversarial Robustness, Exacerbated Heterogeneity, Federated Learning. [TPAMI 2025]
We extend the original setting of federated adversarial training to a more general and practical one that permits both adversarial/standard training in a unified framework, which better reflects real-world federated systems where client capabilities and training modes vary. We generalize SFAT to an extended version, termed SFAT*, with a hierarchical aggregation scheme that realizes a discriminative slack. It captures the complex heterogeneity dynamics to avoid conflict relaxation.
Figure 1. Exacerbated heterogeneity in federated adversarial training.
Figure 2. Over-relaxation in federated standard training.
Python (3.8)
Pytorch (1.7.0 or above)
torchvision
CUDA
Numpy
./SFAT-main
├─ Centralized_AT.py # Training and evaluation
├─ SFAT.py
├─ attack_generator.py # Attack generation
├─ eval_pgd.py
├─ logger.py # Log support
├─ models.py
├─ options.py # Options and hyperparameters
├─ readme.md
├─ sampling.py # Data split
├─ update.py
└─ utils.py # Aggregation and other utils
To train federated robust model, we provide examples below to use our code:
CUDA_VISIBLE_DEVICES='0' python SFAT.py --dataset=cifar-10 --local_ep=10 --local_bs=32 --iid=0 --epochs=100 --num_users=5 --agg-opt='FedAvg' --agg-center='FedAvg' --out-dir='../output_results_FAT_FedAvg'
CUDA_VISIBLE_DEVICES='1' python SFAT.py --dataset=cifar-10 --local_ep=10 --local_bs=32 --iid=0 --epochs=100 --num_users=5 --agg-opt='FedAvg' --agg-center='SFAT' --pri=1.2 --out-dir='../output_results_SFAT_FedAvg'Compared with FAT, our proposed SFAT selectively upweights/downweights the client with small/large adversarial training loss to alleviate it during aggregation, which follows our
Following the conventional federated learning realization, we realizes the overall framework of SFAT in SFAT.py which coordinate the local optimization part in update.py and the aggregation functions in utils.py.
In SFAT.py, we get the local model in each client and aggregate the global model.
# local updates
for idx in idxs_users:
local_model = LocalUpdate(args=args, dataset=train_dataset, idxs=user_groups[idx], logger=logger, alg=args.agg_opt, anchor=global_model, anchor_mu=args.mu, local_rank=ipx, method=args.train_method)
''' '''
# aggregation method
if args.agg_center == 'FedAvg':
global_weights = average_weights(local_weights)
if args.agg_center == 'SFAT' and args.train_method != 'ST+AT': #SFAT
''' '''
global_weights = average_weights_alpha(local_weights, idt, idtxnum, args.pri)
if args.agg_center == 'SFAT' and args.train_method == 'ST+AT': #SFAT-Star
''' '''
global_weights_1 = average_weights_alpha(local_weights[split_loc:], idt_1, idtxnum, args.pri)
global_weights_2 = average_weights(local_weights[:split_loc])In updates.py, we realize the local training on each client for adversarial training and defined the LocalUpdate().
In utils.py, we realize the aggregation methods and define the FAT, i.e., average_weights() and SFAT average_weights_alpha() as well as their unequal versions. For the our SFAT, the critical part of code is as follows, where the lw and idx is to help choose the corresponding clients and the p is our
We realize the operation of data split in sampling.py and utilized in utils.py for generate local data loader for each client. We can use our pre-defined split function as following to get the local data.
def get_dataset(args):
''' '''
user_groups = cifar_noniid_skew(train_dataset, args.num_users)
''' '''
return train_dataset, test_dataset, user_groupsTo choose different federated optimization methods (e.g., FedAvg, FedProx, Scaffold) and the aggregations (e.g., FAT and SFAT) for training robust federated model. We can used defined parameter in our options.py:
parser.add_argument('--agg-opt',type=str,default='FedAvg',help='option of on-device learning: FedAvg, FedProx, Scaffold')
parser.add_argument('--agg-center',type=str,default='FedAvg',help='option of aggregation: FedAvg, SFAT')To evaluate our trained model using various attack methods, we provide the eval_pgd.py contains different evaluation metrics for natural and robust performance. You can run the following script with your model path to conduct evaluation:
CUDA_VISIBLE_DEVICES='0' python eval_pgd.py --net [NETWORK STRUCTURE] --dataset [DATASET] --model_path [MODLE PATH]Actually, during the training, we also provide the accuracy track via logger.py to save the model performance in each epoch.
Either the local optimization or aggregation method can be re-designed based on our framework in the corresponding updates.py and utils.py part.

