當前位置: 首頁>>代碼示例>>Python>>正文


Python file_utils.PYTORCH_PRETRAINED_BERT_CACHE屬性代碼示例

本文整理匯總了Python中pytorch_pretrained_bert.file_utils.PYTORCH_PRETRAINED_BERT_CACHE屬性的典型用法代碼示例。如果您正苦於以下問題:Python file_utils.PYTORCH_PRETRAINED_BERT_CACHE屬性的具體用法?Python file_utils.PYTORCH_PRETRAINED_BERT_CACHE怎麽用?Python file_utils.PYTORCH_PRETRAINED_BERT_CACHE使用的例子?那麽, 這裏精選的屬性代碼示例或許可以為您提供幫助。您也可以進一步了解該屬性所在pytorch_pretrained_bert.file_utils的用法示例。


在下文中一共展示了file_utils.PYTORCH_PRETRAINED_BERT_CACHE屬性的5個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。

示例1: __init__

# 需要導入模塊: from pytorch_pretrained_bert import file_utils [as 別名]
# 或者: from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE [as 別名]
def __init__(self, archive_file, model_file=None, use_cuda=False):
        if not os.path.isfile(archive_file):
            if not model_file:
                raise Exception("No model for DA-predictor is specified!")
            archive_file = cached_path(model_file)
        model_dir = os.path.dirname(os.path.abspath(__file__))
        if not os.path.exists(os.path.join(model_dir, 'checkpoints')):
            archive = zipfile.ZipFile(archive_file, 'r')
            archive.extractall(model_dir)
        
        load_dir = os.path.join(model_dir, "checkpoints/predictor/save_step_15120")
        if not os.path.exists(load_dir):
            archive = zipfile.ZipFile(f'{load_dir}.zip', 'r')
            archive.extractall(os.path.dirname(load_dir))
        
        self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", do_lower_case=False)
        self.max_seq_length = 256
        self.domain = 'restaurant'
        self.model = BertForSequenceClassification.from_pretrained(load_dir, 
            cache_dir=os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed_{}'.format(-1)), num_labels=44)
        self.device = 'cuda' if use_cuda else 'cpu'
        self.model.to(self.device) 
開發者ID:ConvLab,項目名稱:ConvLab,代碼行數:24,代碼來源:predictor.py

示例2: create_model

# 需要導入模塊: from pytorch_pretrained_bert import file_utils [as 別名]
# 或者: from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE [as 別名]
def create_model(task_type, bert_model_name, bert_load_mode, bert_load_args,
                 all_state,
                 num_labels, device, n_gpu, fp16, local_rank,
                 bert_config_json_path=None):
    if bert_load_mode == "from_pretrained":
        assert bert_load_args is None
        assert all_state is None
        assert bert_config_json_path is None
        cache_dir = PYTORCH_PRETRAINED_BERT_CACHE / 'distributed_{}'.format(local_rank)
        model = create_from_pretrained(
            task_type=task_type,
            bert_model_name=bert_model_name,
            cache_dir=cache_dir,
            num_labels=num_labels,
        )
    elif bert_load_mode in ["model_only", "state_model_only", "state_all", "state_full_model",
                            "full_model_only"]:
        assert bert_load_args is None
        model = load_bert(
            task_type=task_type,
            bert_model_name=bert_model_name,
            bert_load_mode=bert_load_mode,
            all_state=all_state,
            num_labels=num_labels,
            bert_config_json_path=bert_config_json_path,
        )
    elif bert_load_mode in ["state_adapter"]:
        model = load_bert_adapter(
            task_type=task_type,
            bert_model_name=bert_model_name,
            bert_load_mode=bert_load_mode,
            bert_load_args=bert_load_args,
            all_state=all_state,
            num_labels=num_labels,
            bert_config_json_path=bert_config_json_path,
        )
    else:
        raise KeyError(bert_load_mode)
    model = stage_model(model, fp16=fp16, device=device, local_rank=local_rank, n_gpu=n_gpu)
    return model 
開發者ID:zphang,項目名稱:bert_on_stilts,代碼行數:42,代碼來源:model_setup.py

示例3: create_model

# 需要導入模塊: from pytorch_pretrained_bert import file_utils [as 別名]
# 或者: from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE [as 別名]
def create_model(bert_model_name, bert_load_mode, bert_load_args,
                 all_state,
                 device, n_gpu, fp16, local_rank,
                 bert_config_json_path=None):
    if bert_load_mode == "from_pretrained":
        assert all_state is None
        assert bert_config_json_path is None
        assert bert_load_args is None
        cache_dir = PYTORCH_PRETRAINED_BERT_CACHE / 'distributed_{}'.format(local_rank)
        model = create_from_pretrained(
            bert_model_name=bert_model_name,
            cache_dir=cache_dir,
        )
    elif bert_load_mode in ["model_only", "state_model_only", "state_all", "state_full_model"]:
        assert bert_load_args is None
        model = load_bert(
            bert_model_name=bert_model_name,
            bert_load_mode=bert_load_mode,
            all_state=all_state,
            bert_config_json_path=bert_config_json_path,
        )
    elif bert_load_mode in ["state_adapter"]:
        raise NotImplementedError("Adapter")
    else:
        raise KeyError(bert_load_mode)
    model = stage_model(model, fp16=fp16, device=device, local_rank=local_rank, n_gpu=n_gpu)
    return model 
開發者ID:zphang,項目名稱:bert_on_stilts,代碼行數:29,代碼來源:model_setup.py

示例4: main

# 需要導入模塊: from pytorch_pretrained_bert import file_utils [as 別名]
# 或者: from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE [as 別名]
def main(output_model_file = './models/bert-base-uncased.bin', load = False, mode = 'tensors', batch_size = 12, 
            num_epoch = 1, gradient_accumulation_steps = 1, lr1 = 1e-4, lr2 = 1e-4, alpha = 0.2):
    
    BERT_MODEL = 'bert-base-uncased' # bert-large is too large for ordinary GPU on task #2
    tokenizer = BertTokenizer.from_pretrained(BERT_MODEL, do_lower_case=True)
    with open('./hotpot_train_v1.1_refined.json' ,'r') as fin:
        dataset = json.load(fin)
    bundles = []
    for data in tqdm(dataset):
        try:
            bundles.append(convert_question_to_samples_bundle(tokenizer, data))
        except ValueError as err:
            pass
        # except Exception as err:
        #     traceback.print_exc()
        #     pass
    device = torch.device('cpu') if not torch.cuda.is_available() else torch.device('cuda')
    if load:
        print('Loading model from {}'.format(output_model_file))
        model_state_dict = torch.load(output_model_file)
        model1 = BertForMultiHopQuestionAnswering.from_pretrained(BERT_MODEL, state_dict=model_state_dict['params1'])
        model2 = CognitiveGNN(model1.config.hidden_size)
        model2.load_state_dict(model_state_dict['params2'])

    else:
        model1 = BertForMultiHopQuestionAnswering.from_pretrained(BERT_MODEL,
                cache_dir=PYTORCH_PRETRAINED_BERT_CACHE / 'distributed_{}'.format(-1))
        model2 = CognitiveGNN(model1.config.hidden_size)

    print('Start Training... on {} GPUs'.format(torch.cuda.device_count()))
    model1 = torch.nn.DataParallel(model1, device_ids = range(torch.cuda.device_count()))
    model1, model2 = train(bundles, model1=model1, device=device, mode=mode, model2=model2, # Then pass hyperparams
        batch_size=batch_size, num_epoch=num_epoch, gradient_accumulation_steps=gradient_accumulation_steps,lr1=lr1, lr2=lr2, alpha=alpha)
    
    print('Saving model to {}'.format(output_model_file))
    saved_dict = {'params1' : model1.module.state_dict()}
    saved_dict['params2'] = model2.state_dict()
    torch.save(saved_dict, output_model_file) 
開發者ID:THUDM,項目名稱:CogQA,代碼行數:40,代碼來源:train.py

示例5: main

# 需要導入模塊: from pytorch_pretrained_bert import file_utils [as 別名]
# 或者: from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE [as 別名]
def main():
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_devices
    use_gpu = torch.cuda.is_available()
    if args.use_cpu: use_gpu = False

    logging.basicConfig(level=logging.INFO)

    if use_gpu:
        print("Currently using GPU {}".format(args.gpu_devices))
        cudnn.benchmark = True
        torch.cuda.manual_seed_all(args.seed)
    else:
        print("Currently using CPU (GPU is highly recommended)")

    logging.info("Initializing model...")
    # model = BaseModel(args, use_gpu)
    model = BertForSequenceClassification.from_pretrained(args.bert_model,
                cache_dir=PYTORCH_PRETRAINED_BERT_CACHE / 'distributed_{}'.format(-1),
                num_labels=2)

    if args.resume:
        model.load_state_dict(torch.load(args.load_model))

    if use_gpu:
        model = model.cuda()

    params = sum(np.prod(p.size()) for p in model.parameters())
    logging.info("Number of parameters: {}".format(params))

    if not os.path.isdir(args.save_dir):
        os.mkdir(args.save_dir)

    train_dataset = BertDataset(args.input_train, "train")
    dev_dataset = BertDataset(args.input_dev, "dev")
    test_dataset = BertDataset(args.input_test, "test")

    train_examples = len(train_dataset)

    train_dataloader = \
        BertDataLoader(train_dataset, mode="train", max_len=args.max_len, batch_size=args.batch_size, num_workers=4, shuffle=True)
    dev_dataloader = \
        BertDataLoader(dev_dataset, mode="dev", max_len=args.max_len, batch_size=args.batch_size, num_workers=4, shuffle=False)
    test_dataloader = \
        BertDataLoader(test_dataset, mode="test", max_len=args.max_len, batch_size=int(args.batch_size / 2), num_workers=4, shuffle=False)

    trainer = Trainer(args, model, train_examples, use_gpu)

    if args.resume == False:
        logging.info("Beginning training...")
        trainer.train(train_dataloader, dev_dataloader)

    prediction, id = trainer.predict(test_dataloader)

    with open(os.path.join(args.save_dir, "MG1833039.txt"), "w", encoding="utf-8") as f:
        for index in range(len(prediction)):
            f.write("{}\t{}\n".format(id[index], prediction[index]))

    logging.info("Done!") 
開發者ID:tracy-talent,項目名稱:curriculum,代碼行數:63,代碼來源:main.py


注:本文中的pytorch_pretrained_bert.file_utils.PYTORCH_PRETRAINED_BERT_CACHE屬性示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。