Customize the client#
Welcome to the fourth part of the Flower federated learning tutorial. In the previous parts of this tutorial, we introduced federated learning with PyTorch and Flower (part 1), we learned how strategies can be used to customize the execution on both the server and the clients (part 2), and we built our own custom strategy from scratch (part 3).
Dans ce carnet, nous revisitons NumPyClient` et introduisons une nouvelle classe de base pour construire des clients, simplement appelée Client`. Dans les parties précédentes de ce tutoriel, nous avons basé notre client sur NumPyClient
, une classe de commodité qui facilite le travail avec les bibliothèques d’apprentissage automatique qui ont une bonne interopérabilité NumPy. Avec Client
, nous gagnons beaucoup de flexibilité que nous n’avions pas auparavant, mais nous devrons également faire quelques choses que nous n’avions pas à faire auparavant.
Star Flower on GitHub ⭐️ et rejoignez la communauté Flower sur Slack pour vous connecter, poser des questions et obtenir de l’aide : Join Slack 🌼 Nous serions ravis d’avoir de vos nouvelles dans le canal
#introductions
! Et si quelque chose n’est pas clair, rendez-vous sur le canal#questions
.
Allons plus loin et voyons ce qu’il faut faire pour passer de NumPyClient
à Client
!
Étape 0 : Préparation#
Avant de commencer le code proprement dit, assurons-nous que nous disposons de tout ce dont nous avons besoin.
Installation des dépendances#
Tout d’abord, nous installons les paquets nécessaires :
[ ]:
!pip install -q flwr[simulation] torch torchvision scipy
Maintenant que toutes les dépendances sont installées, nous pouvons importer tout ce dont nous avons besoin pour ce tutoriel :
[ ]:
from collections import OrderedDict
from typing import Dict, List, Optional, Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import CIFAR10
import flwr as fl
DEVICE = torch.device("cpu") # Try "cuda" to train on GPU
print(
f"Training on {DEVICE} using PyTorch {torch.__version__} and Flower {fl.__version__}"
)
Il est possible de passer à un runtime dont l’accélération GPU est activée (sur Google Colab : Runtime > Change runtime type > Hardware acclerator : GPU > Save
). Note cependant que Google Colab n’est pas toujours en mesure de proposer l’accélération GPU. Si tu vois une erreur liée à la disponibilité du GPU dans l’une des sections suivantes, envisage de repasser à une exécution basée sur le CPU en définissant DEVICE = torch.device("cpu")
. Si le runtime a activé l’accélération GPU, tu devrais voir apparaître le résultat Training on cuda
, sinon il dira Training on cpu
.
Chargement des données#
Chargeons maintenant les ensembles d’entraînement et de test CIFAR-10, divisons-les en dix ensembles de données plus petits (chacun divisé en ensemble d’entraînement et de validation) et enveloppons le tout dans leur propre DataLoader
.
[ ]:
NUM_CLIENTS = 10
def load_datasets(num_clients: int):
# Download and transform CIFAR-10 (train and test)
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)
trainset = CIFAR10("./dataset", train=True, download=True, transform=transform)
testset = CIFAR10("./dataset", train=False, download=True, transform=transform)
# Split training set into `num_clients` partitions to simulate different local datasets
partition_size = len(trainset) // num_clients
lengths = [partition_size] * num_clients
datasets = random_split(trainset, lengths, torch.Generator().manual_seed(42))
# Split each partition into train/val and create DataLoader
trainloaders = []
valloaders = []
for ds in datasets:
len_val = len(ds) // 10 # 10 % validation set
len_train = len(ds) - len_val
lengths = [len_train, len_val]
ds_train, ds_val = random_split(ds, lengths, torch.Generator().manual_seed(42))
trainloaders.append(DataLoader(ds_train, batch_size=32, shuffle=True))
valloaders.append(DataLoader(ds_val, batch_size=32))
testloader = DataLoader(testset, batch_size=32)
return trainloaders, valloaders, testloader
trainloaders, valloaders, testloader = load_datasets(NUM_CLIENTS)
Formation/évaluation du modèle#
Continuons avec la définition habituelle du modèle (y compris set_parameters
et get_parameters
), les fonctions d’entraînement et de test :
[ ]:
class Net(nn.Module):
def __init__(self) -> None:
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
def get_parameters(net) -> List[np.ndarray]:
return [val.cpu().numpy() for _, val in net.state_dict().items()]
def set_parameters(net, parameters: List[np.ndarray]):
params_dict = zip(net.state_dict().keys(), parameters)
state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
net.load_state_dict(state_dict, strict=True)
def train(net, trainloader, epochs: int):
"""Train the network on the training set."""
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters())
net.train()
for epoch in range(epochs):
correct, total, epoch_loss = 0, 0, 0.0
for images, labels in trainloader:
images, labels = images.to(DEVICE), labels.to(DEVICE)
optimizer.zero_grad()
outputs = net(images)
loss = criterion(net(images), labels)
loss.backward()
optimizer.step()
# Metrics
epoch_loss += loss
total += labels.size(0)
correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()
epoch_loss /= len(trainloader.dataset)
epoch_acc = correct / total
print(f"Epoch {epoch+1}: train loss {epoch_loss}, accuracy {epoch_acc}")
def test(net, testloader):
"""Evaluate the network on the entire test set."""
criterion = torch.nn.CrossEntropyLoss()
correct, total, loss = 0, 0, 0.0
net.eval()
with torch.no_grad():
for images, labels in testloader:
images, labels = images.to(DEVICE), labels.to(DEVICE)
outputs = net(images)
loss += criterion(outputs, labels).item()
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
loss /= len(testloader.dataset)
accuracy = correct / total
return loss, accuracy
Étape 1 : Revoir NumPyClient#
Jusqu’à présent, nous avons implémenté notre client en sous-classant flwr.client.NumPyClient
. Les trois méthodes que nous avons implémentées sont get_parameters
, fit
et evaluate
. Enfin, nous enveloppons la création d’instances de cette classe dans une fonction appelée client_fn
:
[ ]:
class FlowerNumPyClient(fl.client.NumPyClient):
def __init__(self, cid, net, trainloader, valloader):
self.cid = cid
self.net = net
self.trainloader = trainloader
self.valloader = valloader
def get_parameters(self, config):
print(f"[Client {self.cid}] get_parameters")
return get_parameters(self.net)
def fit(self, parameters, config):
print(f"[Client {self.cid}] fit, config: {config}")
set_parameters(self.net, parameters)
train(self.net, self.trainloader, epochs=1)
return get_parameters(self.net), len(self.trainloader), {}
def evaluate(self, parameters, config):
print(f"[Client {self.cid}] evaluate, config: {config}")
set_parameters(self.net, parameters)
loss, accuracy = test(self.net, self.valloader)
return float(loss), len(self.valloader), {"accuracy": float(accuracy)}
def numpyclient_fn(cid) -> FlowerNumPyClient:
net = Net().to(DEVICE)
trainloader = trainloaders[int(cid)]
valloader = valloaders[int(cid)]
return FlowerNumPyClient(cid, net, trainloader, valloader)
Nous avons déjà vu cela auparavant, il n’y a rien de nouveau jusqu’à présent. La seule petite différence par rapport au carnet précédent est le nommage, nous avons changé FlowerClient
en FlowerNumPyClient
et client_fn
en numpyclient_fn
. Exécutons-le pour voir la sortie que nous obtenons :
[ ]:
# Specify client resources if you need GPU (defaults to 1 CPU and 0 GPU)
client_resources = None
if DEVICE.type == "cuda":
client_resources = {"num_gpus": 1}
fl.simulation.start_simulation(
client_fn=numpyclient_fn,
num_clients=2,
config=fl.server.ServerConfig(num_rounds=3),
client_resources=client_resources,
)
Cela fonctionne comme prévu, deux clients s’entraînent pour trois tours d’apprentissage fédéré.
Plongeons un peu plus profondément et discutons de la façon dont Flower exécute cette simulation. Chaque fois qu’un client est sélectionné pour effectuer un travail, start_simulation
appelle la fonction numpyclient_fn
pour créer une instance de notre FlowerNumPyClient
(en même temps qu’il charge le modèle et les données).
Mais voici la partie la plus surprenante : Flower n’utilise pas directement l’objet FlowerNumPyClient. Au lieu de cela, il enveloppe l’objet pour le faire ressembler à une sous-classe de flwr.client.Client, et non de flwr.client.NumPyClient. En fait, le noyau de Flower ne sait pas comment gérer les NumPyClient, il sait seulement comment gérer les Client. NumPyClient est juste une abstraction de commodité construite au dessus de Client.
Au lieu de construire par-dessus NumPyClient`, nous pouvons construire directement par-dessus Client`.
Étape 2 : Passer de NumPyClient
à Client
#
Essayons de faire la même chose en utilisant Client
au lieu de NumPyClient
.
[ ]:
from flwr.common import (
Code,
EvaluateIns,
EvaluateRes,
FitIns,
FitRes,
GetParametersIns,
GetParametersRes,
Status,
ndarrays_to_parameters,
parameters_to_ndarrays,
)
class FlowerClient(fl.client.Client):
def __init__(self, cid, net, trainloader, valloader):
self.cid = cid
self.net = net
self.trainloader = trainloader
self.valloader = valloader
def get_parameters(self, ins: GetParametersIns) -> GetParametersRes:
print(f"[Client {self.cid}] get_parameters")
# Get parameters as a list of NumPy ndarray's
ndarrays: List[np.ndarray] = get_parameters(self.net)
# Serialize ndarray's into a Parameters object
parameters = ndarrays_to_parameters(ndarrays)
# Build and return response
status = Status(code=Code.OK, message="Success")
return GetParametersRes(
status=status,
parameters=parameters,
)
def fit(self, ins: FitIns) -> FitRes:
print(f"[Client {self.cid}] fit, config: {ins.config}")
# Deserialize parameters to NumPy ndarray's
parameters_original = ins.parameters
ndarrays_original = parameters_to_ndarrays(parameters_original)
# Update local model, train, get updated parameters
set_parameters(self.net, ndarrays_original)
train(self.net, self.trainloader, epochs=1)
ndarrays_updated = get_parameters(self.net)
# Serialize ndarray's into a Parameters object
parameters_updated = ndarrays_to_parameters(ndarrays_updated)
# Build and return response
status = Status(code=Code.OK, message="Success")
return FitRes(
status=status,
parameters=parameters_updated,
num_examples=len(self.trainloader),
metrics={},
)
def evaluate(self, ins: EvaluateIns) -> EvaluateRes:
print(f"[Client {self.cid}] evaluate, config: {ins.config}")
# Deserialize parameters to NumPy ndarray's
parameters_original = ins.parameters
ndarrays_original = parameters_to_ndarrays(parameters_original)
set_parameters(self.net, ndarrays_original)
loss, accuracy = test(self.net, self.valloader)
# return float(loss), len(self.valloader), {"accuracy": float(accuracy)}
# Build and return response
status = Status(code=Code.OK, message="Success")
return EvaluateRes(
status=status,
loss=float(loss),
num_examples=len(self.valloader),
metrics={"accuracy": float(accuracy)},
)
def client_fn(cid) -> FlowerClient:
net = Net().to(DEVICE)
trainloader = trainloaders[int(cid)]
valloader = valloaders[int(cid)]
return FlowerClient(cid, net, trainloader, valloader)
Avant de discuter du code plus en détail, essayons de l’exécuter ! Nous devons nous assurer que notre nouveau client basé sur le Client
fonctionne, n’est-ce pas ?
[ ]:
fl.simulation.start_simulation(
client_fn=client_fn,
num_clients=2,
config=fl.server.ServerConfig(num_rounds=3),
client_resources=client_resources,
)
Voilà, nous utilisons maintenant Client
. Cela ressemble probablement à ce que nous avons fait avec NumPyClient
. Alors quelle est la différence ?
First of all, it’s more code. But why? The difference comes from the fact that Client
expects us to take care of parameter serialization and deserialization. For Flower to be able to send parameters over the network, it eventually needs to turn these parameters into bytes
. Turning parameters (e.g., NumPy ndarray
’s) into raw bytes is called serialization. Turning raw bytes into something more useful (like NumPy ndarray
’s) is called deserialization. Flower needs to do both: it
needs to serialize parameters on the server-side and send them to the client, the client needs to deserialize them to use them for local training, and then serialize the updated parameters again to send them back to the server, which (finally!) deserializes them again in order to aggregate them with the updates received from other clients.
La seule vraie différence entre Client et NumPyClient est que NumPyClient s’occupe de la sérialisation et de la désérialisation pour toi. Il peut le faire parce qu’il s’attend à ce que tu renvoies des paramètres sous forme de NumPy ndarray, et il sait comment les gérer. Cela permet de travailler avec des bibliothèques d’apprentissage automatique qui ont une bonne prise en charge de NumPy (la plupart d’entre elles) en un clin d’œil.
In terms of API, there’s one major difference: all methods in Client take exactly one argument (e.g., FitIns
in Client.fit
) and return exactly one value (e.g., FitRes
in Client.fit
). The methods in NumPyClient
on the other hand have multiple arguments (e.g., parameters
and config
in NumPyClient.fit
) and multiple return values (e.g., parameters
, num_example
, and metrics
in NumPyClient.fit
) if there are multiple things to handle. These *Ins
and
*Res
objects in Client
wrap all the individual values you’re used to from NumPyClient
.
Étape 3 : Sérialisation personnalisée#
Nous allons ici explorer comment mettre en œuvre une sérialisation personnalisée à l’aide d’un exemple simple.
Mais d’abord, qu’est-ce que la sérialisation ? La sérialisation est simplement le processus de conversion d’un objet en octets bruts, et tout aussi important, la désérialisation est le processus de reconversion des octets bruts en objet. Ceci est très utile pour la communication réseau. En effet, sans la sérialisation, tu ne pourrais pas faire passer un objet Python par Internet.
L’apprentissage fédéré s’appuie fortement sur la communication Internet pour la formation en envoyant des objets Python dans les deux sens entre les clients et le serveur, ce qui signifie que la sérialisation est un élément essentiel de l’apprentissage fédéré.
Dans la section suivante, nous allons écrire un exemple de base où, au lieu d’envoyer une version sérialisée de nos ndarray
contenant nos paramètres, nous allons d’abord convertir les ndarray
en matrices éparses, avant de les envoyer. Cette technique peut être utilisée pour économiser de la bande passante, car dans certains cas où les poids d’un modèle sont épars (contenant de nombreuses entrées 0), les convertir en une matrice éparse peut grandement améliorer leur taille en octets.
Nos fonctions de sérialisation/désérialisation personnalisées#
C’est là que la véritable sérialisation/désérialisation se produira, en particulier dans ndarray_to_sparse_bytes
pour la sérialisation et sparse_bytes_to_ndarray
pour la désérialisation.
Notez que nous avons importé la bibliothèque scipy.sparse
afin de convertir nos tableaux.
[ ]:
from io import BytesIO
from typing import cast
import numpy as np
from flwr.common.typing import NDArray, NDArrays, Parameters
def ndarrays_to_sparse_parameters(ndarrays: NDArrays) -> Parameters:
"""Convert NumPy ndarrays to parameters object."""
tensors = [ndarray_to_sparse_bytes(ndarray) for ndarray in ndarrays]
return Parameters(tensors=tensors, tensor_type="numpy.ndarray")
def sparse_parameters_to_ndarrays(parameters: Parameters) -> NDArrays:
"""Convert parameters object to NumPy ndarrays."""
return [sparse_bytes_to_ndarray(tensor) for tensor in parameters.tensors]
def ndarray_to_sparse_bytes(ndarray: NDArray) -> bytes:
"""Serialize NumPy ndarray to bytes."""
bytes_io = BytesIO()
if len(ndarray.shape) > 1:
# We convert our ndarray into a sparse matrix
ndarray = torch.tensor(ndarray).to_sparse_csr()
# And send it byutilizing the sparse matrix attributes
# WARNING: NEVER set allow_pickle to true.
# Reason: loading pickled data can execute arbitrary code
# Source: https://numpy.org/doc/stable/reference/generated/numpy.save.html
np.savez(
bytes_io, # type: ignore
crow_indices=ndarray.crow_indices(),
col_indices=ndarray.col_indices(),
values=ndarray.values(),
allow_pickle=False,
)
else:
# WARNING: NEVER set allow_pickle to true.
# Reason: loading pickled data can execute arbitrary code
# Source: https://numpy.org/doc/stable/reference/generated/numpy.save.html
np.save(bytes_io, ndarray, allow_pickle=False)
return bytes_io.getvalue()
def sparse_bytes_to_ndarray(tensor: bytes) -> NDArray:
"""Deserialize NumPy ndarray from bytes."""
bytes_io = BytesIO(tensor)
# WARNING: NEVER set allow_pickle to true.
# Reason: loading pickled data can execute arbitrary code
# Source: https://numpy.org/doc/stable/reference/generated/numpy.load.html
loader = np.load(bytes_io, allow_pickle=False) # type: ignore
if "crow_indices" in loader:
# We convert our sparse matrix back to a ndarray, using the attributes we sent
ndarray_deserialized = (
torch.sparse_csr_tensor(
crow_indices=loader["crow_indices"],
col_indices=loader["col_indices"],
values=loader["values"],
)
.to_dense()
.numpy()
)
else:
ndarray_deserialized = loader
return cast(NDArray, ndarray_deserialized)
Côté client#
Pour pouvoir sérialiser nos ndarray
en paramètres sparse, il nous suffira d’appeler nos fonctions personnalisées dans notre flwr.client.Client
.
En effet, dans get_parameters
nous devons sérialiser les paramètres que nous avons obtenus de notre réseau en utilisant nos ndarrays_to_sparse_parameters
personnalisés définis ci-dessus.
Dans fit
, nous devons d’abord désérialiser les paramètres provenant du serveur en utilisant notre sparse_parameters_to_ndarrays
personnalisé, puis nous devons sérialiser nos résultats locaux avec ndarrays_to_sparse_parameters
.
Dans evaluate
, nous n’aurons besoin que de désérialiser les paramètres globaux avec notre fonction personnalisée.
[ ]:
from flwr.common import (
Code,
EvaluateIns,
EvaluateRes,
FitIns,
FitRes,
GetParametersIns,
GetParametersRes,
Status,
)
class FlowerClient(fl.client.Client):
def __init__(self, cid, net, trainloader, valloader):
self.cid = cid
self.net = net
self.trainloader = trainloader
self.valloader = valloader
def get_parameters(self, ins: GetParametersIns) -> GetParametersRes:
print(f"[Client {self.cid}] get_parameters")
# Get parameters as a list of NumPy ndarray's
ndarrays: List[np.ndarray] = get_parameters(self.net)
# Serialize ndarray's into a Parameters object using our custom function
parameters = ndarrays_to_sparse_parameters(ndarrays)
# Build and return response
status = Status(code=Code.OK, message="Success")
return GetParametersRes(
status=status,
parameters=parameters,
)
def fit(self, ins: FitIns) -> FitRes:
print(f"[Client {self.cid}] fit, config: {ins.config}")
# Deserialize parameters to NumPy ndarray's using our custom function
parameters_original = ins.parameters
ndarrays_original = sparse_parameters_to_ndarrays(parameters_original)
# Update local model, train, get updated parameters
set_parameters(self.net, ndarrays_original)
train(self.net, self.trainloader, epochs=1)
ndarrays_updated = get_parameters(self.net)
# Serialize ndarray's into a Parameters object using our custom function
parameters_updated = ndarrays_to_sparse_parameters(ndarrays_updated)
# Build and return response
status = Status(code=Code.OK, message="Success")
return FitRes(
status=status,
parameters=parameters_updated,
num_examples=len(self.trainloader),
metrics={},
)
def evaluate(self, ins: EvaluateIns) -> EvaluateRes:
print(f"[Client {self.cid}] evaluate, config: {ins.config}")
# Deserialize parameters to NumPy ndarray's using our custom function
parameters_original = ins.parameters
ndarrays_original = sparse_parameters_to_ndarrays(parameters_original)
set_parameters(self.net, ndarrays_original)
loss, accuracy = test(self.net, self.valloader)
# Build and return response
status = Status(code=Code.OK, message="Success")
return EvaluateRes(
status=status,
loss=float(loss),
num_examples=len(self.valloader),
metrics={"accuracy": float(accuracy)},
)
def client_fn(cid) -> FlowerClient:
net = Net().to(DEVICE)
trainloader = trainloaders[int(cid)]
valloader = valloaders[int(cid)]
return FlowerClient(cid, net, trainloader, valloader)
Côté serveur#
Pour cet exemple, nous utiliserons simplement FedAvg
comme stratégie. Pour modifier la sérialisation et la désérialisation ici, il suffit de réimplémenter les fonctions evaluate
et aggregate_fit
de FedAvg
. Les autres fonctions de la stratégie seront héritées de la super-classe FedAvg
.
Comme tu peux le voir, seule une ligne a été modifiée dans evaluate
:
parameters_ndarrays = sparse_parameters_to_ndarrays(parameters)
Et pour aggregate_fit
, nous allons d’abord désérialiser chaque résultat que nous avons reçu :
weights_results = [
(sparse_parameters_to_ndarrays(fit_res.parameters), fit_res.num_examples)
for _, fit_res in results
]
Puis sérialise le résultat agrégé :
parameters_aggregated = ndarrays_to_sparse_parameters(aggregate(weights_results))
[ ]:
from logging import WARNING
from typing import Callable, Dict, List, Optional, Tuple, Union
from flwr.common import FitRes, MetricsAggregationFn, NDArrays, Parameters, Scalar
from flwr.common.logger import log
from flwr.server.client_proxy import ClientProxy
from flwr.server.strategy import FedAvg
from flwr.server.strategy.aggregate import aggregate
WARNING_MIN_AVAILABLE_CLIENTS_TOO_LOW = """
Setting `min_available_clients` lower than `min_fit_clients` or
`min_evaluate_clients` can cause the server to fail when there are too few clients
connected to the server. `min_available_clients` must be set to a value larger
than or equal to the values of `min_fit_clients` and `min_evaluate_clients`.
"""
class FedSparse(FedAvg):
def __init__(
self,
*,
fraction_fit: float = 1.0,
fraction_evaluate: float = 1.0,
min_fit_clients: int = 2,
min_evaluate_clients: int = 2,
min_available_clients: int = 2,
evaluate_fn: Optional[
Callable[
[int, NDArrays, Dict[str, Scalar]],
Optional[Tuple[float, Dict[str, Scalar]]],
]
] = None,
on_fit_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None,
on_evaluate_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None,
accept_failures: bool = True,
initial_parameters: Optional[Parameters] = None,
fit_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None,
evaluate_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None,
) -> None:
"""Custom FedAvg strategy with sparse matrices.
Parameters
----------
fraction_fit : float, optional
Fraction of clients used during training. Defaults to 0.1.
fraction_evaluate : float, optional
Fraction of clients used during validation. Defaults to 0.1.
min_fit_clients : int, optional
Minimum number of clients used during training. Defaults to 2.
min_evaluate_clients : int, optional
Minimum number of clients used during validation. Defaults to 2.
min_available_clients : int, optional
Minimum number of total clients in the system. Defaults to 2.
evaluate_fn : Optional[Callable[[int, NDArrays, Dict[str, Scalar]], Optional[Tuple[float, Dict[str, Scalar]]]]]
Optional function used for validation. Defaults to None.
on_fit_config_fn : Callable[[int], Dict[str, Scalar]], optional
Function used to configure training. Defaults to None.
on_evaluate_config_fn : Callable[[int], Dict[str, Scalar]], optional
Function used to configure validation. Defaults to None.
accept_failures : bool, optional
Whether or not accept rounds containing failures. Defaults to True.
initial_parameters : Parameters, optional
Initial global model parameters.
"""
if (
min_fit_clients > min_available_clients
or min_evaluate_clients > min_available_clients
):
log(WARNING, WARNING_MIN_AVAILABLE_CLIENTS_TOO_LOW)
super().__init__(
fraction_fit=fraction_fit,
fraction_evaluate=fraction_evaluate,
min_fit_clients=min_fit_clients,
min_evaluate_clients=min_evaluate_clients,
min_available_clients=min_available_clients,
evaluate_fn=evaluate_fn,
on_fit_config_fn=on_fit_config_fn,
on_evaluate_config_fn=on_evaluate_config_fn,
accept_failures=accept_failures,
initial_parameters=initial_parameters,
fit_metrics_aggregation_fn=fit_metrics_aggregation_fn,
evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn,
)
def evaluate(
self, server_round: int, parameters: Parameters
) -> Optional[Tuple[float, Dict[str, Scalar]]]:
"""Evaluate model parameters using an evaluation function."""
if self.evaluate_fn is None:
# No evaluation function provided
return None
# We deserialize using our custom method
parameters_ndarrays = sparse_parameters_to_ndarrays(parameters)
eval_res = self.evaluate_fn(server_round, parameters_ndarrays, {})
if eval_res is None:
return None
loss, metrics = eval_res
return loss, metrics
def aggregate_fit(
self,
server_round: int,
results: List[Tuple[ClientProxy, FitRes]],
failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],
) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:
"""Aggregate fit results using weighted average."""
if not results:
return None, {}
# Do not aggregate if there are failures and failures are not accepted
if not self.accept_failures and failures:
return None, {}
# We deserialize each of the results with our custom method
weights_results = [
(sparse_parameters_to_ndarrays(fit_res.parameters), fit_res.num_examples)
for _, fit_res in results
]
# We serialize the aggregated result using our custom method
parameters_aggregated = ndarrays_to_sparse_parameters(
aggregate(weights_results)
)
# Aggregate custom metrics if aggregation fn was provided
metrics_aggregated = {}
if self.fit_metrics_aggregation_fn:
fit_metrics = [(res.num_examples, res.metrics) for _, res in results]
metrics_aggregated = self.fit_metrics_aggregation_fn(fit_metrics)
elif server_round == 1: # Only log this warning once
log(WARNING, "No fit_metrics_aggregation_fn provided")
return parameters_aggregated, metrics_aggregated
Nous pouvons maintenant exécuter notre exemple de sérialisation personnalisée !
[ ]:
strategy = FedSparse()
fl.simulation.start_simulation(
strategy=strategy,
client_fn=client_fn,
num_clients=2,
config=fl.server.ServerConfig(num_rounds=3),
client_resources=client_resources,
)
Récapitulation#
Dans cette partie du tutoriel, nous avons vu comment construire des clients en sous-classant soit NumPyClient
, soit Client
. NumPyClient
est une abstraction de commodité qui facilite le travail avec les bibliothèques d’apprentissage automatique qui ont une bonne interopérabilité NumPy. Client
est une abstraction plus flexible qui nous permet de faire des choses qui ne sont pas possibles dans NumPyClient
. Pour ce faire, elle nous oblige à gérer nous-mêmes la sérialisation et la désérialisation des paramètres.
Prochaines étapes#
Avant de continuer, n’oublie pas de rejoindre la communauté Flower sur Slack : Join Slack
Il existe un canal dédié aux questions
si vous avez besoin d’aide, mais nous aimerions aussi savoir qui vous êtes dans #introductions
!
C’est la dernière partie du tutoriel Flower (pour l’instant !), félicitations ! Tu es maintenant bien équipé pour comprendre le reste de la documentation. Il y a de nombreux sujets que nous n’avons pas abordés dans le tutoriel, nous te recommandons les ressources suivantes :