博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
神经网络优化(三) - 全连接网络基础
阅读量:6278 次
发布时间:2019-06-22

本文共 18020 字,大约阅读时间需要 60 分钟。

本章节的主要目标是 - MNIST 数据集输出手写数字识别准确率

1 MNIST数据集

MNIST数据集共有7万张图片。

提供 28*28 像素点的 0~9 手写数字图片和标签 6 万张用于训练、1 万张用于测试

 

每张图片的 784 个像素点(28*28=784)组成长度为 784 的一维数组作为输入特征;图片中纯黑色像素为0,纯白色像素值为1。

图片的标签以一维数组形式给出,每个元素表示对应分类出现的概率。

例如:

将一张数字手写图片变成长度为长度为784的一维数组 [0.0.0.0.0.231 0.235 0.459 ……0.219 0.0.0.0.] ;

其对应标签为[0.0.0.0.0.0.1.0.0.0],该标签中的索引号为 6 的元素为1,表示数字 6 出现的概率为 100%;

则该图片对应的识别结果为 6 。

1.1 mnist 数据集的引入

from tensorflow.examples.tutorials.mnist import input_datamnist = input_data.read_data_sets('./data/', one_hot=True)

注解

' ./data/ '  - 数据集的存放路径。

one_hot = True  - 数据集的存取方式;当为 True 时表示以读热码的方式存取。

当read_data_sets() 函数运行时,会检查路径文件内容是否有数据集了,如果没有则会自动下载数据集。

下载完毕后会将 mnist 数据集分为训练集train、验证集validation 和测试集 test 存放 。在终端显示如下内容:

Extracting ./data/train-images-idx3-ubyte.gzExtracting ./data/train-labels-idx1-ubyte.gzExtracting ./data/t10k-images-idx3-ubyte.gzExtracting ./data/t10k-labels-idx1-ubyte.gz

 

1.2 查看各样本的数量

# 查看各样本的子集数量:# 训练集train样本数print("train data size:", mnist.train.mun_examples)# train data size:55000)# 验证集validation样本数print("validation data size:", mnist.validation.mun_examples)# validation data size:5000# 测试集 test 样本数print("test data size:", mnist.test.mun_examples)# test data size:10000

 1.3 查看标签和数据

在mnist数据集中,要想查看训练集中的图片标签,则使用下面函数,以查看第 0 张图片标签为例

 

>>> mnist.train.labels[0]

输出结果

array([0.,0.,0.,0.,0.,0.,1.,0.,0.,0])

1.4 查看mnist数据集中图片像素值

函数train.images() 可以查看mnist数据集图片的像素值,以第 0 张图片为例。

>>> mnist.train.images[0]

输出结果

array([0. ,0. ,0. ,        0. ,0. ,0. ,        0. ,0. ,0. ,        …  …  …])

1.5 取一部分数据喂入神经网络训练

采用的函数为 mnist.train.next_batch() 函数

# 定义一次喂入数据量BATCH_SIZE = 200xs, ys = mnist.train.next_batch(BATCH_SIZE)print("xs shape:", xs.shape)# xs.shape(200,784)print("ys shape:", ys.shape)# ys.shape(200,10)

函数 mnist.train.next_batch()中的参数 BATCH_SIZE,表示随机从训练集中抽取 BATCH_SIZE 个样本输入神经网络,并将样本的像素值和标签分布赋给xs 和 ys 。

在上述代码中,BATCH_SIZE = 200 ,表示一次将 200 个样本的像素值和标签分别赋值给xs 和 ys,故 xs 的形状为(200,784),ys 的形状为(200,10)。

1.6 实现“Mnist 数据集手写数字识别”的常用函数

1)tf.get_collection(' ')函数

该函数表示从 collection() 集合中取出全部变量生成一个列表。

2)tf.add()

表示将参数列表中对应元素相加。

x = tf.constant([[1, 2], [1, 2]])y = tf.constant([[1, 1], [1, 2]])z = tf.add(x, y)print(z)# [[2,3],[2,4]]

3)tf.cast(x, dtype)

表示将参数 x 转换为指定数据类型dtype。

4)tf.equal()

