连载二:PyCon2018|用slim微调PNASNet模型(附源码)

栏目: Python · 发布时间: 6年前

内容简介:有一组照片,分为男人和女人。本案就是让深度学习模型来学习这些样本,并能够找到其中的规律,完成模型的训练。接着可以使用该模型对图片中的人物进行识别,区分其性别是男还是女。

第八届中国 Python 开发者大会PyConChina2018,由PyChina.org发起,由来自CPyUG/TopGeek等社区的30位组织者,近150位志愿者在北京、上海、深圳、杭州、成都等城市举办。致力于推动各类Python相关的技术在互联网、企业应用等领域的研发和应用。

代码医生工作室有幸接受邀请,参加了这次会议的北京站专场。在会上主要分享了《人工智能实战案例分享-图像处理与数值分析》。

会上分享的一些案例主要是来源于《python带我起飞——入门、进阶、商业实战》一书与《深度学习之TensorFlow:入门、原理与进阶实战》一书。另外,还扩充了若干其它案例。在本文作为补充,将会上分享的其它案例以详细的图文方式补充进来,并提供源码。共分为4期连载。

  1. 用slim调用PNASNet模型

  2. 用slim微调PNASNet模型

  3. 用对抗样本攻击PNASNet模型

  4. 恶意域名检测实例

连载二:PyCon2018|用slim微调PNASNet模型(附源码)

通过微调模型实现分辨男女

案例描述

有一组照片,分为男人和女人。

本案就是让深度学习模型来学习这些样本,并能够找到其中的规律,完成模型的训练。接着可以使用该模型对图片中的人物进行识别,区分其性别是男还是女。

本案例中,使用了一个NASNet_A_Mobile的模型来做二次训练。具体过程分为4步:

(1)准备样本;

(2)准备NASNet_A_Mobile网络模型;

(3)编写代码进行二次训练;

(4)使用已经训练好的模型进行测试。

准备样本

通过如下链接下载CelebA数据集:

mmlab.ie.cuhk.edu.hk/projects/Ce…

下载完之后,解压,并手动分出一部分男人与女人的照片。

在本例中,一共用了20000张图片用来训练模型,其中训练样本由8421张男性头像和11599张女性头像构成(在train文件夹下),测试样本由10张男性头像和10张女性头像构成(在val文件夹下)。部分样本数据如图5-1。

连载二:PyCon2018|用slim微调PNASNet模型(附源码)

图5-1 男女数据集样本示例

数据样本整理好后,统一放到data文件夹下。该数据样本同样也可以在随书的配套资源中找到。

代码环境及模型准备

为了使读者能够快速完成该实例,直观上感受到模型的识别能力,可以直接使用本书配套的资源。并将其放到代码的同级目录下即可。

如果想体验下从零开始手动搭建,也可以按照下面的方法准备代码环境及预编译模型。

1. 下载models与部署TensorFlow slim模块

该部分的内容与3.1节完全一样,这里不再详述。

2. 下载NASNet_A_Mobile模型

该部分的内容与3.1节类似。在如图3-2中的倒数第3个模型,找到 “nasnet-a_mobile_04_10_2017.tar.gz”的下载链接。将其下载并解压。

3. 整体代码文件部署结构

本案例是通过4个代码文件来实现的,具体文件及描述如下:

l 5-1 mydataset.py:处理男女图片数据集的代码;

l 5-2 model.py:加载预编译模型NASNet_A_Mobile,并进行微调的代码;

l 5-3 train.py:训练模型的代码;

l 5-4 test.py:测试模型的代码。

部署时,将这4个代码文件与slim库、NASNet_A_Mobile模型、样本一起放到一个文件夹下即可。完整的文件结构如图5-2。

连载二:PyCon2018|用slim微调PNASNet模型(附源码)

图5-2 分辨男女案例的文件结构

代码实现:处理样本数据并生成Dataset对象

本案例中,直接将数据集的相关操作封装到了“5-1 mydataset.py”代码文件里。在该文件中,实现了符合训练与测试使用场景的数据集。在训练模式下,会对数据进行乱序处理;在测试模式下,直接使用顺序数据。两种数据集都是按批次读取。

这部分的知识在第4章已经有全面的介绍,这里不再详述。完整代码如下:

