Source code for spaic.IO.Dataloader

# -*- coding: utf-8 -*-
"""
Created on 2020/8/12
@project: SPAIC
@filename: Dataloader
@author: Hong Chaofei
@contact: hongchf@gmail.com
@description:
定义数据导入模块
"""
from .sampler import *
import numpy as np


# Dataloader class is written by referring to https://github.com/pytorch/pytorch/blob/master/torch/utils/data/dataloader.py.
class _BaseDatasetFetcher(object):
    def __init__(self, dataset, auto_collation, collate_fn, drop_last):
        self.dataset = dataset
        self.auto_collation = auto_collation
        self.collate_fn = collate_fn
        self.drop_last = drop_last

    def fetch(self, possibly_batched_index):
        raise NotImplementedError()


class _MapDatasetFetcher(_BaseDatasetFetcher):
    def __init__(self, dataset, auto_collation, collate_fn, drop_last):
        super(_MapDatasetFetcher, self).__init__(dataset, auto_collation, collate_fn, drop_last)

    def fetch(self, possibly_batched_index):
        if self.auto_collation:
            data = [self.dataset[idx] for idx in possibly_batched_index]
        else:
            data = self.dataset[possibly_batched_index]
        return self.collate_fn(data)


[docs]def default_collate(batch): # shape of data is [batch_size, *shape] data = [item[0] for item in batch] data = np.array(data) target = [item[1] for item in batch] target = np.array(target) return [data, target]
[docs]class Dataloader(object): """ sampler的作用是生成一系列的index 而batch_sampler则是将sampler生成的indices打包分组,得到一个又一个batch的index """ __initialized = False def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, collate_fn=None, drop_last=False): self.dataset = dataset self.data = None self.label = None # self.source = None if sampler is not None and shuffle: raise ValueError('sampler option is mutually exclusive with ' 'shuffle') if batch_sampler is not None: # auto_collation with custom batch_sampler if batch_size != 1 or shuffle or sampler is not None or drop_last: raise ValueError('batch_sampler option is mutually exclusive ' 'with batch_size, shuffle, sampler, and ' 'drop_last') batch_size = None drop_last = False elif batch_size is None: # no auto_collation if drop_last: raise ValueError('batch_size=None option disables auto-batching ' 'and is mutually exclusive with drop_last') if sampler is None: if shuffle: # Cannot statically verify that dataset is Sized # Somewhat related: see NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ] sampler = RandomSampler(dataset) # type: ignore else: sampler = SequentialSampler(dataset) if batch_size is not None and batch_sampler is None: # auto_collation without custom batch_sampler batch_sampler = BatchSampler(sampler, batch_size, drop_last) self.batch_size = batch_size self.drop_last = drop_last self.sampler = sampler self.batch_sampler = batch_sampler if collate_fn is None: collate_fn = default_collate self.collate_fn = collate_fn self.__initialized = True # self.try_fetch() def __setattr__(self, attr, val): if self.__initialized and attr in ('batch_size', 'sampler', 'drop_last'): raise ValueError('{} attribute should not be set after {} is ' 'initialized'.format(attr, self.__class__.__name__)) super(Dataloader, self).__setattr__(attr, val) def __iter__(self): return _SingleProcessDataLoaderIter(self) # 在循环前获得一个batch的数据给Node结点用于build
[docs] def try_fetch(self): for i, item in enumerate(_SingleProcessDataLoaderIter(self)): self.data = item[0] self.label = item[1] if 'maxNum' in self.dataset.data.keys(): self.num = self.dataset.data['maxNum'] else: shape = self.data.shape[1:] self.num = int(np.prod(shape)) break return self.data, self.label
@property def _auto_collation(self): return self.batch_sampler is not None @property def _index_sampler(self): # The actual sampler used for generating indices for `_DatasetFetcher` # to read data at each time. if self._auto_collation: return self.batch_sampler else: return self.sampler def __len__(self): return len(self._index_sampler)
class _BaseDataLoaderIter(object): def __init__(self, loader: Dataloader): self._dataset = loader.dataset self._auto_collation = loader._auto_collation self._drop_last = loader.drop_last self._index_sampler = loader._index_sampler self._sampler_iter = iter(self._index_sampler) self._collate_fn = loader.collate_fn self._num_yielded = 0 def __iter__(self): return self def _next_index(self): return next(self._sampler_iter) # may raise StopIteration def __next__(self): raise NotImplementedError def __len__(self): return len(self.index_sampler) def __getstate__(self): # TODO: add limited pickling support for sharing an iterator # across multiple threads for HOGWILD. # Probably the best way to do this is by moving the sample pushing # to a separate thread and then just sharing the data queue # but signalling the end is tricky without a non-blocking API raise NotImplementedError("{} cannot be pickled", self.__class__.__name__) class _SingleProcessDataLoaderIter(_BaseDataLoaderIter): def __init__(self, loader): super(_SingleProcessDataLoaderIter, self).__init__(loader) self._dataset_fetcher = _MapDatasetFetcher(self._dataset, self._auto_collation, self._collate_fn, self._drop_last) def __next__(self): index = self._next_index() # may raise StopIteration data = self._dataset_fetcher.fetch(index) # may raise StopIteration return data next = __next__