当前位置: 首页>>代码示例>>Python>>正文


Python slim.assign_from_checkpoint_fn方法代码示例

本文整理汇总了Python中tensorflow.contrib.slim.assign_from_checkpoint_fn方法的典型用法代码示例。如果您正苦于以下问题:Python slim.assign_from_checkpoint_fn方法的具体用法?Python slim.assign_from_checkpoint_fn怎么用?Python slim.assign_from_checkpoint_fn使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在tensorflow.contrib.slim的用法示例。


在下文中一共展示了slim.assign_from_checkpoint_fn方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: load_ckpt

# 需要导入模块: from tensorflow.contrib import slim [as 别名]
# 或者: from tensorflow.contrib.slim import assign_from_checkpoint_fn [as 别名]
def load_ckpt(ckpt_name, var_scope_name, scope, constructor, input_tensor, label_offset, load_weights, **kwargs):
    """ 
    Arguments
        ckpt_name       file name of the checkpoint
        var_scope_name  name of the variable scope
        scope           arg_scope
        constructor     constructor of the model
        input_tensor    tensor of input image
        label_offset    whether it is 1000 classes or 1001 classes, if it is 1001, remove class 0
        load_weights    whether to load weights
        kwargs 
            is_training 
            create_aux_logits 
    """
    with slim.arg_scope(scope):
        logits, endpoints = constructor(\
                input_tensor, num_classes=1000+label_offset, \
                scope=var_scope_name, **kwargs)
    if load_weights:
        init_fn = slim.assign_from_checkpoint_fn(\
                ckpt_name, slim.get_model_variables(var_scope_name))
        init_fn(K.get_session())
    return logits, endpoints 
开发者ID:sangxia,项目名称:nips-2017-adversarial,代码行数:25,代码来源:model_wrappers.py

示例2: _get_init_fn

# 需要导入模块: from tensorflow.contrib import slim [as 别名]
# 或者: from tensorflow.contrib.slim import assign_from_checkpoint_fn [as 别名]
def _get_init_fn():
    vgg_checkpoint_path = "vgg_19.ckpt"
    if tf.gfile.IsDirectory(vgg_checkpoint_path):
        checkpoint_path = tf.train.latest_checkpoint(vgg_checkpoint_path)
    else:
        checkpoint_path = vgg_checkpoint_path

    variables_to_restore = []
    for var in slim.get_model_variables():
        tf.logging.info('model_var: %s' % var)
        excluded = False
        for exclusion in ['vgg_19/fc']:
            if var.op.name.startswith(exclusion):
                excluded = True
                tf.logging.info('exclude:%s' % exclusion)
                break
        if not excluded:
            variables_to_restore.append(var)

    tf.logging.info('Fine-tuning from %s' % checkpoint_path)
    return slim.assign_from_checkpoint_fn(
        checkpoint_path,
        variables_to_restore,
        ignore_missing_vars=True) 
开发者ID:JianqiangRen,项目名称:AAMS,代码行数:26,代码来源:train.py

示例3: get_init_fn

# 需要导入模块: from tensorflow.contrib import slim [as 别名]
# 或者: from tensorflow.contrib.slim import assign_from_checkpoint_fn [as 别名]
def get_init_fn(checkpoints_dir, model_name='inception_v1.ckpt'):
    """Returns a function run by the chief worker to warm-start the training.
    """
    checkpoint_exclude_scopes=["InceptionV1/Logits", "InceptionV1/AuxLogits"]
    
    exclusions = [scope.strip() for scope in checkpoint_exclude_scopes]

    variables_to_restore = []
    for var in slim.get_model_variables():
        excluded = False
        for exclusion in exclusions:
            if var.op.name.startswith(exclusion):
                excluded = True
                break
        if not excluded:
            variables_to_restore.append(var)

    return slim.assign_from_checkpoint_fn(
        os.path.join(checkpoints_dir, model_name),
        variables_to_restore) 
开发者ID:anthonyhu,项目名称:tumblr-emotions,代码行数:22,代码来源:im_model.py

示例4: get_init_fn

