风格迁移-TensorFlow

栏目: 编程工具 · 发布时间: 5年前

内容简介:这篇文章使用 TensorFlow 实现风格迁移。$$J_{content}(C,G) = \frac{1}{4 \times n_H \times n_W \times n_C}\sum _{ \text{all entries}} (a^{(C)} - a^{(G)})^2\tag{1} $$

风格迁移-TensorFlow

这篇文章使用 TensorFlow 实现风格迁移。

导包

import os
import sys
import scipy.io
import scipy.misc
import matplotlib.pyplot as plt
from matplotlib.pyplot import imshow
from PIL import Image
from nst_utils import *
import numpy as np
import tensorflow as tf

%matplotlib inline

定义 Cost

Content Cost

$$J_{content}(C,G) = \frac{1}{4 \times n_H \times n_W \times n_C}\sum _{ \text{all entries}} (a^{(C)} - a^{(G)})^2\tag{1} $$

def compute_content_cost(a_C, a_G):
    """
a_G -- 图片G的隐藏层,(1, n_H, n_W, n_C)
a_C -- 图片C的隐藏层,(1, n_H, n_W, n_C)
"""
    
    # 获取维度
    m, n_H, n_W, n_C = a_G.get_shape().as_list()
    
    # 展开
    a_C_unrolled = tf.transpose(tf.reshape(a_C, [n_H*n_W, n_C]))
    a_G_unrolled = tf.transpose(tf.reshape(a_G, [n_H*n_W, n_C]))
    
    # 计算
    J_content = tf.reduce_sum((a_C - a_G) ** 2) / (4 * n_H * n_W * n_C)
    
    return J_content

Style Cost

格拉姆矩阵(style martix)

风格迁移-TensorFlow

def gram_matrix(A):
    """
计算格拉姆矩阵
A -- 矩阵 (n_C, n_H*n_W)
"""
    GA = tf.matmul(A, tf.transpose(A))
    
    return GA

某一层的 style cost

$$J_{style}^{[l]}(S,G) = \frac{1}{4 \times {n_C}^2 \times (n_H \times n_W)^2} \sum _{i=1}^{n_C}\sum_{j=1}^{n_C}(G^{(S)}_{ij} - G^{(G)}_{ij})^2\tag{2} $$

def compute_layer_style_cost(a_S, a_G):
    """
a_S -- 图片S的隐藏层,(1, n_H, n_W, n_C)
a_G -- 图片G的隐藏层,(1, n_H, n_W, n_C)
"""
    
    # 获取维度
    m, n_H, n_W, n_C = a_G.get_shape().as_list()
    
    # 展开
    a_S = tf.transpose(tf.reshape(a_S, [n_H*n_W, n_C]))
    a_G = tf.transpose(tf.reshape(a_G, [n_H*n_W, n_C]))

    # 计算格拉姆矩阵
    GS = gram_matrix(a_S)
    GG = gram_matrix(a_G)

    # 计算 style cost
    J_style_layer = tf.reduce_sum(tf.square(GS - GG)) / (4 * (n_C * n_H * n_W)**2)
    
    return J_style_layer

将多个隐藏层的 style cost 组合

$$J_{style}(S,G) = \sum_{l} \lambda^{[l]} J^{[l]}_{style}(S,G)$$

def compute_style_cost(model, STYLE_LAYERS):
    """
计算最终的 style cost

STYLE_LAYERS -- 列表:
- 隐藏层的名字
- 系数
"""
    J_style = 0

    for layer_name, coeff in STYLE_LAYERS:
        out = model[layer_name]
        a_S = sess.run(out)
        a_G = out
        
        J_style_layer = compute_layer_style_cost(a_S, a_G)
        J_style += coeff * J_style_layer

    return J_style

最终的 Cost

$$J(G) = \alpha J_{content}(C,G) + \beta J_{style}(S,G)$$

def total_cost(J_content, J_style, alpha =10, beta =40):
    
    J = J_content * alpha + J_style * beta
    
    return J

定义 Session 并准备数据和模型

定义 Session

tf.reset_default_graph()
sess = tf.InteractiveSession()

准备数据和模型

content_image = scipy.misc.imread("images/content.jpg")
content_image = reshape_and_normalize_image(content_image)

style_image = scipy.misc.imread("images/picasso")
style_image = reshape_and_normalize_image(style_image)

