Get started with Flower#
Welcome to the Flower federated learning tutorial!
In this notebook, weāll build a federated learning system using Flower, Flower Datasets and PyTorch. In part 1, we use PyTorch for the model training pipeline and data loading. In part 2, we continue to federate the PyTorch-based pipeline using Flower.
Star Flower on GitHub āļø and join the Flower community on Slack to connect, ask questions, and get help: Join Slack š¼ Weād love to hear from you in the
#introductions
channel! And if anything is unclear, head over to the#questions
channel.
Letās get started!
Step 0: Preparation#
Before we begin with any actual code, letās make sure that we have everything we need.
Installing dependencies#
Next, we install the necessary packages for PyTorch (torch
and torchvision
), Flower Datasets (flwr-datasets
) and Flower (flwr
):
[ ]:
!pip install -q flwr[simulation] flwr_datasets[vision] torch torchvision matplotlib
Now that we have all dependencies installed, we can import everything we need for this tutorial:
[ ]:
from collections import OrderedDict
from typing import List, Tuple
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from datasets.utils.logging import disable_progress_bar
from torch.utils.data import DataLoader
import flwr as fl
from flwr.common import Metrics
from flwr_datasets import FederatedDataset
DEVICE = torch.device("cpu") # Try "cuda" to train on GPU
print(
f"Training on {DEVICE} using PyTorch {torch.__version__} and Flower {fl.__version__}"
)
disable_progress_bar()
It is possible to switch to a runtime that has GPU acceleration enabled (on Google Colab: Runtime > Change runtime type > Hardware accelerator: GPU > Save
). Note, however, that Google Colab is not always able to offer GPU acceleration. If you see an error related to GPU availability in one of the following sections, consider switching back to CPU-based execution by setting DEVICE = torch.device("cpu")
. If the runtime has GPU acceleration enabled, you should see the output
Training on cuda
, otherwise itāll say Training on cpu
.
Loading the data#
Federated learning can be applied to many different types of tasks across different domains. In this tutorial, we introduce federated learning by training a simple convolutional neural network (CNN) on the popular CIFAR-10 dataset. CIFAR-10 can be used to train image classifiers that distinguish between images from ten different classes: āairplaneā, āautomobileā, ābirdā, ācatā, ādeerā, ādogā, āfrogā, āhorseā, āshipā, and ātruckā.
We simulate having multiple datasets from multiple organizations (also called the ācross-siloā setting in federated learning) by splitting the original CIFAR-10 dataset into multiple partitions. Each partition will represent the data from a single organization. Weāre doing this purely for experimentation purposes, in the real world thereās no need for data splitting because each organization already has their own data (so the data is naturally partitioned).
Each organization will act as a client in the federated learning system. So having ten organizations participate in a federation means having ten clients connected to the federated learning server.
Letās now create the Federated Dataset abstraction that from flwr-datasets
that partitions the CIFAR-10. We will create small training and test set for each edge device and wrap each of them into a PyTorch DataLoader
:
[ ]:
NUM_CLIENTS = 10
BATCH_SIZE = 32
def load_datasets():
fds = FederatedDataset(dataset="cifar10", partitioners={"train": NUM_CLIENTS})
def apply_transforms(batch):
# Instead of passing transforms to CIFAR10(..., transform=transform)
# we will use this function to dataset.with_transform(apply_transforms)
# The transforms object is exactly the same
transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
)
batch["img"] = [transform(img) for img in batch["img"]]
return batch
# Create train/val for each partition and wrap it into DataLoader
trainloaders = []
valloaders = []
for partition_id in range(NUM_CLIENTS):
partition = fds.load_partition(partition_id, "train")
partition = partition.with_transform(apply_transforms)
partition = partition.train_test_split(train_size=0.8, seed=42)
trainloaders.append(DataLoader(partition["train"], batch_size=BATCH_SIZE))
valloaders.append(DataLoader(partition["test"], batch_size=BATCH_SIZE))
testset = fds.load_split("test").with_transform(apply_transforms)
testloader = DataLoader(testset, batch_size=BATCH_SIZE)
return trainloaders, valloaders, testloader
trainloaders, valloaders, testloader = load_datasets()
We now have a list of ten training sets and ten validation sets (trainloaders
and valloaders
) representing the data of ten different organizations. Each trainloader
/valloader
pair contains 4500 training examples and 500 validation examples. Thereās also a single testloader
(we did not split the test set). Again, this is only necessary for building research or educational systems, actual federated learning systems have their data naturally distributed across multiple
partitions.
Letās take a look at the first batch of images and labels in the first training set (i.e., trainloaders[0]
) before we move on:
[ ]:
batch = next(iter(trainloaders[0]))
images, labels = batch["img"], batch["label"]
# Reshape and convert images to a NumPy array
# matplotlib requires images with the shape (height, width, 3)
images = images.permute(0, 2, 3, 1).numpy()
# Denormalize
images = images / 2 + 0.5
# Create a figure and a grid of subplots
fig, axs = plt.subplots(4, 8, figsize=(12, 6))
# Loop over the images and plot them
for i, ax in enumerate(axs.flat):
ax.imshow(images[i])
ax.set_title(trainloaders[0].dataset.features["label"].int2str([labels[i]])[0])
ax.axis("off")
# Show the plot
fig.tight_layout()
plt.show()
The output above shows a random batch of images from the first trainloader
in our list of ten trainloaders
. It also prints the labels associated with each image (i.e., one of the ten possible labels weāve seen above). If you run the cell again, you should see another batch of images.
Step 1: Centralized Training with PyTorch#
Next, weāre going to use PyTorch to define a simple convolutional neural network. This introduction assumes basic familiarity with PyTorch, so it doesnāt cover the PyTorch-related aspects in full detail. If you want to dive deeper into PyTorch, we recommend DEEP LEARNING WITH PYTORCH: A 60 MINUTE BLITZ.
Defining the model#
We use the simple CNN described in the PyTorch tutorial:
[ ]:
class Net(nn.Module):
def __init__(self) -> None:
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
Letās continue with the usual training and test functions:
[ ]:
def train(net, trainloader, epochs: int, verbose=False):
"""Train the network on the training set."""
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters())
net.train()
for epoch in range(epochs):
correct, total, epoch_loss = 0, 0, 0.0
for batch in trainloader:
images, labels = batch["img"].to(DEVICE), batch["label"].to(DEVICE)
optimizer.zero_grad()
outputs = net(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# Metrics
epoch_loss += loss
total += labels.size(0)
correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()
epoch_loss /= len(trainloader.dataset)
epoch_acc = correct / total
if verbose:
print(f"Epoch {epoch+1}: train loss {epoch_loss}, accuracy {epoch_acc}")
def test(net, testloader):
"""Evaluate the network on the entire test set."""
criterion = torch.nn.CrossEntropyLoss()
correct, total, loss = 0, 0, 0.0
net.eval()
with torch.no_grad():
for batch in testloader:
images, labels = batch["img"].to(DEVICE), batch["label"].to(DEVICE)
outputs = net(images)
loss += criterion(outputs, labels).item()
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
loss /= len(testloader.dataset)
accuracy = correct / total
return loss, accuracy
Training the model#
We now have all the basic building blocks we need: a dataset, a model, a training function, and a test function. Letās put them together to train the model on the dataset of one of our organizations (trainloaders[0]
). This simulates the reality of most machine learning projects today: each organization has their own data and trains models only on this internal data:
[ ]:
trainloader = trainloaders[0]
valloader = valloaders[0]
net = Net().to(DEVICE)
for epoch in range(5):
train(net, trainloader, 1)
loss, accuracy = test(net, valloader)
print(f"Epoch {epoch+1}: validation loss {loss}, accuracy {accuracy}")
loss, accuracy = test(net, testloader)
print(f"Final test set performance:\n\tloss {loss}\n\taccuracy {accuracy}")
Training the simple CNN on our CIFAR-10 split for 5 epochs should result in a test set accuracy of about 41%, which is not good, but at the same time, it doesnāt really matter for the purposes of this tutorial. The intent was just to show a simplistic centralized training pipeline that sets the stage for what comes next - federated learning!
Step 2: Federated Learning with Flower#
Step 1 demonstrated a simple centralized training pipeline. All data was in one place (i.e., a single trainloader
and a single valloader
). Next, weāll simulate a situation where we have multiple datasets in multiple organizations and where we train a model over these organizations using federated learning.
Updating model parameters#
In federated learning, the server sends the global model parameters to the client, and the client updates the local model with the parameters received from the server. It then trains the model on the local data (which changes the model parameters locally) and sends the updated/changed model parameters back to the server (or, alternatively, it sends just the gradients back to the server, not the full model parameters).
We need two helper functions to update the local model with parameters received from the server and to get the updated model parameters from the local model: set_parameters
and get_parameters
. The following two functions do just that for the PyTorch model above.
The details of how this works are not really important here (feel free to consult the PyTorch documentation if you want to learn more). In essence, we use state_dict
to access PyTorch model parameter tensors. The parameter tensors are then converted to/from a list of NumPy ndarrayās (which Flower knows how to serialize/deserialize):
[ ]:
def set_parameters(net, parameters: List[np.ndarray]):
params_dict = zip(net.state_dict().keys(), parameters)
state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
net.load_state_dict(state_dict, strict=True)
def get_parameters(net) -> List[np.ndarray]:
return [val.cpu().numpy() for _, val in net.state_dict().items()]
Implementing a Flower client#
With that out of the way, letās move on to the interesting part. Federated learning systems consist of a server and multiple clients. In Flower, we create clients by implementing subclasses of flwr.client.Client
or flwr.client.NumPyClient
. We use NumPyClient
in this tutorial because it is easier to implement and requires us to write less boilerplate.
To implement the Flower client, we create a subclass of flwr.client.NumPyClient
and implement the three methods get_parameters
, fit
, and evaluate
:
get_parameters
: Return the current local model parametersfit
: Receive model parameters from the server, train the model parameters on the local data, and return the (updated) model parameters to the serverevaluate
: Receive model parameters from the server, evaluate the model parameters on the local data, and return the evaluation result to the server
We mentioned that our clients will use the previously defined PyTorch components for model training and evaluation. Letās see a simple Flower client implementation that brings everything together:
[ ]:
class FlowerClient(fl.client.NumPyClient):
def __init__(self, net, trainloader, valloader):
self.net = net
self.trainloader = trainloader
self.valloader = valloader
def get_parameters(self, config):
return get_parameters(self.net)
def fit(self, parameters, config):
set_parameters(self.net, parameters)
train(self.net, self.trainloader, epochs=1)
return get_parameters(self.net), len(self.trainloader), {}
def evaluate(self, parameters, config):
set_parameters(self.net, parameters)
loss, accuracy = test(self.net, self.valloader)
return float(loss), len(self.valloader), {"accuracy": float(accuracy)}
Our class FlowerClient
defines how local training/evaluation will be performed and allows Flower to call the local training/evaluation through fit
and evaluate
. Each instance of FlowerClient
represents a single client in our federated learning system. Federated learning systems have multiple clients (otherwise, thereās not much to federate), so each client will be represented by its own instance of FlowerClient
. If we have, for example, three clients in our workload, then
weād have three instances of FlowerClient
. Flower calls FlowerClient.fit
on the respective instance when the server selects a particular client for training (and FlowerClient.evaluate
for evaluation).
Using the Virtual Client Engine#
In this notebook, we want to simulate a federated learning system with 10 clients on a single machine. This means that the server and all 10 clients will live on a single machine and share resources such as CPU, GPU, and memory. Having 10 clients would mean having 10 instances of FlowerClient
in memory. Doing this on a single machine can quickly exhaust the available memory resources, even if only a subset of these clients participates in a single round of federated learning.
In addition to the regular capabilities where server and clients run on multiple machines, Flower, therefore, provides special simulation capabilities that create FlowerClient
instances only when they are actually necessary for training or evaluation. To enable the Flower framework to create clients when necessary, we need to implement a function called client_fn
that creates a FlowerClient
instance on demand. Flower calls client_fn
whenever it needs an instance of one particular
client to call fit
or evaluate
(those instances are usually discarded after use, so they should not keep any local state). Clients are identified by a client ID, or short cid
. The cid
can be used, for example, to load different local data partitions for different clients, as can be seen below:
[ ]:
def client_fn(cid: str) -> FlowerClient:
"""Create a Flower client representing a single organization."""
# Load model
net = Net().to(DEVICE)
# Load data (CIFAR-10)
# Note: each client gets a different trainloader/valloader, so each client
# will train and evaluate on their own unique data
trainloader = trainloaders[int(cid)]
valloader = valloaders[int(cid)]
# Create a single Flower client representing a single organization
return FlowerClient(net, trainloader, valloader).to_client()
Starting the training#
We now have the class FlowerClient
which defines client-side training/evaluation and client_fn
which allows Flower to create FlowerClient
instances whenever it needs to call fit
or evaluate
on one particular client. The last step is to start the actual simulation using flwr.simulation.start_simulation
.
The function start_simulation
accepts a number of arguments, amongst them the client_fn
used to create FlowerClient
instances, the number of clients to simulate (num_clients
), the number of federated learning rounds (num_rounds
), and the strategy. The strategy encapsulates the federated learning approach/algorithm, for example, Federated Averaging (FedAvg).
Flower has a number of built-in strategies, but we can also use our own strategy implementations to customize nearly all aspects of the federated learning approach. For this example, we use the built-in FedAvg
implementation and customize it using a few basic parameters. The last step is the actual call to start_simulation
which - you guessed it - starts the simulation:
[ ]:
# Create FedAvg strategy
strategy = fl.server.strategy.FedAvg(
fraction_fit=1.0, # Sample 100% of available clients for training
fraction_evaluate=0.5, # Sample 50% of available clients for evaluation
min_fit_clients=10, # Never sample less than 10 clients for training
min_evaluate_clients=5, # Never sample less than 5 clients for evaluation
min_available_clients=10, # Wait until all 10 clients are available
)
# Specify the resources each of your clients need. By default, each
# client will be allocated 1x CPU and 0x GPUs
client_resources = {"num_cpus": 1, "num_gpus": 0.0}
if DEVICE.type == "cuda":
# here we are assigning an entire GPU for each client.
client_resources = {"num_cpus": 1, "num_gpus": 1.0}
# Refer to our documentation for more details about Flower Simulations
# and how to setup these `client_resources`.
# Start simulation
fl.simulation.start_simulation(
client_fn=client_fn,
num_clients=NUM_CLIENTS,
config=fl.server.ServerConfig(num_rounds=5),
strategy=strategy,
client_resources=client_resources,
)
Behind the scenes#
So how does this work? How does Flower execute this simulation?
When we call start_simulation
, we tell Flower that there are 10 clients (num_clients=10
). Flower then goes ahead an asks the FedAvg
strategy to select clients. FedAvg
knows that it should select 100% of the available clients (fraction_fit=1.0
), so it goes ahead and selects 10 random clients (i.e., 100% of 10).
Flower then asks the selected 10 clients to train the model. When the server receives the model parameter updates from the clients, it hands those updates over to the strategy (FedAvg) for aggregation. The strategy aggregates those updates and returns the new global model, which then gets used in the next round of federated learning.
Whereās the accuracy?#
You may have noticed that all metrics except for losses_distributed
are empty. Where did the {"accuracy": float(accuracy)}
go?
Flower can automatically aggregate losses returned by individual clients, but it cannot do the same for metrics in the generic metrics dictionary (the one with the accuracy
key). Metrics dictionaries can contain very different kinds of metrics and even key/value pairs that are not metrics at all, so the framework does not (and can not) know how to handle these automatically.
As users, we need to tell the framework how to handle/aggregate these custom metrics, and we do so by passing metric aggregation functions to the strategy. The strategy will then call these functions whenever it receives fit or evaluate metrics from clients. The two possible functions are fit_metrics_aggregation_fn
and evaluate_metrics_aggregation_fn
.
Letās create a simple weighted averaging function to aggregate the accuracy
metric we return from evaluate
:
[ ]:
def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:
# Multiply accuracy of each client by number of examples used
accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics]
examples = [num_examples for num_examples, _ in metrics]
# Aggregate and return custom metric (weighted average)
return {"accuracy": sum(accuracies) / sum(examples)}
The only thing left to do is to tell the strategy to call this function whenever it receives evaluation metric dictionaries from the clients:
[ ]:
# Create FedAvg strategy
strategy = fl.server.strategy.FedAvg(
fraction_fit=1.0,
fraction_evaluate=0.5,
min_fit_clients=10,
min_evaluate_clients=5,
min_available_clients=10,
evaluate_metrics_aggregation_fn=weighted_average, # <-- pass the metric aggregation function
)
# Start simulation
fl.simulation.start_simulation(
client_fn=client_fn,
num_clients=NUM_CLIENTS,
config=fl.server.ServerConfig(num_rounds=5),
strategy=strategy,
client_resources=client_resources,
)
We now have a full system that performs federated training and federated evaluation. It uses the weighted_average
function to aggregate custom evaluation metrics and calculates a single accuracy
metric across all clients on the server side.
The other two categories of metrics (losses_centralized
and metrics_centralized
) are still empty because they only apply when centralized evaluation is being used. Part two of the Flower tutorial will cover centralized evaluation.
Final remarks#
Congratulations, you just trained a convolutional neural network, federated over 10 clients! With that, you understand the basics of federated learning with Flower. The same approach youāve seen can be used with other machine learning frameworks (not just PyTorch) and tasks (not just CIFAR-10 images classification), for example NLP with Hugging Face Transformers or speech with SpeechBrain.
In the next notebook, weāre going to cover some more advanced concepts. Want to customize your strategy? Initialize parameters on the server side? Or evaluate the aggregated model on the server side? Weāll cover all this and more in the next tutorial.
Next steps#
Before you continue, make sure to join the Flower community on Slack: Join Slack
Thereās a dedicated #questions
channel if you need help, but weād also love to hear who you are in #introductions
!
The Flower Federated Learning Tutorial - Part 2 goes into more depth about strategies and all the advanced things you can build with them.