基于MNIST数据集实现2层神经网络案例实战-大数据ML样本集案例实战

栏目: 数据库 · 发布时间: 5年前

内容简介:版权声明:本套技术专栏是作者(秦凯新)平时工作的总结和升华,通过从真实商业环境抽取案例进行总结和分享,并给出商业应用的调优建议和集群环境容量规划等内容,请持续关注本套博客。QQ邮箱地址:1120746959@qq.com,如有任何学术交流,可随时联系。基本的神经网络案例,在于真正的入门神经网络的构建。版权声明:本套技术专栏是作者(秦凯新)平时工作的总结和升华,通过从真实商业环境抽取案例进行总结和分享,并给出商业应用的调优建议和集群环境容量规划等内容,请持续关注本套博客。QQ邮箱地址:1120746959@

版权声明:本套技术专栏是作者(秦凯新)平时工作的总结和升华,通过从真实商业环境抽取案例进行总结和分享,并给出商业应用的调优建议和集群环境容量规划等内容,请持续关注本套博客。QQ邮箱地址:1120746959@qq.com,如有任何学术交流,可随时联系。

1 神经网络基本结构定义

  • 28*28=784个像素点,第一层神经元256,第二层神经元128
基于MNIST数据集实现2层神经网络案例实战-大数据ML样本集案例实战

2 神经网络构建

  • 变量初始化

    import numpy as np
      import tensorflow as tf
      import matplotlib.pyplot as plt
      import input_data
      mnist = input_data.read_data_sets('data/', one_hot=True)
      Extracting data/train-images-idx3-ubyte.gz
      Extracting data/train-labels-idx1-ubyte.gz
      Extracting data/t10k-images-idx3-ubyte.gz
      Extracting data/t10k-labels-idx1-ubyte.gz
    
      # NETWORK TOPOLOGIES
      #第一层神经元
      n_hidden_1 = 256 
      #第二层神经元
      n_hidden_2 = 128
      #28*28 784像素点
      n_input    = 784 
      # 类别10
      n_classes  = 10  
      
      # INPUTS AND OUTPUTS
      x = tf.placeholder("float", [None, n_input])
      y = tf.placeholder("float", [None, n_classes])
          
      # NETWORK PARAMETERS
      stddev = 0.1
      #初始化
      weights = {
          'w1': tf.Variable(tf.random_normal([n_input, n_hidden_1], stddev=stddev)),
          'w2': tf.Variable(tf.random_normal([n_hidden_1, n_hidden_2], stddev=stddev)),
          'out': tf.Variable(tf.random_normal([n_hidden_2, n_classes], stddev=stddev))
      }
      #初始化
      biases = {
          'b1': tf.Variable(tf.random_normal([n_hidden_1])),
          'b2': tf.Variable(tf.random_normal([n_hidden_2])),
          'out': tf.Variable(tf.random_normal([n_classes]))
      }
      print ("NETWORK READY")
    复制代码
  • 前向传播(每一层增加激活函数sigmoid,最后一层不加sigmoid)

    def multilayer_perceptron(_X, _weights, _biases):
          layer_1 = tf.nn.sigmoid(tf.add(tf.matmul(_X, _weights['w1']), _biases['b1'])) 
          layer_2 = tf.nn.sigmoid(tf.add(tf.matmul(layer_1, _weights['w2']), _biases['b2']))
          return (tf.matmul(layer_2, _weights['out']) + _biases['out'])
    复制代码
  • 损失变量和优化器定义

  • softmax_cross_entropy_with_logits交叉熵损失函数(参数pred预测值),reduce_mean除以样本总数。

  • GradientDescentOptimizer采用梯度下降优化求解

    # PREDICTION
      pred = multilayer_perceptron(x, weights, biases)
      
      # LOSS AND OPTIMIZER
      cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(pred, y)) 
      optm = tf.train.GradientDescentOptimizer(learning_rate=0.001).minimize(cost) 
      
      #准确率求解
      corr = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))    
      accr = tf.reduce_mean(tf.cast(corr, "float"))
      
      # INITIALIZER
      init = tf.global_variables_initializer()
      print ("FUNCTIONS READY")
    复制代码
  • 按照Batch迭代

    training_epochs = 20
      batch_size      = 100
      display_step    = 4
      # LAUNCH THE GRAPH
      sess = tf.Session()
      sess.run(init)
      # OPTIMIZE
      for epoch in range(training_epochs):
          avg_cost = 0.
          total_batch = int(mnist.train.num_examples/batch_size)
          
          # ITERATION(按照Batch迭代,每一次迭代100)
          for i in range(total_batch):
              batch_xs, batch_ys = mnist.train.next_batch(batch_size)
              #填充值
              feeds = {x: batch_xs, y: batch_ys}
              #sess.run(模型训练)
              sess.run(optm, feed_dict=feeds)
              avg_cost += sess.run(cost, feed_dict=feeds)
          avg_cost = avg_cost / total_batch
          # DISPLAY
          if (epoch+1) % display_step == 0:
              print ("Epoch: %03d/%03d cost: %.9f" % (epoch, training_epochs, avg_cost))
              feeds = {x: batch_xs, y: batch_ys}
              
              #sess.run(准确率求解)
              train_acc = sess.run(accr, feed_dict=feeds)
              print ("TRAIN ACCURACY: %.3f" % (train_acc))
              feeds = {x: mnist.test.images, y: mnist.test.labels}
              test_acc = sess.run(accr, feed_dict=feeds)
              print ("TEST ACCURACY: %.3f" % (test_acc))
      print ("OPTIMIZATION FINISHED")
    复制代码

3 总结

基本的神经网络案例,在于真正的入门神经网络的构建。

版权声明:本套技术专栏是作者(秦凯新)平时工作的总结和升华,通过从真实商业环境抽取案例进行总结和分享,并给出商业应用的调优建议和集群环境容量规划等内容,请持续关注本套博客。QQ邮箱地址:1120746959@qq.com,如有任何学术交流,可随时联 秦凯新 于深圳 2018120892153


以上就是本文的全部内容,希望本文的内容对大家的学习或者工作能带来一定的帮助,也希望大家多多支持 码农网

查看所有标签

猜你喜欢:

本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们

Go Web 编程

Go Web 编程

[新加坡]Sau Sheong Chang(郑兆雄) / 黄健宏 / 人民邮电出版社 / 2017-11-22 / 79

《Go Web 编程》原名《Go Web Programming》,原书由新加坡开发者郑兆雄(Sau Sheong Chang)创作、 Manning 出版社出版,人名邮电出版社引进了该书的中文版权,并将其交由黄健宏进行翻译。 《Go Web 编程》一书围绕一个网络论坛 作为例子,教授读者如何使用请求处理器、多路复用器、模板引擎、存储系统等核心组件去构建一个 Go Web 应用,然后在该应用......一起来看看 《Go Web 编程》 这本书的介绍吧!

MD5 加密
MD5 加密

MD5 加密工具

html转js在线工具
html转js在线工具

html转js在线工具

HEX CMYK 转换工具
HEX CMYK 转换工具

HEX CMYK 互转工具