當前位置: 首頁>>代碼示例>>Python>>正文


Python trainer.Trainer方法代碼示例

本文整理匯總了Python中trainer.Trainer方法的典型用法代碼示例。如果您正苦於以下問題:Python trainer.Trainer方法的具體用法?Python trainer.Trainer怎麽用?Python trainer.Trainer使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在trainer的用法示例。


在下文中一共展示了trainer.Trainer方法的15個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。

示例1: main

# 需要導入模塊: import trainer [as 別名]
# 或者: from trainer import Trainer [as 別名]
def main(args):
    with open(args.config) as f:
        if version.parse(yaml.version >= "5.1"):
            config = yaml.load(f, Loader=yaml.FullLoader)
        else:
            config = yaml.load(f)

    for k, v in config.items():
        setattr(args, k, v)

    # exp path
    if not hasattr(args, 'exp_path'):
        args.exp_path = os.path.dirname(args.config)

    # dist init
    if mp.get_start_method(allow_none=True) != 'spawn':
        mp.set_start_method('spawn', force=True)
    dist_init(args.launcher, backend='nccl')

    # train
    trainer = Trainer(args)
    trainer.run() 
開發者ID:XiaohangZhan,項目名稱:conditional-motion-propagation,代碼行數:24,代碼來源:main.py

示例2: main

# 需要導入模塊: import trainer [as 別名]
# 或者: from trainer import Trainer [as 別名]
def main(_):

    #Directory generating.. for saving
    prepare_dirs(config)

    #Random seed settings
    rng = np.random.RandomState(config.random_seed)
    tf.set_random_seed(config.random_seed)

    #Model training
    trainer = Trainer(config, rng)
    save_config(config.model_dir, config)
    if config.is_train:
        trainer.train()
    else:
        if not config.load_path:
            raise Exception(
                "[!] You should specify `load_path` to "
                "load a pretrained model")
        trainer.test() 
開發者ID:youngjoo-epfl,項目名稱:gconvRNN,代碼行數:22,代碼來源:gcrn_main.py

示例3: main

# 需要導入模塊: import trainer [as 別名]
# 或者: from trainer import Trainer [as 別名]
def main(is_debug):
    # configs
    dataset_dir = '../datasets/cardio_dance_512'
    pose_name = '../datasets/cardio_dance_512/poses.npy'
    ckpt_dir = './checkpoints/dance_test_new_down2_res6'
    log_dir = './logs/dance_test_new_down2_res6'
    batch_num = 0
    batch_size = 64

    image_folder = dataset.ImageFolderDataset(dataset_dir, cache=os.path.join(dataset_dir, 'local.db'))
    face_dataset = dataset.FaceCropDataset(image_folder, pose_name, image_transforms, crop_size=48)  # 48 for 512-frame, 96 for HD frame
    data_loader = DataLoader(face_dataset, batch_size=batch_size,
                             drop_last=True, num_workers=4, shuffle=True)

    generator, discriminator, batch_num = load_models(ckpt_dir, batch_num)

    if is_debug:
        trainer = Trainer(ckpt_dir, log_dir, face_dataset, data_loader, log_every=1, save_every=1)
    else:
        trainer = Trainer(ckpt_dir, log_dir, face_dataset, data_loader)
    trainer.train(generator, discriminator, batch_num) 
開發者ID:Lotayou,項目名稱:everybody_dance_now_pytorch,代碼行數:23,代碼來源:main.py

示例4: main

