连载二:PyCon2018|用slim微调PNASNet模型(附源码)

栏目: Python · 发布时间: 7年前

内容简介:有一组照片,分为男人和女人。本案就是让深度学习模型来学习这些样本,并能够找到其中的规律,完成模型的训练。接着可以使用该模型对图片中的人物进行识别,区分其性别是男还是女。

第八届中国 Python 开发者大会PyConChina2018,由PyChina.org发起,由来自CPyUG/TopGeek等社区的30位组织者,近150位志愿者在北京、上海、深圳、杭州、成都等城市举办。致力于推动各类Python相关的技术在互联网、企业应用等领域的研发和应用。

代码医生工作室有幸接受邀请,参加了这次会议的北京站专场。在会上主要分享了《人工智能实战案例分享-图像处理与数值分析》。

会上分享的一些案例主要是来源于《python带我起飞——入门、进阶、商业实战》一书与《深度学习之TensorFlow:入门、原理与进阶实战》一书。另外,还扩充了若干其它案例。在本文作为补充,将会上分享的其它案例以详细的图文方式补充进来,并提供源码。共分为4期连载。

  1. 用slim调用PNASNet模型

  2. 用slim微调PNASNet模型

  3. 用对抗样本攻击PNASNet模型

  4. 恶意域名检测实例

连载二:PyCon2018|用slim微调PNASNet模型(附源码)

通过微调模型实现分辨男女

案例描述

有一组照片,分为男人和女人。

本案就是让深度学习模型来学习这些样本,并能够找到其中的规律,完成模型的训练。接着可以使用该模型对图片中的人物进行识别,区分其性别是男还是女。

本案例中,使用了一个NASNet_A_Mobile的模型来做二次训练。具体过程分为4步:

(1)准备样本;

(2)准备NASNet_A_Mobile网络模型;

(3)编写代码进行二次训练;

(4)使用已经训练好的模型进行测试。

准备样本

通过如下链接下载CelebA数据集:

mmlab.ie.cuhk.edu.hk/projects/Ce…

下载完之后,解压,并手动分出一部分男人与女人的照片。

在本例中,一共用了20000张图片用来训练模型,其中训练样本由8421张男性头像和11599张女性头像构成(在train文件夹下),测试样本由10张男性头像和10张女性头像构成(在val文件夹下)。部分样本数据如图5-1。

连载二:PyCon2018|用slim微调PNASNet模型(附源码)

图5-1 男女数据集样本示例

数据样本整理好后,统一放到data文件夹下。该数据样本同样也可以在随书的配套资源中找到。

代码环境及模型准备

为了使读者能够快速完成该实例,直观上感受到模型的识别能力,可以直接使用本书配套的资源。并将其放到代码的同级目录下即可。

如果想体验下从零开始手动搭建,也可以按照下面的方法准备代码环境及预编译模型。

1. 下载models与部署TensorFlow slim模块

该部分的内容与3.1节完全一样,这里不再详述。

2. 下载NASNet_A_Mobile模型

该部分的内容与3.1节类似。在如图3-2中的倒数第3个模型,找到 “nasnet-a_mobile_04_10_2017.tar.gz”的下载链接。将其下载并解压。

3. 整体代码文件部署结构

本案例是通过4个代码文件来实现的,具体文件及描述如下:

l 5-1 mydataset.py:处理男女图片数据集的代码;

l 5-2 model.py:加载预编译模型NASNet_A_Mobile,并进行微调的代码;

l 5-3 train.py:训练模型的代码;

l 5-4 test.py:测试模型的代码。

部署时,将这4个代码文件与slim库、NASNet_A_Mobile模型、样本一起放到一个文件夹下即可。完整的文件结构如图5-2。

连载二:PyCon2018|用slim微调PNASNet模型(附源码)

图5-2 分辨男女案例的文件结构

代码实现:处理样本数据并生成Dataset对象

本案例中,直接将数据集的相关操作封装到了“5-1 mydataset.py”代码文件里。在该文件中,实现了符合训练与测试使用场景的数据集。在训练模式下,会对数据进行乱序处理;在测试模式下,直接使用顺序数据。两种数据集都是按批次读取。