代码5-1 mydataset

1 import tensorflow as tf
 2 import sys                                      
 3 nets_path = r'slim'                                             #加载环境变量
 4 if nets_path not in sys.path:
 5    sys.path.insert(0,nets_path)
 6 else:
 7     print('already add slim')
 8 from nets.nasnet import nasnet                               #导出nasnet
 9 slim = tf.contrib.slim                                         #slim
10 image_size = nasnet.build_nasnet_mobile.default_image_size     #获得图片输入尺寸 224
11 from preprocessing import preprocessing_factory            #图像处理
12 
13 import os
14 def list_images(directory):
15    """
16    获取所有directory中的所有图片和标签
17    """
18
19    #返回path指定的文件夹包含的文件或文件夹的名字的列表
20    labels = os.listdir(directory)
21    #对标签进行排序,以便训练和验证按照相同的顺序进行
22    labels.sort()
23    #创建文件标签列表
24    files_and_labels = []
25    for label in labels:
26        for f in os.listdir(os.path.join(directory, label)):
27            #转换字符串中所有大写字符为小写再判断
28            if 'jpg' in f.lower() or 'png' in f.lower():
29                #加入列表
30                files_and_labels.append((os.path.join(directory, label, f), label))
31    #理解为解压 把数据路径和标签解压出来
32    filenames, labels = zip(*files_and_labels)
33    #转换为列表 分别储存数据路径和对应标签
34    filenames = list(filenames)
35    labels = list(labels)
36    #列出分类总数 比如两类:['man', 'woman']
37    unique_labels = list(set(labels))
38
39    label_to_int = {}
40    #循环列出数据和数据下标,给每个分类打上标签{'woman': 2, 'man': 1,none:0}
41    for i, label in enumerate(sorted(unique_labels)):
42        label_to_int[label] = i+1
43    print(label,label_to_int[label])
44    #把每个标签化为0 1 这种形式
45    labels = [label_to_int[l] for l in labels]
46    print(labels[:6],labels[-6:])
47    return filenames, labels              #返回储存数据路径和对应转换后的标签
48
49 num_workers = 2                          #定义并行处理数据的线程数量
50
51 #图像批量预处理
52 image_preprocessing_fn = preprocessing_factory.get_preprocessing('nasnet_mobile', is_training=True)
53 image_eval_preprocessing_fn = preprocessing_factory.get_preprocessing('nasnet_mobile', is_training=False)
54
55 def _parse_function(filename, label):      #定义图像解码函数
56    image_string = tf.read_file(filename)
57    image = tf.image.decode_jpeg(image_string, channels=3)          
58    return image, label
59
60 def training_preprocess(image, label):    #定义调整图像大小函数
61    image = image_preprocessing_fn(image, image_size, image_size)
62    return image, label
63
64 def val_preprocess(image, label):       #定义评估图像预处理函数
65    image = image_eval_preprocessing_fn(image, image_size, image_size)
66    return image, label
67
68 #创建带批次的数据集
69 def creat_batched_dataset(filenames, labels,batch_size,isTrain = True):
70
71    dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
72
73    dataset = dataset.map(_parse_function, num_parallel_calls=num_workers)    #对图像解码
74
75    if isTrain == True:
76        dataset = dataset.shuffle(buffer_size=len(filenames))                #打乱数据顺序
77        dataset = dataset.map(training_preprocess, num_parallel_calls=num_workers)#调整图像大小
78    else:
79        dataset = dataset.map(val_preprocess,num_parallel_calls=num_workers)    #调整图像大小
80
81    return dataset.batch(batch_size)                                           #返回批次数据
82
83 #根据目录返回数据集
84 def creat_dataset_fromdir(directory,batch_size,isTrain = True):
85    filenames, labels = list_images(directory)
86    num_classes = len(set(labels))
87    dataset = creat_batched_dataset(filenames, labels,batch_size,isTrain)
88    return dataset,num_classes 复制代码

代码11行,导入了preprocessing_factory函数,该函数是slim模块中封装好的工厂函数,用于生成模型的预处理函数。利用统一封装好的预处理函数,对样本进行操作(代码60、61行),可以提升开发效率,并能够减小出错的可能性。

