当前位置: 首页>>代码示例>>Python>>正文


Python models.create_model方法代码示例

本文整理汇总了Python中models.create_model方法的典型用法代码示例。如果您正苦于以下问题:Python models.create_model方法的具体用法?Python models.create_model怎么用?Python models.create_model使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在models的用法示例。


在下文中一共展示了models.create_model方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: main

# 需要导入模块: import models [as 别名]
# 或者: from models import create_model [as 别名]
def main(extra_flags):
    # Check no unknown flags was passed.
    assert len(extra_flags) >= 1
    if len(extra_flags) > 1:
        raise ValueError('Received unknown flags: %s' % extra_flags[1:])

    # Get parameters from FLAGS passed.
    params = parameters.make_params_from_flags()
    deploy.setup_env(params)
    parameters.save_params(params, params.train_dir)

    # TF log...
    tfversion = deploy.tensorflow_version_tuple()
    deploy.log_fn('TensorFlow:  %i.%i' % (tfversion[0], tfversion[1]))

    # Create model and dataset.
    dataset = datasets.create_dataset(
        params.data_dir, params.data_name, params.data_subset)
    model = models.create_model(params.model, dataset)
    set_model_params(model, params)

    # Run CNN trainer.
    trainer = deploy.TrainerCNN(dataset, model, params)
    trainer.print_info()
    trainer.run() 
开发者ID:balancap,项目名称:tf-imagenet,代码行数:27,代码来源:train.py

示例2: test_utils

# 需要导入模块: import models [as 别名]
# 或者: from models import create_model [as 别名]
def test_utils():
    model = models.create_model(False, 'cifar10', 'resnet20_cifar', parallel=False)
    assert model is not None

    p = distiller.model_find_param(model, "")
    assert p is None

    # Search for a parameter by its "non-parallel" name
    p = distiller.model_find_param(model, "layer1.0.conv1.weight")
    assert p is not None

    # Search for a module name
    module_to_find = None
    for name, m in model.named_modules():
        if name == "layer1.0.conv1":
            module_to_find = m
            break
    assert module_to_find is not None

    module_name = distiller.model_find_module_name(model, module_to_find)
    assert module_name == "layer1.0.conv1" 
开发者ID:cornell-zhang,项目名称:dnn-quant-ocs,代码行数:23,代码来源:test_basic.py

示例3: run_test

# 需要导入模块: import models [as 别名]
# 或者: from models import create_model [as 别名]
def run_test(epoch=-1):
    print('Running Test')
    opt = TestOptions().parse()
    opt.serial_batches = True  # no shuffle
    dataset = DataLoader(opt)
    model = create_model(opt)
    writer = Writer(opt)
    # test
    writer.reset_counter()
    for i, data in enumerate(dataset):
        model.set_input(data)
        ncorrect, nexamples = model.test()
        writer.update_counter(ncorrect, nexamples)
    writer.print_acc(epoch, writer.acc)
    return writer.acc 
开发者ID:ranahanocka,项目名称:MeshCNN,代码行数:17,代码来源:test.py

示例4: load

# 需要导入模块: import models [as 别名]
# 或者: from models import create_model [as 别名]
def load(self, checkpoint_path, model_name='tacotron'):
    print('Constructing model: %s' % model_name)
    inputs = tf.placeholder(tf.int32, [1, None], 'inputs')
    reference_mel = tf.placeholder(tf.float32, [1, None, 80], 'reference_mel')
    input_lengths = tf.placeholder(tf.int32, [1], 'input_lengths')
    with tf.variable_scope('model') as scope:
      self.model = create_model(model_name, hparams)
      self.model.initialize(inputs, input_lengths, reference_mel=reference_mel)
      self.wav_output = audio.inv_spectrogram_tensorflow(self.model.linear_outputs[0])

    print('Loading checkpoint: %s' % checkpoint_path)
    self.session = tf.Session()
    self.session.run(tf.global_variables_initializer())
    saver = tf.train.Saver()
    saver.restore(self.session, checkpoint_path) 
开发者ID:yanggeng1995,项目名称:vae_tacotron,代码行数:17,代码来源:synthesizer.py

示例5: main

