当前位置: 首页>>代码示例>>Python>>正文


Python model.Model方法代码示例

本文整理汇总了Python中model.model.Model方法的典型用法代码示例。如果您正苦于以下问题:Python model.Model方法的具体用法?Python model.Model怎么用?Python model.Model使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在model.model的用法示例。


在下文中一共展示了model.Model方法的11个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: test_save_load_network

# 需要导入模块: from model import model [as 别名]
# 或者: from model.model import Model [as 别名]
def test_save_load_network(self):
        local_net = Net_arch(self.hp)
        self.loss_f = nn.MSELoss()
        local_model = Model(self.hp, local_net, self.loss_f)

        self.model.save_network(self.logger)
        save_filename = "%s_%d.pt" % (self.hp.log.name, self.model.step)
        save_path = os.path.join(self.hp.log.chkpt_dir, save_filename)
        self.hp.load.network_chkpt_path = save_path

        assert os.path.exists(save_path) and os.path.isfile(save_path)
        assert os.path.exists(self.hp.log.log_file_path) and os.path.isfile(
            self.hp.log.log_file_path
        )

        local_model.load_network(logger=self.logger)
        parameters = zip(
            list(local_model.net.parameters()), list(self.model.net.parameters())
        )
        for load, origin in parameters:
            assert (load == origin).all() 
开发者ID:ryul99,项目名称:pytorch-project-template,代码行数:23,代码来源:model_test.py

示例2: test_save_load_state

# 需要导入模块: from model import model [as 别名]
# 或者: from model.model import Model [as 别名]
def test_save_load_state(self):
        local_net = Net_arch(self.hp)
        self.loss_f = nn.MSELoss()
        local_model = Model(self.hp, local_net, self.loss_f)

        self.model.save_training_state(self.logger)
        save_filename = "%s_%d.state" % (self.hp.log.name, self.model.step)
        save_path = os.path.join(self.hp.log.chkpt_dir, save_filename)
        self.hp.load.resume_state_path = save_path

        assert os.path.exists(save_path) and os.path.isfile(save_path)
        assert os.path.exists(self.hp.log.log_file_path) and os.path.isfile(
            self.hp.log.log_file_path
        )

        local_model.load_training_state(logger=self.logger)
        parameters = zip(
            list(local_model.net.parameters()), list(self.model.net.parameters())
        )
        for load, origin in parameters:
            assert (load == origin).all()
        assert local_model.epoch == self.model.epoch
        assert local_model.step == self.model.step 
开发者ID:ryul99,项目名称:pytorch-project-template,代码行数:25,代码来源:model_test.py

示例3: main

# 需要导入模块: from model import model [as 别名]
# 或者: from model.model import Model [as 别名]
def main(argv=None):
    # Configurations
    config = Config()
    config.DATA_DIR = ['/data/']
    config.LOG_DIR = './log/model'
    config.MODE = 'training'
    config.STEPS_PER_EPOCH_VAL = 180
    config.display()

    # Get images and labels.
    dataset_train = Dataset(config, 'train')
    # Build a Graph
    model = Model(config)

    # Train the model
    model.compile()
    model.train(dataset_train, None) 
开发者ID:yaojieliu,项目名称:CVPR2019-DeepTreeLearningForZeroShotFaceAntispoofing,代码行数:19,代码来源:train.py

示例4: main

# 需要导入模块: from model import model [as 别名]
# 或者: from model.model import Model [as 别名]
def main():
    torch.set_num_threads(multiprocessing.cpu_count())
    args = parse_args()
    if args.set == 'gta':
        from model.model import Model
    elif args.set == 'kitti':
        from model.model_cen import Model
    else:
        raise ValueError("Model not found")

    model = Model(args.arch,
                  args.roi_name,
                  args.down_ratio,
                  args.roi_kernel)
    model = nn.DataParallel(model)
    model = model.to(args.device)

    if args.phase == 'train':
        run_training(model, args)
    elif args.phase == 'test':
        test_model(model, args) 
开发者ID:ucbdrive,项目名称:3d-vehicle-tracking,代码行数:23,代码来源:mono_3d_estimation.py

示例5: setup_method

# 需要导入模块: from model import model [as 别名]
# 或者: from model.model import Model [as 别名]
def setup_method(self, method):
        super(TestModel, self).setup_method()
        self.net = Net_arch(self.hp)
        self.loss_f = nn.CrossEntropyLoss()
        self.model = Model(self.hp, self.net, self.loss_f) 