# 需要导入模块: from tensorflow.contrib import slim [as 别名]
# 或者: from tensorflow.contrib.slim import assign_from_checkpoint_fn [as 别名]
def get_init_fn(self, checkpoint_path):
        """Returns a function run by the chief worker to warm-start the training."""
        checkpoint_exclude_scopes=["InceptionV4/Logits", "InceptionV4/AuxLogits"]
        
        exclusions = [scope.strip() for scope in checkpoint_exclude_scopes]
    
        variables_to_restore = []
        for var in slim.get_model_variables():
            excluded = False
            for exclusion in exclusions:
                if var.op.name.startswith(exclusion):
                    excluded = True
                    break
            if not excluded:
                variables_to_restore.append(var)
    
        return slim.assign_from_checkpoint_fn(
          checkpoint_path,
          variables_to_restore) 
开发者ID:LevinJ,项目名称:SSD_tensorflow_VOC,代码行数:21,代码来源:pretrained.py

示例5: get_model_init_fn

# 需要导入模块: from tensorflow.contrib import slim [as 别名]
# 或者: from tensorflow.contrib.slim import assign_from_checkpoint_fn [as 别名]
def get_model_init_fn(train_logdir,
                      tf_initial_checkpoint,
                      initialize_last_layer,
                      last_layers,
                      ignore_missing_vars=False):
    """Gets the function initializing model variables from a checkpoint.

    Args:
      train_logdir: Log directory for training.
      tf_initial_checkpoint: TensorFlow checkpoint for initialization.
      initialize_last_layer: Initialize last layer or not.
      last_layers: Last layers of the model.
      ignore_missing_vars: Ignore missing variables in the checkpoint.

    Returns:
      Initialization function.
    """
    if tf_initial_checkpoint is None:
        tf.logging.info('Not initializing the model from a checkpoint.')
        return None

    if tf.train.latest_checkpoint(train_logdir):
        tf.logging.info('Ignoring initialization; other checkpoint exists')
        return None

    tf.logging.info('Initializing model from path: %s', tf_initial_checkpoint)

    # Variables that will not be restored.
    exclude_list = ['global_step']
    if not initialize_last_layer:
        exclude_list.extend(last_layers)

    variables_to_restore = slim.get_variables_to_restore(exclude=exclude_list)

    if variables_to_restore:
        return slim.assign_from_checkpoint_fn(
            tf_initial_checkpoint,
            variables_to_restore,
            ignore_missing_vars=ignore_missing_vars)
    return None 
开发者ID:sercant,项目名称:mobile-segmentation,代码行数:42,代码来源:train_utils.py

示例6: load_ckpt

# 需要导入模块: from tensorflow.contrib import slim [as 别名]
# 或者: from tensorflow.contrib.slim import assign_from_checkpoint_fn [as 别名]
def load_ckpt(ckpt_name, var_scope_name, scope, constructor, input_tensor, label_offset, load_weights, **kwargs):
    """ kwargs are is_training and create_aux_logits """
    print(var_scope_name)
    with slim.arg_scope(scope):
        logits, endpoints = constructor(\
                input_tensor, num_classes=1000+label_offset, \
                scope=var_scope_name, **kwargs)
    if load_weights:
        init_fn = slim.assign_from_checkpoint_fn(\
                ckpt_name, slim.get_model_variables(var_scope_name))
        init_fn(K.get_session())
    return logits, endpoints 
开发者ID:sangxia,项目名称:nips-2017-adversarial,代码行数:14,代码来源:model_wrappers.py

示例7: main

