Tensorflow数据读取机制剖析

栏目: 编程工具 · 发布时间: 5年前

内容简介:展示如何将数据输入到计算图中

展示如何将数据输入到计算图中

Tensorflow数据读取机制剖析

Dataset 可以看作是相同类型“元素”的有序列表,在实际使用时,单个元素可以是向量、字符串、图片甚至是tuple或dict。

数据集对象实例化:

dataset=tf.data.Dataset.from_tensor_slice(<data>)

迭代器对象实例化:

iterator=dataset.make_one_shot_iterator()
one_element=iterator.get_next()

读取结束异常:如果一个 dataset 中的元素被读取完毕,再尝试 sess.run(one_element) 的话,会抛出 tf.errors.OutOfRangeError 异常,这个行为与使用队列方式读取数据是一致的。

高维数据集的使用

tf.data.Dataset.from_tensor_slices 真正作用是切分传入Tensor的第一个维度,生成相应的dataset,即第一维表明数据集中数据的数量,之后切分batch等操作均以第一维为基础。

dataset=tf.data.Dataset.from_tensor_slices(np.random.uniform((5,2)))
iterator=dataset.make_one_shot_iterator()
one_element=iterator.get_next()
with tf.Session(config=config) as sess:
    try:
        while True:
            print(sess.run(one_element))
    except tf.errors.OutOfRangeError as e:
        print('end~')

输出:

[0.1,0.2]
[0.3,0.2]
[0.1,0.6]
[0.4,0.3]
[0.5,0.2]

tuple组合数据

dataset=tf.data.Dataset.from_tensor_slices((np.array([1.,2.,3.,4.,5.]),
                                            np.random.uniform(size=(5,2))))
iterator=dataset.make_one_shot_iterator()
one_element=iterator.get_next()
with tf.Session() as sess:
    try:
        while True:
            print(sess.run(one_element))
    except tf.errors.OutOfRangeError:
        print('end~')

输出:

(1.,array(0.1,0.3))
(2.,array(0.2,0.4))
...

数据集处理方法

Dataset 支持一类特殊操作: Transformation 。一个 Dataset 通过 Transformation 变成一个新的 Dataset 。常用的 Transformation

map
batch
shuffle
repeat

其中,

  • mapPython 中的 map 一致,接受一个函数, Dataset 中的每个元素都会作为这个函数的输入,并将函数返回值作为新的 Dataset

    dataset=dataset.map(lambda x:x+1)

    注意: map 函数可以使用 num_parallel_calls 参数并行化

  • batch 就是将多个元素组成batch。

    dataset=tf.data.Dataset.from_tensor_slices(
    {
        'a':np.array([1.,2.,3.,4.,5.]),
        'b':np.random.uniform(size=(5,2))
    })
    ###
    dataset=dataset.batch(2)  # batch_size=2
    ###
    iterator=dataset.make_one_shot_iterator()
    one_element=iterator.get_next()
    with tf.Session() as sess:
        try:
            while True:
                print(one_element)
        except tf.errors.OutOfRangeError:
            print('end~')

    输出:

    {'a':array([1.,2.]),'b':array([[1.,2.],[3.,4.]])}
    {'a':array([3.,4.]),'b':array([[5.,6.],[7.,8.]])}
  • shuffle 的功能是打乱 dataset 中的元素,它有个参数 buffer_size ,表示打乱时使用的 buffer 的大小,不应设置过小,推荐值1000.

    dataset=tf.data.Dataset.from_tensor_slices(
    {
        'a':np.array([1.,2.,3.,4.,5.]),
        'b':np.random.uniform(size=(5,2))
    })
    ###
    dataset=dataset.shuffle(buffer_size=5)
    ###
    iterator=dataset.make_one_shot_iterator()
    one_element=iterator.get_next()
    with tf.Session() as sess:
        try:
            while True:
                print(one_element)
        except tf.errors.OutOfRangeError:
            print('end~')
  • repeat 的功能就是将整个序列重复多次, 主要用来处理机器学习中的 epoch 。假设原先的数据是一个 epoch ,使用 repeat(2) 可以使之变成2个epoch.

    dataset=tf.data.Dataset.from_tensor_slices({
        'a':np.array([1.,2.,3.,4.,5.]),
        'b':np.random.uniform(size=(5,2))
    })
    ###
    dataset=dataset.repeat(2)  # 2epoch
    ###
    # iterator, one_element...

    注意:如果直接调用 repeat() 函数的话,生成的序列会无限重复下去,没有结果,因此不会抛出 tf.errors.OutOfRangeError 异常。

模拟读入磁盘图片及其Label示例

