@author: huangyongye
@creat_date: 2017-03-09
前言: 根據(jù)我本人學(xué)習(xí) TensorFlow 實(shí)現(xiàn) LSTM 的經(jīng)歷,發(fā)現(xiàn)網(wǎng)上雖然也有不少教程,其中很多都是根據(jù)官方給出的例子,用多層 LSTM 來實(shí)現(xiàn) PTBModel 語言模型,比如:
tensorflow筆記:多層LSTM代碼分析
但是感覺這些例子還是太復(fù)雜了,所以這里寫了個(gè)比較簡單的版本,雖然不優(yōu)雅,但是還是比較容易理解。
如果你想了解 LSTM 的原理的話(前提是你已經(jīng)理解了普通 RNN 的原理),可以參考我前面翻譯的博客:
(譯)理解 LSTM 網(wǎng)絡(luò) (Understanding LSTM Networks by colah)
如果你想了解 RNN 原理的話,可以參考 AK 的博客:
The Unreasonable Effectiveness of Recurrent Neural Networks
本例不講原理。通過本例,你可以了解到單層 LSTM 的實(shí)現(xiàn),多層 LSTM 的實(shí)現(xiàn)。輸入輸出數(shù)據(jù)的格式。 RNN 的 dropout layer 的實(shí)現(xiàn)。
# -*- coding:utf-8 -*-import tensorflow as tfimport numpy as npfrom tensorflow.contrib import rnnfrom tensorflow.examples.tutorials.mnist import input_data# 設(shè)置 GPU 按需增長config = tf.ConfigProto()config.gpu_options.allow_growth = Truesess = tf.Session(config=config)# 首先導(dǎo)入數(shù)據(jù),看一下數(shù)據(jù)的形式mnist = input_data.read_data_sets('MNIST_data', one_hot=True)print mnist.train.images.shape
Extracting MNIST_data/train-images-idx3-ubyte.gzExtracting MNIST_data/train-labels-idx1-ubyte.gzExtracting MNIST_data/t10k-images-idx3-ubyte.gzExtracting MNIST_data/t10k-labels-idx1-ubyte.gz(55000, 784)
1. 首先設(shè)置好模型用到的各個(gè)超參數(shù)
lr = 1e-3# 在訓(xùn)練和測(cè)試的時(shí)候,我們想用不同的 batch_size.所以采用占位符的方式batch_size = tf.placeholder(tf.int32) # 注意類型必須為 tf.int32# batch_size = 128# 每個(gè)時(shí)刻的輸入特征是28維的,就是每個(gè)時(shí)刻輸入一行,一行有 28 個(gè)像素input_size = 28# 時(shí)序持續(xù)長度為28,即每做一次預(yù)測(cè),需要先輸入28行timestep_size = 28# 每個(gè)隱含層的節(jié)點(diǎn)數(shù)hidden_size = 256# LSTM layer 的層數(shù)layer_num = 2# 最后輸出分類類別數(shù)量,如果是回歸預(yù)測(cè)的話應(yīng)該是 1class_num = 10_X = tf.placeholder(tf.float32, [None, 784])y = tf.placeholder(tf.float32, [None, class_num])keep_prob = tf.placeholder(tf.float32)
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
2. 開始搭建 LSTM 模型,其實(shí)普通 RNNs 模型也一樣
# 把784個(gè)點(diǎn)的字符信息還原成 28 * 28 的圖片# 下面幾個(gè)步驟是實(shí)現(xiàn) RNN / LSTM 的關(guān)鍵##################################################################### **步驟1:RNN 的輸入shape = (batch_size, timestep_size, input_size) X = tf.reshape(_X, [-1, 28, 28])# **步驟2:定義一層 LSTM_cell,只需要說明 hidden_size, 它會(huì)自動(dòng)匹配輸入的 X 的維度lstm_cell = rnn.BasicLSTMCell(num_units=hidden_size, forget_bias=1.0, state_is_tuple=True)# **步驟3:添加 dropout layer, 一般只設(shè)置 output_keep_problstm_cell = rnn.DropoutWrapper(cell=lstm_cell, input_keep_prob=1.0, output_keep_prob=keep_prob)# **步驟4:調(diào)用 MultiRNNCell 來實(shí)現(xiàn)多層 LSTMmlstm_cell = rnn.MultiRNNCell([lstm_cell] * layer_num, state_is_tuple=True)# **步驟5:用全零來初始化stateinit_state = mlstm_cell.zero_state(batch_size, dtype=tf.float32)# **步驟6:方法一,調(diào)用 dynamic_rnn() 來讓我們構(gòu)建好的網(wǎng)絡(luò)運(yùn)行起來,重點(diǎn)看 state = layer_num * [c_state, h_state],我們?nèi)?state[-1][1] 作為最后輸出# outputs, state = tf.nn.dynamic_rnn(mlstm_cell, inputs=X, initial_state=init_state, time_major=False)# h_state = outputs[-1] # 或者 h_state = state[-1][1]# *************** 為了更好的理解 LSTM 工作原理,我們把上面 步驟6 中的函數(shù)自己來實(shí)現(xiàn) ***************# 通過查看文檔你會(huì)發(fā)現(xiàn), RNNCell 都提供了一個(gè) __call__()函數(shù)(見最后附),我們可以用它來展開實(shí)現(xiàn)LSTM按時(shí)間步迭代。# **步驟6:方法二,按時(shí)間步展開計(jì)算outputs = list()state = init_stateh_state_list = list() # 這句非必要,只是為了后面可視化加上來而已with tf.variable_scope('RNN'): for timestep in range(timestep_size): if timestep > 0: tf.get_variable_scope().reuse_variables() # 這里的state保存了每一層 LSTM 的狀態(tài) (cell_output, state) = mlstm_cell(X[:, timestep, :], state) outputs.append(cell_output)h_state = state[-1][1]
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
3. 設(shè)置 loss function 和 優(yōu)化器,展開訓(xùn)練并完成測(cè)試
# 上面 LSTM 部分的輸出會(huì)是一個(gè) [hidden_size] 的tensor,我們要分類的話,還需要接一個(gè) softmax 層# 首先定義 softmax 的連接權(quán)重矩陣和偏置# out_W = tf.placeholder(tf.float32, [hidden_size, class_num], name='out_Weights')# out_bias = tf.placeholder(tf.float32, [class_num], name='out_bias')# 開始訓(xùn)練和測(cè)試W = tf.Variable(tf.truncated_normal([hidden_size, class_num], stddev=0.1), dtype=tf.float32)bias = tf.Variable(tf.constant(0.1,shape=[class_num]), dtype=tf.float32)y_pre = tf.nn.softmax(tf.matmul(h_state, W) + bias)# 損失和評(píng)估函數(shù)cross_entropy = -tf.reduce_mean(y * tf.log(y_pre))train_op = tf.train.AdamOptimizer(lr).minimize(cross_entropy)correct_prediction = tf.equal(tf.argmax(y_pre,1), tf.argmax(y,1))accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))sess.run(tf.global_variables_initializer())for i in range(2000): _batch_size = 128 batch = mnist.train.next_batch(_batch_size) if (i+1)%200 == 0: train_accuracy = sess.run(accuracy, feed_dict={ _X:batch[0], y: batch[1], keep_prob: 1.0, batch_size: _batch_size}) # 已經(jīng)迭代完成的 epoch 數(shù): mnist.train.epochs_completed print "Iter%d, step %d, training accuracy %g" % ( mnist.train.epochs_completed, (i+1), train_accuracy) sess.run(train_op, feed_dict={_X: batch[0], y: batch[1], keep_prob: 0.5, batch_size: _batch_size})# 計(jì)算測(cè)試數(shù)據(jù)的準(zhǔn)確率print "test accuracy %g"% sess.run(accuracy, feed_dict={ _X: mnist.test.images, y: mnist.test.labels, keep_prob: 1.0, batch_size:mnist.test.images.shape[0]})
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
Iter0, step 200, training accuracy 0.851562Iter0, step 400, training accuracy 0.960938Iter1, step 600, training accuracy 0.984375Iter1, step 800, training accuracy 0.960938Iter2, step 1000, training accuracy 0.984375Iter2, step 1200, training accuracy 0.9375Iter3, step 1400, training accuracy 0.96875Iter3, step 1600, training accuracy 0.984375Iter4, step 1800, training accuracy 0.992188Iter4, step 2000, training accuracy 0.984375test accuracy 0.9858
我們一共只迭代不到5個(gè)epoch,在測(cè)試集上就已經(jīng)達(dá)到了0.9825的準(zhǔn)確率,可以看出來 LSTM 在做這個(gè)字符分類的任務(wù)上還是比較有效的,而且我們最后一次性對(duì) 10000 張測(cè)試圖片進(jìn)行預(yù)測(cè),才占了 725 MiB 的顯存。而我們?cè)谥暗膬蓪?CNNs 網(wǎng)絡(luò)中,預(yù)測(cè) 10000 張圖片一共用了 8721 MiB 的顯存,差了整整 12 倍呀??! 這主要是因?yàn)?RNN/LSTM 網(wǎng)絡(luò)中,每個(gè)時(shí)間步所用的權(quán)值矩陣都是共享的,可以通過前面介紹的 LSTM 的網(wǎng)絡(luò)結(jié)構(gòu)分析一下,整個(gè)網(wǎng)絡(luò)的參數(shù)非常少。
4. 可視化看看 LSTM 的是怎么做分類的
畢竟 LSTM 更多的是用來做時(shí)序相關(guān)的問題,要么是文本,要么是序列預(yù)測(cè)之類的,所以很難像 CNNs 一樣非常直觀地看到每一層中特征的變化。在這里,我想通過可視化的方式來幫助大家理解 LSTM 是怎么樣一步一步地把圖片正確的給分類。
import matplotlib.pyplot as plt
看下面我找了一個(gè)字符 3
print mnist.train.labels[4]
[ 0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]
我們先來看看這個(gè)字符樣子,上半部分還挺像 2 來的
X3 = mnist.train.images[4]img3 = X3.reshape([28, 28])plt.imshow(img3, cmap='gray')plt.show()
我們看看在分類的時(shí)候,一行一行地輸入,分為各個(gè)類別的概率會(huì)是什么樣子的。
X3.shape = [-1, 784]y_batch = mnist.train.labels[0]y_batch.shape = [-1, class_num]X3_outputs = np.array(sess.run(outputs, feed_dict={ _X: X3, y: y_batch, keep_prob: 1.0, batch_size: 1}))print X3_outputs.shapeX3_outputs.shape = [28, hidden_size]print X3_outputs.shape
(28, 1, 256)(28, 256)
h_W = sess.run(W, feed_dict={ _X:X3, y: y_batch, keep_prob: 1.0, batch_size: 1})h_bias = sess.run(bias, feed_dict={ _X:X3, y: y_batch, keep_prob: 1.0, batch_size: 1})h_bias.shape = [-1, 10]bar_index = range(class_num)for i in xrange(X3_outputs.shape[0]): plt.subplot(7, 4, i+1) X3_h_shate = X3_outputs[i, :].reshape([-1, hidden_size]) pro = sess.run(tf.nn.softmax(tf.matmul(X3_h_shate, h_W) + h_bias)) plt.bar(bar_index, pro[0], width=0.2 , align='center') plt.axis('off')plt.show()
在上面的圖中,為了更清楚地看到線條的變化,我把坐標(biāo)都去了,每一行顯示了 4 個(gè)圖,共有 7 行,表示了一行一行讀取過程中,模型對(duì)字符的識(shí)別??梢钥吹剑谥豢吹角懊娴膸仔邢袼貢r(shí),模型根本認(rèn)不出來是什么字符,隨著看到的像素越來越多,最后就基本確定了它是字符 3.
好了,本次就到這里。有機(jī)會(huì)再寫個(gè)優(yōu)雅一點(diǎn)的例子,哈哈。其實(shí)學(xué)這個(gè) LSTM 還是比較困難的,當(dāng)時(shí)寫 多層 CNNs 也就半天到一天的時(shí)間基本上就沒啥問題了,但是這個(gè)花了我大概整整三四天,而且是在我對(duì)原理已經(jīng)很了解(我自己覺得而已。。。)的情況下,所以學(xué)會(huì)了感覺還是有點(diǎn)小高興的~
17-04-19補(bǔ)充幾個(gè)資料:
- recurrent_network.py 一個(gè)簡單的 tensorflow LSTM 例子。
- Tensorflow下構(gòu)建LSTM模型進(jìn)行序列化標(biāo)注 介紹非常好的一個(gè) NLP 開源項(xiàng)目。(例子中有些函數(shù)可能在新版的 tensorflow 中已經(jīng)更新了,但并不影響理解)
5. 附:BASICLSTM.__call__()
'''code: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py''' def __call__(self, inputs, state, scope=None): """Long short-term memory cell (LSTM).""" with vs.variable_scope(scope or "basic_lstm_cell"): # Parameters of gates are concatenated into one multiply for efficiency. if self._state_is_tuple: c, h = state else: c, h = array_ops.split(value=state, num_or_size_splits=2, axis=1) concat = _linear([inputs, h], 4 * self._num_units, True, scope=scope) # ** 下面四個(gè) tensor,分別是四個(gè) gate 對(duì)應(yīng)的權(quán)重矩陣 # i = input_gate, j = new_input, f = forget_gate, o = output_gate i, j, f, o = array_ops.split(value=concat, num_or_size_splits=4, axis=1) # ** 更新 cell 的狀態(tài): # ** c * sigmoid(f + self._forget_bias) 是保留上一個(gè) timestep 的部分舊信息 # ** sigmoid(i) * self._activation(j) 是有當(dāng)前 timestep 帶來的新信息 new_c = (c * sigmoid(f + self._forget_bias) + sigmoid(i) * self._activation(j)) # ** 新的輸出 new_h = self._activation(new_c) * sigmoid(o) if self._state_is_tuple: new_state = LSTMStateTuple(new_c, new_h) else: new_state = array_ops.concat([new_c, new_h], 1) # ** 在(一般都是) state_is_tuple=True 情況下, new_h=new_state[1] # ** 在上面博文中,就有 cell_output = state[1] return new_h, new_state
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32