深度学习实战 cifar数据集预处理技术分析

栏目: 数据库 · 发布时间: 5年前

内容简介:本文首发于微信公众号:"算法与编程之美",欢迎关注,及时了解更多此系列文章。cifar数据集是以cifar-10-python.tar.gz的压缩包格式存储在远程服务器,利用keras的get_file()方法下载压缩包并执行解压,解压后得到:

欢迎点击「算法与编程之美」↑关注我们!

本文首发于微信公众号:"算法与编程之美",欢迎关注,及时了解更多此系列文章。

cifar数据集是以cifar-10-python.tar.gz的压缩包格式存储在远程服务器,利用keras的get_file()方法下载压缩包并执行解压,解压后得到:

cifar-10-batches-py

├── batches.meta

├── data_batch_1

├── data_batch_2

├── data_batch_3

├── data_batch_4

├── data_batch_5

├── readme.html

└── test_batch

其中data_batch_[1..5]为训练集数据,test_batch为测试集数据。

def load_data():

"""Loads CIFAR10 dataset.

# Returns

Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`.

"""

dirname = 'cifar-10-batches-py'

origin = 'https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'

path = get_file(dirname, origin=origin, untar=True)

num_train_samples = 50000

x_train = np.empty((num_train_samples, 3, 32, 32), dtype='uint8')

y_train = np.empty((num_train_samples,), dtype='uint8')

for i in range(1, 6):

fpath = os.path.join(path, 'data_batch_' + str(i))

(x_train[(i - 1) * 10000: i * 10000, :, :, :],

y_train[(i - 1) * 10000: i * 10000]) = load_batch(fpath)

fpath = os.path.join(path, 'test_batch')

x_test, y_test = load_batch(fpath)

y_train = np.reshape(y_train, (len(y_train), 1))

y_test = np.reshape(y_test, (len(y_test), 1))

if K.image_data_format() == 'channels_last':

x_train = x_train.transpose(0, 2, 3, 1)

x_test = x_test.transpose(0, 2, 3, 1)

return (x_train, y_train), (x_test, y_test)

data_batch_i 存放了cifar的训练集数据,每个文件1万条数据,采用pickle的方式进行序列化数据,利用pickle.load()的方式加载文件并反序列化为之前的dict(),该字典中有’data’和’label’两个key,分别存放了数据和标签。

def load_batch(fpath, label_key='labels'):

"""Internal utility for parsing CIFAR data.

# Arguments

fpath: path the file to parse.

label_key: key for label data in the retrieve

dictionary.

# Returns

A tuple `(data, labels)`.

"""

with open(fpath, 'rb') as f:

if sys.version_info < (3,):

d = cPickle.load(f)

else:

d = cPickle.load(f, encoding='bytes')

# decode utf8

d_decoded = {}

for k, v in d.items():

d_decoded[k.decode('utf8')] = v

d = d_decoded

data = d['data']

labels = d[label_key]

data = data.reshape(data.shape[0], 3, 32, 32)

return data, labels

where2 go 团队

   

微信号:算法与编程之美          

深度学习实战 cifar数据集预处理技术分析

长按识别二维码关注我们!

温馨提示: 点击页面右下角 “写留言”发表评论,期待您的参与!期待您的转发!


以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持 码农网

查看所有标签

猜你喜欢:

本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们

分布式服务架构:原理、设计与实战

分布式服务架构:原理、设计与实战

李艳鹏、杨彪 / 电子工业出版社 / 2017-8 / 89.00

《分布式服务架构:原理、设计与实战》全面介绍了分布式服务架构的原理与设计,并结合作者在实施微服务架构过程中的实践经验,总结了保障线上服务健康、可靠的最佳方案,是一本架构级、实战型的重量级著作。 《分布式服务架构:原理、设计与实战》以分布式服务架构的设计与实现为主线,由浅入深地介绍了分布式服务架构的方方面面,主要包括理论和实践两部分。理论上,首先介绍了服务架构的背景,以及从服务化架构到微服务架......一起来看看 《分布式服务架构:原理、设计与实战》 这本书的介绍吧!

JS 压缩/解压工具
JS 压缩/解压工具

在线压缩/解压 JS 代码

XML、JSON 在线转换
XML、JSON 在线转换

在线XML、JSON转换工具

HSV CMYK 转换工具
HSV CMYK 转换工具

HSV CMYK互换工具