表示对比两个矩阵或向量间的元素;若对应元素相等,则返回 True,若对应元素不相等,则返回 False 。

A = [[1, 3, 4, 5, 6]]B = [[1, 3, 4, 3, 2]]with tf.Session() as sess:    print(sess.run(tf.equal(A, B)))# [[True True True False False]]

5)tf.reduce_mean(x, axis) 

用于求取矩阵或张量指定维度的平均值,axis 不指定具体值,则所有元素中取平均值。

6)tf.argmax(x, axis)

求算在指定axis维度下,参数x中最大值的索引号。

如:在 tf.argmax([1,0,0],1)函数中,axis 为 1,参数 x 为[1,0,0],表示在参数 x的第一个维度取最大值对应的索引号,故返回 0。

7)os.path.join() 

把参数字符按照路径命名规则拼接

import osos.path.join('/hello/', 'good/boy/', 'doiido')# '/hello/good/boy/doiido'

8)str.split() 函数

 字符串.split() 函数是按照指定 “ 拆分符 ” 对字符串进行拆分,并返回拆分列表。

'./model/mnist_model-1001'.split('/')[-1].split('-')[-1]

该示例中进行了两次拆分。

第一次拆分符号为 ‘ / ’,并从列表中提取索引为 [ -1 ] 的元素,也即倒数第一个元素。

第二次拆分符号为 ‘ - ’,并从列表中提取索引为 [ -1 ] 的元素,同样为列表中倒数第一个元素。

最终得到1001值。

9)tf.Graph().as_default()

 将当前图设置为默认图,并返回一个上下文管理器。

该函数一般与 with 关键字搭配使用,应用于将已经定义好的神经网络在计算图中复现。

with tf.Graph().as_default() as g

代码表示在 Graph()内定义的节点加入到计算图 g 中。

2 模型的保存与加载

2.1 模型的保存

在网络神经计算过程中,一般会间隔一定轮数保存一次神经网络模型,同时产生三个文件,

  • .meta文件 - 保存当前图结构
  • .index文件 - 保存当前参数名
  • .data文件 - 保存当前参数

在 Tensorflow 中如下所示:

# 实例化saver对象saver = tf.train.Saver()# 在with结构的for循环中,一定轮数时保存模型到当前会话。with tf.Session() as sess:    for i in range(STEPS):        # 拼接成/MODEL_SAVE_PATH/MODEL_NAME-global_step        if i % 轮数 == 0:            saver.save(sess, os.path.join(MODEL_SAVE_PATH,MODEL_NAME), global_step=global_step)

其中,tf.train.Saver()用来实例化saver对象。

神经网络每循环规定的轮数,将神经网络模型中所有参数等信息保存到指定的路径中,并在存放网络模型的文件夹名称中注明保存模型时的训练轮数

2.2 模型的加载

在测试神经网络效果时,需要将训练好的神经网络模型进行加载,在 Tensorflow 中这样表示:

with tf.Session() as sess:    ckpt = tf.train.get_checkpoint_state("存储路径")    if  ckpt and ckpt.model_checkpoint_path:        saver.restore(sess, ckpt.model_checkpoint_path)

在 with 结构中进行加载已保存的神经网络模型时,若 目标模型 存在于指定路径时,则将该神经网络模型加载到当前会话中。

2.3 滑动平均值的加载

在保存模型时,若模型中采用滑动平均时,则参数中的滑动平均值也会保存在相应文件中。

通过实例化 saver 对象,实现参数滑动平均值的加载,在 Tensorflow 中表示如下:

ema = tf.train.ExponentialMovingAverage("滑动平均基数")ema_restore = ema.variables_to_restore()saver = tf.train.Saver(ema_restore)

2.4 模型的准确率评估方法

在神经网络评估时,一般通过计算在一组数据上的识别准确率,来评估神经网络的效果。

在 Tensorflow 中的表达为:

correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

注:

y 表示在一组数据(即 batch_size 个数据)上神经网络模型的预测结果,其形状为 [ batch_size, 10 ] 每一行表示一张图的识别结果。

