1. 修改 object detection API 源码,支持BMP图片格式

官方源码只支持的是三通道jpg,png图片格式, bmp格式的图片是一个通道, 需要修改源码支持1通道的bmp训练.

1.1. 修改tfrecord的保存方式

create_pascal_tf_record.py 文件中,修改图片数据保存方式.

原代码

   with tf.gfile.GFile(full_path, 'rb') as fid:
     encoded_jpg = fid.read()
   encoded_jpg_io = io.BytesIO(encoded_jpg)
   image = PIL.Image.open(encoded_jpg_io)
   if image.format != 'JPEG':
     raise ValueError('Image format not JPEG')

注释掉修改为:

    with PIL.Image.open(full_path) as fid:
      encoded_jpg = fid.tobytes()

1.2. 修改tfrecord的读取方式

文件 data_decoders/tf_example_decoder.py中找到类 TfExampleDecoder

  1. 增加函数_read_image:
class TfExampleDecoder(data_decoder.DataDecoder):
  """Tensorflow Example proto decoder."""

  def _read_image(self, keys_to_tensors):
      image_encoded = keys_to_tensors['image/encoded']
      height = keys_to_tensors['image/height']
      width = keys_to_tensors['image/width']
      to_shape = tf.cast(tf.stack([height, width, 1]), tf.int32)
      image = tf.reshape(tf.decode_raw(image_encoded, tf.uint8), to_shape)
      return image
  1. 修改init中下面的代码,这是图片数据读取方式的代码,不要官方的读取方法,使用上面写的函数_read_image读取图片数据:

原代码

    self.items_to_handlers = {
         fields.InputDataFields.image: slim_example_decoder.Image(
             image_key='image/encoded', format_key='image/format', channels=3),

修改为:

    self.items_to_handlers = {
        fields.InputDataFields.image: slim_example_decoder.ItemHandlerCallback(
            keys=['image/encoded', 'image/height', 'image/width'],
            func=self._read_image),
  1. 修改此类中的函数 def decode(self, tf_example_string_tensor):

原代码

    tensor_dict[fields.InputDataFields.image].set_shape([None, None, 3])

修改为:

    tensor_dict[fields.InputDataFields.image].set_shape([None, None, 1])

1.3. 修改模型图导出维数:

exporter.py文件中,查找函数_image_tensor_input_placeholder()
将占位符input_shape的维数更改为(None, None, None, 1)。对于tf_example输入,您不需要指定此维度。

原代码

   input_shape = (None, None, None, 3)

修改为:

def _image_tensor_input_placeholder(input_shape=None):
  """Returns input placeholder and a 4-D uint8 image tensor."""
  if input_shape is None:
    input_shape = (None, None, None, 1)  # 原来是(None, None, None, 3)
  input_tensor = tf.placeholder(
      dtype=tf.uint8, shape=input_shape, name='image_tensor')
  return input_tensor, input_tensor

1.4. 修改预训练模型的权重

创建 edit_checkpoint.py 文件,代码:

from tensorflow.python import pywrap_tensorflow
import numpy as np
import tensorflow as tf

flags = tf.app.flags
flags.DEFINE_string('input_path', 'faster_rcnn_resnet101_coco_2017_11_08/model.ckpt',
                    'path of pretrained_checkpoint')
flags.DEFINE_string('output_path', 'ckpts/model.ckpt', 'output checkpoint')
flags.DEFINE_string('feature_extractor', 'resnet_v1_101', 'name of first checkpoint')
flags.DEFINE_string('num_input_channels', 1, 'number of input channel. Each image, background, diff image require 3 channels')
flags.DEFINE_string('edit_method', 'reduce', 'divide the checkpoint convolution variable by the number of channels'
                                             ' divided by 3 and clone it to every set of 3 channels. random: initialize'
                                             ' extra channels feature map to random truncated_normal with sttdev=0.2'
                                             '. clone: clone the value to new channels')
FLAGS = flags.FLAGS

if __name__ == '__main__':
  reader = pywrap_tensorflow.NewCheckpointReader(FLAGS.input_path)
  var_to_shape_map = reader.get_variable_to_shape_map()
  # var_to_edit_names = ['FirstStageFeatureExtractor/{}/conv1/weights'.format(FLAGS.feature_extractor),
  #                      'FirstStageFeatureExtractor/{}/conv1/weights/Momentum'.format(FLAGS.feature_extractor),]
  var_to_edit_names = ['FirstStageFeatureExtractor/{}/conv1/weights'.format(FLAGS.feature_extractor)]
  print('Loading checkpoint...')
  for key in sorted(var_to_shape_map):
    if key not in var_to_edit_names:
      var = tf.Variable(reader.get_tensor(key), name=key, dtype=tf.float32)
    else:
      print("Found variable: {}".format(key))
  vars_to_edit = []
  for name in var_to_edit_names:
    if reader.has_tensor(name):
      vars_to_edit.append(reader.get_tensor(name))
    else:
      raise Exception("{} not found in checkpoint. Check feature extractor name. Exiting.".format(name))
  new_vars = []
  sess = tf.Session()
  for name, var_to_edit in zip(var_to_edit_names, vars_to_edit):
    if FLAGS.edit_method in ['spread', 'clone']:
      checkpoint_num_input_channels = var_to_edit.shape[2]
      if FLAGS.num_input_channels % checkpoint_num_input_channels != 0:
        raise Exception('For spread edit method, num_input_channels must be divisible by num input channels of checkpoint!')
      num_clones = int(FLAGS.num_input_channels / checkpoint_num_input_channels)
      if FLAGS.edit_method == 'spread':
        spreaded_var = var_to_edit / num_clones
      else:
        spreaded_var = var_to_edit
      new_var = np.tile(spreaded_var, [1, 1, num_clones, 1])
      new_vars.append(tf.Variable(new_var, name=name, dtype=tf.float32))
    elif FLAGS.edit_method == 'random':
      random_shape = list(var_to_edit.shape)
      random_shape[2] = FLAGS.num_input_channels - 3
      random_var = tf.truncated_normal(shape=random_shape, stddev=0.01).eval(session=sess)
      new_var = np.concatenate([var_to_edit, random_var], axis=2)
      new_vars.append(tf.Variable(new_var, name=name, dtype=tf.float32))
    elif FLAGS.edit_method == 'reduce':
      print("name ",name,"  shape ",var_to_edit.shape)
      checkpoint_num_input_channels = var_to_edit.shape[2]
      print("checkpoint_num_input_channels.",checkpoint_num_input_channels)
      # print("name ",name,"  Value ",var_to_edit)
      new_var = var_to_edit[:, :, 0:1, :]
      print("new_var ",name,"  shape ",new_var.shape)
      new_vars.append(tf.Variable(new_var, name=name, dtype=tf.float32))

    else:
      raise Exception("Edit method must be spread or zeros or clone!")
  sess.run(tf.global_variables_initializer())
  saver = tf.train.Saver()
  saver.save(sess, FLAGS.output_path)

#Only need .0000-of-0001 and .index file. Good to go!


技术交流学习,请加QQ微信:631531977
目录