PyTorch入门学习:8-Dataset and DataLoader

8 Dataset and DataLoader

0 Revision

  • Manual data feed
1
2
3
xy=np.loadtxt('data/diabetes.csv.gz', delimiter=',', dtype=np.float32)
x_data = torch.from_numpy(xy[:, :-1])
y_data = torch.from_numpy(xy[:, [-1]])

1 Terminology

  • Epoch: 一次训练循环
  • Batch_Size: batch中样本的数量
  • Iteration: batch的数量
1
2
3
4
# Training cycle
for epoch in range(training_epochs)
# Loop over all batches
for i in range(total_batch):

2 Dataset & DataLoader

数据集和数据加载器的工作机制

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import torch
from torch.utils.data import Dataset, DataLoader
# Dataset是一个抽象类,DataLoader不是
# DiabetesDataset继承自Dataset抽象类
class DiabetesDataset(Dataset):
def __init__(self):
pass
# 可以使用索引
def __getitem__(self, index):
pass
# 可以返回数据集的大小
def __len__(self):
pass

dataset = DiabetesDataset()
train_loader = DataLoader(dataset=dataset, batch_size=32, shuffle=True, num_workers=2)
# num_workers: 读取数据集使用的子进程数量

for epoch in range(100):
for i, data in enumerate(train_loader, 0):
  • 数据集读取方法:
  1. init全部读取到内存
  2. init读取文件名列表,getitem根据索引读取文件
  • Windows直接进行多进程读取会报错,需要在if name == ‘main’:中运行

3 Example Implementation

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
class DiabetesDataset(Dataset):
def __init__(self):
xy=np.loadtxt(filepath, delimiter=',', dtype=np.float32)
self.len = xy.shape[0]
self.x_data = torch.from_numpy(xy[:, :-1])
self.y_data = torch.from_numpy(xy[:, [-1]])

def __getitem__(self, index):
return self.x_data[index], self.y_data[index]

def __len__(self):
return self.len

dataset = DiabetesDataset('data/diabetes.csv.gz')
train_loader = DataLoader(dataset=dataset, batch_size=32, shuffle=True, num_workers=2)

for epoch in range(100):
for i, data in enumerate(train_loader, 0):
# 1. Prepare data
inputs, labels = data
# 2. Forward
y_pred = model(inputs)
loss = criterion(y_pred, labels)
print(f'Epoch: {epoch} | Batch: {i} | Loss: {loss.item()}')
# 3. Backward
optimizer.zero_grad()
loss.backward()
# 4. Update
optimizer.step()

4 The following dataset loaders are available

torchvision中的Dataset Example: MNIST

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision import datasets

train_dataset = datasets.MNIST(root='data', train=True, transform=transforms.ToTensor(), download=True)
# transform: 对数据进行预处理,格式转换
test_dataset = datasets.MNIST(root='data', train=False, transform=transforms.ToTensor(), download=True)

train_loader = DataLoader(dataset=train_dataset, batch_size=32, shuffle=True, num_workers=2)
test_loader = DataLoader(dataset=test_dataset, batch_size=32, shuffle=False, num_workers=2)

for batch_idx, (inputs, labels) in enumerate(train_loader):
# 1. Prepare data
# 2. Forward
# 3. Backward
# 4. Update

5 Exercise

5.1

  • Build DataLoader for Titanic dataset
  • Build a classifier using the DataLoader

PyTorch入门学习:8-Dataset and DataLoader
https://eleco.top/2026/02/24/learn-torch-8-Dataset-and-DataLoader/
作者
Eleco
发布于
2026年2月24日
许可协议