当前位置: 首页>>代码示例>>Python>>正文


Python trainer.Trainer方法代码示例

本文整理汇总了Python中trainer.Trainer方法的典型用法代码示例。如果您正苦于以下问题:Python trainer.Trainer方法的具体用法?Python trainer.Trainer怎么用?Python trainer.Trainer使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在trainer的用法示例。


在下文中一共展示了trainer.Trainer方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: main

# 需要导入模块: import trainer [as 别名]
# 或者: from trainer import Trainer [as 别名]
def main(args):
    with open(args.config) as f:
        if version.parse(yaml.version >= "5.1"):
            config = yaml.load(f, Loader=yaml.FullLoader)
        else:
            config = yaml.load(f)

    for k, v in config.items():
        setattr(args, k, v)

    # exp path
    if not hasattr(args, 'exp_path'):
        args.exp_path = os.path.dirname(args.config)

    # dist init
    if mp.get_start_method(allow_none=True) != 'spawn':
        mp.set_start_method('spawn', force=True)
    dist_init(args.launcher, backend='nccl')

    # train
    trainer = Trainer(args)
    trainer.run() 
开发者ID:XiaohangZhan,项目名称:conditional-motion-propagation,代码行数:24,代码来源:main.py

示例2: main

# 需要导入模块: import trainer [as 别名]
# 或者: from trainer import Trainer [as 别名]
def main(_):

    #Directory generating.. for saving
    prepare_dirs(config)

    #Random seed settings
    rng = np.random.RandomState(config.random_seed)
    tf.set_random_seed(config.random_seed)

    #Model training
    trainer = Trainer(config, rng)
    save_config(config.model_dir, config)
    if config.is_train:
        trainer.train()
    else:
        if not config.load_path:
            raise Exception(
                "[!] You should specify `load_path` to "
                "load a pretrained model")
        trainer.test() 
开发者ID:youngjoo-epfl,项目名称:gconvRNN,代码行数:22,代码来源:gcrn_main.py

示例3: main

# 需要导入模块: import trainer [as 别名]
# 或者: from trainer import Trainer [as 别名]
def main(is_debug):
    # configs
    dataset_dir = '../datasets/cardio_dance_512'
    pose_name = '../datasets/cardio_dance_512/poses.npy'
    ckpt_dir = './checkpoints/dance_test_new_down2_res6'
    log_dir = './logs/dance_test_new_down2_res6'
    batch_num = 0
    batch_size = 64

    image_folder = dataset.ImageFolderDataset(dataset_dir, cache=os.path.join(dataset_dir, 'local.db'))
    face_dataset = dataset.FaceCropDataset(image_folder, pose_name, image_transforms, crop_size=48)  # 48 for 512-frame, 96 for HD frame
    data_loader = DataLoader(face_dataset, batch_size=batch_size,
                             drop_last=True, num_workers=4, shuffle=True)

    generator, discriminator, batch_num = load_models(ckpt_dir, batch_num)

    if is_debug:
        trainer = Trainer(ckpt_dir, log_dir, face_dataset, data_loader, log_every=1, save_every=1)
    else:
        trainer = Trainer(ckpt_dir, log_dir, face_dataset, data_loader)
    trainer.train(generator, discriminator, batch_num) 
开发者ID:Lotayou,项目名称:everybody_dance_now_pytorch,代码行数:23,代码来源:main.py

示例4: main

# 需要导入模块: import trainer [as 别名]
# 或者: from trainer import Trainer [as 别名]
def main(is_debug):
    # configs
    import os

    dataset_dir = '../data/face'
    pose_name = '../data/target/pose.npy'
    ckpt_dir = '../checkpoints/face'
    log_dir = '../checkpoints/face/logs'
    batch_num = 10
    batch_size = 10

    image_folder = dataset.ImageFolderDataset(dataset_dir, cache=os.path.join(dataset_dir, 'local.db'))
    face_dataset = dataset.FaceCropDataset(image_folder, pose_name, image_transforms, crop_size=48)  # 48 for 512-frame, 96 for HD frame
    data_loader = DataLoader(face_dataset, batch_size=batch_size,
                             drop_last=True, num_workers=4, shuffle=True)

    generator, discriminator, batch_num = load_models(ckpt_dir, batch_num)

    if is_debug:
        trainer = Trainer(ckpt_dir, log_dir, face_dataset, data_loader, log_every=1, save_every=1)
    else:
        trainer = Trainer(ckpt_dir, log_dir, face_dataset, data_loader)
    trainer.train(generator, discriminator, batch_num) 
