本文参考:pytorch实现简单GAN - 灰信网(软件开发博客聚合)
上文中pytorch代码执行会有问题,这块本文中已经修复!
1、GAN概述
GAN:Generative Adversarial Nets,生成对抗网络。在给定充分的建模能力,两个博弈模型能够通过简单的反向传播来协同训练。
两个模型的角色定位十分鲜明。给定真实数据集Data,G是生成器(Generator),它的任务是生成能以假乱真的假数据。D是判别器(Discriminator),它从真实数据或者G那里获取数据,然后做出判别真假的标记。
理想情况下,D和G都会随着不断训练做得越来越好,直到G基本上成为一个“赝品制造大事”,而D因无法正确区分两种数据分布输给G。
2、数学建模
设真实数据的概率分布为Pdata,生成器生成数据的概率分布为PG。
(1)D的数学描述
规定D的输出代表输入为”真”的概率(在0~1之间),则D的目标是:
若输入是真品,则提高D(x);若输入是赝品,则降低D(x)。
综合起来用数学语言描述如下:
解释:若x服从,则log(D(x))越大越好。若x服从,则log(D(x))越小越好,即log(1-D(x))越大越好。
(2)G的数学描述
对于G来说,它的目标是尽可能提高生成数据被D判别为”真”的概率,数学描述如下:
也即:
(3)全局最优解
生成器生成数据的分布在最优解情况下就等于真实数据的分布。
3、用pytorch实现简单GAN
import numpy as np
import torch.nn as nn
import torch
import matplotlib.pyplot as plt
LR = 0.0001
BATCH_SIZE = 64
DATA_SIZE = 16
IDEA = 5
X = np.linspace(0, 2 * np.pi, DATA_SIZE)
def p_data(x):
f = np.zeros((BATCH_SIZE, DATA_SIZE))
for i in range(BATCH_SIZE):
f[i] = np.sin(x)
return f
G = nn.Sequential(
nn.Linear(IDEA, 64),
nn.ReLU(),
nn.Linear(64, 128),
nn.ReLU(),
nn.Linear(128, DATA_SIZE)
)
D = nn.Sequential(
nn.Linear(DATA_SIZE, 64),
nn.ReLU(),
nn.Linear(64, 256),
nn.ReLU(),
nn.Linear(256, 1),
nn.Sigmoid()
)
D_optimizer = torch.optim.Adam(D.parameters(), lr=LR)
G_optimizer = torch.optim.Adam(G.parameters(), lr=LR)
for step in range(10000):
real = torch.tensor(p_data(X)).float()
idea = torch.randn((BATCH_SIZE, IDEA))
fake = G(idea)
prob_fake = D(fake)
G_loss = torch.mean(torch.log(torch.tensor(1) - prob_fake))
G_optimizer.zero_grad()
G_loss.backward()
G_optimizer.step()
prob_real = D(real)
prob_fake = D(fake.detach())
D_loss = -torch.mean((torch.log(prob_real) + torch.log(torch.tensor(1) - prob_fake)))
D_optimizer.zero_grad()
D_loss.backward(retain_graph=True)
D_optimizer.step()
if step % 100 == 0:
print(prob_real.mean())
print(prob_fake.mean())
print('-----------------------------------------------')
if torch.abs(prob_real.mean() - 0.5) <= 1.e-6:
break
if step % 50 == 0: # plotting
plt.cla()
plt.plot(X, fake.data.numpy()[0], c='red', lw=3, label='Generated painting')
plt.plot(X, real.data.numpy()[0], c='black', lw=1, label='real painting')
plt.text(1, .5, 'the prob of Generated painting is real = %.2f' % prob_fake.data.numpy().mean())
plt.ylim((-1.1, 1.1))
plt.legend(loc='best', fontsize=10)
plt.draw()
plt.pause(0.01)
plt.ioff()
plt.show()