# 需要导入模块: from tensorflow.contrib import slim [as 别名]
# 或者: from tensorflow.contrib.slim import assign_from_checkpoint_fn [as 别名]
def main():
    model = config.get('config', 'model')
    yolo = importlib.import_module('model.' + model)
    width = config.getint(model, 'width')
    height = config.getint(model, 'height')
    with tf.Session() as sess:
        image = tf.placeholder(tf.float32, [1, height, width, 3], name='image')
        builder = yolo.Builder(args, config)
        builder(image)
        global_step = tf.contrib.framework.get_or_create_global_step()
        model_path = tf.train.latest_checkpoint(utils.get_logdir(config))
        tf.logging.info('load ' + model_path)
        slim.assign_from_checkpoint_fn(model_path, tf.global_variables())(sess)
        tf.logging.info('global_step=%d' % sess.run(global_step))
        path = os.path.expanduser(os.path.expandvars(args.path))
        if os.path.isfile(path):
            detect(sess, builder.model, builder.names, image, path)
            plt.show()
        else:
            for dirpath, _, filenames in os.walk(path):
                for filename in filenames:
                    if os.path.splitext(filename)[-1].lower() in args.exts:
                        _path = os.path.join(dirpath, filename)
                        print(_path)
                        detect(sess, builder.model, builder.names, image, _path)
                        plt.show() 
开发者ID:ruiminshen,项目名称:yolo-tf,代码行数:28,代码来源:detect.py

示例8: main

