当前位置: 首页>>代码示例>>Python>>正文


Python PoseTools.get_vars方法代码示例

本文整理汇总了Python中PoseTools.get_vars方法的典型用法代码示例。如果您正苦于以下问题:Python PoseTools.get_vars方法的具体用法?Python PoseTools.get_vars怎么用?Python PoseTools.get_vars使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在PoseTools的用法示例。


在下文中一共展示了PoseTools.get_vars方法的12个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: create_saver

# 需要导入模块: import PoseTools [as 别名]
# 或者: from PoseTools import get_vars [as 别名]
    def create_saver(self):
        saver = {}
        name = self.name
        net_name = self.net_name
        saver['out_file'] = os.path.join(
            self.conf.cachedir,
            self.conf.expname + '_' + name)
        if self.train_data_name is None:
            saver['train_data_file'] = os.path.join(
                self.conf.cachedir,
                self.conf.expname + '_' + name + '_traindata')
        else:
            saver['train_data_file'] = os.path.join(
                self.conf.cachedir,
                self.train_data_name)

        saver['ckpt_file'] = os.path.join(
            self.conf.cachedir,
            self.conf.expname + '_' + name + '_ckpt')
        saver['saver'] = (tf.train.Saver(var_list=PoseTools.get_vars(net_name+ '/'),
                                         max_to_keep=self.conf.maxckpt,
                                         save_relative_paths=True))
        self.saver = saver
        if self.dep_nets:
            self.dep_nets.create_joint_saver(self.name)
开发者ID:mkabra,项目名称:poseTF,代码行数:27,代码来源:PoseCommon.py

示例2: restore

# 需要导入模块: import PoseTools [as 别名]
# 或者: from PoseTools import get_vars [as 别名]
    def restore(self, sess, do_restore, at_step=-1):
        saver = self.saver
        name = self.net_name
        out_file = saver['out_file'].replace('\\', '/')
        latest_ckpt = tf.train.get_checkpoint_state(
            self.conf.cachedir, saver['ckpt_file'])
        if not latest_ckpt or not do_restore:
            start_at = 0
            sess.run(tf.variables_initializer(
                PoseTools.get_vars(name)),
                feed_dict=self.fd)
            print("Not loading {:s} variables. Initializing them".format(name))
        else:
            if at_step < 0:
                saver['saver'].restore(sess, latest_ckpt.model_checkpoint_path)
                match_obj = re.match(out_file + '-(\d*)', latest_ckpt.model_checkpoint_path)
                start_at = int(match_obj.group(1)) + 1
            else:
                aa = latest_ckpt.all_model_checkpoint_paths
                model_file = ''
                for a in aa:
                    match_obj = re.match(out_file + '-(\d*)', a)
                    step = int(match_obj.group(1))
                    if step >= at_step:
                        model_file = a
                        break
                saver['saver'].restore(sess, model_file)
                match_obj = re.match(out_file + '-(\d*)', model_file)
                start_at = int(match_obj.group(1)) + 1

        if self.dep_nets:
            self.dep_nets.restore_joint(sess, self.name, self.joint, do_restore)

        return start_at
开发者ID:mkabra,项目名称:poseTF,代码行数:36,代码来源:PoseCommon.py

示例3: restoreBaseReg

# 需要导入模块: import PoseTools [as 别名]
# 或者: from PoseTools import get_vars [as 别名]
 def restoreBaseReg(self,sess,restore):
     outfilename = os.path.join(self.conf.cachedir,self.conf.baseregoutname)
     traindatafilename = os.path.join(self.conf.cachedir,self.conf.baseregdataname)
     latest_ckpt = tf.train.get_checkpoint_state(self.conf.cachedir,
                                         latest_filename = self.conf.baseregckptname)
     if not latest_ckpt or not restore:
         self.baseregstartat = 0
         self.baseregtrainData = {'train_err':[], 'val_err':[], 'step_no':[],
                               'train_dist':[], 'val_dist':[] }
         sess.run(tf.initialize_variables(PoseTools.get_vars('base')))
         print("Not loading base variables. Initializing them")
         return False
     else:
         self.baseregsaver.restore(sess,latest_ckpt.model_checkpoint_path)
         matchObj = re.match(outfilename + '-(\d*)',latest_ckpt.model_checkpoint_path)
         self.baseregstartat = int(matchObj.group(1))+1
         with open(traindatafilename,'rb') as tdfile:
             inData = pickle.load(tdfile)
             if not isinstance(inData,dict):
                 self.baseregtrainData, loadconf = inData
                 print('Parameters that dont match for base:')
                 PoseTools.compare_conf(self.conf, loadconf)
             else:
                 print("No config was stored for base. Not comparing conf")
                 self.baseregtrainData = inData
         print("Loading base variables from %s"%latest_ckpt.model_checkpoint_path)
         return True
