ShardPartitioner#
- class ShardPartitioner(num_partitions: int, partition_by: str, num_shards_per_partition: int | None = None, shard_size: int | None = None, keep_incomplete_shard: bool = False, shuffle: bool = True, seed: int | None = 42)[source]#
Bases:
Partitioner
Partitioner based on shard of (typically) unique classes.
The algorithm works as follows: the dataset is sorted by label e.g. [samples with label 1, samples with labels 2 …], then the shards are created, with each shard of size = shard_size if provided or automatically calculated: shards_size = len(dataset) / num_partitions * num_shards_per_partition.
A shard is just a block (chunk) of a dataset that contains shard_size consecutive samples. There might be shards that contain samples associated with more than a single unique label. The first case is (remember the preprocessing step sorts the dataset by label) when a shard is constructed from samples at the boundaries of the sorted dataset and therefore belonging to different classes e.g. the “leftover” of samples of class 1 and the majority of class 2. The another scenario when a shard has samples with more than one unique label is when the shard size is bigger than the number of samples of a certain class.
Each partition is created from num_shards_per_partition that are chosen randomly.
There are a few ways of partitioning data that result in certain properties (depending on the parameters specification): 1) same number of shards per partitions + the same shard size (specify: a) num_shards_per_partitions, shard_size; or b) num_shards_per_partition) In case of b the shard_size is calculated as floor(len(dataset) / (num_shards_per_partitions * num_partitions)) 2) possibly different number of shards per partition (use nearly all data) + the same shard size (specify: shard_size + keep_incomplete_shard=False) 3) possibly different number of shards per partition (use all data) + possibly different shard size (specify: shard_size + keep_incomplete_shard=True)
Algorithm based on the description in Communication-Efficient Learning of Deep Networks from Decentralized Data https://arxiv.org/abs/1602.05629. This implementation expands on the initial idea by enabling more hyperparameters specification therefore providing more control on how partitions are created. It enables the division obtained in original paper.
- Parameters:
num_partitions (int) – The total number of partitions that the data will be divided into.
partition_by (str) – Column name of the labels (targets) based on which Dirichlet sampling works.
num_shards_per_partition (Optional[int]) – Number of shards to assign to a single partitioner. It’s an alternative to num_partitions.
shard_size (Optional[int]) – Size of a single shards (a partition has one or more shards). If the size is not given it will be automatically computed.
keep_incomplete_shard (bool) – Whether to drop the last shard which might be incomplete (smaller than the others). If it is dropped each shard is equal size. (It does not mean that each client gets equal number of shards, which only happens if num_partitions % num_shards = 0). This parameter has no effect if num_shards_per_partitions and shard_size are specified.
shuffle (bool) – Whether to randomize the order of samples. Shuffling applied after the samples assignment to partitions.
seed (int) – Seed used for dataset shuffling. It has no effect if shuffle is False.
Examples
1) If you need same number of shards per partitions + the same shard size (and you know both of these values)
>>> from flwr_datasets import FederatedDataset >>> from flwr_datasets.partitioner import ShardPartitioner >>> >>> partitioner = ShardPartitioner(num_partitions=10, partition_by="label", >>> num_shards_per_partition=2, shard_size=1_000) >>> fds = FederatedDataset(dataset="mnist", partitioners={"train": partitioner}) >>> partition = fds.load_partition(0) >>> print(partition[0]) # Print the first example {'image': <PIL.PngImagePlugin.PngImageFile image mode=L size=28x28 at 0x15F616C50>, 'label': 3} >>> partition_sizes = [ >>> len(fds.load_partition(partition_id)) for partition_id in range(10) >>> ] >>> print(partition_sizes) [2000, 2000, 2000, 2000, 2000, 2000, 2000, 2000, 2000, 2000]
2) If you want to use nearly all the data and do not need to have the number of shard per each partition to be the same
>>> from flwr_datasets import FederatedDataset >>> from flwr_datasets.partitioner import ShardPartitioner >>> >>> partitioner = ShardPartitioner(num_partitions=9, partition_by="label", >>> shard_size=1_000) >>> fds = FederatedDataset(dataset="mnist", partitioners={"train": partitioner}) >>> partition_sizes = [ >>> len(fds.load_partition(partition_id)) for partition_id in range(9) >>> ] >>> print(partition_sizes) [7000, 7000, 7000, 7000, 7000, 7000, 6000, 6000, 6000]
If you want to use all the data
>>> from flwr_datasets import FederatedDataset >>> from flwr_datasets.partitioner import ShardPartitioner >>> >>> partitioner = ShardPartitioner(num_partitions=10, partition_by="label", >>> shard_size=990, keep_incomplete_shard=True) >>> fds = FederatedDataset(dataset="mnist", partitioners={"train": partitioner}) >>> partition_sizes = [ >>> len(fds.load_partition(partition_id)) for partition_id in range(10) >>> ] >>> print(sorted(partition_sizes)) [5550, 5940, 5940, 5940, 5940, 5940, 5940, 5940, 5940, 6930]
Methods
Check if a dataset has been assigned to the partitioner.
load_partition
(partition_id)Load a partition based on the partition index.
Attributes
Dataset property.
Total number of partitions.
- property dataset: Dataset#
Dataset property.
- is_dataset_assigned() → bool#
Check if a dataset has been assigned to the partitioner.
This method returns True if a dataset is already set for the partitioner, otherwise, it returns False.
- Returns:
dataset_assigned – True if a dataset is assigned, otherwise False.
- Return type:
bool
- load_partition(partition_id: int) → Dataset[source]#
Load a partition based on the partition index.
- Parameters:
partition_id (int) – the index that corresponds to the requested partition
- Returns:
dataset_partition – single partition of a dataset
- Return type:
Dataset
- property num_partitions: int#
Total number of partitions.