當前位置: 首頁>>代碼示例>>Python>>正文


Python tfutils.get_model_loader方法代碼示例

本文整理匯總了Python中tensorpack.tfutils.get_model_loader方法的典型用法代碼示例。如果您正苦於以下問題:Python tfutils.get_model_loader方法的具體用法?Python tfutils.get_model_loader怎麽用?Python tfutils.get_model_loader使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在tensorpack.tfutils的用法示例。


在下文中一共展示了tfutils.get_model_loader方法的3個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。

示例1: main

# 需要導入模塊: from tensorpack import tfutils [as 別名]
# 或者: from tensorpack.tfutils import get_model_loader [as 別名]
def main():
    args = get_args()
    nr_gpu = get_nr_gpu()
    args.batch_size = args.batch_size // nr_gpu

    model = Model(args)

    if args.evaluate:
        evaluate_wsol(args, model, interval=False)
        sys.exit()

    logger.set_logger_dir(ospj('train_log', args.log_dir))
    config = get_config(model, args)

    if args.use_pretrained_model:
        config.session_init = get_model_loader(_CKPT_NAMES[args.arch_name])

    launch_train_with_config(config,
                             SyncMultiGPUTrainerParameterServer(nr_gpu))

    evaluate_wsol(args, model, interval=True) 
開發者ID:junsukchoe,項目名稱:ADL,代碼行數:23,代碼來源:train.py

示例2: prepare_model

# 需要導入模塊: from tensorpack import tfutils [as 別名]
# 或者: from tensorpack.tfutils import get_model_loader [as 別名]
def prepare_model(model_name,
                  use_pretrained,
                  pretrained_model_file_path,
                  data_format="channels_last"):
    kwargs = {"pretrained": use_pretrained}

    raw_net = get_model(
        name=model_name,
        data_format=data_format,
        **kwargs)
    input_image_size = raw_net.in_size[0] if hasattr(raw_net, "in_size") else 224

    net = ImageNetModel(
        model_lambda=raw_net,
        image_size=input_image_size,
        data_format=data_format)

    if use_pretrained and not pretrained_model_file_path:
        pretrained_model_file_path = raw_net.file_path

    inputs_desc = None
    if pretrained_model_file_path:
        assert (os.path.isfile(pretrained_model_file_path))
        logging.info("Loading model: {}".format(pretrained_model_file_path))
        inputs_desc = get_model_loader(pretrained_model_file_path)

    return net, inputs_desc 
開發者ID:osmr,項目名稱:imgclsmob,代碼行數:29,代碼來源:utils_tp.py

示例3: run

# 需要導入模塊: from tensorpack import tfutils [as 別名]
# 或者: from tensorpack.tfutils import get_model_loader [as 別名]
def run(self):
        def get_last_chkpt_path(prev_phase_dir):
            stat_file_path = prev_phase_dir + '/stats.json'
            with open(stat_file_path) as stat_file:
                info = json.load(stat_file)
            chkpt_list = [epoch_stat['global_step'] for epoch_stat in info]
            last_chkpts_path = "%smodel-%d.index" % (prev_phase_dir, max(chkpt_list))
            return last_chkpts_path

        phase_opts = self.training_phase

        if len(phase_opts) > 1:
            for idx, opt in enumerate(phase_opts):
                random.seed(self.seed)
                np.random.seed(self.seed)
                tf.random.set_random_seed(self.seed)

                log_dir = '%s/%02d/' % (self.save_dir, idx)
                pretrained_path = opt['pretrained_path'] 
                if pretrained_path == -1:
                    pretrained_path = get_last_chkpt_path(prev_log_dir)
                    init_weights = SaverRestore(pretrained_path, ignore=['learning_rate'])
                elif pretrained_path is not None:
                    init_weights = get_model_loader(pretrained_path)
                self.run_once(opt, sess_init=init_weights, save_dir=log_dir)
                prev_log_dir = log_dir
        else:
            random.seed(self.seed)
            np.random.seed(self.seed)
            tf.random.set_random_seed(self.seed)

            opt = phase_opts[0]
            init_weights = None
            if 'pretrained_path' in opt:
                assert opt['pretrained_path'] != -1
                init_weights = get_model_loader(opt['pretrained_path'])
            self.run_once(opt, sess_init=init_weights, save_dir=self.save_dir)

        return
    ####
####

########################################################################### 
開發者ID:vqdang,項目名稱:hover_net,代碼行數:45,代碼來源:train.py


注:本文中的tensorpack.tfutils.get_model_loader方法示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。