工厂函数的知识点,属于Python基础知识,这里不再详述。有兴趣的读者可以参考《python带我起飞——入门、进阶、商业实战》一书的6.10节。

注意:

这里用了一个技巧。仿照原NASNet_A_Mobile模型的分类方法,在对分类标签排号时,将标签为0的分类空出来,男人与女人分别为1和2。

另外在代码42行,用到的变量unique_labels是从集合对象转化过来的。在使用时需要对齐固定顺序,所以使用了sorted函数进行变换。如果没有这句,在下次启动的时候,有可能出现标签序号与名称对应不上的现象。在多次中断,多次训练的场景下,会造成训练结果的混乱。这部分知识在《python带我起飞——入门、进阶、商业实战》的第四章集合部分的内容中,也做了重点的强调。

代码实现:定义微调模型类MyNASNetModel

在微调模型的实现中,统一通过定义类MyNASNetModel来实现。在类MyNASNetModel中,大致可分为2大动作:初始化设置、构建模型。

l 初始化设置:定义好构建模型时所需要的必要参数;

l 构建模型:针对训练、测试、应用的三种情况分别构建不同的网络模型。在训练过程中,还要支持加载预编译模型及微调模型。

实现定义类MyNASNetModel并进行初始化模型设置的代码如下:

代码5-2 model

1 import sys                                      
  2 nets_path = r'slim'                                    #加载环境变量
  3 if nets_path not in sys.path:
  4    sys.path.insert(0,nets_path)
  5 else:
  6    print('already add slim')
  7
  8 import tensorflow as tf
  9 from nets.nasnet import nasnet                       #导出nasnet
 10 slim = tf.contrib.slim 
 11
 12 import os  
 13 mydataset = __import__("5-1  mydataset")
 14 creat_dataset_fromdir = mydataset.creat_dataset_fromdir
 15
 16 class MyNASNetModel(object):
 17    """微调模型类MyNASNetModel
 18    """
 19    def __init__(self, model_path=''):
 20        self.model_path = model_path              #原始模型的路径           复制代码

代码20行为初始化MyNASNetModel类的操作。model_path指的是所要加载的原始预编译模型。该操作只有在训练模式下是有意义的。在测试和应用模式下,可以为空。

构建MyNASNetModel类中的基本模型

在构建模型中,无论是训练、测试还是应用,都需要将最基本的NASNet_A_Mobile模型载入。这里通过定义MyNASNetModel类的MyNASNet方法来实现。具体的实现方式与3.3节的实现基本一致,不同的是3.3节构建的是PNASNet网络结构,这里构建的NASNet_A_Mobile结构。

代码5-2 model(续)

21  def MyNASNet(self,images,is_training):
 22        arg_scope = nasnet.nasnet_mobile_arg_scope()          #获得模型命名空间
 23        with slim.arg_scope(arg_scope):
 24            #构建NASNet Mobile模型
 25            logits, end_points = nasnet.build_nasnet_mobile(images,num_classes = self.num_classes+1, is_training=is_training)
 26
 27        global_step = tf.train.get_or_create_global_step()      #定义记录步数的张量
 28
 29        return logits,end_points,global_step                   #返回有用的张量
复制代码

代码25行中,往num_classes参数里传的值代表分类的个数,在本案例中分为男人和女人,一共两类(即,self.num_classes=2,该值是在后文5.2.8节中,build_model方法被赋值的)。再加上一个None类。于是传入的值为self.num_classes+1。

实现MyNASNetModel类中的微调操作

微调操作是针对训练场景下使用的。通过定义MyNASNetModel类中的FineTuneNASNet方法来实现。微调操作主要是对预编译模型的超参进行选择性恢复。

因为预编译模型NASNet_A_Mobile是在ImgNet上训练的,有1000个分类,而本案例中识别男女的任务只有两个分类。所以最后两个输出层的超参不应该被恢复(由于分类不同,导致超参的个数不同)。在实际使用时,最后两层的参数需要对其初始化,并单独训练即可。

代码5-2 model(续)

