从零开始制定策略#
欢迎来到 Flower 联邦学习教程的第三部分。在本教程的前几部分,我们介绍了 PyTorch 和 Flower 的联邦学习(part 1),并学习了如何使用策略来定制服务器和客户端的执行(part 2)。
在本笔记中,我们将通过创建 FedAvg 的自定义版本(再次使用 Flower 和 PyTorch),继续定制我们之前构建的联邦学习系统。
Star Flower on GitHub ⭐️ 并加入 Slack 上的 Flower 社区,进行交流、提问并获得帮助: 加入 Slack <https://flower.ai/join-slack>`__ 🌼 我们希望在
#introductions
频道听到您的声音!如果有任何不清楚的地方,请访问#questions
频道。
让我们从头开始构建一个新的``Strategy``!
准备工作#
在开始实际代码之前,让我们先确保我们已经准备好了所需的一切。
安装依赖项#
首先,我们安装必要的软件包:
[ ]:
!pip install -q flwr[simulation] torch torchvision
现在我们已经安装了所有依赖项,可以导入本教程所需的所有内容:
[ ]:
from collections import OrderedDict
from typing import Dict, List, Optional, Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import CIFAR10
import flwr as fl
DEVICE = torch.device("cpu") # Try "cuda" to train on GPU
print(
f"Training on {DEVICE} using PyTorch {torch.__version__} and Flower {fl.__version__}"
)
可以切换到已启用 GPU 加速的运行时(在 Google Colab 上: 运行时 > 更改运行时类型 > 硬件加速: GPU > 保存``)。但请注意,Google Colab 并非总能提供 GPU 加速。如果在以下部分中看到与 GPU 可用性相关的错误,请考虑通过设置 DEVICE = torch.device("cpu")
切回基于 CPU 的执行。如果运行时已启用 GPU 加速,你应该会看到输出``Training on cuda``,否则会显示``Training on cpu``。
数据加载#
现在,让我们加载 CIFAR-10 训练集和测试集,将它们分割成 10 个较小的数据集(每个数据集又分为训练集和验证集),并将所有数据都封装在各自的 DataLoader
中。我们引入了一个新参数 num_clients
,它允许我们使用不同数量的客户端调用 load_datasets
。
[ ]:
NUM_CLIENTS = 10
def load_datasets(num_clients: int):
# Download and transform CIFAR-10 (train and test)
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)
trainset = CIFAR10("./dataset", train=True, download=True, transform=transform)
testset = CIFAR10("./dataset", train=False, download=True, transform=transform)
# Split training set into `num_clients` partitions to simulate different local datasets
partition_size = len(trainset) // num_clients
lengths = [partition_size] * num_clients
datasets = random_split(trainset, lengths, torch.Generator().manual_seed(42))
# Split each partition into train/val and create DataLoader
trainloaders = []
valloaders = []
for ds in datasets:
len_val = len(ds) // 10 # 10 % validation set
len_train = len(ds) - len_val
lengths = [len_train, len_val]
ds_train, ds_val = random_split(ds, lengths, torch.Generator().manual_seed(42))
trainloaders.append(DataLoader(ds_train, batch_size=32, shuffle=True))
valloaders.append(DataLoader(ds_val, batch_size=32))
testloader = DataLoader(testset, batch_size=32)
return trainloaders, valloaders, testloader
trainloaders, valloaders, testloader = load_datasets(NUM_CLIENTS)
模型培训/评估#
让我们继续使用常见的模型定义(包括 set_parameters 和 get_parameters)、训练和测试函数:
[ ]:
class Net(nn.Module):
def __init__(self) -> None:
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
def get_parameters(net) -> List[np.ndarray]:
return [val.cpu().numpy() for _, val in net.state_dict().items()]
def set_parameters(net, parameters: List[np.ndarray]):
params_dict = zip(net.state_dict().keys(), parameters)
state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
net.load_state_dict(state_dict, strict=True)
def train(net, trainloader, epochs: int):
"""Train the network on the training set."""
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters())
net.train()
for epoch in range(epochs):
correct, total, epoch_loss = 0, 0, 0.0
for images, labels in trainloader:
images, labels = images.to(DEVICE), labels.to(DEVICE)
optimizer.zero_grad()
outputs = net(images)
loss = criterion(net(images), labels)
loss.backward()
optimizer.step()
# Metrics
epoch_loss += loss
total += labels.size(0)
correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()
epoch_loss /= len(trainloader.dataset)
epoch_acc = correct / total
print(f"Epoch {epoch+1}: train loss {epoch_loss}, accuracy {epoch_acc}")
def test(net, testloader):
"""Evaluate the network on the entire test set."""
criterion = torch.nn.CrossEntropyLoss()
correct, total, loss = 0, 0, 0.0
net.eval()
with torch.no_grad():
for images, labels in testloader:
images, labels = images.to(DEVICE), labels.to(DEVICE)
outputs = net(images)
loss += criterion(outputs, labels).item()
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
loss /= len(testloader.dataset)
accuracy = correct / total
return loss, accuracy
Flower 客户端#
为了实现 Flower 客户端,我们(再次)创建了 flwr.client.NumPyClient
的子类,并实现了 get_parameters
、fit
和 evaluate``三个方法。在这里,我们还将 ``cid
传递给客户端,并使用它记录其他详细信息:
[ ]:
class FlowerClient(fl.client.NumPyClient):
def __init__(self, cid, net, trainloader, valloader):
self.cid = cid
self.net = net
self.trainloader = trainloader
self.valloader = valloader
def get_parameters(self, config):
print(f"[Client {self.cid}] get_parameters")
return get_parameters(self.net)
def fit(self, parameters, config):
print(f"[Client {self.cid}] fit, config: {config}")
set_parameters(self.net, parameters)
train(self.net, self.trainloader, epochs=1)
return get_parameters(self.net), len(self.trainloader), {}
def evaluate(self, parameters, config):
print(f"[Client {self.cid}] evaluate, config: {config}")
set_parameters(self.net, parameters)
loss, accuracy = test(self.net, self.valloader)
return float(loss), len(self.valloader), {"accuracy": float(accuracy)}
def client_fn(cid) -> FlowerClient:
net = Net().to(DEVICE)
trainloader = trainloaders[int(cid)]
valloader = valloaders[int(cid)]
return FlowerClient(cid, net, trainloader, valloader)
在继续之前,让我们先测试一下我们目前掌握的情况:
[ ]:
# Specify client resources if you need GPU (defaults to 1 CPU and 0 GPU)
client_resources = None
if DEVICE.type == "cuda":
client_resources = {"num_gpus": 1}
fl.simulation.start_simulation(
client_fn=client_fn,
num_clients=2,
config=fl.server.ServerConfig(num_rounds=3),
client_resources=client_resources,
)
从零开始构建策略#
让我们重写 configure_fit
方法,使其向一部分客户的优化器传递更高的学习率(可能还有其他超参数)。我们将保持 FedAvg
中的客户端采样,然后更改配置字典(FitIns
属性之一)。
[ ]:
from typing import Callable, Union
from flwr.common import (
EvaluateIns,
EvaluateRes,
FitIns,
FitRes,
MetricsAggregationFn,
NDArrays,
Parameters,
Scalar,
ndarrays_to_parameters,
parameters_to_ndarrays,
)
from flwr.server.client_manager import ClientManager
from flwr.server.client_proxy import ClientProxy
from flwr.server.strategy.aggregate import aggregate, weighted_loss_avg
class FedCustom(fl.server.strategy.Strategy):
def __init__(
self,
fraction_fit: float = 1.0,
fraction_evaluate: float = 1.0,
min_fit_clients: int = 2,
min_evaluate_clients: int = 2,
min_available_clients: int = 2,
) -> None:
super().__init__()
self.fraction_fit = fraction_fit
self.fraction_evaluate = fraction_evaluate
self.min_fit_clients = min_fit_clients
self.min_evaluate_clients = min_evaluate_clients
self.min_available_clients = min_available_clients
def __repr__(self) -> str:
return "FedCustom"
def initialize_parameters(
self, client_manager: ClientManager
) -> Optional[Parameters]:
"""Initialize global model parameters."""
net = Net()
ndarrays = get_parameters(net)
return fl.common.ndarrays_to_parameters(ndarrays)
def configure_fit(
self, server_round: int, parameters: Parameters, client_manager: ClientManager
) -> List[Tuple[ClientProxy, FitIns]]:
"""Configure the next round of training."""
# Sample clients
sample_size, min_num_clients = self.num_fit_clients(
client_manager.num_available()
)
clients = client_manager.sample(
num_clients=sample_size, min_num_clients=min_num_clients
)
# Create custom configs
n_clients = len(clients)
half_clients = n_clients // 2
standard_config = {"lr": 0.001}
higher_lr_config = {"lr": 0.003}
fit_configurations = []
for idx, client in enumerate(clients):
if idx < half_clients:
fit_configurations.append((client, FitIns(parameters, standard_config)))
else:
fit_configurations.append(
(client, FitIns(parameters, higher_lr_config))
)
return fit_configurations
def aggregate_fit(
self,
server_round: int,
results: List[Tuple[ClientProxy, FitRes]],
failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],
) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:
"""Aggregate fit results using weighted average."""
weights_results = [
(parameters_to_ndarrays(fit_res.parameters), fit_res.num_examples)
for _, fit_res in results
]
parameters_aggregated = ndarrays_to_parameters(aggregate(weights_results))
metrics_aggregated = {}
return parameters_aggregated, metrics_aggregated
def configure_evaluate(
self, server_round: int, parameters: Parameters, client_manager: ClientManager
) -> List[Tuple[ClientProxy, EvaluateIns]]:
"""Configure the next round of evaluation."""
if self.fraction_evaluate == 0.0:
return []
config = {}
evaluate_ins = EvaluateIns(parameters, config)
# Sample clients
sample_size, min_num_clients = self.num_evaluation_clients(
client_manager.num_available()
)
clients = client_manager.sample(
num_clients=sample_size, min_num_clients=min_num_clients
)
# Return client/config pairs
return [(client, evaluate_ins) for client in clients]
def aggregate_evaluate(
self,
server_round: int,
results: List[Tuple[ClientProxy, EvaluateRes]],
failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]],
) -> Tuple[Optional[float], Dict[str, Scalar]]:
"""Aggregate evaluation losses using weighted average."""
if not results:
return None, {}
loss_aggregated = weighted_loss_avg(
[
(evaluate_res.num_examples, evaluate_res.loss)
for _, evaluate_res in results
]
)
metrics_aggregated = {}
return loss_aggregated, metrics_aggregated
def evaluate(
self, server_round: int, parameters: Parameters
) -> Optional[Tuple[float, Dict[str, Scalar]]]:
"""Evaluate global model parameters using an evaluation function."""
# Let's assume we won't perform the global model evaluation on the server side.
return None
def num_fit_clients(self, num_available_clients: int) -> Tuple[int, int]:
"""Return sample size and required number of clients."""
num_clients = int(num_available_clients * self.fraction_fit)
return max(num_clients, self.min_fit_clients), self.min_available_clients
def num_evaluation_clients(self, num_available_clients: int) -> Tuple[int, int]:
"""Use a fraction of available clients for evaluation."""
num_clients = int(num_available_clients * self.fraction_evaluate)
return max(num_clients, self.min_evaluate_clients), self.min_available_clients
剩下的唯一工作就是在启动实验时使用新创建的自定义策略 FedCustom
:
[ ]:
fl.simulation.start_simulation(
client_fn=client_fn,
num_clients=2,
config=fl.server.ServerConfig(num_rounds=3),
strategy=FedCustom(), # <-- pass the new strategy here
client_resources=client_resources,
)
回顾#
在本笔记中,我们了解了如何实施自定义策略。自定义策略可以对客户端节点配置、结果聚合等进行细粒度控制。要定义自定义策略,只需覆盖(抽象)基类 Strategy
的抽象方法即可。为使自定义策略更加强大,您可以将自定义函数传递给新类的构造函数(__init__`),然后在需要时调用这些函数。
接下来的步骤#
在继续之前,请务必加入 Slack 上的 Flower 社区:Join Slack
如果您需要帮助,我们有专门的 #questions
频道,但我们也很乐意在 #introductions
中了解您是谁!
Flower联邦学习教程 - 第4部分 <https://flower.ai/docs/framework/tutorial-customize-the-client-pytorch.html>`__ 介绍了``Client``,它是``NumPyClient``底层的灵活应用程序接口。