Tensorflow数据读取机制剖析

栏目: 编程工具 · 发布时间: 6年前

内容简介:展示如何将数据输入到计算图中

展示如何将数据输入到计算图中

Tensorflow数据读取机制剖析

Dataset 可以看作是相同类型“元素”的有序列表,在实际使用时,单个元素可以是向量、字符串、图片甚至是tuple或dict。

数据集对象实例化:

dataset=tf.data.Dataset.from_tensor_slice(<data>)

迭代器对象实例化:

iterator=dataset.make_one_shot_iterator()
one_element=iterator.get_next()

读取结束异常:如果一个 dataset 中的元素被读取完毕,再尝试 sess.run(one_element) 的话,会抛出 tf.errors.OutOfRangeError 异常,这个行为与使用队列方式读取数据是一致的。

高维数据集的使用

tf.data.Dataset.from_tensor_slices 真正作用是切分传入Tensor的第一个维度,生成相应的dataset,即第一维表明数据集中数据的数量,之后切分batch等操作均以第一维为基础。

dataset=tf.data.Dataset.from_tensor_slices(np.random.uniform((5,2)))
iterator=dataset.make_one_shot_iterator()
one_element=iterator.get_next()
with tf.Session(config=config) as sess:
    try:
        while True:
            print(sess.run(one_element))
    except tf.errors.OutOfRangeError as e:
        print('end~')

输出:

[0.1,0.2]
[0.3,0.2]
[0.1,0.6]
[0.4,0.3]
[0.5,0.2]

tuple组合数据

dataset=tf.data.Dataset.from_tensor_slices((np.array([1.,2.,3.,4.,5.]),
                                            np.random.uniform(size=(5,2))))
iterator=dataset.make_one_shot_iterator()
one_element=iterator.get_next()
with tf.Session() as sess:
    try:
        while True:
            print(sess.run(one_element))
    except tf.errors.OutOfRangeError:
        print('end~')

输出:

(1.,array(0.1,0.3))
(2.,array(0.2,0.4))
...

数据集处理方法

Dataset 支持一类特殊操作: Transformation 。一个 Dataset 通过 Transformation 变成一个新的 Dataset 。常用的 Transformation

map
batch
shuffle
repeat

其中,

  • mapPython 中的 map 一致,接受一个函数, Dataset 中的每个元素都会作为这个函数的输入,并将函数返回值作为新的 Dataset

    dataset=dataset.map(lambda x:x+1)

    注意: map 函数可以使用 num_parallel_calls 参数并行化

  • batch 就是将多个元素组成batch。

    dataset=tf.data.Dataset.from_tensor_slices(
    {
        'a':np.array([1.,2.,3.,4.,5.]),
        'b':np.random.uniform(size=(5,2))
    })
    ###
    dataset=dataset.batch(2)  # batch_size=2
    ###
    iterator=dataset.make_one_shot_iterator()
    one_element=iterator.get_next()
    with tf.Session() as sess:
        try:
            while True:
                print(one_element)
        except tf.errors.OutOfRangeError:
            print('end~')

    输出:

    {'a':array([1.,2.]),'b':array([[1.,2.],[3.,4.]])}
    {'a':array([3.,4.]),'b':array([[5.,6.],[7.,8.]])}
  • shuffle 的功能是打乱 dataset 中的元素,它有个参数 buffer_size ,表示打乱时使用的 buffer 的大小,不应设置过小,推荐值1000.

    dataset=tf.data.Dataset.from_tensor_slices(
    {
        'a':np.array([1.,2.,3.,4.,5.]),
        'b':np.random.uniform(size=(5,2))
    })
    ###
    dataset=dataset.shuffle(buffer_size=5)
    ###
    iterator=dataset.make_one_shot_iterator()
    one_element=iterator.get_next()
    with tf.Session() as sess:
        try:
            while True:
                print(one_element)
        except tf.errors.OutOfRangeError:
            print('end~')
  • repeat 的功能就是将整个序列重复多次, 主要用来处理机器学习中的 epoch 。假设原先的数据是一个 epoch ,使用 repeat(2) 可以使之变成2个epoch.

    dataset=tf.data.Dataset.from_tensor_slices({
        'a':np.array([1.,2.,3.,4.,5.]),
        'b':np.random.uniform(size=(5,2))
    })
    ###
    dataset=dataset.repeat(2)  # 2epoch
    ###
    # iterator, one_element...

    注意:如果直接调用 repeat() 函数的话,生成的序列会无限重复下去,没有结果,因此不会抛出 tf.errors.OutOfRangeError 异常。

模拟读入磁盘图片及其Label示例