开发者ID:mkabra,项目名称:poseTF,代码行数:29,代码来源:PoseRegression.py

示例4: initialize_net

# 需要导入模块: import PoseTools [as 别名]
# 或者: from PoseTools import get_vars [as 别名]
 def initialize_net(self, sess):
     name = self.net_name
     sess.run(tf.variables_initializer(PoseTools.get_vars(name)),
              feed_dict=self.fd)
     print("Not loading {:s} variables. Initializing them".format(name))
     self.init_td()
     for dep_net in self.dep_nets:
         dep_net.initialize_net(sess)
     initialize_remaining_vars(sess)
开发者ID:mkabra,项目名称:poseTF,代码行数:11,代码来源:PoseCommon_dataset.py

示例5: restoreEval

# 需要导入模块: import PoseTools [as 别名]
# 或者: from PoseTools import get_vars [as 别名]
def restoreEval(sess,evalsaver,restore,conf,feed_dict):
    outfilename = os.path.join(conf.cachedir,conf.evaloutname)
    latest_ckpt = tf.train.get_checkpoint_state(conf.cachedir,
                                        latest_filename = conf.evalckptname)
    if not latest_ckpt or not restore:
        evalstartat = 0
        sess.run(tf.variables_initializer(PoseTools.get_vars('eval')), feed_dict=feed_dict)
        print("Not loading Eval variables. Initializing them")
    else:
        evalsaver.restore(sess,latest_ckpt.model_checkpoint_path)
        matchObj = re.match(outfilename + '-(\d*)',latest_ckpt.model_checkpoint_path)
        evalstartat = int(matchObj.group(1))+1
        print("Loading eval variables from %s"%latest_ckpt.model_checkpoint_path)
    return  evalstartat
开发者ID:mkabra,项目名称:poseTF,代码行数:16,代码来源:poseEval.py

示例6: restoreGen

# 需要导入模块: import PoseTools [as 别名]
# 或者: from PoseTools import get_vars [as 别名]
def restoreGen(sess,conf,genSaver,restore=True):
    outfilename = os.path.join(conf.cachedir,conf.genoutname)
    latest_ckpt = tf.train.get_checkpoint_state(conf.cachedir,
                                        latest_filename = conf.genckptname)
    if not latest_ckpt or not restore:
        startat = 0
        sess.run(tf.initialize_variables(PoseTools.get_vars('poseGen')))
        print("Not loading gen variables. Initializing them")
        didRestore = False
    else:
        genSaver.restore(sess,latest_ckpt.model_checkpoint_path)
        matchObj = re.match(outfilename + '-(\d*)',latest_ckpt.model_checkpoint_path)
        startat = int(matchObj.group(1))+1
        print("Loading gen variables from %s"%latest_ckpt.model_checkpoint_path)
        didRestore = True
        
    return didRestore,startat
开发者ID:mkabra,项目名称:poseTF,代码行数:19,代码来源:poseGen.py

示例7: create_shape_saver

# 需要导入模块: import PoseTools [as 别名]
# 或者: from PoseTools import get_vars [as 别名]
def create_shape_saver(conf):
    shape_saver = tf.train.Saver(var_list=PoseTools.get_vars('shape'), max_to_keep=conf.maxckpt)
    return shape_saver
开发者ID:mkabra,项目名称:poseTF,代码行数:5,代码来源:poseShape.py

