整合评估结果#

Flower 服务器没有规定整合评估结果的方法,但用户可以完全自定义如何整合。

自定义整合评估结果#

同样的 Strategy 定制方法也可用于汇总来自单个客户端的自定义评估结果。客户端可以通过返回字典的方式向服务器返回自定义指标:

class CifarClient(fl.client.NumPyClient):

    def get_parameters(self, config):
        # ...

    def fit(self, parameters, config):
        # ...

    def evaluate(self, parameters, config):
        """Evaluate parameters on the locally held test set."""

        # Update local model with global parameters
        self.model.set_weights(parameters)

        # Evaluate global model parameters on the local test data
        loss, accuracy = self.model.evaluate(self.x_test, self.y_test)

        # Return results, including the custom accuracy metric
        num_examples_test = len(self.x_test)
        return loss, num_examples_test, {"accuracy": accuracy}

然后,服务器可以使用定制的策略来汇总这些字典中提供的指标:

class AggregateCustomMetricStrategy(fl.server.strategy.FedAvg):
    def aggregate_evaluate(
        self,
        server_round: int,
        results: List[Tuple[ClientProxy, EvaluateRes]],
        failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],
    ) -> Tuple[Optional[float], Dict[str, Scalar]]:
        """Aggregate evaluation accuracy using weighted average."""

        if not results:
            return None, {}

        # Call aggregate_evaluate from base class (FedAvg) to aggregate loss and metrics
        aggregated_loss, aggregated_metrics = super().aggregate_evaluate(server_round, results, failures)

        # Weigh accuracy of each client by number of examples used
        accuracies = [r.metrics["accuracy"] * r.num_examples for _, r in results]
        examples = [r.num_examples for _, r in results]

        # Aggregate and print custom metric
        aggregated_accuracy = sum(accuracies) / sum(examples)
        print(f"Round {server_round} accuracy aggregated from client results: {aggregated_accuracy}")

        # Return aggregated loss and metrics (i.e., aggregated accuracy)
        return aggregated_loss, {"accuracy": aggregated_accuracy}

# Create strategy and run server
strategy = AggregateCustomMetricStrategy(
    # (same arguments as FedAvg here)
)
fl.server.start_server(strategy=strategy)