flwr.server.server 源代码

# Copyright 2020 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Flower server."""


import concurrent.futures
import io
import timeit
from logging import INFO, WARN
from typing import Dict, List, Optional, Tuple, Union

from flwr.common import (
    Code,
    DisconnectRes,
    EvaluateIns,
    EvaluateRes,
    FitIns,
    FitRes,
    Parameters,
    ReconnectIns,
    Scalar,
)
from flwr.common.logger import log
from flwr.common.typing import GetParametersIns
from flwr.server.client_manager import ClientManager, SimpleClientManager
from flwr.server.client_proxy import ClientProxy
from flwr.server.history import History
from flwr.server.strategy import FedAvg, Strategy

from .server_config import ServerConfig

FitResultsAndFailures = Tuple[
    List[Tuple[ClientProxy, FitRes]],
    List[Union[Tuple[ClientProxy, FitRes], BaseException]],
]
EvaluateResultsAndFailures = Tuple[
    List[Tuple[ClientProxy, EvaluateRes]],
    List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]],
]
ReconnectResultsAndFailures = Tuple[
    List[Tuple[ClientProxy, DisconnectRes]],
    List[Union[Tuple[ClientProxy, DisconnectRes], BaseException]],
]


[文档]class Server: """Flower server.""" def __init__( self, *, client_manager: ClientManager, strategy: Optional[Strategy] = None, ) -> None: self._client_manager: ClientManager = client_manager self.parameters: Parameters = Parameters( tensors=[], tensor_type="numpy.ndarray" ) self.strategy: Strategy = strategy if strategy is not None else FedAvg() self.max_workers: Optional[int] = None
[文档] def set_max_workers(self, max_workers: Optional[int]) -> None: """Set the max_workers used by ThreadPoolExecutor.""" self.max_workers = max_workers
[文档] def set_strategy(self, strategy: Strategy) -> None: """Replace server strategy.""" self.strategy = strategy
[文档] def client_manager(self) -> ClientManager: """Return ClientManager.""" return self._client_manager
# pylint: disable=too-many-locals
[文档] def fit(self, num_rounds: int, timeout: Optional[float]) -> Tuple[History, float]: """Run federated averaging for a number of rounds.""" history = History() # Initialize parameters log(INFO, "[INIT]") self.parameters = self._get_initial_parameters(server_round=0, timeout=timeout) log(INFO, "Evaluating initial global parameters") res = self.strategy.evaluate(0, parameters=self.parameters) if res is not None: log( INFO, "initial parameters (loss, other metrics): %s, %s", res[0], res[1], ) history.add_loss_centralized(server_round=0, loss=res[0]) history.add_metrics_centralized(server_round=0, metrics=res[1]) # Run federated learning for num_rounds start_time = timeit.default_timer() for current_round in range(1, num_rounds + 1): log(INFO, "") log(INFO, "[ROUND %s]", current_round) # Train model and replace previous global model res_fit = self.fit_round( server_round=current_round, timeout=timeout, ) if res_fit is not None: parameters_prime, fit_metrics, _ = res_fit # fit_metrics_aggregated if parameters_prime: self.parameters = parameters_prime history.add_metrics_distributed_fit( server_round=current_round, metrics=fit_metrics ) # Evaluate model using strategy implementation res_cen = self.strategy.evaluate(current_round, parameters=self.parameters) if res_cen is not None: loss_cen, metrics_cen = res_cen log( INFO, "fit progress: (%s, %s, %s, %s)", current_round, loss_cen, metrics_cen, timeit.default_timer() - start_time, ) history.add_loss_centralized(server_round=current_round, loss=loss_cen) history.add_metrics_centralized( server_round=current_round, metrics=metrics_cen ) # Evaluate model on a sample of available clients res_fed = self.evaluate_round(server_round=current_round, timeout=timeout) if res_fed is not None: loss_fed, evaluate_metrics_fed, _ = res_fed if loss_fed is not None: history.add_loss_distributed( server_round=current_round, loss=loss_fed ) history.add_metrics_distributed( server_round=current_round, metrics=evaluate_metrics_fed ) # Bookkeeping end_time = timeit.default_timer() elapsed = end_time - start_time return history, elapsed
[文档] def evaluate_round( self, server_round: int, timeout: Optional[float], ) -> Optional[ Tuple[Optional[float], Dict[str, Scalar], EvaluateResultsAndFailures] ]: """Validate current global model on a number of clients.""" # Get clients and their respective instructions from strategy client_instructions = self.strategy.configure_evaluate( server_round=server_round, parameters=self.parameters, client_manager=self._client_manager, ) if not client_instructions: log(INFO, "configure_evaluate: no clients selected, skipping evaluation") return None log( INFO, "configure_evaluate: strategy sampled %s clients (out of %s)", len(client_instructions), self._client_manager.num_available(), ) # Collect `evaluate` results from all clients participating in this round results, failures = evaluate_clients( client_instructions, max_workers=self.max_workers, timeout=timeout, group_id=server_round, ) log( INFO, "aggregate_evaluate: received %s results and %s failures", len(results), len(failures), ) # Aggregate the evaluation results aggregated_result: Tuple[ Optional[float], Dict[str, Scalar], ] = self.strategy.aggregate_evaluate(server_round, results, failures) loss_aggregated, metrics_aggregated = aggregated_result return loss_aggregated, metrics_aggregated, (results, failures)
[文档] def fit_round( self, server_round: int, timeout: Optional[float], ) -> Optional[ Tuple[Optional[Parameters], Dict[str, Scalar], FitResultsAndFailures] ]: """Perform a single round of federated averaging.""" # Get clients and their respective instructions from strategy client_instructions = self.strategy.configure_fit( server_round=server_round, parameters=self.parameters, client_manager=self._client_manager, ) if not client_instructions: log(INFO, "configure_fit: no clients selected, cancel") return None log( INFO, "configure_fit: strategy sampled %s clients (out of %s)", len(client_instructions), self._client_manager.num_available(), ) # Collect `fit` results from all clients participating in this round results, failures = fit_clients( client_instructions=client_instructions, max_workers=self.max_workers, timeout=timeout, group_id=server_round, ) log( INFO, "aggregate_fit: received %s results and %s failures", len(results), len(failures), ) # Aggregate training results aggregated_result: Tuple[ Optional[Parameters], Dict[str, Scalar], ] = self.strategy.aggregate_fit(server_round, results, failures) parameters_aggregated, metrics_aggregated = aggregated_result return parameters_aggregated, metrics_aggregated, (results, failures)
[文档] def disconnect_all_clients(self, timeout: Optional[float]) -> None: """Send shutdown signal to all clients.""" all_clients = self._client_manager.all() clients = [all_clients[k] for k in all_clients.keys()] instruction = ReconnectIns(seconds=None) client_instructions = [(client_proxy, instruction) for client_proxy in clients] _ = reconnect_clients( client_instructions=client_instructions, max_workers=self.max_workers, timeout=timeout, )
def _get_initial_parameters( self, server_round: int, timeout: Optional[float] ) -> Parameters: """Get initial parameters from one of the available clients.""" # Server-side parameter initialization parameters: Optional[Parameters] = self.strategy.initialize_parameters( client_manager=self._client_manager ) if parameters is not None: log(INFO, "Using initial global parameters provided by strategy") return parameters # Get initial parameters from one of the clients log(INFO, "Requesting initial parameters from one random client") random_client = self._client_manager.sample(1)[0] ins = GetParametersIns(config={}) get_parameters_res = random_client.get_parameters( ins=ins, timeout=timeout, group_id=server_round ) log(INFO, "Received initial parameters from one random client") return get_parameters_res.parameters
def reconnect_clients( client_instructions: List[Tuple[ClientProxy, ReconnectIns]], max_workers: Optional[int], timeout: Optional[float], ) -> ReconnectResultsAndFailures: """Instruct clients to disconnect and never reconnect.""" with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: submitted_fs = { executor.submit(reconnect_client, client_proxy, ins, timeout) for client_proxy, ins in client_instructions } finished_fs, _ = concurrent.futures.wait( fs=submitted_fs, timeout=None, # Handled in the respective communication stack ) # Gather results results: List[Tuple[ClientProxy, DisconnectRes]] = [] failures: List[Union[Tuple[ClientProxy, DisconnectRes], BaseException]] = [] for future in finished_fs: failure = future.exception() if failure is not None: failures.append(failure) else: result = future.result() results.append(result) return results, failures def reconnect_client( client: ClientProxy, reconnect: ReconnectIns, timeout: Optional[float], ) -> Tuple[ClientProxy, DisconnectRes]: """Instruct client to disconnect and (optionally) reconnect later.""" disconnect = client.reconnect( reconnect, timeout=timeout, group_id=None, ) return client, disconnect def fit_clients( client_instructions: List[Tuple[ClientProxy, FitIns]], max_workers: Optional[int], timeout: Optional[float], group_id: int, ) -> FitResultsAndFailures: """Refine parameters concurrently on all selected clients.""" with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: submitted_fs = { executor.submit(fit_client, client_proxy, ins, timeout, group_id) for client_proxy, ins in client_instructions } finished_fs, _ = concurrent.futures.wait( fs=submitted_fs, timeout=None, # Handled in the respective communication stack ) # Gather results results: List[Tuple[ClientProxy, FitRes]] = [] failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]] = [] for future in finished_fs: _handle_finished_future_after_fit( future=future, results=results, failures=failures ) return results, failures def fit_client( client: ClientProxy, ins: FitIns, timeout: Optional[float], group_id: int ) -> Tuple[ClientProxy, FitRes]: """Refine parameters on a single client.""" fit_res = client.fit(ins, timeout=timeout, group_id=group_id) return client, fit_res def _handle_finished_future_after_fit( future: concurrent.futures.Future, # type: ignore results: List[Tuple[ClientProxy, FitRes]], failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]], ) -> None: """Convert finished future into either a result or a failure.""" # Check if there was an exception failure = future.exception() if failure is not None: failures.append(failure) return # Successfully received a result from a client result: Tuple[ClientProxy, FitRes] = future.result() _, res = result # Check result status code if res.status.code == Code.OK: results.append(result) return # Not successful, client returned a result where the status code is not OK failures.append(result) def evaluate_clients( client_instructions: List[Tuple[ClientProxy, EvaluateIns]], max_workers: Optional[int], timeout: Optional[float], group_id: int, ) -> EvaluateResultsAndFailures: """Evaluate parameters concurrently on all selected clients.""" with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: submitted_fs = { executor.submit(evaluate_client, client_proxy, ins, timeout, group_id) for client_proxy, ins in client_instructions } finished_fs, _ = concurrent.futures.wait( fs=submitted_fs, timeout=None, # Handled in the respective communication stack ) # Gather results results: List[Tuple[ClientProxy, EvaluateRes]] = [] failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]] = [] for future in finished_fs: _handle_finished_future_after_evaluate( future=future, results=results, failures=failures ) return results, failures def evaluate_client( client: ClientProxy, ins: EvaluateIns, timeout: Optional[float], group_id: int, ) -> Tuple[ClientProxy, EvaluateRes]: """Evaluate parameters on a single client.""" evaluate_res = client.evaluate(ins, timeout=timeout, group_id=group_id) return client, evaluate_res def _handle_finished_future_after_evaluate( future: concurrent.futures.Future, # type: ignore results: List[Tuple[ClientProxy, EvaluateRes]], failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]], ) -> None: """Convert finished future into either a result or a failure.""" # Check if there was an exception failure = future.exception() if failure is not None: failures.append(failure) return # Successfully received a result from a client result: Tuple[ClientProxy, EvaluateRes] = future.result() _, res = result # Check result status code if res.status.code == Code.OK: results.append(result) return # Not successful, client returned a result where the status code is not OK failures.append(result) def init_defaults( server: Optional[Server], config: Optional[ServerConfig], strategy: Optional[Strategy], client_manager: Optional[ClientManager], ) -> Tuple[Server, ServerConfig]: """Create server instance if none was given.""" if server is None: if client_manager is None: client_manager = SimpleClientManager() if strategy is None: strategy = FedAvg() server = Server(client_manager=client_manager, strategy=strategy) elif strategy is not None: log(WARN, "Both server and strategy were provided, ignoring strategy") # Set default config values if config is None: config = ServerConfig() return server, config def run_fl( server: Server, config: ServerConfig, ) -> History: """Train a model on the given server and return the History object.""" hist, elapsed_time = server.fit( num_rounds=config.num_rounds, timeout=config.round_timeout ) log(INFO, "") log(INFO, "[SUMMARY]") log(INFO, "Run finished %s rounds in %.2fs", config.num_rounds, elapsed_time) for idx, line in enumerate(io.StringIO(str(hist))): if idx == 0: log(INFO, "%s", line.strip("\n")) else: log(INFO, "\t%s", line.strip("\n")) log(INFO, "") # Graceful shutdown server.disconnect_all_clients(timeout=config.round_timeout) return hist