在前面三篇博文中,我们已经完成了三个网络的构建、损失函数的构造
接下来,就是激动人心的训练时刻了
训练预先判别D_pre网络
1 2 3 4 5 6 7 8 9 10 11 12 |
num_pretrain_steps = 1000 #迭代次数 for step in range(num_pretrain_steps): d = (np.random.random(self.batch_size) - 0.5) * 10.0 #生成(-5,5)的数据 labels = norm.pdf(d, loc=self.data.mu, scale=self.data.sigma) #d 4 0.5 方差 pretrain_loss, _ = session.run([self.pre_loss, self.pre_opt], { self.pre_input: np.reshape(d, (self.batch_size, 1)), self.pre_labels: np.reshape(labels, (self.batch_size, 1)) #一维的 }) self.weightsD = session.run(self.d_pre_params) #先训练D_pre,把参数提取出来 # copy weights from pre-training over to new D network for i, v in enumerate(self.d_params): session.run(v.assign(self.weightsD[i])) |
首先确定D_pre的训练次数,在这儿我们设定训练1000次
然后随机生成数据值(采样点),在通过高斯计算,得到相应的label,这个时候得到的数据是真实的数据
然后再把d和label放到之前构造好的D_pre网络模型中(在之前的构造中,已经在D_pre中预先挖好了坑,这个时候我们只要把已知的d和label填坑就行),session.run相当于迭代一次
当迭代1000次完成之后,取出D_pre训练好的模型(即取出w和b参数,作为判别网络D的初始化值)
训练真正的对抗生成网络
先构造出x真实的数据点
1 2 3 4 |
def sample(self, N): samples = np.random.normal(self.mu, self.sigma, N) samples.sort() return samples |
在构造出z随机的噪声点,以此来作为生成网络的输入
1 2 3 |
def sample(self, N): return np.linspace(-self.range, self.range, N) + \ np.random.random(N) * 0.01 |
开始训练
1 2 3 4 5 6 7 8 9 10 11 12 13 14 |
for step in range(self.num_steps): #开始迭代 # update discriminator 判别网络 x = self.data.sample(self.batch_size) #已知的高斯初始化 z = self.gen.sample(self.batch_size) #随机的高斯初始化 loss_d, _ = session.run([self.loss_d, self.opt_d], { self.x: np.reshape(x, (self.batch_size, 1)), #两种输入 self.z: np.reshape(z, (self.batch_size, 1)) }) # update generator 生成网络 z = self.gen.sample(self.batch_size) #噪音,随机初始化 loss_g, _ = session.run([self.loss_g, self.opt_g], { self.z: np.reshape(z, (self.batch_size, 1)) #迭代优化 }) |
1 |
将x和z输入判别网络中,先优化判别网络D,在优化生成网络G
总结
1.定义好生成网络、判别网络,并定义好损失函数
2.实际的训练,传入两组数据