30 def FineTuneNASNet(self,is_training):      #实现微调模型的网络操作 
 31        model_path = self.model_path
 32
 33        exclude = ['final_layer','aux_7']      #恢复超参, 除了exclude以外的全部恢复
 34        variables_to_restore = slim.get_variables_to_restore(exclude=exclude)
 35        if is_training == True:
 36            init_fn = slim.assign_from_checkpoint_fn(model_path, variables_to_restore)
 37        else:
 38            init_fn = None
 39
 40        tuning_variables = []             #将没有恢复的超参收集起来,用于微调训练
 41        for v in exclude:
 42            tuning_variables += slim.get_variables(v)
 43
 44        return init_fn, tuning_variables复制代码

代码中,使用了exclude列表,将不需要恢复的网络节点收集起来(代码33行),接着将预训练模型中的超参值赋值给剩下的节点,完成了预训练模型的载入(代码36行)。最后使用了tuning_variables列表,将不需要恢复的网络节点权重收集起来(代码40行),用于微调训练。

注意:

这里介绍个技巧,如何获得exclude中的元素(代码33行):通过额外执行代码tf.global_variables(),将张量图中的节点打印出来。从里面找到最后两层的节点,并将其填入代码中即可。在找到节点后,还可以通过slim.get_variables函数,来检查该名称的节点是否正确。例如,可以通过将slim.get_variables('final_layer')的返回值打印出来,来观察张量图中是否有final_layer节点。这部分的原理可以参考《深度学习之TensorFlow:入门、原理与进阶实战》书中第4章的内容(在第11章也有类似的案例)。

代码实现:实现与训练相关的其他方法

在MyNASNetModel类中,还需要定义与训练操作相关的其他方法,具体如下:

l build_acc_base方法:用于构建评估模型的相关节点;

l load_cpk方法:用于载入及保存模型检查点

l build_model_train方法:用于构建训练模型中的损失函数及优化器等操作节点。

具体代码如下:

代码5-2 model(续)

45 def build_acc_base(self,labels):#定义评估函数
 46        #返回张量中最大值的索引
 47        self.prediction = tf.to_int32(tf.argmax(self.logits, 1))
 48        #计算prediction、labels是否相同 
 49        self.correct_prediction = tf.equal(self.prediction, labels)
 50        #计算平均值
 51        self.accuracy = tf.reduce_mean(tf.to_float(self.correct_prediction))
 52        #将前5个最高正确率的值取出来,计算平均值
 53        self.accuracy_top_5 = tf.reduce_mean(tf.to_float(tf.nn.in_top_k(predictions=self.logits, targets=labels, k=5)))
 54
 55    def load_cpk(self,global_step,sess,begin = 0,saver= None,save_path = None):                                                    #储存和导出模型
 56       if begin == 0:
 57            save_path=r'./train_nasnet'                      #定义检查点路径
 58            if not os.path.exists(save_path):
 59                print("there is not a model path:",save_path)
 60            saver = tf.train.Saver(max_to_keep=1)            #生成saver
 61            return saver,save_path
 62        else:
 63            kpt = tf.train.latest_checkpoint(save_path)    #查找最新的检查点
 64            print("load model:",kpt)
 65            startepo= 0                                    #计步
 66            if kpt!=None:
 67                saver.restore(sess, kpt)                     #还原模型
 68                ind = kpt.find("-")
 69                startepo = int(kpt[ind+1:])
 70                print("global_step=",global_step.eval(),startepo)    
 71            return startepo  
 72
 73    def build_model_train(self,images,
 74           labels,learning_rate1,learning_rate2,is_training):
 75           self.logits,self.end_points, 
 76           self.global_step= self.MyNASNet(images,is_training=is_training)
 77        self.step_init = self.global_step.initializer
 78
 79        self.init_fn,self.tuning_variables = self.FineTuneNASNet(
 80            is_training=is_training)
 81        #定义损失函数
 82       tf.losses.sparse_softmax_cross_entropy(labels=labels, 
 83            logits=self.logits)
 84        loss = tf.losses.get_total_loss()
 85        #定义微调的率退化学习速率
 86        learning_rate1=tf.train.exponential_decay(
 87                 learning_rate=learning_rate1, global_step=self.global_step,
 88                 decay_steps=100, decay_rate=0.5)
 89        #定义联调的率退化学习速率
 90        learning_rate2=tf.train.exponential_decay(
 91             learning_rate=learning_rate2, global_step=self.global_step,
 92             decay_steps=100, decay_rate=0.2)                
 93        last_optimizer = tf.train.AdamOptimizer(learning_rate1) #优化器
 94        full_optimizer = tf.train.AdamOptimizer(learning_rate2)   
 95        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)  
 96        with tf.control_dependencies(update_ops):      #更新批量归一化中的参数
 97            #使loss减小方向做优化
 98            self.last_train_op = last_optimizer.minimize(loss, self.global_step,var_list=self.tuning_variables)
 99            self.full_train_op = full_optimizer.minimize(loss, self.global_step)
