本文整理汇总了Python中PoseTools.get_vars方法的典型用法代码示例。如果您正苦于以下问题:Python PoseTools.get_vars方法的具体用法?Python PoseTools.get_vars怎么用?Python PoseTools.get_vars使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类PoseTools
的用法示例。
在下文中一共展示了PoseTools.get_vars方法的12个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: create_saver
# 需要导入模块: import PoseTools [as 别名]
# 或者: from PoseTools import get_vars [as 别名]
def create_saver(self):
saver = {}
name = self.name
net_name = self.net_name
saver['out_file'] = os.path.join(
self.conf.cachedir,
self.conf.expname + '_' + name)
if self.train_data_name is None:
saver['train_data_file'] = os.path.join(
self.conf.cachedir,
self.conf.expname + '_' + name + '_traindata')
else:
saver['train_data_file'] = os.path.join(
self.conf.cachedir,
self.train_data_name)
saver['ckpt_file'] = os.path.join(
self.conf.cachedir,
self.conf.expname + '_' + name + '_ckpt')
saver['saver'] = (tf.train.Saver(var_list=PoseTools.get_vars(net_name+ '/'),
max_to_keep=self.conf.maxckpt,
save_relative_paths=True))
self.saver = saver
if self.dep_nets:
self.dep_nets.create_joint_saver(self.name)
示例2: restore
# 需要导入模块: import PoseTools [as 别名]
# 或者: from PoseTools import get_vars [as 别名]
def restore(self, sess, do_restore, at_step=-1):
saver = self.saver
name = self.net_name
out_file = saver['out_file'].replace('\\', '/')
latest_ckpt = tf.train.get_checkpoint_state(
self.conf.cachedir, saver['ckpt_file'])
if not latest_ckpt or not do_restore:
start_at = 0
sess.run(tf.variables_initializer(
PoseTools.get_vars(name)),
feed_dict=self.fd)
print("Not loading {:s} variables. Initializing them".format(name))
else:
if at_step < 0:
saver['saver'].restore(sess, latest_ckpt.model_checkpoint_path)
match_obj = re.match(out_file + '-(\d*)', latest_ckpt.model_checkpoint_path)
start_at = int(match_obj.group(1)) + 1
else:
aa = latest_ckpt.all_model_checkpoint_paths
model_file = ''
for a in aa:
match_obj = re.match(out_file + '-(\d*)', a)
step = int(match_obj.group(1))
if step >= at_step:
model_file = a
break
saver['saver'].restore(sess, model_file)
match_obj = re.match(out_file + '-(\d*)', model_file)
start_at = int(match_obj.group(1)) + 1
if self.dep_nets:
self.dep_nets.restore_joint(sess, self.name, self.joint, do_restore)
return start_at
示例3: restoreBaseReg
# 需要导入模块: import PoseTools [as 别名]
# 或者: from PoseTools import get_vars [as 别名]
def restoreBaseReg(self,sess,restore):
outfilename = os.path.join(self.conf.cachedir,self.conf.baseregoutname)
traindatafilename = os.path.join(self.conf.cachedir,self.conf.baseregdataname)
latest_ckpt = tf.train.get_checkpoint_state(self.conf.cachedir,
latest_filename = self.conf.baseregckptname)
if not latest_ckpt or not restore:
self.baseregstartat = 0
self.baseregtrainData = {'train_err':[], 'val_err':[], 'step_no':[],
'train_dist':[], 'val_dist':[] }
sess.run(tf.initialize_variables(PoseTools.get_vars('base')))
print("Not loading base variables. Initializing them")
return False
else:
self.baseregsaver.restore(sess,latest_ckpt.model_checkpoint_path)
matchObj = re.match(outfilename + '-(\d*)',latest_ckpt.model_checkpoint_path)
self.baseregstartat = int(matchObj.group(1))+1
with open(traindatafilename,'rb') as tdfile:
inData = pickle.load(tdfile)
if not isinstance(inData,dict):
self.baseregtrainData, loadconf = inData
print('Parameters that dont match for base:')
PoseTools.compare_conf(self.conf, loadconf)
else:
print("No config was stored for base. Not comparing conf")
self.baseregtrainData = inData
print("Loading base variables from %s"%latest_ckpt.model_checkpoint_path)
return True
示例4: initialize_net
# 需要导入模块: import PoseTools [as 别名]
# 或者: from PoseTools import get_vars [as 别名]
def initialize_net(self, sess):
name = self.net_name
sess.run(tf.variables_initializer(PoseTools.get_vars(name)),
feed_dict=self.fd)
print("Not loading {:s} variables. Initializing them".format(name))
self.init_td()
for dep_net in self.dep_nets:
dep_net.initialize_net(sess)
initialize_remaining_vars(sess)
示例5: restoreEval
# 需要导入模块: import PoseTools [as 别名]
# 或者: from PoseTools import get_vars [as 别名]
def restoreEval(sess,evalsaver,restore,conf,feed_dict):
outfilename = os.path.join(conf.cachedir,conf.evaloutname)
latest_ckpt = tf.train.get_checkpoint_state(conf.cachedir,
latest_filename = conf.evalckptname)
if not latest_ckpt or not restore:
evalstartat = 0
sess.run(tf.variables_initializer(PoseTools.get_vars('eval')), feed_dict=feed_dict)
print("Not loading Eval variables. Initializing them")
else:
evalsaver.restore(sess,latest_ckpt.model_checkpoint_path)
matchObj = re.match(outfilename + '-(\d*)',latest_ckpt.model_checkpoint_path)
evalstartat = int(matchObj.group(1))+1
print("Loading eval variables from %s"%latest_ckpt.model_checkpoint_path)
return evalstartat
示例6: restoreGen
# 需要导入模块: import PoseTools [as 别名]
# 或者: from PoseTools import get_vars [as 别名]
def restoreGen(sess,conf,genSaver,restore=True):
outfilename = os.path.join(conf.cachedir,conf.genoutname)
latest_ckpt = tf.train.get_checkpoint_state(conf.cachedir,
latest_filename = conf.genckptname)
if not latest_ckpt or not restore:
startat = 0
sess.run(tf.initialize_variables(PoseTools.get_vars('poseGen')))
print("Not loading gen variables. Initializing them")
didRestore = False
else:
genSaver.restore(sess,latest_ckpt.model_checkpoint_path)
matchObj = re.match(outfilename + '-(\d*)',latest_ckpt.model_checkpoint_path)
startat = int(matchObj.group(1))+1
print("Loading gen variables from %s"%latest_ckpt.model_checkpoint_path)
didRestore = True
return didRestore,startat
示例7: create_shape_saver
# 需要导入模块: import PoseTools [as 别名]
# 或者: from PoseTools import get_vars [as 别名]
def create_shape_saver(conf):
shape_saver = tf.train.Saver(var_list=PoseTools.get_vars('shape'), max_to_keep=conf.maxckpt)
return shape_saver
示例8: createGenSaver
# 需要导入模块: import PoseTools [as 别名]
# 或者: from PoseTools import get_vars [as 别名]
def createGenSaver(conf):
genSaver = tf.train.Saver(var_list = PoseTools.get_vars('poseGen'), max_to_keep=conf.maxckpt)
return genSaver
示例9: Mixture
# 需要导入模块: import PoseTools [as 别名]
# 或者: from PoseTools import get_vars [as 别名]
value=tf.zeros_like(y_ph))
# y_y = Mixture(cat=cat, components=components_y, value=tf.zeros_like(y_y_ph))
# Note: A bug exists in Mixture which prevents samples from it to have
# a shape of [None]. For now fix it using the value argument, as
# sampling is not necessary for MAP estimation anyways.
#
# There are no latent variables to infer. Thus inference is concerned
# with only training model parameters, which are baked into how we
# specify the neural networks.
n_epoch = 1000
#
inference = mymap.MAP(data={y: y_ph})
v_list = PoseTools.get_vars('nn')
inference.initialize(var_list=v_list,n_iter=n_epoch)
#
myOpt = True
starter_learning_rate = 0.001
step = tf.placeholder(tf.int64)
learning_rate = tf.train.exponential_decay(starter_learning_rate,
step, 300, 0.9, staircase=True)
if myOpt:
optimizer = tf.train.AdamOptimizer(learning_rate).minimize(inference.loss)
sess = tf.InteractiveSession()
# sess = ed.get_session()
示例10: range
# 需要导入模块: import PoseTools [as 别名]
# 或者: from PoseTools import get_vars [as 别名]
self.feed_dict[self.ph['keep_prob']] = mdn_dropout
self.feed_dict[self.ph['phase_train_base']] = False
self.feed_dict[self.ph['phase_train_mdn']] = False
self.trainType = trainType
y = self.create_network()
self.openDBs()
self.createBaseSaver()
self.create_saver()
y_label = self.ph['base_locs'] / self.conf.rescale / self.conf.pool_scale
data_dict = {}
for ndx in range(self.conf.n_classes):
data_dict[y[ndx]] = y_label[:, ndx, :]
inference = mymap.MAP(data=data_dict)
inference.initialize(var_list=PoseTools.get_vars('mdn'))
self.loss = inference.loss
starter_learning_rate = 0.00003
decay_steps = 5000 / 8 * self.conf.batch_size
learning_rate = tf.train.exponential_decay(
starter_learning_rate, self.ph['step'], decay_steps, 0.9,
staircase=True)
self.opt = tf.train.AdamOptimizer(
learning_rate=learning_rate).minimize(self.loss)
sess = tf.InteractiveSession()
self.createCursors(sess)
self.updateFeedDict(self.DBType.Train, sess=sess, distort=True)
sess.run(tf.global_variables_initializer())
示例11: createBaseRegSaver
# 需要导入模块: import PoseTools [as 别名]
# 或者: from PoseTools import get_vars [as 别名]
def createBaseRegSaver(self):
self.baseregsaver = tf.train.Saver(var_list = PoseTools.get_vars('regbase'),
max_to_keep=self.conf.maxckpt)
示例12: createEvalSaver
# 需要导入模块: import PoseTools [as 别名]
# 或者: from PoseTools import get_vars [as 别名]
def createEvalSaver(conf):
evalsaver = tf.train.Saver(var_list = PoseTools.get_vars('eval'),
max_to_keep=conf.maxckpt)
return evalsaver