Use a federated learning strategy#
Welcome to the next part of the federated learning tutorial. In previous parts of this tutorial, we introduced federated learning with PyTorch and Flower (part 1).
In this notebook, weāll begin to customize the federated learning system we built in the introductory notebook (again, using Flower and PyTorch).
Star Flower on GitHub āļø and join the Flower community on Slack to connect, ask questions, and get help: Join Slack š¼ Weād love to hear from you in the
#introductions
channel! And if anything is unclear, head over to the#questions
channel.
Letās move beyond FedAvg with Flower strategies!
Preparation#
Before we begin with the actual code, letās make sure that we have everything we need.
Installing dependencies#
First, we install the necessary packages:
[ ]:
!pip install -q flwr[simulation] torch torchvision
Now that we have all dependencies installed, we can import everything we need for this tutorial:
[ ]:
from collections import OrderedDict
from typing import Dict, List, Optional, Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import CIFAR10
import flwr as fl
DEVICE = torch.device("cpu") # Try "cuda" to train on GPU
print(
f"Training on {DEVICE} using PyTorch {torch.__version__} and Flower {fl.__version__}"
)
It is possible to switch to a runtime that has GPU acceleration enabled (on Google Colab: Runtime > Change runtime type > Hardware acclerator: GPU > Save
). Note, however, that Google Colab is not always able to offer GPU acceleration. If you see an error related to GPU availability in one of the following sections, consider switching back to CPU-based execution by setting DEVICE = torch.device("cpu")
. If the runtime has GPU acceleration enabled, you should see the output
Training on cuda
, otherwise itāll say Training on cpu
.
Data loading#
Letās now load the CIFAR-10 training and test set, partition them into ten smaller datasets (each split into training and validation set), and wrap everything in their own DataLoader
. We introduce a new parameter num_clients
which allows us to call load_datasets
with different numbers of clients.
[ ]:
NUM_CLIENTS = 10
def load_datasets(num_clients: int):
# Download and transform CIFAR-10 (train and test)
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)
trainset = CIFAR10("./dataset", train=True, download=True, transform=transform)
testset = CIFAR10("./dataset", train=False, download=True, transform=transform)
# Split training set into `num_clients` partitions to simulate different local datasets
partition_size = len(trainset) // num_clients
lengths = [partition_size] * num_clients
datasets = random_split(trainset, lengths, torch.Generator().manual_seed(42))
# Split each partition into train/val and create DataLoader
trainloaders = []
valloaders = []
for ds in datasets:
len_val = len(ds) // 10 # 10 % validation set
len_train = len(ds) - len_val
lengths = [len_train, len_val]
ds_train, ds_val = random_split(ds, lengths, torch.Generator().manual_seed(42))
trainloaders.append(DataLoader(ds_train, batch_size=32, shuffle=True))
valloaders.append(DataLoader(ds_val, batch_size=32))
testloader = DataLoader(testset, batch_size=32)
return trainloaders, valloaders, testloader
trainloaders, valloaders, testloader = load_datasets(NUM_CLIENTS)
Model training/evaluation#
Letās continue with the usual model definition (including set_parameters
and get_parameters
), training and test functions:
[ ]:
class Net(nn.Module):
def __init__(self) -> None:
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
def get_parameters(net) -> List[np.ndarray]:
return [val.cpu().numpy() for _, val in net.state_dict().items()]
def set_parameters(net, parameters: List[np.ndarray]):
params_dict = zip(net.state_dict().keys(), parameters)
state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
net.load_state_dict(state_dict, strict=True)
def train(net, trainloader, epochs: int):
"""Train the network on the training set."""
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters())
net.train()
for epoch in range(epochs):
correct, total, epoch_loss = 0, 0, 0.0
for images, labels in trainloader:
images, labels = images.to(DEVICE), labels.to(DEVICE)
optimizer.zero_grad()
outputs = net(images)
loss = criterion(net(images), labels)
loss.backward()
optimizer.step()
# Metrics
epoch_loss += loss
total += labels.size(0)
correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()
epoch_loss /= len(trainloader.dataset)
epoch_acc = correct / total
print(f"Epoch {epoch+1}: train loss {epoch_loss}, accuracy {epoch_acc}")
def test(net, testloader):
"""Evaluate the network on the entire test set."""
criterion = torch.nn.CrossEntropyLoss()
correct, total, loss = 0, 0, 0.0
net.eval()
with torch.no_grad():
for images, labels in testloader:
images, labels = images.to(DEVICE), labels.to(DEVICE)
outputs = net(images)
loss += criterion(outputs, labels).item()
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
loss /= len(testloader.dataset)
accuracy = correct / total
return loss, accuracy
Flower client#
To implement the Flower client, we (again) create a subclass of flwr.client.NumPyClient
and implement the three methods get_parameters
, fit
, and evaluate
. Here, we also pass the cid
to the client and use it log additional details:
[ ]:
class FlowerClient(fl.client.NumPyClient):
def __init__(self, cid, net, trainloader, valloader):
self.cid = cid
self.net = net
self.trainloader = trainloader
self.valloader = valloader
def get_parameters(self, config):
print(f"[Client {self.cid}] get_parameters")
return get_parameters(self.net)
def fit(self, parameters, config):
print(f"[Client {self.cid}] fit, config: {config}")
set_parameters(self.net, parameters)
train(self.net, self.trainloader, epochs=1)
return get_parameters(self.net), len(self.trainloader), {}
def evaluate(self, parameters, config):
print(f"[Client {self.cid}] evaluate, config: {config}")
set_parameters(self.net, parameters)
loss, accuracy = test(self.net, self.valloader)
return float(loss), len(self.valloader), {"accuracy": float(accuracy)}
def client_fn(cid) -> FlowerClient:
net = Net().to(DEVICE)
trainloader = trainloaders[int(cid)]
valloader = valloaders[int(cid)]
return FlowerClient(cid, net, trainloader, valloader)
Strategy customization#
So far, everything should look familiar if youāve worked through the introductory notebook. With that, weāre ready to introduce a number of new features.
Server-side parameter initialization#
Flower, by default, initializes the global model by asking one random client for the initial parameters. In many cases, we want more control over parameter initialization though. Flower therefore allows you to directly pass the initial parameters to the Strategy:
[ ]:
# Create an instance of the model and get the parameters
params = get_parameters(Net())
# Pass parameters to the Strategy for server-side parameter initialization
strategy = fl.server.strategy.FedAvg(
fraction_fit=0.3,
fraction_evaluate=0.3,
min_fit_clients=3,
min_evaluate_clients=3,
min_available_clients=NUM_CLIENTS,
initial_parameters=fl.common.ndarrays_to_parameters(params),
)
# Specify client resources if you need GPU (defaults to 1 CPU and 0 GPU)
client_resources = None
if DEVICE.type == "cuda":
client_resources = {"num_gpus": 1}
# Start simulation
fl.simulation.start_simulation(
client_fn=client_fn,
num_clients=NUM_CLIENTS,
config=fl.server.ServerConfig(num_rounds=3), # Just three rounds
strategy=strategy,
client_resources=client_resources,
)
Passing initial_parameters
to the FedAvg
strategy prevents Flower from asking one of the clients for the initial parameters. If we look closely, we can see that the logs do not show any calls to the FlowerClient.get_parameters
method.
Starting with a customized strategy#
Weāve seen the function start_simulation
before. It accepts a number of arguments, amongst them the client_fn
used to create FlowerClient
instances, the number of clients to simulate num_clients
, the number of rounds num_rounds
, and the strategy.
The strategy encapsulates the federated learning approach/algorithm, for example, FedAvg
or FedAdagrad
. Letās try to use a different strategy this time:
[ ]:
# Create FedAdam strategy
strategy = fl.server.strategy.FedAdagrad(
fraction_fit=0.3,
fraction_evaluate=0.3,
min_fit_clients=3,
min_evaluate_clients=3,
min_available_clients=NUM_CLIENTS,
initial_parameters=fl.common.ndarrays_to_parameters(get_parameters(Net())),
)
# Start simulation
fl.simulation.start_simulation(
client_fn=client_fn,
num_clients=NUM_CLIENTS,
config=fl.server.ServerConfig(num_rounds=3), # Just three rounds
strategy=strategy,
client_resources=client_resources,
)
Server-side parameter evaluation#
Flower can evaluate the aggregated model on the server-side or on the client-side. Client-side and server-side evaluation are similar in some ways, but different in others.
Centralized Evaluation (or server-side evaluation) is conceptually simple: it works the same way that evaluation in centralized machine learning does. If there is a server-side dataset that can be used for evaluation purposes, then thatās great. We can evaluate the newly aggregated model after each round of training without having to send the model to clients. Weāre also fortunate in the sense that our entire evaluation dataset is available at all times.
Federated Evaluation (or client-side evaluation) is more complex, but also more powerful: it doesnāt require a centralized dataset and allows us to evaluate models over a larger set of data, which often yields more realistic evaluation results. In fact, many scenarios require us to use Federated Evaluation if we want to get representative evaluation results at all. But this power comes at a cost: once we start to evaluate on the client side, we should be aware that our evaluation dataset can change over consecutive rounds of learning if those clients are not always available. Moreover, the dataset held by each client can also change over consecutive rounds. This can lead to evaluation results that are not stable, so even if we would not change the model, weād see our evaluation results fluctuate over consecutive rounds.
Weāve seen how federated evaluation works on the client side (i.e., by implementing the evaluate
method in FlowerClient
). Now letās see how we can evaluate aggregated model parameters on the server-side:
[ ]:
# The `evaluate` function will be by Flower called after every round
def evaluate(
server_round: int,
parameters: fl.common.NDArrays,
config: Dict[str, fl.common.Scalar],
) -> Optional[Tuple[float, Dict[str, fl.common.Scalar]]]:
net = Net().to(DEVICE)
valloader = valloaders[0]
set_parameters(net, parameters) # Update model with the latest parameters
loss, accuracy = test(net, valloader)
print(f"Server-side evaluation loss {loss} / accuracy {accuracy}")
return loss, {"accuracy": accuracy}
[ ]:
strategy = fl.server.strategy.FedAvg(
fraction_fit=0.3,
fraction_evaluate=0.3,
min_fit_clients=3,
min_evaluate_clients=3,
min_available_clients=NUM_CLIENTS,
initial_parameters=fl.common.ndarrays_to_parameters(get_parameters(Net())),
evaluate_fn=evaluate, # Pass the evaluation function
)
fl.simulation.start_simulation(
client_fn=client_fn,
num_clients=NUM_CLIENTS,
config=fl.server.ServerConfig(num_rounds=3), # Just three rounds
strategy=strategy,
client_resources=client_resources,
)
Sending/receiving arbitrary values to/from clients#
In some situations, we want to configure client-side execution (training, evaluation) from the server-side. One example for that is the server asking the clients to train for a certain number of local epochs. Flower provides a way to send configuration values from the server to the clients using a dictionary. Letās look at an example where the clients receive values from the server through the config
parameter in fit
(config
is also available in evaluate
). The fit
method
receives the configuration dictionary through the config
parameter and can then read values from this dictionary. In this example, it reads server_round
and local_epochs
and uses those values to improve the logging and configure the number of local training epochs:
[ ]:
class FlowerClient(fl.client.NumPyClient):
def __init__(self, cid, net, trainloader, valloader):
self.cid = cid
self.net = net
self.trainloader = trainloader
self.valloader = valloader
def get_parameters(self, config):
print(f"[Client {self.cid}] get_parameters")
return get_parameters(self.net)
def fit(self, parameters, config):
# Read values from config
server_round = config["server_round"]
local_epochs = config["local_epochs"]
# Use values provided by the config
print(f"[Client {self.cid}, round {server_round}] fit, config: {config}")
set_parameters(self.net, parameters)
train(self.net, self.trainloader, epochs=local_epochs)
return get_parameters(self.net), len(self.trainloader), {}
def evaluate(self, parameters, config):
print(f"[Client {self.cid}] evaluate, config: {config}")
set_parameters(self.net, parameters)
loss, accuracy = test(self.net, self.valloader)
return float(loss), len(self.valloader), {"accuracy": float(accuracy)}
def client_fn(cid) -> FlowerClient:
net = Net().to(DEVICE)
trainloader = trainloaders[int(cid)]
valloader = valloaders[int(cid)]
return FlowerClient(cid, net, trainloader, valloader)
So how can we send this config dictionary from server to clients? The built-in Flower Strategies provide way to do this, and it works similarly to the way server-side evaluation works. We provide a function to the strategy, and the strategy calls this function for every round of federated learning:
[ ]:
def fit_config(server_round: int):
"""Return training configuration dict for each round.
Perform two rounds of training with one local epoch, increase to two local
epochs afterwards.
"""
config = {
"server_round": server_round, # The current round of federated learning
"local_epochs": 1 if server_round < 2 else 2, #
}
return config
Next, weāll just pass this function to the FedAvg strategy before starting the simulation:
[ ]:
strategy = fl.server.strategy.FedAvg(
fraction_fit=0.3,
fraction_evaluate=0.3,
min_fit_clients=3,
min_evaluate_clients=3,
min_available_clients=NUM_CLIENTS,
initial_parameters=fl.common.ndarrays_to_parameters(get_parameters(Net())),
evaluate_fn=evaluate,
on_fit_config_fn=fit_config, # Pass the fit_config function
)
fl.simulation.start_simulation(
client_fn=client_fn,
num_clients=NUM_CLIENTS,
config=fl.server.ServerConfig(num_rounds=3), # Just three rounds
strategy=strategy,
client_resources=client_resources,
)
As we can see, the client logs now include the current round of federated learning (which they read from the config
dictionary). We can also configure local training to run for one epoch during the first and second round of federated learning, and then for two epochs during the third round.
Clients can also return arbitrary values to the server. To do so, they return a dictionary from fit
and/or evaluate
. We have seen and used this concept throughout this notebook without mentioning it explicitly: our FlowerClient
returns a dictionary containing a custom key/value pair as the third return value in evaluate
.
Scaling federated learning#
As a last step in this notebook, letās see how we can use Flower to experiment with a large number of clients.
[ ]:
NUM_CLIENTS = 1000
trainloaders, valloaders, testloader = load_datasets(NUM_CLIENTS)
We now have 1000 partitions, each holding 45 training and 5 validation examples. Given that the number of training examples on each client is quite small, we should probably train the model a bit longer, so we configure the clients to perform 3 local training epochs. We should also adjust the fraction of clients selected for training during each round (we donāt want all 1000 clients participating in every round), so we adjust fraction_fit
to 0.05
, which means that only 5% of available
clients (so 50 clients) will be selected for training each round:
[ ]:
def fit_config(server_round: int):
config = {
"server_round": server_round,
"local_epochs": 3,
}
return config
strategy = fl.server.strategy.FedAvg(
fraction_fit=0.025, # Train on 25 clients (each round)
fraction_evaluate=0.05, # Evaluate on 50 clients (each round)
min_fit_clients=20,
min_evaluate_clients=40,
min_available_clients=NUM_CLIENTS,
initial_parameters=fl.common.ndarrays_to_parameters(get_parameters(Net())),
on_fit_config_fn=fit_config,
)
fl.simulation.start_simulation(
client_fn=client_fn,
num_clients=NUM_CLIENTS,
config=fl.server.ServerConfig(num_rounds=3), # Just three rounds
strategy=strategy,
client_resources=client_resources,
)
Recap#
In this notebook, weāve seen how we can gradually enhance our system by customizing the strategy, initializing parameters on the server side, choosing a different strategy, and evaluating models on the server-side. Thatās quite a bit of flexibility with so little code, right?
In the later sections, weāve seen how we can communicate arbitrary values between server and clients to fully customize client-side execution. With that capability, we built a large-scale Federated Learning simulation using the Flower Virtual Client Engine and ran an experiment involving 1000 clients in the same workload - all in a Jupyter Notebook!
Next steps#
Before you continue, make sure to join the Flower community on Slack: Join Slack
Thereās a dedicated #questions
channel if you need help, but weād also love to hear who you are in #introductions
!
The Flower Federated Learning Tutorial - Part 3 shows how to build a fully custom Strategy
from scratch.