实例: PyTorch - 从集中式到联邦式#
本教程将向您展示如何使用 Flower 构建现有机器学习工作的联邦版本。我们使用 PyTorch 在 CIFAR-10 数据集上训练一个卷积神经网络。首先,我们基于 "Deep Learning with PyTorch <https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html>`_"教程,采用集中式训练方法介绍了这项机器学习任务。然后,我们在集中式训练代码的基础上以联邦方式运行训练。
集中式训练#
我们首先简要介绍一下集中式 CNN 训练代码。如果您想获得更深入的解释,请参阅 PyTorch 官方教程`PyTorch tutorial <https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html>`_。
让我们创建一个名为 cifar.py
的新文件,其中包含 CIFAR-10 传统(集中)培训所需的所有组件。首先,需要导入所有必需的软件包(如 torch
和 torchvision
)。您可以看到,我们没有导入任何用于联邦学习的软件包。即使在以后添加联邦学习组件时,也可以保留所有这些导入。
from typing import Tuple, Dict
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch import Tensor
from torchvision.datasets import CIFAR10
如前所述,我们将使用 CIFAR-10 数据集进行机器学习。模型架构(一个非常简单的卷积神经网络)在 class Net()
中定义。
class Net(nn.Module):
def __init__(self) -> None:
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x: Tensor) -> Tensor:
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
load_data()
函数加载 CIFAR-10 训练集和测试集。加载数据后,:code:`transform`函数对数据进行了归一化处理。
DATA_ROOT = "~/data/cifar-10"
def load_data() -> Tuple[torch.utils.data.DataLoader, torch.utils.data.DataLoader, Dict]:
"""Load CIFAR-10 (training and test set)."""
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)
trainset = CIFAR10(DATA_ROOT, train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True)
testset = CIFAR10(DATA_ROOT, train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=32, shuffle=False)
num_examples = {"trainset" : len(trainset), "testset" : len(testset)}
return trainloader, testloader, num_examples
现在,我们需要定义训练函数(train()
),该函数在训练集上循环训练,计算损失值并反向传播,然后为每批训练数据在优化器上执行一个优化步骤。
模型的评估在函数 test()
中定义。该函数循环遍历所有测试样本,并计算测试数据集的模型损失值。
def train(
net: Net,
trainloader: torch.utils.data.DataLoader,
epochs: int,
device: torch.device,
) -> None:
"""Train the network."""
# Define loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
print(f"Training {epochs} epoch(s) w/ {len(trainloader)} batches each")
# Train the network
for epoch in range(epochs): # loop over the dataset multiple times
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
images, labels = data[0].to(device), data[1].to(device)
# zero the parameter gradients
optimizer.zero_grad()
# forward + backward + optimize
outputs = net(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# print statistics
running_loss += loss.item()
if i % 100 == 99: # print every 100 mini-batches
print("[%d, %5d] loss: %.3f" % (epoch + 1, i + 1, running_loss / 2000))
running_loss = 0.0
def test(
net: Net,
testloader: torch.utils.data.DataLoader,
device: torch.device,
) -> Tuple[float, float]:
"""Validate the network on the entire test set."""
criterion = nn.CrossEntropyLoss()
correct = 0
total = 0
loss = 0.0
with torch.no_grad():
for data in testloader:
images, labels = data[0].to(device), data[1].to(device)
outputs = net(images)
loss += criterion(outputs, labels).item()
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
accuracy = correct / total
return loss, accuracy
在确定了数据加载、模型架构、训练和评估之后,我们就可以将所有整合在一起,在 CIFAR-10 上训练我们的 CNN。
def main():
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Centralized PyTorch training")
print("Load data")
trainloader, testloader, _ = load_data()
print("Start training")
net=Net().to(DEVICE)
train(net=net, trainloader=trainloader, epochs=2, device=DEVICE)
print("Evaluate model")
loss, accuracy = test(net=net, testloader=testloader, device=DEVICE)
print("Loss: ", loss)
print("Accuracy: ", accuracy)
if __name__ == "__main__":
main()
现在,您可以运行您的机器学习工作了:
python3 cifar.py
到目前为止,如果你以前用过 PyTorch,这一切看起来应该相当熟悉。让我们进行下一步,利用我们所构建的内容创建一个简单联邦学习系统(由一个服务器和两个客户端组成)。
联邦培训#
上一节讨论的简单机器学习项目在单一数据集(CIFAR-10)上训练模型,我们称之为集中学习。如上一节所示,集中学习的概念可能为大多数人所熟知,而且很多人以前都使用过。通常情况下,如果要以联邦方式运行机器学习工作,就必须更改大部分代码,并从头开始设置一切。这可能是一个相当大的工作量。
不过,有了 Flower,您可以轻松地将已有的代码转变成联邦学习的模式,无需进行大量重写。
这个概念很容易理解。我们必须启动一个*服务器*,然后对连接到*服务器*的*客户端*使用 :code:`cifar.py`中的代码。服务器*向客户端发送模型参数,*客户端*运行训练并更新参数。更新后的参数被发回*服务器,然后会对所有收到的参数更新进行平均聚合。以上描述的是一轮联邦学习过程,我们将重复进行多轮学习。
我们的示例包括一个*服务器*和两个*客户端*。让我们先设置 server.py
。*服务器*需要导入 Flower 软件包 flwr
。接下来,我们使用 start_server
函数启动服务器,并让它执行三轮联邦学习。
import flwr as fl
if __name__ == "__main__":
fl.server.start_server(server_address="0.0.0.0:8080", config=fl.server.ServerConfig(num_rounds=3))
我们已经可以启动*服务器*了:
python3 server.py
最后,我们将在 client.py
中定义我们的 client 逻辑,并以之前在 cifar.py
中定义的集中式训练为基础。我们的 client 不仅需要导入 flwr
,还需要导入 torch
,以更新 PyTorch 模型的参数:
from collections import OrderedDict
from typing import Dict, List, Tuple
import numpy as np
import torch
import cifar
import flwr as fl
DEVICE: str = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
实现 Flower *client*基本上意味着实现 flwr.client.Client
或 flwr.client.NumPyClient
的子类。我们的代码实现将基于 flwr.client.NumPyClient
,并将其命名为 CifarClient
。如果使用具有良好 NumPy 互操作性的框架(如 PyTorch 或 TensorFlow/Keras),NumPyClient`的实现比 :code:`Client`略微容易一些,因为它避免了一些不必要的操作。:code:`CifarClient
需要实现四个方法,两个用于获取/设置模型参数,一个用于训练模型,一个用于测试模型:
set_parameters
在本地模型上设置从服务器接收的模型参数
循环遍历以 NumPy
ndarray
形式接收的模型参数列表(可以看作神经网络的列表)
fit
用从服务器接收到的参数更新本地模型的参数
在本地训练集上训练模型
获取更新后的本地模型参数并发送回服务器
evaluate
用从服务器接收到的参数更新本地模型的参数
在本地测试集上评估更新后的模型
向服务器返回本地损失值和精确度
这两个 NumPyClient
中的方法 fit
和 evaluate
使用了之前在 cifar.py
中定义的函数 train()
和 test()
。因此,我们在这里要做的就是通过 NumPyClient
子类告知 Flower 在训练和评估时要调用哪些已定义的函数。我们加入了类型注解,以便让你更好地理解传递的数据类型。
class CifarClient(fl.client.NumPyClient):
"""Flower client implementing CIFAR-10 image classification using
PyTorch."""
def __init__(
self,
model: cifar.Net,
trainloader: torch.utils.data.DataLoader,
testloader: torch.utils.data.DataLoader,
num_examples: Dict,
) -> None:
self.model = model
self.trainloader = trainloader
self.testloader = testloader
self.num_examples = num_examples
def get_parameters(self, config) -> List[np.ndarray]:
# Return model parameters as a list of NumPy ndarrays
return [val.cpu().numpy() for _, val in self.model.state_dict().items()]
def set_parameters(self, parameters: List[np.ndarray]) -> None:
# Set model parameters from a list of NumPy ndarrays
params_dict = zip(self.model.state_dict().keys(), parameters)
state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
self.model.load_state_dict(state_dict, strict=True)
def fit(
self, parameters: List[np.ndarray], config: Dict[str, str]
) -> Tuple[List[np.ndarray], int, Dict]:
# Set model parameters, train model, return updated model parameters
self.set_parameters(parameters)
cifar.train(self.model, self.trainloader, epochs=1, device=DEVICE)
return self.get_parameters(config={}), self.num_examples["trainset"], {}
def evaluate(
self, parameters: List[np.ndarray], config: Dict[str, str]
) -> Tuple[float, int, Dict]:
# Set model parameters, evaluate model on local test dataset, return result
self.set_parameters(parameters)
loss, accuracy = cifar.test(self.model, self.testloader, device=DEVICE)
return float(loss), self.num_examples["testset"], {"accuracy": float(accuracy)}
剩下的就是定义模型和数据加载函数了。创建一个:code:CifarClient`类,并运行这个客服端。您将通过:code:`cifar.py`加载数据和模型。另外,通过:code:`fl.client.start_client()`函数来运行客户端:code:`CifarClient,需要保证IP地址和:code:`server.py`中所使用的一致:
def main() -> None:
"""Load data, start CifarClient."""
# Load model and data
model = cifar.Net()
model.to(DEVICE)
trainloader, testloader, num_examples = cifar.load_data()
# Start client
client = CifarClient(model, trainloader, testloader, num_examples)
fl.client.start_client(server_address="0.0.0.0:8080", client.to_client())
if __name__ == "__main__":
main()
就是这样,现在你可以打开另外两个终端窗口,然后运行
python3 client.py
确保服务器正在运行后,您就能看到您的 PyTorch 项目(之前是集中式的)在两个客户端上运行联邦学习了。祝贺!
下一步工作#
本示例的完整源代码为:PyTorch: 从集中式到联合式。当然,我们的示例有些过于简单,因为两个客户端都加载了完全相同的数据集,这并不真实。现在,您已经准备好进一步探讨这一主题了。比如在每个客户端使用不同的 CIFAR-10 子集会如何?增加更多客户端会如何?