Open in Colab

使用联邦学习策略#

欢迎来到联邦学习教程的下一部分。在本教程的前几部分,我们介绍了使用 PyTorch 和 Flower 进行联邦学习(`第 1 部分 <https://flower.ai/docs/framework/tutorial-get-started-with-flower-pytorch.html>`___)。

在本笔记中,我们将开始定制在入门笔记中构建的联邦学习系统(再次使用 FlowerPyTorch)。

Star Flower on GitHub ⭐️ 并加入 Slack 上的 Flower 社区,进行交流、提问并获得帮助: 加入 Slack <https://flower.ai/join-slack>`__ 🌼 我们希望在 #introductions 频道听到您的声音!如果有任何不清楚的地方,请访问 #questions 频道。

让我们超越 FedAvg,采用Flower策略!

准备工作#

在开始实际代码之前,让我们先确保我们已经准备好了所需的一切。

安装依赖项#

首先,我们安装必要的软件包:

[ ]:
!pip install -q flwr[simulation] torch torchvision

现在我们已经安装了所有依赖项,可以导入本教程所需的所有内容:

[ ]:
from collections import OrderedDict
from typing import Dict, List, Optional, Tuple

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import CIFAR10

import flwr as fl

DEVICE = torch.device("cpu")  # Try "cuda" to train on GPU
print(
    f"Training on {DEVICE} using PyTorch {torch.__version__} and Flower {fl.__version__}"
)

可以切换到已启用 GPU 加速的运行时(在 Google Colab 上: 运行时 > 更改运行时类型 > 硬件加速: GPU > 保存``)。但请注意,Google Colab 并非总能提供 GPU 加速。如果在以下部分中看到与 GPU 可用性相关的错误,请考虑通过设置 DEVICE = torch.device("cpu") 切回基于 CPU 的执行。如果运行时已启用 GPU 加速,你应该会看到输出``Training on cuda``,否则会显示``Training on cpu``。

数据加载#

现在,让我们加载 CIFAR-10 训练集和测试集,将它们分割成 10 个较小的数据集(每个数据集又分为训练集和验证集),并将所有数据都封装在各自的 DataLoader 中。我们引入了一个新参数 num_clients,它允许我们使用不同数量的客户端调用 load_datasets

[ ]:
NUM_CLIENTS = 10


def load_datasets(num_clients: int):
    # Download and transform CIFAR-10 (train and test)
    transform = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
    )
    trainset = CIFAR10("./dataset", train=True, download=True, transform=transform)
    testset = CIFAR10("./dataset", train=False, download=True, transform=transform)

    # Split training set into `num_clients` partitions to simulate different local datasets
    partition_size = len(trainset) // num_clients
    lengths = [partition_size] * num_clients
    datasets = random_split(trainset, lengths, torch.Generator().manual_seed(42))

    # Split each partition into train/val and create DataLoader
    trainloaders = []
    valloaders = []
    for ds in datasets:
        len_val = len(ds) // 10  # 10 % validation set
        len_train = len(ds) - len_val
        lengths = [len_train, len_val]
        ds_train, ds_val = random_split(ds, lengths, torch.Generator().manual_seed(42))
        trainloaders.append(DataLoader(ds_train, batch_size=32, shuffle=True))
        valloaders.append(DataLoader(ds_val, batch_size=32))
    testloader = DataLoader(testset, batch_size=32)
    return trainloaders, valloaders, testloader


trainloaders, valloaders, testloader = load_datasets(NUM_CLIENTS)

模型培训/评估#

让我们继续使用常见的模型定义(包括 set_parametersget_parameters)、训练和测试函数:

[ ]:
class Net(nn.Module):
    def __init__(self) -> None:
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


def get_parameters(net) -> List[np.ndarray]:
    return [val.cpu().numpy() for _, val in net.state_dict().items()]


def set_parameters(net, parameters: List[np.ndarray]):
    params_dict = zip(net.state_dict().keys(), parameters)
    state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
    net.load_state_dict(state_dict, strict=True)


def train(net, trainloader, epochs: int):
    """Train the network on the training set."""
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters())
    net.train()
    for epoch in range(epochs):
        correct, total, epoch_loss = 0, 0, 0.0
        for images, labels in trainloader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            outputs = net(images)
            loss = criterion(net(images), labels)
            loss.backward()
            optimizer.step()
            # Metrics
            epoch_loss += loss
            total += labels.size(0)
            correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()
        epoch_loss /= len(trainloader.dataset)
        epoch_acc = correct / total
        print(f"Epoch {epoch+1}: train loss {epoch_loss}, accuracy {epoch_acc}")


