torchvision.datasets¶
All datasets are subclasses of torch.utils.data.Dataset
i.e, they have __getitem__
and __len__
methods implemented.
Hence, they can all be passed to a torch.utils.data.DataLoader
which can load multiple samples parallelly using torch.multiprocessing
workers.
For example:
imagenet_data = torchvision.datasets.ImageNet('path/to/imagenet_root/')
data_loader = torch.utils.data.DataLoader(imagenet_data,
batch_size=4,
shuffle=True,
num_workers=args.nThreads)
The following datasets are available:
Datasets
All the datasets have almost similar API. They all have two common arguments:
transform
and target_transform
to transform the input and target respectively.
MNIST¶
-
class
torchvision.datasets.
MNIST
(root, train=True, transform=None, target_transform=None, download=False)[source]¶ MNIST Dataset.
- Parameters
root (string) – Root directory of dataset where
MNIST/processed/training.pt
andMNIST/processed/test.pt
exist.train (bool, optional) – If True, creates dataset from
training.pt
, otherwise fromtest.pt
.download (bool, optional) – If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.
transform (callable, optional) – A function/transform that takes in an PIL image and returns a transformed version. E.g,
transforms.RandomCrop
target_transform (callable, optional) – A function/transform that takes in the target and transforms it.
Fashion-MNIST¶
-
class
torchvision.datasets.
FashionMNIST
(root, train=True, transform=None, target_transform=None, download=False)[source]¶ Fashion-MNIST Dataset.
- Parameters
root (string) – Root directory of dataset where
Fashion-MNIST/processed/training.pt
andFashion-MNIST/processed/test.pt
exist.train (bool, optional) – If True, creates dataset from
training.pt
, otherwise fromtest.pt
.download (bool, optional) – If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.
transform (callable, optional) – A function/transform that takes in an PIL image and returns a transformed version. E.g,
transforms.RandomCrop
target_transform (callable, optional) – A function/transform that takes in the target and transforms it.
KMNIST¶
-
class
torchvision.datasets.
KMNIST
(root, train=True, transform=None, target_transform=None, download=False)[source]¶ Kuzushiji-MNIST Dataset.
- Parameters
root (string) – Root directory of dataset where
KMNIST/processed/training.pt
andKMNIST/processed/test.pt
exist.train (bool, optional) – If True, creates dataset from
training.pt
, otherwise fromtest.pt
.download (bool, optional) – If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.
transform (callable, optional) – A function/transform that takes in an PIL image and returns a transformed version. E.g,
transforms.RandomCrop
target_transform (callable, optional) – A function/transform that takes in the target and transforms it.
EMNIST¶
-
class
torchvision.datasets.
EMNIST
(root, split, **kwargs)[source]¶ EMNIST Dataset.
- Parameters
root (string) – Root directory of dataset where
EMNIST/processed/training.pt
andEMNIST/processed/test.pt
exist.split (string) – The dataset has 6 different splits:
byclass
,bymerge
,balanced
,letters
,digits
andmnist
. This argument specifies which one to use.train (bool, optional) – If True, creates dataset from
training.pt
, otherwise fromtest.pt
.download (bool, optional) – If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.
transform (callable, optional) – A function/transform that takes in an PIL image and returns a transformed version. E.g,
transforms.RandomCrop
target_transform (callable, optional) – A function/transform that takes in the target and transforms it.
FakeData¶
-
class
torchvision.datasets.
FakeData
(size=1000, image_size=(3, 224, 224), num_classes=10, transform=None, target_transform=None, random_offset=0)[source]¶ A fake dataset that returns randomly generated images and returns them as PIL images
- Parameters
size (int, optional) – Size of the dataset. Default: 1000 images
image_size (tuple, optional) – Size if the returned images. Default: (3, 224, 224)
num_classes (int, optional) – Number of classes in the datset. Default: 10
transform (callable, optional) – A function/transform that takes in an PIL image and returns a transformed version. E.g,
transforms.RandomCrop
target_transform (callable, optional) – A function/transform that takes in the target and transforms it.
random_offset (int) – Offsets the index-based random seed used to generate each image. Default: 0
COCO¶
Note
These require the COCO API to be installed
Detection¶
-
class
torchvision.datasets.
CocoDetection
(root, annFile, transform=None, target_transform=None, transforms=None)[source]¶ MS Coco Detection Dataset.
- Parameters
root (string) – Root directory where images are downloaded to.
annFile (string) – Path to json annotation file.
transform (callable, optional) – A function/transform that takes in an PIL image and returns a transformed version. E.g,
transforms.ToTensor
target_transform (callable, optional) – A function/transform that takes in the target and transforms it.
LSUN¶
-
class
torchvision.datasets.
LSUN
(root, classes='train', transform=None, target_transform=None)[source]¶ LSUN dataset.
- Parameters
root (string) – Root directory for the database files.
classes (string or list) – One of {‘train’, ‘val’, ‘test’} or a list of categories to load. e,g. [‘bedroom_train’, ‘church_train’].
transform (callable, optional) – A function/transform that takes in an PIL image and returns a transformed version. E.g,
transforms.RandomCrop
target_transform (callable, optional) – A function/transform that takes in the target and transforms it.
ImageFolder¶
-
class
torchvision.datasets.
ImageFolder
(root, transform=None, target_transform=None, loader=<function default_loader>, is_valid_file=None)[source]¶ A generic data loader where the images are arranged in this way:
root/dog/xxx.png root/dog/xxy.png root/dog/xxz.png root/cat/123.png root/cat/nsdf3.png root/cat/asd932_.png
- Parameters
root (string) – Root directory path.
transform (callable, optional) – A function/transform that takes in an PIL image and returns a transformed version. E.g,
transforms.RandomCrop
target_transform (callable, optional) – A function/transform that takes in the target and transforms it.
loader (callable, optional) – A function to load an image given its path.
is_valid_file – A function that takes path of an Image file and check if the file is a valid_file (used to check of corrupt files)
DatasetFolder¶
-
class
torchvision.datasets.
DatasetFolder
(root, loader, extensions=None, transform=None, target_transform=None, is_valid_file=None)[source]¶ A generic data loader where the samples are arranged in this way:
root/class_x/xxx.ext root/class_x/xxy.ext root/class_x/xxz.ext root/class_y/123.ext root/class_y/nsdf3.ext root/class_y/asd932_.ext
- Parameters
root (string) – Root directory path.
loader (callable) – A function to load a sample given its path.
extensions (tuple[string]) – A list of allowed extensions. both extensions and is_valid_file should not be passed.
transform (callable, optional) – A function/transform that takes in a sample and returns a transformed version. E.g,
transforms.RandomCrop
for images.target_transform (callable, optional) – A function/transform that takes in the target and transforms it.
is_valid_file – A function that takes path of an Image file and check if the file is a valid_file (used to check of corrupt files) both extensions and is_valid_file should not be passed.
ImageNet¶
-
class
torchvision.datasets.
ImageNet
(root, split='train', download=False, **kwargs)[source]¶ ImageNet 2012 Classification Dataset.
- Parameters
root (string) – Root directory of the ImageNet Dataset.
split (string, optional) – The dataset split, supports
train
, orval
.download (bool, optional) – If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.
transform (callable, optional) – A function/transform that takes in an PIL image and returns a transformed version. E.g,
transforms.RandomCrop
target_transform (callable, optional) – A function/transform that takes in the target and transforms it.
loader – A function to load an image given its path.
Note
This requires scipy to be installed
CIFAR¶
-
class
torchvision.datasets.
CIFAR10
(root, train=True, transform=None, target_transform=None, download=False)[source]¶ CIFAR10 Dataset.
- Parameters
root (string) – Root directory of dataset where directory
cifar-10-batches-py
exists or will be saved to if download is set to True.train (bool, optional) – If True, creates dataset from training set, otherwise creates from test set.
transform (callable, optional) – A function/transform that takes in an PIL image and returns a transformed version. E.g,
transforms.RandomCrop
target_transform (callable, optional) – A function/transform that takes in the target and transforms it.
download (bool, optional) – If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.
STL10¶
-
class
torchvision.datasets.
STL10
(root, split='train', transform=None, target_transform=None, download=False)[source]¶ STL10 Dataset.
- Parameters
root (string) – Root directory of dataset where directory
stl10_binary
exists.split (string) – One of {‘train’, ‘test’, ‘unlabeled’, ‘train+unlabeled’}. Accordingly dataset is selected.
transform (callable, optional) – A function/transform that takes in an PIL image and returns a transformed version. E.g,
transforms.RandomCrop
target_transform (callable, optional) – A function/transform that takes in the target and transforms it.
download (bool, optional) – If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.
SVHN¶
-
class
torchvision.datasets.
SVHN
(root, split='train', transform=None, target_transform=None, download=False)[source]¶ SVHN Dataset. Note: The SVHN dataset assigns the label 10 to the digit 0. However, in this Dataset, we assign the label 0 to the digit 0 to be compatible with PyTorch loss functions which expect the class labels to be in the range [0, C-1]
- Parameters
root (string) – Root directory of dataset where directory
SVHN
exists.split (string) – One of {‘train’, ‘test’, ‘extra’}. Accordingly dataset is selected. ‘extra’ is Extra training set.
transform (callable, optional) – A function/transform that takes in an PIL image and returns a transformed version. E.g,
transforms.RandomCrop
target_transform (callable, optional) – A function/transform that takes in the target and transforms it.
download (bool, optional) – If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.
PhotoTour¶
-
class
torchvision.datasets.
PhotoTour
(root, name, train=True, transform=None, download=False)[source]¶ Learning Local Image Descriptors Data Dataset.
- Parameters
root (string) – Root directory where images are.
name (string) – Name of the dataset to load.
transform (callable, optional) – A function/transform that takes in an PIL image and returns a transformed version.
download (bool, optional) – If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.
SBU¶
-
class
torchvision.datasets.
SBU
(root, transform=None, target_transform=None, download=True)[source]¶ SBU Captioned Photo Dataset.
- Parameters
root (string) – Root directory of dataset where tarball
SBUCaptionedPhotoDataset.tar.gz
exists.transform (callable, optional) – A function/transform that takes in a PIL image and returns a transformed version. E.g,
transforms.RandomCrop
target_transform (callable, optional) – A function/transform that takes in the target and transforms it.
download (bool, optional) – If True, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.
Flickr¶
-
class
torchvision.datasets.
Flickr8k
(root, ann_file, transform=None, target_transform=None)[source]¶ Flickr8k Entities Dataset.
- Parameters
root (string) – Root directory where images are downloaded to.
ann_file (string) – Path to annotation file.
transform (callable, optional) – A function/transform that takes in a PIL image and returns a transformed version. E.g,
transforms.ToTensor
target_transform (callable, optional) – A function/transform that takes in the target and transforms it.
-
class
torchvision.datasets.
Flickr30k
(root, ann_file, transform=None, target_transform=None)[source]¶ Flickr30k Entities Dataset.
- Parameters
root (string) – Root directory where images are downloaded to.
ann_file (string) – Path to annotation file.
transform (callable, optional) – A function/transform that takes in a PIL image and returns a transformed version. E.g,
transforms.ToTensor
target_transform (callable, optional) – A function/transform that takes in the target and transforms it.
VOC¶
-
class
torchvision.datasets.
VOCSegmentation
(root, year='2012', image_set='train', download=False, transform=None, target_transform=None, transforms=None)[source]¶ Pascal VOC Segmentation Dataset.
- Parameters
root (string) – Root directory of the VOC Dataset.
year (string, optional) – The dataset year, supports years 2007 to 2012.
image_set (string, optional) – Select the image_set to use,
train
,trainval
orval
download (bool, optional) – If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.
transform (callable, optional) – A function/transform that takes in an PIL image and returns a transformed version. E.g,
transforms.RandomCrop
target_transform (callable, optional) – A function/transform that takes in the target and transforms it.
-
class
torchvision.datasets.
VOCDetection
(root, year='2012', image_set='train', download=False, transform=None, target_transform=None, transforms=None)[source]¶ Pascal VOC Detection Dataset.
- Parameters
root (string) – Root directory of the VOC Dataset.
year (string, optional) – The dataset year, supports years 2007 to 2012.
image_set (string, optional) – Select the image_set to use,
train
,trainval
orval
download (bool, optional) – If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. (default: alphabetic indexing of VOC’s 20 classes).
transform (callable, optional) – A function/transform that takes in an PIL image and returns a transformed version. E.g,
transforms.RandomCrop
target_transform (callable, required) – A function/transform that takes in the target and transforms it.
Cityscapes¶
Note
Requires Cityscape to be downloaded.
-
class
torchvision.datasets.
Cityscapes
(root, split='train', mode='fine', target_type='instance', transform=None, target_transform=None)[source]¶ Cityscapes Dataset.
- Parameters
root (string) – Root directory of dataset where directory
leftImg8bit
andgtFine
orgtCoarse
are located.split (string, optional) – The image split to use,
train
,test
orval
if mode=”gtFine” otherwisetrain
,train_extra
orval
mode (string, optional) – The quality mode to use,
gtFine
orgtCoarse
target_type (string or list, optional) – Type of target to use,
instance
,semantic
,polygon
orcolor
. Can also be a list to output a tuple with all specified target types.transform (callable, optional) – A function/transform that takes in a PIL image and returns a transformed version. E.g,
transforms.RandomCrop
target_transform (callable, optional) – A function/transform that takes in the target and transforms it.
Examples
Get semantic segmentation target
dataset = Cityscapes('./data/cityscapes', split='train', mode='fine', target_type='semantic') img, smnt = dataset[0]
Get multiple targets
dataset = Cityscapes('./data/cityscapes', split='train', mode='fine', target_type=['instance', 'color', 'polygon']) img, (inst, col, poly) = dataset[0]
Validate on the “coarse” set
dataset = Cityscapes('./data/cityscapes', split='val', mode='coarse', target_type='semantic') img, smnt = dataset[0]
SBD¶
-
class
torchvision.datasets.
SBDataset
(root, image_set='train', mode='boundaries', download=False, transforms=None)[source]¶ -
The SBD currently contains annotations from 11355 images taken from the PASCAL VOC 2011 dataset.
Note
Please note that the train and val splits included with this dataset are different from the splits in the PASCAL VOC dataset. In particular some “train” images might be part of VOC2012 val. If you are interested in testing on VOC 2012 val, then use image_set=’train_noval’, which excludes all val images.
Warning
This class needs scipy to load target files from .mat format.
- Parameters
root (string) – Root directory of the Semantic Boundaries Dataset
image_set (string, optional) – Select the image_set to use,
train
,val
ortrain_noval
. Image settrain_noval
excludes VOC 2012 val images.mode (string, optional) – Select target type. Possible values ‘boundaries’ or ‘segmentation’. In case of ‘boundaries’, the target is an array of shape [num_classes, H, W], where num_classes=20.
download (bool, optional) – If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.
xy_transform (callable, optional) – A function/transform that takes input sample and its target as entry and returns a transformed version. Input sample is PIL image and target is a numpy array if mode=’boundaries’ or PIL image if mode=’segmentation’.