100
101        self.build_acc_base(labels)                    #定义评估模型相关指标
102        #写入日志,支持tensorBoard操作
103        tf.summary.scalar('accuracy', self.accuracy)    
104        tf.summary.scalar('accuracy_top_5', self.accuracy_top_5)
105
106        #将收集的所有默认图表并合并
107        self.merged = tf.summary.merge_all()
108        #写入日志文件
109        self.train_writer = tf.summary.FileWriter('./log_dir/train')
110        self.eval_writer = tf.summary.FileWriter('./log_dir/eval')
111        #定义检查点相关变量
112        self.saver,self.save_path = self.load_cpk(self.global_step,None)复制代码

在上面代码中,使用了tf.losses接口来获得loss值。通过调用tf.losses.sparse_softmax_cross_entropy 函数计算具体的loss(见代码82行)。该函数会自动将loss值添加到内部集合ops.GraphKeys.LOSSES中。然后调用tf.losses.get_total_loss函数,将ops.GraphKeys.LOSSES集合中的所有loss值获取,并返回回来(见代码84行)。

在代码96行中,在反向优化时,使用了tf.control_dependencies函数对的批量归一化操作中的均值与方差进行更新。

代码实现:构建模型,用于训练、测试、使用

在MyNASNetModel类中,定义build_model方法用与构建模型的实现。在build_model方法中,通过参数mode来指定模型的具体使用场景。具体代码如下:

代码5-2 model(续)

113 def build_model(self,mode='train',testdata_dir='./data/val',traindata_dir='./data/train', batch_size=32,learning_rate1=0.001,learning_rate2=0.001):
114
115        if mode == 'train':        
116            tf.reset_default_graph()
117            #创建训练数据和测试数据的Dataset数据集
118            dataset,self.num_classes = creat_dataset_fromdir(traindata_dir,batch_size)
119            testdataset,_ = creat_dataset_fromdir(testdata_dir,batch_size,isTrain = False)
120
121            #创建一个可初始化的迭代器
122            iterator = tf.data.Iterator.from_structure(dataset.output_types, dataset.output_shapes)
123            #读取数据
124            images, labels = iterator.get_next()
125
126            self.train_init_op = iterator.make_initializer(dataset)
127            self.test_init_op = iterator.make_initializer(testdataset)
128
129            self.build_model_train(images, labels,learning_rate1,learning_rate2,is_training=True)
130            self.global_init = tf.global_variables_initializer()    #定义全局初始化op
131            tf.get_default_graph().finalize()                #将后续的图设为只读
132        elif mode == 'test':
133            tf.reset_default_graph()
134
135            #创建测试数据的Dataset数据集
136            testdataset,self.num_classes = creat_dataset_fromdir(testdata_dir,batch_size,isTrain = False)
137
138            #创建一个可初始化的迭代器
139            iterator = tf.data.Iterator.from_structure(testdataset.output_types, testdataset.output_shapes)
140            #读取数据
141            self.images, labels = iterator.get_next()
142
143            self.test_init_op = iterator.make_initializer(testdataset)
144            self.logits,self.end_points, self.global_step= self.MyNASNet(self.images,is_training=False)
145            self.saver,self.save_path = self.load_cpk(self.global_step,None)                  #定义检查点相关变量
146            #评估指标
147            self.build_acc_base(labels)
148            tf.get_default_graph().finalize()            #将后续的图设为只读
149        elif mode == 'eval':
150            tf.reset_default_graph()
151            #创建测试数据的Dataset数据集
152            testdataset,self.num_classes = creat_dataset_fromdir(testdata_dir,batch_size,isTrain = False)
153
154            #创建一个可初始化的迭代器
155            iterator = tf.data.Iterator.from_structure(testdataset.output_types, testdataset.output_shapes)
156            #读取数据
157            self.images, labels = iterator.get_next()
158
159            self.logits,self.end_points, self.global_step= self.MyNASNet(self.images,is_training=False)
160            self.saver,self.save_path = self.load_cpk(self.global_step,None)   #定义检查点相关变量
161            tf.get_default_graph().finalize()                        #将后续的图设为只读复制代码