# 需要导入模块: from tensorflow.contrib import slim [as 别名]
# 或者: from tensorflow.contrib.slim import assign_from_checkpoint_fn [as 别名]
def main():
    model = config.get('config', 'model')
    yolo = importlib.import_module('model.' + model)
    width = config.getint(model, 'width')
    height = config.getint(model, 'height')
    preprocess = getattr(importlib.import_module('detect'), args.preprocess)
    with tf.Session() as sess:
        ph_image = tf.placeholder(tf.float32, [1, height, width, 3], name='ph_image')
        builder = yolo.Builder(args, config)
        builder(ph_image)
        global_step = tf.contrib.framework.get_or_create_global_step()
        model_path = tf.train.latest_checkpoint(utils.get_logdir(config))
        tf.logging.info('load ' + model_path)
        slim.assign_from_checkpoint_fn(model_path, tf.global_variables())(sess)
        tf.logging.info('global_step=%d' % sess.run(global_step))
        tensors = [builder.model.conf, builder.model.xy_min, builder.model.xy_max]
        tensors = [tf.check_numerics(t, t.op.name) for t in tensors]
        cap = cv2.VideoCapture(0)
        try:
            while True:
                ret, image_bgr = cap.read()
                assert ret
                image_height, image_width, _ = image_bgr.shape
                scale = [image_width / builder.model.cell_width, image_height / builder.model.cell_height]
                image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
                image_std = np.expand_dims(preprocess(cv2.resize(image_rgb, (width, height))).astype(np.float32), 0)
                feed_dict = {ph_image: image_std}
                conf, xy_min, xy_max = sess.run(tensors, feed_dict)
                boxes = utils.postprocess.non_max_suppress(conf[0], xy_min[0], xy_max[0], args.threshold, args.threshold_iou)
                for _conf, _xy_min, _xy_max in boxes:
                    index = np.argmax(_conf)
                    if _conf[index] > args.threshold:
                        _xy_min = (_xy_min * scale).astype(np.int)
                        _xy_max = (_xy_max * scale).astype(np.int)
                        cv2.rectangle(image_bgr, tuple(_xy_min), tuple(_xy_max), (255, 0, 255), 3)
                        cv2.putText(image_bgr, builder.names[index] + ' (%.1f%%)' % (_conf[index] * 100), tuple(_xy_min), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
                cv2.imshow('detection', image_bgr)
                cv2.waitKey(1)
        finally:
            cv2.destroyAllWindows()
            cap.release() 
开发者ID:ruiminshen,项目名称:yolo-tf,代码行数:43,代码来源:detect_camera.py

示例9: use_inceptionv4

# 需要导入模块: from tensorflow.contrib import slim [as 别名]
# 或者: from tensorflow.contrib.slim import assign_from_checkpoint_fn [as 别名]
def use_inceptionv4(self):
        image_size = inception.inception_v4.default_image_size
        img_path = "../../data/misec_images/EnglishCockerSpaniel_simon.jpg"
        checkpoint_path = "../../data/trained_models/inception_v4/inception_v4.ckpt"

        with tf.Graph().as_default():
           
            image_string = tf.read_file(img_path)
            image = tf.image.decode_jpeg(image_string, channels=3)
            processed_image = inception_preprocessing.preprocess_image(image, image_size, image_size, is_training=False)
            processed_images  = tf.expand_dims(processed_image, 0)
            
            # Create the model, use the default arg scope to configure the batch norm parameters.
            with slim.arg_scope(inception.inception_v4_arg_scope()):
                logits, _ = inception.inception_v4(processed_images, num_classes=1001, is_training=False)
            probabilities = tf.nn.softmax(logits)
            
            init_fn = slim.assign_from_checkpoint_fn(
                checkpoint_path,
                slim.get_model_variables('InceptionV4'))
            
            with tf.Session() as sess:
                init_fn(sess)
                np_image, probabilities = sess.run([image, probabilities])
                probabilities = probabilities[0, 0:]
                sorted_inds = [i[0] for i in sorted(enumerate(-probabilities), key=lambda x:x[1])]
                self.disp_names(sorted_inds,probabilities)
                
            plt.figure()
            plt.imshow(np_image.astype(np.uint8))
            plt.axis('off')
            plt.title(img_path)
            plt.show()
            
            
        
        return 
开发者ID:LevinJ,项目名称:SSD_tensorflow_VOC,代码行数:39,代码来源:pretrained.py

示例10: use_vgg16

# 需要导入模块: from tensorflow.contrib import slim [as 别名]
# 或者: from tensorflow.contrib.slim import assign_from_checkpoint_fn [as 别名]
def use_vgg16(self):
        
        with tf.Graph().as_default():
            image_size = vgg.vgg_16.default_image_size
            img_path = "../../data/misec_images/First_Student_IC_school_bus_202076.jpg"
            checkpoint_path = "../../data/trained_models/vgg16/vgg_16.ckpt"
            
            image_string = tf.read_file(img_path)
            image = tf.image.decode_jpeg(image_string, channels=3)
            processed_image = vgg_preprocessing.preprocess_image(image, image_size, image_size, is_training=False)
            processed_images  = tf.expand_dims(processed_image, 0)
            
            # Create the model, use the default arg scope to configure the batch norm parameters.
            with slim.arg_scope(vgg.vgg_arg_scope()):
                # 1000 classes instead of 1001.
                logits, _ = vgg.vgg_16(processed_images, num_classes=1000, is_training=False)
                probabilities = tf.nn.softmax(logits)
                
                init_fn = slim.assign_from_checkpoint_fn(
                    checkpoint_path,
                    slim.get_model_variables('vgg_16'))
                
                with tf.Session() as sess:
                    init_fn(sess)
                    np_image, probabilities = sess.run([image, probabilities])
                    probabilities = probabilities[0, 0:]
                    sorted_inds = [i[0] for i in sorted(enumerate(-probabilities), key=lambda x:x[1])]
                    self.disp_names(sorted_inds,probabilities,include_background=False)
                    
                plt.figure()
                plt.imshow(np_image.astype(np.uint8))
                plt.axis('off')
                plt.title(img_path)
                plt.show()
        return 
开发者ID:LevinJ,项目名称:SSD_tensorflow_VOC,代码行数:37,代码来源:pretrained.py

示例11: __init__

# 需要导入模块: from tensorflow.contrib import slim [as 别名]
# 或者: from tensorflow.contrib.slim import assign_from_checkpoint_fn [as 别名]
def __init__(self, model_path, batch_size):
        self.batch_size = batch_size

        latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir=os.path.join(model_path, 'train'))
        step = int(os.path.basename(latest_checkpoint).split('-')[1])
        default_params = get_arguments()
        #flags = Namespace(load_and_save_params(vars(default_params), model_path))
        flags = Namespace(load_and_save_params(default_params=dict(), exp_dir=model_path))
        image_size = get_image_size(flags.data_dir)

        with tf.Graph().as_default():
            pretrain_images_pl, pretrain_labels_pl = placeholder_inputs(
                batch_size=batch_size, image_size=image_size, scope='inputs/pretrain')
            logits = build_feat_extract_pretrain_graph(pretrain_images_pl, flags, is_training=False)

            self.pretrain_images_pl = pretrain_images_pl
            self.pretrain_labels_pl = pretrain_labels_pl

            init_fn = slim.assign_from_checkpoint_fn(
                latest_checkpoint,
                slim.get_model_variables('Model'))

            config = tf.ConfigProto(allow_soft_placement=True)
            config.gpu_options.allow_growth = True
            self.sess = tf.Session(config=config)

            # Run init before loading the weights
            self.sess.run(tf.global_variables_initializer())
            # Load weights
            init_fn(self.sess)

            self.flags = flags
            self.logits = logits
            self.logits_size = self.logits.get_shape().as_list()[-1]
            self.step = step 
