# Copyright 2023 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utils for FederatedDataset."""
import warnings
from typing import Dict, List, Optional, Tuple, Union, cast
from datasets import Dataset, DatasetDict, concatenate_datasets
from flwr_datasets.partitioner import IidPartitioner, Partitioner
from flwr_datasets.resplitter import Resplitter
from flwr_datasets.resplitter.merge_resplitter import MergeResplitter
tested_datasets = [
"mnist",
"cifar10",
"fashion_mnist",
"sasha/dog-food",
"zh-plus/tiny-imagenet",
"scikit-learn/adult-census-income",
"cifar100",
"svhn",
"sentiment140",
"speech_commands",
]
def _instantiate_partitioners(
partitioners: Dict[str, Union[Partitioner, int]]
) -> Dict[str, Partitioner]:
"""Transform the partitioners from the initial format to instantiated objects.
Parameters
----------
partitioners : Dict[str, Union[Partitioner, int]]
Dataset split to the Partitioner or a number of IID partitions.
Returns
-------
partitioners : Dict[str, Partitioner]
Partitioners specified as split to Partitioner object.
"""
instantiated_partitioners: Dict[str, Partitioner] = {}
if isinstance(partitioners, Dict):
for split, partitioner in partitioners.items():
if isinstance(partitioner, Partitioner):
instantiated_partitioners[split] = partitioner
elif isinstance(partitioner, int):
instantiated_partitioners[split] = IidPartitioner(
num_partitions=partitioner
)
else:
raise ValueError(
f"Incorrect type of the 'partitioners' value encountered. "
f"Expected Partitioner or int. Given {type(partitioner)}"
)
else:
raise ValueError(
f"Incorrect type of the 'partitioners' encountered. "
f"Expected Dict[str, Union[int, Partitioner]]. "
f"Given {type(partitioners)}."
)
return instantiated_partitioners
def _instantiate_resplitter_if_needed(
resplitter: Optional[Union[Resplitter, Dict[str, Tuple[str, ...]]]]
) -> Optional[Resplitter]:
"""Instantiate `MergeResplitter` if resplitter is merge_config."""
if resplitter and isinstance(resplitter, Dict):
resplitter = MergeResplitter(merge_config=resplitter)
return cast(Optional[Resplitter], resplitter)
def _check_if_dataset_tested(dataset: str) -> None:
"""Check if the dataset is in the narrowed down list of the tested datasets."""
if dataset not in tested_datasets:
warnings.warn(
f"The currently tested dataset are {tested_datasets}. Given: {dataset}.",
stacklevel=1,
)
[docs]def divide_dataset(
dataset: Dataset, division: Union[List[float], Tuple[float, ...], Dict[str, float]]
) -> Union[List[Dataset], DatasetDict]:
"""Divide the dataset according to the `division`.
The division support varying number of splits, which you can name. The splits are
created from the beginning of the dataset.
Parameters
----------
dataset : Dataset
Dataset to be divided.
division: Union[List[float], Tuple[float, ...], Dict[str, float]]
Configuration specifying how the dataset is divided. Each fraction has to be
>0 and <=1. They have to sum up to at most 1 (smaller sum is possible).
Returns
-------
divided_dataset : Union[List[Dataset], DatasetDict]
If `division` is `List` or `Tuple` then `List[Dataset]` is returned else if
`division` is `Dict` then `DatasetDict` is returned.
Examples
--------
Use `divide_dataset` with division specified as a list.
>>> from flwr_datasets import FederatedDataset
>>> from flwr_datasets.utils import divide_dataset
>>>
>>> fds = FederatedDataset(dataset="mnist", partitioners={"train": 100})
>>> partition = fds.load_partition(0)
>>> division = [0.8, 0.2]
>>> train, test = divide_dataset(dataset=partition, division=division)
Use `divide_dataset` with division specified as a dict.
>>> from flwr_datasets import FederatedDataset
>>> from flwr_datasets.utils import divide_dataset
>>>
>>> fds = FederatedDataset(dataset="mnist", partitioners={"train": 100})
>>> partition = fds.load_partition(0)
>>> division = {"train": 0.8, "test": 0.2}
>>> train_test = divide_dataset(dataset=partition, division=division)
>>> train, test = train_test["train"], train_test["test"]
"""
_check_division_config_correctness(division)
dataset_length = len(dataset)
ranges = _create_division_indices_ranges(dataset_length, division)
if isinstance(division, (list, tuple)):
split_partition: List[Dataset] = []
for single_range in ranges:
split_partition.append(dataset.select(single_range))
return split_partition
if isinstance(division, dict):
split_partition_dict: Dict[str, Dataset] = {}
for split_name, single_range in zip(division.keys(), ranges):
split_partition_dict[split_name] = dataset.select(single_range)
return DatasetDict(split_partition_dict)
raise TypeError(
f"The type of the `division` should be dict, "
f"tuple or list but is {type(division)} instead."
)
def _create_division_indices_ranges(
dataset_length: int,
division: Union[List[float], Tuple[float, ...], Dict[str, float]],
) -> List[range]:
ranges = []
if isinstance(division, (list, tuple)):
start_idx = 0
end_idx = 0
for fraction in division:
end_idx += int(dataset_length * fraction)
ranges.append(range(start_idx, end_idx))
start_idx = end_idx
elif isinstance(division, dict):
ranges = []
start_idx = 0
end_idx = 0
for fraction in division.values():
end_idx += int(dataset_length * fraction)
ranges.append(range(start_idx, end_idx))
start_idx = end_idx
else:
TypeError(
f"The type of the `division` should be dict, "
f"tuple or list but is {type(division)} instead. "
)
return ranges
def _check_division_config_types_correctness(
division: Union[List[float], Tuple[float, ...], Dict[str, float]]
) -> None:
if isinstance(division, (list, tuple)):
if not all(isinstance(x, float) for x in division):
raise TypeError(
"List or tuple values of `division` must contain only floats, "
"other types are not allowed."
)
elif isinstance(division, dict):
if not all(isinstance(x, float) for x in division.values()):
raise TypeError(
"Dict values of `division` must be only floats, "
"other types are not allowed."
)
else:
raise TypeError("`division` must be a list, tuple, or dict.")
def _check_division_config_values_correctness(
division: Union[List[float], Tuple[float, ...], Dict[str, float]]
) -> None:
if isinstance(division, (list, tuple)):
if not all(0 < x <= 1 for x in division):
raise ValueError(
"All fractions for the division must be greater than 0 and smaller or "
"equal to 1."
)
fraction_sum_from_list_tuple = sum(division)
if fraction_sum_from_list_tuple > 1:
raise ValueError("Sum of fractions for division must not exceed 1.")
if fraction_sum_from_list_tuple < 1:
warnings.warn(
f"Sum of fractions for division is {sum(division)}, which is below 1. "
f"Make sure that's the desired behavior. Some data will not be used "
f"in the current specification.",
stacklevel=1,
)
elif isinstance(division, dict):
values = list(division.values())
if not all(0 < x <= 1 for x in values):
raise ValueError(
"All fractions must be greater than 0 and smaller or equal to 1."
)
if sum(values) > 1:
raise ValueError("Sum of fractions must not exceed 1.")
if sum(values) < 1:
warnings.warn(
f"Sum of fractions in `division` is {values}, which is below 1. "
f"Make sure that's the desired behavior. Some data will not be used "
f"in the current specification.",
stacklevel=1,
)
else:
raise TypeError("`division` must be a list, tuple, or dict.")
def _check_division_config_correctness(
division: Union[List[float], Tuple[float, ...], Dict[str, float]]
) -> None:
_check_division_config_types_correctness(division)
_check_division_config_values_correctness(division)
[docs]def concatenate_divisions(
partitioner: Partitioner,
partition_division: Union[List[float], Tuple[float, ...], Dict[str, float]],
division_id: Union[int, str],
) -> Dataset:
"""Create a dataset by concatenation of all partitions in the same division.
The divisions are created based on the `partition_division` and accessed based
on the `division_id`. It can be used to create e.g. centralized dataset from
federated on-edge test sets.
Parameters
----------
partitioner : Partitioner
Partitioner object with assigned dataset.
partition_division : Union[List[float], Tuple[float, ...], Dict[str, float]]
Fractions specifying the division of the partitions of a `partitioner`. You can
think of this as on-edge division of the data into multiple divisions
(e.g. into train and validation). E.g. [0.8, 0.2] or
{"partition_train": 0.8, "partition_test": 0.2}.
division_id : Union[int, str]
The way to access the division (from a List or DatasetDict). If your
`partition_division` is specified as a list, then `division_id` represents an
index to an element in that list. If `partition_division` is passed as a
`Dict`, then `division_id` is a key of such dictionary.
Returns
-------
concatenated_divisions : Dataset
A dataset created as concatenation of the divisions from all partitions.
"""
_check_division_config_correctness(partition_division)
divisions = []
zero_len_divisions = 0
for partition_id in range(partitioner.num_partitions):
partition = partitioner.load_partition(partition_id)
if isinstance(partition_division, (list, tuple)):
if not isinstance(division_id, int):
raise TypeError(
"The `division_id` needs to be an int in case of "
"`partition_division` specification as List."
)
partition = divide_dataset(partition, partition_division)
division = partition[division_id]
elif isinstance(partition_division, Dict):
partition = divide_dataset(partition, partition_division)
division = partition[division_id]
else:
raise TypeError(
"The type of partition needs to be List of DatasetDict in this "
"context."
)
if len(division) == 0:
zero_len_divisions += 1
divisions.append(division)
if zero_len_divisions == partitioner.num_partitions:
raise ValueError(
"The concatenated dataset is of length 0. Please change the "
"`partition_division` parameter to change this behavior."
)
if zero_len_divisions != 0:
warnings.warn(
f"{zero_len_divisions} division(s) have length zero.", stacklevel=1
)
return concatenate_datasets(divisions)