通过 tf.argmax(y, 1) 函数取出每张图片对应向量中最大值元素对应的索引值,组成长度为输入数据 batch_size 个的一维数组。1 - 仅在第一个维度上进行。

通过 tf.equal() 函数判断预测结果张量和实际张量的每个维度是否相等,若相等则返回 True,不相等则返回 False。

通过 tf.cast() 函数将得到的布尔型数值转化为实数型。

通过 tf.reduce_mean() 函数求平均值。该平均值就是模型在本组数据上的准确率。

3 模块化搭建神经网络八股

3.1 前向传播forward.py

def  forward(x, regularizer):     w=     b=    y=      return y # 实现参数W值def get_weight(shape, regularizer): # 实现参数b值def get_bias(shape):

  前向传播过程中,需要定义神经网络中的参数 w 和 偏置项 b,定义由输入到输出的网络结构。

通过定义函数 get_weight() 实现对参数 w 的设置,包括有参数 w 的形状和是否正则化;

通过定义函数 get_bias() 实现对偏置项 b 的设置。

3.2 反向传播backward.py

def backward(mnist):    x = tf.placeholder(dtype,  shape)    y_ = tf.placeholder(dtype, shape)    # 定义前向传播函数    y = forward()     global_step =     loss =      # < 正则化、指数衰减学习率、滑动平均 > 如果用,则用相关代码插在该处    train_step = tf.train.GradientDescentOptimizer(learning_rate). minimize(loss, global_step=global_step)     # 实例化  saver    saver = tf.train.Saver() with tf.Session() as sess:        # 初始化所有模型参数        tf.initialize_all_variables().run()         # 训练模型        for i in range(STEPS):             sess.run(train_step, feed_dict={x:  , y_: })             if  i % 轮数  == 0:                 print()               saver.save(  )

反向传播过程中,

tf.placeholder(dtype,  shape)函数实现训练样本 x 和样本标签 y_占位,函数参数 dtype 表示数据的类型,shape 表示数据的形状;

y 表示定义的前向传播函数 forward;

loss 表示定义的损失函数,一般为预测值与样本标签的交叉熵(或均方误差)与正则化损失之和;

train_step 表示利用优化算法对模型参数进行优化,

常用优化算法 GradientDescentOptimizer 、AdamOptimizer、MomentumOptimizer 算法,在上述代码中使用的 GradientDescentOptimizer 优化算法。

接着实例化 saver 对象,其中利用 tf.initialize_all_variables().run()函数实例化所有参数模型,利用 sess.run( )函数实现模型的训练优化过程,并每间隔一定轮数保存一次模型。

3.3 正则化、指数衰减学习率、滑动平均方法的设置

在 3.2 章节中留有正则化、指数衰减学习率、滑动平均方法的设置插入位置,具体如何设置则在本章节中体现。

3.3.1 正则化 regularization

当在神经网络训练时增加正则化时,需要在前向传播 (forward.py 文件)和反向传播(backward.py 文件)中增加正则化函数。

首先 在前向传播 forward.py 中加入下段代码

if regularizer != None: tf.add_to_collection('losses', tf.contrib.layers.l2_regularizer(regularizer)(w))

则将正则化的 w 值存入 losses 中。

其次 在反向传播 backward.py 中加入下段代码

ce = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, labels=tf.argmax(y_, 1))cem = tf.reduce_mean(ce)loss = cem + tf.add_n(tf.get_collection('losses'))

代码 tf.nn.sparse_softmax_cross_entropy_with_logits() 表示 softmax() 函数与交叉熵一起使用

同时将 losses 值增加到 loss 中。

3.3.2 指数衰减学习率learning_rate

在训练模型时,使用指数衰减学习率可以使模型在训练的前期快速收敛接近较优解,又可以保证模型在训练后期不会有太大波动。

运用指数衰减学习率,需要在反向传播文件 backward.py 中加入下段代码

learning_rate = tf.train.exponential_decay(    LEARNING_RATE_BASE,    global_step,    LEARNING_RATE_STEP,    LEARNING_RATE_DECAY,    staircase=True    )

3.3.3 滑动平均 ema

在模型训练时引入滑动平均可以使模型在测试数据上表现的更加健壮。

