Transforms的使用.md
Transforms的使用.md
Tinsiag Zhu本节主要介绍 torchvision 工具箱中 transforms 模块的基本使用。在深度学习中,transforms 主要用于对图片进行预处理(Pre-processing)和数据增强(Data Augmentation)。
1. Transforms 的结构理解
可以将 transforms 理解为一个工具箱。
- transforms: 整个工具箱。
- transforms.ToTensor(): 工具箱里的一个具体工具(例如:一把锤子)。
- tool = transforms.ToTensor(): 我们从工具箱里拿出来的具体工具对象。
- result = tool(input): 使用这个工具对输入(图片)进行处理,得到结果。
2. 核心方法:ToTensor
ToTensor 是最基础也是最常用的 transform,它的主要作用是将 PIL Image 或 numpy.ndarray 转换为 torch.Tensor。
实战代码
from torchvision import transforms
from PIL import Image
img_path =r"D:\AProject\PythonProject\hymenoptera_data\train\ants\0013035.jpg"
img = Image.open(img_path)
transform = transforms.ToTensor()
tensor = transform(img)
print(tensor)
4. 常见的Transforms
| 输入 | PIL | Image.open() |
|---|---|---|
| 输出 | tensor | ToTensor() |
| 作用 | narrays | cv2.imread() |
5. Normalize()使用
计算公式$$output[channel] = \frac{input[channel] - mean[channel]}{std[channel]}$$
transforms.Normalize()
参数说明
mean:每个通道的均值序列的平均值
std: 每个通道的标准差
实战代码
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter
from PIL import Image
writer = SummaryWriter("logs")
img_path = r"D:\AProject\PythonProject\hymenoptera_data\train\ants\5650366_e22b7e1065.jpg"
img = Image.open(img_path)
transform = transforms.ToTensor()
tensor = transform(img)
writer.add_image("original",tensor,0)
transform_normail = transforms.Normalize([0.5,0.3,0.1], [0.5,0.6,0.5])
tensor_n = transform_normail(tensor)
writer.add_image(
"transformed",tensor_n,
0
)
writer.close()
6. Resize()使用
用于将输入图片调整为指定的大小
- 参数
- 传入
(H, W):直接强制拉伸/压缩到指定高宽。 - 传入一个整数
N:将图像的短边缩放到 N,长边按比例缩放。
- 传入
7. Compose (组合变换) - 重点
这是将多个 Transform 操作串联起来的“容器”。
- 作用: 像流水线一样,按顺序执行列表中的 Transform 操作。
- 核心逻辑: 上一个操作的输出(Output),必须匹配下一个操作的输入(Input)。
- 常见报错原因: 数据类型不匹配。
- 错误示例: 先做
Normalize(需要Tensor) 再做ToTensor(输出Tensor)。如果第一步输入是 PIL,Normalize就会报错。 - 正确顺序示例: PIL图片 $\rightarrow$
Resize(PIL) $\rightarrow$ToTensor(变为Tensor) $\rightarrow$Normalize(处理Tensor)。
- 错误示例: 先做
8. RandomCrop (随机裁剪)
常用于数据增强(Data Augmentation)。
- 作用: 在图片中随机裁剪出一块指定大小的区域。
- 参数: 传入
(H, W)或一个整数。 - 意义: 增加数据的多样性,防止过拟合。
import torchvision.transforms as transforms
# 定义一个组合操作 (流水线)
trans_compose = transforms.Compose([
transforms.Resize((224, 224)), # 1. 先调整大小 (Input: PIL -> Output: PIL)
transforms.ToTensor(), # 2. 转为 Tensor (Input: PIL -> Output: Tensor)
transforms.Normalize([0.5], [0.5]) # 3. 归一化 (Input: Tensor -> Output: Tensor)
])
# 使用
# img_out = trans_compose(img_in)



