当前位置: 首页>>代码示例>>Python>>正文


Python pytorch_pretrained_bert.BertAdam方法代码示例

本文整理汇总了Python中pytorch_pretrained_bert.BertAdam方法的典型用法代码示例。如果您正苦于以下问题:Python pytorch_pretrained_bert.BertAdam方法的具体用法?Python pytorch_pretrained_bert.BertAdam怎么用?Python pytorch_pretrained_bert.BertAdam使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在pytorch_pretrained_bert的用法示例。


在下文中一共展示了pytorch_pretrained_bert.BertAdam方法的4个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: init_optimizer

# 需要导入模块: import pytorch_pretrained_bert [as 别名]
# 或者: from pytorch_pretrained_bert import BertAdam [as 别名]
def init_optimizer(model, config, *args, **params):
    optimizer_type = config.get("train", "optimizer")
    learning_rate = config.getfloat("train", "learning_rate")
    if optimizer_type == "adam":
        optimizer = optim.Adam(model.parameters(), lr=learning_rate,
                               weight_decay=config.getfloat("train", "weight_decay"))
    elif optimizer_type == "sgd":
        optimizer = optim.SGD(model.parameters(), lr=learning_rate,
                              weight_decay=config.getfloat("train", "weight_decay"))
    elif optimizer_type == "bert_adam":
        optimizer = BertAdam(model.parameters(), lr=learning_rate,
                             weight_decay=config.getfloat("train", "weight_decay"))
    else:
        raise NotImplementedError

    return optimizer 
开发者ID:haoxizhong,项目名称:pytorch-worker,代码行数:18,代码来源:optimizer.py

示例2: test_bert_sched_init

# 需要导入模块: import pytorch_pretrained_bert [as 别名]
# 或者: from pytorch_pretrained_bert import BertAdam [as 别名]
def test_bert_sched_init(self):
        m = torch.nn.Linear(50, 50)
        optim = BertAdam(m.parameters(), lr=0.001, warmup=.1, t_total=1000, schedule=None)
        self.assertTrue(isinstance(optim.param_groups[0]["schedule"], ConstantLR))
        optim = BertAdam(m.parameters(), lr=0.001, warmup=.1, t_total=1000, schedule="none")
        self.assertTrue(isinstance(optim.param_groups[0]["schedule"], ConstantLR))
        optim = BertAdam(m.parameters(), lr=0.001, warmup=.01, t_total=1000)
        self.assertTrue(isinstance(optim.param_groups[0]["schedule"], WarmupLinearSchedule))
        # shouldn't fail 
开发者ID:martiansideofthemoon,项目名称:squash-generation,代码行数:11,代码来源:optimization_test.py

示例3: test_adam

# 需要导入模块: import pytorch_pretrained_bert [as 别名]
# 或者: from pytorch_pretrained_bert import BertAdam [as 别名]
def test_adam(self):
        w = torch.tensor([0.1, -0.2, -0.1], requires_grad=True)
        target = torch.tensor([0.4, 0.2, -0.5])
        criterion = torch.nn.MSELoss()
        # No warmup, constant schedule, no gradient clipping
        optimizer = BertAdam(params=[w], lr=2e-1,
                                          weight_decay=0.0,
                                          max_grad_norm=-1)
        for _ in range(100):
            loss = criterion(w, target)
            loss.backward()
            optimizer.step()
            w.grad.detach_() # No zero_grad() function on simple tensors. we do it ourselves.
            w.grad.zero_()
        self.assertListAlmostEqual(w.tolist(), [0.4, 0.2, -0.5], tol=1e-2) 
开发者ID:Meelfy,项目名称:pytorch_pretrained_BERT,代码行数:17,代码来源:optimization_test.py

示例4: _setup_optim

# 需要导入模块: import pytorch_pretrained_bert [as 别名]
# 或者: from pytorch_pretrained_bert import BertAdam [as 别名]
def _setup_optim(self, optimizer_parameters, state_dict=None, num_train_step=-1):
        if self.config['optimizer'] == 'sgd':
            self.optimizer = optim.SGD(optimizer_parameters, self.config['learning_rate'],
                                       weight_decay=self.config['weight_decay'])

        elif self.config['optimizer'] == 'adamax':
            self.optimizer = Adamax(optimizer_parameters,
                                    self.config['learning_rate'],
                                    warmup=self.config['warmup'],
                                    t_total=num_train_step,
                                    max_grad_norm=self.config['grad_clipping'],
                                    schedule=self.config['warmup_schedule'],
                                    weight_decay=self.config['weight_decay'])
            if self.config.get('have_lr_scheduler', False): self.config['have_lr_scheduler'] = False
        elif self.config['optimizer'] == 'radam':
            self.optimizer = RAdam(optimizer_parameters,
                                    self.config['learning_rate'],
                                    warmup=self.config['warmup'],
                                    t_total=num_train_step,
                                    max_grad_norm=self.config['grad_clipping'],
                                    schedule=self.config['warmup_schedule'],
                                    eps=self.config['adam_eps'],
                                    weight_decay=self.config['weight_decay'])
            if self.config.get('have_lr_scheduler', False): self.config['have_lr_scheduler'] = False
            # The current radam does not support FP16.
            self.config['fp16'] = False
        elif self.config['optimizer'] == 'adam':
            self.optimizer = Adam(optimizer_parameters,
                                  lr=self.config['learning_rate'],
                                  warmup=self.config['warmup'],
                                  t_total=num_train_step,
                                  max_grad_norm=self.config['grad_clipping'],
                                  schedule=self.config['warmup_schedule'],
                                  weight_decay=self.config['weight_decay'])
            if self.config.get('have_lr_scheduler', False): self.config['have_lr_scheduler'] = False
        else:
            raise RuntimeError('Unsupported optimizer: %s' % opt['optimizer'])

        if state_dict and 'optimizer' in state_dict:
            self.optimizer.load_state_dict(state_dict['optimizer'])

        if self.config['fp16']:
            try:
                from apex import amp
                global amp
            except ImportError:
                raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
            model, optimizer = amp.initialize(self.network, self.optimizer, opt_level=self.config['fp16_opt_level'])
            self.network = model
            self.optimizer = optimizer

        if self.config.get('have_lr_scheduler', False):
            if self.config.get('scheduler_type', 'rop') == 'rop':
                self.scheduler = ReduceLROnPlateau(self.optimizer, mode='max', factor=self.config['lr_gamma'], patience=3)
            elif self.config.get('scheduler_type', 'rop') == 'exp':
                self.scheduler = ExponentialLR(self.optimizer, gamma=self.config.get('lr_gamma', 0.95))
            else:
                milestones = [int(step) for step in self.config.get('multi_step_lr', '10,20,30').split(',')]
                self.scheduler = MultiStepLR(self.optimizer, milestones=milestones, gamma=self.config.get('lr_gamma'))
        else:
            self.scheduler = None 
开发者ID:namisan,项目名称:mt-dnn,代码行数:63,代码来源:model.py


注:本文中的pytorch_pretrained_bert.BertAdam方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。