PyTorch Tutorial 03 - Transforms
Transforms — PyTorch Tutorials 2.10.0+cu128 documentation
本文是 PyTorch 官方 Tutorial 中 Transforms 部分的学习记录,主要整理 Transforms 的基本用法和个人理解。
这节内容量较少,是为后面更复杂的预处理打基础
主要讲的是数据预处理(transform/target_transform)API 如何把图片/标签转换成训练所需的 Tensor/格式。
代码
import torch
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda
ds = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=ToTensor(),
target_transform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))
"""
先创建长度 10 的全 0 向量
在索引 y 的位置填 1
比如 y=3 → [0,0,0,1,0,0,0,0,0,0]
"""
)
可视化
from torch.utils.data import DataLoader
# 可视化
dataloader = DataLoader(ds, batch_size=64, shuffle=False)
features, labels = next(iter(dataloader))
img = features[0].squeeze()
label = labels[0]
print("features shape:", features.shape) # [B, C, H, W]
print("img shape:", img.shape)
print("img dtype:", img.dtype)
print("img min/max:", img.min().item(), img.max().item())
print("label:", label, "label argmax:", label.argmax().item())
结果如下:
