FedBox is a toolbox for federated learning research.
Install from source:
git clone https://github.com/zhb2000/fedbox.git
cd fedbox
pip install .
Editable installation (for development):
# For an editable installation, replace 'pip install .' to
pip install --editable .
The module fedbox.utils.data
provides several non-IID data splitting setups for federated learning research.
This non-IID setup is first introduced in Communication-Efficient Learning of Deep Networks from Decentralized Data. In this setup, each client only contains samples from a few classes.
The quantity-based label skew is implemented in the function split_by_label
. The parameter class_per_client
is used to control the number of classes in each client.
The following code demonstrates how to split the MNIST dataset into 100 clients, and each client only contains samples from 2 classes. We split and save the indices of samples instead of the samples themselves for each client, which gives us more flexibility. The splitting results are saved in a JSON file.
import json
import numpy as np
from torchvision.datasets import MNIST
from torchvision.transforms import Compose, ToTensor, Normalize
from fedbox.utils.data import split_by_label
if __name__ == '__main__':
client_num = 100
class_per_client = 2
splitting_file = 'mnist-splitting.json'
mnist = MNIST(
train=True,
transform=Compose([ToTensor(), Normalize([0.5], [0.5])])
)
mnist_labels = mnist.targets.numpy()
mnist_indices = np.arange(len(mnist))
results: list[tuple[np.ndarray, np.ndarray]] = split_by_label(
mnist_indices,
mnist_labels,
client_num,
class_per_client=class_per_client
)
with open(splitting_file, 'w') as file:
json.dump([indices.tolist() for indices, _ in results], file, indent=4)
The following code demonstrates how to read the dataset subset for each client. We read the sample indices for each client from the JSON file, then use Subset
to build the local dataset for each client.
from torch.utils.data import Subset
def read_clients_dataset(mnist: MNIST, splitting_file: str) -> list[Subset]:
"""Read the dataset subset for each client."""
with open(splitting_file) as file:
results: list[list[int]] = json.load(file)
return [Subset(mnist, indices) for indices in results]
This setup is first introduced in Bayesian Nonparametric Federated Learning of Neural Networks. In this setup, each client is allocated a certain proportion of samples from each class according to the Dirichlet distribution. Changing
The distribution-based label skew is implemented in the function split_dirichlet_label
. The parameter alpha
is used to control the concentration parameter of the Dirichlet distribution.
import json
import numpy as np
from torchvision.datasets import MNIST
from torchvision.transforms import Compose, ToTensor, Normalize
from fedbox.utils.data import split_dirichlet_label
if __name__ == '__main__':
client_num = 100
alpha = 0.1
splitting_file = 'mnist-splitting.json'
mnist = MNIST(
train=True,
transform=Compose([ToTensor(), Normalize([0.5], [0.5])])
)
mnist_labels = mnist.targets.numpy()
mnist_indices = np.arange(len(mnist))
results: list[tuple[np.ndarray, np.ndarray]] = split_dirichlet_label(
mnist_indices,
mnist_labels,
client_num,
alpha=alpha
)
with open(splitting_file, 'w') as file:
json.dump([indices.tolist() for indices, _ in results], file, indent=4)
This setup is first introduced in Federated Learning on Non-IID Data Silos: An Experimental Study. In this setup, the size of the local dataset varies across clients. Dirichlet distribution to allocate different amounts of data samples into each client.
The quantity skew is implemented in the function split_dirichlet_quantity
.
Use model_average
to perform model averaging. The function supports both simple model averaging and weighted model averaging.
from fedbox.utils.functional import model_average
global_model: Module = ...
local_models: list[Module] = ...
# simple model averaging
global_model.load_state_dict(model_average(local_models))
# weighted model averaging
weights = [0.5, 0.3, 0.2]
global_model.load_state_dict(model_average(local_models, weights))
Use model_aggregate
to implement custom aggregate operations.
Aggregate two models:
from fedbox.utils.functional import model_aggregate
global_model: Module = ...
local_model_a: Module = ...
local_model_b: Module = ...
def custom_aggregate(a: Tensor, b: Tensor) -> Tensor:
return (a + b) / 2
result = model_aggregate(
custom_aggregate,
local_model_a, local_model_b
)
global_model.load_state_dict(result)
Aggregate a sequence of models:
global_model: Module = ...
local_models: list[Module] = ...
def custom_aggregate(params: Sequence[Tensor]) -> Tensor:
return sum(params) / len(params)
result = model_aggregate(custom_aggregate, local_models)
global_model.load_state_dict(result)
AverageMeter
is a tool for tracking the average value of multiple metrics during the training process. For example, you can use it to track the average loss value of each epoch.
from fedbox.utils.training import AverageMeter
meter = AverageMeter()
for epoch in range(epochs):
meter.clear() # clear the meter at the beginning of each epoch
for batch in dataloader:
cls_loss = ... # classification loss
reg_loss = ... # regularization loss
loss = cls_loss + reg_loss # total loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
# add the loss values to the meter
meter.add(
loss=loss.item(),
cls_loss=cls_loss.item(),
reg_loss=reg_loss.item()
)
# print the average loss values of the epoch
print(
f'epoch {epoch}' +
f', loss: {meter["loss"]:.4f}' +
f', cls_loss: {meter["cls_loss"]:.4f}' +
f', reg_loss: {meter["reg_loss"]:.4f}'
)
Use Recorder
to perform early stopping during training processes.
from fedbox.utils.training import Recorder
stopper = Recorder(higher_better=True, patience=10)
for epoch in range(epochs):
train(...)
f1, acc = validate(...)
# use 'f1' as the early stopping metric, and record the corresponding 'acc' and 'epoch'
is_best = stopper.update(f1, acc=acc, epoch=epoch)
print(f'epoch {epoch}, is best: {is_best}, f1: {f1:.4f}, acc: {acc:.4f}, ')
if stopper.reach_stop():
break
# print final result
print(f'best f1: {stopper.best_metric}, best epoch: {stopper["epoch"]}, acc: {stopper["acc"]}')
Use freeze_model
and unfreeze_model
to freeze and unfreeze the model's parameters. module_requires_grad_
is a helper function that sets the requires_grad
attribute of the model's parameters.
from fedbox.utils.training import freeze_model, unfreeze_model, module_requires_grad_
model: Module = ...
# freeze the model
freeze_model(model)
# unfreeze the model
unfreeze_model(model)
# set the 'requires_grad' attribute of the model's parameters
module_requires_grad_(model, True)
# Equivalent to
for param in model.parameters():
param.requires_grad_(True)