Deeplearning4j多层感知线性

栏目: 编程工具 · 发布时间: 5年前

内容简介:Deeplearning4j是一个java编写的深度学习商业框架,可以通过他们提供的API快速搭建出神经网络。根据官网的下载的demo中的MLPClassiferLinear(多层感知线性分类器),训练Titanic数据,求生存分析。MLPClassiferLinear
编辑推荐:

本文来自于jianshu,文章介绍了神经网络的搭建并训练,构造评估对象,用测试集做评估的相关内容。

Deeplearning4j是一个 java 编写的深度学习商业框架,可以通过他们提供的API快速搭建出神经网络。根据官网的下载的demo中的MLPClassiferLinear(多层感知线性分类器),训练Titanic数据,求生存分析。

MLPClassiferLinear

Deeplearning4j多层感知线性

idea中该类所在的位置

Titanic数据,数据集来自kaggle训练集:

Deeplearning4j多层感知线性

泰坦尼克数据

测试集:

Deeplearning4j多层感知线性

泰坦尼克数据

然后依葫芦画瓢,照着MLP的样子打代码。

Deeplearning4j多层感知线性

神经网络参数

一开始时设置神经网络的参数,根据官网介绍这些参数设置和最终训练的结果有比较大的关联。batchSize表示每一步抓取的数据量。

Deeplearning4j多层感知线性

参数设置

要搞清楚这里的参数是什么意思得先看神经网络的原理图

Deeplearning4j多层感知线性

图片来自百度百科

可以看到信息经过一层一层的处理最终变为输出,每一层都有一个输入(除了第一层)都有一个输出(除了最后一层),也就是说第一层我们输入原始的数据,然后经过N层的隐藏层得到结果。numInput表示原始数据的维数,例如在titanic中将数据处理为[sex,survived]则数据为1维,[sex,parch,survived]则数据为二维的(即我们用几个数据预测survived这个变量,f(x)=y,f(x,z)=y,可以这么理解)。然后这n维数据经过一定的组合变成numHiddenNodes维,距离f(x,z)=y->f(x1,x2,z1,z2...)=y,最终输出结果。numOutput表示y的取值可能性,在本例中suvived即为生或死,表示为0或1所以为2。

Deeplearning4j多层感知线性

加载数据

参数设置好了后可以加载数据了,但是这里的数据不能是上面展示的原始数据,需要对数据进行处理。因为神经网络接受的输入是向量类型的,即不可以出现字符,要把所有的信息都转换为数字,比如男,女可以表示为0,1等。deeplearning4j自己提供了data2vec工具类,在机器学习中数据的质量对结果有着格外的影响,好的数据处理能让预测结果更准确。这里我处理数据的方式采用的方法不科学且粗犷不可取。为了方便我直接把所有出现字符的行列都删除,最终得到的数据如下

Deeplearning4j多层感知线性

训练集

第一列是suvived,后面的依次为pclass,sibSp,parch

Deeplearning4j多层感知线性

测试集

kaggle提供的测试即未提供suvived列,故不能用来评估,我直接把数据裁取好后在第一列中简单粗暴的加上分类标签,随机的0,1。

Deeplearning4j多层感知线性

加载数据

构造方法参数的意思,lableIndex表示在你的数据中标签列的索引号(本例中即survived,就是要预测的那一列)numPossibleLabels表示标签列可能的值的个数(本例中即生或死,为2)

Deeplearning4j多层感知线性

搭建网络并训练

然后可以搭建神经网络,代码如上每需要一层则用.layer(...)添加,layer即英文中神经网络层的意思。上图中搭建的是一个两层的网络,第一层接受numInputs的向量为参数然后输出numHiddenNodes的向量为输出,第二层接受上一层的输出为输入并且输出结果numOutputs。然后创建模型对象并且用训练集训练模型。最终得到model对象。

这样一个神经网络就训练好啦!

但是我们并不知道这个网络的效果如何,这时测试数据集及登场了。(原本应该是测试数据集是一组正确的数据,用来评估模型的预测准确率的,但这里测试数据集的数据并不是正确的而是自己杜撰的)

Deeplearning4j多层感知线性

构造评估对象,并用测试集做评估

写好后,点击运行。过一会,就能预测输出结果了。

Deeplearning4j多层感知线性

模型评估结果

可以看到评估指标中又正确率,召回率等数据,其结果均为50%左右。也就是我们这模型可能和瞎猜的结果差不多。这个应该是数据处理的问题,用demo中的数据跑精确率可以达到99%以上。

deeplearning4j给广大java程序源提供一个很好进一步接触人工智能的机会。有兴趣的花可以自己扒数据,尝试这它预测诸如股票、天气等数据,看看自己训练的网络的准确性、实用性。

ps.未来应该是一个人工智能的时代,越来越多的框架让我们能更简洁的接触利用人工智能,这是很好的时代普通人一可以用这些看似高大上(其实也高达上)的东西做一些自己的想法。官网的demo还有许多其他例子,有时间可以拿出来研究。


以上所述就是小编给大家介绍的《Deeplearning4j多层感知线性》,希望对大家有所帮助,如果大家有任何疑问请给我留言,小编会及时回复大家的。在此也非常感谢大家对 码农网 的支持!

查看所有标签

猜你喜欢:

本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们

MySQL权威指南

MySQL权威指南

Randy Jay Yarger / 林琪、朱涛江 / 中国电力出版社 / 2003-11-1 / 49.00元

为一种开源数据库,MySQL已经成为最流行的服务器软件包之一。开发人员在其数据库引擎中提供了丰富的特性(只需很少的内存和CPU支持)。 因此,众多Linux和Unix服务器(以及一些Windows服务器)都采用MySQL作为其数据库引擎。由于MySQL作为Web站点后端时速度特别快而且相当方便,所有在目前流行的一个词LAMP(表示Linux、Apache、MySQL和Perl、Python或......一起来看看 《MySQL权威指南》 这本书的介绍吧!

JSON 在线解析
JSON 在线解析

在线 JSON 格式化工具

随机密码生成器
随机密码生成器

多种字符组合密码

URL 编码/解码
URL 编码/解码

URL 编码/解码