# 需要导入模块: import models [as 别名]
# 或者: from models import create_model [as 别名]
def main():    
    opt = TestOptions().parse()
    opt.is_flip = False  
    opt.batchSize = 1
    data_loader = CreateDataLoader(opt)
    model = create_model(opt) 
    web_dir = os.path.join(opt.results_dir, 'test')
    webpage = html.HTML(web_dir, 'task {}'.format(opt.exp_name))

    for i, data in enumerate(islice(data_loader, opt.how_many)):
        print('process input image %3.3d/%3.3d' % (i, opt.how_many))
        results = model.translation(data)
        img_path = 'image%3.3i' % i
        save_images(webpage, results, img_path, None, width=opt.fine_size)
    webpage.save() 
开发者ID:Xiaoming-Yu,项目名称:DMIT,代码行数:17,代码来源:test.py

示例6: main

# 需要导入模块: import models [as 别名]
# 或者: from models import create_model [as 别名]
def main():
    opt = TrainOptions().parse()
    data_loader = CreateDataLoader(opt)
    dataset_size = len(data_loader) * opt.batch_size
    visualizer = Visualizer(opt)
    model = create_model(opt)    
    start_epoch = model.start_epoch
    total_steps = start_epoch*dataset_size
    for epoch in range(start_epoch+1, opt.niter+opt.niter_decay+1):
        epoch_start_time = time.time()
        model.update_lr()
        save_result = True
        for i, data in enumerate(data_loader):
            iter_start_time = time.time()
            total_steps += opt.batch_size
            epoch_iter = total_steps - dataset_size * (epoch - 1)
            model.prepare_data(data)
            model.update_model()
            if save_result or total_steps % opt.display_freq == 0:
                save_result = save_result or total_steps % opt.update_html_freq == 0
                visualizer.display_current_results(model.get_current_visuals(), epoch, ncols=1, save_result=save_result)
                save_result = False
            if total_steps % opt.print_freq == 0:
                errors = model.get_current_errors()
                t = (time.time() - iter_start_time) / opt.batch_size
                visualizer.print_current_errors(epoch, epoch_iter, errors, t)
                if opt.display_id > 0:
                    visualizer.plot_current_errors(epoch, float(epoch_iter)/dataset_size, opt, errors)
        print('epoch {} cost dime {}'.format(epoch,time.time()-epoch_start_time))
        model.save_ckpt(epoch)
        model.save_generator('latest')
        if epoch % opt.save_epoch_freq == 0:
            print('saving the generator at the end of epoch {}, iters {}'.format(epoch, total_steps))
            model.save_generator(epoch) 
开发者ID:Xiaoming-Yu,项目名称:DMIT,代码行数:36,代码来源:train.py

示例7: load

# 需要导入模块: import models [as 别名]
# 或者: from models import create_model [as 别名]
def load(self, checkpoint_path, model_name='tacotron'):
    print('Constructing model: %s' % model_name)
    inputs = tf.placeholder(tf.int32, [1, None], 'inputs')
    input_lengths = tf.placeholder(tf.int32, [1], 'input_lengths')
    with tf.variable_scope('model') as scope:
      self.model = create_model(model_name, hparams)
      self.model.initialize(inputs, input_lengths)
      self.wav_output = audio.inv_spectrogram_tensorflow(self.model.linear_outputs[0])

    print('Loading checkpoint: %s' % checkpoint_path)
    self.session = tf.Session()
    self.session.run(tf.global_variables_initializer())
    saver = tf.train.Saver()
    saver.restore(self.session, checkpoint_path) 
开发者ID:youssefsharief,项目名称:arabic-tacotron-tts,代码行数:16,代码来源:synthesizer.py

示例8: main