def _parse_function(filename,label):  # 接受单个元素,转换为目标
    img_string=tf.read_file(filename)
    img_decoded=tf.image.decode_images(img_string)
    img_resized=tf.image.resize_images(image_decoded,[28,28])
    return image_resized,label

filenames=tf.constant(['data/img1.jpg','data/img2.jpg',...])
labels=tf.constant([1,3,...])
dataset=tf.data.Dataset.from_tensor_slices((filenames,labels))
dataset=dataset.map(_parse_function)  # num_parallel_calls 并行
dataset=dataset.shuffle(buffer_size=1000).batch_size(32).repeat(10)

更多Dataset创建方法

  • tf.data.TextLineDataset() :函数输入一个文件列表,输出一个Dataset。dataset中的每一个元素对应文件中的一行,可以使用该方法读入csv文件。
  • tf.data.FixedLengthRecordDataset() :函数输入一个文件列表和 record_bytes 参数,dataset中每一个元素是文件中固定字节数 record_bytes 的内容,可用来读取二进制保存的文件,如CIFAR10。
  • tf.data.TFRecordDataset() :读取TFRecord文件,dataset中每一个元素是一个TFExample。

更多Iterator创建方法

最简单的创建 Iterator 方法是通过 dataset.make_one_shot_iterator() 创建一个iterator。

除了这种iterator之外,还有更复杂的Iterator:

  • initializable iterator
  • reinitializable iterator
  • feedable iterator

其中,initializable iterator方法要在使用前通过 sess.run() 进行初始化, initializable iterator还可用于读入较大数组。 在使用 tf.data.Dataset.from_tensor_slices(array) 时,实际上发生的事情是将array作为一个 tf.constants 保存到了计算图中,当array很大时,会导致计算图变得很大,给传输保存带来不便,这时可以使用一个 placeholder 取代这里的array,并使用initializable iterator,只在需要时将array传进去,这样即可避免将大数组保存在图里。

features_placeholder=tf.placeholder(<features.dtype>,<features.shape>)
labels_placeholder=tf.placeholder(<labels.dtype>,<labels.shape>)
dataset=tf.data.Dataset.from_tensor_slices((features_placeholder,labels_placeholder))
iterator=dataset.make_initializable_iterator()
next_element=iterator.get_next()
sess.run(iterator.initializer,feed_dict={features_placeholder:features,labels_placeholder:labels})

Tensorflow内部读取机制

Tensorflow数据读取机制剖析

对于文件名队列,使用 tf.train.string_input_producer() 函数, tf.train.string_input_producer() 还有两个重要参数, num_epochesshuffle

内存队列不需要我们建立,只需要使用 reader 对象从文件名队列中读取数据即可,使用 tf.train.start_queue_runners() 函数启动队列,填充两个队列的数据。

with tf.Session() as sess:
    filenames=['A.jpg','B.jpg','C.jpg']
    filename_queue=tf.train.string_input_producer(filenames,shuffle=True,num_epoch=5)
    reader=tf.WholeFileReader()
    key,value=reader.read(filename_queue)
    # tf.train.string_input_producer()定义了一个epoch变量,需要对其进行初始化
    tf.local_variables_initializer().run()
    threads=tf.train.start_queue_runners(sess=sess)
    i=0
    while True:
        i+=1
        image_data=sess.run(value)
        with open('reader/test_%d.jpg'%i,'wb') as f:
            f.write(image_data)

Linux公社的RSS地址https://www.linuxidc.com/rssFeed.aspx

本文永久更新链接地址: https://www.linuxidc.com/Linux/2019-05/158706.htm


以上就是本文的全部内容,希望本文的内容对大家的学习或者工作能带来一定的帮助,也希望大家多多支持 码农网

查看所有标签

猜你喜欢:

本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们

复杂网络理论及其应用

复杂网络理论及其应用

汪小帆、李翔、陈关荣 / 清华大学出版社 / 2006 / 45.00元

国内首部复杂网络专著 【图书目录】 第1章 引论 1.1 引言 1.2 复杂网络研究简史 1.3 基本概念 1.4 本书内容简介 参考文献 第2章 网络拓扑基本模型及其性质 2.1 引言 2.2 规则网络 2.3 随机图 2.4 小世界网络模型 2.5 无标度网络模型 ......一起来看看 《复杂网络理论及其应用》 这本书的介绍吧!

JS 压缩/解压工具
JS 压缩/解压工具

在线压缩/解压 JS 代码

HTML 编码/解码
HTML 编码/解码

HTML 编码/解码

html转js在线工具
html转js在线工具

html转js在线工具