# Copyright 2023 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""MergeResplitter class for Flower Datasets."""
import collections
import warnings
from functools import reduce
from typing import Dict, List, Tuple
import datasets
from datasets import Dataset, DatasetDict
[docs]class MergeResplitter:
"""Merge existing splits of the dataset and assign them custom names.
Create new `DatasetDict` with new split names corresponding to the merged existing
splits (e.g. "train", "valid" and "test").
Parameters
----------
merge_config : Dict[str, Tuple[str, ...]]
Dictionary with keys - the desired split names to values - tuples of the current
split names that will be merged together
Examples
--------
Create new `DatasetDict` with a split name "new_train" that is created as a merger
of the "train" and "valid" splits. Keep the "test" split.
>>> # Assuming there is a dataset_dict of type `DatasetDict`
>>> # dataset_dict is {"train": train-data, "valid": valid-data, "test": test-data}
>>> merge_resplitter = MergeResplitter(
>>> merge_config={
>>> "new_train": ("train", "valid"),
>>> "test": ("test", )
>>> }
>>> )
>>> new_dataset_dict = merge_resplitter(dataset_dict)
>>> # new_dataset_dict is
>>> # {"new_train": concatenation of train-data and valid-data, "test": test-data}
"""
def __init__(
self,
merge_config: Dict[str, Tuple[str, ...]],
) -> None:
self._merge_config: Dict[str, Tuple[str, ...]] = merge_config
self._check_duplicate_merge_splits()
def __call__(self, dataset: DatasetDict) -> DatasetDict:
"""Resplit the dataset according to the `merge_config`."""
self._check_correct_keys_in_merge_config(dataset)
return self.resplit(dataset)
[docs] def resplit(self, dataset: DatasetDict) -> DatasetDict:
"""Resplit the dataset according to the `merge_config`."""
resplit_dataset = {}
for divide_to, divided_from__list in self._merge_config.items():
datasets_from_list: List[Dataset] = []
for divide_from in divided_from__list:
datasets_from_list.append(dataset[divide_from])
if len(datasets_from_list) > 1:
resplit_dataset[divide_to] = datasets.concatenate_datasets(
datasets_from_list
)
else:
resplit_dataset[divide_to] = datasets_from_list[0]
return datasets.DatasetDict(resplit_dataset)
def _check_correct_keys_in_merge_config(self, dataset: DatasetDict) -> None:
"""Check if the keys in merge_config are existing dataset splits."""
dataset_keys = dataset.keys()
specified_dataset_keys = self._merge_config.values()
for key_list in specified_dataset_keys:
for key in key_list:
if key not in dataset_keys:
raise ValueError(
f"The given dataset key '{key}' is not present in the given "
f"dataset object. Make sure to use only the keywords that are "
f"available in your dataset."
)
def _check_duplicate_merge_splits(self) -> None:
"""Check if the original splits are duplicated for new splits creation."""
merge_splits = reduce(lambda x, y: x + y, self._merge_config.values())
duplicates = [
item
for item, count in collections.Counter(merge_splits).items()
if count > 1
]
if duplicates:
warnings.warn(
f"More than one desired splits used '{duplicates[0]}' in "
f"`merge_config`. Make sure that is the intended behavior.",
stacklevel=1,
)