代码115行,对mode进行了判断,并按照具体的场景进行构建模型。针对训练、测试、使用的三个场景,构建的步骤几乎一样,具体如下:

(1)清空张量图(代码116、133、150);

(2)生成数据集(代码118、136、152);

(3)定义网络结构(代码129、144、159)。

测试与使用的场景是最相似的。在代码中测试比使用的操作对了个评估节点的生成(代码147)。

注意:

在每个操作分支的最后代码部分都加了代码tf.get_default_graph().finalize()(见代码131、148、161行),这是一个很好的习惯。该代码的功能是把图锁定,之后想要添加任何新的操作都会产生错误。这么做的意图是防止在后面训练或是测试过程中,由于开发人员疏忽,在图中添加额外的图操作。一旦在循环内部加了某个张量的操作,将会使整体性能大大下降。然而这种错误又很难发现。利用锁定图的方法,可以避免这种情况的发生。

代码实现:通过二次迭代来训练微调模型

训练微调模型的操作是在代码文件“5-3 train.py”中单独实现的。与正常的训练方式不同,这里使用了二次迭代的方式:

l 第一次迭代:微调模型,固定预编译模型载入的权重,只训练最后两层;

l 第二次迭代:联调模型,使用更小的学习率,训练全部节点。

先将类MyNASNetModel进行实例化,在调用其build_model方法构建模型,然后使用session开始训练。具体代码如下:

代码5-3 train

162 import tensorflow as tf
163 model = __import__("5-2  model")
164 MyNASNetModel = model.MyNASNetModel
165
166 batch_size = 32
167 train_dir  = 'data/train'
168 val_dir  = 'data/val'
169
170 learning_rate1 = 1e-1                                       #定义两次迭代的学习率
171 learning_rate2 = 1e-3
172
173 mymode = MyNASNetModel(r'nasnet-a_mobile_04_10_2017\model.ckpt')#初始化模型
174 mymode.build_model('train',val_dir,train_dir,batch_size,learning_rate1 ,learning_rate2 )                                                                    #将模型定义载入图中
175
176 num_epochs1 = 20                                           #微调的迭代次数
177 num_epochs2 = 200                                        #联调的迭代次数
178
179 with tf.Session() as sess:
180    sess.run(mymode.global_init)                         #初始全局节点
181
182    step = 0
183    step = mymode.load_cpk(mymode.global_step,sess,1,mymode.saver,mymode.save_path )#载入模型
184    print(step)
185    if step == 0:                                        #微调
186        mymode.init_fn(sess)                                 #载入预编译模型权重
187
188        for epoch in range(num_epochs1):
189
190            print('Starting1 epoch %d / %d' % (epoch + 1, num_epochs1))    #输出进度 
191            #用训练集初始化迭代器
192            sess.run(mymode.train_init_op)                                #数据集从头开始
193            while True:
194                try:
195                    step += 1
196                    #预测,合并图,训练
197                    acc,accuracy_top_5, summary, _ = sess.run([mymode.accuracy, mymode.accuracy_top_5,mymode.merged,mymode.last_train_op])
198
199                    #mymode.train_writer.add_summary(summary, step)#写入日志文件
200                    if step % 100 == 0:
201                        print(f'step: {step} train1 accuracy: {acc},{accuracy_top_5}')
202                except tf.errors.OutOfRangeError:#数据集指针在最后
203                    print("train1:",epoch," ok")
204                    mymode.saver.save(sess, mymode.save_path+"/mynasnet.cpkt",   global_step=mymode.global_step.eval())
205                    break
206
207        sess.run(mymode.step_init)                    #微调结束,计数器从0开始
208
209    #整体训练
210    for epoch in range(num_epochs2):
211        print('Starting2 epoch %d / %d' % (epoch + 1, num_epochs2))
212        sess.run(mymode.train_init_op)
213        while True:
214            try:
215                step += 1
216                #预测,合并图,训练
217                acc, summary, _ = sess.run([mymode.accuracy, mymode.merged, mymode.full_train_op])
218
219                mymode.train_writer.add_summary(summary, step)#写入日志文件
220
221                if step % 100 == 0:
222                    print(f'step: {step} train2 accuracy: {acc}')
223            except tf.errors.OutOfRangeError:
224                print("train2:",epoch," ok")
225                mymode.saver.save(sess, mymode.save_path+"/mynasnet.cpkt",   global_step=mymode.global_step.eval())
226                break复制代码