开发者ID:CUHKSZ-TQL,项目名称:EverybodyDanceNow_reproduce_pytorch,代码行数:25,代码来源:main.py

示例5: train

# 需要导入模块: import trainer [as 别名]
# 或者: from trainer import Trainer [as 别名]
def train(self):
        assert callable(self.model), "model is not callable!!"
        assert callable(self.loss), "loss is not callable!!"
        assert all(callable(met) for met in self.metrics), "metrics is not callable!!"
        assert "trainer" in self.config, "trainer hasn't been configured!!"
        assert isinstance(self.data_loader, Iterable), "data_loader is not iterable!!"

        # the num of classes in dataset must bet the same as model's output
        if hasattr(self.data_loader, 'classes'):
            true_classes = len(self.data_loader.classes)
            model_output = self.config['arch']['args']['n_class']
            assert true_classes==model_output, "model分类数为{},可是实际上有{}个类".format(
                model_output, true_classes)

        if "name" not in self.config:
            self.config["name"] = "_".join(self.config["arch"]["type"], 
                self.config["data_loader"]["type"])
        self.trainer = Trainer(self.model, self.loss, self.metrics, self.optimizer, 
            resume=self.resume, config=self.config, data_loader=self.data_loader,
            valid_data_loader=self.valid_data_loader, lr_scheduler=self.lr_scheduler,
            train_logger=self.train_logger)        
        self.trainer.train() 
开发者ID:daili0015,项目名称:ModelFeast,代码行数:24,代码来源:classifier.py

示例6: main

# 需要导入模块: import trainer [as 别名]
# 或者: from trainer import Trainer [as 别名]
def main(config, resume):
    train_logger = Logger()

    # DATA LOADERS
    train_loader = get_instance(dataloaders, 'train_loader', config)
    val_loader = get_instance(dataloaders, 'val_loader', config)

    # MODEL
    model = get_instance(models, 'arch', config, train_loader.dataset.num_classes)
    print(f'\n{model}\n')

    # LOSS
    loss = getattr(losses, config['loss'])(ignore_index = config['ignore_index'])

    # TRAINING
    trainer = Trainer(
        model=model,
        loss=loss,
        resume=resume,
        config=config,
        train_loader=train_loader,
        val_loader=val_loader,
        train_logger=train_logger)

    trainer.train() 
开发者ID:yassouali,项目名称:pytorch_segmentation,代码行数:27,代码来源:train.py

示例7: main

# 需要导入模块: import trainer [as 别名]
# 或者: from trainer import Trainer [as 别名]
def main(_):
  prepare_dirs_and_logger(config)

  if not config.task.lower().startswith('tsp'):
    raise Exception("[!] Task should starts with TSP")

  if config.max_enc_length is None:
    config.max_enc_length = config.max_data_length
  if config.max_dec_length is None:
    config.max_dec_length = config.max_data_length

  rng = np.random.RandomState(config.random_seed)
  tf.set_random_seed(config.random_seed)

  trainer = Trainer(config, rng)
  save_config(config.model_dir, config)

  if config.is_train:
    trainer.train()
  else:
    if not config.load_path:
      raise Exception("[!] You should specify `load_path` to load a pretrained model")
    trainer.test()

  tf.logging.info("Run finished.") 
开发者ID:devsisters,项目名称:neural-combinatorial-rl-tensorflow,代码行数:27,代码来源:main.py

示例8: main

