本文整理汇总了Python中tensorpack.tfutils.get_model_loader方法的典型用法代码示例。如果您正苦于以下问题:Python tfutils.get_model_loader方法的具体用法?Python tfutils.get_model_loader怎么用?Python tfutils.get_model_loader使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类tensorpack.tfutils
的用法示例。
在下文中一共展示了tfutils.get_model_loader方法的3个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: main
# 需要导入模块: from tensorpack import tfutils [as 别名]
# 或者: from tensorpack.tfutils import get_model_loader [as 别名]
def main():
args = get_args()
nr_gpu = get_nr_gpu()
args.batch_size = args.batch_size // nr_gpu
model = Model(args)
if args.evaluate:
evaluate_wsol(args, model, interval=False)
sys.exit()
logger.set_logger_dir(ospj('train_log', args.log_dir))
config = get_config(model, args)
if args.use_pretrained_model:
config.session_init = get_model_loader(_CKPT_NAMES[args.arch_name])
launch_train_with_config(config,
SyncMultiGPUTrainerParameterServer(nr_gpu))
evaluate_wsol(args, model, interval=True)
示例2: prepare_model
# 需要导入模块: from tensorpack import tfutils [as 别名]
# 或者: from tensorpack.tfutils import get_model_loader [as 别名]
def prepare_model(model_name,
use_pretrained,
pretrained_model_file_path,
data_format="channels_last"):
kwargs = {"pretrained": use_pretrained}
raw_net = get_model(
name=model_name,
data_format=data_format,
**kwargs)
input_image_size = raw_net.in_size[0] if hasattr(raw_net, "in_size") else 224
net = ImageNetModel(
model_lambda=raw_net,
image_size=input_image_size,
data_format=data_format)
if use_pretrained and not pretrained_model_file_path:
pretrained_model_file_path = raw_net.file_path
inputs_desc = None
if pretrained_model_file_path:
assert (os.path.isfile(pretrained_model_file_path))
logging.info("Loading model: {}".format(pretrained_model_file_path))
inputs_desc = get_model_loader(pretrained_model_file_path)
return net, inputs_desc
示例3: run
# 需要导入模块: from tensorpack import tfutils [as 别名]
# 或者: from tensorpack.tfutils import get_model_loader [as 别名]
def run(self):
def get_last_chkpt_path(prev_phase_dir):
stat_file_path = prev_phase_dir + '/stats.json'
with open(stat_file_path) as stat_file:
info = json.load(stat_file)
chkpt_list = [epoch_stat['global_step'] for epoch_stat in info]
last_chkpts_path = "%smodel-%d.index" % (prev_phase_dir, max(chkpt_list))
return last_chkpts_path
phase_opts = self.training_phase
if len(phase_opts) > 1:
for idx, opt in enumerate(phase_opts):
random.seed(self.seed)
np.random.seed(self.seed)
tf.random.set_random_seed(self.seed)
log_dir = '%s/%02d/' % (self.save_dir, idx)
pretrained_path = opt['pretrained_path']
if pretrained_path == -1:
pretrained_path = get_last_chkpt_path(prev_log_dir)
init_weights = SaverRestore(pretrained_path, ignore=['learning_rate'])
elif pretrained_path is not None:
init_weights = get_model_loader(pretrained_path)
self.run_once(opt, sess_init=init_weights, save_dir=log_dir)
prev_log_dir = log_dir
else:
random.seed(self.seed)
np.random.seed(self.seed)
tf.random.set_random_seed(self.seed)
opt = phase_opts[0]
init_weights = None
if 'pretrained_path' in opt:
assert opt['pretrained_path'] != -1
init_weights = get_model_loader(opt['pretrained_path'])
self.run_once(opt, sess_init=init_weights, save_dir=self.save_dir)
return
####
####
###########################################################################