这部分的知识在第4章已经有全面的介绍,这里不再详述。完整代码如下:

代码5-1 mydataset

1 import tensorflow as tf
 2 import sys                                      
 3 nets_path = r'slim'                                             #加载环境变量
 4 if nets_path not in sys.path:
 5    sys.path.insert(0,nets_path)
 6 else:
 7     print('already add slim')
 8 from nets.nasnet import nasnet                               #导出nasnet
 9 slim = tf.contrib.slim                                         #slim
10 image_size = nasnet.build_nasnet_mobile.default_image_size     #获得图片输入尺寸 224
11 from preprocessing import preprocessing_factory            #图像处理
12 
13 import os
14 def list_images(directory):
15    """
16    获取所有directory中的所有图片和标签
17    """
18
19    #返回path指定的文件夹包含的文件或文件夹的名字的列表
20    labels = os.listdir(directory)
21    #对标签进行排序,以便训练和验证按照相同的顺序进行
22    labels.sort()
23    #创建文件标签列表
24    files_and_labels = []
25    for label in labels:
26        for f in os.listdir(os.path.join(directory, label)):
27            #转换字符串中所有大写字符为小写再判断
28            if 'jpg' in f.lower() or 'png' in f.lower():
29                #加入列表
30                files_and_labels.append((os.path.join(directory, label, f), label))
31    #理解为解压 把数据路径和标签解压出来
32    filenames, labels = zip(*files_and_labels)
33    #转换为列表 分别储存数据路径和对应标签
34    filenames = list(filenames)
35    labels = list(labels)
36    #列出分类总数 比如两类:['man', 'woman']
37    unique_labels = list(set(labels))
38
39    label_to_int = {}
40    #循环列出数据和数据下标,给每个分类打上标签{'woman': 2, 'man': 1,none:0}
41    for i, label in enumerate(sorted(unique_labels)):
42        label_to_int[label] = i+1
43    print(label,label_to_int[label])
44    #把每个标签化为0 1 这种形式
45    labels = [label_to_int[l] for l in labels]
46    print(labels[:6],labels[-6:])
47    return filenames, labels              #返回储存数据路径和对应转换后的标签
48
49 num_workers = 2                          #定义并行处理数据的线程数量
50
51 #图像批量预处理
52 image_preprocessing_fn = preprocessing_factory.get_preprocessing('nasnet_mobile', is_training=True)
53 image_eval_preprocessing_fn = preprocessing_factory.get_preprocessing('nasnet_mobile', is_training=False)
54
55 def _parse_function(filename, label):      #定义图像解码函数
56    image_string = tf.read_file(filename)
57    image = tf.image.decode_jpeg(image_string, channels=3)          
58    return image, label
59
60 def training_preprocess(image, label):    #定义调整图像大小函数
61    image = image_preprocessing_fn(image, image_size, image_size)
62    return image, label
63
64 def val_preprocess(image, label):       #定义评估图像预处理函数
65    image = image_eval_preprocessing_fn(image, image_size, image_size)
66    return image, label
67
68 #创建带批次的数据集
69 def creat_batched_dataset(filenames, labels,batch_size,isTrain = True):
70
71    dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
72
73    dataset = dataset.map(_parse_function, num_parallel_calls=num_workers)    #对图像解码
74
75    if isTrain == True:
76        dataset = dataset.shuffle(buffer_size=len(filenames))                #打乱数据顺序
77        dataset = dataset.map(training_preprocess, num_parallel_calls=num_workers)#调整图像大小
78    else:
79        dataset = dataset.map(val_preprocess,num_parallel_calls=num_workers)    #调整图像大小
80
81    return dataset.batch(batch_size)                                           #返回批次数据
82
83 #根据目录返回数据集
84 def creat_dataset_fromdir(directory,batch_size,isTrain = True):
85    filenames, labels = list_images(directory)
86    num_classes = len(set(labels))
87    dataset = creat_batched_dataset(filenames, labels,batch_size,isTrain)
88    return dataset,num_classes 复制代码

代码11行,导入了preprocessing_factory函数,该函数是slim模块中封装好的工厂函数,用于生成模型的预处理函数。利用统一封装好的预处理函数,对样本进行操作(代码60、61行),可以提升开发效率,并能够减小出错的可能性。

