本文整理汇总了Python中model.BaseModel方法的典型用法代码示例。如果您正苦于以下问题:Python model.BaseModel方法的具体用法?Python model.BaseModel怎么用?Python model.BaseModel使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类model
的用法示例。
在下文中一共展示了model.BaseModel方法的2个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: pretrain
# 需要导入模块: import model [as 别名]
# 或者: from model import BaseModel [as 别名]
def pretrain(model, dataloader):
"""
Pre-normalizes a model (i.e., PreNormLayer layers) over the given samples.
Parameters
----------
model : model.BaseModel
A base model, which may contain some model.PreNormLayer layers.
dataloader : tf.data.Dataset
Dataset to use for pre-training the model.
Return
------
number of PreNormLayer layers processed.
"""
model.pre_train_init()
i = 0
while True:
for batch in dataloader:
c, ei, ev, v, n_cs, n_vs, n_cands, cands, best_cands, cand_scores = batch
batched_states = (c, ei, ev, v, n_cs, n_vs)
if not model.pre_train(batched_states, tf.convert_to_tensor(True)):
break
res = model.pre_train_next()
if res is None:
break
else:
layer, name = res
i += 1
return i
示例2: main
# 需要导入模块: import model [as 别名]
# 或者: from model import BaseModel [as 别名]
def main():
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_devices
use_gpu = torch.cuda.is_available()
if args.use_cpu: use_gpu = False
logging.basicConfig(level=logging.INFO)
if use_gpu:
print("Currently using GPU {}".format(args.gpu_devices))
cudnn.benchmark = True
torch.cuda.manual_seed_all(args.seed)
else:
print("Currently using CPU (GPU is highly recommended)")
logging.info("Initializing model...")
# model = BaseModel(args, use_gpu)
model = BertForSequenceClassification.from_pretrained(args.bert_model,
cache_dir=PYTORCH_PRETRAINED_BERT_CACHE / 'distributed_{}'.format(-1),
num_labels=2)
if args.resume:
model.load_state_dict(torch.load(args.load_model))
if use_gpu:
model = model.cuda()
params = sum(np.prod(p.size()) for p in model.parameters())
logging.info("Number of parameters: {}".format(params))
if not os.path.isdir(args.save_dir):
os.mkdir(args.save_dir)
train_dataset = BertDataset(args.input_train, "train")
dev_dataset = BertDataset(args.input_dev, "dev")
test_dataset = BertDataset(args.input_test, "test")
train_examples = len(train_dataset)
train_dataloader = \
BertDataLoader(train_dataset, mode="train", max_len=args.max_len, batch_size=args.batch_size, num_workers=4, shuffle=True)
dev_dataloader = \
BertDataLoader(dev_dataset, mode="dev", max_len=args.max_len, batch_size=args.batch_size, num_workers=4, shuffle=False)
test_dataloader = \
BertDataLoader(test_dataset, mode="test", max_len=args.max_len, batch_size=int(args.batch_size / 2), num_workers=4, shuffle=False)
trainer = Trainer(args, model, train_examples, use_gpu)
if args.resume == False:
logging.info("Beginning training...")
trainer.train(train_dataloader, dev_dataloader)
prediction, id = trainer.predict(test_dataloader)
with open(os.path.join(args.save_dir, "MG1833039.txt"), "w", encoding="utf-8") as f:
for index in range(len(prediction)):
f.write("{}\t{}\n".format(id[index], prediction[index]))
logging.info("Done!")