Source code for flwr_datasets.partitioner.size_partitioner

# Copyright 2023 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""SizePartitioner class."""


from typing import Callable, Dict, List, Union

import numpy as np

import datasets
from flwr_datasets.partitioner.partitioner import Partitioner


[docs]class SizePartitioner(Partitioner): """Base class for the deterministic size partitioning based on the `partition_id`. The client with `partition_id` has the following relationship regarding the number of samples. `partition_id_to_size_fn(partition_id)` ~ number of samples for `partition_id` If the function doesn't transform the `partition_id` it's a linear correlation between the number of sample for the partition and the value of `partition_id`. For instance, if the partition ids range from 1 to M, partition with id 1 gets 1 unit of data, client 2 gets 2 units, and so on, up to partition M which gets M units. Note that size corresponding to the `partition_id` is deterministic, yet in case of different dataset shuffling the assignment of samples to `partition_id` will vary. Parameters ---------- num_partitions : int The total number of partitions that the data will be divided into. partition_id_to_size_fn : Callable Function that defines the relationship between partition id and the number of samples. """ def __init__( self, num_partitions: int, partition_id_to_size_fn: Callable, # type: ignore[type-arg] ) -> None: super().__init__() if num_partitions <= 0: raise ValueError("The number of partitions must be greater than zero.") self._num_partitions = num_partitions self._partition_id_to_size_fn = partition_id_to_size_fn self._partition_id_to_size: Dict[int, int] = {} self._partition_id_to_indices: Dict[int, List[int]] = {} # A flag to perform only a single compute to determine the indices self._partition_id_to_indices_determined = False
[docs] def load_partition(self, partition_id: int) -> datasets.Dataset: """Load a single partition based on the partition index. The number of samples is dependent on the partition partition_id. Parameters ---------- partition_id : int the index that corresponds to the requested partition Returns ------- dataset_partition: Dataset single dataset partition """ # The partitioning is done lazily - only when the first partition is requested. # A single run creates the indices assignments for all the partition indices. self._determine_partition_id_to_indices_if_needed() return self.dataset.select(self._partition_id_to_indices[partition_id])
@property def num_partitions(self) -> int: """Total number of partitions.""" self._determine_partition_id_to_indices_if_needed() return self._num_partitions @property def partition_id_to_size(self) -> Dict[int, int]: """Node id to the number of samples.""" return self._partition_id_to_size @property def partition_id_to_indices(self) -> Dict[int, List[int]]: """Node id to the list of indices.""" return self._partition_id_to_indices def _determine_partition_id_to_size(self) -> None: """Determine data quantity associated with partition indices.""" data_division_in_units = self._partition_id_to_size_fn( np.linspace(start=1, stop=self._num_partitions, num=self._num_partitions) ) total_units: Union[int, float] = data_division_in_units.sum() # Normalize the units to get the fraction total dataset partition_sizes_as_fraction = data_division_in_units / total_units # Calculate the number of samples partition_sizes_as_num_of_samples = np.array( partition_sizes_as_fraction * len(self.dataset), dtype=np.int64 ) # Check if any sample is not allocated because of multiplication with fractions. assigned_samples = np.sum(partition_sizes_as_num_of_samples) left_unassigned_samples = len(self.dataset) - assigned_samples # If there is any sample(s) left unassigned, assign it to the largest partition. partition_sizes_as_num_of_samples[-1] += left_unassigned_samples for idx, partition_size in enumerate(partition_sizes_as_num_of_samples): self._partition_id_to_size[idx] = partition_size self._check_if_partition_id_to_size_possible() def _determine_partition_id_to_indices_if_needed(self) -> None: """Create an assignment of indices to the partition indices..""" if self._partition_id_to_indices_determined is True: return self._determine_partition_id_to_size() total_samples_assigned = 0 for idx, quantity in self._partition_id_to_size.items(): self._partition_id_to_indices[idx] = list( range(total_samples_assigned, total_samples_assigned + quantity) ) total_samples_assigned += quantity self._partition_id_to_indices_determined = True def _check_if_partition_id_to_size_possible(self) -> None: all_positive = all(value >= 1 for value in self.partition_id_to_size.values()) if not all_positive: raise ValueError( f"The given specification of the parameter num_partitions" f"={self._num_partitions} for the given dataset results " f"in the partitions sizes that are not greater than 0." )