分享web开发知识

注册/登录|最近发布|今日推荐

主页 IT知识网页技术软件开发前端开发代码编程运营维护技术分享教程案例
当前位置:首页 > 运营维护

芝麻HTTP:TensorFlow LSTM MNIST分类

发布时间:2023-09-06 01:45责任编辑:熊小新关键词:暂无标签

本节来介绍一下使用 RNN 的 LSTM 来做 MNIST 分类的方法,RNN 相比 CNN 来说,速度可能会慢,但可以节省更多的内存空间。

初始化 首先我们可以先初始化一些变量,如学习率、节点单元数、RNN 层数等:

learning_rate = 1e-3num_units = 256num_layer = 3input_size = 28time_step = 28total_steps = 2000category_num = 10steps_per_validate = 100steps_per_test = 500batch_size = tf.placeholder(tf.int32, [])keep_prob = tf.placeholder(tf.float32, [])

然后还需要声明一下 MNIST 数据生成器:

import tensorflow as tffrom tensorflow.examples.tutorials.mnist import input_datamnist = input_data.read_data_sets(‘MNIST_data/‘, one_hot=True)

接下来常规声明一下输入的数据,输入数据用 x 表示,标注数据用 y_label 表示:

x = tf.placeholder(tf.float32, [None, 784])y_label = tf.placeholder(tf.float32, [None, 10])

这里输入的 x 维度是 [None, 784],代表 batch_size 不确定,输入维度 784,y_label 同理。

接下来我们需要对输入的 x 进行 reshape 操作,因为我们需要将一张图分为多个 time_step 来输入,这样才能构建一个 RNN 序列,所以这里直接将 time_step 设成 28,这样一来 input_size 就变为了 28,batch_size 不变,所以reshape 的结果是一个三维的矩阵:

x_shape = tf.reshape(x, [-1, time_step, input_size])

RNN 层 接下来我们需要构建一个 RNN 模型了,这里我们使用的 RNN Cell 是 LSTMCell,而且要搭建一个三层的 RNN,所以这里还需要用到 MultiRNNCell,它的输入参数是 LSTMCell 的列表。

所以我们可以先声明一个方法用于创建 LSTMCell,方法如下:

def cell(num_units): ???cell = tf.nn.rnn_cell.BasicLSTMCell(num_units=num_units) ???return DropoutWrapper(cell, output_keep_prob=keep_prob)

这里还加入了 Dropout,来减少训练过程中的过拟合。

接下来我们再利用它来构建多层的 RNN:

cells = tf.nn.rnn_cell.MultiRNNCell([cell(num_units) for _ in range(num_layer)])

注意这里使用了 for 循环,每循环一次新生成一个 LSTMCell,而不是直接使用乘法来扩展列表,因为这样会导致 LSTMCell 是同一个对象,导致构建完 MultiRNNCell 之后出现维度不匹配的问题。

接下来我们需要声明一个初始状态:

h0 = cells.zero_state(batch_size, dtype=tf.float32)

然后接下来调用 dynamic_rnn() 方法即可完成模型的构建了:

output, hs = tf.nn.dynamic_rnn(cells, inputs=x_shape, initial_state=h0)

这里 inputs 的输入就是 x 做了 reshape 之后的结果,初始状态通过 initial_state 传入,其返回结果有两个,一个 output 是所有 time_step 的输出结果,赋值为 output,它是三维的,第一维长度等于 batch_size,第二维长度等于 time_step,第三维长度等于 num_units。另一个 hs 是隐含状态,是元组形式,长度即 RNN 的层数 3,每一个元素都包含了 c 和 h,即 LSTM 的两个隐含状态。

这样的话 output 的最终结果可以取最后一个 time_step 的结果,所以可以使用:

output = output[:, -1, :]

或者直接取隐藏状态最后一层的 h 也是相同的:

h = hs[-1].h

在此模型中,二者是等价的。但注意如果用于文本处理,可能由于文本长度不一,而 padding,导致二者不同。

输出层 接下来我们再做一次线性变换和 Softmax 输出结果即可:

# Output Layerw = tf.Variable(tf.truncated_normal([num_units, category_num], stddev=0.1), dtype=tf.float32)b = tf.Variable(tf.constant(0.1, shape=[category_num]), dtype=tf.float32)y = tf.matmul(output, w) + b# Losscross_entropy = tf.nn.softmax_cross_entropy_with_logits(labels=y_label, logits=y)

这里的 Loss 直接调用了 softmax_cross_entropy_with_logits 先计算了 Softmax,然后计算了交叉熵。

训练和评估 最后再定义训练和评估的流程即可,在训练过程中每隔一定的 step 就输出 Train Accuracy 和 Test Accuracy:

# Traintrain = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cross_entropy)# Predictioncorrection_prediction = tf.equal(tf.argmax(y, axis=1), tf.argmax(y_label, axis=1))accuracy = tf.reduce_mean(tf.cast(correction_prediction, tf.float32))# Trainwith tf.Session() as sess: ???sess.run(tf.global_variables_initializer()) ???for step in range(total_steps + 1): ???????batch_x, batch_y = mnist.train.next_batch(100) ???????sess.run(train, feed_dict={x: batch_x, y_label: batch_y, keep_prob: 0.5, batch_size: batch_x.shape[0]}) ???????# Train Accuracy ???????if step % steps_per_validate == 0: ???????????print(‘Train‘, step, sess.run(accuracy, feed_dict={x: batch_x, y_label: batch_y, keep_prob: 0.5, ??????????????????????????????????????????????????????????????batch_size: batch_x.shape[0]})) ???????# Test Accuracy ???????if step % steps_per_test == 0: ???????????test_x, test_y = mnist.test.images, mnist.test.labels ???????????print(‘Test‘, step, ?????????????????sess.run(accuracy, feed_dict={x: test_x, y_label: test_y, keep_prob: 1, batch_size: test_x.shape[0]}))

运行 直接运行之后,只训练了几轮就可以达到 98% 的准确率:

Train 0 0.27Test 0 0.2223Train 100 0.87Train 200 0.91Train 300 0.94Train 400 0.94Train 500 0.99Test 500 0.9595Train 600 0.95Train 700 0.97Train 800 0.98

可以看出来 LSTM 在做 MNIST 字符分类的任务上还是比较有效的。

芝麻HTTP:TensorFlow LSTM MNIST分类

原文地址:https://www.cnblogs.com/zhimaruanjian/p/8537525.html

知识推荐

我的编程学习网——分享web前端后端开发技术知识。 垃圾信息处理邮箱 tousu563@163.com 网站地图
icp备案号 闽ICP备2023006418号-8 不良信息举报平台 互联网安全管理备案 Copyright 2023 www.wodecom.cn All Rights Reserved