开发者ID:ElementAI,项目名称:am3,代码行数:37,代码来源:AM3_TADAM.py

示例12: __init__

# 需要导入模块: from tensorflow.contrib import slim [as 别名]
# 或者: from tensorflow.contrib.slim import assign_from_checkpoint_fn [as 别名]
def __init__(self, model_path, batch_size):
        self.batch_size = batch_size

        latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir=os.path.join(model_path, 'train'))
        step = int(os.path.basename(latest_checkpoint).split('-')[1])

        flags = Namespace(load_and_save_params(default_params=dict(), exp_dir=model_path))
        image_size = get_image_size(flags.data_dir)

        with tf.Graph().as_default():
            pretrain_images_pl, pretrain_labels_pl = placeholder_inputs(
                batch_size=batch_size, image_size=image_size, scope='inputs/pretrain')
            logits = build_feat_extract_pretrain_graph(pretrain_images_pl, flags, is_training=False)

            self.pretrain_images_pl = pretrain_images_pl
            self.pretrain_labels_pl = pretrain_labels_pl

            init_fn = slim.assign_from_checkpoint_fn(
                latest_checkpoint,
                slim.get_model_variables('Model'))

            config = tf.ConfigProto(allow_soft_placement=True)
            config.gpu_options.allow_growth = True
            self.sess = tf.Session(config=config)

            # Run init before loading the weights
            self.sess.run(tf.global_variables_initializer())
            # Load weights
            init_fn(self.sess)

            self.flags = flags
            self.logits = logits
            self.logits_size = self.logits.get_shape().as_list()[-1]
            self.step = step 
开发者ID:ElementAI,项目名称:am3,代码行数:36,代码来源:tadam.py

示例13: train

# 需要导入模块: from tensorflow.contrib import slim [as 别名]
# 或者: from tensorflow.contrib.slim import assign_from_checkpoint_fn [as 别名]
def train(self):
        img_size = [self.image_height, self.image_width, self.image_depth]
        train_batch = tf.train.shuffle_batch([read_tfrecord(self.train_file, img_size)],
                    batch_size = self.train_batch_size,
                    capacity = 3000,
                    num_threads = 2,
                    min_after_dequeue = 1000)
        test_batch = tf.train.shuffle_batch([read_tfrecord(self.test_file, img_size)],
                    batch_size = self.test_batch_size,
                    capacity = 500,
                    num_threads = 2,
                    min_after_dequeue = 300)
        init = tf.global_variables_initializer()
        init_fn = slim.assign_from_checkpoint_fn("resnet_v2_50.ckpt", slim.get_model_variables('resnet_v2'))
        saver = tf.train.Saver()
        with tf.Session() as sess:
            sess.run(init)
            init_fn(sess)
            train_writer = tf.summary.FileWriter(self.log_dir + "/train", sess.graph)
            test_writer  = tf.summary.FileWriter(self.log_dir + "/test", sess.graph)
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)
            inputs_test, outputs_gt_test = build_img_pair(sess.run(test_batch))
            for iter in range(self.max_iteration):
                inputs_train, outputs_gt_train = build_img_pair(sess.run(train_batch))
                # train with dynamic learning rate
                if iter <= 500:
                    self.train_step.run({self.input_data:inputs_train, self.gt:outputs_gt_train,
                                    self.learning_rate:1e-3, self.batch_size:self.train_batch_size})
                elif iter <= self.max_iteration - 1000:
                    self.train_step.run({self.input_data:inputs_train, self.gt:outputs_gt_train,
                                    self.learning_rate:0.5e-3, self.batch_size:self.train_batch_size})
                else:
                    self.train_step.run({self.input_data:inputs_train, self.gt:outputs_gt_train,
                                    self.learning_rate:1e-4, self.batch_size:self.train_batch_size})
                # print training loss and test loss
                if iter%10 == 0:
                    summary_train = sess.run(self.summary, {self.input_data:inputs_train, self.gt:outputs_gt_train,
                                             self.batch_size:self.train_batch_size})
                    train_writer.add_summary(summary_train, iter)
                    train_writer.flush()
                    summary_test = sess.run(self.summary, {self.input_data:inputs_test, self.gt:outputs_gt_test,
                                             self.batch_size:self.test_batch_size})
                    test_writer.add_summary(summary_test, iter)
                    test_writer.flush()
                # record training loss and test loss
                if iter%10 == 0:
                    train_loss  = self.cross_entropy.eval({self.input_data:inputs_train, self.gt:outputs_gt_train,
                                                    self.batch_size:self.train_batch_size})
                    test_loss   = self.cross_entropy.eval({self.input_data:inputs_test, self.gt:outputs_gt_test,
                                                    self.batch_size:self.test_batch_size})
                    print("iter step %d trainning batch loss %f"%(iter, train_loss))
                    print("iter step %d test loss %f\n"%(iter, test_loss))
                # record model
                if iter%100 == 0:
                    saver.save(sess, self.log_dir + "/model.ckpt", global_step=iter)
            coord.request_stop()
            coord.join(threads) 
