快速入门 JAX#
本教程将向您展示如何使用 Flower 构建现有 JAX 的联邦学习版本。我们将使用 JAX 在 scikit-learn 数据集上训练线性回归模型。我们将采用与 PyTorch - 从集中式到联邦式 教程中类似的示例结构。首先,我们根据 JAX 的线性回归 教程构建集中式训练方法。然后,我们在集中式训练代码的基础上以联邦方式运行训练。
在开始构建 JAX 示例之前,我们需要安装软件包 jax
、jaxlib
、scikit-learn
和 flwr
:
$ pip install jax jaxlib scikit-learn flwr
使用 JAX 进行线性回归#
首先,我们将简要介绍基于 Linear Regression
模型的集中式训练代码。如果您想获得更深入的解释,请参阅官方的 JAX 文档。
让我们创建一个名为 jax_training.py
的新文件,其中包含传统(集中式)线性回归训练所需的所有组件。首先,需要导入 JAX 包 jax
和 jaxlib
。此外,我们还需要导入 sklearn
,因为我们使用 make_regression
创建数据集,并使用 train_test_split
将数据集拆分成训练集和测试集。您可以看到,我们还没有导入用于联邦学习的 flwr
软件包,这将在稍后完成。
from typing import Dict, List, Tuple, Callable
import jax
import jax.numpy as jnp
from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split
key = jax.random.PRNGKey(0)
load_data()
函数会加载上述训练集和测试集。
def load_data() -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray], List[np.ndarray]]:
# create our dataset and start with similar datasets for different clients
X, y = make_regression(n_features=3, random_state=0)
X, X_test, y, y_test = train_test_split(X, y)
return X, y, X_test, y_test
模型结构(一个非常简单的 Linear Regression
线性回归模型)在 load_model()
中定义。
def load_model(model_shape) -> Dict:
# model weights
params = {
'b' : jax.random.uniform(key),
'w' : jax.random.uniform(key, model_shape)
}
return params
现在,我们需要定义训练函数( train()
)。它循环遍历训练集,并计算每批训练数据的损失值(函数 loss_fn()
)。由于 JAX 使用 grad()
函数提取导数(在 main()
函数中定义,并在 train()
中调用),因此损失函数是独立的。
def loss_fn(params, X, y) -> Callable:
err = jnp.dot(X, params['w']) + params['b'] - y
return jnp.mean(jnp.square(err)) # mse
def train(params, grad_fn, X, y) -> Tuple[np.array, float, int]:
num_examples = X.shape[0]
for epochs in range(10):
grads = grad_fn(params, X, y)
params = jax.tree_multimap(lambda p, g: p - 0.05 * g, params, grads)
loss = loss_fn(params,X, y)
# if epochs % 10 == 9:
# print(f'For Epoch {epochs} loss {loss}')
return params, loss, num_examples
模型的评估在函数 evaluation()
中定义。该函数获取所有测试数据,并计算线性回归模型的损失值。
def evaluation(params, grad_fn, X_test, y_test) -> Tuple[float, int]:
num_examples = X_test.shape[0]
err_test = loss_fn(params, X_test, y_test)
loss_test = jnp.mean(jnp.square(err_test))
# print(f'Test loss {loss_test}')
return loss_test, num_examples
在定义了数据加载、模型架构、训练和评估之后,我们就可以把这些放在一起,使用 JAX 训练我们的模型了。如前所述,jax.grad()
函数在 main()
中定义,并传递给 train()
。
def main():
X, y, X_test, y_test = load_data()
model_shape = X.shape[1:]
grad_fn = jax.grad(loss_fn)
print("Model Shape", model_shape)
params = load_model(model_shape)
params, loss, num_examples = train(params, grad_fn, X, y)
evaluation(params, grad_fn, X_test, y_test)
if __name__ == "__main__":
main()
现在您可以运行(集中式)JAX 线性回归工作了:
python3 jax_training.py
到目前为止,如果你以前使用过 JAX,就会对这一切感到很熟悉。下一步,让我们利用已构建的代码创建一个简单的联邦学习系统(一个服务器和两个客户端)。
JAX 结合 Flower#
把现有工作联邦化的概念始终是相同的,也很容易理解。我们要启动一个*服务器*,然后对连接到*服务器*的*客户端*运行 :code:`jax_training.py`中的代码。服务器*向客户端发送模型参数,*客户端*运行训练并更新参数。更新后的参数被发回*服务器,然后服务器对所有收到的参数进行平均聚合。以上的描述构成了一轮联邦学习,我们将重复进行多轮学习。
我们的示例包括一个*服务器*和两个*客户端*。让我们先设置 server.py
。*服务器*需要导入 Flower 软件包 flwr
。接下来,我们使用 start_server
函数启动服务器,并让它执行三轮联邦学习。
import flwr as fl
if __name__ == "__main__":
fl.server.start_server(server_address="0.0.0.0:8080", config=fl.server.ServerConfig(num_rounds=3))
我们已经可以启动*服务器*了:
python3 server.py
最后,我们将在 client.py
中定义我们的 client 逻辑,并以之前在 jax_training.py
中定义的 JAX 训练为基础。我们的 client 需要导入 flwr
,还需要导入 jax
和 jaxlib
以更新 JAX 模型的参数:
from typing import Dict, List, Callable, Tuple
import flwr as fl
import numpy as np
import jax
import jax.numpy as jnp
import jax_training
实现一个 Flower *client*基本上意味着去实现一个 flwr.client.Client
或 flwr.client.NumPyClient
的子类。我们的代码实现将基于 flwr.client.NumPyClient
,并将其命名为 FlowerClient
。如果使用具有良好 NumPy 互操作性的框架(如 JAX),NumPyClient
比 Client`更容易实现,因为它避免了一些不必要的操作。:code:`FlowerClient
需要实现四个方法,两个用于获取/设置模型参数,一个用于训练模型,一个用于测试模型:
set_parameters (可选)
在本地模型上设置从服务器接收的模型参数
将参数转换为 NumPy :code:`ndarray`格式
循环遍历以 NumPy
ndarray
形式接收的模型参数列表(可以看作神经网络的列表)
fit
用从服务器接收到的参数更新本地模型的参数
在本地训练集上训练模型
获取更新后的本地模型参数并返回服务器
evaluate
用从服务器接收到的参数更新本地模型的参数
在本地测试集上评估更新后的模型
向服务器返回本地损失值
具有挑战性的部分是将 JAX 模型参数从 DeviceArray
转换为 NumPy ndarray
,使其与 NumPyClient 兼容。
这两个 NumPyClient
方法 fit
和 evaluate
使用了之前在 jax_training.py
中定义的函数 train()
和 evaluate()
。因此,我们在这里要做的就是通过 NumPyClient
子类告知 Flower 在训练和评估时要调用哪些已定义的函数。我们加入了类型注解,以便让您更好地理解传递的数据类型。
class FlowerClient(fl.client.NumPyClient):
"""Flower client implementing using linear regression and JAX."""
def __init__(
self,
params: Dict,
grad_fn: Callable,
train_x: List[np.ndarray],
train_y: List[np.ndarray],
test_x: List[np.ndarray],
test_y: List[np.ndarray],
) -> None:
self.params= params
self.grad_fn = grad_fn
self.train_x = train_x
self.train_y = train_y
self.test_x = test_x
self.test_y = test_y
def get_parameters(self, config) -> Dict:
# Return model parameters as a list of NumPy ndarrays
parameter_value = []
for _, val in self.params.items():
parameter_value.append(np.array(val))
return parameter_value
def set_parameters(self, parameters: List[np.ndarray]) -> Dict:
# Collect model parameters and update the parameters of the local model
value=jnp.ndarray
params_item = list(zip(self.params.keys(),parameters))
for item in params_item:
key = item[0]
value = item[1]
self.params[key] = value
return self.params
def fit(
self, parameters: List[np.ndarray], config: Dict
) -> Tuple[List[np.ndarray], int, Dict]:
# Set model parameters, train model, return updated model parameters
print("Start local training")
self.params = self.set_parameters(parameters)
self.params, loss, num_examples = jax_training.train(self.params, self.grad_fn, self.train_x, self.train_y)
results = {"loss": float(loss)}
print("Training results", results)
return self.get_parameters(config={}), num_examples, results
def evaluate(
self, parameters: List[np.ndarray], config: Dict
) -> Tuple[float, int, Dict]:
# Set model parameters, evaluate the model on a local test dataset, return result
print("Start evaluation")
self.params = self.set_parameters(parameters)
loss, num_examples = jax_training.evaluation(self.params,self.grad_fn, self.test_x, self.test_y)
print("Evaluation accuracy & loss", loss)
return (
float(loss),
num_examples,
{"loss": float(loss)},
)
定义了联邦进程后,我们就可以运行它了。
def main() -> None:
"""Load data, start MNISTClient."""
# Load data
train_x, train_y, test_x, test_y = jax_training.load_data()
grad_fn = jax.grad(jax_training.loss_fn)
# Load model (from centralized training) and initialize parameters
model_shape = train_x.shape[1:]
params = jax_training.load_model(model_shape)
# Start Flower client
client = FlowerClient(params, grad_fn, train_x, train_y, test_x, test_y)
fl.client.start_client(server_address="0.0.0.0:8080", client=client.to_client())
if __name__ == "__main__":
main()
就是这样,现在你可以打开另外两个终端窗口,然后运行
python3 client.py
确保服务器仍在运行,然后在每个客户端窗口就能看到你的 JAX 项目在两个客户端上运行联邦学习了。祝贺!
下一步工作#
此示例的源代码经过长期改进,可在此处找到: Quickstart JAX。我们的示例有些过于简单,因为两个客户端都加载了相同的数据集。
现在,您已准备好进行更深一步探索了。例如使用更复杂的模型或使用不同的数据集会如何?增加更多客户端会如何?