def test(net, testloader):
    """Evaluate the network on the entire test set."""
    criterion = torch.nn.CrossEntropyLoss()
    correct, total, loss = 0, 0, 0.0
    net.eval()
    with torch.no_grad():
        for images, labels in testloader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            outputs = net(images)
            loss += criterion(outputs, labels).item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    loss /= len(testloader.dataset)
    accuracy = correct / total
    return loss, accuracy

Flower 客户端#

为了实现 Flower 客户端,我们(再次)创建了 flwr.client.NumPyClient 的子类,并实现了 get_parametersfitevaluate``三个方法。在这里,我们还将 ``cid 传递给客户端,并使用它记录其他详细信息:

[ ]:
class FlowerClient(fl.client.NumPyClient):
    def __init__(self, cid, net, trainloader, valloader):
        self.cid = cid
        self.net = net
        self.trainloader = trainloader
        self.valloader = valloader

    def get_parameters(self, config):
        print(f"[Client {self.cid}] get_parameters")
        return get_parameters(self.net)

    def fit(self, parameters, config):
        print(f"[Client {self.cid}] fit, config: {config}")
        set_parameters(self.net, parameters)
        train(self.net, self.trainloader, epochs=1)
        return get_parameters(self.net), len(self.trainloader), {}

    def evaluate(self, parameters, config):
        print(f"[Client {self.cid}] evaluate, config: {config}")
        set_parameters(self.net, parameters)
        loss, accuracy = test(self.net, self.valloader)
        return float(loss), len(self.valloader), {"accuracy": float(accuracy)}


def client_fn(cid) -> FlowerClient:
    net = Net().to(DEVICE)
    trainloader = trainloaders[int(cid)]
    valloader = valloaders[int(cid)]
    return FlowerClient(cid, net, trainloader, valloader)

策略定制#

到目前为止,如果您已经阅读过入门笔记本,那么一切都应该很熟悉了。接下来,我们将介绍一些新功能。

服务器端参数 初始化#

默认情况下,Flower 会通过向一个随机客户端询问初始参数来初始化全局模型。但在许多情况下,我们需要对参数初始化进行更多控制。因此,Flower 允许您直接将初始参数传递给策略:

[ ]:
# Create an instance of the model and get the parameters
params = get_parameters(Net())

# Pass parameters to the Strategy for server-side parameter initialization
strategy = fl.server.strategy.FedAvg(
    fraction_fit=0.3,
    fraction_evaluate=0.3,
    min_fit_clients=3,
    min_evaluate_clients=3,
    min_available_clients=NUM_CLIENTS,
    initial_parameters=fl.common.ndarrays_to_parameters(params),
)

# Specify client resources if you need GPU (defaults to 1 CPU and 0 GPU)
client_resources = None
if DEVICE.type == "cuda":
    client_resources = {"num_gpus": 1}

# Start simulation
fl.simulation.start_simulation(
    client_fn=client_fn,
    num_clients=NUM_CLIENTS,
    config=fl.server.ServerConfig(num_rounds=3),  # Just three rounds
    strategy=strategy,
    client_resources=client_resources,
)

FedAvg 策略传递 initial_parameters 可以防止 Flower 向其中一个客户端询问初始参数。如果我们仔细观察,就会发现日志中没有显示对 FlowerClient.get_parameters 方法的任何调用。

从定制战略开始#

我们以前见过函数 start_simulation。它接受许多参数,其中包括用于创建 FlowerClient 实例的 client_fn、要模拟的客户数量 num_clients、回合数 ``num_rounds``和策略。

该策略封装了联邦学习方法/算法,例如`FedAvg``或`FedAdagrad``。这次让我们尝试使用不同的策略:

[ ]:
# Create FedAdam strategy
strategy = fl.server.strategy.FedAdagrad(
    fraction_fit=0.3,
    fraction_evaluate=0.3,
    min_fit_clients=3,
    min_evaluate_clients=3,
    min_available_clients=NUM_CLIENTS,
    initial_parameters=fl.common.ndarrays_to_parameters(get_parameters(Net())),
)

# Start simulation
fl.simulation.start_simulation(
    client_fn=client_fn,
    num_clients=NUM_CLIENTS,
    config=fl.server.ServerConfig(num_rounds=3),  # Just three rounds
    strategy=strategy,
    client_resources=client_resources,
)