# 需要导入模块: import trainer [as 别名]
# 或者: from trainer import Trainer [as 别名]
def main():
    global model
    if args.data_test == ['video']:
        from videotester import VideoTester
        model = model.Model(args, checkpoint)
        t = VideoTester(args, model, checkpoint)
        t.test()
    else:
        if checkpoint.ok:
            loader = data.Data(args)
            _model = model.Model(args, checkpoint)
            _loss = loss.Loss(args, checkpoint) if not args.test_only else None
            t = Trainer(args, loader, _model, _loss, checkpoint)
            while not t.terminate():
                t.train()
                t.test()

            checkpoint.done() 
开发者ID:thstkdgus35,项目名称:EDSR-PyTorch,代码行数:20,代码来源:main.py

示例9: main

# 需要导入模块: import trainer [as 别名]
# 或者: from trainer import Trainer [as 别名]
def main(args):
    init_logger()
    set_seed(args)
    tokenizer = load_tokenizer(args)

    train_dataset = load_and_cache_examples(args, tokenizer, mode="train")
    dev_dataset = load_and_cache_examples(args, tokenizer, mode="dev")
    test_dataset = load_and_cache_examples(args, tokenizer, mode="test")

    trainer = Trainer(args, train_dataset, dev_dataset, test_dataset)

    if args.do_train:
        trainer.train()

    if args.do_eval:
        trainer.load_model()
        trainer.evaluate("test") 
开发者ID:monologg,项目名称:JointBERT,代码行数:19,代码来源:main.py

示例10: main

# 需要导入模块: import trainer [as 别名]
# 或者: from trainer import Trainer [as 别名]
def main(args):
    init_logger()
    set_seed(args)
    
    tokenizer = load_tokenizer(args)

    train_dataset = None
    dev_dataset = None
    test_dataset = None

    if args.do_train or args.do_eval:
        test_dataset = load_and_cache_examples(args, tokenizer, mode="test")
    if args.do_train:
        train_dataset = load_and_cache_examples(args, tokenizer, mode="train")

    trainer = Trainer(args, train_dataset, dev_dataset, test_dataset)

    if args.do_train:
        trainer.train()

    if args.do_eval:
        trainer.load_model()
        trainer.evaluate("test", "eval") 
开发者ID:monologg,项目名称:KoBERT-NER,代码行数:25,代码来源:main.py

示例11: main

# 需要导入模块: import trainer [as 别名]
# 或者: from trainer import Trainer [as 别名]
def main(args):
    logging.basicConfig(
      level=logging.DEBUG,
      format='%(asctime)-15s %(name)-5s %(levelname)-8s %(message)s',
      filename=os.path.join(args.logdir, 'logging.txt'))
    console = logging.StreamHandler()
    console.setLevel(logging.DEBUG)
    formatter = logging.Formatter('%(asctime)s %(levelname)-8s: %(message)s')
    console.setFormatter(formatter)
    logging.getLogger().addHandler(console)

    filename = os.path.realpath(args.index_file)
    if not os.path.isfile(filename):
        raise ValueError('No such index_file: {}'.format(filename))
    else:
        print("Reading csv file: {}".format(filename))

    with open(filename, "r") as f:
        line = f.readline().strip()
        input_path = line.split(',')[0]
        if not os.path.exists(input_path):
            raise ValueError('Input path in csv not exist: {}'.format(input_path))

    t = trainer.Trainer(filename, args)
    t.fit() 
开发者ID:neycyanshi,项目名称:DDRNet,代码行数:27,代码来源:train.py

示例12: main

