示例: PyTorch 中的 FedBN - 从集中式到联邦式#
This tutorial will show you how to use Flower to build a federated version of an existing machine learning workload with FedBN, a federated training strategy designed for non-iid data. We are using PyTorch to train a Convolutional Neural Network(with Batch Normalization layers) on the CIFAR-10 dataset. When applying FedBN, only few changes needed compared to Example: PyTorch - From Centralized To Federated.
集中式训练#
All files are revised based on Example: PyTorch - From Centralized To Federated.
The only thing to do is modifying the file called cifar.py
, revised part is shown below:
类 Net() 中定义的模型架构会相应添加Batch Normalization层。
class Net(nn.Module):
def __init__(self) -> None:
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.bn1 = nn.BatchNorm2d(6)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.bn2 = nn.BatchNorm2d(16)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.bn3 = nn.BatchNorm1d(120)
self.fc2 = nn.Linear(120, 84)
self.bn4 = nn.BatchNorm1d(84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x: Tensor) -> Tensor:
x = self.pool(F.relu(self.bn1(self.conv1(x))))
x = self.pool(F.relu(self.bn2(self.conv2(x))))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.bn3(self.fc1(x)))
x = F.relu(self.bn4(self.fc2(x)))
x = self.fc3(x)
return x
现在,您可以运行您的机器学习工作了:
python3 cifar.py
So far this should all look fairly familiar if you've used PyTorch before. Let's take the next step and use what we've built to create a federated learning system within FedBN, the system consists of one server and two clients.
联邦培训#
If you have read Example: PyTorch - From Centralized To Federated, the following parts are easy to follow, only get_parameters
and set_parameters
function in client.py
needed to revise.
If not, please read the Example: PyTorch - From Centralized To Federated. first.
我们的示例包括一个*服务器*和两个*客户端*。在 FedBN 中,server.py
保持不变,我们可以直接启动服务器。
python3 server.py
最后,我们将修改 client 的逻辑,修改 client.py
中的 get_parameters
和 set_parameters
,在向服务器发送或从服务器接收时,我们将从模型参数列表中排除batch normalization层的参数。
class CifarClient(fl.client.NumPyClient):
"""Flower client implementing CIFAR-10 image classification using
PyTorch."""
...
def get_parameters(self, config) -> List[np.ndarray]:
# Return model parameters as a list of NumPy ndarrays, excluding parameters of BN layers when using FedBN
return [val.cpu().numpy() for name, val in self.model.state_dict().items() if 'bn' not in name]
def set_parameters(self, parameters: List[np.ndarray]) -> None:
# Set model parameters from a list of NumPy ndarrays
keys = [k for k in self.model.state_dict().keys() if 'bn' not in k]
params_dict = zip(keys, parameters)
state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
self.model.load_state_dict(state_dict, strict=False)
...
现在,您可以打开另外两个终端窗口并运行程序
python3 client.py
确保服务器仍在运行后,然后您就能看到您的 PyTorch 项目(之前是集中式的)通过 FedBN 策略在两个客户端上运行联合学习。祝贺!
下一步工作#
本示例的完整源代码可在 <https://github.com/adap/flower/blob/main/examples/pytorch-from-centralized-to-federated>`_ 找到。当然,我们的示例有些过于简单,因为两个客户端都加载了完全相同的数据集,这并不真实。让我们准备好进一步探讨这一主题。如在每个客户端使用不同的 CIFAR-10 子集,或者增加客户端的数量。