# 需要導入模塊: import trainer [as 別名]
# 或者: from trainer import Trainer [as 別名]
def main(is_debug):
    # configs
    import os

    dataset_dir = '../data/face'
    pose_name = '../data/target/pose.npy'
    ckpt_dir = '../checkpoints/face'
    log_dir = '../checkpoints/face/logs'
    batch_num = 10
    batch_size = 10

    image_folder = dataset.ImageFolderDataset(dataset_dir, cache=os.path.join(dataset_dir, 'local.db'))
    face_dataset = dataset.FaceCropDataset(image_folder, pose_name, image_transforms, crop_size=48)  # 48 for 512-frame, 96 for HD frame
    data_loader = DataLoader(face_dataset, batch_size=batch_size,
                             drop_last=True, num_workers=4, shuffle=True)

    generator, discriminator, batch_num = load_models(ckpt_dir, batch_num)

    if is_debug:
        trainer = Trainer(ckpt_dir, log_dir, face_dataset, data_loader, log_every=1, save_every=1)
    else:
        trainer = Trainer(ckpt_dir, log_dir, face_dataset, data_loader)
    trainer.train(generator, discriminator, batch_num) 
開發者ID:CUHKSZ-TQL,項目名稱:EverybodyDanceNow_reproduce_pytorch,代碼行數:25,代碼來源:main.py

示例5: train

# 需要導入模塊: import trainer [as 別名]
# 或者: from trainer import Trainer [as 別名]
def train(self):
        assert callable(self.model), "model is not callable!!"
        assert callable(self.loss), "loss is not callable!!"
        assert all(callable(met) for met in self.metrics), "metrics is not callable!!"
        assert "trainer" in self.config, "trainer hasn't been configured!!"
        assert isinstance(self.data_loader, Iterable), "data_loader is not iterable!!"

        # the num of classes in dataset must bet the same as model's output
        if hasattr(self.data_loader, 'classes'):
            true_classes = len(self.data_loader.classes)
            model_output = self.config['arch']['args']['n_class']
            assert true_classes==model_output, "model分類數為{},可是實際上有{}個類".format(
                model_output, true_classes)

        if "name" not in self.config:
            self.config["name"] = "_".join(self.config["arch"]["type"], 
                self.config["data_loader"]["type"])
        self.trainer = Trainer(self.model, self.loss, self.metrics, self.optimizer, 
            resume=self.resume, config=self.config, data_loader=self.data_loader,
            valid_data_loader=self.valid_data_loader, lr_scheduler=self.lr_scheduler,
            train_logger=self.train_logger)        
        self.trainer.train() 
開發者ID:daili0015,項目名稱:ModelFeast,代碼行數:24,代碼來源:classifier.py

示例6: main

# 需要導入模塊: import trainer [as 別名]
# 或者: from trainer import Trainer [as 別名]
def main(config, resume):
    train_logger = Logger()

    # DATA LOADERS
    train_loader = get_instance(dataloaders, 'train_loader', config)
    val_loader = get_instance(dataloaders, 'val_loader', config)

    # MODEL
    model = get_instance(models, 'arch', config, train_loader.dataset.num_classes)
    print(f'\n{model}\n')

    # LOSS
    loss = getattr(losses, config['loss'])(ignore_index = config['ignore_index'])

    # TRAINING
    trainer = Trainer(
        model=model,
        loss=loss,
        resume=resume,
        config=config,
        train_loader=train_loader,
        val_loader=val_loader,
        train_logger=train_logger)

    trainer.train() 
開發者ID:yassouali,項目名稱:pytorch_segmentation,代碼行數:27,代碼來源:train.py

示例7: main

# 需要導入模塊: import trainer [as 別名]
# 或者: from trainer import Trainer [as 別名]
def main(_):
  prepare_dirs_and_logger(config)

  if not config.task.lower().startswith('tsp'):
    raise Exception("[!] Task should starts with TSP")

  if config.max_enc_length is None:
    config.max_enc_length = config.max_data_length
  if config.max_dec_length is None:
    config.max_dec_length = config.max_data_length

  rng = np.random.RandomState(config.random_seed)
  tf.set_random_seed(config.random_seed)

  trainer = Trainer(config, rng)
  save_config(config.model_dir, config)

  if config.is_train:
    trainer.train()
  else:
    if not config.load_path:
      raise Exception("[!] You should specify `load_path` to load a pretrained model")
    trainer.test()

  tf.logging.info("Run finished.") 
開發者ID:devsisters,項目名稱:neural-combinatorial-rl-tensorflow,代碼行數:27,代碼來源:main.py

