flwr.server.client_manager 源代码

# Copyright 2020 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Flower ClientManager."""


import random
import threading
from abc import ABC, abstractmethod
from logging import INFO
from typing import Dict, List, Optional

from flwr.common.logger import log

from .client_proxy import ClientProxy
from .criterion import Criterion


[文档]class ClientManager(ABC): """Abstract base class for managing Flower clients."""
[文档] @abstractmethod def num_available(self) -> int: """Return the number of available clients. Returns ------- num_available : int The number of currently available clients. """
[文档] @abstractmethod def register(self, client: ClientProxy) -> bool: """Register Flower ClientProxy instance. Parameters ---------- client : flwr.server.client_proxy.ClientProxy Returns ------- success : bool Indicating if registration was successful. False if ClientProxy is already registered or can not be registered for any reason. """
[文档] @abstractmethod def unregister(self, client: ClientProxy) -> None: """Unregister Flower ClientProxy instance. This method is idempotent. Parameters ---------- client : flwr.server.client_proxy.ClientProxy """
[文档] @abstractmethod def all(self) -> Dict[str, ClientProxy]: """Return all available clients."""
[文档] @abstractmethod def wait_for(self, num_clients: int, timeout: int) -> bool: """Wait until at least `num_clients` are available."""
[文档] @abstractmethod def sample( self, num_clients: int, min_num_clients: Optional[int] = None, criterion: Optional[Criterion] = None, ) -> List[ClientProxy]: """Sample a number of Flower ClientProxy instances."""
[文档]class SimpleClientManager(ClientManager): """Provides a pool of available clients.""" def __init__(self) -> None: self.clients: Dict[str, ClientProxy] = {} self._cv = threading.Condition() def __len__(self) -> int: """Return the number of available clients. Returns ------- num_available : int The number of currently available clients. """ return len(self.clients)
[文档] def num_available(self) -> int: """Return the number of available clients. Returns ------- num_available : int The number of currently available clients. """ return len(self)
[文档] def wait_for(self, num_clients: int, timeout: int = 86400) -> bool: """Wait until at least `num_clients` are available. Blocks until the requested number of clients is available or until a timeout is reached. Current timeout default: 1 day. Parameters ---------- num_clients : int The number of clients to wait for. timeout : int The time in seconds to wait for, defaults to 86400 (24h). Returns ------- success : bool """ with self._cv: return self._cv.wait_for( lambda: len(self.clients) >= num_clients, timeout=timeout )
[文档] def register(self, client: ClientProxy) -> bool: """Register Flower ClientProxy instance. Parameters ---------- client : flwr.server.client_proxy.ClientProxy Returns ------- success : bool Indicating if registration was successful. False if ClientProxy is already registered or can not be registered for any reason. """ if client.cid in self.clients: return False self.clients[client.cid] = client with self._cv: self._cv.notify_all() return True
[文档] def unregister(self, client: ClientProxy) -> None: """Unregister Flower ClientProxy instance. This method is idempotent. Parameters ---------- client : flwr.server.client_proxy.ClientProxy """ if client.cid in self.clients: del self.clients[client.cid] with self._cv: self._cv.notify_all()
[文档] def all(self) -> Dict[str, ClientProxy]: """Return all available clients.""" return self.clients
[文档] def sample( self, num_clients: int, min_num_clients: Optional[int] = None, criterion: Optional[Criterion] = None, ) -> List[ClientProxy]: """Sample a number of Flower ClientProxy instances.""" # Block until at least num_clients are connected. if min_num_clients is None: min_num_clients = num_clients self.wait_for(min_num_clients) # Sample clients which meet the criterion available_cids = list(self.clients) if criterion is not None: available_cids = [ cid for cid in available_cids if criterion.select(self.clients[cid]) ] if num_clients > len(available_cids): log( INFO, "Sampling failed: number of available clients" " (%s) is less than number of requested clients (%s).", len(available_cids), num_clients, ) return [] sampled_cids = random.sample(available_cids, num_clients) return [self.clients[cid] for cid in sampled_cids]