# 需要导入模块: import models [as 别名]
# 或者: from models import create_model [as 别名]
def main(extra_flags):
    # Check no unknown flags was passed.
    assert len(extra_flags) >= 1
    if len(extra_flags) > 1:
        raise ValueError('Received unknown flags: %s' % extra_flags[1:])

    # Get parameters from FLAGS passed.
    params = parameters.make_params_from_flags()
    deploy.setup_env(params)
    # Training parameters, update using json file.
    params = replace_with_train_params(params)

    # TF log...
    tfversion = deploy.tensorflow_version_tuple()
    deploy.log_fn('TensorFlow:  %i.%i' % (tfversion[0], tfversion[1]))

    # Create model and dataset.
    dataset = datasets.create_dataset(
        params.data_dir, params.data_name, params.data_subset)
    model = models.create_model(params.model, dataset)
    train.set_model_params(model, params)

    # Set the number of batches to the size of the eval dataset.
    params = params._replace(
        num_batches=int(dataset.num_examples_per_epoch() / (params.batch_size * params.num_gpus)))
    # Run CNN trainer.
    trainer = deploy.TrainerCNN(dataset, model, params)
    trainer.print_info()
    trainer.run() 
开发者ID:balancap,项目名称:tf-imagenet,代码行数:31,代码来源:eval.py

示例9: create_graph

# 需要导入模块: import models [as 别名]
# 或者: from models import create_model [as 别名]
def create_graph(dataset, arch):
    if dataset == 'imagenet':
        dummy_input = torch.randn((1, 3, 224, 224), requires_grad=False)
    elif dataset == 'cifar10':
        dummy_input = torch.randn((1, 3, 32, 32))
    assert dummy_input is not None, "Unsupported dataset ({}) - aborting draw operation".format(dataset)

    model = create_model(False, dataset, arch, parallel=False)
    assert model is not None
    return SummaryGraph(model, dummy_input.cuda()) 
开发者ID:cornell-zhang,项目名称:dnn-quant-ocs,代码行数:12,代码来源:thinning.py

示例10: test_load

# 需要导入模块: import models [as 别名]
# 或者: from models import create_model [as 别名]
def test_load():
    logger = logging.getLogger('simple_example')
    logger.setLevel(logging.INFO)

    model = create_model(False, 'cifar10', 'resnet20_cifar')
    model, compression_scheduler, start_epoch = load_checkpoint(model, '../examples/ssl/checkpoints/checkpoint_trained_dense.pth.tar')
    assert compression_scheduler is not None
    assert start_epoch == 180 
开发者ID:cornell-zhang,项目名称:dnn-quant-ocs,代码行数:10,代码来源:test_infra.py

示例11: test_load_negative

# 需要导入模块: import models [as 别名]
# 或者: from models import create_model [as 别名]
def test_load_negative():
    with pytest.raises(FileNotFoundError):
        model = create_model(False, 'cifar10', 'resnet20_cifar')
        model, compression_scheduler, start_epoch = load_checkpoint(model, 'THIS_IS_AN_ERROR/checkpoint_trained_dense.pth.tar') 
开发者ID:cornell-zhang,项目名称:dnn-quant-ocs,代码行数:6,代码来源:test_infra.py

示例12: setup_test

# 需要导入模块: import models [as 别名]
# 或者: from models import create_model [as 别名]
def setup_test(arch, dataset, parallel):
    model = create_model(False, dataset, arch, parallel=parallel)
    assert model is not None

    # Create the masks
    zeros_mask_dict = {}
    for name, param in model.named_parameters():
        masker = distiller.ParameterMasker(name)
        zeros_mask_dict[name] = masker
    return model, zeros_mask_dict 
开发者ID:cornell-zhang,项目名称:dnn-quant-ocs,代码行数:12,代码来源:common.py

示例13: name_test

# 需要导入模块: import models [as 别名]
# 或者: from models import create_model [as 别名]
def name_test(dataset, arch):
    model = create_model(False, dataset, arch, parallel=False)
    modelp = create_model(False, dataset, arch, parallel=True)
    assert model is not None and modelp is not None

    mod_names   = [mod_name for mod_name, _ in model.named_modules()]
    mod_names_p = [mod_name for mod_name, _ in modelp.named_modules()]
    assert mod_names is not None and mod_names_p is not None
    assert len(mod_names)+1 == len(mod_names_p)

    for i in range(len(mod_names)-1):
        assert mod_names[i+1] == normalize_module_name(mod_names_p[i+2])
        logging.debug("{} {} {}".format(mod_names_p[i+2], mod_names[i+1], normalize_module_name(mod_names_p[i+2])))
        assert mod_names_p[i+2] == denormalize_module_name(modelp, mod_names[i+1]) 