工厂函数的知识点,属于Python基础知识,这里不再详述。有兴趣的读者可以参考《python带我起飞——入门、进阶、商业实战》一书的6.10节。

注意:

这里用了一个技巧。仿照原NASNet_A_Mobile模型的分类方法,在对分类标签排号时,将标签为0的分类空出来,男人与女人分别为1和2。

另外在代码42行,用到的变量unique_labels是从集合对象转化过来的。在使用时需要对齐固定顺序,所以使用了sorted函数进行变换。如果没有这句,在下次启动的时候,有可能出现标签序号与名称对应不上的现象。在多次中断,多次训练的场景下,会造成训练结果的混乱。这部分知识在《python带我起飞——入门、进阶、商业实战》的第四章集合部分的内容中,也做了重点的强调。

代码实现:定义微调模型类MyNASNetModel

在微调模型的实现中,统一通过定义类MyNASNetModel来实现。在类MyNASNetModel中,大致可分为2大动作:初始化设置、构建模型。

l 初始化设置:定义好构建模型时所需要的必要参数;

l 构建模型:针对训练、测试、应用的三种情况分别构建不同的网络模型。在训练过程中,还要支持加载预编译模型及微调模型。

实现定义类MyNASNetModel并进行初始化模型设置的代码如下:

代码5-2 model

1 import sys                                      
  2 nets_path = r'slim'                                    #加载环境变量
  3 if nets_path not in sys.path:
  4    sys.path.insert(0,nets_path)
  5 else:
  6    print('already add slim')
  7
  8 import tensorflow as tf
  9 from nets.nasnet import nasnet                       #导出nasnet
 10 slim = tf.contrib.slim 
 11
 12 import os  
 13 mydataset = __import__("5-1  mydataset")
 14 creat_dataset_fromdir = mydataset.creat_dataset_fromdir
 15
 16 class MyNASNetModel(object):
 17    """微调模型类MyNASNetModel
 18    """
 19    def __init__(self, model_path=''):
 20        self.model_path = model_path              #原始模型的路径           复制代码

代码20行为初始化MyNASNetModel类的操作。model_path指的是所要加载的原始预编译模型。该操作只有在训练模式下是有意义的。在测试和应用模式下,可以为空。

构建MyNASNetModel类中的基本模型

在构建模型中,无论是训练、测试还是应用,都需要将最基本的NASNet_A_Mobile模型载入。这里通过定义MyNASNetModel类的MyNASNet方法来实现。具体的实现方式与3.3节的实现基本一致,不同的是3.3节构建的是PNASNet网络结构,这里构建的NASNet_A_Mobile结构。

代码5-2 model(续)

21  def MyNASNet(self,images,is_training):
 22        arg_scope = nasnet.nasnet_mobile_arg_scope()          #获得模型命名空间
 23        with slim.arg_scope(arg_scope):
 24            #构建NASNet Mobile模型
 25            logits, end_points = nasnet.build_nasnet_mobile(images,num_classes = self.num_classes+1, is_training=is_training)
 26
 27        global_step = tf.train.get_or_create_global_step()      #定义记录步数的张量
 28
 29        return logits,end_points,global_step                   #返回有用的张量
复制代码

代码25行中,往num_classes参数里传的值代表分类的个数,在本案例中分为男人和女人,一共两类(即,self.num_classes=2,该值是在后文5.2.8节中,build_model方法被赋值的)。再加上一个None类。于是传入的值为self.num_classes+1。

实现MyNASNetModel类中的微调操作

微调操作是针对训练场景下使用的。通过定义MyNASNetModel类中的FineTuneNASNet方法来实现。微调操作主要是对预编译模型的超参进行选择性恢复。

因为预编译模型NASNet_A_Mobile是在ImgNet上训练的,有1000个分类,而本案例中识别男女的任务只有两个分类。所以最后两个输出层的超参不应该被恢复(由于分类不同,导致超参的个数不同)。在实际使用时,最后两层的参数需要对其初始化,并单独训练即可。

代码5-2 model(续)