示例8: main

# 需要導入模塊: import trainer [as 別名]
# 或者: from trainer import Trainer [as 別名]
def main():
    global model
    if args.data_test == ['video']:
        from videotester import VideoTester
        model = model.Model(args, checkpoint)
        t = VideoTester(args, model, checkpoint)
        t.test()
    else:
        if checkpoint.ok:
            loader = data.Data(args)
            _model = model.Model(args, checkpoint)
            _loss = loss.Loss(args, checkpoint) if not args.test_only else None
            t = Trainer(args, loader, _model, _loss, checkpoint)
            while not t.terminate():
                t.train()
                t.test()

            checkpoint.done() 
開發者ID:thstkdgus35,項目名稱:EDSR-PyTorch,代碼行數:20,代碼來源:main.py

示例9: main

# 需要導入模塊: import trainer [as 別名]
# 或者: from trainer import Trainer [as 別名]
def main(args):
    init_logger()
    set_seed(args)
    tokenizer = load_tokenizer(args)

    train_dataset = load_and_cache_examples(args, tokenizer, mode="train")
    dev_dataset = load_and_cache_examples(args, tokenizer, mode="dev")
    test_dataset = load_and_cache_examples(args, tokenizer, mode="test")

    trainer = Trainer(args, train_dataset, dev_dataset, test_dataset)

    if args.do_train:
        trainer.train()

    if args.do_eval:
        trainer.load_model()
        trainer.evaluate("test") 
開發者ID:monologg,項目名稱:JointBERT,代碼行數:19,代碼來源:main.py

示例10: main

# 需要導入模塊: import trainer [as 別名]
# 或者: from trainer import Trainer [as 別名]
def main(args):
    init_logger()
    set_seed(args)
    
    tokenizer = load_tokenizer(args)

    train_dataset = None
    dev_dataset = None
    test_dataset = None

    if args.do_train or args.do_eval:
        test_dataset = load_and_cache_examples(args, tokenizer, mode="test")
    if args.do_train:
        train_dataset = load_and_cache_examples(args, tokenizer, mode="train")

    trainer = Trainer(args, train_dataset, dev_dataset, test_dataset)

    if args.do_train:
        trainer.train()

    if args.do_eval:
        trainer.load_model()
        trainer.evaluate("test", "eval") 
開發者ID:monologg,項目名稱:KoBERT-NER,代碼行數:25,代碼來源:main.py

示例11: main

# 需要導入模塊: import trainer [as 別名]
# 或者: from trainer import Trainer [as 別名]
def main(args):
    logging.basicConfig(
      level=logging.DEBUG,
      format='%(asctime)-15s %(name)-5s %(levelname)-8s %(message)s',
      filename=os.path.join(args.logdir, 'logging.txt'))
    console = logging.StreamHandler()
    console.setLevel(logging.DEBUG)
    formatter = logging.Formatter('%(asctime)s %(levelname)-8s: %(message)s')
    console.setFormatter(formatter)
    logging.getLogger().addHandler(console)

    filename = os.path.realpath(args.index_file)
    if not os.path.isfile(filename):
        raise ValueError('No such index_file: {}'.format(filename))
    else:
        print("Reading csv file: {}".format(filename))

    with open(filename, "r") as f:
        line = f.readline().strip()
        input_path = line.split(',')[0]
        if not os.path.exists(input_path):
            raise ValueError('Input path in csv not exist: {}'.format(input_path))

    t = trainer.Trainer(filename, args)
    t.fit() 
開發者ID:neycyanshi,項目名稱:DDRNet,代碼行數:27,代碼來源:train.py

示例12: main

