围绕GAN的一个小开发
AI2025-05-06
围绕GAN的一个小开发
简介
高二的时候社团设计本来打算从底层搓一个 GAN,但后来训练太慢了就放弃了。现在重新用 PyTorch 实现,还结合了之前想用 Flet 写的一个项目,也算是填上了当年的坑。
参考资料
GAN 基本原理
GAN(Generative Adversarial Network,生成对抗网络)由两个网络组成:
- 生成器(Generator):生成假数据,试图以假乱真;
 - 判别器(Discriminator):判断输入是真数据还是假数据。
 
两者不断对抗与优化,直至生成数据足够逼真。
卷积核常用参数
| 参数 (kernel_size, stride, padding) | 作用说明 | 
|---|---|
| (3,1,1) | 保持输入尺寸不变(常见) | 
| (3,2,1) | 尺寸缩小 2×(常见) | 
| (4,2,1) | 尺寸缩小 2×,但影响范围稍大 | 
| (5,1,2) | 尺寸不变,但感受野更大 | 
| (5,2,2) | 尺寸缩小 2×,感受野更大 | 
项目中采用
(4,2,1)的卷积核,方便根据输入图片大小动态调整卷积层数。
防止训练崩溃的技巧
(1) 降低判别器学习率
1  | optimizer_G = torch.optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))  | 
(2) 让 G 训练更多次
1  | for _ in range(2):  | 
(3) 使用 Label Smoothing(标签平滑)
1  | real_labels = torch.full((batch_size,), 0.9)  | 
(4) 添加噪声
1  | real_data += 0.05 * torch.randn_like(real_data)  | 
(5) 使用 WGAN-GP
使用 Wasserstein GAN(WGAN) 或 WGAN-GP 可减少梯度消失问题:
1  | loss = -torch.mean(D(real)) + torch.mean(D(fake)) # WGAN 损失  | 
项目仓库
📦 GitHub Repo: Woor3x/GANForge