Tensorflow上手2: Keras的技巧和弊端

栏目: 数据库 · 发布时间: 5年前

内容简介:就在不久前,TF 2.0的预告发布,大家都在讨论着Tensorflow接口的变化,于是我也开始尝试使用Tensorflow版本的Keras.Keras是一个非常易用的框架,提供了更好的神经网络层级Layer的抽象,但是真正实现大规模模型训练时却遇到了一些坑.本文从我个人使用经历中看看Keras的灵巧和缺点,其中有一些坑可能大家会有更好的解决方案,欢迎多多指教.层(Layer)是Keras当中最重要的概念之一,但是Keras本身提供的Layer实现又不如Tensorflow的计算图那么的多样化与方便,所以Ke
Tensorflow上手2: Keras的技巧和弊端

就在不久前,TF 2.0的预告发布,大家都在讨论着Tensorflow接口的变化,于是我也开始尝试使用Tensorflow版本的Keras.Keras是一个非常易用的框架,提供了更好的神经网络层级Layer的抽象,但是真正实现大规模模型训练时却遇到了一些坑.

本文从我个人使用经历中看看Keras的灵巧和缺点,其中有一些坑可能大家会有更好的解决方案,欢迎多多指教.

Lambda层与自定义层

层(Layer)是Keras当中最重要的概念之一,但是Keras本身提供的Layer实现又不如Tensorflow的计算图那么的多样化与方便,所以Keras就通过Lambda层和自定义层对其灵活性进行拓展.

Lambda层的使用最为简单:

from tensorflow.keras.layers import Lambda
model.add(Lambda(lambda x: x**2))

其次是自定义层,相比Lambda层的好处是,自定义可以给Layer增加新的可以训练的参数,这些参数需要在build函数中进行定义,比如说一个自定义的dot product层(代码来自官网):

class MyLayer(Layer):
 def __init__(self, output_dim, **kwargs):
 super(MyLayer, self).__init__(**kwargs)
 self.output_dim = output_dim
def build(self, input_shape):
 self.kernel = self.add_weight(
 name=’kernel’,
 shape=(input_shape[1], self.output_dim)
 initializer=’uniform’,
 trainable=True)
def call(self, x):
 return K.dot(x, self.kernel)
def compute_output_shape(self, input_shape):
 return (input_shape[0], self.output_dim)

通过Lambda层和自定义层的灵活运用,人们可以用Keras写出一个很好的Mask RCNN代码,并且通过Keras提供的可视化函数plot_model,讲网络结构打印出来.

Lambda层和自定义层虽然很灵活,但是真正使用过程中还是会遇到不少坑.

TypeError: can’t pickle _thread.lock objects

有时候在你的网络中有Lambda层,保存到时候会遇到以上错误.这通常是因为Lambda层在进行序列化的时候无法序列化你使用的某一个 Python 非静态函数.

这时候有三种解决办法,第一种是在保存模型的时候选择model.save_weights而非model.save,毕竟大部分时候你不需要原先的训练结构.第二种办法是采用functools.partial等函数将Python函数包装成静态函数.第三种办法时放弃Lambda层,讲其函数和参数包装成一个自定义层.我个人推荐这一方案,原因会在下一个问题揭晓.

推荐使用自定义层

从我个人的使用来看,使用自定义层能够更好的对模型结构信息进行存储,包括每一个自定义层采用的参数等等.在新版的Tensorflow中,可以通过Keras导出Estimator使用的模型,这时候我们需要每一个自定义层使用的参数,这些参数可以通过自定义层的如下函数导出:

class MyLayer(Layer):
 def __init__(self, output_dim, **kwargs):
 super(MyLayer, self).__init__(**kwargs)
 self.output_dim = output_dim
def get_conf(self):
 config = super(MyLayer, self).get_config()
 config.update({‘output_dim’: self.output_dim})
 return config

双输出时一定要用list

Keras的灵活性还在于一个层可以有多个输出,就比方Mask RCNN里面的fpn_classifier_graph,在我们使用自定义层产生多输出的时候,既可以

return out1, out2

也可以

return [out1, out2]

这时候请记住,一定要选择返回list,不然你试试多GPU训练的情况就知道了.

关于Keras模型中的Loss

Keras中对Loss的基本定义是一个输入为y_true和y_pred的函数,但是在特殊情况下,他也可以结合权重进行复杂的运算.

就我个人写代码,阅读代码,阅读博客的经验来看,Kera的自定义loss有很多种写法,非常灵活,但与此同时也会遇到不同的问题.

model.add_loss

抛开最基本的model.compile(loss=…),我第一个尝试的是模仿Mask RCNN,通过model.add_loss函数另模型同时优化多个损失函数.使用add_loss的主要原因是传入的函数不能通过简单的y_true, y_pred进行计算,比方Mask RCNN要同时计算边界的损失和分类的损失,写成两个张量的表达形式很复杂并且不容易扩展.

然而add_loss也有自己的弊端,比方说如果我调用函数model_to_estimator,那么我加入的loss就没有了.至于我为什么要调用这个函数,下文会给予介绍.

layer.loss

通过阅读让Keras更酷一些,我了解到我们还可以通过自定义层添加loss,如下所述:

class MyLayer(Layer):
 def loss(self, y_true, y_pred):
 # do something

在原文中,作者给出的函数可以访问该层的一些参数,这样以来可以更灵活的给自定义层的参数进行一定的约束,但是由于一方面我本人没有试验过,不清楚这样做的利弊,另一方面看起来函数的借口还是不够灵活,而且我也不确定如何将多个不同层的损失叠加到一起,所以不能做更多的评价.

回到model.compile(loss=…)

因为一些特殊的原因,我最后还是尝试了将所有的损失函数通过一个自定义层合并,然后通过compile添加到模型里的做法.写成程序很简单:

model.compile(
 optimizer=’adam’,
 loss=lambda y_true, y_pred: pred)

结语,关于回调和分布式训练

Keras的回调是一个很厉害的功能,可以通过回调生成Tensorboard的Summary,调整训练速率,提前结束训练等等,但是使用不好的话很可能在训练过程中造成内存溢出.

最后,我暂时的放弃了Keras,其实并不是因为别的什么原因,而是在我研究如何利用Keras进行分布式训练的时候,我看到了这么一个推荐方案:

End to end example for multi worker training in tensorflow/ecosystem using Kuberentes templates. This example starts with a Keras model and converts it to an Estimator using the tf.keras.estimator.model_to_estimator API.

那么,还是让我们直接使用Estimator比较方便吧.


以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持 码农网

查看所有标签

猜你喜欢:

本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们

部落:一呼百应的力量

部落:一呼百应的力量

高汀 (Godin.S.) / 刘晖 / 中信出版社 / 2009-7 / 26.00元

部落指的是任何一群人,规模可大可小,他们因追随领导、志同道合而相互联系在一起。人类其实数百万年前就有部落的出现,随之还形成了宗教、种族、政治或甚至音乐。 互联网消除了地理隔离,降低了沟通成本并缩短了时间。博客和社交网站都有益于现有的部落扩张,并促进了网络部落的诞生——这些部落的人数从10个到1000万个不等,他们所关注的也许是iPhone,或一场政治运动,或阻止全球变暖的新方法。 那么......一起来看看 《部落:一呼百应的力量》 这本书的介绍吧!

HTML 压缩/解压工具
HTML 压缩/解压工具

在线压缩/解压 HTML 代码

JSON 在线解析
JSON 在线解析

在线 JSON 格式化工具

RGB CMYK 转换工具
RGB CMYK 转换工具

RGB CMYK 互转工具