Pytorch Basics 1: Data
PyTorch has two primitives to work with data: torch.utils.data.DataLoader
and torch.utils.data.Dataset
.
Dataset
is an object that stores the samples and their corresponding labels- there are two types of
Dataset
: https://pytorch.org/docs/stable/data.html- for map-style dataset, must contains two functions:
__getitem__
__len__
- for map-style dataset, must contains two functions:
- there are two types of
- DataLoader makes your dataset to an iterable.
- then you could use for loop to generate batch of data
Examples
######### e.g. 1
# Download training data from open datasets.
training_data = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=ToTensor(),
)
# Download test data from open datasets.
test_data = datasets.FashionMNIST(
root="data",
train=False,
download=True,
transform=ToTensor(),
)
batch_size = 64
# Create data loaders.
train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)
for X, y in test_dataloader:
print("Shape of X [N, C, H, W]: ", X.shape)
print("Shape of y: ", y.shape, y.dtype)
break
# Shape of X [N, C, H, W]: torch.Size([64, 1, 28, 28])
# Shape of y: torch.Size([64]) torch.int64
######### e,g, 2