Pytorch Basics 1: Data

PyTorch has two primitives to work with data: torch.utils.data.DataLoader and torch.utils.data.Dataset.

  • Dataset is an object that stores the samples and their corresponding labels
  • DataLoader makes your dataset to an iterable.
    • then you could use for loop to generate batch of data

Examples

######### e.g. 1


# Download training data from open datasets.
training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
)

# Download test data from open datasets.
test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor(),
)


batch_size = 64

# Create data loaders.
train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)

for X, y in test_dataloader:
    print("Shape of X [N, C, H, W]: ", X.shape)
    print("Shape of y: ", y.shape, y.dtype)
    break

# Shape of X [N, C, H, W]:  torch.Size([64, 1, 28, 28])
# Shape of y:  torch.Size([64]) torch.int64



######### e,g, 2



Dayu Yang
Dayu Yang
Ph.D. Student in Financial Services Analytics

I am excited about NLP, Information Retrieval, and Algorithm Trading. My current research focuses on Conversational AI systems.

Related