官方资源帖!手把手教你在TensorFlow 2.0中实现CycleGAN,推特上百赞

栏目: 数据库 · 发布时间: 5年前

内容简介:Last updated on 2019年7月1日CycleGAN,一个可以将一张图像的特征迁移到另一张图像的酷算法,此前可以完成马变斑马、冬天变夏天、苹果变桔子等一颗赛艇的效果。

Last updated on 2019年7月1日

CycleGAN,一个可以将一张图像的特征迁移到另一张图像的酷算法,此前可以完成马变斑马、冬天变夏天、苹果变桔子等一颗赛艇的效果。

官方资源帖!手把手教你在TensorFlow 2.0中实现CycleGAN,推特上百赞

这行被顶会ICCV收录的研究自提出后,就为图形学等领域的技术人员所用,甚至还成为不少艺术家用来创作的工具。

官方资源帖!手把手教你在TensorFlow 2.0中实现CycleGAN,推特上百赞

也是目前大火的“换脸”技术的老前辈了。

官方资源帖!手把手教你在TensorFlow 2.0中实现CycleGAN,推特上百赞

如果你还没学会这项厉害的研究,那这次一定要抓紧上车了。

现在,TensorFlow开始手把手教你,在TensorFlow 2.0中CycleGAN实现大法。

这个 官方教程贴 几天内收获了满满人气,获得了Google AI工程师、哥伦比亚大学数据科学研究所Josh Gordon的推荐,推特上已近600赞。

官方资源帖!手把手教你在TensorFlow 2.0中实现CycleGAN,推特上百赞

有国外网友称赞太棒,表示很高兴看到TensorFlow 2.0教程中涵盖了最先进的模型。

这份教程全面详细,想学CycleGAN不能错过这个:

详细内容

在TensorFlow 2.0中实现CycleGAN,只要7个步骤就可以了。

1、设置输入Pipeline

安装tensorflow_examples包,用于导入生成器和鉴别器。

!pip install -q git+https: //github.com/tensorflow/examples.git

!pip install -q tensorflow-gpu==2.0.0-beta1
import tensorflow as tf
from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow_datasets as tfds
from tensorflow_examples.models.pix2pix import pix2pix
import os
import time
import matplotlib.pyplot as plt
from IPython.display import clear_output
tfds.disable_progress_bar()
AUTOTUNE = tf.data.experimental.AUTOTUNE

2、输入pipeline

在这个教程中,我们主要学习马到斑马的图像转换,如果想寻找类似的数据集,可以前往:

https://www.tensorflow.org/datasets/datasets#cycle_gan

在CycleGAN论文中也提到,将随机抖动( Jitter )和镜像应用到训练集中,这是避免过度拟合的图像增强技术。

和在Pix2Pix中的操作类似,在随机抖动中吗,图像大小被调整成286×286,然后随机裁剪为256×256。

在随机镜像中吗,图像随机水平翻转,即从左到右进行翻转。

官方资源帖!手把手教你在TensorFlow 2.0中实现CycleGAN,推特上百赞

BUFFER_SIZE = 1000
BATCH_SIZE = 1
IMG_WIDTH = 256
IMG_HEIGHT = 256
def random_crop(image):
  cropped_image = tf.image.random_crop(
      image, size=[IMG_HEIGHT, IMG_WIDTH, 3])
  return cropped_image
# normalizing the images to [-1, 1]
def normalize(image):
  image = tf.cast(image, tf.float32)
  image = (image / 127.5) - 1
  return image
