本文整理汇总了Python中trainer.Trainer方法的典型用法代码示例。如果您正苦于以下问题:Python trainer.Trainer方法的具体用法?Python trainer.Trainer怎么用?Python trainer.Trainer使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类trainer
的用法示例。
在下文中一共展示了trainer.Trainer方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: main
# 需要导入模块: import trainer [as 别名]
# 或者: from trainer import Trainer [as 别名]
def main(args):
with open(args.config) as f:
if version.parse(yaml.version >= "5.1"):
config = yaml.load(f, Loader=yaml.FullLoader)
else:
config = yaml.load(f)
for k, v in config.items():
setattr(args, k, v)
# exp path
if not hasattr(args, 'exp_path'):
args.exp_path = os.path.dirname(args.config)
# dist init
if mp.get_start_method(allow_none=True) != 'spawn':
mp.set_start_method('spawn', force=True)
dist_init(args.launcher, backend='nccl')
# train
trainer = Trainer(args)
trainer.run()
示例2: main
# 需要导入模块: import trainer [as 别名]
# 或者: from trainer import Trainer [as 别名]
def main(_):
#Directory generating.. for saving
prepare_dirs(config)
#Random seed settings
rng = np.random.RandomState(config.random_seed)
tf.set_random_seed(config.random_seed)
#Model training
trainer = Trainer(config, rng)
save_config(config.model_dir, config)
if config.is_train:
trainer.train()
else:
if not config.load_path:
raise Exception(
"[!] You should specify `load_path` to "
"load a pretrained model")
trainer.test()
示例3: main
# 需要导入模块: import trainer [as 别名]
# 或者: from trainer import Trainer [as 别名]
def main(is_debug):
# configs
dataset_dir = '../datasets/cardio_dance_512'
pose_name = '../datasets/cardio_dance_512/poses.npy'
ckpt_dir = './checkpoints/dance_test_new_down2_res6'
log_dir = './logs/dance_test_new_down2_res6'
batch_num = 0
batch_size = 64
image_folder = dataset.ImageFolderDataset(dataset_dir, cache=os.path.join(dataset_dir, 'local.db'))
face_dataset = dataset.FaceCropDataset(image_folder, pose_name, image_transforms, crop_size=48) # 48 for 512-frame, 96 for HD frame
data_loader = DataLoader(face_dataset, batch_size=batch_size,
drop_last=True, num_workers=4, shuffle=True)
generator, discriminator, batch_num = load_models(ckpt_dir, batch_num)
if is_debug:
trainer = Trainer(ckpt_dir, log_dir, face_dataset, data_loader, log_every=1, save_every=1)
else:
trainer = Trainer(ckpt_dir, log_dir, face_dataset, data_loader)
trainer.train(generator, discriminator, batch_num)
示例4: main
# 需要导入模块: import trainer [as 别名]
# 或者: from trainer import Trainer [as 别名]
def main(is_debug):
# configs
import os
dataset_dir = '../data/face'
pose_name = '../data/target/pose.npy'
ckpt_dir = '../checkpoints/face'
log_dir = '../checkpoints/face/logs'
batch_num = 10
batch_size = 10
image_folder = dataset.ImageFolderDataset(dataset_dir, cache=os.path.join(dataset_dir, 'local.db'))
face_dataset = dataset.FaceCropDataset(image_folder, pose_name, image_transforms, crop_size=48) # 48 for 512-frame, 96 for HD frame
data_loader = DataLoader(face_dataset, batch_size=batch_size,
drop_last=True, num_workers=4, shuffle=True)
generator, discriminator, batch_num = load_models(ckpt_dir, batch_num)
if is_debug:
trainer = Trainer(ckpt_dir, log_dir, face_dataset, data_loader, log_every=1, save_every=1)
else:
trainer = Trainer(ckpt_dir, log_dir, face_dataset, data_loader)
trainer.train(generator, discriminator, batch_num)
示例5: train
# 需要导入模块: import trainer [as 别名]
# 或者: from trainer import Trainer [as 别名]
def train(self):
assert callable(self.model), "model is not callable!!"
assert callable(self.loss), "loss is not callable!!"
assert all(callable(met) for met in self.metrics), "metrics is not callable!!"
assert "trainer" in self.config, "trainer hasn't been configured!!"
assert isinstance(self.data_loader, Iterable), "data_loader is not iterable!!"
# the num of classes in dataset must bet the same as model's output
if hasattr(self.data_loader, 'classes'):
true_classes = len(self.data_loader.classes)
model_output = self.config['arch']['args']['n_class']
assert true_classes==model_output, "model分类数为{},可是实际上有{}个类".format(
model_output, true_classes)
if "name" not in self.config:
self.config["name"] = "_".join(self.config["arch"]["type"],
self.config["data_loader"]["type"])
self.trainer = Trainer(self.model, self.loss, self.metrics, self.optimizer,
resume=self.resume, config=self.config, data_loader=self.data_loader,
valid_data_loader=self.valid_data_loader, lr_scheduler=self.lr_scheduler,
train_logger=self.train_logger)
self.trainer.train()
示例6: main
# 需要导入模块: import trainer [as 别名]
# 或者: from trainer import Trainer [as 别名]
def main(config, resume):
train_logger = Logger()
# DATA LOADERS
train_loader = get_instance(dataloaders, 'train_loader', config)
val_loader = get_instance(dataloaders, 'val_loader', config)
# MODEL
model = get_instance(models, 'arch', config, train_loader.dataset.num_classes)
print(f'\n{model}\n')
# LOSS
loss = getattr(losses, config['loss'])(ignore_index = config['ignore_index'])
# TRAINING
trainer = Trainer(
model=model,
loss=loss,
resume=resume,
config=config,
train_loader=train_loader,
val_loader=val_loader,
train_logger=train_logger)
trainer.train()
示例7: main
# 需要导入模块: import trainer [as 别名]
# 或者: from trainer import Trainer [as 别名]
def main(_):
prepare_dirs_and_logger(config)
if not config.task.lower().startswith('tsp'):
raise Exception("[!] Task should starts with TSP")
if config.max_enc_length is None:
config.max_enc_length = config.max_data_length
if config.max_dec_length is None:
config.max_dec_length = config.max_data_length
rng = np.random.RandomState(config.random_seed)
tf.set_random_seed(config.random_seed)
trainer = Trainer(config, rng)
save_config(config.model_dir, config)
if config.is_train:
trainer.train()
else:
if not config.load_path:
raise Exception("[!] You should specify `load_path` to load a pretrained model")
trainer.test()
tf.logging.info("Run finished.")
示例8: main
# 需要导入模块: import trainer [as 别名]
# 或者: from trainer import Trainer [as 别名]
def main():
global model
if args.data_test == ['video']:
from videotester import VideoTester
model = model.Model(args, checkpoint)
t = VideoTester(args, model, checkpoint)
t.test()
else:
if checkpoint.ok:
loader = data.Data(args)
_model = model.Model(args, checkpoint)
_loss = loss.Loss(args, checkpoint) if not args.test_only else None
t = Trainer(args, loader, _model, _loss, checkpoint)
while not t.terminate():
t.train()
t.test()
checkpoint.done()
示例9: main
# 需要导入模块: import trainer [as 别名]
# 或者: from trainer import Trainer [as 别名]
def main(args):
init_logger()
set_seed(args)
tokenizer = load_tokenizer(args)
train_dataset = load_and_cache_examples(args, tokenizer, mode="train")
dev_dataset = load_and_cache_examples(args, tokenizer, mode="dev")
test_dataset = load_and_cache_examples(args, tokenizer, mode="test")
trainer = Trainer(args, train_dataset, dev_dataset, test_dataset)
if args.do_train:
trainer.train()
if args.do_eval:
trainer.load_model()
trainer.evaluate("test")
示例10: main
# 需要导入模块: import trainer [as 别名]
# 或者: from trainer import Trainer [as 别名]
def main(args):
init_logger()
set_seed(args)
tokenizer = load_tokenizer(args)
train_dataset = None
dev_dataset = None
test_dataset = None
if args.do_train or args.do_eval:
test_dataset = load_and_cache_examples(args, tokenizer, mode="test")
if args.do_train:
train_dataset = load_and_cache_examples(args, tokenizer, mode="train")
trainer = Trainer(args, train_dataset, dev_dataset, test_dataset)
if args.do_train:
trainer.train()
if args.do_eval:
trainer.load_model()
trainer.evaluate("test", "eval")
示例11: main
# 需要导入模块: import trainer [as 别名]
# 或者: from trainer import Trainer [as 别名]
def main(args):
logging.basicConfig(
level=logging.DEBUG,
format='%(asctime)-15s %(name)-5s %(levelname)-8s %(message)s',
filename=os.path.join(args.logdir, 'logging.txt'))
console = logging.StreamHandler()
console.setLevel(logging.DEBUG)
formatter = logging.Formatter('%(asctime)s %(levelname)-8s: %(message)s')
console.setFormatter(formatter)
logging.getLogger().addHandler(console)
filename = os.path.realpath(args.index_file)
if not os.path.isfile(filename):
raise ValueError('No such index_file: {}'.format(filename))
else:
print("Reading csv file: {}".format(filename))
with open(filename, "r") as f:
line = f.readline().strip()
input_path = line.split(',')[0]
if not os.path.exists(input_path):
raise ValueError('Input path in csv not exist: {}'.format(input_path))
t = trainer.Trainer(filename, args)
t.fit()
示例12: main
# 需要导入模块: import trainer [as 别名]
# 或者: from trainer import Trainer [as 别名]
def main(args):
prepare_dirs(args)
torch.manual_seed(args.random_seed)
if args.num_gpu > 0:
torch.cuda.manual_seed(args.random_seed)
if args.network_type == 'seq2seq':
vocab = data.common_loader.Vocab(args.vocab_file, args.max_vocab_size)
dataset = {}
if args.dataset == 'msrvtt':
dataset['train'] = data.common_loader.MSRVTTBatcher(args, 'train', vocab)
dataset['val'] = data.common_loader.MSRVTTBatcher(args, 'val', vocab)
dataset['test'] = data.common_loader.MSRVTTBatcher(args, 'test', vocab)
else:
raise Exception(f"Unknown dataset: {args.dataset} for the corresponding network type: {args.network_type}")
else:
raise NotImplemented(f"{args.dataset} is not supported")
trainer = Trainer(args, dataset)
if args.mode == 'train':
save_args(args)
trainer.train()
else:
if not args.load_path:
raise Exception("[!] You should specify `load_path` to load a pretrained model")
else:
trainer.test(args.mode)
示例13: get_trainer
# 需要导入模块: import trainer [as 别名]
# 或者: from trainer import Trainer [as 别名]
def get_trainer(config):
print('tf: resetting default graph!')
tf.reset_default_graph()
#tf.set_random_seed(config.random_seed)
#np.random.seed(22)
print('Using data_type ',config.data_type)
trainer=Trainer(config,config.data_type)
print('built trainer successfully')
tf.logging.set_verbosity(tf.logging.ERROR)
return trainer
示例14: main
# 需要导入模块: import trainer [as 别名]
# 或者: from trainer import Trainer [as 别名]
def main(config):
# For fast training
cudnn.benchmark = True
config.n_class = len(glob.glob(os.path.join(config.image_path, '*/')))
print('number class:', config.n_class)
# Data loader
data_loader = Data_Loader(config.train, config.dataset, config.image_path, config.imsize,
config.batch_size, shuf=config.train)
# Create directories if not exist
make_folder(config.model_save_path, config.version)
make_folder(config.sample_path, config.version)
make_folder(config.log_path, config.version)
make_folder(config.attn_path, config.version)
print('config data_loader and build logs folder')
if config.train:
if config.model=='sagan':
trainer = Trainer(data_loader.loader(), config)
elif config.model == 'qgan':
trainer = qgan_trainer(data_loader.loader(), config)
trainer.train()
else:
tester = Tester(data_loader.loader(), config)
tester.test()
示例15: main
# 需要导入模块: import trainer [as 别名]
# 或者: from trainer import Trainer [as 别名]
def main(config):
logger = config.get_logger('train')
# setup data_loader instances
data_loader = config.init_obj('data_loader', module_data)
valid_data_loader = data_loader.split_validation()
# build model architecture, then print to console
model = config.init_obj('arch', module_arch)
logger.info(model)
# get function handles of loss and metrics
criterion = getattr(module_loss, config['loss'])
metrics = [getattr(module_metric, met) for met in config['metrics']]
# build optimizer, learning rate scheduler. delete every lines containing lr_scheduler for disabling scheduler
trainable_params = filter(lambda p: p.requires_grad, model.parameters())
optimizer = config.init_obj('optimizer', torch.optim, trainable_params)
lr_scheduler = config.init_obj('lr_scheduler', torch.optim.lr_scheduler, optimizer)
trainer = Trainer(model, criterion, metrics, optimizer,
config=config,
data_loader=data_loader,
valid_data_loader=valid_data_loader,
lr_scheduler=lr_scheduler)
trainer.train()