TensorFlow.js 卷积神经网络手写数字识别

栏目: 编程工具 · 发布时间: 7年前

内容简介:演示开始时需要加载大概调整训练集的大小, 观察测试结果的准确性

https://github-laziji.github.io/digit-recognizer/

演示开始时需要加载大概 100M 的训练数据, 稍等片刻

调整训练集的大小, 观察测试结果的准确性

数据来源

数据来源与 https://www.kaggle.com 中的一道题目 digit-recognizer

题目给出 42000 条训练数据(包含图片和标签)以及 28000 条测试数据(只包含图片)

要求给这些测试数据打上标签 [0-9] 描述该图像显示的是哪个数字, 要尽可能的准确

网站中还有许多其他的机器学习的题目以及数据, 是个很好的练手的地方

实现

TensorFlow 是一个开源的机器学习库, 利用这个库我们可以快速地构建机器学习项目

这里我们使用 TensorFlow.js 来实现识别手写数字

创建模型

卷积神经网络的第一层有两种作用, 它既是输入层也是执行层, 接收 IMAGE_H * IMAGE_W 大小的黑白像素

最后一层是输出层, 有10个输出单元, 代表着 0-9 这十个值的概率分布, 例如 Label=2 , 输出为 [0.02,0.01,0.9,...,0.01]

function createConvModel() {
  const model = tf.sequential();

  model.add(tf.layers.conv2d({
    inputShape: [IMAGE_H, IMAGE_W, 1],
    kernelSize: 3,
    filters: 16,
    activation: 'relu'
  }));

  model.add(tf.layers.maxPooling2d({ poolSize: 2, strides: 2 }));
  model.add(tf.layers.conv2d({ kernelSize: 3, filters: 32, activation: 'relu' }));
  model.add(tf.layers.maxPooling2d({ poolSize: 2, strides: 2 }));
  model.add(tf.layers.conv2d({ kernelSize: 3, filters: 32, activation: 'relu' }));
  model.add(tf.layers.flatten({}));

  model.add(tf.layers.dense({ units: 64, activation: 'relu' }));
  model.add(tf.layers.dense({ units: 10, activation: 'softmax' }));

  return model;
}

训练模型

我们选择适当的优化器和损失函数, 来编译模型

async function train() {

  ui.trainLog('Create model...');
  model = createConvModel();
  
  ui.trainLog('Compile model...');
  const optimizer = 'rmsprop';
  model.compile({
    optimizer,
    loss: 'categoricalCrossentropy',
    metrics: ['accuracy'],
  });
  const trainData = Data.getTrainData(ui.getTrainNum());
  
  ui.trainLog('Training model...');
  await model.fit(trainData.xs, trainData.labels, {});

  ui.trainLog('Completed!');
  ui.trainCompleted();
}

测试

这里测试一组测试数据, 返回对应的标签, 即十个输出单元中概率最高的下标

function testOne(xs){
  if(!model){
    ui.viewLog('Need to train the model first');
    return;
  }
  ui.viewLog('Testing...');
  let output = model.predict(xs);
  ui.viewLog('Completed!');
  output.print();
  const axis = 1;
  const predictions = output.argMax(axis).dataSync();
  return predictions[0];
}

以上所述就是小编给大家介绍的《TensorFlow.js 卷积神经网络手写数字识别》,希望对大家有所帮助,如果大家有任何疑问请给我留言,小编会及时回复大家的。在此也非常感谢大家对 码农网 的支持!

查看所有标签

猜你喜欢:

本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们

你凭什么做好互联网

你凭什么做好互联网

曹政 / 中国友谊出版公司 / 2016-12 / 42.00元

为什么有人可以预见商机、超越景气,在不确定环境下表现更出色? 在规则之外,做好互联网,还有哪些关键秘诀? 当环境不给机会,你靠什么翻身? 本书为“互联网百晓生”曹政20多年互联网经验的总结,以严谨的逻辑思维分析个人与企业在互联网发展中的一些错误思想及做法,并给出正确解法。 从技术到商业如何实现,每个发展阶段需要匹配哪些能力、分解哪些目标、落实哪些策略都一一点出,并在......一起来看看 《你凭什么做好互联网》 这本书的介绍吧!

HTML 压缩/解压工具
HTML 压缩/解压工具

在线压缩/解压 HTML 代码

随机密码生成器
随机密码生成器

多种字符组合密码

MD5 加密
MD5 加密

MD5 加密工具