示例8: createGenSaver

# 需要导入模块: import PoseTools [as 别名]
# 或者: from PoseTools import get_vars [as 别名]
def createGenSaver(conf):
    genSaver = tf.train.Saver(var_list = PoseTools.get_vars('poseGen'), max_to_keep=conf.maxckpt)
    return genSaver
开发者ID:mkabra,项目名称:poseTF,代码行数:5,代码来源:poseGen.py

示例9: Mixture

# 需要导入模块: import PoseTools [as 别名]
# 或者: from PoseTools import get_vars [as 别名]
            value=tf.zeros_like(y_ph))
# y_y = Mixture(cat=cat, components=components_y, value=tf.zeros_like(y_y_ph))
# Note: A bug exists in Mixture which prevents samples from it to have
# a shape of [None]. For now fix it using the value argument, as
# sampling is not necessary for MAP estimation anyways.

#

# There are no latent variables to infer. Thus inference is concerned
# with only training model parameters, which are baked into how we
# specify the neural networks.
n_epoch = 1000

#
inference = mymap.MAP(data={y: y_ph})
v_list = PoseTools.get_vars('nn')
inference.initialize(var_list=v_list,n_iter=n_epoch)

#

myOpt = True
starter_learning_rate = 0.001
step = tf.placeholder(tf.int64)
learning_rate = tf.train.exponential_decay(starter_learning_rate,
                                           step, 300, 0.9, staircase=True)

if myOpt:
    optimizer = tf.train.AdamOptimizer(learning_rate).minimize(inference.loss)

sess = tf.InteractiveSession()
# sess = ed.get_session()
开发者ID:mkabra,项目名称:poseTF,代码行数:33,代码来源:testEdward_ims.py

示例10: range

# 需要导入模块: import PoseTools [as 别名]
# 或者: from PoseTools import get_vars [as 别名]
self.feed_dict[self.ph['keep_prob']] = mdn_dropout
self.feed_dict[self.ph['phase_train_base']] = False
self.feed_dict[self.ph['phase_train_mdn']] = False
self.trainType = trainType

y = self.create_network()
self.openDBs()
self.createBaseSaver()
self.create_saver()

y_label = self.ph['base_locs'] / self.conf.rescale / self.conf.pool_scale
data_dict = {}
for ndx in range(self.conf.n_classes):
    data_dict[y[ndx]] = y_label[:, ndx, :]
inference = mymap.MAP(data=data_dict)
inference.initialize(var_list=PoseTools.get_vars('mdn'))
self.loss = inference.loss

starter_learning_rate = 0.00003
decay_steps = 5000 / 8 * self.conf.batch_size
learning_rate = tf.train.exponential_decay(
    starter_learning_rate, self.ph['step'], decay_steps, 0.9,
    staircase=True)

self.opt = tf.train.AdamOptimizer(
    learning_rate=learning_rate).minimize(self.loss)

sess = tf.InteractiveSession()
self.createCursors(sess)
self.updateFeedDict(self.DBType.Train, sess=sess, distort=True)
sess.run(tf.global_variables_initializer())
开发者ID:mkabra,项目名称:poseTF,代码行数:33,代码来源:scratch.py

示例11: createBaseRegSaver

# 需要导入模块: import PoseTools [as 别名]
# 或者: from PoseTools import get_vars [as 别名]
 def createBaseRegSaver(self):
     self.baseregsaver = tf.train.Saver(var_list = PoseTools.get_vars('regbase'),
                                        max_to_keep=self.conf.maxckpt)
开发者ID:mkabra,项目名称:poseTF,代码行数:5,代码来源:PoseRegression.py

示例12: createEvalSaver

# 需要导入模块: import PoseTools [as 别名]
# 或者: from PoseTools import get_vars [as 别名]
def createEvalSaver(conf):
    evalsaver = tf.train.Saver(var_list = PoseTools.get_vars('eval'),
                               max_to_keep=conf.maxckpt)
    return evalsaver
开发者ID:mkabra,项目名称:poseTF,代码行数:6,代码来源:poseEval.py


注:本文中的PoseTools.get_vars方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。