1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142
| import torch import torch.nn as nn import torch.optim as optim from torchvision import datasets, transforms from torch.utils.data import DataLoader import torchvision.utils as vutils import os
batch_size = 128 lr = 0.0002 noise_dim = 100 epochs = 20
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") os.makedirs("output_ganpro", exist_ok=True)
transform = transforms.Compose([ transforms.Grayscale(num_output_channels=1), transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)) ])
dataset = datasets.ImageFolder(root='data/mnist_jpg', transform=transform) dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
class Generator(nn.Module): def __init__(self, noise_dim): """ 生成器,将输入的噪声通过 MLP :param noise_dim: 输入的噪声维度 """ super().__init__() self.main = nn.Sequential( nn.Linear(noise_dim, 256), nn.BatchNorm1d(256), nn.LeakyReLU(0.2, True),
nn.Linear(256, 512), nn.BatchNorm1d(512), nn.LeakyReLU(0.2, True),
nn.Linear(512, 1024), nn.BatchNorm1d(1024), nn.LeakyReLU(0.2, True),
nn.Linear(1024, 28 * 28), nn.Tanh() )
def forward(self, input): """ :param input: 输入的噪声数据 :return: 通过 MLP 生成的图像 """ output = self.main(input) output = output.view(-1, 1, 28, 28) return output
class Discriminator(nn.Module): def __init__(self): """ 判别器,将输入的图像通过 MLP 进行二分类 """ super().__init__() self.main = nn.Sequential( nn.Linear(28 * 28, 1024), nn.LeakyReLU(0.2, inplace=True),
nn.Linear(1024, 512), nn.LeakyReLU(0.2, inplace=True),
nn.Linear(512, 256), nn.LeakyReLU(0.2, inplace=True),
nn.Linear(256, 1), nn.Sigmoid() )
def forward(self, input): output = self.main(input.view(-1, 28 * 28)) return output
netG = Generator(noise_dim).to(device) netD = Discriminator().to(device)
criterion = nn.BCELoss()
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(0.5, 0.999)) optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(0.5, 0.999))
for epoch in range(epochs): for i, (data, _) in enumerate(dataloader): netD.zero_grad() real_images = data.to(device) batch_size = real_images.size(0)
label_real = torch.full((batch_size, 1), 1.0, device=device) output_real = netD(real_images) lossD_real = criterion(output_real, label_real)
noise = torch.randn(batch_size, noise_dim, device=device) fake_images = netG(noise) label_fake = torch.full((batch_size, 1), 0.0, device=device) output_fake = netD(fake_images.detach()) lossD_fake = criterion(output_fake, label_fake)
lossD = lossD_real + lossD_fake lossD.backward() optimizerD.step()
netG.zero_grad() label_gen = torch.full((batch_size, 1), 1.0, device=device) output_gen = netD(fake_images) lossG = criterion(output_gen, label_gen) lossG.backward() optimizerG.step()
if i % 100 == 0: print(f"Epoch [{epoch + 1}/{epochs}] Batch {i}/{len(dataloader)} " f"Loss_D: {(lossD_real + lossD_fake).item():.4f} Loss_G: {lossG.item():.4f}")
with torch.no_grad(): fixed_noise = torch.randn(16, noise_dim, device=device) fake = netG(fixed_noise).detach() vutils.save_image(fake, f"output_ganpro/fake_samples_epoch_{epoch + 1}.png", normalize=True)
|