Source code for mmdet.datasets.samplers.custom_sample_size_sampler
# Copyright (c) OpenMMLab. All rights reserved.
import math
from typing import Iterator, Optional, Sequence, Sized
import torch
from mmengine.dist import get_dist_info, sync_random_seed
from torch.utils.data import Sampler
from mmdet.registry import DATA_SAMPLERS
from .class_aware_sampler import RandomCycleIter
[docs]@DATA_SAMPLERS.register_module()
class CustomSampleSizeSampler(Sampler):
def __init__(self,
dataset: Sized,
dataset_size: Sequence[int],
ratio_mode: bool = False,
seed: Optional[int] = None,
round_up: bool = True) -> None:
assert len(dataset.datasets) == len(dataset_size)
rank, world_size = get_dist_info()
self.rank = rank
self.world_size = world_size
self.dataset = dataset
if seed is None:
seed = sync_random_seed()
self.seed = seed
self.epoch = 0
self.round_up = round_up
total_size = 0
total_size_fake = 0
self.dataset_index = []
self.dataset_cycle_iter = []
new_dataset_size = []
for dataset, size in zip(dataset.datasets, dataset_size):
self.dataset_index.append(
list(range(total_size_fake,
len(dataset) + total_size_fake)))
total_size_fake += len(dataset)
if size == -1:
total_size += len(dataset)
self.dataset_cycle_iter.append(None)
new_dataset_size.append(-1)
else:
if ratio_mode:
size = int(size * len(dataset))
assert size <= len(
dataset
), f'dataset size {size} is larger than ' \
f'dataset length {len(dataset)}'
total_size += size
new_dataset_size.append(size)
g = torch.Generator()
g.manual_seed(self.seed)
self.dataset_cycle_iter.append(
RandomCycleIter(self.dataset_index[-1], generator=g))
self.dataset_size = new_dataset_size
if self.round_up:
self.num_samples = math.ceil(total_size / world_size)
self.total_size = self.num_samples * self.world_size
else:
self.num_samples = math.ceil((total_size - rank) / world_size)
self.total_size = total_size
def __iter__(self) -> Iterator[int]:
"""Iterate the indices."""
# deterministically shuffle based on epoch and seed
g = torch.Generator()
g.manual_seed(self.seed + self.epoch)
out_index = []
for data_size, data_index, cycle_iter in zip(self.dataset_size,
self.dataset_index,
self.dataset_cycle_iter):
if data_size == -1:
out_index += data_index
else:
index = [next(cycle_iter) for _ in range(data_size)]
out_index += index
index = torch.randperm(len(out_index), generator=g).numpy().tolist()
indices = [out_index[i] for i in index]
if self.round_up:
indices = (
indices *
int(self.total_size / len(indices) + 1))[:self.total_size]
indices = indices[self.rank:self.total_size:self.world_size]
return iter(indices)
def __len__(self) -> int:
"""The number of samples in this rank."""
return self.num_samples
[docs] def set_epoch(self, epoch: int) -> None:
"""Sets the epoch for this sampler.
When :attr:`shuffle=True`, this ensures all replicas use a different
random ordering for each epoch. Otherwise, the next iteration of this
sampler will yield the same ordering.
Args:
epoch (int): Epoch number.
"""
self.epoch = epoch