1. slim.learning.train实例,训练VGG网络
    1. 1. 介绍
    2. 2. 应用流程:
    3. 3. 示例:训练vgg16模型

slim.learning.train实例,训练VGG网络

1. 介绍

TF-Slim在tensorflow/contrib/slim/python/slim/learning.py提供了一套简单但功能强大的工具。 这些功能包括一个训练函数,可以反复测量损失,计算梯度并将模型保存到磁盘,以及用于操纵梯度的几个便利函数。
调用 slim.learning.create_train_op 和 slim.learning.train 来实现

learning源代码:
https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/slim

2. 应用流程:

g = tf.Graph()

# Create the model and specify the losses...
...

total_loss = slim.losses.get_total_loss()
optimizer = tf.train.GradientDescentOptimizer(learning_rate)

# create_train_op ensures that each time we ask for the loss, the update_ops
# are run and the gradients being computed are applied too.
train_op = slim.learning.create_train_op(total_loss, optimizer)
logdir = ... # Where checkpoints are stored.

slim.learning.train(
    train_op,
    logdir,
    number_of_steps=1000,
    save_summaries_secs=300,
    save_interval_secs=600):

在本例中,slim.learning.train与train_op一起提供,用于计算损失和操作梯度步骤。

  • logdir指定存储检查点和事件文件的目录。
  • number_of_steps我们可以限制采取任何数字的梯度步数。 在这种情况下,我们要求采取1000个步骤。
  • save_summaries_secs = 300表示我们将每隔5分钟计算摘要,
  • save_interval_secs = 600表示我们将每10分钟保存一次模型检查点。

3. 示例:训练vgg16模型

import tensorflow as tf
import tensorflow.contrib.slim.nets as nets

slim = tf.contrib.slim
vgg = nets.vgg

train_log_dir = ...
if not tf.gfile.Exists(train_log_dir):
  tf.gfile.MakeDirs(train_log_dir)

with tf.Graph().as_default():
  # Set up the data loading:
  images, labels = ...

  # Define the model:
  predictions = vgg.vgg_16(images, is_training=True)

  # Specify the loss function:
  slim.losses.softmax_cross_entropy(predictions, labels)

  total_loss = slim.losses.get_total_loss()
  tf.summary.scalar('losses/total_loss', total_loss)

  # Specify the optimization scheme:
  optimizer = tf.train.GradientDescentOptimizer(learning_rate=.001)

  # create_train_op that ensures that when we evaluate it to get the loss,
  # the update_ops are done and the gradient updates are computed.
  train_tensor = slim.learning.create_train_op(total_loss, optimizer)

  # Actually runs training.
  slim.learning.train(train_tensor, train_log_dir)

https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/slim


技术交流学习,请加QQ微信:631531977
目录