30 def FineTuneNASNet(self,is_training):      #实现微调模型的网络操作 
 31        model_path = self.model_path
 32
 33        exclude = ['final_layer','aux_7']      #恢复超参, 除了exclude以外的全部恢复
 34        variables_to_restore = slim.get_variables_to_restore(exclude=exclude)
 35        if is_training == True:
 36            init_fn = slim.assign_from_checkpoint_fn(model_path, variables_to_restore)
 37        else:
 38            init_fn = None
 39
 40        tuning_variables = []             #将没有恢复的超参收集起来,用于微调训练
 41        for v in exclude:
 42            tuning_variables += slim.get_variables(v)
 43
 44        return init_fn, tuning_variables复制代码

代码中,使用了exclude列表,将不需要恢复的网络节点收集起来(代码33行),接着将预训练模型中的超参值赋值给剩下的节点,完成了预训练模型的载入(代码36行)。最后使用了tuning_variables列表,将不需要恢复的网络节点权重收集起来(代码40行),用于微调训练。

注意:

这里介绍个技巧,如何获得exclude中的元素(代码33行):通过额外执行代码tf.global_variables(),将张量图中的节点打印出来。从里面找到最后两层的节点,并将其填入代码中即可。在找到节点后,还可以通过slim.get_variables函数,来检查该名称的节点是否正确。例如,可以通过将slim.get_variables('final_layer')的返回值打印出来,来观察张量图中是否有final_layer节点。这部分的原理可以参考《深度学习之TensorFlow:入门、原理与进阶实战》书中第4章的内容(在第11章也有类似的案例)。

代码实现:实现与训练相关的其他方法

在MyNASNetModel类中,还需要定义与训练操作相关的其他方法,具体如下:

l build_acc_base方法:用于构建评估模型的相关节点;

l load_cpk方法:用于载入及保存模型检查点

l build_model_train方法:用于构建训练模型中的损失函数及优化器等操作节点。

具体代码如下:

代码5-2 model(续)

45 def build_acc_base(self,labels):#定义评估函数
 46        #返回张量中最大值的索引
 47        self.prediction = tf.to_int32(tf.argmax(self.logits, 1))
 48        #计算prediction、labels是否相同 
 49        self.correct_prediction = tf.equal(self.prediction, labels)
 50        #计算平均值
 51        self.accuracy = tf.reduce_mean(tf.to_float(self.correct_prediction))
 52        #将前5个最高正确率的值取出来,计算平均值
 53        self.accuracy_top_5 = tf.reduce_mean(tf.to_float(tf.nn.in_top_k(predictions=self.logits, targets=labels, k=5)))
 54
 55    def load_cpk(self,global_step,sess,begin = 0,saver= None,save_path = None):                                                    #储存和导出模型
 56       if begin == 0:
 57            save_path=r'./train_nasnet'                      #定义检查点路径
 58            if not os.path.exists(save_path):
 59                print("there is not a model path:",save_path)
 60            saver = tf.train.Saver(max_to_keep=1)            #生成saver
 61            return saver,save_path
 62        else:
 63            kpt = tf.train.latest_checkpoint(save_path)    #查找最新的检查点
 64            print("load model:",kpt)
 65            startepo= 0                                    #计步
 66            if kpt!=None:
 67                saver.restore(sess, kpt)                     #还原模型
 68                ind = kpt.find("-")
 69                startepo = int(kpt[ind+1:])
 70                print("global_step=",global_step.eval(),startepo)    
 71            return startepo  
 72
 73    def build_model_train(self,images,
 74           labels,learning_rate1,learning_rate2,is_training):
 75           self.logits,self.end_points, 
 76           self.global_step= self.MyNASNet(images,is_training=is_training)
 77        self.step_init = self.global_step.initializer
 78
 79        self.init_fn,self.tuning_variables = self.FineTuneNASNet(
 80            is_training=is_training)
 81        #定义损失函数
 82       tf.losses.sparse_softmax_cross_entropy(labels=labels, 
 83            logits=self.logits)
 84        loss = tf.losses.get_total_loss()
 85        #定义微调的率退化学习速率
 86        learning_rate1=tf.train.exponential_decay(
 87                 learning_rate=learning_rate1, global_step=self.global_step,
 88                 decay_steps=100, decay_rate=0.5)
 89        #定义联调的率退化学习速率
 90        learning_rate2=tf.train.exponential_decay(
 91             learning_rate=learning_rate2, global_step=self.global_step,
 92             decay_steps=100, decay_rate=0.2)                
 93        last_optimizer = tf.train.AdamOptimizer(learning_rate1) #优化器
 94        full_optimizer = tf.train.AdamOptimizer(learning_rate2)   
 95        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)  
 96        with tf.control_dependencies(update_ops):      #更新批量归一化中的参数
 97            #使loss减小方向做优化
 98            self.last_train_op = last_optimizer.minimize(loss, self.global_step,var_list=self.tuning_variables)
 99            self.full_train_op = full_optimizer.minimize(loss, self.global_step)
