TONG-H

围绕GAN的一个小开发

AI2025-05-06

围绕GAN的一个小开发


简介

高二的时候社团设计本来打算从底层搓一个 GAN,但后来训练太慢了就放弃了。现在重新用 PyTorch 实现,还结合了之前想用 Flet 写的一个项目,也算是填上了当年的坑。


参考资料


GAN 基本原理

GAN(Generative Adversarial Network,生成对抗网络)由两个网络组成:

  • 生成器(Generator):生成假数据,试图以假乱真;
  • 判别器(Discriminator):判断输入是真数据还是假数据。

两者不断对抗与优化,直至生成数据足够逼真。


卷积核常用参数

参数 (kernel_size, stride, padding) 作用说明
(3,1,1) 保持输入尺寸不变(常见)
(3,2,1) 尺寸缩小 2×(常见)
(4,2,1) 尺寸缩小 2×,但影响范围稍大
(5,1,2) 尺寸不变,但感受野更大
(5,2,2) 尺寸缩小 2×,感受野更大

项目中采用 (4,2,1) 的卷积核,方便根据输入图片大小动态调整卷积层数。


防止训练崩溃的技巧

(1) 降低判别器学习率

1
2
optimizer_G = torch.optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=0.0001, betas=(0.5, 0.999))

(2) 让 G 训练更多次

1
2
3
for _ in range(2): 
train_generator() # 让 G 训练两次
train_discriminator()

(3) 使用 Label Smoothing(标签平滑)

1
2
real_labels = torch.full((batch_size,), 0.9)
fake_labels = torch.full((batch_size,), 0.0)

(4) 添加噪声

1
real_data += 0.05 * torch.randn_like(real_data)

(5) 使用 WGAN-GP

使用 Wasserstein GAN(WGAN)WGAN-GP 可减少梯度消失问题:

1
loss = -torch.mean(D(real)) + torch.mean(D(fake))  # WGAN 损失

项目仓库

📦 GitHub Repo: Woor3x/GANForge