开发者ID:SaoYan,项目名称:bgsCNN,代码行数:60,代码来源:bgsCNN_v1.py

示例14: train

# 需要导入模块: from tensorflow.contrib import slim [as 别名]
# 或者: from tensorflow.contrib.slim import assign_from_checkpoint_fn [as 别名]
def train(self):
        img_size = [self.image_height, self.image_width, self.image_depth]
        train_batch = tf.train.shuffle_batch([read_tfrecord(self.train_file, img_size)],
                    batch_size = self.train_batch_size,
                    capacity = 2000,
                    num_threads = 2,
                    min_after_dequeue = 1000)
        test_batch = tf.train.shuffle_batch([read_tfrecord(self.test_file, img_size)],
                    batch_size = self.test_batch_size,
                    capacity = 500,
                    num_threads = 2,
                    min_after_dequeue = 300)
        init = tf.global_variables_initializer()
        init_fn = slim.assign_from_checkpoint_fn("vgg_16.ckpt", slim.get_model_variables('vgg_16'))
        saver = tf.train.Saver()
        with tf.Session() as sess:
            sess.run(init)
            init_fn(sess)
            train_writer = tf.summary.FileWriter(self.log_dir + "/train", sess.graph)
            test_writer  = tf.summary.FileWriter(self.log_dir + "/test", sess.graph)
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)
            inputs_test, outputs_gt_test = build_img_pair(sess.run(test_batch))
            for iter in range(self.max_iteration):
                inputs_train, outputs_gt_train = build_img_pair(sess.run(train_batch))
                # train with dynamic learning rate
                if iter <= 500:
                    self.train_step.run({self.input_data:inputs_train, self.gt:outputs_gt_train, self.is_training:True,
                                    self.learning_rate:1e-4})
                elif iter <= self.max_iteration - 1000:
                    self.train_step.run({self.input_data:inputs_train, self.gt:outputs_gt_train, self.is_training:True,
                                    self.learning_rate:0.5e-4})
                else:
                    self.train_step.run({self.input_data:inputs_train, self.gt:outputs_gt_train, self.is_training:True,
                                    self.learning_rate:1e-5})
                # print training loss and test loss
                if iter%10 == 0:
                    summary_train = sess.run(self.summary, {self.input_data:inputs_train, self.gt:outputs_gt_train,
                                             self.is_training:False})
                    train_writer.add_summary(summary_train, iter)
                    train_writer.flush()
                    summary_test = sess.run(self.summary, {self.input_data:inputs_test, self.gt:outputs_gt_test,
                                            self.is_training:False})
                    test_writer.add_summary(summary_test, iter)
                    test_writer.flush()
                # record training loss and test loss
                if iter%10 == 0:
                    train_loss  = self.cross_entropy.eval({self.input_data:inputs_train, self.gt:outputs_gt_train,
                                                    self.is_training:False})
                    test_loss   = self.cross_entropy.eval({self.input_data:inputs_test, self.gt:outputs_gt_test,
                                                    self.is_training:False})
                    print("iter step %d trainning batch loss %f"%(iter, train_loss))
                    print("iter step %d test loss %f\n"%(iter, test_loss))
                # record model
                if iter%100 == 0:
                    saver.save(sess, self.log_dir + "/model.ckpt", global_step=iter)
            coord.request_stop()
            coord.join(threads) 