100
101        self.build_acc_base(labels)                    #定义评估模型相关指标
102        #写入日志,支持tensorBoard操作
103        tf.summary.scalar('accuracy', self.accuracy)    
104        tf.summary.scalar('accuracy_top_5', self.accuracy_top_5)
105
106        #将收集的所有默认图表并合并
107        self.merged = tf.summary.merge_all()
108        #写入日志文件
109        self.train_writer = tf.summary.FileWriter('./log_dir/train')
110        self.eval_writer = tf.summary.FileWriter('./log_dir/eval')
111        #定义检查点相关变量
112        self.saver,self.save_path = self.load_cpk(self.global_step,None)复制代码

在上面代码中,使用了tf.losses接口来获得loss值。通过调用tf.losses.sparse_softmax_cross_entropy 函数计算具体的loss(见代码82行)。该函数会自动将loss值添加到内部集合ops.GraphKeys.LOSSES中。然后调用tf.losses.get_total_loss函数,将ops.GraphKeys.LOSSES集合中的所有loss值获取,并返回回来(见代码84行)。

在代码96行中,在反向优化时,使用了tf.control_dependencies函数对的批量归一化操作中的均值与方差进行更新。

代码实现:构建模型,用于训练、测试、使用

在MyNASNetModel类中,定义build_model方法用与构建模型的实现。在build_model方法中,通过参数mode来指定模型的具体使用场景。具体代码如下:

代码5-2 model(续)

113 def build_model(self,mode='train',testdata_dir='./data/val',traindata_dir='./data/train', batch_size=32,learning_rate1=0.001,learning_rate2=0.001):
114
115        if mode == 'train':        
116            tf.reset_default_graph()
117            #创建训练数据和测试数据的Dataset数据集
118            dataset,self.num_classes = creat_dataset_fromdir(traindata_dir,batch_size)
119            testdataset,_ = creat_dataset_fromdir(testdata_dir,batch_size,isTrain = False)
120
121            #创建一个可初始化的迭代器
122            iterator = tf.data.Iterator.from_structure(dataset.output_types, dataset.output_shapes)
123            #读取数据
124            images, labels = iterator.get_next()
125
126            self.train_init_op = iterator.make_initializer(dataset)
127            self.test_init_op = iterator.make_initializer(testdataset)
128
129            self.build_model_train(images, labels,learning_rate1,learning_rate2,is_training=True)
130            self.global_init = tf.global_variables_initializer()    #定义全局初始化op
131            tf.get_default_graph().finalize()                #将后续的图设为只读
132        elif mode == 'test':
133            tf.reset_default_graph()
134
135            #创建测试数据的Dataset数据集
136            testdataset,self.num_classes = creat_dataset_fromdir(testdata_dir,batch_size,isTrain = False)
137
138            #创建一个可初始化的迭代器
139            iterator = tf.data.Iterator.from_structure(testdataset.output_types, testdataset.output_shapes)
140            #读取数据
141            self.images, labels = iterator.get_next()
142
143            self.test_init_op = iterator.make_initializer(testdataset)
144            self.logits,self.end_points, self.global_step= self.MyNASNet(self.images,is_training=False)
145            self.saver,self.save_path = self.load_cpk(self.global_step,None)                  #定义检查点相关变量
146            #评估指标
147            self.build_acc_base(labels)
148            tf.get_default_graph().finalize()            #将后续的图设为只读
149        elif mode == 'eval':
150            tf.reset_default_graph()
151            #创建测试数据的Dataset数据集
152            testdataset,self.num_classes = creat_dataset_fromdir(testdata_dir,batch_size,isTrain = False)
153
154            #创建一个可初始化的迭代器
155            iterator = tf.data.Iterator.from_structure(testdataset.output_types, testdataset.output_shapes)
156            #读取数据
157            self.images, labels = iterator.get_next()
158
159            self.logits,self.end_points, self.global_step= self.MyNASNet(self.images,is_training=False)
160            self.saver,self.save_path = self.load_cpk(self.global_step,None)   #定义检查点相关变量
161            tf.get_default_graph().finalize()                        #将后续的图设为只读复制代码