def random_jitter(image):
  # resizing to 286 x 286 x 3
  image = tf.image.resize(image, [286, 286],
                          method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
  # randomly cropping to 256 x 256 x 3
  image = random_crop(image)
  # random mirroring
  image = tf.image.random_flip_left_right(image)
  return image
def preprocess_image_train(image, label):
  image = random_jitter(image)
  image = normalize(image)
  return image
def preprocess_image_test(image, label):
  image = normalize(image)
  return image
train_horses = train_horses.map(
    preprocess_image_train, num_parallel_calls=AUTOTUNE).cache().shuffle(
    BUFFER_SIZE).batch(1)
train_zebras = train_zebras.map(
    preprocess_image_train, num_parallel_calls=AUTOTUNE).cache().shuffle(
    BUFFER_SIZE).batch(1)
test_horses = test_horses.map(
    preprocess_image_test, num_parallel_calls=AUTOTUNE).cache().shuffle(
    BUFFER_SIZE).batch(1)
test_zebras = test_zebras.map(
    preprocess_image_test, num_parallel_calls=AUTOTUNE).cache().shuffle(
    BUFFER_SIZE).batch(1)
sample_horse = next(iter(train_horses))
sample_zebra = next(iter(train_zebras))
plt.subplot(121)
plt.title('Horse')
plt.imshow(sample_horse[0] * 0.5 + 0.5)
plt.subplot(122)
plt.title('Horse with random jitter')
plt.imshow(random_jitter(sample_horse[0]) * 0.5 + 0.5)

官方资源帖!手把手教你在TensorFlow 2.0中实现CycleGAN,推特上百赞

官方资源帖!手把手教你在TensorFlow 2.0中实现CycleGAN,推特上百赞

3、导入并重新使用Pix2Pix模型

通过安装tensorflow_examples包,从Pix2Pix中导入生成器和鉴别器。

这个教程中使用的模型体系结构与Pix2Pix中很类似,但也有一些差异,比如Cyclegan使用的是实例规范化而不是批量规范化,比如Cyclegan论文使用的是修改后的resnet生成器等。

我们训练两个生成器(G和F)和两个鉴别器(X和Y)。生成器G架构图像X转换为图像Y,生成器F将图像Y转换为图像X。

鉴别器D_X区分图像X和生成的图像X(F(Y)),辨别器D_Y区分图像Y和生成的图像Y(G(X))。

官方资源帖!手把手教你在TensorFlow 2.0中实现CycleGAN,推特上百赞

OUTPUT_CHANNELS = 3
generator_g = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')
generator_f = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')
discriminator_x = pix2pix.discriminator(norm_type='instancenorm', target=False)
discriminator_y = pix2pix.discriminator(norm_type='instancenorm', target=False)
to_zebra = generator_g(sample_horse)
to_horse = generator_f(sample_zebra)
plt.figure(figsize=(8, 8))
contrast = 8
plt.subplot(221)
plt.title('Horse')
plt.imshow(sample_horse[0] * 0.5 + 0.5)
plt.subplot(222)
plt.title('To Zebra')
plt.imshow(to_zebra[0] * 0.5 * contrast + 0.5)
plt.subplot(223)
plt.title('Zebra')
plt.imshow(sample_zebra[0] * 0.5 + 0.5)
plt.subplot(224)
plt.title('To Horse')
plt.imshow(to_horse[0] * 0.5 * contrast + 0.5)
plt.show()

官方资源帖!手把手教你在TensorFlow 2.0中实现CycleGAN,推特上百赞

plt.figure(figsize=(8, 8))
plt.subplot(121)
plt.title('Is a real zebra?')
plt.imshow(discriminator_y(sample_zebra)[0, ..., -1], cmap='RdBu_r')
plt.subplot(122)
plt.title('Is a real horse?')
plt.imshow(discriminator_x(sample_horse)[0, ..., -1], cmap='RdBu_r')
plt.show()

官方资源帖!手把手教你在TensorFlow 2.0中实现CycleGAN,推特上百赞

4、损失函数

在CycleGAN中,因为没有用于训练的成对数据,因此无法保证输入X和目标Y在训练期间是否有意义。因此,为了强制学习正确的映射,CycleGAN中提出了“循环一致性损失”(cycle consistency loss)。

鉴别器和生成器的损失与Pix2Pix中的类似。

LAMBDA = 10
loss_obj = tf.keras.losses.BinaryCrossentropy(from_logits=True)
def discriminator_loss(real, generated):
  real_loss = loss_obj(tf.ones_like(real), real)
  generated_loss = loss_obj(tf.zeros_like(generated), generated)
  total_disc_loss = real_loss + generated_loss
  return total_disc_loss * 0.5
def generator_loss(generated):
  return loss_obj(tf.ones_like(generated), generated)

循环一致性意味着结果接近原始输入。

例如将一个句子和英语翻译成法语,再将其从法语翻译成英语后,结果与原始英文句子相同。

在循环一致性损失中,图像X通过生成器传递C产生的图像Y^,生成的图像Y^通过生成器传递F产生的图像X^,然后计算平均绝对误差X和X^。

前向循环一致性损失为:

官方资源帖!手把手教你在TensorFlow 2.0中实现CycleGAN,推特上百赞

反向循环一致性损失为:

官方资源帖!手把手教你在TensorFlow 2.0中实现CycleGAN,推特上百赞

def calc_cycle_loss(real_image, cycled_image):
  loss1 = tf.reduce_mean(tf.abs(real_image - cycled_image))
  return LAMBDA * loss1

初始化所有生成器和鉴别器的的优化:

generator_g_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
generator_f_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_x_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_y_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

5、检查点

checkpoint_path = "./checkpoints/train"
ckpt = tf.train.Checkpoint(generator_g=generator_g,
                           generator_f=generator_f,
                           discriminator_x=discriminator_x,
                           discriminator_y=discriminator_y,
                           generator_g_optimizer=generator_g_optimizer,
                           generator_f_optimizer=generator_f_optimizer,
                           discriminator_x_optimizer=discriminator_x_optimizer,
                           discriminator_y_optimizer=discriminator_y_optimizer)
ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)
# if a checkpoint exists, restore the latest checkpoint.
if ckpt_manager.latest_checkpoint:
  ckpt.restore(ckpt_manager.latest_checkpoint)
  print ('Latest checkpoint restored!!')

