tensorflow keras 查找中间tensor并构建局部子图

栏目: 数据库 · 发布时间: 5年前

内容简介:在Mask_RCNN项目的示例项目此方法可以读取层的输出,对于输出多于1个tensor的,可以指定get_layer("rpn_class").output[0:2]等确定。但是对于自定义层的中间变量,就没办法获得了,因此需要使用方法二。

在Mask_RCNN项目的示例项目 nucleus 中,stepbystep步骤里面,需要对网络模型的中间变量进行提取和可视化,常见方式有两种:

通过 get_layer方法:

outputs = [
    ("rpn_class", model.keras_model.get_layer("rpn_class").output),
    ("proposals", model.keras_model.get_layer("ROI").output)
    ]

此方法可以读取层的输出,对于输出多于1个tensor的,可以指定get_layer("rpn_class").output[0:2]等确定。

但是对于自定义层的中间变量,就没办法获得了,因此需要使用方法二。

通过 tensor.op.inputs 逐层向上查找

定义一个迭代函数,不断查找

def find_in_tensor(tensor,name,index=0):
    index += 1
    if index >20:
        return
    tensor_parent = tensor.op.inputs
    for each_ptensor in tensor_parent:
        #print(each_ptensor.name)
        if bool(re.fullmatch(name, each_ptensor.name)):
            print('find it!')
            return each_ptensor
        result = find_in_tensor(each_ptensor,name,index)
        if result is not None:
            return result

接着获得某层的输出,调用迭代函数,找到该tensor

pillar = model.keras_model.get_layer("ROI").output
nms_rois = find_in_tensor(pillar,'ROI_3/rpn_non_max_suppression/NonMaxSuppressionV2:0')
outputs.append(('NonMaxSuppression',nms_rois))

最后,调用kf.fuction构建局部图,并运行:

submodel = model.keras_model
outputs = OrderedDict(outputs)
if submodel.uses_learning_phase and not isinstance(K.learning_phase(), int):
    inputs += [K.learning_phase()]
kf = K.function(submodel.inputs, list(outputs.values()))
in_p,ou_p = next(train_generator)
output_all = kf(in_p)

此时打印outputs可以看到类似如下:

OrderedDict([('rpn_class',<tf.Tensor 'rpn_class_3/concat:0' shape=(?, ?, 2) dtype=float32>),
             ('proposals',<tf.Tensor 'ROI_3/packed_2:0' shape=(1, ?, ?) dtype=float32>),
             ('fpn_p2',<tf.Tensor 'fpn_p2_3/BiasAdd:0' shape=(?, 192, 192, 256) dtype=float32>),
             ('fpn_p3',<tf.Tensor 'fpn_p3_3/BiasAdd:0' shape=(?, 96, 96, 256) dtype=float32>),
             ('fpn_p4',<tf.Tensor 'fpn_p4_3/BiasAdd:0' shape=(?, 48, 48, 256) dtype=float32>),
             ('fpn_p6',<tf.Tensor 'fpn_p6_3/MaxPool:0' shape=(?, 12, 12, 256) dtype=float32>),
             ('NonMaxSuppression',<tf.Tensor 'ROI_3/rpn_non_max_suppression/NonMaxSuppressionV2:0' shape=(?,) dtype=int32>)])

大功告成~


以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持 码农网

查看所有标签

猜你喜欢:

本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们

标签

标签

Gene Smith / 张军、陈军亮 / 机械工业出版社 / 2012-6 / 59.00元

本书对标记系统这一概念的内涵和外延进行了系统化的、深入浅出的阐述。从什么是标记系统、标记系统有什么价值,到标记系统的架构和与其他分类系统的对比,再到标签的呈现方式和标记系统的实现细节,作者都用通俗易懂的语言进行了阐述,并附有详细的示例和具体的案例研究。本书的每一章都涵盖了标记系统的一个方面,主要内容包括:标记系统的模型、价值、架构,标签的分类、可视化、管理方法,最后介绍标记系统设计方法。本书带领读......一起来看看 《标签》 这本书的介绍吧!

CSS 压缩/解压工具
CSS 压缩/解压工具

在线压缩/解压 CSS 代码

在线进制转换器
在线进制转换器

各进制数互转换器

HSV CMYK 转换工具
HSV CMYK 转换工具

HSV CMYK互换工具