代码115行,对mode进行了判断,并按照具体的场景进行构建模型。针对训练、测试、使用的三个场景,构建的步骤几乎一样,具体如下:

(1)清空张量图(代码116、133、150);

(2)生成数据集(代码118、136、152);

(3)定义网络结构(代码129、144、159)。

测试与使用的场景是最相似的。在代码中测试比使用的操作对了个评估节点的生成(代码147)。

注意:

在每个操作分支的最后代码部分都加了代码tf.get_default_graph().finalize()(见代码131、148、161行),这是一个很好的习惯。该代码的功能是把图锁定,之后想要添加任何新的操作都会产生错误。这么做的意图是防止在后面训练或是测试过程中,由于开发人员疏忽,在图中添加额外的图操作。一旦在循环内部加了某个张量的操作,将会使整体性能大大下降。然而这种错误又很难发现。利用锁定图的方法,可以避免这种情况的发生。

代码实现:通过二次迭代来训练微调模型

训练微调模型的操作是在代码文件“5-3 train.py”中单独实现的。与正常的训练方式不同,这里使用了二次迭代的方式:

l 第一次迭代:微调模型,固定预编译模型载入的权重,只训练最后两层;

l 第二次迭代:联调模型,使用更小的学习率,训练全部节点。

先将类MyNASNetModel进行实例化,在调用其build_model方法构建模型,然后使用session开始训练。具体代码如下:

代码5-3 train

162 import tensorflow as tf
163 model = __import__("5-2  model")
164 MyNASNetModel = model.MyNASNetModel
165
166 batch_size = 32
167 train_dir  = 'data/train'
168 val_dir  = 'data/val'
169
170 learning_rate1 = 1e-1                                       #定义两次迭代的学习率
171 learning_rate2 = 1e-3
172
173 mymode = MyNASNetModel(r'nasnet-a_mobile_04_10_2017\model.ckpt')#初始化模型
174 mymode.build_model('train',val_dir,train_dir,batch_size,learning_rate1 ,learning_rate2 )                                                                    #将模型定义载入图中
175
176 num_epochs1 = 20                                           #微调的迭代次数
177 num_epochs2 = 200                                        #联调的迭代次数
178
179 with tf.Session() as sess:
180    sess.run(mymode.global_init)                         #初始全局节点
181
182    step = 0
183    step = mymode.load_cpk(mymode.global_step,sess,1,mymode.saver,mymode.save_path )#载入模型
184    print(step)
185    if step == 0:                                        #微调
186        mymode.init_fn(sess)                                 #载入预编译模型权重
187
188        for epoch in range(num_epochs1):
189
190            print('Starting1 epoch %d / %d' % (epoch + 1, num_epochs1))    #输出进度 
191            #用训练集初始化迭代器
192            sess.run(mymode.train_init_op)                                #数据集从头开始
193            while True:
194                try:
195                    step += 1
196                    #预测,合并图,训练
197                    acc,accuracy_top_5, summary, _ = sess.run([mymode.accuracy, mymode.accuracy_top_5,mymode.merged,mymode.last_train_op])
198
199                    #mymode.train_writer.add_summary(summary, step)#写入日志文件
200                    if step % 100 == 0:
201                        print(f'step: {step} train1 accuracy: {acc},{accuracy_top_5}')
202                except tf.errors.OutOfRangeError:#数据集指针在最后
203                    print("train1:",epoch," ok")
204                    mymode.saver.save(sess, mymode.save_path+"/mynasnet.cpkt",   global_step=mymode.global_step.eval())
205                    break
206
207        sess.run(mymode.step_init)                    #微调结束,计数器从0开始
208
209    #整体训练
210    for epoch in range(num_epochs2):
211        print('Starting2 epoch %d / %d' % (epoch + 1, num_epochs2))
212        sess.run(mymode.train_init_op)
213        while True:
214            try:
215                step += 1
216                #预测,合并图,训练
217                acc, summary, _ = sess.run([mymode.accuracy, mymode.merged, mymode.full_train_op])
218
219                mymode.train_writer.add_summary(summary, step)#写入日志文件
220
221                if step % 100 == 0:
222                    print(f'step: {step} train2 accuracy: {acc}')
223            except tf.errors.OutOfRangeError:
224                print("train2:",epoch," ok")
225                mymode.saver.save(sess, mymode.save_path+"/mynasnet.cpkt",   global_step=mymode.global_step.eval())
226                break复制代码

