# Copyright 2020 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""FedAvg [McMahan et al., 2016] strategy with custom serialization for Android devices.
Paper: arxiv.org/abs/1602.05629
"""
from typing import Callable, Dict, List, Optional, Tuple, Union, cast
import numpy as np
from flwr.common import (
EvaluateIns,
EvaluateRes,
FitIns,
FitRes,
NDArray,
NDArrays,
Parameters,
Scalar,
)
from flwr.server.client_manager import ClientManager
from flwr.server.client_proxy import ClientProxy
from .aggregate import aggregate, weighted_loss_avg
from .strategy import Strategy
# pylint: disable=line-too-long
[文档]class FedAvgAndroid(Strategy):
"""Federated Averaging strategy.
Implementation based on https://arxiv.org/abs/1602.05629
Parameters
----------
fraction_fit : Optional[float]
Fraction of clients used during training. Defaults to 1.0.
fraction_evaluate : Optional[float]
Fraction of clients used during validation. Defaults to 1.0.
min_fit_clients : Optional[int]
Minimum number of clients used during training. Defaults to 2.
min_evaluate_clients : Optional[int]
Minimum number of clients used during validation. Defaults to 2.
min_available_clients : Optional[int]
Minimum number of total clients in the system. Defaults to 2.
evaluate_fn : Optional[Callable[[int, NDArrays, Dict[str, Scalar]], Optional[Tuple[float, Dict[str, Scalar]]]]]
Optional function used for validation. Defaults to None.
on_fit_config_fn : Optional[Callable[[int], Dict[str, Scalar]]]
Function used to configure training. Defaults to None.
on_evaluate_config_fn : Optional[Callable[[int], Dict[str, Scalar]]]
Function used to configure validation. Defaults to None.
accept_failures : Optional[bool]
Whether or not accept rounds
containing failures. Defaults to True.
initial_parameters : Optional[Parameters]
Initial global model parameters.
"""
# pylint: disable=too-many-arguments,too-many-instance-attributes
def __init__(
self,
*,
fraction_fit: float = 1.0,
fraction_evaluate: float = 1.0,
min_fit_clients: int = 2,
min_evaluate_clients: int = 2,
min_available_clients: int = 2,
evaluate_fn: Optional[
Callable[
[int, NDArrays, Dict[str, Scalar]],
Optional[Tuple[float, Dict[str, Scalar]]],
]
] = None,
on_fit_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None,
on_evaluate_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None,
accept_failures: bool = True,
initial_parameters: Optional[Parameters] = None,
) -> None:
super().__init__()
self.min_fit_clients = min_fit_clients
self.min_evaluate_clients = min_evaluate_clients
self.fraction_fit = fraction_fit
self.fraction_evaluate = fraction_evaluate
self.min_available_clients = min_available_clients
self.evaluate_fn = evaluate_fn
self.on_fit_config_fn = on_fit_config_fn
self.on_evaluate_config_fn = on_evaluate_config_fn
self.accept_failures = accept_failures
self.initial_parameters = initial_parameters
def __repr__(self) -> str:
"""Compute a string representation of the strategy."""
rep = f"FedAvg(accept_failures={self.accept_failures})"
return rep
[文档] def num_fit_clients(self, num_available_clients: int) -> Tuple[int, int]:
"""Return the sample size and the required number of available clients."""
num_clients = int(num_available_clients * self.fraction_fit)
return max(num_clients, self.min_fit_clients), self.min_available_clients
[文档] def num_evaluation_clients(self, num_available_clients: int) -> Tuple[int, int]:
"""Use a fraction of available clients for evaluation."""
num_clients = int(num_available_clients * self.fraction_evaluate)
return max(num_clients, self.min_evaluate_clients), self.min_available_clients
[文档] def initialize_parameters(
self, client_manager: ClientManager
) -> Optional[Parameters]:
"""Initialize global model parameters."""
initial_parameters = self.initial_parameters
self.initial_parameters = None # Don't keep initial parameters in memory
return initial_parameters
[文档] def evaluate(
self, server_round: int, parameters: Parameters
) -> Optional[Tuple[float, Dict[str, Scalar]]]:
"""Evaluate model parameters using an evaluation function."""
if self.evaluate_fn is None:
# No evaluation function provided
return None
weights = self.parameters_to_ndarrays(parameters)
eval_res = self.evaluate_fn(server_round, weights, {})
if eval_res is None:
return None
loss, metrics = eval_res
return loss, metrics
[文档] def aggregate_fit(
self,
server_round: int,
results: List[Tuple[ClientProxy, FitRes]],
failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],
) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:
"""Aggregate fit results using weighted average."""
if not results:
return None, {}
# Do not aggregate if there are failures and failures are not accepted
if not self.accept_failures and failures:
return None, {}
# Convert results
weights_results = [
(self.parameters_to_ndarrays(fit_res.parameters), fit_res.num_examples)
for client, fit_res in results
]
return self.ndarrays_to_parameters(aggregate(weights_results)), {}
[文档] def aggregate_evaluate(
self,
server_round: int,
results: List[Tuple[ClientProxy, EvaluateRes]],
failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]],
) -> Tuple[Optional[float], Dict[str, Scalar]]:
"""Aggregate evaluation losses using weighted average."""
if not results:
return None, {}
# Do not aggregate if there are failures and failures are not accepted
if not self.accept_failures and failures:
return None, {}
loss_aggregated = weighted_loss_avg(
[
(evaluate_res.num_examples, evaluate_res.loss)
for _, evaluate_res in results
]
)
return loss_aggregated, {}
[文档] def ndarrays_to_parameters(self, ndarrays: NDArrays) -> Parameters:
"""Convert NumPy ndarrays to parameters object."""
tensors = [self.ndarray_to_bytes(ndarray) for ndarray in ndarrays]
return Parameters(tensors=tensors, tensor_type="numpy.nda")
[文档] def parameters_to_ndarrays(self, parameters: Parameters) -> NDArrays:
"""Convert parameters object to NumPy weights."""
return [self.bytes_to_ndarray(tensor) for tensor in parameters.tensors]
[文档] def ndarray_to_bytes(self, ndarray: NDArray) -> bytes:
"""Serialize NumPy array to bytes."""
return ndarray.tobytes()
[文档] def bytes_to_ndarray(self, tensor: bytes) -> NDArray:
"""Deserialize NumPy array from bytes."""
ndarray_deserialized = np.frombuffer(tensor, dtype=np.float32)
return cast(NDArray, ndarray_deserialized)