本文整理匯總了Python中trainer.Trainer方法的典型用法代碼示例。如果您正苦於以下問題:Python trainer.Trainer方法的具體用法?Python trainer.Trainer怎麽用?Python trainer.Trainer使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在類trainer
的用法示例。
在下文中一共展示了trainer.Trainer方法的15個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。
示例1: main
# 需要導入模塊: import trainer [as 別名]
# 或者: from trainer import Trainer [as 別名]
def main(args):
with open(args.config) as f:
if version.parse(yaml.version >= "5.1"):
config = yaml.load(f, Loader=yaml.FullLoader)
else:
config = yaml.load(f)
for k, v in config.items():
setattr(args, k, v)
# exp path
if not hasattr(args, 'exp_path'):
args.exp_path = os.path.dirname(args.config)
# dist init
if mp.get_start_method(allow_none=True) != 'spawn':
mp.set_start_method('spawn', force=True)
dist_init(args.launcher, backend='nccl')
# train
trainer = Trainer(args)
trainer.run()
示例2: main
# 需要導入模塊: import trainer [as 別名]
# 或者: from trainer import Trainer [as 別名]
def main(_):
#Directory generating.. for saving
prepare_dirs(config)
#Random seed settings
rng = np.random.RandomState(config.random_seed)
tf.set_random_seed(config.random_seed)
#Model training
trainer = Trainer(config, rng)
save_config(config.model_dir, config)
if config.is_train:
trainer.train()
else:
if not config.load_path:
raise Exception(
"[!] You should specify `load_path` to "
"load a pretrained model")
trainer.test()
示例3: main
# 需要導入模塊: import trainer [as 別名]
# 或者: from trainer import Trainer [as 別名]
def main(is_debug):
# configs
dataset_dir = '../datasets/cardio_dance_512'
pose_name = '../datasets/cardio_dance_512/poses.npy'
ckpt_dir = './checkpoints/dance_test_new_down2_res6'
log_dir = './logs/dance_test_new_down2_res6'
batch_num = 0
batch_size = 64
image_folder = dataset.ImageFolderDataset(dataset_dir, cache=os.path.join(dataset_dir, 'local.db'))
face_dataset = dataset.FaceCropDataset(image_folder, pose_name, image_transforms, crop_size=48) # 48 for 512-frame, 96 for HD frame
data_loader = DataLoader(face_dataset, batch_size=batch_size,
drop_last=True, num_workers=4, shuffle=True)
generator, discriminator, batch_num = load_models(ckpt_dir, batch_num)
if is_debug:
trainer = Trainer(ckpt_dir, log_dir, face_dataset, data_loader, log_every=1, save_every=1)
else:
trainer = Trainer(ckpt_dir, log_dir, face_dataset, data_loader)
trainer.train(generator, discriminator, batch_num)
示例4: main
# 需要導入模塊: import trainer [as 別名]
# 或者: from trainer import Trainer [as 別名]
def main(is_debug):
# configs
import os
dataset_dir = '../data/face'
pose_name = '../data/target/pose.npy'
ckpt_dir = '../checkpoints/face'
log_dir = '../checkpoints/face/logs'
batch_num = 10
batch_size = 10
image_folder = dataset.ImageFolderDataset(dataset_dir, cache=os.path.join(dataset_dir, 'local.db'))
face_dataset = dataset.FaceCropDataset(image_folder, pose_name, image_transforms, crop_size=48) # 48 for 512-frame, 96 for HD frame
data_loader = DataLoader(face_dataset, batch_size=batch_size,
drop_last=True, num_workers=4, shuffle=True)
generator, discriminator, batch_num = load_models(ckpt_dir, batch_num)
if is_debug:
trainer = Trainer(ckpt_dir, log_dir, face_dataset, data_loader, log_every=1, save_every=1)
else:
trainer = Trainer(ckpt_dir, log_dir, face_dataset, data_loader)
trainer.train(generator, discriminator, batch_num)
示例5: train
# 需要導入模塊: import trainer [as 別名]
# 或者: from trainer import Trainer [as 別名]
def train(self):
assert callable(self.model), "model is not callable!!"
assert callable(self.loss), "loss is not callable!!"
assert all(callable(met) for met in self.metrics), "metrics is not callable!!"
assert "trainer" in self.config, "trainer hasn't been configured!!"
assert isinstance(self.data_loader, Iterable), "data_loader is not iterable!!"
# the num of classes in dataset must bet the same as model's output
if hasattr(self.data_loader, 'classes'):
true_classes = len(self.data_loader.classes)
model_output = self.config['arch']['args']['n_class']
assert true_classes==model_output, "model分類數為{},可是實際上有{}個類".format(
model_output, true_classes)
if "name" not in self.config:
self.config["name"] = "_".join(self.config["arch"]["type"],
self.config["data_loader"]["type"])
self.trainer = Trainer(self.model, self.loss, self.metrics, self.optimizer,
resume=self.resume, config=self.config, data_loader=self.data_loader,
valid_data_loader=self.valid_data_loader, lr_scheduler=self.lr_scheduler,
train_logger=self.train_logger)
self.trainer.train()
示例6: main
# 需要導入模塊: import trainer [as 別名]
# 或者: from trainer import Trainer [as 別名]
def main(config, resume):
train_logger = Logger()
# DATA LOADERS
train_loader = get_instance(dataloaders, 'train_loader', config)
val_loader = get_instance(dataloaders, 'val_loader', config)
# MODEL
model = get_instance(models, 'arch', config, train_loader.dataset.num_classes)
print(f'\n{model}\n')
# LOSS
loss = getattr(losses, config['loss'])(ignore_index = config['ignore_index'])
# TRAINING
trainer = Trainer(
model=model,
loss=loss,
resume=resume,
config=config,
train_loader=train_loader,
val_loader=val_loader,
train_logger=train_logger)
trainer.train()
示例7: main
# 需要導入模塊: import trainer [as 別名]
# 或者: from trainer import Trainer [as 別名]
def main(_):
prepare_dirs_and_logger(config)
if not config.task.lower().startswith('tsp'):
raise Exception("[!] Task should starts with TSP")
if config.max_enc_length is None:
config.max_enc_length = config.max_data_length
if config.max_dec_length is None:
config.max_dec_length = config.max_data_length
rng = np.random.RandomState(config.random_seed)
tf.set_random_seed(config.random_seed)
trainer = Trainer(config, rng)
save_config(config.model_dir, config)
if config.is_train:
trainer.train()
else:
if not config.load_path:
raise Exception("[!] You should specify `load_path` to load a pretrained model")
trainer.test()
tf.logging.info("Run finished.")
示例8: main
# 需要導入模塊: import trainer [as 別名]
# 或者: from trainer import Trainer [as 別名]
def main():
global model
if args.data_test == ['video']:
from videotester import VideoTester
model = model.Model(args, checkpoint)
t = VideoTester(args, model, checkpoint)
t.test()
else:
if checkpoint.ok:
loader = data.Data(args)
_model = model.Model(args, checkpoint)
_loss = loss.Loss(args, checkpoint) if not args.test_only else None
t = Trainer(args, loader, _model, _loss, checkpoint)
while not t.terminate():
t.train()
t.test()
checkpoint.done()
示例9: main
# 需要導入模塊: import trainer [as 別名]
# 或者: from trainer import Trainer [as 別名]
def main(args):
init_logger()
set_seed(args)
tokenizer = load_tokenizer(args)
train_dataset = load_and_cache_examples(args, tokenizer, mode="train")
dev_dataset = load_and_cache_examples(args, tokenizer, mode="dev")
test_dataset = load_and_cache_examples(args, tokenizer, mode="test")
trainer = Trainer(args, train_dataset, dev_dataset, test_dataset)
if args.do_train:
trainer.train()
if args.do_eval:
trainer.load_model()
trainer.evaluate("test")
示例10: main
# 需要導入模塊: import trainer [as 別名]
# 或者: from trainer import Trainer [as 別名]
def main(args):
init_logger()
set_seed(args)
tokenizer = load_tokenizer(args)
train_dataset = None
dev_dataset = None
test_dataset = None
if args.do_train or args.do_eval:
test_dataset = load_and_cache_examples(args, tokenizer, mode="test")
if args.do_train:
train_dataset = load_and_cache_examples(args, tokenizer, mode="train")
trainer = Trainer(args, train_dataset, dev_dataset, test_dataset)
if args.do_train:
trainer.train()
if args.do_eval:
trainer.load_model()
trainer.evaluate("test", "eval")
示例11: main
# 需要導入模塊: import trainer [as 別名]
# 或者: from trainer import Trainer [as 別名]
def main(args):
logging.basicConfig(
level=logging.DEBUG,
format='%(asctime)-15s %(name)-5s %(levelname)-8s %(message)s',
filename=os.path.join(args.logdir, 'logging.txt'))
console = logging.StreamHandler()
console.setLevel(logging.DEBUG)
formatter = logging.Formatter('%(asctime)s %(levelname)-8s: %(message)s')
console.setFormatter(formatter)
logging.getLogger().addHandler(console)
filename = os.path.realpath(args.index_file)
if not os.path.isfile(filename):
raise ValueError('No such index_file: {}'.format(filename))
else:
print("Reading csv file: {}".format(filename))
with open(filename, "r") as f:
line = f.readline().strip()
input_path = line.split(',')[0]
if not os.path.exists(input_path):
raise ValueError('Input path in csv not exist: {}'.format(input_path))
t = trainer.Trainer(filename, args)
t.fit()
示例12: main
# 需要導入模塊: import trainer [as 別名]
# 或者: from trainer import Trainer [as 別名]
def main(args):
prepare_dirs(args)
torch.manual_seed(args.random_seed)
if args.num_gpu > 0:
torch.cuda.manual_seed(args.random_seed)
if args.network_type == 'seq2seq':
vocab = data.common_loader.Vocab(args.vocab_file, args.max_vocab_size)
dataset = {}
if args.dataset == 'msrvtt':
dataset['train'] = data.common_loader.MSRVTTBatcher(args, 'train', vocab)
dataset['val'] = data.common_loader.MSRVTTBatcher(args, 'val', vocab)
dataset['test'] = data.common_loader.MSRVTTBatcher(args, 'test', vocab)
else:
raise Exception(f"Unknown dataset: {args.dataset} for the corresponding network type: {args.network_type}")
else:
raise NotImplemented(f"{args.dataset} is not supported")
trainer = Trainer(args, dataset)
if args.mode == 'train':
save_args(args)
trainer.train()
else:
if not args.load_path:
raise Exception("[!] You should specify `load_path` to load a pretrained model")
else:
trainer.test(args.mode)
示例13: get_trainer
# 需要導入模塊: import trainer [as 別名]
# 或者: from trainer import Trainer [as 別名]
def get_trainer(config):
print('tf: resetting default graph!')
tf.reset_default_graph()
#tf.set_random_seed(config.random_seed)
#np.random.seed(22)
print('Using data_type ',config.data_type)
trainer=Trainer(config,config.data_type)
print('built trainer successfully')
tf.logging.set_verbosity(tf.logging.ERROR)
return trainer
示例14: main
# 需要導入模塊: import trainer [as 別名]
# 或者: from trainer import Trainer [as 別名]
def main(config):
# For fast training
cudnn.benchmark = True
config.n_class = len(glob.glob(os.path.join(config.image_path, '*/')))
print('number class:', config.n_class)
# Data loader
data_loader = Data_Loader(config.train, config.dataset, config.image_path, config.imsize,
config.batch_size, shuf=config.train)
# Create directories if not exist
make_folder(config.model_save_path, config.version)
make_folder(config.sample_path, config.version)
make_folder(config.log_path, config.version)
make_folder(config.attn_path, config.version)
print('config data_loader and build logs folder')
if config.train:
if config.model=='sagan':
trainer = Trainer(data_loader.loader(), config)
elif config.model == 'qgan':
trainer = qgan_trainer(data_loader.loader(), config)
trainer.train()
else:
tester = Tester(data_loader.loader(), config)
tester.test()
示例15: main
# 需要導入模塊: import trainer [as 別名]
# 或者: from trainer import Trainer [as 別名]
def main(config):
logger = config.get_logger('train')
# setup data_loader instances
data_loader = config.init_obj('data_loader', module_data)
valid_data_loader = data_loader.split_validation()
# build model architecture, then print to console
model = config.init_obj('arch', module_arch)
logger.info(model)
# get function handles of loss and metrics
criterion = getattr(module_loss, config['loss'])
metrics = [getattr(module_metric, met) for met in config['metrics']]
# build optimizer, learning rate scheduler. delete every lines containing lr_scheduler for disabling scheduler
trainable_params = filter(lambda p: p.requires_grad, model.parameters())
optimizer = config.init_obj('optimizer', torch.optim, trainable_params)
lr_scheduler = config.init_obj('lr_scheduler', torch.optim.lr_scheduler, optimizer)
trainer = Trainer(model, criterion, metrics, optimizer,
config=config,
data_loader=data_loader,
valid_data_loader=valid_data_loader,
lr_scheduler=lr_scheduler)
trainer.train()