PyTorch快速入门#

在本教程中,我们将学习如何使用 Flower 和 PyTorch 在 CIFAR10 上训练卷积神经网络。

First of all, it is recommended to create a virtual environment and run everything within a virtualenv.

我们的例子包括一个*服务器*和两个*客户端*,它们都有相同的模型。

客户端*负责在其本地数据集上更新模型参数。然后,这些参数会被发送到*服务器,由*服务器*聚合后生成一个更好的模型。最后,服务器*将改进后的模型发送回每个*客户端。一个完整的模型参数更新周期称为一*轮*。

现在,我们已经有了一个大致的概念了,那就让我们开始吧。首先,我们需要安装 Flower。可以通过运行 :

$ pip install flwr

既然我们想用 PyTorch 解决计算机视觉任务,那就继续安装 PyTorch 和 torchvision 库吧:

$ pip install torch torchvision

Flower 客户端#

现在我们已经安装了所有的依赖项,让我们用两个客户端和一个服务器来运行一个简单的分布式训练。我们的训练过程和网络架构基于 PyTorch 的《Deep Learning with PyTorch <https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html>`_》。

在名为 client.py 的文件中,导入 Flower 和 PyTorch 相关软件包:

from collections import OrderedDict

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10

import flwr as fl

此外,我们还在 PyTorch 中定义了设备分配:

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

我们使用 PyTorch 来加载 CIFAR10,这是一个用于机器学习的流行彩色图像分类数据集。PyTorch :code:`DataLoader()`下载训练数据和测试数据,然后进行归一化处理。

def load_data():
    """Load CIFAR-10 (training and test set)."""
    transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
    )
    trainset = CIFAR10(".", train=True, download=True, transform=transform)
    testset = CIFAR10(".", train=False, download=True, transform=transform)
    trainloader = DataLoader(trainset, batch_size=32, shuffle=True)
    testloader = DataLoader(testset, batch_size=32)
    num_examples = {"trainset" : len(trainset), "testset" : len(testset)}
    return trainloader, testloader, num_examples

使用 PyTorch 定义损失和优化器。数据集的训练是通过循环数据集、测量相应的损失值并对其进行优化来完成的。

def train(net, trainloader, epochs):
    """Train the network on the training set."""
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
    for _ in range(epochs):
        for images, labels in trainloader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            loss = criterion(net(images), labels)
            loss.backward()
            optimizer.step()

然后定义机器学习网络的验证。我们在测试集上循环,计算测试集的损失值和准确率。

def test(net, testloader):
    """Validate the network on the entire test set."""
    criterion = torch.nn.CrossEntropyLoss()
    correct, total, loss = 0, 0, 0.0
    with torch.no_grad():
        for data in testloader:
            images, labels = data[0].to(DEVICE), data[1].to(DEVICE)
            outputs = net(images)
            loss += criterion(outputs, labels).item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = correct / total
    return loss, accuracy

在定义了 PyTorch 机器学习模型的训练和测试之后,我们将这些功能用于 Flower 客户端。

Flower 客户端将使用一个简单的从“PyTorch: 60 分钟突击"改编的CNN:

class Net(nn.Module):
    def __init__(self) -> None:
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# Load model and data
net = Net().to(DEVICE)
trainloader, testloader, num_examples = load_data()

使用 load_data() 加载数据集后,我们定义了 Flower 接口。

Flower 服务器通过一个名为 Client 的接口与客户端交互。当服务器选择一个特定的客户端进行训练时,它会通过网络发送训练指令。客户端接收到这些指令后,会调用 Client 方法之一来运行您的代码(即训练我们之前定义的神经网络)。

Flower 提供了一个名为 NumPyClient 的便捷类,当您的工作负载使用 PyTorch 时,它使 Client 接口的实现变得更容易。实现 NumPyClient 通常意味着定义以下方法(set_parameters 是可选的):

  1. get_parameters
    • 以 NumPy ndarrays 列表形式返回模型参数

  2. set_parameters (可选)
    • 用从服务器接收到的参数更新本地模型参数

  3. fit
    • 设置本地模型参数

    • 训练本地模型

    • 接收更新的本地模型参数

  4. evaluate
    • 测试本地模型

可以通过以下方式实现:

class CifarClient(fl.client.NumPyClient):
    def get_parameters(self, config):
        return [val.cpu().numpy() for _, val in net.state_dict().items()]

    def set_parameters(self, parameters):
        params_dict = zip(net.state_dict().keys(), parameters)
        state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
        net.load_state_dict(state_dict, strict=True)

    def fit(self, parameters, config):
        self.set_parameters(parameters)
        train(net, trainloader, epochs=1)
        return self.get_parameters(config={}), num_examples["trainset"], {}

    def evaluate(self, parameters, config):
        self.set_parameters(parameters)
        loss, accuracy = test(net, testloader)
        return float(loss), num_examples["testset"], {"accuracy": float(accuracy)}

现在我们可以创建一个 CifarClient 类的实例,并添加一行来实际运行该客户端:

fl.client.start_client(server_address="[::]:8080", client=CifarClient().to_client())

That's it for the client. We only have to implement Client or NumPyClient and call fl.client.start_client(). If you implement a client of type NumPyClient you'll need to first call its to_client() method. The string "[::]:8080" tells the client which server to connect to. In our case we can run the server and the client on the same machine, therefore we use "[::]:8080". If we run a truly federated workload with the server and clients running on different machines, all that needs to change is the server_address we point the client at.

Flower 服务器#

对于简单的工作负载,我们可以启动 Flower 服务器,并将所有配置选项保留为默认值。在名为 server.py 的文件中,导入 Flower 并启动服务器:

import flwr as fl

fl.server.start_server(config=fl.server.ServerConfig(num_rounds=3))

联邦训练模型!#

客户端和服务器都已准备就绪,我们现在可以运行一切,看看联邦学习的实际效果。FL 系统通常有一个服务器和多个客户端。因此,我们必须先启动服务器:

$ python server.py

服务器运行后,我们就可以在不同终端启动客户端了。打开一个新终端,启动第一个客户端:

$ python client.py

打开另一台终端,启动第二个客户端:

$ python client.py

每个客户端都有自己的数据集。现在你应该看到第一个终端(启动服务器的终端)的训练效果了:

INFO flower 2021-02-25 14:00:27,227 | app.py:76 | Flower server running (insecure, 3 rounds)
INFO flower 2021-02-25 14:00:27,227 | server.py:72 | Getting initial parameters
INFO flower 2021-02-25 14:01:15,881 | server.py:74 | Evaluating initial parameters
INFO flower 2021-02-25 14:01:15,881 | server.py:87 | [TIME] FL starting
DEBUG flower 2021-02-25 14:01:41,310 | server.py:165 | fit_round: strategy sampled 2 clients (out of 2)
DEBUG flower 2021-02-25 14:02:00,256 | server.py:177 | fit_round received 2 results and 0 failures
DEBUG flower 2021-02-25 14:02:00,262 | server.py:139 | evaluate: strategy sampled 2 clients
DEBUG flower 2021-02-25 14:02:03,047 | server.py:149 | evaluate received 2 results and 0 failures
DEBUG flower 2021-02-25 14:02:03,049 | server.py:165 | fit_round: strategy sampled 2 clients (out of 2)
DEBUG flower 2021-02-25 14:02:23,908 | server.py:177 | fit_round received 2 results and 0 failures
DEBUG flower 2021-02-25 14:02:23,915 | server.py:139 | evaluate: strategy sampled 2 clients
DEBUG flower 2021-02-25 14:02:27,120 | server.py:149 | evaluate received 2 results and 0 failures
DEBUG flower 2021-02-25 14:02:27,122 | server.py:165 | fit_round: strategy sampled 2 clients (out of 2)
DEBUG flower 2021-02-25 14:03:04,660 | server.py:177 | fit_round received 2 results and 0 failures
DEBUG flower 2021-02-25 14:03:04,671 | server.py:139 | evaluate: strategy sampled 2 clients
DEBUG flower 2021-02-25 14:03:09,273 | server.py:149 | evaluate received 2 results and 0 failures
INFO flower 2021-02-25 14:03:09,273 | server.py:122 | [TIME] FL finished in 113.39180790000046
INFO flower 2021-02-25 14:03:09,274 | app.py:109 | app_fit: losses_distributed [(1, 650.9747924804688), (2, 526.2535400390625), (3, 473.76959228515625)]
INFO flower 2021-02-25 14:03:09,274 | app.py:110 | app_fit: accuracies_distributed []
INFO flower 2021-02-25 14:03:09,274 | app.py:111 | app_fit: losses_centralized []
INFO flower 2021-02-25 14:03:09,274 | app.py:112 | app_fit: accuracies_centralized []
DEBUG flower 2021-02-25 14:03:09,276 | server.py:139 | evaluate: strategy sampled 2 clients
DEBUG flower 2021-02-25 14:03:11,852 | server.py:149 | evaluate received 2 results and 0 failures
INFO flower 2021-02-25 14:03:11,852 | app.py:121 | app_evaluate: federated loss: 473.76959228515625
INFO flower 2021-02-25 14:03:11,852 | app.py:122 | app_evaluate: results [('ipv6:[::1]:36602', EvaluateRes(loss=351.4906005859375, num_examples=10000, accuracy=0.0, metrics={'accuracy': 0.6067})), ('ipv6:[::1]:36604', EvaluateRes(loss=353.92742919921875, num_examples=10000, accuracy=0.0, metrics={'accuracy': 0.6005}))]
INFO flower 2021-02-25 14:03:27,514 | app.py:127 | app_evaluate: failures []

恭喜您!您已经成功构建并运行了第一个联邦学习系统。本示例的`完整源代码 <https://github.com/adap/flower/blob/main/examples/quickstart-pytorch/client.py>`_ 可以在 examples/quickstart-pytorch 中找到。