6、训练

注意:为了使本教程的训练时间合理,本示例模型迭代次数较少(40次,论文中为200次),预测效果可能不如论文准确。

EPOCHS = 40
def generate_images(model, test_input):
  prediction = model(test_input)
  plt.figure(figsize=(12, 12))
  display_list = [test_input[0], prediction[0]]
  title = ['Input Image', 'Predicted Image']
  for i in range(2):
    plt.subplot(1, 2, i+1)
    plt.title(title[i])
    # getting the pixel values between [0, 1] to plot it.
    plt.imshow(display_list[i] * 0.5 + 0.5)
    plt.axis('off')
  plt.show()

尽管训练起来很复杂,但基本的步骤只有四个,分别为:获取预测、计算损失、使用反向传播计算梯度、将梯度应用于优化程序。

@tf.function
def train_step(real_x, real_y):
  # persistent is set to True because gen_tape and disc_tape is used more than
  # once to calculate the gradients.
  with tf.GradientTape(persistent=True) as gen_tape, tf.GradientTape(
      persistent=True) as disc_tape:
    fake_y = generator_g(real_x, training=True)
    cycled_x = generator_f(fake_y, training=True)
    fake_x = generator_f(real_y, training=True)
    cycled_y = generator_g(fake_x, training=True)
    disc_real_x = discriminator_x(real_x, training=True)
    disc_real_y = discriminator_y(real_y, training=True)
    disc_fake_x = discriminator_x(fake_x, training=True)
    disc_fake_y = discriminator_y(fake_y, training=True)
    # calculate the loss
    gen_g_loss = generator_loss(disc_fake_y)
    gen_f_loss = generator_loss(disc_fake_x)
    # Total generator loss = adversarial loss + cycle loss
    total_gen_g_loss = gen_g_loss + calc_cycle_loss(real_x, cycled_x)
    total_gen_f_loss = gen_f_loss + calc_cycle_loss(real_y, cycled_y)
    disc_x_loss = discriminator_loss(disc_real_x, disc_fake_x)
    disc_y_loss = discriminator_loss(disc_real_y, disc_fake_y)
  # Calculate the gradients for generator and discriminator
  generator_g_gradients = gen_tape.gradient(total_gen_g_loss, 
                                            generator_g.trainable_variables)
  generator_f_gradients = gen_tape.gradient(total_gen_f_loss, 
                                            generator_f.trainable_variables)
  discriminator_x_gradients = disc_tape.gradient(
      disc_x_loss, discriminator_x.trainable_variables)
  discriminator_y_gradients = disc_tape.gradient(
      disc_y_loss, discriminator_y.trainable_variables)
  # Apply the gradients to the optimizer
  generator_g_optimizer.apply_gradients(zip(generator_g_gradients, 
                                             generator_g.trainable_variables))
  generator_f_optimizer.apply_gradients(zip(generator_f_gradients, 
                                             generator_f.trainable_variables))
  discriminator_x_optimizer.apply_gradients(
      zip(discriminator_x_gradients,
      discriminator_x.trainable_variables))
  discriminator_y_optimizer.apply_gradients(
      zip(discriminator_y_gradients,
      discriminator_y.trainable_variables))