将以上代码运行后,经过一段时间的训练,可以在本地找到“train_nasnet”文件夹,里面放着的就是训练生成的模型文件。

代码实现:测试模型

测试模型的操作是在代码文件“5-4 test.py”中单独实现的。这里实现了使用测试数据集对现有模型的评估,并且使用单张图片放到模型里进行预测。

1. 定义测试模型所需要的功能函数

首先定义函数check_accuracy实现准确率的计算,接着定义函数check_sex实现男女性别的识别。具体代码如下:

代码5-4 test

227 import tensorflow as tf
228 model = __import__("5-2  model")
229 MyNASNetModel = model.MyNASNetModel
230
231 import sys                                      
232 nets_path = r'slim'                                     #加载环境变量
233 if nets_path not in sys.path:
234    sys.path.insert(0,nets_path)
235 else:
236    print('already add slim')
237
238 from nets.nasnet import nasnet                     #导出nasnet
239 slim = tf.contrib.slim                                 #slim
240 image_size = nasnet.build_nasnet_mobile.default_image_size  #获得图片输入尺寸 224
241
242 import numpy as np
243 from PIL import Image
244
245 batch_size = 32
246 test_dir  = 'data/val'
247
248 def check_accuracy(sess):
249    """
250    测试模型准确率
251    """
252    sess.run(mymode.test_init_op)                  #初始化测试数据集
253    num_correct, num_samples = 0, 0                 #定义正确个数 和 总个数
254    i = 0
255    while True:
256        i+=1
257        print('i',i)
258        try:
259            #计算correct_prediction 获取prediction、labels是否相同 
260            correct_pred,accuracy,logits = sess.run([mymode.correct_prediction,mymode.accuracy,mymode.logits])
261            #累加correct_pred
262            num_correct += correct_pred.sum()
263            num_samples += correct_pred.shape[0]
264            print("accuracy",accuracy,logits)
265
266
267        except tf.errors.OutOfRangeError:          #捕获异常,数据用完自动跳出
268            print('over')
269            break
270
271    acc = float(num_correct) / num_samples         #计算并返回准确率
272    return acc 
273
274
275 def check_sex(imgdir,sess):                        #定义函数识别男女
276    img = Image.open(image_dir)                      #读入图片
277    if "RGB"!=img.mode :                             #检查图片格式
278        img = img.convert("RGB") 
279
280    img = np.asarray(img.resize((image_size,image_size)),     #图像预处理  
281                          dtype=np.float32).reshape(1,image_size,image_size,3)
282    img = 2 *( img / 255.0)-1.0 
283
284    prediction = sess.run(mymode.logits, {mymode.images: img})#传入nasnet输入端中
285    print(prediction)
286
287    pre = prediction.argmax()                    #返回张量中最大值的索引
288    print(pre)
289
290    if pre == 1: img_id = 'man'
291    elif pre == 2: img_id = 'woman'
292    else: img_id = 'None'
293    plt.imshow( np.asarray((img[0]+1)*255/2,np.uint8 )  )
294    plt.show()
295    print(img_id,"--",image_dir)                    #返回类别
296    return pre复制代码

2. 建立会话,进行测试

首先建立会话session,对模型进行测试,接着取2张图片输入模型,进行男女的判断。具体代码如下:

代码5-4 test(续)