# 需要導入模塊: import trainer [as 別名]
# 或者: from trainer import Trainer [as 別名]
def main(args):
    prepare_dirs(args)

    torch.manual_seed(args.random_seed)

    if args.num_gpu > 0:
        torch.cuda.manual_seed(args.random_seed)


    if args.network_type == 'seq2seq':
        vocab = data.common_loader.Vocab(args.vocab_file, args.max_vocab_size)
        dataset = {}
        if args.dataset == 'msrvtt':
            dataset['train'] = data.common_loader.MSRVTTBatcher(args, 'train', vocab)
            dataset['val'] = data.common_loader.MSRVTTBatcher(args, 'val', vocab)
            dataset['test'] = data.common_loader.MSRVTTBatcher(args, 'test', vocab)
        else:
            raise Exception(f"Unknown dataset: {args.dataset} for the corresponding network type: {args.network_type}")

    else:
        raise NotImplemented(f"{args.dataset} is not supported")

    trainer = Trainer(args, dataset)

    if args.mode == 'train':
        save_args(args)
        trainer.train()
    else:
        if not args.load_path:
            raise Exception("[!] You should specify `load_path` to load a pretrained model")
        else:
            trainer.test(args.mode) 
開發者ID:ramakanth-pasunuru,項目名稱:video_captioning_rl,代碼行數:34,代碼來源:main.py

示例13: get_trainer

# 需要導入模塊: import trainer [as 別名]
# 或者: from trainer import Trainer [as 別名]
def get_trainer(config):
    print('tf: resetting default graph!')
    tf.reset_default_graph()

    #tf.set_random_seed(config.random_seed)
    #np.random.seed(22)

    print('Using data_type ',config.data_type)
    trainer=Trainer(config,config.data_type)
    print('built trainer successfully')

    tf.logging.set_verbosity(tf.logging.ERROR)

    return trainer 
開發者ID:mkocaoglu,項目名稱:CausalGAN,代碼行數:16,代碼來源:main.py

示例14: main

# 需要導入模塊: import trainer [as 別名]
# 或者: from trainer import Trainer [as 別名]
def main(config):
    # For fast training
    cudnn.benchmark = True


    config.n_class = len(glob.glob(os.path.join(config.image_path, '*/')))
    print('number class:', config.n_class)
    # Data loader
    data_loader = Data_Loader(config.train, config.dataset, config.image_path, config.imsize,
                             config.batch_size, shuf=config.train)

    # Create directories if not exist
    make_folder(config.model_save_path, config.version)
    make_folder(config.sample_path, config.version)
    make_folder(config.log_path, config.version)
    make_folder(config.attn_path, config.version)


    print('config data_loader and build logs folder')

    if config.train:
        if config.model=='sagan':
            trainer = Trainer(data_loader.loader(), config)
        elif config.model == 'qgan':
            trainer = qgan_trainer(data_loader.loader(), config)
        trainer.train()
    else:
        tester = Tester(data_loader.loader(), config)
        tester.test() 
開發者ID:sxhxliang,項目名稱:BigGAN-pytorch,代碼行數:31,代碼來源:main.py

示例15: main

# 需要導入模塊: import trainer [as 別名]
# 或者: from trainer import Trainer [as 別名]
def main(config):
    logger = config.get_logger('train')

    # setup data_loader instances
    data_loader = config.init_obj('data_loader', module_data)
    valid_data_loader = data_loader.split_validation()

    # build model architecture, then print to console
    model = config.init_obj('arch', module_arch)
    logger.info(model)

    # get function handles of loss and metrics
    criterion = getattr(module_loss, config['loss'])
    metrics = [getattr(module_metric, met) for met in config['metrics']]

    # build optimizer, learning rate scheduler. delete every lines containing lr_scheduler for disabling scheduler
    trainable_params = filter(lambda p: p.requires_grad, model.parameters())
    optimizer = config.init_obj('optimizer', torch.optim, trainable_params)

    lr_scheduler = config.init_obj('lr_scheduler', torch.optim.lr_scheduler, optimizer)

    trainer = Trainer(model, criterion, metrics, optimizer,
                      config=config,
                      data_loader=data_loader,
                      valid_data_loader=valid_data_loader,
                      lr_scheduler=lr_scheduler)

    trainer.train() 
開發者ID:victoresque,項目名稱:pytorch-template,代碼行數:30,代碼來源:train.py


注:本文中的trainer.Trainer方法示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。