开发者ID:cornell-zhang,项目名称:dnn-quant-ocs,代码行数:16,代码来源:test_summarygraph.py

示例14: main

# 需要导入模块: import models [as 别名]
# 或者: from models import create_model [as 别名]
def main():
    global args, best_top1
    args = parse()
    if not args.no_logger:
        tee.Tee(args.cache + '/log.txt')
    print(vars(args))
    seed(args.manual_seed)

    model, criterion, optimizer = create_model(args)
    if args.resume:
        best_top1 = checkpoints.load(args, model, optimizer)
    print(model)
    trainer = train.Trainer()
    loaders = get_dataset(args)
    train_loader = loaders[0]

    if args.evaluate:
        scores = validate(trainer, loaders, model, criterion, args)
        checkpoints.score_file(scores, "{}/model_000.txt".format(args.cache))
        return

    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            trainer.train_sampler.set_epoch(epoch)
        scores = {}
        scores.update(trainer.train(train_loader, model, criterion, optimizer, epoch, args))
        scores.update(validate(trainer, loaders, model, criterion, args, epoch))

        is_best = scores[args.metric] > best_top1
        best_top1 = max(scores[args.metric], best_top1)
        checkpoints.save(epoch, args, model, optimizer, is_best, scores, args.metric)
    if not args.nopdb:
        pdb.set_trace() 
开发者ID:gsig,项目名称:actor-observer,代码行数:35,代码来源:main.py

示例15: create_inference_graph

# 需要导入模块: import models [as 别名]
# 或者: from models import create_model [as 别名]
def create_inference_graph(wanted_words, sample_rate, clip_duration_ms,
                           clip_stride_ms, window_size_ms, window_stride_ms,
                           dct_coefficient_count, model_architecture):
  """Creates an audio model with the nodes needed for inference.

  Uses the supplied arguments to create a model, and inserts the input and
  output nodes that are needed to use the graph for inference.

  Args:
    wanted_words: Comma-separated list of the words we're trying to recognize.
    sample_rate: How many samples per second are in the input audio files.
    clip_duration_ms: How many samples to analyze for the audio pattern.
    clip_stride_ms: How often to run recognition. Useful for models with cache.
    window_size_ms: Time slice duration to estimate frequencies from.
    window_stride_ms: How far apart time slices should be.
    dct_coefficient_count: Number of frequency bands to analyze.
    model_architecture: Name of the kind of model to generate.
  """

  words_list = input_data.prepare_words_list(wanted_words.split(','))
  model_settings = models.prepare_model_settings(
      len(words_list), sample_rate, clip_duration_ms, window_size_ms,
      window_stride_ms, dct_coefficient_count)
  runtime_settings = {'clip_stride_ms': clip_stride_ms}

  wav_data_placeholder = tf.placeholder(tf.string, [], name='wav_data')
  decoded_sample_data = contrib_audio.decode_wav(
      wav_data_placeholder,
      desired_channels=1,
      desired_samples=model_settings['desired_samples'],
      name='decoded_sample_data')
  spectrogram = contrib_audio.audio_spectrogram(
      decoded_sample_data.audio,
      window_size=model_settings['window_size_samples'],
      stride=model_settings['window_stride_samples'],
      magnitude_squared=True)
  fingerprint_input = contrib_audio.mfcc(
      spectrogram,
      decoded_sample_data.sample_rate,
      dct_coefficient_count=dct_coefficient_count)
  fingerprint_frequency_size = model_settings['dct_coefficient_count']
  fingerprint_time_size = model_settings['spectrogram_length']
  reshaped_input = tf.reshape(fingerprint_input, [
      -1, fingerprint_time_size * fingerprint_frequency_size
  ])

  logits = models.create_model(
      reshaped_input, model_settings, model_architecture, is_training=False,
      runtime_settings=runtime_settings)

  # Create an output to use for inference.
  tf.nn.softmax(logits, name='labels_softmax') 
开发者ID:nesl,项目名称:adversarial_audio,代码行数:54,代码来源:freeze.py


注:本文中的models.create_model方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。