DCGAN 实现手写数字生成

题目

  1. 本次挑战使用的MNIST手写数字数据集,包含60,000张28x28的灰度图像,分为10个类别(数字0-9)。此数据集将用于训练你的生成对抗网络。
  2. 你的任务是使用DCGAN模型,对该数据集进行图像生成。具体要求如下:
    1. 数据集下载:请下载MNIST数据集,并确保数据集中包含训练集和测试集。
    2. 数据预处理:将图像数据进行必要的预处理,使其适合于DCGAN模型的训练。
    3. 模型训练:搭建DCGAN模型,并利用训练数据集进行训练,调整模型参数,尝试生成高质量的数字图像。
    4. 模型评估:在训练过程中,监控生成图像的质量,并可视化不同训练阶段生成的图像。

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torchvision.utils as vutils
import os

batch_size = 128
lr = 0.0002
noise_dim = 100
epochs = 20
channel_size = 1

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
os.makedirs("output_dcgan", exist_ok=True)

# 数据预处理
transform = transforms.Compose([
transforms.Grayscale(num_output_channels=1),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])

# 使用 ImageFolder 读取数据
dataset = datasets.ImageFolder(root='data/mnist_jpg', transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)


class Generator(nn.Module):
def __init__(self, noise_dim, channel_size):
"""
基于卷积层的生成器
实现生成器的若干卷积层的叠加
:param noise_dim: 输入的噪音维度
:param channel_size: 目标图像的通道数
"""
super().__init__()
self.main = nn.Sequential(
nn.ConvTranspose2d(noise_dim, 64 * 2, kernel_size=7, stride=1, padding=0), # (batch_size, noise_dim, 1, 1) -> (batch_size, 128, 7, 7)
nn.BatchNorm2d(64 * 2),
nn.ReLU(True),

nn.ConvTranspose2d(64 * 2, 64, kernel_size=4, stride=2, padding=1), # (batch_size, 128, 7, 7) -> (batch_size, 64, 14, 14)
nn.BatchNorm2d(64),
nn.ReLU(True),

nn.ConvTranspose2d(64, channel_size, kernel_size=4, stride=2, padding=1), # (batch_size, 64, 14, 14) -> (batch_size, channel_size, 28, 28)
nn.Tanh()
)

def forward(self, input):
"""完成前向传播"""
return self.main(input)


class Discriminator(nn.Module):
def __init__(self, channel_size):
"""
基于卷积层的判别器
实现判别器的若干卷积层的叠加
:param channel_size: 欲判别的图像通道数
"""
super().__init__()
self.main = nn.Sequential(
nn.Conv2d(channel_size, 64, kernel_size=4, stride=2, padding=1), # (batch_size, channel_size, 28, 28)->(batch_size, 64, 14, 14)
nn.LeakyReLU(0.2, inplace=True),

nn.Conv2d(64, 64 * 2, kernel_size=4, stride=2, padding=1), # (batch_size, 64, 14, 14)->(batch_size, 128, 7, 7)
nn.BatchNorm2d(64 * 2),
nn.LeakyReLU(0.2, inplace=True),
)

self.flatten = nn.Flatten()

self.fc = nn.Sequential(
nn.Linear(128 * 7 * 7, 1),
nn.Sigmoid()
)

def forward(self, input):
"""
完成前向传播
:param input: 欲判别的图像数据
:return: 返回分类结果
"""
x = self.main(input)
x = self.flatten(x)
output = self.fc(x)
return output

# 模型,优化器,损失函数
netG = Generator(noise_dim, channel_size).to(device)
netD = Discriminator(channel_size).to(device)

criterion = nn.BCELoss()

optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(0.5, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(0.5, 0.999))

# 训练过程
for epoch in range(epochs):
for i, (data, _) in enumerate(dataloader):
# 训练判别器
# 使 D_model 对真实数据集里的真实图像进行分类判断,将 label 视作 1
netD.zero_grad()
real_imgs = data.to(device)
batch_size = real_imgs.size(0)

label_real = torch.full((batch_size, 1), 1.0, device=device)
output_real = netD(real_imgs)
lossD_real = criterion(output_real, label_real)

# 使 D_model 对 G_model 生成的虚假图像进行分类判断,将 label 视作 0
noise = torch.randn(batch_size, noise_dim, 1, 1, device=device)
fake_imgs = netG(noise)
label_fake = torch.full((batch_size, 1), 0.0, device=device)
output_fake = netD(fake_imgs.detach())
lossD_fake = criterion(output_fake, label_fake)
# 通过真实图像上的损失和虚假图像上的损失相加,得到原论文中的损失表达,可以衡量模型在真假图形分类上的表现
lossD = lossD_real + lossD_fake
lossD.backward()
optimizerD.step()

# 训练生成器
netG.zero_grad()
label_gen = torch.full((batch_size, 1), 1.0, device=device) # 生成器希望判别器将假样本判为真实,故标签设置为 1
output_gen = netD(fake_imgs)
lossG = criterion(output_gen, label_gen)
lossG.backward()
optimizerG.step()

if i % 100 == 0:
print(f"Epoch [{epoch + 1}/{epochs}] Batch {i}/{len(dataloader)} Loss_D: {(lossD_real + lossD_fake).item():.4f} Loss_G: {lossG.item():.4f}")

# 保存每个 epoch 的生成结果
with torch.no_grad():
fixed_noise = torch.randn(16, noise_dim, 1, 1, device=device)
fake = netG(fixed_noise)
vutils.save_image(fake, f"output_dcgan/fake_samples_epoch_{epoch + 1}.png", normalize=True)

结果

30 epoches:
alt text

Contents
  1. 1. 题目
  2. 2. 代码
  3. 3. 结果
|