Customize the client#
Welcome to the fourth part of the Flower federated learning tutorial. In the previous parts of this tutorial, we introduced federated learning with PyTorch and Flower (part 1), we learned how strategies can be used to customize the execution on both the server and the clients (part 2), and we built our own custom strategy from scratch (part 3).
In this notebook, we revisit NumPyClient
and introduce a new baseclass for building clients, simply named Client
. In previous parts of this tutorial, weāve based our client on NumPyClient
, a convenience class which makes it easy to work with machine learning libraries that have good NumPy interoperability. With Client
, we gain a lot of flexibility that we didnāt have before, but weāll also have to do a few things the we didnāt have to do before.
Star Flower on GitHub āļø and join the Flower community on Slack to connect, ask questions, and get help: Join Slack š¼ Weād love to hear from you in the
#introductions
channel! And if anything is unclear, head over to the#questions
channel.
Letās go deeper and see what it takes to move from NumPyClient
to Client
!
Step 0: Preparation#
Before we begin with the actual code, letās make sure that we have everything we need.
Installing dependencies#
First, we install the necessary packages:
[ ]:
!pip install -q flwr[simulation] torch torchvision scipy
Now that we have all dependencies installed, we can import everything we need for this tutorial:
[ ]:
from collections import OrderedDict
from typing import Dict, List, Optional, Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import CIFAR10
import flwr as fl
DEVICE = torch.device("cpu") # Try "cuda" to train on GPU
print(
f"Training on {DEVICE} using PyTorch {torch.__version__} and Flower {fl.__version__}"
)
It is possible to switch to a runtime that has GPU acceleration enabled (on Google Colab: Runtime > Change runtime type > Hardware acclerator: GPU > Save
). Note, however, that Google Colab is not always able to offer GPU acceleration. If you see an error related to GPU availability in one of the following sections, consider switching back to CPU-based execution by setting DEVICE = torch.device("cpu")
. If the runtime has GPU acceleration enabled, you should see the output
Training on cuda
, otherwise itāll say Training on cpu
.
Data loading#
Letās now load the CIFAR-10 training and test set, partition them into ten smaller datasets (each split into training and validation set), and wrap everything in their own DataLoader
.
[ ]:
NUM_CLIENTS = 10
def load_datasets(num_clients: int):
# Download and transform CIFAR-10 (train and test)
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)
trainset = CIFAR10("./dataset", train=True, download=True, transform=transform)
testset = CIFAR10("./dataset", train=False, download=True, transform=transform)
# Split training set into `num_clients` partitions to simulate different local datasets
partition_size = len(trainset) // num_clients
lengths = [partition_size] * num_clients
datasets = random_split(trainset, lengths, torch.Generator().manual_seed(42))
# Split each partition into train/val and create DataLoader
trainloaders = []
valloaders = []
for ds in datasets:
len_val = len(ds) // 10 # 10 % validation set
len_train = len(ds) - len_val
lengths = [len_train, len_val]
ds_train, ds_val = random_split(ds, lengths, torch.Generator().manual_seed(42))
trainloaders.append(DataLoader(ds_train, batch_size=32, shuffle=True))
valloaders.append(DataLoader(ds_val, batch_size=32))
testloader = DataLoader(testset, batch_size=32)
return trainloaders, valloaders, testloader
trainloaders, valloaders, testloader = load_datasets(NUM_CLIENTS)
Model training/evaluation#
Letās continue with the usual model definition (including set_parameters
and get_parameters
), training and test functions:
[ ]:
class Net(nn.Module):
def __init__(self) -> None:
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
def get_parameters(net) -> List[np.ndarray]:
return [val.cpu().numpy() for _, val in net.state_dict().items()]
def set_parameters(net, parameters: List[np.ndarray]):
params_dict = zip(net.state_dict().keys(), parameters)
state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
net.load_state_dict(state_dict, strict=True)
def train(net, trainloader, epochs: int):
"""Train the network on the training set."""
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters())
net.train()
for epoch in range(epochs):
correct, total, epoch_loss = 0, 0, 0.0
for images, labels in trainloader:
images, labels = images.to(DEVICE), labels.to(DEVICE)
optimizer.zero_grad()
outputs = net(images)
loss = criterion(net(images), labels)
loss.backward()
optimizer.step()
# Metrics
epoch_loss += loss
total += labels.size(0)
correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()
epoch_loss /= len(trainloader.dataset)
epoch_acc = correct / total
print(f"Epoch {epoch+1}: train loss {epoch_loss}, accuracy {epoch_acc}")
def test(net, testloader):
"""Evaluate the network on the entire test set."""
criterion = torch.nn.CrossEntropyLoss()
correct, total, loss = 0, 0, 0.0
net.eval()
with torch.no_grad():
for images, labels in testloader:
images, labels = images.to(DEVICE), labels.to(DEVICE)
outputs = net(images)
loss += criterion(outputs, labels).item()
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
loss /= len(testloader.dataset)
accuracy = correct / total
return loss, accuracy
Step 1: Revisiting NumPyClient#
So far, weāve implemented our client by subclassing flwr.client.NumPyClient
. The three methods we implemented are get_parameters
, fit
, and evaluate
. Finally, we wrap the creation of instances of this class in a function called client_fn
:
[ ]:
class FlowerNumPyClient(fl.client.NumPyClient):
def __init__(self, cid, net, trainloader, valloader):
self.cid = cid
self.net = net
self.trainloader = trainloader
self.valloader = valloader
def get_parameters(self, config):
print(f"[Client {self.cid}] get_parameters")
return get_parameters(self.net)
def fit(self, parameters, config):
print(f"[Client {self.cid}] fit, config: {config}")
set_parameters(self.net, parameters)
train(self.net, self.trainloader, epochs=1)
return get_parameters(self.net), len(self.trainloader), {}
def evaluate(self, parameters, config):
print(f"[Client {self.cid}] evaluate, config: {config}")
set_parameters(self.net, parameters)
loss, accuracy = test(self.net, self.valloader)
return float(loss), len(self.valloader), {"accuracy": float(accuracy)}
def numpyclient_fn(cid) -> FlowerNumPyClient:
net = Net().to(DEVICE)
trainloader = trainloaders[int(cid)]
valloader = valloaders[int(cid)]
return FlowerNumPyClient(cid, net, trainloader, valloader)
Weāve seen this before, thereās nothing new so far. The only tiny difference compared to the previous notebook is naming, weāve changed FlowerClient
to FlowerNumPyClient
and client_fn
to numpyclient_fn
. Letās run it to see the output we get:
[ ]:
# Specify client resources if you need GPU (defaults to 1 CPU and 0 GPU)
client_resources = None
if DEVICE.type == "cuda":
client_resources = {"num_gpus": 1}
fl.simulation.start_simulation(
client_fn=numpyclient_fn,
num_clients=2,
config=fl.server.ServerConfig(num_rounds=3),
client_resources=client_resources,
)
This works as expected, two clients are training for three rounds of federated learning.
Letās dive a little bit deeper and discuss how Flower executes this simulation. Whenever a client is selected to do some work, start_simulation
calls the function numpyclient_fn
to create an instance of our FlowerNumPyClient
(along with loading the model and the data).
But hereās the perhaps surprising part: Flower doesnāt actually use the FlowerNumPyClient
object directly. Instead, it wraps the object to makes it look like a subclass of flwr.client.Client
, not flwr.client.NumPyClient
. In fact, the Flower core framework doesnāt know how to handle NumPyClient
ās, it only knows how to handle Client
ās. NumPyClient
is just a convenience abstraction built on top of Client
.
Instead of building on top of NumPyClient
, we can directly build on top of Client
.
Step 2: Moving from NumPyClient
to Client
#
Letās try to do the same thing using Client
instead of NumPyClient
.
[ ]:
from flwr.common import (
Code,
EvaluateIns,
EvaluateRes,
FitIns,
FitRes,
GetParametersIns,
GetParametersRes,
Status,
ndarrays_to_parameters,
parameters_to_ndarrays,
)
class FlowerClient(fl.client.Client):
def __init__(self, cid, net, trainloader, valloader):
self.cid = cid
self.net = net
self.trainloader = trainloader
self.valloader = valloader
def get_parameters(self, ins: GetParametersIns) -> GetParametersRes:
print(f"[Client {self.cid}] get_parameters")
# Get parameters as a list of NumPy ndarray's
ndarrays: List[np.ndarray] = get_parameters(self.net)
# Serialize ndarray's into a Parameters object
parameters = ndarrays_to_parameters(ndarrays)
# Build and return response
status = Status(code=Code.OK, message="Success")
return GetParametersRes(
status=status,
parameters=parameters,
)
def fit(self, ins: FitIns) -> FitRes:
print(f"[Client {self.cid}] fit, config: {ins.config}")
# Deserialize parameters to NumPy ndarray's
parameters_original = ins.parameters
ndarrays_original = parameters_to_ndarrays(parameters_original)
# Update local model, train, get updated parameters
set_parameters(self.net, ndarrays_original)
train(self.net, self.trainloader, epochs=1)
ndarrays_updated = get_parameters(self.net)
# Serialize ndarray's into a Parameters object
parameters_updated = ndarrays_to_parameters(ndarrays_updated)
# Build and return response
status = Status(code=Code.OK, message="Success")
return FitRes(
status=status,
parameters=parameters_updated,
num_examples=len(self.trainloader),
metrics={},
)
def evaluate(self, ins: EvaluateIns) -> EvaluateRes:
print(f"[Client {self.cid}] evaluate, config: {ins.config}")
# Deserialize parameters to NumPy ndarray's
parameters_original = ins.parameters
ndarrays_original = parameters_to_ndarrays(parameters_original)
set_parameters(self.net, ndarrays_original)
loss, accuracy = test(self.net, self.valloader)
# return float(loss), len(self.valloader), {"accuracy": float(accuracy)}
# Build and return response
status = Status(code=Code.OK, message="Success")
return EvaluateRes(
status=status,
loss=float(loss),
num_examples=len(self.valloader),
metrics={"accuracy": float(accuracy)},
)
def client_fn(cid) -> FlowerClient:
net = Net().to(DEVICE)
trainloader = trainloaders[int(cid)]
valloader = valloaders[int(cid)]
return FlowerClient(cid, net, trainloader, valloader)
Before we discuss the code in more detail, letās try to run it! Gotta make sure our new Client
-based client works, right?
[ ]:
fl.simulation.start_simulation(
client_fn=client_fn,
num_clients=2,
config=fl.server.ServerConfig(num_rounds=3),
client_resources=client_resources,
)
Thatās it, weāre now using Client
. It probably looks similar to what weāve done with NumPyClient
. So whatās the difference?
First of all, itās more code. But why? The difference comes from the fact that Client
expects us to take care of parameter serialization and deserialization. For Flower to be able to send parameters over the network, it eventually needs to turn these parameters into bytes
. Turning parameters (e.g., NumPy ndarray
ās) into raw bytes is called serialization. Turning raw bytes into something more useful (like NumPy ndarray
ās) is called deserialization. Flower needs to do both: it
needs to serialize parameters on the server-side and send them to the client, the client needs to deserialize them to use them for local training, and then serialize the updated parameters again to send them back to the server, which (finally!) deserializes them again in order to aggregate them with the updates received from other clients.
The only real difference between Client and NumPyClient is that NumPyClient takes care of serialization and deserialization for you. It can do so because it expects you to return parameters as NumPy ndarrayās, and it knows how to handle these. This makes working with machine learning libraries that have good NumPy support (most of them) a breeze.
In terms of API, thereās one major difference: all methods in Client take exactly one argument (e.g., FitIns
in Client.fit
) and return exactly one value (e.g., FitRes
in Client.fit
). The methods in NumPyClient
on the other hand have multiple arguments (e.g., parameters
and config
in NumPyClient.fit
) and multiple return values (e.g., parameters
, num_example
, and metrics
in NumPyClient.fit
) if there are multiple things to handle. These *Ins
and
*Res
objects in Client
wrap all the individual values youāre used to from NumPyClient
.
Step 3: Custom serialization#
Here we will explore how to implement custom serialization with a simple example.
But first what is serialization? Serialization is just the process of converting an object into raw bytes, and equally as important, deserialization is the process of converting raw bytes back into an object. This is very useful for network communication. Indeed, without serialization, you could not just a Python object through the internet.
Federated Learning relies heavily on internet communication for training by sending Python objects back and forth between the clients and the server. This means that serialization is an essential part of Federated Learning.
In the following section, we will write a basic example where instead of sending a serialized version of our ndarray
s containing our parameters, we will first convert the ndarray
into sparse matrices, before sending them. This technique can be used to save bandwidth, as in certain cases where the weights of a model are sparse (containing many 0 entries), converting them to a sparse matrix can greatly improve their bytesize.
Our custom serialization/deserialization functions#
This is where the real serialization/deserialization will happen, especially in ndarray_to_sparse_bytes
for serialization and sparse_bytes_to_ndarray
for deserialization.
Note that we imported the scipy.sparse
library in order to convert our arrays.
[ ]:
from io import BytesIO
from typing import cast
import numpy as np
from flwr.common.typing import NDArray, NDArrays, Parameters
def ndarrays_to_sparse_parameters(ndarrays: NDArrays) -> Parameters:
"""Convert NumPy ndarrays to parameters object."""
tensors = [ndarray_to_sparse_bytes(ndarray) for ndarray in ndarrays]
return Parameters(tensors=tensors, tensor_type="numpy.ndarray")
def sparse_parameters_to_ndarrays(parameters: Parameters) -> NDArrays:
"""Convert parameters object to NumPy ndarrays."""
return [sparse_bytes_to_ndarray(tensor) for tensor in parameters.tensors]
def ndarray_to_sparse_bytes(ndarray: NDArray) -> bytes:
"""Serialize NumPy ndarray to bytes."""
bytes_io = BytesIO()
if len(ndarray.shape) > 1:
# We convert our ndarray into a sparse matrix
ndarray = torch.tensor(ndarray).to_sparse_csr()
# And send it byutilizing the sparse matrix attributes
# WARNING: NEVER set allow_pickle to true.
# Reason: loading pickled data can execute arbitrary code
# Source: https://numpy.org/doc/stable/reference/generated/numpy.save.html
np.savez(
bytes_io, # type: ignore
crow_indices=ndarray.crow_indices(),
col_indices=ndarray.col_indices(),
values=ndarray.values(),
allow_pickle=False,
)
else:
# WARNING: NEVER set allow_pickle to true.
# Reason: loading pickled data can execute arbitrary code
# Source: https://numpy.org/doc/stable/reference/generated/numpy.save.html
np.save(bytes_io, ndarray, allow_pickle=False)
return bytes_io.getvalue()
def sparse_bytes_to_ndarray(tensor: bytes) -> NDArray:
"""Deserialize NumPy ndarray from bytes."""
bytes_io = BytesIO(tensor)
# WARNING: NEVER set allow_pickle to true.
# Reason: loading pickled data can execute arbitrary code
# Source: https://numpy.org/doc/stable/reference/generated/numpy.load.html
loader = np.load(bytes_io, allow_pickle=False) # type: ignore
if "crow_indices" in loader:
# We convert our sparse matrix back to a ndarray, using the attributes we sent
ndarray_deserialized = (
torch.sparse_csr_tensor(
crow_indices=loader["crow_indices"],
col_indices=loader["col_indices"],
values=loader["values"],
)
.to_dense()
.numpy()
)
else:
ndarray_deserialized = loader
return cast(NDArray, ndarray_deserialized)
Client-side#
To be able to serialize our ndarray
s into sparse parameters, we will just have to call our custom functions in our flwr.client.Client
.
Indeed, in get_parameters
we need to serialize the parameters we got from our network using our custom ndarrays_to_sparse_parameters
defined above.
In fit
, we first need to deserialize the parameters coming from the server using our custom sparse_parameters_to_ndarrays
and then we need to serialize our local results with ndarrays_to_sparse_parameters
.
In evaluate
, we will only need to deserialize the global parameters with our custom function.
[ ]:
from flwr.common import (
Code,
EvaluateIns,
EvaluateRes,
FitIns,
FitRes,
GetParametersIns,
GetParametersRes,
Status,
)
class FlowerClient(fl.client.Client):
def __init__(self, cid, net, trainloader, valloader):
self.cid = cid
self.net = net
self.trainloader = trainloader
self.valloader = valloader
def get_parameters(self, ins: GetParametersIns) -> GetParametersRes:
print(f"[Client {self.cid}] get_parameters")
# Get parameters as a list of NumPy ndarray's
ndarrays: List[np.ndarray] = get_parameters(self.net)
# Serialize ndarray's into a Parameters object using our custom function
parameters = ndarrays_to_sparse_parameters(ndarrays)
# Build and return response
status = Status(code=Code.OK, message="Success")
return GetParametersRes(
status=status,
parameters=parameters,
)
def fit(self, ins: FitIns) -> FitRes:
print(f"[Client {self.cid}] fit, config: {ins.config}")
# Deserialize parameters to NumPy ndarray's using our custom function
parameters_original = ins.parameters
ndarrays_original = sparse_parameters_to_ndarrays(parameters_original)
# Update local model, train, get updated parameters
set_parameters(self.net, ndarrays_original)
train(self.net, self.trainloader, epochs=1)
ndarrays_updated = get_parameters(self.net)
# Serialize ndarray's into a Parameters object using our custom function
parameters_updated = ndarrays_to_sparse_parameters(ndarrays_updated)
# Build and return response
status = Status(code=Code.OK, message="Success")
return FitRes(
status=status,
parameters=parameters_updated,
num_examples=len(self.trainloader),
metrics={},
)
def evaluate(self, ins: EvaluateIns) -> EvaluateRes:
print(f"[Client {self.cid}] evaluate, config: {ins.config}")
# Deserialize parameters to NumPy ndarray's using our custom function
parameters_original = ins.parameters
ndarrays_original = sparse_parameters_to_ndarrays(parameters_original)
set_parameters(self.net, ndarrays_original)
loss, accuracy = test(self.net, self.valloader)
# Build and return response
status = Status(code=Code.OK, message="Success")
return EvaluateRes(
status=status,
loss=float(loss),
num_examples=len(self.valloader),
metrics={"accuracy": float(accuracy)},
)
def client_fn(cid) -> FlowerClient:
net = Net().to(DEVICE)
trainloader = trainloaders[int(cid)]
valloader = valloaders[int(cid)]
return FlowerClient(cid, net, trainloader, valloader)
Server-side#
For this example, we will just use FedAvg
as a strategy. To change the serialization and deserialization here, we only need to reimplement the evaluate
and aggregate_fit
functions of FedAvg
. The other functions of the strategy will be inherited from the super class FedAvg
.
As you can see only one line as change in evaluate
:
parameters_ndarrays = sparse_parameters_to_ndarrays(parameters)
And for aggregate_fit
, we will first deserialize every result we received:
weights_results = [
(sparse_parameters_to_ndarrays(fit_res.parameters), fit_res.num_examples)
for _, fit_res in results
]
And then serialize the aggregated result:
parameters_aggregated = ndarrays_to_sparse_parameters(aggregate(weights_results))
[ ]:
from logging import WARNING
from typing import Callable, Dict, List, Optional, Tuple, Union
from flwr.common import FitRes, MetricsAggregationFn, NDArrays, Parameters, Scalar
from flwr.common.logger import log
from flwr.server.client_proxy import ClientProxy
from flwr.server.strategy import FedAvg
from flwr.server.strategy.aggregate import aggregate
WARNING_MIN_AVAILABLE_CLIENTS_TOO_LOW = """
Setting `min_available_clients` lower than `min_fit_clients` or
`min_evaluate_clients` can cause the server to fail when there are too few clients
connected to the server. `min_available_clients` must be set to a value larger
than or equal to the values of `min_fit_clients` and `min_evaluate_clients`.
"""
class FedSparse(FedAvg):
def __init__(
self,
*,
fraction_fit: float = 1.0,
fraction_evaluate: float = 1.0,
min_fit_clients: int = 2,
min_evaluate_clients: int = 2,
min_available_clients: int = 2,
evaluate_fn: Optional[
Callable[
[int, NDArrays, Dict[str, Scalar]],
Optional[Tuple[float, Dict[str, Scalar]]],
]
] = None,
on_fit_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None,
on_evaluate_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None,
accept_failures: bool = True,
initial_parameters: Optional[Parameters] = None,
fit_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None,
evaluate_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None,
) -> None:
"""Custom FedAvg strategy with sparse matrices.
Parameters
----------
fraction_fit : float, optional
Fraction of clients used during training. Defaults to 0.1.
fraction_evaluate : float, optional
Fraction of clients used during validation. Defaults to 0.1.
min_fit_clients : int, optional
Minimum number of clients used during training. Defaults to 2.
min_evaluate_clients : int, optional
Minimum number of clients used during validation. Defaults to 2.
min_available_clients : int, optional
Minimum number of total clients in the system. Defaults to 2.
evaluate_fn : Optional[Callable[[int, NDArrays, Dict[str, Scalar]], Optional[Tuple[float, Dict[str, Scalar]]]]]
Optional function used for validation. Defaults to None.
on_fit_config_fn : Callable[[int], Dict[str, Scalar]], optional
Function used to configure training. Defaults to None.
on_evaluate_config_fn : Callable[[int], Dict[str, Scalar]], optional
Function used to configure validation. Defaults to None.
accept_failures : bool, optional
Whether or not accept rounds containing failures. Defaults to True.
initial_parameters : Parameters, optional
Initial global model parameters.
"""
if (
min_fit_clients > min_available_clients
or min_evaluate_clients > min_available_clients
):
log(WARNING, WARNING_MIN_AVAILABLE_CLIENTS_TOO_LOW)
super().__init__(
fraction_fit=fraction_fit,
fraction_evaluate=fraction_evaluate,
min_fit_clients=min_fit_clients,
min_evaluate_clients=min_evaluate_clients,
min_available_clients=min_available_clients,
evaluate_fn=evaluate_fn,
on_fit_config_fn=on_fit_config_fn,
on_evaluate_config_fn=on_evaluate_config_fn,
accept_failures=accept_failures,
initial_parameters=initial_parameters,
fit_metrics_aggregation_fn=fit_metrics_aggregation_fn,
evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn,
)
def evaluate(
self, server_round: int, parameters: Parameters
) -> Optional[Tuple[float, Dict[str, Scalar]]]:
"""Evaluate model parameters using an evaluation function."""
if self.evaluate_fn is None:
# No evaluation function provided
return None
# We deserialize using our custom method
parameters_ndarrays = sparse_parameters_to_ndarrays(parameters)
eval_res = self.evaluate_fn(server_round, parameters_ndarrays, {})
if eval_res is None:
return None
loss, metrics = eval_res
return loss, metrics
def aggregate_fit(
self,
server_round: int,
results: List[Tuple[ClientProxy, FitRes]],
failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],
) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:
"""Aggregate fit results using weighted average."""
if not results:
return None, {}
# Do not aggregate if there are failures and failures are not accepted
if not self.accept_failures and failures:
return None, {}
# We deserialize each of the results with our custom method
weights_results = [
(sparse_parameters_to_ndarrays(fit_res.parameters), fit_res.num_examples)
for _, fit_res in results
]
# We serialize the aggregated result using our custom method
parameters_aggregated = ndarrays_to_sparse_parameters(
aggregate(weights_results)
)
# Aggregate custom metrics if aggregation fn was provided
metrics_aggregated = {}
if self.fit_metrics_aggregation_fn:
fit_metrics = [(res.num_examples, res.metrics) for _, res in results]
metrics_aggregated = self.fit_metrics_aggregation_fn(fit_metrics)
elif server_round == 1: # Only log this warning once
log(WARNING, "No fit_metrics_aggregation_fn provided")
return parameters_aggregated, metrics_aggregated
We can now run our custom serialization example!
[ ]:
strategy = FedSparse()
fl.simulation.start_simulation(
strategy=strategy,
client_fn=client_fn,
num_clients=2,
config=fl.server.ServerConfig(num_rounds=3),
client_resources=client_resources,
)
Recap#
In this part of the tutorial, weāve seen how we can build clients by subclassing either NumPyClient
or Client
. NumPyClient
is a convenience abstraction that makes it easier to work with machine learning libraries that have good NumPy interoperability. Client
is a more flexible abstraction that allows us to do things that are not possible in NumPyClient
. In order to do so, it requires us to handle parameter serialization and deserialization ourselves.
Next steps#
Before you continue, make sure to join the Flower community on Slack: Join Slack
Thereās a dedicated #questions
channel if you need help, but weād also love to hear who you are in #introductions
!
This is the final part of the Flower tutorial (for now!), congratulations! Youāre now well equipped to understand the rest of the documentation. There are many topics we didnāt cover in the tutorial, we recommend the following resources: