联邦学习评估#

评估联合学习系统中的模型主要有两种方法:集中(或服务器端)评估和联邦(或客户端)评估。

集中评估#

内置策略#

所有内置策略都通过在初始化过程中提供一个评估函数来支持集中评估。评估函数是任何可以将当前全局模型参数作为输入并返回评估结果的函数:

from flwr.common import NDArrays, Scalar

from typing import Dict, Optional, Tuple

def get_evaluate_fn(model):
    """Return an evaluation function for server-side evaluation."""

    # Load data and model here to avoid the overhead of doing it in `evaluate` itself
    (x_train, y_train), _ = tf.keras.datasets.cifar10.load_data()

    # Use the last 5k training examples as a validation set
    x_val, y_val = x_train[45000:50000], y_train[45000:50000]

    # The `evaluate` function will be called after every round
    def evaluate(
        server_round: int, parameters: NDArrays, config: Dict[str, Scalar]
    ) -> Optional[Tuple[float, Dict[str, Scalar]]]:
        model.set_weights(parameters)  # Update model with the latest parameters
        loss, accuracy = model.evaluate(x_val, y_val)
        return loss, {"accuracy": accuracy}

    return evaluate

# Load and compile model for server-side parameter evaluation
model = tf.keras.applications.EfficientNetB0(
    input_shape=(32, 32, 3), weights=None, classes=10
)
model.compile("adam", "sparse_categorical_crossentropy", metrics=["accuracy"])


# Create strategy
strategy = fl.server.strategy.FedAvg(
    # ... other FedAvg arguments
    evaluate_fn=get_evaluate_fn(model),
)

# Start Flower server for four rounds of federated learning
fl.server.start_server(server_address="[::]:8080", strategy=strategy)

定制策略#

Strategy 抽象提供了一个名为 evaluate 的方法,可直接用于评估当前的全局模型参数。服务器会在参数聚合后和联邦评估前调用 :code:`evaluate`(见下段)。

联邦评估#

实现联邦评估#

客户端评估在 Client.evaluate 方法中进行,并可从服务器端进行配置。

class CifarClient(fl.client.NumPyClient):
    def __init__(self, model, x_train, y_train, x_test, y_test):
        self.model = model
        self.x_train, self.y_train = x_train, y_train
        self.x_test, self.y_test = x_test, y_test

    def get_parameters(self, config):
        # ...

    def fit(self, parameters, config):
        # ...

    def evaluate(self, parameters, config):
        """Evaluate parameters on the locally held test set."""

        # Update local model with global parameters
        self.model.set_weights(parameters)

        # Get config values
        steps: int = config["val_steps"]

        # Evaluate global model parameters on the local test data and return results
        loss, accuracy = self.model.evaluate(self.x_test, self.y_test, 32, steps=steps)
        num_examples_test = len(self.x_test)
        return loss, num_examples_test, {"accuracy": accuracy}

配置联邦评估#

联邦评估可从服务器端进行配置。内置策略支持以下参数:

  • fraction_evaluatefloat,定义了被选中进行评估的客户端的比例。如果 fraction_evaluate 设置为 0.1,并且 100 个客户端连接到服务器,那么 10 个客户端将被随机选中进行评估。如果 fraction_evaluate 设置为 0.0,联邦评估将被禁用。

  • min_evaluate_clients:一个 int,需要评估的客户的最小数量。如果 fraction_evaluate 设置为 0.1min_evaluate_clients 设置为 20,并且有 100 个客户端已连接到服务器,那么 20 个客户端将被选中进行评估。

  • min_available_clientsint,定义了在一轮联邦评估开始之前,需要连接到服务器的最小客户端数量。如果连接到服务器的客户端数量少于 min_available_clients,服务器将等待更多客户端连接后,才继续采样客户端进行评估。

  • code:on_evaluate_config_fn:返回配置字典的函数,该字典将发送给选定的客户端。该函数将在每一轮中被调用,并提供了一种方便的方法来从服务器端自定义客户端评估,例如,配置执行的验证步骤数。

def evaluate_config(server_round: int):
    """Return evaluation configuration dict for each round.
    Perform five local evaluation steps on each client (i.e., use five
    batches) during rounds, one to three, then increase to ten local
    evaluation steps.
    """
    val_steps = 5 if server_round < 4 else 10
    return {"val_steps": val_steps}

# Create strategy
strategy = fl.server.strategy.FedAvg(
    # ... other FedAvg arguments
    fraction_evaluate=0.2,
    min_evaluate_clients=2,
    min_available_clients=10,
    on_evaluate_config_fn=evaluate_config,
)

# Start Flower server for four rounds of federated learning
fl.server.start_server(server_address="[::]:8080", strategy=strategy)

评估训练期间的本地模型更新#

模型参数也可在训练过程中进行评估。 :code:`Client.fit`可以字典形式返回任意评估结果:

class CifarClient(fl.client.NumPyClient):
    def __init__(self, model, x_train, y_train, x_test, y_test):
        self.model = model
        self.x_train, self.y_train = x_train, y_train
        self.x_test, self.y_test = x_test, y_test

    def get_parameters(self, config):
        # ...

    def fit(self, parameters, config):
        """Train parameters on the locally held training set."""

        # Update local model parameters
        self.model.set_weights(parameters)

        # Train the model using hyperparameters from config
        history = self.model.fit(
            self.x_train, self.y_train, batch_size=32, epochs=2, validation_split=0.1
        )

        # Return updated model parameters and validation results
        parameters_prime = self.model.get_weights()
        num_examples_train = len(self.x_train)
        results = {
            "loss": history.history["loss"][0],
            "accuracy": history.history["accuracy"][0],
            "val_loss": history.history["val_loss"][0],
            "val_accuracy": history.history["val_accuracy"][0],
        }
        return parameters_prime, num_examples_train, results

    def evaluate(self, parameters, config):
        # ...

完整代码示例#

有关同时使用集中评估和联邦评估的完整代码示例,请参阅 *Advanced TensorFlow Example*(同样的方法也可应用于任何其他框架中): https://github.com/adap/flower/tree/main/examples/advanced-tensorflow