用于完成神经网络参数的训练
定义训练过程中的超参数
#coding:utf-8
BATCH_SIZE = 50#100#batch
LEARNING_RATE_BASE = 0.005#学习率
LEARNING_RATE_DECAY = 0.99#学习率的衰减率
REGULARIZER = 0.0001#正则化项权重
STEPS = 50000#迭代次数
MOVING_AVERAGE_DECAY = 0.99#滑动平均衰减率
MODEL_SAVE_PATH="./model/"#保存模型的路径
MODEL_NAME="mnist_model"#模型命名
完成反向传播过程
给x, y_ 是占位
调用前向传播过程
求含有正则化的损失值
实现指数衰减学习率
实现滑动平均模型
将train_step和ema_op两个训练操作绑定到train_op上
实例化一个保存和恢复变量的saver,并创建一个会话
defbackward(mnist):
#x,y_占位
x =tf.placeholder(tf.float32,[
BATCH_SIZE,
mnist_lenet5_forward.IMAGE_SIZE,
mnist_lenet5_forward.IMAGE_SIZE,
mnist_lenet5_forward.NUM_CHANNELS])
y_ =tf.placeholder(tf.float32,[None,mnist_lenet5_forward.OUTPUT_NODE])
#前向传播
y =mnist_lenet5_forward.forward(x,True,REGULARIZER)
#声明一个全局计数器,并输出化为0
global_step =tf.Variable(0,trainable=False)
#先是对网络Zui后一层的输出y做softmax,再将此向量和实际标签值做交叉熵
ce =tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y,labels=tf.argmax(y_,1))
#再对得到的向量求均值就得到 loss
cem =tf.reduce_mean(ce)
#添加正则化中的 losses
loss = cem +tf.add_n(tf.get_collection('losses'))
#实现指数级的减小学习率
learning_rate =tf.train.exponential_decay(
LEARNING_RATE_BASE,
global_step,
mnist.train.num_examples/ BATCH_SIZE,
LEARNING_RATE_DECAY,
staircase=True)
#传入学习率,构造一个实现梯度下降算法的优化器,再通过使用minimize更新存储要训练的变量的列表来减小loss
train_step =tf.train.GradientDescentOptimizer(learning_rate).minimize(loss,global_step=global_step)
#实现滑动平均模型,参数MOVING_AVERAGE_DECAY用于控制模型更新的速度
ema =tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY,global_step)
ema_op =ema.apply(tf.trainable_variables())
#将train_step和ema_op两个训练操作绑定到train_op
withtf.control_dependencies([train_step,ema_op]):
train_op =tf.no_op(name='train')
#实例化一个保存和恢复变量的saver
saver =tf.train.Saver()
#创建一个会话,并通过python中的上下文管理器来管理这个会话
withtf.Session()as sess:
init_op =tf.global_variables_initializer()
sess.run(init_op)
# 通过checkpoint文件定位到Zui新保存的模型
ckpt =tf.train.get_checkpoint_state(MODEL_SAVE_PATH)
if ckpt andckpt.model_checkpoint_path:
saver.restore(sess,ckpt.model_checkpoint_path)
for i inrange(STEPS):
#读取一个batch的数据
xs, ys =mnist.train.next_batch(BATCH_SIZE)
#将输入数据xs转换成与网络输入相同形状的矩阵
reshaped_xs =np.reshape(xs,(
BATCH_SIZE,
mnist_lenet5_forward.IMAGE_SIZE,
mnist_lenet5_forward.IMAGE_SIZE,
mnist_lenet5_forward.NUM_CHANNELS))
#喂入训练图像和标签,开始训练
_, loss_value,step =sess.run([train_op,loss,global_step],feed_dict={x:reshaped_xs,y_: ys})
if i % 100== 0:
print("After %dtraining step(s), loss on training batch is%g." %(step,loss_value))
saver.save(sess,os.path.join(MODEL_SAVE_PATH,MODEL_NAME),global_step=g