PyTorch快速入门#
在本教程中,我们将学习如何使用 Flower 和 PyTorch 在 CIFAR10 上训练卷积神经网络。
First of all, it is recommended to create a virtual environment and run everything within a virtualenv.
我们的例子包括一个*服务器*和两个*客户端*,它们都有相同的模型。
客户端*负责在其本地数据集上更新模型参数。然后,这些参数会被发送到*服务器,由*服务器*聚合后生成一个更好的模型。最后,服务器*将改进后的模型发送回每个*客户端。一个完整的模型参数更新周期称为一*轮*。
现在,我们已经有了一个大致的概念了,那就让我们开始吧。首先,我们需要安装 Flower。可以通过运行 :
$ pip install flwr
既然我们想用 PyTorch 解决计算机视觉任务,那就继续安装 PyTorch 和 torchvision 库吧:
$ pip install torch torchvision
Flower 客户端#
现在我们已经安装了所有的依赖项,让我们用两个客户端和一个服务器来运行一个简单的分布式训练。我们的训练过程和网络架构基于 PyTorch 的《Deep Learning with PyTorch <https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html>`_》。
在名为 client.py
的文件中,导入 Flower 和 PyTorch 相关软件包:
from collections import OrderedDict
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
import flwr as fl
此外,我们还在 PyTorch 中定义了设备分配:
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
我们使用 PyTorch 来加载 CIFAR10,这是一个用于机器学习的流行彩色图像分类数据集。PyTorch :code:`DataLoader()`下载训练数据和测试数据,然后进行归一化处理。
def load_data():
"""Load CIFAR-10 (training and test set)."""
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)
trainset = CIFAR10(".", train=True, download=True, transform=transform)
testset = CIFAR10(".", train=False, download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=32, shuffle=True)
testloader = DataLoader(testset, batch_size=32)
num_examples = {"trainset" : len(trainset), "testset" : len(testset)}
return trainloader, testloader, num_examples
使用 PyTorch 定义损失和优化器。数据集的训练是通过循环数据集、测量相应的损失值并对其进行优化来完成的。
def train(net, trainloader, epochs):
"""Train the network on the training set."""
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
for _ in range(epochs):
for images, labels in trainloader:
images, labels = images.to(DEVICE), labels.to(DEVICE)
optimizer.zero_grad()
loss = criterion(net(images), labels)
loss.backward()
optimizer.step()
然后定义机器学习网络的验证。我们在测试集上循环,计算测试集的损失值和准确率。
def test(net, testloader):
"""Validate the network on the entire test set."""
criterion = torch.nn.CrossEntropyLoss()
correct, total, loss = 0, 0, 0.0
with torch.no_grad():
for data in testloader:
images, labels = data[0].to(DEVICE), data[1].to(DEVICE)
outputs = net(images)
loss += criterion(outputs, labels).item()
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
accuracy = correct / total
return loss, accuracy
在定义了 PyTorch 机器学习模型的训练和测试之后,我们将这些功能用于 Flower 客户端。
Flower 客户端将使用一个简单的从“PyTorch: 60 分钟突击"改编的CNN:
class Net(nn.Module):
def __init__(self) -> None:
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
# Load model and data
net = Net().to(DEVICE)
trainloader, testloader, num_examples = load_data()
使用 load_data()
加载数据集后,我们定义了 Flower 接口。
Flower 服务器通过一个名为 Client
的接口与客户端交互。当服务器选择一个特定的客户端进行训练时,它会通过网络发送训练指令。客户端接收到这些指令后,会调用 Client
方法之一来运行您的代码(即训练我们之前定义的神经网络)。
Flower 提供了一个名为 NumPyClient
的便捷类,当您的工作负载使用 PyTorch 时,它使 Client
接口的实现变得更容易。实现 NumPyClient
通常意味着定义以下方法(set_parameters
是可选的):
get_parameters
以 NumPy ndarrays 列表形式返回模型参数
set_parameters
(可选)用从服务器接收到的参数更新本地模型参数
fit
设置本地模型参数
训练本地模型
接收更新的本地模型参数
evaluate
测试本地模型
可以通过以下方式实现:
class CifarClient(fl.client.NumPyClient):
def get_parameters(self, config):
return [val.cpu().numpy() for _, val in net.state_dict().items()]
def set_parameters(self, parameters):
params_dict = zip(net.state_dict().keys(), parameters)
state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
net.load_state_dict(state_dict, strict=True)
def fit(self, parameters, config):
self.set_parameters(parameters)
train(net, trainloader, epochs=1)
return self.get_parameters(config={}), num_examples["trainset"], {}
def evaluate(self, parameters, config):
self.set_parameters(parameters)
loss, accuracy = test(net, testloader)
return float(loss), num_examples["testset"], {"accuracy": float(accuracy)}
现在我们可以创建一个 CifarClient
类的实例,并添加一行来实际运行该客户端:
fl.client.start_client(server_address="[::]:8080", client=CifarClient().to_client())
That's it for the client. We only have to implement Client
or
NumPyClient
and call fl.client.start_client()
. If you implement a client of type NumPyClient
you'll need to first call its to_client()
method. The string "[::]:8080"
tells the client which server to connect to. In our case we can run the server and the client on the same machine, therefore we use
"[::]:8080"
. If we run a truly federated workload with the server and
clients running on different machines, all that needs to change is the
server_address
we point the client at.
Flower 服务器#
对于简单的工作负载,我们可以启动 Flower 服务器,并将所有配置选项保留为默认值。在名为 server.py
的文件中,导入 Flower 并启动服务器:
import flwr as fl
fl.server.start_server(config=fl.server.ServerConfig(num_rounds=3))
联邦训练模型!#
客户端和服务器都已准备就绪,我们现在可以运行一切,看看联邦学习的实际效果。FL 系统通常有一个服务器和多个客户端。因此,我们必须先启动服务器:
$ python server.py
服务器运行后,我们就可以在不同终端启动客户端了。打开一个新终端,启动第一个客户端:
$ python client.py
打开另一台终端,启动第二个客户端:
$ python client.py
每个客户端都有自己的数据集。现在你应该看到第一个终端(启动服务器的终端)的训练效果了:
INFO flower 2021-02-25 14:00:27,227 | app.py:76 | Flower server running (insecure, 3 rounds)
INFO flower 2021-02-25 14:00:27,227 | server.py:72 | Getting initial parameters
INFO flower 2021-02-25 14:01:15,881 | server.py:74 | Evaluating initial parameters
INFO flower 2021-02-25 14:01:15,881 | server.py:87 | [TIME] FL starting
DEBUG flower 2021-02-25 14:01:41,310 | server.py:165 | fit_round: strategy sampled 2 clients (out of 2)
DEBUG flower 2021-02-25 14:02:00,256 | server.py:177 | fit_round received 2 results and 0 failures
DEBUG flower 2021-02-25 14:02:00,262 | server.py:139 | evaluate: strategy sampled 2 clients
DEBUG flower 2021-02-25 14:02:03,047 | server.py:149 | evaluate received 2 results and 0 failures
DEBUG flower 2021-02-25 14:02:03,049 | server.py:165 | fit_round: strategy sampled 2 clients (out of 2)
DEBUG flower 2021-02-25 14:02:23,908 | server.py:177 | fit_round received 2 results and 0 failures
DEBUG flower 2021-02-25 14:02:23,915 | server.py:139 | evaluate: strategy sampled 2 clients
DEBUG flower 2021-02-25 14:02:27,120 | server.py:149 | evaluate received 2 results and 0 failures
DEBUG flower 2021-02-25 14:02:27,122 | server.py:165 | fit_round: strategy sampled 2 clients (out of 2)
DEBUG flower 2021-02-25 14:03:04,660 | server.py:177 | fit_round received 2 results and 0 failures
DEBUG flower 2021-02-25 14:03:04,671 | server.py:139 | evaluate: strategy sampled 2 clients
DEBUG flower 2021-02-25 14:03:09,273 | server.py:149 | evaluate received 2 results and 0 failures
INFO flower 2021-02-25 14:03:09,273 | server.py:122 | [TIME] FL finished in 113.39180790000046
INFO flower 2021-02-25 14:03:09,274 | app.py:109 | app_fit: losses_distributed [(1, 650.9747924804688), (2, 526.2535400390625), (3, 473.76959228515625)]
INFO flower 2021-02-25 14:03:09,274 | app.py:110 | app_fit: accuracies_distributed []
INFO flower 2021-02-25 14:03:09,274 | app.py:111 | app_fit: losses_centralized []
INFO flower 2021-02-25 14:03:09,274 | app.py:112 | app_fit: accuracies_centralized []
DEBUG flower 2021-02-25 14:03:09,276 | server.py:139 | evaluate: strategy sampled 2 clients
DEBUG flower 2021-02-25 14:03:11,852 | server.py:149 | evaluate received 2 results and 0 failures
INFO flower 2021-02-25 14:03:11,852 | app.py:121 | app_evaluate: federated loss: 473.76959228515625
INFO flower 2021-02-25 14:03:11,852 | app.py:122 | app_evaluate: results [('ipv6:[::1]:36602', EvaluateRes(loss=351.4906005859375, num_examples=10000, accuracy=0.0, metrics={'accuracy': 0.6067})), ('ipv6:[::1]:36604', EvaluateRes(loss=353.92742919921875, num_examples=10000, accuracy=0.0, metrics={'accuracy': 0.6005}))]
INFO flower 2021-02-25 14:03:27,514 | app.py:127 | app_evaluate: failures []
恭喜您!您已经成功构建并运行了第一个联邦学习系统。本示例的`完整源代码 <https://github.com/adap/flower/blob/main/examples/quickstart-pytorch/client.py>`_ 可以在 examples/quickstart-pytorch
中找到。