Exemple : PyTorch - De la centralisation à la fédération#
Ce tutoriel te montrera comment utiliser Flower pour construire une version fédérée d’une charge de travail d’apprentissage automatique existante. Nous utilisons PyTorch pour entraîner un réseau neuronal convolutif sur l’ensemble de données CIFAR-10. Tout d’abord, nous présentons cette tâche d’apprentissage automatique avec une approche d’entraînement centralisée basée sur le tutoriel Deep Learning with PyTorch. Ensuite, nous nous appuyons sur le code d’entraînement centralisé pour exécuter l’entraînement de manière fédérée.
Formation centralisée#
Nous commençons par une brève description du code d’entraînement CNN centralisé. Si tu veux une explication plus approfondie de ce qui se passe, jette un coup d’œil au tutoriel officiel PyTorch.
Créons un nouveau fichier appelé cifar.py
avec tous les composants requis pour une formation traditionnelle (centralisée) sur le CIFAR-10. Tout d’abord, tous les paquets requis (tels que torch
et torchvision
) doivent être importés. Tu peux voir que nous n’importons aucun paquet pour l’apprentissage fédéré. Tu peux conserver toutes ces importations telles quelles même lorsque nous ajouterons les composants d’apprentissage fédéré à un moment ultérieur.
from typing import Tuple, Dict
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch import Tensor
from torchvision.datasets import CIFAR10
Comme nous l’avons déjà mentionné, nous utiliserons l’ensemble de données CIFAR-10 pour cette charge de travail d’apprentissage automatique. L’architecture du modèle (un réseau neuronal convolutif très simple) est définie dans class Net()
.
class Net(nn.Module):
def __init__(self) -> None:
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x: Tensor) -> Tensor:
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
La fonction load_data()
charge les ensembles d’entraînement et de test CIFAR-10. La fonction transform
normalise les données après leur chargement.
DATA_ROOT = "~/data/cifar-10"
def load_data() -> Tuple[torch.utils.data.DataLoader, torch.utils.data.DataLoader, Dict]:
"""Load CIFAR-10 (training and test set)."""
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)
trainset = CIFAR10(DATA_ROOT, train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True)
testset = CIFAR10(DATA_ROOT, train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=32, shuffle=False)
num_examples = {"trainset" : len(trainset), "testset" : len(testset)}
return trainloader, testloader, num_examples
Nous devons maintenant définir la formation (fonction train()
) qui passe en boucle sur l’ensemble de la formation, mesure la perte, la rétropropage, puis effectue une étape d’optimisation pour chaque lot d’exemples de formation.
L’évaluation du modèle est définie dans la fonction test()
. La fonction boucle sur tous les échantillons de test et mesure la perte du modèle en fonction de l’ensemble des données de test.
def train(
net: Net,
trainloader: torch.utils.data.DataLoader,
epochs: int,
device: torch.device,
) -> None:
"""Train the network."""
# Define loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
print(f"Training {epochs} epoch(s) w/ {len(trainloader)} batches each")
# Train the network
for epoch in range(epochs): # loop over the dataset multiple times
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
images, labels = data[0].to(device), data[1].to(device)
# zero the parameter gradients
optimizer.zero_grad()
# forward + backward + optimize
outputs = net(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# print statistics
running_loss += loss.item()
if i % 100 == 99: # print every 100 mini-batches
print("[%d, %5d] loss: %.3f" % (epoch + 1, i + 1, running_loss / 2000))
running_loss = 0.0
def test(
net: Net,
testloader: torch.utils.data.DataLoader,
device: torch.device,
) -> Tuple[float, float]:
"""Validate the network on the entire test set."""
criterion = nn.CrossEntropyLoss()
correct = 0
total = 0
loss = 0.0
with torch.no_grad():
for data in testloader:
images, labels = data[0].to(device), data[1].to(device)
outputs = net(images)
loss += criterion(outputs, labels).item()
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
accuracy = correct / total
return loss, accuracy
Après avoir défini le chargement des données, l’architecture du modèle, la formation et l’évaluation, nous pouvons tout mettre ensemble et former notre CNN sur CIFAR-10.
def main():
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Centralized PyTorch training")
print("Load data")
trainloader, testloader, _ = load_data()
print("Start training")
net=Net().to(DEVICE)
train(net=net, trainloader=trainloader, epochs=2, device=DEVICE)
print("Evaluate model")
loss, accuracy = test(net=net, testloader=testloader, device=DEVICE)
print("Loss: ", loss)
print("Accuracy: ", accuracy)
if __name__ == "__main__":
main()
Tu peux maintenant exécuter ta charge de travail d’apprentissage automatique :
python3 cifar.py
Jusqu’à présent, tout cela devrait te sembler assez familier si tu as déjà utilisé PyTorch. Passons à l’étape suivante et utilisons ce que nous avons construit pour créer un simple système d’apprentissage fédéré composé d’un serveur et de deux clients.
Formation fédérée#
Le projet simple d’apprentissage automatique discuté dans la section précédente entraîne le modèle sur un seul ensemble de données (CIFAR-10), nous appelons cela l’apprentissage centralisé. Ce concept d’apprentissage centralisé, comme le montre la section précédente, est probablement connu de la plupart d’entre vous, et beaucoup d’entre vous l’ont déjà utilisé. Normalement, si tu veux exécuter des charges de travail d’apprentissage automatique de manière fédérée, tu dois alors changer la plupart de ton code et tout mettre en place à partir de zéro, ce qui peut représenter un effort considérable.
Cependant, avec Flower, tu peux faire évoluer ton code préexistant vers une configuration d’apprentissage fédéré sans avoir besoin d’une réécriture majeure.
Le concept est facile à comprendre. Nous devons démarrer un serveur et utiliser le code dans cifar.py
pour les clients qui sont connectés au serveur. Le serveur envoie les paramètres du modèle aux clients. Les clients exécutent la formation et mettent à jour les paramètres. Les paramètres mis à jour sont renvoyés au serveur qui fait la moyenne de toutes les mises à jour de paramètres reçues. Ceci décrit un tour du processus d’apprentissage fédéré et nous répétons cette opération pour plusieurs tours.
Notre exemple consiste en un serveur et deux clients. Commençons par configurer server.py
. Le serveur doit importer le paquet Flower flwr
. Ensuite, nous utilisons la fonction start_server
pour démarrer un serveur et lui demander d’effectuer trois cycles d’apprentissage fédéré.
import flwr as fl
if __name__ == "__main__":
fl.server.start_server(server_address="0.0.0.0:8080", config=fl.server.ServerConfig(num_rounds=3))
Nous pouvons déjà démarrer le serveur :
python3 server.py
Enfin, nous allons définir notre logique client dans client.py
et nous appuyer sur la formation centralisée définie précédemment dans cifar.py
. Notre client doit importer flwr
, mais aussi torch
pour mettre à jour les paramètres de notre modèle PyTorch :
from collections import OrderedDict
from typing import Dict, List, Tuple
import numpy as np
import torch
import cifar
import flwr as fl
DEVICE: str = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
Implementing a Flower client basically means implementing a subclass of either flwr.client.Client
or flwr.client.NumPyClient
.
Our implementation will be based on flwr.client.NumPyClient
and we’ll call it CifarClient
.
NumPyClient
is slightly easier to implement than Client
if you use a framework with good NumPy interoperability (like PyTorch or TensorFlow/Keras) because it avoids some of the boilerplate that would otherwise be necessary.
CifarClient
needs to implement four methods, two methods for getting/setting model parameters, one method for training the model, and one method for testing the model:
set_parameters
règle les paramètres du modèle local reçus du serveur
boucle sur la liste des paramètres du modèle reçus sous forme de NumPy
ndarray
’s (pensez à la liste des couches du réseau neuronal)
get_parameters
récupère les paramètres du modèle et les renvoie sous forme de liste de
ndarray
NumPy (ce qui correspond à ce queflwr.client.NumPyClient
attend)
fit
mettre à jour les paramètres du modèle local avec les paramètres reçus du serveur
entraîne le modèle sur l’ensemble d’apprentissage local
récupère les poids du modèle local mis à jour et les renvoie au serveur
évaluer
mettre à jour les paramètres du modèle local avec les paramètres reçus du serveur
évaluer le modèle mis à jour sur l’ensemble de test local
renvoie la perte locale et la précision au serveur
Les deux méthodes NumPyClient
fit
et evaluate
utilisent les fonctions train()
et test()
définies précédemment dans cifar.py
. Ce que nous faisons vraiment ici, c’est que nous indiquons à Flower, par le biais de notre sous-classe NumPyClient
, laquelle de nos fonctions déjà définies doit être appelée pour l’entraînement et l’évaluation. Nous avons inclus des annotations de type pour te donner une meilleure compréhension des types de données qui sont transmis.
class CifarClient(fl.client.NumPyClient):
"""Flower client implementing CIFAR-10 image classification using
PyTorch."""
def __init__(
self,
model: cifar.Net,
trainloader: torch.utils.data.DataLoader,
testloader: torch.utils.data.DataLoader,
num_examples: Dict,
) -> None:
self.model = model
self.trainloader = trainloader
self.testloader = testloader
self.num_examples = num_examples
def get_parameters(self, config) -> List[np.ndarray]:
# Return model parameters as a list of NumPy ndarrays
return [val.cpu().numpy() for _, val in self.model.state_dict().items()]
def set_parameters(self, parameters: List[np.ndarray]) -> None:
# Set model parameters from a list of NumPy ndarrays
params_dict = zip(self.model.state_dict().keys(), parameters)
state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
self.model.load_state_dict(state_dict, strict=True)
def fit(
self, parameters: List[np.ndarray], config: Dict[str, str]
) -> Tuple[List[np.ndarray], int, Dict]:
# Set model parameters, train model, return updated model parameters
self.set_parameters(parameters)
cifar.train(self.model, self.trainloader, epochs=1, device=DEVICE)
return self.get_parameters(config={}), self.num_examples["trainset"], {}
def evaluate(
self, parameters: List[np.ndarray], config: Dict[str, str]
) -> Tuple[float, int, Dict]:
# Set model parameters, evaluate model on local test dataset, return result
self.set_parameters(parameters)
loss, accuracy = cifar.test(self.model, self.testloader, device=DEVICE)
return float(loss), self.num_examples["testset"], {"accuracy": float(accuracy)}
All that’s left to do it to define a function that loads both model and data, creates a CifarClient
, and starts this client.
You load your data and model by using cifar.py
. Start CifarClient
with the function fl.client.start_client()
by pointing it at the same IP address we used in server.py
:
def main() -> None:
"""Load data, start CifarClient."""
# Load model and data
model = cifar.Net()
model.to(DEVICE)
trainloader, testloader, num_examples = cifar.load_data()
# Start client
client = CifarClient(model, trainloader, testloader, num_examples)
fl.client.start_client(server_address="0.0.0.0:8080", client.to_client())
if __name__ == "__main__":
main()
Tu peux maintenant ouvrir deux autres fenêtres de terminal et exécuter les commandes suivantes
python3 client.py
dans chaque fenêtre (assure-toi que le serveur fonctionne avant de le faire) et tu verras ton projet PyTorch (auparavant centralisé) exécuter l’apprentissage fédéré sur deux clients. Félicitations !
Prochaines étapes#
The full source code for this example: PyTorch: From Centralized To Federated (Code). Our example is, of course, somewhat over-simplified because both clients load the exact same dataset, which isn’t realistic. You’re now prepared to explore this topic further. How about using different subsets of CIFAR-10 on each client? How about adding more clients?