保存和加载模型检查点#
Flower 不会在服务器端自动保存模型更新。本指南将介绍在 Flower 中保存(和加载)模型检查点的步骤。
模型检查点#
模型更新可通过自定义 Strategy
方法在服务器端持久化。实现自定义策略始终是一种选择,但在许多情况下,简单地自定义现有策略可能更方便。下面的代码示例定义了一个新的 SaveModelStrategy
,它自定义了现有的内置 FedAvg
策略。特别是,它通过调用基类(FedAvg
)中的 aggregate_fit
来定制 aggregate_fit
。然后继续保存返回的(聚合)参数,然后再将这些聚合参数返回给调用者(即服务器):
class SaveModelStrategy(fl.server.strategy.FedAvg):
def aggregate_fit(
self,
server_round: int,
results: List[Tuple[fl.server.client_proxy.ClientProxy, fl.common.FitRes]],
failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],
) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:
# Call aggregate_fit from base class (FedAvg) to aggregate parameters and metrics
aggregated_parameters, aggregated_metrics = super().aggregate_fit(server_round, results, failures)
if aggregated_parameters is not None:
# Convert `Parameters` to `List[np.ndarray]`
aggregated_ndarrays: List[np.ndarray] = fl.common.parameters_to_ndarrays(aggregated_parameters)
# Save aggregated_ndarrays
print(f"Saving round {server_round} aggregated_ndarrays...")
np.savez(f"round-{server_round}-weights.npz", *aggregated_ndarrays)
return aggregated_parameters, aggregated_metrics
# Create strategy and run server
strategy = SaveModelStrategy(
# (same arguments as FedAvg here)
)
fl.server.start_server(strategy=strategy)
保存和加载 PyTorch 检查点#
与前面的例子类似,但多了几个步骤,我们将展示如何存储一个 PyTorch 检查点,我们将使用 torch.save
函数。首先,aggregate_fit
返回一个 Parameters
对象,它必须被转换成一个 NumPy ndarray
的列表,然后这些对象按照 OrderedDict
类结构被转换成 PyTorch state_dict 对象。
net = cifar.Net().to(DEVICE)
class SaveModelStrategy(fl.server.strategy.FedAvg):
def aggregate_fit(
self,
server_round: int,
results: List[Tuple[fl.server.client_proxy.ClientProxy, fl.common.FitRes]],
failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],
) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:
"""Aggregate model weights using weighted average and store checkpoint"""
# Call aggregate_fit from base class (FedAvg) to aggregate parameters and metrics
aggregated_parameters, aggregated_metrics = super().aggregate_fit(server_round, results, failures)
if aggregated_parameters is not None:
print(f"Saving round {server_round} aggregated_parameters...")
# Convert `Parameters` to `List[np.ndarray]`
aggregated_ndarrays: List[np.ndarray] = fl.common.parameters_to_ndarrays(aggregated_parameters)
# Convert `List[np.ndarray]` to PyTorch`state_dict`
params_dict = zip(net.state_dict().keys(), aggregated_ndarrays)
state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
net.load_state_dict(state_dict, strict=True)
# Save the model
torch.save(net.state_dict(), f"model_round_{server_round}.pth")
return aggregated_parameters, aggregated_metrics
要加载进度,只需在代码中添加以下几行。请注意,这将遍历所有已保存的检查点,并加载最新的检查点:
list_of_files = [fname for fname in glob.glob("./model_round_*")]
latest_round_file = max(list_of_files, key=os.path.getctime)
print("Loading pre-trained model from: ", latest_round_file)
state_dict = torch.load(latest_round_file)
net.load_state_dict(state_dict)
state_dict_ndarrays = [v.cpu().numpy() for v in net.state_dict().values()]
parameters = fl.common.ndarrays_to_parameters(state_dict_ndarrays)
Return/use this object of type Parameters
wherever necessary, such as in the initial_parameters
when defining a Strategy
.