297 mymode = MyNASNetModel()                                     #初始化模型
298 mymode.build_model('test',test_dir )                     #将模型定义载入图中
299
300 with tf.Session() as sess:  
301    #载入模型
302    mymode.load_cpk(mymode.global_step,sess,1,mymode.saver,mymode.save_path )
303
304    #测试模型的准确性
305    val_acc = check_accuracy(sess)
306    print('Val accuracy: %f\n' % val_acc)
307
308    #单张图片测试
309    image_dir = 'tt2t.jpg'                                 #选取测试图片
310    check_sex(image_dir,sess)
311
312    image_dir = test_dir + '\\woman' + '\\000001.jpg'       #选取测试图片
313    check_sex(image_dir,sess)
314
315    image_dir = test_dir + '\\man' + '\\000003.jpg'         #选取测试图片
316    check_sex(image_dir,sess)复制代码

该程序使用的是迭代了100次数据集后的模型文件(如果要效果提高,可以再运行久一点)。代码运行后,输出结果如下。

(1)显示测试集的输出结果:

i 1

accuracy 0.90625 [[-3.813714 1.4075054 1.1485975 ]

[-7.3948846 6.220533 -1.4093535 ]

[-1.9391974 3.048838 0.21784738]

[-3.873174 4.530942 0.43135062]

……

[-3.8561587 2.7012844 -0.3634925 ]

[-4.4860134 4.7661724 -0.67080706]

[-2.9615571 2.8164086 0.71033645]]

i 2

accuracy 0.90625 [[ -6.6900268 -2.373093 6.6710057 ]

[ -4.1005263 0.74619263 4.980012 ]

[ -5.6469827 0.39027584 1.2689826 ]

……

[ -5.8080773 0.9121424 3.4134243 ]

[ -4.242001 0.08483959 4.056322 ]]

i 3

over

Val accuracy: 0.906250

上面显示的是测试集中man和woman文件夹中图片的计算结果。最终模型的准确率为90%。

(2)显示单张图片的运行结果:

[[-4.8022223 1.9008529 1.9379601]]

2

连载二:PyCon2018|用slim微调PNASNet模型(附源码)

图5-3 分辨男女测试图片(a)

woman -- tt2t.jpg

[[-6.181205 -2.9042015 6.1356106]]

2

连载二:PyCon2018|用slim微调PNASNet模型(附源码)

图5-3 分辨男女测试图片(b)

woman -- data/val\woman\000001.jpg

[[-4.896065 1.7791721 1.3118265]]

1

连载二:PyCon2018|用slim微调PNASNet模型(附源码)

图5-3 分辨男女测试图片(c)

man -- data/val\man\000003.jpg

上面显示了3张图片,分别为自选图片、测试数据集中的女人图片、测试数据集中的男人图片,每张图片下面显示了模型识别的结果。可以看到结果与图片内容一致。

结尾

文内代码可以直接运行使用。如果不想手动搭建,还可以下载本文的配套代码。

【代码获取】:关注公众号: xiangyuejiqiren    公众号回复“ pycon2

如果觉得本文有用

可以分享给更多小伙伴

连载二:PyCon2018|用slim微调PNASNet模型(附源码)


以上所述就是小编给大家介绍的《连载二:PyCon2018|用slim微调PNASNet模型(附源码)》,希望对大家有所帮助,如果大家有任何疑问请给我留言,小编会及时回复大家的。在此也非常感谢大家对 码农网 的支持!

查看所有标签

猜你喜欢:

本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们

Spark SQL内核剖析

Spark SQL内核剖析

朱锋、张韶全、黄明 / 电子工业出版社 / 2018-8 / 69.00元

Spark SQL 是 Spark 技术体系中较有影响力的应用(Killer application),也是 SQL-on-Hadoop 解决方案 中举足轻重的产品。《Spark SQL内核剖析》由 11 章构成,从源码层面深入介绍 Spark SQL 内部实现机制,以及在实际业务场 景中的开发实践,其中包括 SQL 编译实现、逻辑计划的生成与优化、物理计划的生成与优化、Aggregation 算......一起来看看 《Spark SQL内核剖析》 这本书的介绍吧!

RGB转16进制工具
RGB转16进制工具

RGB HEX 互转工具

在线进制转换器
在线进制转换器

各进制数互转换器

SHA 加密
SHA 加密

SHA 加密工具