Exemple : JAX - Exécuter JAX Federated#
This tutorial will show you how to use Flower to build a federated version of an existing JAX workload. We are using JAX to train a linear regression model on a scikit-learn dataset. We will structure the example similar to our PyTorch - From Centralized To Federated walkthrough. First, we build a centralized training approach based on the Linear Regression with JAX tutorial`. Then, we build upon the centralized training code to run the training in a federated fashion.
Avant de commencer à construire notre exemple JAX, nous devons installer les paquets jax
, jaxlib
, scikit-learn
, et flwr
:
$ pip install jax jaxlib scikit-learn flwr
Régression linéaire avec JAX#
Nous commençons par une brève description du code d’entraînement centralisé basé sur un modèle Régression linéaire
. Si tu veux une explication plus approfondie de ce qui se passe, jette un coup d’œil à la documentation officielle JAX.
Créons un nouveau fichier appelé jax_training.py
avec tous les composants nécessaires pour un apprentissage traditionnel (centralisé) de la régression linéaire. Tout d’abord, les paquets JAX jax
et jaxlib
doivent être importés. En outre, nous devons importer sklearn
puisque nous utilisons make_regression
pour le jeu de données et train_test_split
pour diviser le jeu de données en un jeu d’entraînement et un jeu de test. Tu peux voir que nous n’avons pas encore importé le paquet flwr
pour l’apprentissage fédéré, ce qui sera fait plus tard.
from typing import Dict, List, Tuple, Callable
import jax
import jax.numpy as jnp
from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split
key = jax.random.PRNGKey(0)
La fonction load_data()
charge les ensembles d’entraînement et de test mentionnés.
def load_data() -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray], List[np.ndarray]]:
# create our dataset and start with similar datasets for different clients
X, y = make_regression(n_features=3, random_state=0)
X, X_test, y, y_test = train_test_split(X, y)
return X, y, X_test, y_test
L’architecture du modèle (un modèle Régression linéaire
très simple) est définie dans load_model()
.
def load_model(model_shape) -> Dict:
# model weights
params = {
'b' : jax.random.uniform(key),
'w' : jax.random.uniform(key, model_shape)
}
return params
Nous devons maintenant définir l’entraînement (fonction train()
), qui boucle sur l’ensemble d’entraînement et mesure la perte (fonction loss_fn()
) pour chaque lot d’exemples d’entraînement. La fonction de perte est séparée puisque JAX prend des dérivés avec une fonction grad()
(définie dans la fonction main()
et appelée dans train()
).
def loss_fn(params, X, y) -> Callable:
err = jnp.dot(X, params['w']) + params['b'] - y
return jnp.mean(jnp.square(err)) # mse
def train(params, grad_fn, X, y) -> Tuple[np.array, float, int]:
num_examples = X.shape[0]
for epochs in range(10):
grads = grad_fn(params, X, y)
params = jax.tree_multimap(lambda p, g: p - 0.05 * g, params, grads)
loss = loss_fn(params,X, y)
# if epochs % 10 == 9:
# print(f'For Epoch {epochs} loss {loss}')
return params, loss, num_examples
L’évaluation du modèle est définie dans la fonction evaluation()
. La fonction prend tous les exemples de test et mesure la perte du modèle de régression linéaire.
def evaluation(params, grad_fn, X_test, y_test) -> Tuple[float, int]:
num_examples = X_test.shape[0]
err_test = loss_fn(params, X_test, y_test)
loss_test = jnp.mean(jnp.square(err_test))
# print(f'Test loss {loss_test}')
return loss_test, num_examples
Après avoir défini le chargement des données, l’architecture du modèle, l’entraînement et l’évaluation, nous pouvons tout assembler et entraîner notre modèle à l’aide de JAX. Comme nous l’avons déjà mentionné, la fonction jax.grad()
est définie dans main()
et transmise à train()
.
def main():
X, y, X_test, y_test = load_data()
model_shape = X.shape[1:]
grad_fn = jax.grad(loss_fn)
print("Model Shape", model_shape)
params = load_model(model_shape)
params, loss, num_examples = train(params, grad_fn, X, y)
evaluation(params, grad_fn, X_test, y_test)
if __name__ == "__main__":
main()
Tu peux maintenant exécuter ta charge de travail (centralisée) de régression linéaire JAX :
python3 jax_training.py
Jusqu’à présent, tout cela devrait te sembler assez familier si tu as déjà utilisé JAX. Passons à l’étape suivante et utilisons ce que nous avons construit pour créer un simple système d’apprentissage fédéré composé d’un serveur et de deux clients.
JAX rencontre Flower#
Le concept de fédération d’une charge de travail existante est toujours le même et facile à comprendre. Nous devons démarrer un serveur, puis utiliser le code dans jax_training.py
pour les clients qui sont connectés au serveur.Le serveur envoie les paramètres du modèle aux clients.Les clients exécutent la formation et mettent à jour les paramètres.Les paramètres mis à jour sont renvoyés au serveur, qui fait la moyenne de toutes les mises à jour de paramètres reçues.Ceci décrit un tour du processus d’apprentissage fédéré, et nous répétons cette opération pour plusieurs tours.
Notre exemple consiste en un serveur et deux clients. Commençons par configurer server.py
. Le serveur doit importer le paquet Flower flwr
. Ensuite, nous utilisons la fonction start_server
pour démarrer un serveur et lui demander d’effectuer trois cycles d’apprentissage fédéré.
import flwr as fl
if __name__ == "__main__":
fl.server.start_server(server_address="0.0.0.0:8080", config=fl.server.ServerConfig(num_rounds=3))
Nous pouvons déjà démarrer le serveur :
python3 server.py
Enfin, nous allons définir la logique de notre client dans client.py
et nous appuyer sur la formation JAX définie précédemment dans jax_training.py
. Notre client doit importer flwr
, mais aussi jax
et jaxlib
pour mettre à jour les paramètres de notre modèle JAX :
from typing import Dict, List, Callable, Tuple
import flwr as fl
import numpy as np
import jax
import jax.numpy as jnp
import jax_training
L’implémentation d’un client Flower signifie essentiellement l’implémentation d’une sous-classe de flwr.client.Client
ou flwr.client.NumPyClient
. Notre implémentation sera basée sur flwr.client.NumPyClient
et nous l’appellerons FlowerClient
. NumPyClient
est légèrement plus facile à implémenter que Client
si vous utilisez un framework avec une bonne interopérabilité NumPy (comme JAX) parce qu’il évite une partie du boilerplate qui serait autrement nécessaire. FlowerClient
doit implémenter quatre méthodes, deux méthodes pour obtenir/régler les paramètres du modèle, une méthode pour former le modèle, et une méthode pour tester le modèle :
set_parameters (optional)
règle les paramètres du modèle local reçus du serveur
transforme les paramètres en NumPy
ndarray
’sboucle sur la liste des paramètres du modèle reçus sous forme de NumPy
ndarray
’s (pensez à la liste des couches du réseau neuronal)
get_parameters
récupère les paramètres du modèle et les renvoie sous forme de liste de
ndarray
NumPy (ce qui correspond à ce queflwr.client.NumPyClient
attend)
fit
mettre à jour les paramètres du modèle local avec les paramètres reçus du serveur
entraîne le modèle sur l’ensemble d’apprentissage local
récupère les paramètres du modèle local mis à jour et les renvoie au serveur
évaluer
mettre à jour les paramètres du modèle local avec les paramètres reçus du serveur
évaluer le modèle mis à jour sur l’ensemble de test local
renvoie la perte locale au serveur
La partie la plus difficile consiste à transformer les paramètres du modèle JAX de DeviceArray
en NumPy ndarray
pour les rendre compatibles avec NumPyClient.
Les deux méthodes NumPyClient
fit
et evaluate
utilisent les fonctions train()
et evaluate()
définies précédemment dans jax_training.py
. Ce que nous faisons vraiment ici, c’est que nous indiquons à Flower, par le biais de notre sous-classe NumPyClient
, laquelle de nos fonctions déjà définies doit être appelée pour l’entraînement et l’évaluation. Nous avons inclus des annotations de type pour te donner une meilleure compréhension des types de données qui sont transmis.
class FlowerClient(fl.client.NumPyClient):
"""Flower client implementing using linear regression and JAX."""
def __init__(
self,
params: Dict,
grad_fn: Callable,
train_x: List[np.ndarray],
train_y: List[np.ndarray],
test_x: List[np.ndarray],
test_y: List[np.ndarray],
) -> None:
self.params= params
self.grad_fn = grad_fn
self.train_x = train_x
self.train_y = train_y
self.test_x = test_x
self.test_y = test_y
def get_parameters(self, config) -> Dict:
# Return model parameters as a list of NumPy ndarrays
parameter_value = []
for _, val in self.params.items():
parameter_value.append(np.array(val))
return parameter_value
def set_parameters(self, parameters: List[np.ndarray]) -> Dict:
# Collect model parameters and update the parameters of the local model
value=jnp.ndarray
params_item = list(zip(self.params.keys(),parameters))
for item in params_item:
key = item[0]
value = item[1]
self.params[key] = value
return self.params
def fit(
self, parameters: List[np.ndarray], config: Dict
) -> Tuple[List[np.ndarray], int, Dict]:
# Set model parameters, train model, return updated model parameters
print("Start local training")
self.params = self.set_parameters(parameters)
self.params, loss, num_examples = jax_training.train(self.params, self.grad_fn, self.train_x, self.train_y)
results = {"loss": float(loss)}
print("Training results", results)
return self.get_parameters(config={}), num_examples, results
def evaluate(
self, parameters: List[np.ndarray], config: Dict
) -> Tuple[float, int, Dict]:
# Set model parameters, evaluate the model on a local test dataset, return result
print("Start evaluation")
self.params = self.set_parameters(parameters)
loss, num_examples = jax_training.evaluation(self.params,self.grad_fn, self.test_x, self.test_y)
print("Evaluation accuracy & loss", loss)
return (
float(loss),
num_examples,
{"loss": float(loss)},
)
Après avoir défini le processus de fédération, nous pouvons l’exécuter.
def main() -> None:
"""Load data, start MNISTClient."""
# Load data
train_x, train_y, test_x, test_y = jax_training.load_data()
grad_fn = jax.grad(jax_training.loss_fn)
# Load model (from centralized training) and initialize parameters
model_shape = train_x.shape[1:]
params = jax_training.load_model(model_shape)
# Start Flower client
client = FlowerClient(params, grad_fn, train_x, train_y, test_x, test_y)
fl.client.start_client(server_address="0.0.0.0:8080", client.to_client())
if __name__ == "__main__":
main()
Tu peux maintenant ouvrir deux autres fenêtres de terminal et exécuter les commandes suivantes
python3 client.py
dans chaque fenêtre (assure-toi que le serveur est toujours en cours d’exécution avant de le faire) et tu verras que ton projet JAX exécute l’apprentissage fédéré sur deux clients. Félicitations !
Prochaines étapes#
The source code of this example was improved over time and can be found here: Quickstart JAX. Our example is somewhat over-simplified because both clients load the same dataset.
Tu es maintenant prêt à approfondir ce sujet. Pourquoi ne pas utiliser un modèle plus sophistiqué ou un ensemble de données différent ? Pourquoi ne pas ajouter d’autres clients ?