for epoch in range(EPOCHS):
  start = time.time()
  n = 0
  for image_x, image_y in tf.data.Dataset.zip((train_horses, train_zebras)):
    train_step(image_x, image_y)
    if n % 10 == 0:
      print ('.', end='')
    n+=1
  clear_output(wait=True)
  # Using a consistent image (sample_horse) so that the progress of the model
  # is clearly visible.
  generate_images(generator_g, sample_horse)
  if (epoch + 1) % 5 == 0:
    ckpt_save_path = ckpt_manager.save()
    print ('Saving checkpoint for epoch {} at {}'.format(epoch+1,
                                                         ckpt_save_path))
  print ('Time taken for epoch {} is {} sec\n'.format(epoch + 1,
                                                      time.time()-start))

官方资源帖!手把手教你在TensorFlow 2.0中实现CycleGAN,推特上百赞

官方资源帖!手把手教你在TensorFlow 2.0中实现CycleGAN,推特上百赞

7、使用测试集生成图像

# Run the trained model on the test dataset
for inp in test_horses.take(5):
  generate_images(generator_g, inp)

官方资源帖!手把手教你在TensorFlow 2.0中实现CycleGAN,推特上百赞

官方资源帖!手把手教你在TensorFlow 2.0中实现CycleGAN,推特上百赞

官方资源帖!手把手教你在TensorFlow 2.0中实现CycleGAN,推特上百赞

8、进阶学习方向

在上面的教程中,我们学习了如何从Pix2Pix中实现的生成器和鉴别器进一步实现CycleGAN,接下来的学习你可以尝试使用TensorFlow中的其他数据集。

你还可以用更多次的迭代改善结果,或者实现论文中修改的ResNet生成器,进行知识点的进一步巩固。

传送门

https://www.tensorflow.org/beta/tutorials/generative/cyclegan

GitHub地址:

https://github.com/tensorflow/docs/blob/master/site/en/r2/tutorials/generative/cyclegan.ipynb


以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持 码农网

查看所有标签

猜你喜欢:

本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们

构建之法(第三版)

构建之法(第三版)

邹欣 / 人民邮电出版社 / 2017-6 / 69.00元

软件工程牵涉的范围很广, 同时也是一般院校的同学反映比较空洞乏味的课程。 但是,软件工程 的技术对于投身 IT 产业的学生来说是非常重要的。作者有在世界一流软件企业 20 年的一线软件开 发经验,他在数所高校进行了多年的软件工程教学实践,总结出了在 16 周的时间内让同学们通过 “做 中学 (Learning By Doing)” 掌握实用的软件工程技术的教学计划,并得到高校师生的积极反馈。在此 ......一起来看看 《构建之法(第三版)》 这本书的介绍吧!

在线进制转换器
在线进制转换器

各进制数互转换器

SHA 加密
SHA 加密

SHA 加密工具

HEX HSV 转换工具
HEX HSV 转换工具

HEX HSV 互换工具