运用滑动平均,需要在反向传播文件 backward.py 中加入下段代码。

ema = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)ema_op = ema.apply(tf.trainable_variables())with tf.control_dependencies([train_step, ema_op]):    train_op = tf.no_op(name='train')

3.4 测试

通过对测试数据的预测得到准确率,从而判断出训练出的神经网络模型的性能好坏。

当准确率低时,可能原因有模型需要改进,或者是训练数据量太少导致过拟合。

神经网络模型训练后,便可用于测试数据集,验证神经网络的性能;结构如下:

def test(mnist):    with tf.Graph( ).as_default() as g:    # 定义 x y_ y,给出x y_占位    x = tf.placeholder(dtype,shape)    y_ = tf.placeholder(dtype,shape)    # 向前传播预测结果y    y = mnist_forward.forward(x, None)    # 实例化可还原滑动平均值的saver    ema = tf.train.ExponentialMovingAverage(滑动衰减率)     ema_restore = ema.variables_to_restore()     saver = tf.train.Saver(ema_restore)    # 计算正确率    correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(y_,1))     accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))    while True:        with tf.Session() as sess:            # 加载ckpt模型(加载训练好的模型参数)            ckpt = tf.train.get_checkpoint_state(存储路径)            # 如果已有ckpt模型则恢复            if ckpt and ckpt.model_checkpoint_path:                # 恢复会话                saver.restore(sess, ckpt.model_checkpoint_path)                # 恢复轮数                 global_step= ckpt.model_checkpoint_path.split(‘/’)[-1].split(‘-‘)[-1]                # 计算准确率  {x:测试数据,  y_:测试数据标签 })                accuracy_score = sess.run(accuracy, feed_dict = {x: mnist.test.images, y_mnist.test.labels})                # 打印提示                print(“After %s training step(s), test accuracy = %g”(global_step, accuracy_acore))            # 如果没有模型             else:                # 给出提示                print(“No checkpoint file found”)                return# 定义 main()函数def main():    # 加载测试数据集    mnist = input_data.read_data_sets("./data/", one_hot=True)     # 调用定义好的测试函数 test()    test(mnist) if name == ' main__':    main()

4 手写数字识别准确率输出

该项目共分为 3 个文件

  • mnist_forward.py  前向传播描述网络结构
  • mnist_backward.py 反向传播描述网络优化方法
  • mnist_test.py 测试 复现了计算途中的节点、计算模型在测试集中的准确率

4.1 mnist_forward.py

前向传播过程中,

  • 定义从输入到输出的神经网络架构。
  • 定义网络模型输入层个数、隐藏层节点数、输出层个数。
  • 定义网格参数 w、偏置项 b。 
import tensorflow as tf# 28*28 = 784个像素点,784像素点组成了一个一维数组INPUT_NODE = 784# 输出10个数,每个数对应的索引号出现的概率,实现了10分类OUTPUT_NODE = 10# 定义了隐藏层节点个数LAYER1_NODE = 500# def get_weight(shape, regularizer):    # 在训练神经网络时,随机生成参数 w    w = tf.Variable(tf.truncated_normal(shape,stddev=0.1))    # 若使用正则化,则将每个变量的正则化损失加到总损失集合 losses 中    if regularizer != None: tf.add_to_collection('losses', tf.contrib.layers.l2_regularizer(regularizer)(w))    return wdef get_bias(shape):      b = tf.Variable(tf.zeros(shape))      return b# 描述从输入到输出的数据流 def forward(x, regularizer):    # 第一层、第二层参数是直接输出的。    # 第一层参数    w1 = get_weight([INPUT_NODE, LAYER1_NODE], regularizer)    b1 = get_bias([LAYER1_NODE])    y1 = tf.nn.relu(tf.matmul(x, w1) + b1)    # 第二层参数    w2 = get_weight([LAYER1_NODE, OUTPUT_NODE], regularizer)    b2 = get_bias([OUTPUT_NODE])    y = tf.matmul(y1, w2) + b2    return y

 在上述代码中,规定网格输入节点为 (28*28)784 个,将每张图片的像素值转变为一维向量,