def _parse_function(filename,label):  # 接受单个元素,转换为目标
    img_string=tf.read_file(filename)
    img_decoded=tf.image.decode_images(img_string)
    img_resized=tf.image.resize_images(image_decoded,[28,28])
    return image_resized,label

filenames=tf.constant(['data/img1.jpg','data/img2.jpg',...])
labels=tf.constant([1,3,...])
dataset=tf.data.Dataset.from_tensor_slices((filenames,labels))
dataset=dataset.map(_parse_function)  # num_parallel_calls 并行
dataset=dataset.shuffle(buffer_size=1000).batch_size(32).repeat(10)

更多Dataset创建方法

  • tf.data.TextLineDataset() :函数输入一个文件列表,输出一个Dataset。dataset中的每一个元素对应文件中的一行,可以使用该方法读入csv文件。
  • tf.data.FixedLengthRecordDataset() :函数输入一个文件列表和 record_bytes 参数,dataset中每一个元素是文件中固定字节数 record_bytes 的内容,可用来读取二进制保存的文件,如CIFAR10。
  • tf.data.TFRecordDataset() :读取TFRecord文件,dataset中每一个元素是一个TFExample。

更多Iterator创建方法

最简单的创建 Iterator 方法是通过 dataset.make_one_shot_iterator() 创建一个iterator。

除了这种iterator之外,还有更复杂的Iterator:

  • initializable iterator
  • reinitializable iterator
  • feedable iterator

其中,initializable iterator方法要在使用前通过 sess.run() 进行初始化, initializable iterator还可用于读入较大数组。 在使用 tf.data.Dataset.from_tensor_slices(array) 时,实际上发生的事情是将array作为一个 tf.constants 保存到了计算图中,当array很大时,会导致计算图变得很大,给传输保存带来不便,这时可以使用一个 placeholder 取代这里的array,并使用initializable iterator,只在需要时将array传进去,这样即可避免将大数组保存在图里。

features_placeholder=tf.placeholder(<features.dtype>,<features.shape>)
labels_placeholder=tf.placeholder(<labels.dtype>,<labels.shape>)
dataset=tf.data.Dataset.from_tensor_slices((features_placeholder,labels_placeholder))
iterator=dataset.make_initializable_iterator()
next_element=iterator.get_next()
sess.run(iterator.initializer,feed_dict={features_placeholder:features,labels_placeholder:labels})

Tensorflow内部读取机制

Tensorflow数据读取机制剖析

对于文件名队列,使用 tf.train.string_input_producer() 函数, tf.train.string_input_producer() 还有两个重要参数, num_epochesshuffle

内存队列不需要我们建立,只需要使用 reader 对象从文件名队列中读取数据即可,使用 tf.train.start_queue_runners() 函数启动队列,填充两个队列的数据。

with tf.Session() as sess:
    filenames=['A.jpg','B.jpg','C.jpg']
    filename_queue=tf.train.string_input_producer(filenames,shuffle=True,num_epoch=5)
    reader=tf.WholeFileReader()
    key,value=reader.read(filename_queue)
    # tf.train.string_input_producer()定义了一个epoch变量,需要对其进行初始化
    tf.local_variables_initializer().run()
    threads=tf.train.start_queue_runners(sess=sess)
    i=0
    while True:
        i+=1
        image_data=sess.run(value)
        with open('reader/test_%d.jpg'%i,'wb') as f:
            f.write(image_data)

Linux公社的RSS地址https://www.linuxidc.com/rssFeed.aspx

本文永久更新链接地址: https://www.linuxidc.com/Linux/2019-05/158706.htm


以上就是本文的全部内容,希望本文的内容对大家的学习或者工作能带来一定的帮助,也希望大家多多支持 码农网

查看所有标签

猜你喜欢:

本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们

Vim实用技巧

Vim实用技巧

[英] Drew Neil / 杨源、车文隆 / 人民邮电出版社 / 2014-5-1 / 59.00元

vim是一款功能丰富而强大的文本编辑器,其代码补全、编译及错误跳转等方便编程的功能特别丰富,在程序员中得到非常广泛的使用。vim能够大大提高程序员的工作效率。对于vim高手来说,vim能以与思考同步的速度编辑文本。同时,学习和熟练使用vim又有一定的难度。 《vim实用技巧》为那些想要提升自己的程序员编写,阅读本书是熟练地掌握高超的vim技巧的必由之路。全书共21章,包括121个技巧。每一章......一起来看看 《Vim实用技巧》 这本书的介绍吧!

RGB转16进制工具
RGB转16进制工具

RGB HEX 互转工具

SHA 加密
SHA 加密

SHA 加密工具

XML 在线格式化
XML 在线格式化

在线 XML 格式化压缩工具