实施策略#
策略抽象类可以实现完全定制的策略。策略基本上就是在服务器上运行的联邦学习算法。策略决定如何对客户端进行采样、如何配置客户端进行训练、如何聚合参数更新以及如何评估模型。Flower 提供了一些内置策略,这些策略基于下文所述的相同 API。
:code:`策略 ` 抽象类#
所有策略实现均源自抽象基类 flwr.server.strategy.Strategy
,包括内置实现和第三方实现。这意味着自定义策略实现与内置实现具有完全相同的功能。
策略抽象定义了一些需要实现的抽象方法:
class Strategy(ABC):
"""Abstract base class for server strategy implementations."""
@abstractmethod
def initialize_parameters(
self, client_manager: ClientManager
) -> Optional[Parameters]:
"""Initialize the (global) model parameters."""
@abstractmethod
def configure_fit(
self,
server_round: int,
parameters: Parameters,
client_manager: ClientManager
) -> List[Tuple[ClientProxy, FitIns]]:
"""Configure the next round of training."""
@abstractmethod
def aggregate_fit(
self,
server_round: int,
results: List[Tuple[ClientProxy, FitRes]],
failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],
) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:
"""Aggregate training results."""
@abstractmethod
def configure_evaluate(
self,
server_round: int,
parameters: Parameters,
client_manager: ClientManager
) -> List[Tuple[ClientProxy, EvaluateIns]]:
"""Configure the next round of evaluation."""
@abstractmethod
def aggregate_evaluate(
self,
server_round: int,
results: List[Tuple[ClientProxy, EvaluateRes]],
failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],
) -> Tuple[Optional[float], Dict[str, Scalar]]:
"""Aggregate evaluation results."""
@abstractmethod
def evaluate(
self, parameters: Parameters
) -> Optional[Tuple[float, Dict[str, Scalar]]]:
"""Evaluate the current model parameters."""
创建一个新策略意味着要实现一个新的 class`(从抽象基类 :code:`Strategy
派生),该类要实现前面显示的抽象方法:
class SotaStrategy(Strategy):
def initialize_parameters(self, client_manager):
# Your implementation here
def configure_fit(self, server_round, parameters, client_manager):
# Your implementation here
def aggregate_fit(self, server_round, results, failures):
# Your implementation here
def configure_evaluate(self, server_round, parameters, client_manager):
# Your implementation here
def aggregate_evaluate(self, server_round, results, failures):
# Your implementation here
def evaluate(self, parameters):
# Your implementation here
Flower 服务器按以下顺序调用这些方法:
下文将详细介绍每种方法。
初始化参数
方法#
initialize_parameters
只调用一次,即在执行开始时。它负责以序列化形式(即 Parameters
对象)提供初始全局模型参数。
内置策略会返回用户提供的初始参数。下面的示例展示了如何将初始参数传递给 FedAvg
:
import flwr as fl
import tensorflow as tf
# Load model for server-side parameter initialization
model = tf.keras.applications.EfficientNetB0(
input_shape=(32, 32, 3), weights=None, classes=10
)
model.compile("adam", "sparse_categorical_crossentropy", metrics=["accuracy"])
# Get model weights as a list of NumPy ndarray's
weights = model.get_weights()
# Serialize ndarrays to `Parameters`
parameters = fl.common.ndarrays_to_parameters(weights)
# Use the serialized parameters as the initial global parameters
strategy = fl.server.strategy.FedAvg(
initial_parameters=parameters,
)
fl.server.start_server(config=fl.server.ServerConfig(num_rounds=3), strategy=strategy)
Flower 服务器将调用 initialize_parameters
,返回传给 initial_parameters
的参数或 None
。如果 initialize_parameters
没有返回任何参数(即 None
),服务器将随机选择一个客户端并要求其提供参数。这只是一个便捷的功能,在实际应用中并不推荐使用,但在原型开发中可能很有用。在实践中,建议始终使用服务器端参数初始化。
备注
服务器端参数初始化是一种强大的机制。例如,它可以用来从先前保存的检查点恢复训练。它也是实现混合方法所需的基本能力,例如,使用联邦学习对预先训练好的模型进行微调。
:code:`configure_fit`方法#
configure_fit
负责配置即将开始的一轮训练。*配置*在这里是什么意思?配置一轮训练意味着选择客户并决定向这些客户发送什么指令。configure_fit
说明了这一点:
@abstractmethod
def configure_fit(
self,
server_round: int,
parameters: Parameters,
client_manager: ClientManager
) -> List[Tuple[ClientProxy, FitIns]]:
"""Configure the next round of training."""
返回值是一个元组列表,每个元组代表将发送到特定客户端的指令。策略实现通常在 configure_fit
中执行以下步骤:
使用
client_manager
随机抽样所有(或部分)可用客户端(每个客户端都表示为ClientProxy
对象)将每个
ClientProxy
与持有当前全局模型parameters
和config
dict 的FitIns
配对
More sophisticated implementations can use configure_fit
to implement custom client selection logic. A client will only participate in a round if the corresponding ClientProxy
is included in the list returned from configure_fit
.
备注
该返回值的结构为用户提供了很大的灵活性。由于指令是按客户端定义的,因此可以向每个客户端发送不同的指令。这使得自定义策略成为可能,例如在不同的客户端上训练不同的模型,或在不同的客户端上使用不同的超参数(通过 config
dict)。
aggregate_fit
方法#
aggregate_fit
负责汇总在 configure_fit
中选择并要求训练的客户端所返回的结果。
@abstractmethod
def aggregate_fit(
self,
server_round: int,
results: List[Tuple[ClientProxy, FitRes]],
failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],
) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:
"""Aggregate training results."""
当然,失败是有可能发生的,因此无法保证服务器会从它发送指令(通过 configure_fit
)的所有客户端获得结果。因此 aggregate_fit
会收到 results
的列表,但也会收到 failures
的列表。
aggregate_fit
返回一个可选的 Parameters
对象和一个聚合度量的字典。Parameters
返回值是可选的,因为 aggregate_fit
可能会认为所提供的结果不足以进行聚合(例如,失败次数过多)。
:code:`configure_evaluate`方法#
configure_evaluate
负责配置下一轮评估。*配置*在这里是什么意思?配置一轮评估意味着选择客户端并决定向这些客户端发送什么指令。configure_evaluate
说明了这一点:
@abstractmethod
def configure_evaluate(
self,
server_round: int,
parameters: Parameters,
client_manager: ClientManager
) -> List[Tuple[ClientProxy, EvaluateIns]]:
"""Configure the next round of evaluation."""
返回值是一个元组列表,每个元组代表将发送到特定客户端的指令。策略实现通常在 configure_evaluate
中执行以下步骤:
使用
client_manager
随机抽样所有(或部分)可用客户端(每个客户端都表示为ClientProxy
对象)将每个
ClientProxy
与持有当前全局模型parameters
和config
dict 的EvaluateIns
配对
More sophisticated implementations can use configure_evaluate
to implement custom client selection logic. A client will only participate in a round if the corresponding ClientProxy
is included in the list returned from configure_evaluate
.
备注
该返回值的结构为用户提供了很大的灵活性。由于指令是按客户端定义的,因此可以向每个客户端发送不同的指令。这使得自定义策略可以在不同客户端上评估不同的模型,或在不同客户端上使用不同的超参数(通过 config
dict)。
aggregate_evaluate
方法#
aggregate_evaluate
负责汇总在 configure_evaluate
中选择并要求评估的客户端返回的结果。
@abstractmethod
def aggregate_evaluate(
self,
server_round: int,
results: List[Tuple[ClientProxy, EvaluateRes]],
failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],
) -> Tuple[Optional[float], Dict[str, Scalar]]:
"""Aggregate evaluation results."""
当然,失败是有可能发生的,因此无法保证服务器会从它发送指令(通过 configure_evaluate
)的所有客户端获得结果。因此, aggregate_evaluate
会接收 results
的列表,但也会接收 failures
的列表。
aggregate_evaluate
返回一个可选的 float`(损失值)和一个聚合指标字典。:code:`float
返回值是可选的,因为 aggregate_evaluate
可能会认为所提供的结果不足以进行聚合(例如,失败次数过多)。
:code:`evaluate`方法#
evaluate
负责在服务器端评估模型参数。除了 configure_evaluate
/aggregate_evaluate
之外,evaluate
可以使策略同时执行服务器端和客户端(联邦)评估。
@abstractmethod
def evaluate(
self, parameters: Parameters
) -> Optional[Tuple[float, Dict[str, Scalar]]]:
"""Evaluate the current model parameters."""
返回值也是可选的,因为策略可能不需要执行服务器端评估,或者因为用户定义的 evaluate
方法可能无法成功完成(例如,它可能无法加载服务器端评估数据)。