隐藏层节点 500 个,因此由输入层到隐藏层的参数 w1 形状为 [784,500];

输出节点 10 个,因此由隐藏层到输出层的参数 w2 形状为 [500,10];这里的10,代表数字0-9的十分类

参数满足截断正态分布,并使用正则化,将每个参数的正则化损失加到总损失中。


 

对下段文字保留理解

由输入层到隐藏层的偏置 b1 形状为长度为 500的一维数组,由隐藏层到输出层的偏置 b2 形状为长度为 10 的一维数组,初始化值为全 0。前向传播结构第一层为输入 x 与参数 w1 矩阵相乘加上偏置 b1,再经过 relu 函数,得到隐藏层输出 y1。前向传播结构第二层为隐藏层输出 y1 与参数 w2 矩阵相乘加上偏置 b2,得到输出 y。由于输出 y 要经过 softmax 函数,使其符合概率分布,故输出 y 不经过 relu 函数。


 

4.2 mnist_backward.py

反向传播过程是实现利用训练数据集对神经网络模型训练。通过降低损失函数值,实现网络模型参数的优化,从而得到准确率高、泛化能力强的神经网络模型。

import tensorflow as tffrom tensorflow.examples.tutorials.mnist import input_dataimport mnist_forwardimport os# 定义每轮喂入多少张图片BATCH_SIZE = 200# 最开始的学习率LEARNING_RATE_BASE = 0.1# 学习率的衰减率LEARNING_RATE_DECAY = 0.99# 正则化系数REGULARIZER = 0.0001# 训练多少轮STEPS = 50000# 滑动平均衰减率MOVING_AVERAGE_DECAY = 0.99# 模型保存路径MODEL_SAVE_PATH="./model/"# 模型保存的文件名MODEL_NAME="mnist_model"def backward(mnist):    # x y_ 用placeholder占位    x = tf.placeholder(tf.float32, [None, mnist_forward.INPUT_NODE])    y_ = tf.placeholder(tf.float32, [None, mnist_forward.OUTPUT_NODE])    # 调用前向传播的程序,计算y值    y = mnist_forward.forward(x, REGULARIZER)    # 为轮数计数器赋初值,并设定为不可训练    global_step = tf.Variable(0, trainable=False)    # 定义包含正则化的损失函数loss    ce = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, labels=tf.argmax(y_, 1))    cem = tf.reduce_mean(ce)    loss = cem + tf.add_n(tf.get_collection('losses'))    # 定义指数衰减学习率    learning_rate = tf.train.exponential_decay(        LEARNING_RATE_BASE,        global_step,        mnist.train.num_examples / BATCH_SIZE,        LEARNING_RATE_DECAY,        staircase=True)    # 定义训练过程    train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step=global_step)    # 定义滑动平均    ema = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)    ema_op = ema.apply(tf.trainable_variables())    with tf.control_dependencies([train_step, ema_op]):        train_op = tf.no_op(name='train')    # 实例化saver    saver = tf.train.Saver()    # 在会话结构中初始化所有变量    with tf.Session() as sess:        init_op = tf.global_variables_initializer()        sess.run(init_op)        for i in range(STEPS):            xs, ys = mnist.train.next_batch(BATCH_SIZE)            _, loss_value, step = sess.run([train_op, loss, global_step], feed_dict={x: xs, y_: ys})            if i % 1000 == 0:                print("After %d training step(s), loss on training batch is %g." % (step, loss_value))                saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME), global_step=global_step)def main():    mnist = input_data.read_data_sets("./data/", one_hot=True)    backward(mnist)if __name__ == '__main__':    main()