服务器端参数**评估**#

Flower 可以在服务器端或客户端评估聚合模型。客户端和服务器端评估在某些方面相似,但也有不同之处。

**集中评估**(或*服务器端评估*)在概念上很简单:它的工作方式与集中式机器学习中的评估方式相同。如果有一个服务器端数据集可用于评估目的,那就太好了。我们可以在每一轮训练后对新聚合的模型进行评估,而无需将模型发送给客户端。我们也很幸运,因为我们的整个评估数据集随时可用。

联邦评估**(或*客户端评估*)更为复杂,但也更为强大:它不需要集中的数据集,允许我们在更大的数据集上对模型进行评估,这通常会产生更真实的评估结果。事实上,如果我们想得到有代表性的评估结果,很多情况下都需要使用**联邦评估。但是,这种能力是有代价的:一旦我们开始在客户端进行评估,我们就应该意识到,如果这些客户端并不总是可用,我们的评估数据集可能会在连续几轮学习中发生变化。此外,每个客户端所拥有的数据集也可能在连续几轮学习中发生变化。这可能会导致评估结果不稳定,因此即使我们不改变模型,也会看到评估结果在连续几轮中波动。

我们已经了解了联邦评估如何在客户端工作(即通过在 FlowerClient 中实现 evaluate 方法)。现在让我们看看如何在服务器端评估聚合模型参数:

[ ]:
# The `evaluate` function will be by Flower called after every round
def evaluate(
    server_round: int,
    parameters: fl.common.NDArrays,
    config: Dict[str, fl.common.Scalar],
) -> Optional[Tuple[float, Dict[str, fl.common.Scalar]]]:
    net = Net().to(DEVICE)
    valloader = valloaders[0]
    set_parameters(net, parameters)  # Update model with the latest parameters
    loss, accuracy = test(net, valloader)
    print(f"Server-side evaluation loss {loss} / accuracy {accuracy}")
    return loss, {"accuracy": accuracy}
[ ]:
strategy = fl.server.strategy.FedAvg(
    fraction_fit=0.3,
    fraction_evaluate=0.3,
    min_fit_clients=3,
    min_evaluate_clients=3,
    min_available_clients=NUM_CLIENTS,
    initial_parameters=fl.common.ndarrays_to_parameters(get_parameters(Net())),
    evaluate_fn=evaluate,  # Pass the evaluation function
)

fl.simulation.start_simulation(
    client_fn=client_fn,
    num_clients=NUM_CLIENTS,
    config=fl.server.ServerConfig(num_rounds=3),  # Just three rounds
    strategy=strategy,
    client_resources=client_resources,
)

向/从客户端发送/接收任意值#

在某些情况下,我们希望从服务器端配置客户端的执行(训练、评估)。其中一个例子就是服务器要求客户端训练一定数量的本地遍历。Flower 提供了一种使用字典从服务器向客户端发送配置值的方法。让我们来看一个例子:客户端通过 fit 中的 config 参数从服务器接收配置值(evaluate 中也有 config 参数)。fit 方法通过 config 参数接收配置字典,然后从字典中读取值。在本例中,它读取了 server_roundlocal_epochs,并使用这些值来改进日志记录和配置本地训练遍历的数量:

[ ]:
class FlowerClient(fl.client.NumPyClient):
    def __init__(self, cid, net, trainloader, valloader):
        self.cid = cid
        self.net = net
        self.trainloader = trainloader
        self.valloader = valloader

    def get_parameters(self, config):
        print(f"[Client {self.cid}] get_parameters")
        return get_parameters(self.net)

    def fit(self, parameters, config):
        # Read values from config
        server_round = config["server_round"]
        local_epochs = config["local_epochs"]

        # Use values provided by the config
        print(f"[Client {self.cid}, round {server_round}] fit, config: {config}")
        set_parameters(self.net, parameters)
        train(self.net, self.trainloader, epochs=local_epochs)
        return get_parameters(self.net), len(self.trainloader), {}

    def evaluate(self, parameters, config):
        print(f"[Client {self.cid}] evaluate, config: {config}")
        set_parameters(self.net, parameters)
        loss, accuracy = test(self.net, self.valloader)
        return float(loss), len(self.valloader), {"accuracy": float(accuracy)}


def client_fn(cid) -> FlowerClient:
    net = Net().to(DEVICE)
    trainloader = trainloaders[int(cid)]
    valloader = valloaders[int(cid)]
    return FlowerClient(cid, net, trainloader, valloader)