将以上代码运行后,经过一段时间的训练,可以在本地找到“train_nasnet”文件夹,里面放着的就是训练生成的模型文件。

代码实现:测试模型

测试模型的操作是在代码文件“5-4 test.py”中单独实现的。这里实现了使用测试数据集对现有模型的评估,并且使用单张图片放到模型里进行预测。

1. 定义测试模型所需要的功能函数

首先定义函数check_accuracy实现准确率的计算,接着定义函数check_sex实现男女性别的识别。具体代码如下:

代码5-4 test

227 import tensorflow as tf
228 model = __import__("5-2  model")
229 MyNASNetModel = model.MyNASNetModel
230
231 import sys                                      
232 nets_path = r'slim'                                     #加载环境变量
233 if nets_path not in sys.path:
234    sys.path.insert(0,nets_path)
235 else:
236    print('already add slim')
237
238 from nets.nasnet import nasnet                     #导出nasnet
239 slim = tf.contrib.slim                                 #slim
240 image_size = nasnet.build_nasnet_mobile.default_image_size  #获得图片输入尺寸 224
241
242 import numpy as np
243 from PIL import Image
244
245 batch_size = 32
246 test_dir  = 'data/val'
247
248 def check_accuracy(sess):
249    """
250    测试模型准确率
251    """
252    sess.run(mymode.test_init_op)                  #初始化测试数据集
253    num_correct, num_samples = 0, 0                 #定义正确个数 和 总个数
254    i = 0
255    while True:
256        i+=1
257        print('i',i)
258        try:
259            #计算correct_prediction 获取prediction、labels是否相同 
260            correct_pred,accuracy,logits = sess.run([mymode.correct_prediction,mymode.accuracy,mymode.logits])
261            #累加correct_pred
262            num_correct += correct_pred.sum()
263            num_samples += correct_pred.shape[0]
264            print("accuracy",accuracy,logits)
265
266
267        except tf.errors.OutOfRangeError:          #捕获异常,数据用完自动跳出
268            print('over')
269            break
270
271    acc = float(num_correct) / num_samples         #计算并返回准确率
272    return acc 
273
274
275 def check_sex(imgdir,sess):                        #定义函数识别男女
276    img = Image.open(image_dir)                      #读入图片
277    if "RGB"!=img.mode :                             #检查图片格式
278        img = img.convert("RGB") 
279
280    img = np.asarray(img.resize((image_size,image_size)),     #图像预处理  
281                          dtype=np.float32).reshape(1,image_size,image_size,3)
282    img = 2 *( img / 255.0)-1.0 
283
284    prediction = sess.run(mymode.logits, {mymode.images: img})#传入nasnet输入端中
285    print(prediction)
286
287    pre = prediction.argmax()                    #返回张量中最大值的索引
288    print(pre)
289
290    if pre == 1: img_id = 'man'
291    elif pre == 2: img_id = 'woman'
292    else: img_id = 'None'
293    plt.imshow( np.asarray((img[0]+1)*255/2,np.uint8 )  )
294    plt.show()
295    print(img_id,"--",image_dir)                    #返回类别
296    return pre复制代码

2. 建立会话,进行测试

