200字
PyTorch-Tutorial 学习记录——Transforms
2026-02-05
2026-02-05

PyTorch Tutorial 03 - Transforms

Transforms — PyTorch Tutorials 2.10.0+cu128 documentation

本文是 PyTorch 官方 Tutorial 中 Transforms 部分的学习记录,主要整理 Transforms 的基本用法和个人理解。

这节内容量较少,是为后面更复杂的预处理打基础

主要讲的是数据预处理(transform/target_transform)API 如何把图片/标签转换成训练所需的 Tensor/格式。

代码

import torch
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda

ds = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
    target_transform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))
"""
先创建长度 10 的全 0 向量
在索引 y 的位置填 1
比如 y=3 → [0,0,0,1,0,0,0,0,0,0]
"""
)

可视化

from torch.utils.data import DataLoader

# 可视化
dataloader = DataLoader(ds, batch_size=64, shuffle=False)
features, labels = next(iter(dataloader))
img = features[0].squeeze()
label = labels[0]
print("features shape:", features.shape)  # [B, C, H, W]
print("img shape:", img.shape)
print("img dtype:", img.dtype)
print("img min/max:", img.min().item(), img.max().item())
print("label:", label, "label argmax:", label.argmax().item())

结果如下:

PyTorch-Tutorial 学习记录——Transforms
作者
若离
发表于
2026-02-05
License
CC BY-NC-SA 4.0

评论