Code source de flwr.server.strategy.fedxgb_cyclic

# Copyright 2020 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Federated XGBoost cyclic aggregation strategy."""


from logging import WARNING
from typing import Any, Dict, List, Optional, Tuple, Union, cast

from flwr.common import EvaluateIns, EvaluateRes, FitIns, FitRes, Parameters, Scalar
from flwr.common.logger import log
from flwr.server.client_manager import ClientManager
from flwr.server.client_proxy import ClientProxy

from .fedavg import FedAvg


[docs]class FedXgbCyclic(FedAvg): """Configurable FedXgbCyclic strategy implementation.""" # pylint: disable=too-many-arguments,too-many-instance-attributes, line-too-long def __init__( self, **kwargs: Any, ): self.global_model: Optional[bytes] = None super().__init__(**kwargs) def __repr__(self) -> str: """Compute a string representation of the strategy.""" rep = f"FedXgbCyclic(accept_failures={self.accept_failures})" return rep
[docs] def aggregate_fit( self, server_round: int, results: List[Tuple[ClientProxy, FitRes]], failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]], ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]: """Aggregate fit results using bagging.""" if not results: return None, {} # Do not aggregate if there are failures and failures are not accepted if not self.accept_failures and failures: return None, {} # Fetch the client model from last round as global model for _, fit_res in results: update = fit_res.parameters.tensors for bst in update: self.global_model = bst return ( Parameters(tensor_type="", tensors=[cast(bytes, self.global_model)]), {}, )
[docs] def aggregate_evaluate( self, server_round: int, results: List[Tuple[ClientProxy, EvaluateRes]], failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]], ) -> Tuple[Optional[float], Dict[str, Scalar]]: """Aggregate evaluation metrics using average.""" if not results: return None, {} # Do not aggregate if there are failures and failures are not accepted if not self.accept_failures and failures: return None, {} # Aggregate custom metrics if aggregation fn was provided metrics_aggregated = {} if self.evaluate_metrics_aggregation_fn: eval_metrics = [(res.num_examples, res.metrics) for _, res in results] metrics_aggregated = self.evaluate_metrics_aggregation_fn(eval_metrics) elif server_round == 1: # Only log this warning once log(WARNING, "No evaluate_metrics_aggregation_fn provided") return 0, metrics_aggregated
[docs] def configure_fit( self, server_round: int, parameters: Parameters, client_manager: ClientManager ) -> List[Tuple[ClientProxy, FitIns]]: """Configure the next round of training.""" config = {} if self.on_fit_config_fn is not None: # Custom fit config function provided config = self.on_fit_config_fn(server_round) fit_ins = FitIns(parameters, config) # Sample clients sample_size, min_num_clients = self.num_fit_clients( client_manager.num_available() ) clients = client_manager.sample( num_clients=sample_size, min_num_clients=min_num_clients, ) # Sample the clients sequentially given server_round sampled_idx = (server_round - 1) % len(clients) sampled_clients = [clients[sampled_idx]] # Return client/config pairs return [(client, fit_ins) for client in sampled_clients]
[docs] def configure_evaluate( self, server_round: int, parameters: Parameters, client_manager: ClientManager ) -> List[Tuple[ClientProxy, EvaluateIns]]: """Configure the next round of evaluation.""" # Do not configure federated evaluation if fraction eval is 0. if self.fraction_evaluate == 0.0: return [] # Parameters and config config = {} if self.on_evaluate_config_fn is not None: # Custom evaluation config function provided config = self.on_evaluate_config_fn(server_round) evaluate_ins = EvaluateIns(parameters, config) # Sample clients sample_size, min_num_clients = self.num_evaluation_clients( client_manager.num_available() ) clients = client_manager.sample( num_clients=sample_size, min_num_clients=min_num_clients, ) # Sample the clients sequentially given server_round sampled_idx = (server_round - 1) % len(clients) sampled_clients = [clients[sampled_idx]] # Return client/config pairs return [(client, evaluate_ins) for client in sampled_clients]