首先建立会话session,对模型进行测试,接着取2张图片输入模型,进行男女的判断。具体代码如下:

代码5-4 test(续)

297 mymode = MyNASNetModel()                                     #初始化模型
298 mymode.build_model('test',test_dir )                     #将模型定义载入图中
299
300 with tf.Session() as sess:  
301    #载入模型
302    mymode.load_cpk(mymode.global_step,sess,1,mymode.saver,mymode.save_path )
303
304    #测试模型的准确性
305    val_acc = check_accuracy(sess)
306    print('Val accuracy: %f\n' % val_acc)
307
308    #单张图片测试
309    image_dir = 'tt2t.jpg'                                 #选取测试图片
310    check_sex(image_dir,sess)
311
312    image_dir = test_dir + '\\woman' + '\\000001.jpg'       #选取测试图片
313    check_sex(image_dir,sess)
314
315    image_dir = test_dir + '\\man' + '\\000003.jpg'         #选取测试图片
316    check_sex(image_dir,sess)复制代码

该程序使用的是迭代了100次数据集后的模型文件(如果要效果提高,可以再运行久一点)。代码运行后,输出结果如下。

(1)显示测试集的输出结果:

i 1

accuracy 0.90625 [[-3.813714 1.4075054 1.1485975 ]

[-7.3948846 6.220533 -1.4093535 ]

[-1.9391974 3.048838 0.21784738]

[-3.873174 4.530942 0.43135062]

……

[-3.8561587 2.7012844 -0.3634925 ]

[-4.4860134 4.7661724 -0.67080706]

[-2.9615571 2.8164086 0.71033645]]

i 2

accuracy 0.90625 [[ -6.6900268 -2.373093 6.6710057 ]

[ -4.1005263 0.74619263 4.980012 ]

[ -5.6469827 0.39027584 1.2689826 ]

……

[ -5.8080773 0.9121424 3.4134243 ]

[ -4.242001 0.08483959 4.056322 ]]

i 3

over

Val accuracy: 0.906250

上面显示的是测试集中man和woman文件夹中图片的计算结果。最终模型的准确率为90%。

(2)显示单张图片的运行结果:

[[-4.8022223 1.9008529 1.9379601]]

2

连载二:PyCon2018|用slim微调PNASNet模型(附源码)

图5-3 分辨男女测试图片(a)

woman -- tt2t.jpg

[[-6.181205 -2.9042015 6.1356106]]

2

连载二:PyCon2018|用slim微调PNASNet模型(附源码)

图5-3 分辨男女测试图片(b)

woman -- data/val\woman\000001.jpg

[[-4.896065 1.7791721 1.3118265]]

1

连载二:PyCon2018|用slim微调PNASNet模型(附源码)

图5-3 分辨男女测试图片(c)

man -- data/val\man\000003.jpg

上面显示了3张图片,分别为自选图片、测试数据集中的女人图片、测试数据集中的男人图片,每张图片下面显示了模型识别的结果。可以看到结果与图片内容一致。

结尾

文内代码可以直接运行使用。如果不想手动搭建,还可以下载本文的配套代码。

【代码获取】:关注公众号: xiangyuejiqiren    公众号回复“ pycon2

如果觉得本文有用

可以分享给更多小伙伴

连载二:PyCon2018|用slim微调PNASNet模型(附源码)


以上所述就是小编给大家介绍的《连载二:PyCon2018|用slim微调PNASNet模型(附源码)》,希望对大家有所帮助,如果大家有任何疑问请给我留言,小编会及时回复大家的。在此也非常感谢大家对 码农网 的支持!

查看所有标签

猜你喜欢:

本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们

HTML5

HTML5

Matthew David / Focal Press / 2010-07-29 / USD 39.95

Implement the powerful new multimedia and interactive capabilities offered by HTML5, including style control tools, illustration tools, video, audio, and rich media solutions. Understand how HTML5 is ......一起来看看 《HTML5》 这本书的介绍吧!

JSON 在线解析
JSON 在线解析

在线 JSON 格式化工具

html转js在线工具
html转js在线工具

html转js在线工具

HEX CMYK 转换工具
HEX CMYK 转换工具

HEX CMYK 互转工具