开发者ID:SaoYan,项目名称:bgsCNN,代码行数:60,代码来源:bgsCNN_v4.py

示例15: train

# 需要导入模块: from tensorflow.contrib import slim [as 别名]
# 或者: from tensorflow.contrib.slim import assign_from_checkpoint_fn [as 别名]
def train(self):
        img_size = [self.image_height, self.image_width, self.image_depth]
        train_batch = tf.train.shuffle_batch([read_tfrecord(self.train_file, img_size)],
                    batch_size = self.train_batch_size,
                    capacity = 2000,
                    num_threads = 2,
                    min_after_dequeue = 1000)
        test_batch = tf.train.shuffle_batch([read_tfrecord(self.test_file, img_size)],
                    batch_size = self.test_batch_size,
                    capacity = 500,
                    num_threads = 2,
                    min_after_dequeue = 300)
        init = tf.global_variables_initializer()
        init_fn = slim.assign_from_checkpoint_fn("vgg_16.ckpt", slim.get_model_variables('vgg_16'))
        saver = tf.train.Saver()
        with tf.Session() as sess:
            sess.run(init)
            init_fn(sess)
            train_writer = tf.summary.FileWriter(self.log_dir + "/train", sess.graph)
            test_writer  = tf.summary.FileWriter(self.log_dir + "/test", sess.graph)
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)
            inputs_test, outputs_gt_test = build_img_pair(sess.run(test_batch))
            for iter in range(self.max_iteration):
                inputs_train, outputs_gt_train = build_img_pair(sess.run(train_batch))
                # train with dynamic learning rate
                if iter <= 500:
                    self.train_step.run({self.input_data:inputs_train, self.gt:outputs_gt_train,
                                    self.learning_rate:1e-4, self.is_training:True})
                elif iter <= self.max_iteration - 1000:
                    self.train_step.run({self.input_data:inputs_train, self.gt:outputs_gt_train,
                                    self.learning_rate:0.5e-4, self.is_training:True})
                else:
                    self.train_step.run({self.input_data:inputs_train, self.gt:outputs_gt_train,
                                    self.learning_rate:1e-5, self.is_training:True})
                # print training loss and test loss
                if iter%10 == 0:
                    summary_train = sess.run(self.summary, {self.input_data:inputs_train, self.gt:outputs_gt_train, self.is_training:False})
                    train_writer.add_summary(summary_train, iter)
                    train_writer.flush()
                    summary_test = sess.run(self.summary, {self.input_data:inputs_test, self.gt:outputs_gt_test, self.is_training:False})
                    test_writer.add_summary(summary_test, iter)
                    test_writer.flush()
                # record training loss and test loss
                if iter%10 == 0:
                    train_loss  = self.cross_entropy.eval({self.input_data:inputs_train, self.gt:outputs_gt_train, self.is_training:False})
                    test_loss   = self.cross_entropy.eval({self.input_data:inputs_test, self.gt:outputs_gt_test, self.is_training:False})
                    print("iter step %d trainning batch loss %f"%(iter, train_loss))
                    print("iter step %d test loss %f\n"%(iter, test_loss))
                # record model
                if iter%100 == 0:
                    saver.save(sess, self.log_dir + "/model.ckpt", global_step=iter)
            coord.request_stop()
            coord.join(threads) 
开发者ID:SaoYan,项目名称:bgsCNN,代码行数:56,代码来源:bgsCNN_v5.py


注:本文中的tensorflow.contrib.slim.assign_from_checkpoint_fn方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。