Transforms的使用.md

本节主要介绍 torchvision 工具箱中 transforms 模块的基本使用。在深度学习中,transforms 主要用于对图片进行预处理(Pre-processing)和数据增强(Data Augmentation)。

1. Transforms 的结构理解

可以将 transforms 理解为一个工具箱

  • transforms: 整个工具箱。
  • transforms.ToTensor(): 工具箱里的一个具体工具(例如:一把锤子)。
  • tool = transforms.ToTensor(): 我们从工具箱里拿出来的具体工具对象。
  • result = tool(input): 使用这个工具对输入(图片)进行处理,得到结果。

2. 核心方法:ToTensor

ToTensor 是最基础也是最常用的 transform,它的主要作用是将 PIL Imagenumpy.ndarray 转换为 torch.Tensor

实战代码

from torchvision import transforms
from PIL import Image
img_path =r"D:\AProject\PythonProject\hymenoptera_data\train\ants\0013035.jpg"
img = Image.open(img_path)
transform = transforms.ToTensor()
tensor = transform(img)
print(tensor)

4. 常见的Transforms

输入 PIL Image.open()
输出 tensor ToTensor()
作用 narrays cv2.imread()

5. Normalize()使用

计算公式$$output[channel] = \frac{input[channel] - mean[channel]}{std[channel]}$$

transforms.Normalize()

参数说明

mean:每个通道的均值序列的平均值

std: 每个通道的标准差

实战代码

from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter
from PIL import Image
writer = SummaryWriter("logs")
img_path = r"D:\AProject\PythonProject\hymenoptera_data\train\ants\5650366_e22b7e1065.jpg"
img = Image.open(img_path)
transform = transforms.ToTensor()
tensor = transform(img)
writer.add_image("original",tensor,0)
transform_normail = transforms.Normalize([0.5,0.3,0.1], [0.5,0.6,0.5])
tensor_n = transform_normail(tensor)
writer.add_image(
    "transformed",tensor_n,
    0
)
writer.close()

image-20260112215430595

6. Resize()使用

用于将输入图片调整为指定的大小

  • 参数
    • 传入 (H, W):直接强制拉伸/压缩到指定高宽。
    • 传入一个整数 N:将图像的短边缩放到 N,长边按比例缩放。

7. Compose (组合变换) - 重点

这是将多个 Transform 操作串联起来的“容器”。

  • 作用: 像流水线一样,按顺序执行列表中的 Transform 操作。
  • 核心逻辑: 上一个操作的输出(Output),必须匹配下一个操作的输入(Input)。
  • 常见报错原因: 数据类型不匹配。
    • 错误示例: 先做 Normalize (需要Tensor) 再做 ToTensor (输出Tensor)。如果第一步输入是 PIL,Normalize 就会报错。
    • 正确顺序示例: PIL图片 $\rightarrow$ Resize (PIL) $\rightarrow$ ToTensor (变为Tensor) $\rightarrow$ Normalize (处理Tensor)。

8. RandomCrop (随机裁剪)

常用于数据增强(Data Augmentation)。

  • 作用: 在图片中随机裁剪出一块指定大小的区域。
  • 参数: 传入 (H, W) 或一个整数。
  • 意义: 增加数据的多样性,防止过拟合。
import torchvision.transforms as transforms

# 定义一个组合操作 (流水线)
trans_compose = transforms.Compose([
    transforms.Resize((224, 224)),      # 1. 先调整大小 (Input: PIL -> Output: PIL)
    transforms.ToTensor(),              # 2. 转为 Tensor (Input: PIL -> Output: Tensor)
    transforms.Normalize([0.5], [0.5])  # 3. 归一化 (Input: Tensor -> Output: Tensor)
])

# 使用
# img_out = trans_compose(img_in)