实例化DCGAN
1 2 3 4 5 6 7 8 9 10 11 12 13 |
dcgan = DCGAN( sess, input_width=FLAGS.input_width, input_height=FLAGS.input_height, output_width=FLAGS.output_width, output_height=FLAGS.output_height, batch_size=FLAGS.batch_size, c_dim=FLAGS.c_dim, dataset_name=FLAGS.dataset, input_fname_pattern=FLAGS.input_fname_pattern, is_crop=FLAGS.is_crop, checkpoint_dir=FLAGS.checkpoint_dir, sample_dir=FLAGS.sample_dir) |
batch_size
网络迭代同时处理的图像数量,共同决定网络该怎么更新
c_dim
输入图像的通道数,彩色图为3
dataset
数据集所在文件夹的名称
is_crop
当输入图像比较大的时候,可以以图像为中心,进行相应尺寸的剪裁
checkpoint_dir
用来保存模型参数的
DCGAN类解读
初始化
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 |
def __init__(self, sess, input_height=108, input_width=108, is_crop=True, batch_size=64, sample_num = 64, output_height=64, output_width=64, y_dim=None, z_dim=100, gf_dim=64, df_dim=64, gfc_dim=1024, dfc_dim=1024, c_dim=3, dataset_name='default', input_fname_pattern='*.jpg', checkpoint_dir=None, sample_dir=None): """ Args: sess: TensorFlow session batch_size: The size of batch. Should be specified before training. y_dim: (optional) Dimension of dim for y. [None] z_dim: (optional) Dimension of dim for Z. [100] gf_dim: (optional) Dimension of gen filters in first conv layer. [64] df_dim: (optional) Dimension of discrim filters in first conv layer. [64] gfc_dim: (optional) Dimension of gen units for for fully connected layer. [1024] dfc_dim: (optional) Dimension of discrim units for fully connected layer. [1024] c_dim: (optional) Dimension of image color. For grayscale input, set to 1. [3] """ self.sess = sess self.is_crop = is_crop self.is_grayscale = (c_dim == 1) self.batch_size = batch_size self.sample_num = sample_num #测试生成网络 self.input_height = input_height self.input_width = input_width self.output_height = output_height self.output_width = output_width self.y_dim = y_dim self.z_dim = z_dim #生成网络训练时,输入噪声点的维度 self.gf_dim = gf_dim #64为基数 特征图的基数 self.df_dim = df_dim self.gfc_dim = gfc_dim #全连接的向量 1024 self.dfc_dim = dfc_dim self.c_dim = c_dim #最后想得到三通道的彩色图,即64*64*3 # batch normalization : deals with poor initialization helps gradient flow self.d_bn1 = batch_norm(name='d_bn1') #经过卷积,但还没经过relu的时候,进行一个归一化操作 self.d_bn2 = batch_norm(name='d_bn2') if not self.y_dim: self.d_bn3 = batch_norm(name='d_bn3') #判别网络三层 self.g_bn0 = batch_norm(name='g_bn0') self.g_bn1 = batch_norm(name='g_bn1') self.g_bn2 = batch_norm(name='g_bn2') if not self.y_dim: self.g_bn3 = batch_norm(name='g_bn3') #生成网络四层 self.dataset_name = dataset_name #文件夹的名字 self.input_fname_pattern = input_fname_pattern #取图片,默认为*.jpg self.checkpoint_dir = checkpoint_dir self.build_model() |
创建模型
在该部分中,完成了对判别网络和生成网络内部结构的构建
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 |
def build_model(self): if self.y_dim: self.y= tf.placeholder(tf.float32, [self.batch_size, self.y_dim], name='y') if self.is_crop: image_dims = [self.output_height, self.output_width, self.c_dim] #希望crop后,输出是64*64*3 else: image_dims = [self.input_height, self.input_height, self.c_dim] self.inputs = tf.placeholder( tf.float32, [self.batch_size] + image_dims, name='real_images') #输入图像的大小 self.sample_inputs = tf.placeholder( tf.float32, [self.sample_num] + image_dims, name='sample_inputs') inputs = self.inputs sample_inputs = self.sample_inputs self.z = tf.placeholder( tf.float32, [None, self.z_dim], name='z') #噪音数据,None:添加多大的数都可以,?*100 self.z_sum = histogram_summary("z", self.z) if self.y_dim: <span style="color: #ff0000;">self.G = self.generator(self.z, self.y) #创建生成网络,输入为噪音向量</span> self.D, self.D_logits = \ self.discriminator(inputs, self.y, reuse=False) self.sampler = self.sampler(self.z, self.y) self.D_, self.D_logits_ = \ self.discriminator(self.G, self.y, reuse=True) else: self.G = self.generator(self.z) self.D, self.D_logits = self.discriminator(inputs) #真实的输入 self.sampler = self.sampler(self.z) self.D_, self.D_logits_ = self.discriminator(self.G, reuse=True) #生成的输入 self.d_sum = histogram_summary("d", self.D) self.d__sum = histogram_summary("d_", self.D_) self.G_sum = image_summary("G", self.G) self.d_loss_real = tf.reduce_mean( #利用了交叉熵函数 tf.nn.sigmoid_cross_entropy_with_logits( logits=self.D_logits, targets=tf.ones_like(self.D))) #希望真实输入判别网络判别为1 labels self.d_loss_fake = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits( logits=self.D_logits_, targets=tf.zeros_like(self.D_))) #生成输入希望判别网络为0 self.g_loss = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits( logits=self.D_logits_, targets=tf.ones_like(self.D_))) #生成输入希望生成网络为1 self.d_loss_real_sum = scalar_summary("d_loss_real", self.d_loss_real) self.d_loss_fake_sum = scalar_summary("d_loss_fake", self.d_loss_fake) self.d_loss = self.d_loss_real + self.d_loss_fake self.g_loss_sum = scalar_summary("g_loss", self.g_loss) self.d_loss_sum = scalar_summary("d_loss", self.d_loss) t_vars = tf.trainable_variables() self.d_vars = [var for var in t_vars if 'd_' in var.name] self.g_vars = [var for var in t_vars if 'g_' in var.name] self.saver = tf.train.Saver() #保存节点 |
生成网络的构建
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 |
def generator(self, z, y=None): with tf.variable_scope("generator") as scope: if not self.y_dim: s_h, s_w = self.output_height, self.output_width #64*64 s_h2, s_w2 = conv_out_size_same(s_h, 2), conv_out_size_same(s_w, 2) #32*32 s_h4, s_w4 = conv_out_size_same(s_h2, 2), conv_out_size_same(s_w2, 2) #16*16 s_h8, s_w8 = conv_out_size_same(s_h4, 2), conv_out_size_same(s_w4, 2) #8*8 s_h16, s_w16 = conv_out_size_same(s_h8, 2), conv_out_size_same(s_w8, 2) #4*4 每一层特征图的大小 # project `z` and reshape <span style="color: #ff0000;">self.z_, self.h0_w, self.h0_b = linear( #输入z向量,输出8192向量;全连接操作 z, self.gf_dim*8*s_h16*s_w16, 'g_h0_lin', with_w=True)</span> self.h0 = tf.reshape( self.z_, [-1, s_h16, s_w16, self.gf_dim * 8]) #reshape成 4*4*512,-1表示可推断出的数,reshape后,h0变成了一个特征图 <span style="color: #ff0000;">h0 = tf.nn.relu(self.g_bn0(self.h0)) #relu层</span> <span style="color: #ff0000;">self.h1, self.h1_w, self.h1_b = deconv2d( h0, [self.batch_size, s_h8, s_w8, self.gf_dim*4], name='g_h1', with_w=True)</span> h1 = tf.nn.relu(self.g_bn1(self.h1)) <span style="color: #ff0000;">h2, self.h2_w, self.h2_b = deconv2d( h1, [self.batch_size, s_h4, s_w4, self.gf_dim*2], name='g_h2', with_w=True)</span> h2 = tf.nn.relu(self.g_bn2(h2)) <span style="color: #ff0000;"> h3, self.h3_w, self.h3_b = deconv2d( h2, [self.batch_size, s_h2, s_w2, self.gf_dim*1], name='g_h3', with_w=True)</span> h3 = tf.nn.relu(self.g_bn3(h3)) h4, self.h4_w, self.h4_b = deconv2d( h3, [self.batch_size, s_h, s_w, self.c_dim], name='g_h4', with_w=True) return tf.nn.tanh(h4) else: s_h, s_w = self.output_height, self.output_width s_h2, s_h4 = int(s_h/2), int(s_h/4) s_w2, s_w4 = int(s_w/2), int(s_w/4) # yb = tf.expand_dims(tf.expand_dims(y, 1),2) yb = tf.reshape(y, [self.batch_size, 1, 1, self.y_dim]) z = concat([z, y], 1) h0 = tf.nn.relu( self.g_bn0(linear(z, self.gfc_dim, 'g_h0_lin'))) h0 = concat([h0, y], 1) h1 = tf.nn.relu(self.g_bn1( linear(h0, self.gf_dim*2*s_h4*s_w4, 'g_h1_lin'))) h1 = tf.reshape(h1, [self.batch_size, s_h4, s_w4, self.gf_dim * 2]) h1 = conv_cond_concat(h1, yb) h2 = tf.nn.relu(self.g_bn2(deconv2d(h1, [self.batch_size, s_h2, s_w2, self.gf_dim * 2], name='g_h2'))) h2 = conv_cond_concat(h2, yb) return tf.nn.sigmoid( deconv2d(h2, [self.batch_size, s_h, s_w, self.c_dim], name='g_h3')) |
全连接层
1 2 3 4 5 6 7 8 9 10 11 12 |
def linear(input_, output_size, scope=None, stddev=0.02, bias_start=0.0, with_w=False): shape = input_.get_shape().as_list() #100 with tf.variable_scope(scope or "Linear"): matrix = tf.get_variable("Matrix", [shape[1], output_size], tf.float32, #w 100*8192 tf.random_normal_initializer(stddev=stddev)) bias = tf.get_variable("bias", [output_size], #b 8192 initializer=tf.constant_initializer(bias_start)) #初始化 if with_w: return tf.matmul(input_, matrix) + bias, matrix, bias else: return tf.matmul(input_, matrix) + bias |
反卷积层
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 |
def deconv2d(input_, output_shape, k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02, name="deconv2d", with_w=False): with tf.variable_scope(name): # filter : [height, width, output_channels, in_channels] w = tf.get_variable('w', [k_h, k_w, output_shape[-1], input_.get_shape()[-1]], initializer=tf.random_normal_initializer(stddev=stddev)) try: deconv = tf.nn.conv2d_transpose(input_, w, output_shape=output_shape, strides=[1, d_h, d_w, 1]) # Support for verisons of TensorFlow before 0.7.0 except AttributeError: deconv = tf.nn.deconv2d(input_, w, output_shape=output_shape, strides=[1, d_h, d_w, 1]) biases = tf.get_variable('biases', [output_shape[-1]], initializer=tf.constant_initializer(0.0)) #b的大小总与输出的大小有关 deconv = tf.reshape(tf.nn.bias_add(deconv, biases), deconv.get_shape()) if with_w: return deconv, w, biases else: return deconv |
判别网络的构建
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 |
def discriminator(self, image, y=None, reuse=False): with tf.variable_scope("discriminator") as scope: if reuse: scope.reuse_variables() if not self.y_dim: h0 = lrelu(conv2d(image, self.df_dim, name='d_h0_conv')) h1 = lrelu(self.d_bn1(conv2d(h0, self.df_dim*2, name='d_h1_conv'))) h2 = lrelu(self.d_bn2(conv2d(h1, self.df_dim*4, name='d_h2_conv'))) h3 = lrelu(self.d_bn3(conv2d(h2, self.df_dim*8, name='d_h3_conv'))) h4 = linear(tf.reshape(h3, [self.batch_size, -1]), 1, 'd_h3_lin') return tf.nn.sigmoid(h4), h4 else: yb = tf.reshape(y, [self.batch_size, 1, 1, self.y_dim]) x = conv_cond_concat(image, yb) h0 = lrelu(conv2d(x, self.c_dim + self.y_dim, name='d_h0_conv')) h0 = conv_cond_concat(h0, yb) h1 = lrelu(self.d_bn1(conv2d(h0, self.df_dim + self.y_dim, name='d_h1_conv'))) h1 = tf.reshape(h1, [self.batch_size, -1]) h1 = concat([h1, y], 1) h2 = lrelu(self.d_bn2(linear(h1, self.dfc_dim, 'd_h2_lin'))) h2 = concat([h2, y], 1) h3 = linear(h2, 1, 'd_h3_lin') return tf.nn.sigmoid(h3), h3 |
卷积层
1 2 3 4 5 6 7 8 9 10 11 12 |
def conv2d(input_, output_dim, k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02, name="conv2d"): with tf.variable_scope(name): w = tf.get_variable('w', [k_h, k_w, input_.get_shape()[-1], output_dim], initializer=tf.truncated_normal_initializer(stddev=stddev)) conv = tf.nn.conv2d(input_, w, strides=[1, d_h, d_w, 1], padding='SAME') biases = tf.get_variable('biases', [output_dim], initializer=tf.constant_initializer(0.0)) conv = tf.reshape(tf.nn.bias_add(conv, biases), conv.get_shape()) return conv |
训练DCGAN
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 |
def train(self, config): """Train DCGAN""" if config.dataset == 'mnist': data_X, data_y = self.load_mnist() else: data = glob(os.path.join("./data", config.dataset, self.input_fname_pattern)) #图像读入 #np.random.shuffle(data) d_optim = tf.train.AdamOptimizer(config.learning_rate, beta1=config.beta1) \ #定义了优化器,最小化Loss .minimize(self.d_loss, var_list=self.d_vars) g_optim = tf.train.AdamOptimizer(config.learning_rate, beta1=config.beta1) \ .minimize(self.g_loss, var_list=self.g_vars) try: tf.global_variables_initializer().run() #全局变量初始化 except: tf.initialize_all_variables().run() self.g_sum = merge_summary([self.z_sum, self.d__sum, self.G_sum, self.d_loss_fake_sum, self.g_loss_sum]) self.d_sum = merge_summary( [self.z_sum, self.d_sum, self.d_loss_real_sum, self.d_loss_sum]) self.writer = SummaryWriter("./logs", self.sess.graph) sample_z = np.random.uniform(-1, 1, size=(self.sample_num , self.z_dim)) #生成一个batch的sample_z噪音向量 if config.dataset == 'mnist': #取数据 sample_inputs = data_X[0:self.sample_num] sample_labels = data_y[0:self.sample_num] else: sample_files = data[0:self.sample_num] sample = [ get_image(sample_file, input_height=self.input_height, input_width=self.input_width, resize_height=self.output_height, resize_width=self.output_width, is_crop=self.is_crop, is_grayscale=self.is_grayscale) for sample_file in sample_files] if (self.is_grayscale): sample_inputs = np.array(sample).astype(np.float32)[:, :, :, None] else: sample_inputs = np.array(sample).astype(np.float32) #归一化处理完的值 counter = 1 start_time = time.time() if self.load(self.checkpoint_dir): print(" [*] Load SUCCESS") else: print(" [!] Load failed...") for epoch in xrange(config.epoch): if config.dataset == 'mnist': batch_idxs = min(len(data_X), config.train_size) // config.batch_size else: data = glob(os.path.join( "./data", config.dataset, self.input_fname_pattern)) batch_idxs = min(len(data), config.train_size) // config.batch_size for idx in xrange(0, batch_idxs): if config.dataset == 'mnist': batch_images = data_X[idx*config.batch_size:(idx+1)*config.batch_size] batch_labels = data_y[idx*config.batch_size:(idx+1)*config.batch_size] else: batch_files = data[idx*config.batch_size:(idx+1)*config.batch_size] batch = [ get_image(batch_file, input_height=self.input_height, input_width=self.input_width, resize_height=self.output_height, resize_width=self.output_width, is_crop=self.is_crop, is_grayscale=self.is_grayscale) for batch_file in batch_files] if (self.is_grayscale): batch_images = np.array(batch).astype(np.float32)[:, :, :, None] else: batch_images = np.array(batch).astype(np.float32) #一个Batch的图像 batch_z = np.random.uniform(-1, 1, [config.batch_size, self.z_dim]) \ #对应生成网络的batch输入 .astype(np.float32) if config.dataset == 'mnist': # Update D network _, summary_str = self.sess.run([d_optim, self.d_sum], feed_dict={ self.inputs: batch_images, self.z: batch_z, self.y:batch_labels, }) self.writer.add_summary(summary_str, counter) # Update G network _, summary_str = self.sess.run([g_optim, self.g_sum], feed_dict={ self.z: batch_z, self.y:batch_labels, }) self.writer.add_summary(summary_str, counter) # Run g_optim twice to make sure that d_loss does not go to zero (different from paper) _, summary_str = self.sess.run([g_optim, self.g_sum], feed_dict={ self.z: batch_z, self.y:batch_labels }) self.writer.add_summary(summary_str, counter) errD_fake = self.d_loss_fake.eval({ self.z: batch_z, self.y:batch_labels }) errD_real = self.d_loss_real.eval({ self.inputs: batch_images, self.y:batch_labels }) errG = self.g_loss.eval({ self.z: batch_z, self.y: batch_labels }) else: # Update D network _, summary_str = self.sess.run([d_optim, self.d_sum], feed_dict={ self.inputs: batch_images, self.z: batch_z }) self.writer.add_summary(summary_str, counter) # Update G network _, summary_str = self.sess.run([g_optim, self.g_sum], feed_dict={ self.z: batch_z }) self.writer.add_summary(summary_str, counter) # Run g_optim twice to make sure that d_loss does not go to zero (different from paper) _, summary_str = self.sess.run([g_optim, self.g_sum], feed_dict={ self.z: batch_z }) self.writer.add_summary(summary_str, counter) errD_fake = self.d_loss_fake.eval({ self.z: batch_z }) errD_real = self.d_loss_real.eval({ self.inputs: batch_images }) errG = self.g_loss.eval({self.z: batch_z}) counter += 1 print("Epoch: [%2d] [%4d/%4d] time: %4.4f, d_loss: %.8f, g_loss: %.8f" \ % (epoch, idx, batch_idxs, time.time() - start_time, errD_fake+errD_real, errG)) if np.mod(counter, 100) == 1: #每迭代100次保存 if config.dataset == 'mnist': samples, d_loss, g_loss = self.sess.run( [self.sampler, self.d_loss, self.g_loss], feed_dict={ self.z: sample_z, self.inputs: sample_inputs, self.y:sample_labels, } ) save_images(samples, [8, 8], './{}/train_{:02d}_{:04d}.png'.format(config.sample_dir, epoch, idx)) print("[Sample] d_loss: %.8f, g_loss: %.8f" % (d_loss, g_loss)) else: try: samples, d_loss, g_loss = self.sess.run( [self.sampler, self.d_loss, self.g_loss], feed_dict={ self.z: sample_z, self.inputs: sample_inputs, }, ) save_images(samples, [8, 8], './{}/train_{:02d}_{:04d}.png'.format(config.sample_dir, epoch, idx)) print("[Sample] d_loss: %.8f, g_loss: %.8f" % (d_loss, g_loss)) except: print("one pic error!...") if np.mod(counter, 100) == 2: self.save(config.checkpoint_dir, counter) |