Build a strategy from scratch#
Welcome to the third part of the Flower federated learning tutorial. In previous parts of this tutorial, we introduced federated learning with PyTorch and Flower (part 1) and we learned how strategies can be used to customize the execution on both the server and the clients (part 2).
In this notebook, weāll continue to customize the federated learning system we built previously by creating a custom version of FedAvg (again, using Flower and PyTorch).
Star Flower on GitHub āļø and join the Flower community on Slack to connect, ask questions, and get help: Join Slack š¼ Weād love to hear from you in the
#introductions
channel! And if anything is unclear, head over to the#questions
channel.
Letās build a new Strategy
from scratch!
Preparation#
Before we begin with the actual code, letās make sure that we have everything we need.
Installing dependencies#
First, we install the necessary packages:
[ ]:
!pip install -q flwr[simulation] torch torchvision
Now that we have all dependencies installed, we can import everything we need for this tutorial:
[ ]:
from collections import OrderedDict
from typing import Dict, List, Optional, Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import CIFAR10
import flwr as fl
DEVICE = torch.device("cpu") # Try "cuda" to train on GPU
print(
f"Training on {DEVICE} using PyTorch {torch.__version__} and Flower {fl.__version__}"
)
It is possible to switch to a runtime that has GPU acceleration enabled (on Google Colab: Runtime > Change runtime type > Hardware acclerator: GPU > Save
). Note, however, that Google Colab is not always able to offer GPU acceleration. If you see an error related to GPU availability in one of the following sections, consider switching back to CPU-based execution by setting DEVICE = torch.device("cpu")
. If the runtime has GPU acceleration enabled, you should see the output
Training on cuda
, otherwise itāll say Training on cpu
.
Data loading#
Letās now load the CIFAR-10 training and test set, partition them into ten smaller datasets (each split into training and validation set), and wrap everything in their own DataLoader
. We introduce a new parameter num_clients
which allows us to call load_datasets
with different numbers of clients.
[ ]:
NUM_CLIENTS = 10
def load_datasets(num_clients: int):
# Download and transform CIFAR-10 (train and test)
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)
trainset = CIFAR10("./dataset", train=True, download=True, transform=transform)
testset = CIFAR10("./dataset", train=False, download=True, transform=transform)
# Split training set into `num_clients` partitions to simulate different local datasets
partition_size = len(trainset) // num_clients
lengths = [partition_size] * num_clients
datasets = random_split(trainset, lengths, torch.Generator().manual_seed(42))
# Split each partition into train/val and create DataLoader
trainloaders = []
valloaders = []
for ds in datasets:
len_val = len(ds) // 10 # 10 % validation set
len_train = len(ds) - len_val
lengths = [len_train, len_val]
ds_train, ds_val = random_split(ds, lengths, torch.Generator().manual_seed(42))
trainloaders.append(DataLoader(ds_train, batch_size=32, shuffle=True))
valloaders.append(DataLoader(ds_val, batch_size=32))
testloader = DataLoader(testset, batch_size=32)
return trainloaders, valloaders, testloader
trainloaders, valloaders, testloader = load_datasets(NUM_CLIENTS)
Model training/evaluation#
Letās continue with the usual model definition (including set_parameters
and get_parameters
), training and test functions:
[ ]:
class Net(nn.Module):
def __init__(self) -> None:
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
def get_parameters(net) -> List[np.ndarray]:
return [val.cpu().numpy() for _, val in net.state_dict().items()]
def set_parameters(net, parameters: List[np.ndarray]):
params_dict = zip(net.state_dict().keys(), parameters)
state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
net.load_state_dict(state_dict, strict=True)
def train(net, trainloader, epochs: int):
"""Train the network on the training set."""
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters())
net.train()
for epoch in range(epochs):
correct, total, epoch_loss = 0, 0, 0.0
for images, labels in trainloader:
images, labels = images.to(DEVICE), labels.to(DEVICE)
optimizer.zero_grad()
outputs = net(images)
loss = criterion(net(images), labels)
loss.backward()
optimizer.step()
# Metrics
epoch_loss += loss
total += labels.size(0)
correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()
epoch_loss /= len(trainloader.dataset)
epoch_acc = correct / total
print(f"Epoch {epoch+1}: train loss {epoch_loss}, accuracy {epoch_acc}")
def test(net, testloader):
"""Evaluate the network on the entire test set."""
criterion = torch.nn.CrossEntropyLoss()
correct, total, loss = 0, 0, 0.0
net.eval()
with torch.no_grad():
for images, labels in testloader:
images, labels = images.to(DEVICE), labels.to(DEVICE)
outputs = net(images)
loss += criterion(outputs, labels).item()
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
loss /= len(testloader.dataset)
accuracy = correct / total
return loss, accuracy
Flower client#
To implement the Flower client, we (again) create a subclass of flwr.client.NumPyClient
and implement the three methods get_parameters
, fit
, and evaluate
. Here, we also pass the cid
to the client and use it log additional details:
[ ]:
class FlowerClient(fl.client.NumPyClient):
def __init__(self, cid, net, trainloader, valloader):
self.cid = cid
self.net = net
self.trainloader = trainloader
self.valloader = valloader
def get_parameters(self, config):
print(f"[Client {self.cid}] get_parameters")
return get_parameters(self.net)
def fit(self, parameters, config):
print(f"[Client {self.cid}] fit, config: {config}")
set_parameters(self.net, parameters)
train(self.net, self.trainloader, epochs=1)
return get_parameters(self.net), len(self.trainloader), {}
def evaluate(self, parameters, config):
print(f"[Client {self.cid}] evaluate, config: {config}")
set_parameters(self.net, parameters)
loss, accuracy = test(self.net, self.valloader)
return float(loss), len(self.valloader), {"accuracy": float(accuracy)}
def client_fn(cid) -> FlowerClient:
net = Net().to(DEVICE)
trainloader = trainloaders[int(cid)]
valloader = valloaders[int(cid)]
return FlowerClient(cid, net, trainloader, valloader)
Letās test what we have so far before we continue:
[ ]:
# Specify client resources if you need GPU (defaults to 1 CPU and 0 GPU)
client_resources = None
if DEVICE.type == "cuda":
client_resources = {"num_gpus": 1}
fl.simulation.start_simulation(
client_fn=client_fn,
num_clients=2,
config=fl.server.ServerConfig(num_rounds=3),
client_resources=client_resources,
)
Build a Strategy from scratch#
Letās overwrite the configure_fit
method such that it passes a higher learning rate (potentially also other hyperparameters) to the optimizer of a fraction of the clients. We will keep the sampling of the clients as it is in FedAvg
and then change the configuration dictionary (one of the FitIns
attributes).
[ ]:
from typing import Callable, Union
from flwr.common import (
EvaluateIns,
EvaluateRes,
FitIns,
FitRes,
MetricsAggregationFn,
NDArrays,
Parameters,
Scalar,
ndarrays_to_parameters,
parameters_to_ndarrays,
)
from flwr.server.client_manager import ClientManager
from flwr.server.client_proxy import ClientProxy
from flwr.server.strategy.aggregate import aggregate, weighted_loss_avg
class FedCustom(fl.server.strategy.Strategy):
def __init__(
self,
fraction_fit: float = 1.0,
fraction_evaluate: float = 1.0,
min_fit_clients: int = 2,
min_evaluate_clients: int = 2,
min_available_clients: int = 2,
) -> None:
super().__init__()
self.fraction_fit = fraction_fit
self.fraction_evaluate = fraction_evaluate
self.min_fit_clients = min_fit_clients
self.min_evaluate_clients = min_evaluate_clients
self.min_available_clients = min_available_clients
def __repr__(self) -> str:
return "FedCustom"
def initialize_parameters(
self, client_manager: ClientManager
) -> Optional[Parameters]:
"""Initialize global model parameters."""
net = Net()
ndarrays = get_parameters(net)
return fl.common.ndarrays_to_parameters(ndarrays)
def configure_fit(
self, server_round: int, parameters: Parameters, client_manager: ClientManager
) -> List[Tuple[ClientProxy, FitIns]]:
"""Configure the next round of training."""
# Sample clients
sample_size, min_num_clients = self.num_fit_clients(
client_manager.num_available()
)
clients = client_manager.sample(
num_clients=sample_size, min_num_clients=min_num_clients
)
# Create custom configs
n_clients = len(clients)
half_clients = n_clients // 2
standard_config = {"lr": 0.001}
higher_lr_config = {"lr": 0.003}
fit_configurations = []
for idx, client in enumerate(clients):
if idx < half_clients:
fit_configurations.append((client, FitIns(parameters, standard_config)))
else:
fit_configurations.append(
(client, FitIns(parameters, higher_lr_config))
)
return fit_configurations
def aggregate_fit(
self,
server_round: int,
results: List[Tuple[ClientProxy, FitRes]],
failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],
) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:
"""Aggregate fit results using weighted average."""
weights_results = [
(parameters_to_ndarrays(fit_res.parameters), fit_res.num_examples)
for _, fit_res in results
]
parameters_aggregated = ndarrays_to_parameters(aggregate(weights_results))
metrics_aggregated = {}
return parameters_aggregated, metrics_aggregated
def configure_evaluate(
self, server_round: int, parameters: Parameters, client_manager: ClientManager
) -> List[Tuple[ClientProxy, EvaluateIns]]:
"""Configure the next round of evaluation."""
if self.fraction_evaluate == 0.0:
return []
config = {}
evaluate_ins = EvaluateIns(parameters, config)
# Sample clients
sample_size, min_num_clients = self.num_evaluation_clients(
client_manager.num_available()
)
clients = client_manager.sample(
num_clients=sample_size, min_num_clients=min_num_clients
)
# Return client/config pairs
return [(client, evaluate_ins) for client in clients]
def aggregate_evaluate(
self,
server_round: int,
results: List[Tuple[ClientProxy, EvaluateRes]],
failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]],
) -> Tuple[Optional[float], Dict[str, Scalar]]:
"""Aggregate evaluation losses using weighted average."""
if not results:
return None, {}
loss_aggregated = weighted_loss_avg(
[
(evaluate_res.num_examples, evaluate_res.loss)
for _, evaluate_res in results
]
)
metrics_aggregated = {}
return loss_aggregated, metrics_aggregated
def evaluate(
self, server_round: int, parameters: Parameters
) -> Optional[Tuple[float, Dict[str, Scalar]]]:
"""Evaluate global model parameters using an evaluation function."""
# Let's assume we won't perform the global model evaluation on the server side.
return None
def num_fit_clients(self, num_available_clients: int) -> Tuple[int, int]:
"""Return sample size and required number of clients."""
num_clients = int(num_available_clients * self.fraction_fit)
return max(num_clients, self.min_fit_clients), self.min_available_clients
def num_evaluation_clients(self, num_available_clients: int) -> Tuple[int, int]:
"""Use a fraction of available clients for evaluation."""
num_clients = int(num_available_clients * self.fraction_evaluate)
return max(num_clients, self.min_evaluate_clients), self.min_available_clients
The only thing left is to use the newly created custom Strategy FedCustom
when starting the experiment:
[ ]:
fl.simulation.start_simulation(
client_fn=client_fn,
num_clients=2,
config=fl.server.ServerConfig(num_rounds=3),
strategy=FedCustom(), # <-- pass the new strategy here
client_resources=client_resources,
)
Recap#
In this notebook, weāve seen how to implement a custom strategy. A custom strategy enables granular control over client node configuration, result aggregation, and more. To define a custom strategy, you only have to overwrite the abstract methods of the (abstract) base class Strategy
. To make custom strategies even more powerful, you can pass custom functions to the constructor of your new class (__init__
) and then call these functions whenever needed.
Next steps#
Before you continue, make sure to join the Flower community on Slack: Join Slack
Thereās a dedicated #questions
channel if you need help, but weād also love to hear who you are in #introductions
!
The Flower Federated Learning Tutorial - Part 4 introduces Client
, the flexible API underlying NumPyClient
.