在上述代码中,

  • 01) 定义每轮喂入神经网络的图片数量、初始学习率、学习率衰减率、正则化系数、训练轮数、模型保存路径以及模型保存名称等相关信息。
  • 02) 在反向传播函数中,
  • 03) 读入mnist,
  • 04) 训练数据 x 和标签 y_ 进行placeholder 占位,
  • 05) 调用 mnist_forward.py 文件中 forward() 函数,并设置正则化,
  • 06) 计算训练数据集上的预测结果 y,
  • 07) 给当前计算轮数计数器赋值,并设定为不可训练类型
  • 08) 调用包含所有参数正则化损失的损失函数 loss
  • 09) 设定指数衰减学习率 learning_rate
  • 10) 使用梯度衰减算法对模型优化,降低损失函数,并定义参数的滑动平均
  • 11) with 结构中,实现所有参数初始化,每次喂入 batch_size 组(本代码为 200 组)训练数据和对应标签,循环迭代 steps 轮,并每隔 1000 轮打印出一次损失函数值信息,并将当前会话加保存到指定路径。
  • 12) 通过主函数 main() ,加载指定路径下的训练数据集,并调用规定的 backward() 函数训练模型。

4.3 mnist_test.py

当训练完模型后,给神经网络模型输入测试集验证网络的准确性和泛化性。注意,所用的测试集和训练集是相互独立的。

#coding:utf-8import timeimport tensorflow as tffrom tensorflow.examples.tutorials.mnist import input_dataimport mnist_forwardimport mnist_backward# 规定程序循环间隔是 5 秒TEST_INTERVAL_SECS = 5def test(mnist):    # with 语句复现计算图    with tf.Graph().as_default() as g:        # placeholder 为x y_占位        x = tf.placeholder(tf.float32, [None, mnist_forward.INPUT_NODE])        y_ = tf.placeholder(tf.float32, [None, mnist_forward.OUTPUT_NODE])        # 前向传播过程计算y的值        y = mnist_forward.forward(x, None)        # 实例化带滑动平均的saver对象,这样所有参数在会话中被加载时将被赋值为各自的滑动平均值        ema = tf.train.ExponentialMovingAverage(mnist_backward.MOVING_AVERAGE_DECAY)        ema_restore = ema.variables_to_restore()        saver = tf.train.Saver(ema_restore)                # 计算准确率        correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))        accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))        while True:            with tf.Session() as sess:                ckpt = tf.train.get_checkpoint_state(mnist_backward.MODEL_SAVE_PATH)                # 判单是否有模型                if ckpt and ckpt.model_checkpoint_path:                    saver.restore(sess, ckpt.model_checkpoint_path)                    global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]                    accuracy_score = sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels})                    print("After %s training step(s), test accuracy = %g" % (global_step, accuracy_score))                else:                    print('No checkpoint file found')                    return            time.sleep(TEST_INTERVAL_SECS)def main():    mnist = input_data.read_data_sets("./data/", one_hot=True)    test(mnist)if __name__ == '__main__':    main()

在 with 结构中,加载指定路径下的 ckpt,若模型存在,则加载出模型到当前对话,在测试数据集上进行准确率验证,并打印出当前轮数下的准确率,若模型不存在,则打印出模型不存在的提示,从而 test()函数完成。 通过主函数 main(),加载指定路径下的测试数据集,并调用规定的 test 函数,进行模型在测试集上的准确率验证。

4.4 运行

略。

5 修改代码,实现断点续训

在 mnist_backward.py 会话 Session 中添加下段代码,把 ckpt 恢复到当前会话中,下文三行代码可以实现将 ckpt 中的 w 、b 值恢复到会话中, 实现断点续训;这样就不用担心因突然意外而造成的训练失败。

ckpt = tf.train.get_checkpoint_state(MODEL_SAVE_PATH)if ckpt and ckpt.model_checkpoint_path:    saver.restore(sess, ckpt.model_checkpoint_path)

则为