那么,如何将配置字典从服务器发送到客户端呢?内置的 "Flower策略"(Flower Strategies)提供了这样的方法,其工作原理与服务器端评估的工作原理类似。我们为策略提供一个函数,策略会在每一轮联邦学习中调用这个函数:

[ ]:
def fit_config(server_round: int):
    """Return training configuration dict for each round.

    Perform two rounds of training with one local epoch, increase to two local
    epochs afterwards.
    """
    config = {
        "server_round": server_round,  # The current round of federated learning
        "local_epochs": 1 if server_round < 2 else 2,  #
    }
    return config

接下来,我们只需在开始模拟前将此函数传递给 FedAvg 策略即可:

[ ]:
strategy = fl.server.strategy.FedAvg(
    fraction_fit=0.3,
    fraction_evaluate=0.3,
    min_fit_clients=3,
    min_evaluate_clients=3,
    min_available_clients=NUM_CLIENTS,
    initial_parameters=fl.common.ndarrays_to_parameters(get_parameters(Net())),
    evaluate_fn=evaluate,
    on_fit_config_fn=fit_config,  # Pass the fit_config function
)

fl.simulation.start_simulation(
    client_fn=client_fn,
    num_clients=NUM_CLIENTS,
    config=fl.server.ServerConfig(num_rounds=3),  # Just three rounds
    strategy=strategy,
    client_resources=client_resources,
)

我们可以看到,客户端日志现在包含了当前一轮的联邦学习(从 config 字典中读取)。我们还可以将本地训练配置为在第一轮和第二轮联邦学习期间运行一个遍历,然后在第三轮联邦学习期间运行两个遍历。

客户端还可以向服务器返回任意值。为此,它们会从 fit 和/或 evaluate 返回一个字典。我们在本笔记中看到并使用了这一概念,但并未明确提及:我们的 FlowerClient 返回一个包含自定义键/值对的字典,作为 evaluate 中的第三个返回值。

扩大联邦学习的规模#

作为本笔记的最后一步,让我们看看如何使用 Flower 对大量客户端进行实验。

[ ]:
NUM_CLIENTS = 1000

trainloaders, valloaders, testloader = load_datasets(NUM_CLIENTS)

现在我们有 1000 个分区,每个分区有 45 个训练数据和 5 个验证数据。鉴于每个客户端上的训练示例数量较少,我们可能需要对模型进行更长时间的训练,因此我们将客户端配置为执行 3 个本地训练遍历。我们还应该调整每轮训练中被选中的客户端的比例(我们不希望每轮训练都有 1000 个客户端参与),因此我们将 fraction_fit 调整为 0.05,这意味着每轮训练只选中 5%的可用客户端(即 50 个客户端):

[ ]:
def fit_config(server_round: int):
    config = {
        "server_round": server_round,
        "local_epochs": 3,
    }
    return config


strategy = fl.server.strategy.FedAvg(
    fraction_fit=0.025,  # Train on 25 clients (each round)
    fraction_evaluate=0.05,  # Evaluate on 50 clients (each round)
    min_fit_clients=20,
    min_evaluate_clients=40,
    min_available_clients=NUM_CLIENTS,
    initial_parameters=fl.common.ndarrays_to_parameters(get_parameters(Net())),
    on_fit_config_fn=fit_config,
)

fl.simulation.start_simulation(
    client_fn=client_fn,
    num_clients=NUM_CLIENTS,
    config=fl.server.ServerConfig(num_rounds=3),  # Just three rounds
    strategy=strategy,
    client_resources=client_resources,
)

回顾#

在本笔记中,我们看到了如何通过自定义策略、在服务器端初始化参数、选择不同的策略以及在服务器端评估模型来逐步增强我们的系统。用这么少的代码就能实现这么大的灵活性,不是吗?

在后面的章节中,我们将看到如何在服务器和客户端之间传递任意值,以完全自定义客户端执行。有了这种能力,我们使用 Flower 虚拟客户端引擎构建了一个大规模的联邦学习模拟,并在 Jupyter Notebook 中进行了一次实验,在相同的工作负载中运行了 1000 个客户端!

接下来的步骤#

在继续之前,请务必加入 Slack 上的 Flower 社区:Join Slack

如果您需要帮助,我们有专门的 #questions 频道,但我们也很乐意在 #introductions 中了解您是谁!

Flower 联邦学习教程 - 第 3 部分 展示了如何从头开始构建完全自定义的 "策略"。


Open in Colab