当前位置: 首页>>代码示例>>Python>>正文


Python tfutils.get_model_loader方法代码示例

本文整理汇总了Python中tensorpack.tfutils.get_model_loader方法的典型用法代码示例。如果您正苦于以下问题:Python tfutils.get_model_loader方法的具体用法?Python tfutils.get_model_loader怎么用?Python tfutils.get_model_loader使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在tensorpack.tfutils的用法示例。


在下文中一共展示了tfutils.get_model_loader方法的3个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: main

# 需要导入模块: from tensorpack import tfutils [as 别名]
# 或者: from tensorpack.tfutils import get_model_loader [as 别名]
def main():
    args = get_args()
    nr_gpu = get_nr_gpu()
    args.batch_size = args.batch_size // nr_gpu

    model = Model(args)

    if args.evaluate:
        evaluate_wsol(args, model, interval=False)
        sys.exit()

    logger.set_logger_dir(ospj('train_log', args.log_dir))
    config = get_config(model, args)

    if args.use_pretrained_model:
        config.session_init = get_model_loader(_CKPT_NAMES[args.arch_name])

    launch_train_with_config(config,
                             SyncMultiGPUTrainerParameterServer(nr_gpu))

    evaluate_wsol(args, model, interval=True) 
开发者ID:junsukchoe,项目名称:ADL,代码行数:23,代码来源:train.py

示例2: prepare_model

# 需要导入模块: from tensorpack import tfutils [as 别名]
# 或者: from tensorpack.tfutils import get_model_loader [as 别名]
def prepare_model(model_name,
                  use_pretrained,
                  pretrained_model_file_path,
                  data_format="channels_last"):
    kwargs = {"pretrained": use_pretrained}

    raw_net = get_model(
        name=model_name,
        data_format=data_format,
        **kwargs)
    input_image_size = raw_net.in_size[0] if hasattr(raw_net, "in_size") else 224

    net = ImageNetModel(
        model_lambda=raw_net,
        image_size=input_image_size,
        data_format=data_format)

    if use_pretrained and not pretrained_model_file_path:
        pretrained_model_file_path = raw_net.file_path

    inputs_desc = None
    if pretrained_model_file_path:
        assert (os.path.isfile(pretrained_model_file_path))
        logging.info("Loading model: {}".format(pretrained_model_file_path))
        inputs_desc = get_model_loader(pretrained_model_file_path)

    return net, inputs_desc 
开发者ID:osmr,项目名称:imgclsmob,代码行数:29,代码来源:utils_tp.py

示例3: run

# 需要导入模块: from tensorpack import tfutils [as 别名]
# 或者: from tensorpack.tfutils import get_model_loader [as 别名]
def run(self):
        def get_last_chkpt_path(prev_phase_dir):
            stat_file_path = prev_phase_dir + '/stats.json'
            with open(stat_file_path) as stat_file:
                info = json.load(stat_file)
            chkpt_list = [epoch_stat['global_step'] for epoch_stat in info]
            last_chkpts_path = "%smodel-%d.index" % (prev_phase_dir, max(chkpt_list))
            return last_chkpts_path

        phase_opts = self.training_phase

        if len(phase_opts) > 1:
            for idx, opt in enumerate(phase_opts):
                random.seed(self.seed)
                np.random.seed(self.seed)
                tf.random.set_random_seed(self.seed)

                log_dir = '%s/%02d/' % (self.save_dir, idx)
                pretrained_path = opt['pretrained_path'] 
                if pretrained_path == -1:
                    pretrained_path = get_last_chkpt_path(prev_log_dir)
                    init_weights = SaverRestore(pretrained_path, ignore=['learning_rate'])
                elif pretrained_path is not None:
                    init_weights = get_model_loader(pretrained_path)
                self.run_once(opt, sess_init=init_weights, save_dir=log_dir)
                prev_log_dir = log_dir
        else:
            random.seed(self.seed)
            np.random.seed(self.seed)
            tf.random.set_random_seed(self.seed)

            opt = phase_opts[0]
            init_weights = None
            if 'pretrained_path' in opt:
                assert opt['pretrained_path'] != -1
                init_weights = get_model_loader(opt['pretrained_path'])
            self.run_once(opt, sess_init=init_weights, save_dir=self.save_dir)

        return
    ####
####

########################################################################### 
开发者ID:vqdang,项目名称:hover_net,代码行数:45,代码来源:train.py


注:本文中的tensorpack.tfutils.get_model_loader方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。