generated_image = generate_noise_image(content_image)

# 导入模型
model = load_vgg_model("../models/imagenet-vgg-verydeep-19.mat")

print(model)

{‘input’: , ‘conv1_1’: , ‘conv1_2’: , ‘avgpool1’: , ‘conv2_1’: , ‘conv2_2’: , ‘avgpool2’: , ‘conv3_1’: , ‘conv3_2’: , ‘conv3_3’: , ‘conv3_4’: , ‘avgpool3’: , ‘conv4_1’: , ‘conv4_2’: , ‘conv4_3’: , ‘conv4_4’: , ‘avgpool4’: , ‘conv5_1’: , ‘conv5_2’: , ‘conv5_3’: , ‘conv5_4’: , ‘avgpool5’: }

这里,我传入的 content_image 和 style_image 分别是

content_image style_image
风格迁移-TensorFlow 风格迁移-TensorFlow

计算 Cost

计算 Content Cost

# 将 content_image 输入网络
sess.run(model['input'].assign(content_image))

# 选择隐藏层
out = model['conv4_2']

# a_C 为计算 Content Cost 的隐藏层
a_C = sess.run(out)

# 一会我们再把 generated_image 输入网络,即执行 sess.run(model['input']).assign(generated_image)
a_G = out

# 计算 Content Cost
J_content = compute_content_cost(a_C, a_G)

计算 Style Cost

STYLE_LAYERS = [
    ('conv1_1', 0.2),
    ('conv2_1', 0.2),
    ('conv3_1', 0.2),
    ('conv4_1', 0.2),
    ('conv5_1', 0.2)]

# 将 style_image 输入网络
sess.run(model['input'].assign(style_image))

# 计算 Style Cost
J_style = compute_style_cost(model, STYLE_LAYERS)

计算最终的 Cost

J = total_cost(J_content, J_style, alpha=10, beta=40)

生成图片

定义 optimizer 和 train_step

optimizer = tf.train.AdamOptimizer(2.0)

train_step = optimizer.minimize(J)

定义模型

def model_nn(sess, input_image, num_iterations =200):
    
    sess.run(tf.global_variables_initializer())
    
    # 将 input_image 输入网络
    sess.run(model['input'].assign(input_image))
    
    # 迭代
    for i in range(num_iterations):
    
        sess.run(train_step)
        
        # 得到网络的输出
        generated_image = sess.run(model['input'])

        if i%20 == 0:
            Jt, Jc, Js = sess.run([J, J_content, J_style])
            print("Iteration " + str(i) + " :")
            print("total cost = " + str(Jt))
            print("content cost = " + str(Jc))
            print("style cost = " + str(Js))
            
            # 保存当前图片
            save_image("output/" + str(i) + ".png", generated_image)
    
    # 保存最终图片
    save_image('output/generated_image.jpg', generated_image)
    
    return generated_image

训练

model_nn(sess, generated_image)
 Iteration 0 :  total cost = 9490227000.0  content cost = 7180.4365  style cost = 237253890.0  Iteration 20 :  total cost = 2201664800.0  content cost = 17532.338  style cost = 55037236.0  Iteration 40 :  total cost = 1026127170.0  content cost = 19366.562  style cost = 25648338.0  Iteration 60 :  total cost = 647172200.0  content cost = 19991.74  style cost = 16174307.0

最终生成的图片为

风格迁移-TensorFlow


以上就是本文的全部内容,希望本文的内容对大家的学习或者工作能带来一定的帮助,也希望大家多多支持 码农网

查看所有标签

猜你喜欢:

本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们

Learning PHP, MySQL, JavaScript, and CSS

Learning PHP, MySQL, JavaScript, and CSS

Robin Nixon / O'Reilly Media / 2012-9-3 / USD 39.99

If you're familiar with HTML, you can quickly learn how to build interactive, data-driven websites with the powerful combination of PHP, MySQL, and JavaScript - the top technologies for creating moder......一起来看看 《Learning PHP, MySQL, JavaScript, and CSS》 这本书的介绍吧!

HTML 压缩/解压工具
HTML 压缩/解压工具

在线压缩/解压 HTML 代码

XML、JSON 在线转换
XML、JSON 在线转换

在线XML、JSON转换工具

HSV CMYK 转换工具
HSV CMYK 转换工具

HSV CMYK互换工具