TF之LSTM:利用LSTM算法对mnist手写数字图片数据集(TF函数自带)训练、评估(偶尔100%准确度,交叉熵验证)
TF之LSTM:利用LSTM算法对mnist手写数字图片数据集(TF函数自带)训练、评估(偶尔100%准确度,交叉熵验证)
输出结果
第 0 accuracy 0.125
第 20 accuracy 0.6484375
第 40 accuracy 0.78125
第 60 accuracy 0.9296875
第 80 accuracy 0.8671875
第 100 accuracy 0.90625
第 120 accuracy 0.8671875
第 140 accuracy 0.8671875
第 160 accuracy 0.8671875
第 180 accuracy 0.921875
第 200 accuracy 0.890625
第 220 accuracy 0.953125
第 240 accuracy 0.921875
第 260 accuracy 0.9296875
第 280 accuracy 0.9140625
第 300 accuracy 0.921875
第 320 accuracy 0.9609375
第 340 accuracy 0.953125
第 360 accuracy 0.984375
第 380 accuracy 0.921875
第 400 accuracy 0.9453125
第 420 accuracy 0.921875
第 440 accuracy 0.9296875
第 460 accuracy 0.96875
第 480 accuracy 0.984375
第 500 accuracy 0.96875
第 520 accuracy 0.953125
第 540 accuracy 0.96875
第 560 accuracy 0.953125
第 580 accuracy 0.9921875
第 600 accuracy 0.984375
第 620 accuracy 0.953125
第 640 accuracy 0.953125
第 660 accuracy 0.9921875
第 680 accuracy 0.96875
第 700 accuracy 0.9765625
第 720 accuracy 0.96875
第 740 accuracy 0.9921875
第 760 accuracy 0.984375
第 780 accuracy 0.953125
设计思路
代码设计
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
lr=0.001
training_iters=100000
batch_size=128
n_inputs=28
n_steps=28
n_hidden_units=128
n_classes=10
x=tf.placeholder(tf.float32, [None,n_steps,n_inputs])
y=tf.placeholder(tf.float32, [None,n_classes])
weights ={
'in':tf.Variable(tf.random_normal([n_inputs,n_hidden_units])),
'out':tf.Variable(tf.random_normal([n_hidden_units,n_classes])),
}
biases ={
'in':tf.Variable(tf.constant(0.1,shape=[n_hidden_units,])),
'out':tf.Variable(tf.constant(0.1,shape=[n_classes,])),
}
def RNN(X,weights,biases):
X=tf.reshape(X,[-1,n_inputs])
X_in=tf.matmul(X,weights['in'])+biases['in']
X_in=tf.reshape(X_in,[-1,n_steps,n_hidden_units])
lstm_cell=tf.nn.rnn_cell.BasicLSTMCell(n_hidden_units,forget_bias=1.0,state_is_tuple=True)
__init__state=lstm_cell.zero_state(batch_size, dtype=tf.float32)
outputs,states=tf.nn.dynamic_rnn(lstm_cell,X_in,initial_state=__init__state,time_major=False)
outputs=tf.unpack(tf.transpose(outputs, [1,0,2]))
results=tf.matmul(outputs[-1],weights['out'])+biases['out']
return results
pred =RNN(x,weights,biases)
cost =tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=y))
train_op=tf.train.AdamOptimizer(lr).minimize(cost)
correct_pred=tf.equal(tf.argmax(pred,1),tf.argmax(y,1))
accuracy=tf.reduce_mean(tf.cast(correct_pred,tf.float32))
<br>
with tf.Session() as sess:
sess.run(init)
step=0
while step*batch_size < training_iters:
batch_xs,batch_ys=mnist.train.next_batch(batch_size)
batch_xs=batch_xs.reshape([batch_size,n_steps,n_inputs])
sess.run([train_op],feed_dict={
x:batch_xs,
y:batch_ys,})
if step%20==0:
print(sess.run(accuracy,feed_dict={
x:batch_xs,
y:batch_ys,}))
step+=1
相关文章
TF之LSTM:利用LSTM算法对mnist手写数字图片数据集训练、评估(偶尔100%准确度)
赞 (0)