import tensorflow as tffrom tensorflow.examples.tutorials.mnist import input_dataimport mnist_forwardimport os# 定义每轮喂入多少张图片BATCH_SIZE = 200# 最开始的学习率LEARNING_RATE_BASE = 0.1# 学习率的衰减率LEARNING_RATE_DECAY = 0.99# 正则化系数REGULARIZER = 0.0001# 训练多少轮STEPS = 50000# 滑动平均衰减率MOVING_AVERAGE_DECAY = 0.99# 模型保存路径MODEL_SAVE_PATH="./model/"# 模型保存的文件名MODEL_NAME="mnist_model"def backward(mnist):    # x y_ 用placeholder占位    x = tf.placeholder(tf.float32, [None, mnist_forward.INPUT_NODE])    y_ = tf.placeholder(tf.float32, [None, mnist_forward.OUTPUT_NODE])    # 调用前向传播的程序,计算y值    y = mnist_forward.forward(x, REGULARIZER)    # 为轮数计数器赋初值,并设定为不可训练    global_step = tf.Variable(0, trainable=False)    # 定义包含正则化的损失函数loss    ce = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, labels=tf.argmax(y_, 1))    cem = tf.reduce_mean(ce)    loss = cem + tf.add_n(tf.get_collection('losses'))    # 定义指数衰减学习率    learning_rate = tf.train.exponential_decay(        LEARNING_RATE_BASE,        global_step,        mnist.train.num_examples / BATCH_SIZE,        LEARNING_RATE_DECAY,        staircase=True)    # 定义训练过程    train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step=global_step)    # 定义滑动平均    ema = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)    ema_op = ema.apply(tf.trainable_variables())    with tf.control_dependencies([train_step, ema_op]):        train_op = tf.no_op(name='train')    # 实例化saver    saver = tf.train.Saver()    # 在会话结构中初始化所有变量    with tf.Session() as sess:        init_op = tf.global_variables_initializer()        sess.run(init_op)        # 添加断点续训        ckpt = tf.train.get_checkpoint_state(MODEL_SAVE_PATH)        if ckpt and ckpt.model_checkpoint_path:            saver.restore(sess, ckpt.model_checkpoint_path)        for i in range(STEPS):            xs, ys = mnist.train.next_batch(BATCH_SIZE)            _, loss_value, step = sess.run([train_op, loss, global_step], feed_dict={x: xs, y_: ys})            if i % 1000 == 0:                print("After %d training step(s), loss on training batch is %g." % (step, loss_value))                saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME), global_step=global_step)def main():    mnist = input_data.read_data_sets("./data/", one_hot=True)    backward(mnist)if __name__ == '__main__':    main()

 

注解:

tf.train.get_checkpoint_state(checkpoint_dir,latest_filename=None)

该函数表示如果断点文件夹中包含有效断点状态文件,则返回该文件。

参数说明:

checkpoint_dir:表示存储断点文件的目录

latest_filename=None:断点文件的可选名称,默认为“checkpoint” 2)saver.restore(sess, ckpt.model_checkpoint_path)

该函数表示恢复当前会话,将 ckpt 中的值赋给 w 和 b。

参数说明:

sess:表示当前会话,之前保存的结果将被加载入这个会话

ckpt.model_checkpoint_path:表示模型存储的位置,不需要提供模型的名字,它会去查看 checkpoint 文件,看看最新的是谁,叫做什么。

 

转载地址:http://tgyva.baihongyu.com/

你可能感兴趣的文章
关于再次查看已做的多选题状态逻辑问题
查看>>
动态下拉菜单,非hover
查看>>
政府安全资讯精选 2017年第十六期 工信部发布关于规范互联网信息服务使用域名的通知;俄罗斯拟建立备用DNS;Google打击安卓应用在未经同意情况下收集个人信...
查看>>
简单易懂的谈谈 javascript 中的继承
查看>>
多线程基础知识
查看>>
iOS汇编基础(四)指针和macho文件
查看>>
Laravel 技巧锦集
查看>>
Android 使用 ViewPager+RecyclerView+SmartRefreshLayout 实现顶部图片下拉视差效果
查看>>
Flutter之基础Widget
查看>>
写给0-3岁产品经理的12封信(第08篇)——产品运营能力
查看>>
ArcGIS Engine 符号自动化配置工具实现
查看>>
小程序 · 跳转带参数写法,兼容url的出错
查看>>
flutter error
查看>>
Flask框架从入门到精通之模型数据库配置(十一)
查看>>
10年重新出发
查看>>
2019年-年终总结
查看>>
聊聊elasticsearch的RoutingService
查看>>
让人抓头的Java并发(一) 轻松认识多线程
查看>>
从源码剖析useState的执行过程
查看>>
地包天如何矫正?
查看>>