开发者ID:ryul99,项目名称:pytorch-project-template,代码行数:7,代码来源:model_test.py

示例6: convert_layer_to_tensor

# 需要导入模块: from model import model [as 别名]
# 或者: from model.model import Model [as 别名]
def convert_layer_to_tensor(layer, dtype=None, name=None, as_ref=False):
    if not isinstance(layer, (Layer, Model)):
        return NotImplemented
    return layer.output 
开发者ID:akosiorek,项目名称:hart,代码行数:6,代码来源:__init__.py

示例7: loadModelAndData

# 需要导入模块: from model import model [as 别名]
# 或者: from model.model import Model [as 别名]
def loadModelAndData(num):
    # Load dictionaries
    with open('data/input_lang.index2word.json') as f:
        input_lang_index2word = json.load(f)
    with open('data/input_lang.word2index.json') as f:
        input_lang_word2index = json.load(f)
    with open('data/output_lang.index2word.json') as f:
        output_lang_index2word = json.load(f)
    with open('data/output_lang.word2index.json') as f:
        output_lang_word2index = json.load(f)

    # Reload existing checkpoint
    model = Model(args, input_lang_index2word, output_lang_index2word, input_lang_word2index, output_lang_word2index)
    if args.load_param:
        model.loadModel(iter=num)

    # Load data
    if os.path.exists(args.decode_output):
        shutil.rmtree(args.decode_output)
        os.makedirs(args.decode_output)
    else:
        os.makedirs(args.decode_output)

    if os.path.exists(args.valid_output):
        shutil.rmtree(args.valid_output)
        os.makedirs(args.valid_output)
    else:
        os.makedirs(args.valid_output)

    # Load validation file list:
    with open('data/val_dials.json') as outfile:
        val_dials = json.load(outfile)

    # Load test file list:
    with open('data/test_dials.json') as outfile:
        test_dials = json.load(outfile)
    return model, val_dials, test_dials 
开发者ID:budzianowski,项目名称:multiwoz,代码行数:39,代码来源:test.py

示例8: main

# 需要导入模块: from model import model [as 别名]
# 或者: from model.model import Model [as 别名]
def main():
  parser = argparse.ArgumentParser()
  parser.add_argument('--path', type=str, required=True,
                      help='path to image file.')
  parser.add_argument('--checkpoint', type=str, default='data/model.ckpt',
                      help='path to image file.')
  args = parser.parse_args()

  params = {
    'checkpoint': args.checkpoint,
    'dataset':{
      'dataset_dir': 'data',
      'charset_filename': 'charset_size=63.txt',
      'max_sequence_length': 30,
    },
    'beam_width': 1,
    'summary': False
  }
  model = Model(params, ModeKeys.INFER)
  image = tf.placeholder(tf.uint8, (1, 32, 100, 3), name='image')
  predictions, _, _ = model({'image': image}, None)

  assert os.path.exists(args.path), '%s does not exists!' % args.path
  raw_image = Image.open(args.path).convert('RGB')
  raw_image = raw_image.resize((100, 32), Image.BILINEAR)
  raw_image = np.array(raw_image)[None, :]

  with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    sess.run(tf.tables_initializer())
    predictions = sess.run(predictions, feed_dict={image: raw_image})
    text = predictions['predicted_text'][0]
    print('%s: %s' % (args.path, text)) 
开发者ID:FangShancheng,项目名称:conv-ensemble-str,代码行数:35,代码来源:demo.py

示例9: __init__

# 需要导入模块: from model import model [as 别名]
# 或者: from model.model import Model [as 别名]
def __init__(self, classifier_data):
		port = classifier_data.port
		bufsize = classifier_data.bufsize
		super().__init__(port, bufsize)
		self.sess = tf.Session()
		self.nn = Model()
		self.nn.init(classifier_data.graph_path, classifier_data.checkpoint_path, self.sess)
		self.lib = getLib() 
开发者ID:BerkeleyLearnVerify,项目名称:VerifAI,代码行数:10,代码来源:classifier.py

示例10: main

