深度学习实战 cifar数据集预处理技术分析

栏目: 数据库 · 发布时间: 6年前

内容简介:本文首发于微信公众号:"算法与编程之美",欢迎关注,及时了解更多此系列文章。cifar数据集是以cifar-10-python.tar.gz的压缩包格式存储在远程服务器,利用keras的get_file()方法下载压缩包并执行解压,解压后得到:

欢迎点击「算法与编程之美」↑关注我们!

本文首发于微信公众号:"算法与编程之美",欢迎关注,及时了解更多此系列文章。

cifar数据集是以cifar-10-python.tar.gz的压缩包格式存储在远程服务器,利用keras的get_file()方法下载压缩包并执行解压,解压后得到:

cifar-10-batches-py

├── batches.meta

├── data_batch_1

├── data_batch_2

├── data_batch_3

├── data_batch_4

├── data_batch_5

├── readme.html

└── test_batch

其中data_batch_[1..5]为训练集数据,test_batch为测试集数据。

def load_data():

"""Loads CIFAR10 dataset.

# Returns

Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`.

"""

dirname = 'cifar-10-batches-py'

origin = 'https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'

path = get_file(dirname, origin=origin, untar=True)

num_train_samples = 50000

x_train = np.empty((num_train_samples, 3, 32, 32), dtype='uint8')

y_train = np.empty((num_train_samples,), dtype='uint8')

for i in range(1, 6):

fpath = os.path.join(path, 'data_batch_' + str(i))

(x_train[(i - 1) * 10000: i * 10000, :, :, :],

y_train[(i - 1) * 10000: i * 10000]) = load_batch(fpath)

fpath = os.path.join(path, 'test_batch')

x_test, y_test = load_batch(fpath)

y_train = np.reshape(y_train, (len(y_train), 1))

y_test = np.reshape(y_test, (len(y_test), 1))

if K.image_data_format() == 'channels_last':

x_train = x_train.transpose(0, 2, 3, 1)

x_test = x_test.transpose(0, 2, 3, 1)

return (x_train, y_train), (x_test, y_test)

data_batch_i 存放了cifar的训练集数据,每个文件1万条数据,采用pickle的方式进行序列化数据,利用pickle.load()的方式加载文件并反序列化为之前的dict(),该字典中有’data’和’label’两个key,分别存放了数据和标签。

def load_batch(fpath, label_key='labels'):

"""Internal utility for parsing CIFAR data.

# Arguments

fpath: path the file to parse.

label_key: key for label data in the retrieve

dictionary.

# Returns

A tuple `(data, labels)`.

"""

with open(fpath, 'rb') as f:

if sys.version_info < (3,):

d = cPickle.load(f)

else:

d = cPickle.load(f, encoding='bytes')

# decode utf8

d_decoded = {}

for k, v in d.items():

d_decoded[k.decode('utf8')] = v

d = d_decoded

data = d['data']

labels = d[label_key]

data = data.reshape(data.shape[0], 3, 32, 32)

return data, labels

where2 go 团队

   

微信号:算法与编程之美          

深度学习实战 cifar数据集预处理技术分析

长按识别二维码关注我们!

温馨提示: 点击页面右下角 “写留言”发表评论,期待您的参与!期待您的转发!


以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持 码农网

查看所有标签

猜你喜欢:

本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们

Essential PHP Security

Essential PHP Security

Chris Shiflett / O'Reilly Media / 2005-10-13 / USD 29.95

Being highly flexible in building dynamic, database-driven web applications makes the PHP programming language one of the most popular web development tools in use today. It also works beautifully wit......一起来看看 《Essential PHP Security》 这本书的介绍吧!

JSON 在线解析
JSON 在线解析

在线 JSON 格式化工具

RGB转16进制工具
RGB转16进制工具

RGB HEX 互转工具

html转js在线工具
html转js在线工具

html转js在线工具