# 需要导入模块: import trainer [as 别名]
# 或者: from trainer import Trainer [as 别名]
def main(args):
    prepare_dirs(args)

    torch.manual_seed(args.random_seed)

    if args.num_gpu > 0:
        torch.cuda.manual_seed(args.random_seed)


    if args.network_type == 'seq2seq':
        vocab = data.common_loader.Vocab(args.vocab_file, args.max_vocab_size)
        dataset = {}
        if args.dataset == 'msrvtt':
            dataset['train'] = data.common_loader.MSRVTTBatcher(args, 'train', vocab)
            dataset['val'] = data.common_loader.MSRVTTBatcher(args, 'val', vocab)
            dataset['test'] = data.common_loader.MSRVTTBatcher(args, 'test', vocab)
        else:
            raise Exception(f"Unknown dataset: {args.dataset} for the corresponding network type: {args.network_type}")

    else:
        raise NotImplemented(f"{args.dataset} is not supported")

    trainer = Trainer(args, dataset)

    if args.mode == 'train':
        save_args(args)
        trainer.train()
    else:
        if not args.load_path:
            raise Exception("[!] You should specify `load_path` to load a pretrained model")
        else:
            trainer.test(args.mode) 
开发者ID:ramakanth-pasunuru,项目名称:video_captioning_rl,代码行数:34,代码来源:main.py

示例13: get_trainer

# 需要导入模块: import trainer [as 别名]
# 或者: from trainer import Trainer [as 别名]
def get_trainer(config):
    print('tf: resetting default graph!')
    tf.reset_default_graph()

    #tf.set_random_seed(config.random_seed)
    #np.random.seed(22)

    print('Using data_type ',config.data_type)
    trainer=Trainer(config,config.data_type)
    print('built trainer successfully')

    tf.logging.set_verbosity(tf.logging.ERROR)

    return trainer 
开发者ID:mkocaoglu,项目名称:CausalGAN,代码行数:16,代码来源:main.py

示例14: main

# 需要导入模块: import trainer [as 别名]
# 或者: from trainer import Trainer [as 别名]
def main(config):
    # For fast training
    cudnn.benchmark = True


    config.n_class = len(glob.glob(os.path.join(config.image_path, '*/')))
    print('number class:', config.n_class)
    # Data loader
    data_loader = Data_Loader(config.train, config.dataset, config.image_path, config.imsize,
                             config.batch_size, shuf=config.train)

    # Create directories if not exist
    make_folder(config.model_save_path, config.version)
    make_folder(config.sample_path, config.version)
    make_folder(config.log_path, config.version)
    make_folder(config.attn_path, config.version)


    print('config data_loader and build logs folder')

    if config.train:
        if config.model=='sagan':
            trainer = Trainer(data_loader.loader(), config)
        elif config.model == 'qgan':
            trainer = qgan_trainer(data_loader.loader(), config)
        trainer.train()
    else:
        tester = Tester(data_loader.loader(), config)
        tester.test() 
开发者ID:sxhxliang,项目名称:BigGAN-pytorch,代码行数:31,代码来源:main.py

示例15: main

# 需要导入模块: import trainer [as 别名]
# 或者: from trainer import Trainer [as 别名]
def main(config):
    logger = config.get_logger('train')

    # setup data_loader instances
    data_loader = config.init_obj('data_loader', module_data)
    valid_data_loader = data_loader.split_validation()

    # build model architecture, then print to console
    model = config.init_obj('arch', module_arch)
    logger.info(model)

    # get function handles of loss and metrics
    criterion = getattr(module_loss, config['loss'])
    metrics = [getattr(module_metric, met) for met in config['metrics']]

    # build optimizer, learning rate scheduler. delete every lines containing lr_scheduler for disabling scheduler
    trainable_params = filter(lambda p: p.requires_grad, model.parameters())
    optimizer = config.init_obj('optimizer', torch.optim, trainable_params)

    lr_scheduler = config.init_obj('lr_scheduler', torch.optim.lr_scheduler, optimizer)

    trainer = Trainer(model, criterion, metrics, optimizer,
                      config=config,
                      data_loader=data_loader,
                      valid_data_loader=valid_data_loader,
                      lr_scheduler=lr_scheduler)

    trainer.train() 
开发者ID:victoresque,项目名称:pytorch-template,代码行数:30,代码来源:train.py


注:本文中的trainer.Trainer方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。