# 需要导入模块: from model import model [as 别名]
# 或者: from model.model import Model [as 别名]
def main(args, defaults):
    parameters = process_args(args, defaults)
    logging.basicConfig(
        level=logging.DEBUG,
        format='%(asctime)-15s %(name)-5s %(levelname)-8s %(message)s',
        filename=parameters.log_path)
    console = logging.StreamHandler()
    console.setLevel(logging.INFO)
    formatter = logging.Formatter('%(asctime)-15s %(name)-5s %(levelname)-8s %(message)s')
    console.setFormatter(formatter)
    logging.getLogger('').addHandler(console)

    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
        model = Model(
                phase = parameters.phase,
                visualize = parameters.visualize,
                data_path = parameters.data_path,
                data_base_dir = parameters.data_base_dir,
                output_dir = parameters.output_dir,
                batch_size = parameters.batch_size,
                initial_learning_rate = parameters.initial_learning_rate,
                num_epoch = parameters.num_epoch,
                steps_per_checkpoint = parameters.steps_per_checkpoint,
                target_vocab_size = parameters.target_vocab_size, 
                model_dir = parameters.model_dir,
                target_embedding_size = parameters.target_embedding_size,
                attn_num_hidden = parameters.attn_num_hidden,
                attn_num_layers = parameters.attn_num_layers,
                clip_gradients = parameters.clip_gradients,
                max_gradient_norm = parameters.max_gradient_norm,
                load_model = parameters.load_model,
                valid_target_length = float('inf'),
                gpu_id=parameters.gpu_id,
                use_gru=parameters.use_gru,
                session = sess)
        model.launch() 
开发者ID:da03,项目名称:Attention-OCR,代码行数:38,代码来源:launcher.py

示例11: decode

# 需要导入模块: from model import model [as 别名]
# 或者: from model.model import Model [as 别名]
def decode():
    tfrecords_list, num_batches = read_list(FLAGS.lists_dir, FLAGS.data_type, FLAGS.batch_size)

    with tf.Graph().as_default():
        with tf.device('/cpu:0'):
            with tf.name_scope('input'):
                cmvn = np.load(FLAGS.inputs_cmvn)
                cmvn_aux = np.load(FLAGS.inputs_cmvn.replace('cmvn', 'cmvn_aux'))
                if FLAGS.with_labels:
                    inputs, inputs_cmvn, inputs_cmvn_aux, labels, lengths, lengths_aux = paddedFIFO_batch(tfrecords_list, FLAGS.batch_size,
                        FLAGS.input_size, FLAGS.output_size, cmvn=cmvn, cmvn_aux=cmvn_aux, with_labels=FLAGS.with_labels, 
                        num_enqueuing_threads=1, num_epochs=1, shuffle=False)
                else:
                    inputs, inputs_cmvn, inputs_cmvn_aux, lengths, lengths_aux = paddedFIFO_batch(tfrecords_list, FLAGS.batch_size,
                        FLAGS.input_size, FLAGS.output_size, cmvn=cmvn, cmvn_aux=cmvn_aux, with_labels=FLAGS.with_labels,
                        num_enqueuing_threads=1, num_epochs=1, shuffle=False)
                    labels = None
               
        with tf.name_scope('model'):
            model = Model(FLAGS, inputs, inputs_cmvn, inputs_cmvn_aux, labels, lengths, lengths_aux, infer=True)

        init = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
        sess = tf.Session()
        sess.run(init)

        checkpoint = tf.train.get_checkpoint_state(FLAGS.save_model_dir)
        if checkpoint and checkpoint.model_checkpoint_path:
            tf.logging.info("Restore best model from " + checkpoint.model_checkpoint_path)
            model.saver.restore(sess, checkpoint.model_checkpoint_path)
        else:
            tf.logging.fatal("Checkpoint is not found, please check the best model save path is correct.")
            sys.exit(-1)

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        try:
            for batch in xrange(num_batches):
                if coord.should_stop():
                    break

                sep, mag_lengths = sess.run([model._sep, model._lengths])
                for i in xrange(FLAGS.batch_size):
                    filename = tfrecords_list[FLAGS.batch_size*batch+i]
                    (_, name) = os.path.split(filename)
                    (uttid, _) = os.path.splitext(name)
                    noisy_file = os.path.join(FLAGS.noisy_dir, uttid + '.wav')
                    enhan_sig, rate = reconstruct(np.squeeze(sep[i,:mag_lengths[i],:]), noisy_file)
                    savepath = os.path.join(FLAGS.rec_dir, uttid + '.wav')
                    wav.write(savepath, rate, enhan_sig)

                if (batch+1) % 100 == 0:
                    tf.logging.info("Number of batch processed: %d." % (batch+1))

        except Exception, e:
            coord.request_stop(e)
        finally: 
开发者ID:xuchenglin28,项目名称:speaker_extraction,代码行数:58,代码来源:decode.py


注:本文中的model.model.Model方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。