当前位置: 首页>>代码示例>>Python>>正文


Python config.device方法代码示例

本文整理汇总了Python中config.device方法的典型用法代码示例。如果您正苦于以下问题:Python config.device方法的具体用法?Python config.device怎么用?Python config.device使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在config的用法示例。


在下文中一共展示了config.device方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: forward

# 需要导入模块: import config [as 别名]
# 或者: from config import device [as 别名]
def forward(self, event, control=None, hidden=None):
        # One step forward

        assert len(event.shape) == 2
        assert event.shape[0] == 1
        batch_size = event.shape[1]
        event = self.event_embedding(event)

        if control is None:
            default = torch.ones(1, batch_size, 1).to(device)
            control = torch.zeros(1, batch_size, self.control_dim).to(device)
        else:
            default = torch.zeros(1, batch_size, 1).to(device)
            assert control.shape == (1, batch_size, self.control_dim)

        concat = torch.cat([event, default, control], -1)
        input = self.concat_input_fc(concat)
        input = self.concat_input_fc_activation(input)

        _, hidden = self.gru(input, hidden)
        output = hidden.permute(1, 0, 2).contiguous()
        output = output.view(batch_size, -1).unsqueeze(0)
        output = self.output_fc(output)
        return output, hidden 
开发者ID:djosix,项目名称:Performance-RNN-PyTorch,代码行数:26,代码来源:model.py

示例2: load_session

# 需要导入模块: import config [as 别名]
# 或者: from config import device [as 别名]
def load_session():
    global sess_path, model_config, device, learning_rate, reset_optimizer
    try:
        sess = torch.load(sess_path)
        if 'model_config' in sess and sess['model_config'] != model_config:
            model_config = sess['model_config']
            print('Use session config instead:')
            print(utils.dict2params(model_config))
        model_state = sess['model_state']
        optimizer_state = sess['model_optimizer_state']
        print('Session is loaded from', sess_path)
        sess_loaded = True
    except:
        print('New session')
        sess_loaded = False
    model = PerformanceRNN(**model_config).to(device)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    if sess_loaded:
        model.load_state_dict(model_state)
        if not reset_optimizer:
            optimizer.load_state_dict(optimizer_state)
    return model, optimizer 
开发者ID:djosix,项目名称:Performance-RNN-PyTorch,代码行数:24,代码来源:train.py

示例3: forward

# 需要导入模块: import config [as 别名]
# 或者: from config import device [as 别名]
def forward(self, input, label):
        x = F.normalize(input)
        W = F.normalize(self.weight)
        cosine = F.linear(x, W)
        sine = torch.sqrt(1.0 - torch.pow(cosine, 2))
        phi = cosine * self.cos_m - sine * self.sin_m  # cos(theta + m)
        if self.easy_margin:
            phi = torch.where(cosine > 0, phi, cosine)
        else:
            phi = torch.where(cosine > self.th, phi, cosine - self.mm)
        one_hot = torch.zeros(cosine.size(), device=device)
        one_hot.scatter_(1, label.view(-1, 1).long(), 1)
        output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
        output *= self.s

        return output 
开发者ID:LcenArthas,项目名称:CCF-BDCI2019-Multi-person-Face-Recognition-Competition-Baseline,代码行数:18,代码来源:models.py

示例4: init_model

# 需要导入模块: import config [as 别名]
# 或者: from config import device [as 别名]
def init_model(self):
        if cfg.oracle_pretrain:
            if not os.path.exists(cfg.oracle_state_dict_path):
                create_oracle()
            self.oracle.load_state_dict(torch.load(cfg.oracle_state_dict_path))

        if cfg.dis_pretrain:
            self.log.info(
                'Load pretrained discriminator: {}'.format(cfg.pretrained_dis_path))
            self.dis.load_state_dict(torch.load(cfg.pretrained_dis_path))
        if cfg.gen_pretrain:
            self.log.info('Load MLE pretrained generator gen: {}'.format(cfg.pretrained_gen_path))
            self.gen.load_state_dict(torch.load(cfg.pretrained_gen_path, map_location='cuda:{}'.format(cfg.device)))

        if cfg.CUDA:
            self.oracle = self.oracle.cuda()
            self.gen = self.gen.cuda()
            self.dis = self.dis.cuda() 
开发者ID:williamSYSU,项目名称:TextGAN-PyTorch,代码行数:20,代码来源:instructor.py

示例5: init_model

# 需要导入模块: import config [as 别名]
# 或者: from config import device [as 别名]
def init_model(self):
        if cfg.dis_pretrain:
            self.log.info(
                'Load pretrained discriminator: {}'.format(cfg.pretrained_dis_path))
            self.dis.load_state_dict(torch.load(cfg.pretrained_dis_path))
        if cfg.gen_pretrain:
            for i in range(cfg.k_label):
                self.log.info('Load MLE pretrained generator gen: {}'.format(cfg.pretrained_gen_path + '%d' % i))
                self.gen_list[i].load_state_dict(torch.load(cfg.pretrained_gen_path + '%d' % i))
        if cfg.clas_pretrain:
            self.log.info('Load  pretrained classifier: {}'.format(cfg.pretrained_clas_path))
            self.clas.load_state_dict(torch.load(cfg.pretrained_clas_path, map_location='cuda:%d' % cfg.device))

        if cfg.CUDA:
            for i in range(cfg.k_label):
                self.gen_list[i] = self.gen_list[i].cuda()
            self.dis = self.dis.cuda()
            self.clas = self.clas.cuda() 
开发者ID:williamSYSU,项目名称:TextGAN-PyTorch,代码行数:20,代码来源:sentigan_instructor.py

示例6: sample

# 需要导入模块: import config [as 别名]
# 或者: from config import device [as 别名]
def sample(self, batch_size, net, target_net, beta):
        probability_sum = sum(self.memory_probabiliy)
        p = [probability / probability_sum for probability in self.memory_probabiliy]

        indexes = np.random.choice(np.arange(len(self.memory)), batch_size, p=p)
        transitions = [self.memory[idx] for idx in indexes]
        transitions_p = [p[idx] for idx in indexes]
        batch = Transition(*zip(*transitions))

        weights = [pow(self.capacity * p_j, -beta) for p_j in transitions_p]
        weights = torch.Tensor(weights).to(device)
        weights = weights / weights.max()


        td_error = QNet.get_td_error(net, target_net, batch.state, batch.next_state, batch.action, batch.reward, batch.mask)
        td_error = td_error.detach()

        td_error_idx = 0
        for idx in indexes:
            self.memory_probabiliy[idx] = pow(abs(td_error[td_error_idx]) + small_epsilon, alpha).item()
            # print(pow(abs(td_error[td_error_idx]) + small_epsilon, alpha).item())
            td_error_idx += 1


        return batch, weights 
开发者ID:g6ling,项目名称:Reinforcement-Learning-Pytorch-Cartpole,代码行数:27,代码来源:memory.py

示例7: sample

# 需要导入模块: import config [as 别名]
# 或者: from config import device [as 别名]
def sample(self, batch_size, net, target_net, beta):
        probability_sum = sum(self.memory_probabiliy)
        p = [probability / probability_sum for probability in self.memory_probabiliy]
        # print(len(self.memory_probabiliy))
        indexes = np.random.choice(np.arange(len(self.memory)), batch_size, p=p)
        transitions = [self.memory[idx] for idx in indexes]
        transitions_p = [p[idx] for idx in indexes]
        batch = Transition(*zip(*transitions))

        weights = [pow(self.capacity * p_j, -beta) for p_j in transitions_p]
        weights = torch.Tensor(weights).to(device)
        # print(weights)
        weights = weights / weights.max()
        # print(weights)

        td_error = QNet.get_td_error(net, target_net, batch.state, batch.next_state, batch.action, batch.reward, batch.mask)

        td_error_idx = 0
        for idx in indexes:
            self.memory_probabiliy[idx] = pow(abs(td_error[td_error_idx]) + small_epsilon, alpha).item()
            # print(pow(abs(td_error[td_error_idx]) + small_epsilon, alpha).item())
            td_error_idx += 1


        return batch, weights 
开发者ID:g6ling,项目名称:Reinforcement-Learning-Pytorch-Cartpole,代码行数:27,代码来源:memory.py

示例8: sample

# 需要导入模块: import config [as 别名]
# 或者: from config import device [as 别名]
def sample(self, batch_size, net, target_net, beta):
        probability_sum = sum(self.memory_probabiliy)
        p = [probability / probability_sum for probability in self.memory_probabiliy]

        indexes = np.random.choice(np.arange(len(self.memory)), batch_size, p=p)
        transitions = [self.memory[idx] for idx in indexes]
        transitions_p = [p[idx] for idx in indexes]
        batch = Transition(*zip(*transitions))

        weights = [pow(self.capacity * p_j, -beta) for p_j in transitions_p]
        weights = torch.Tensor(weights).to(device)
        weights = weights / weights.max()

        td_error = QNet.get_loss(net, target_net, batch.state, batch.next_state, batch.action, batch.reward, batch.mask)
        td_error = td_error.detach()

        td_error_idx = 0
        for idx in indexes:
            self.memory_probabiliy[idx] = pow(abs(td_error[td_error_idx]) + small_epsilon, alpha).item()
            td_error_idx += 1


        return batch, weights 
开发者ID:g6ling,项目名称:Reinforcement-Learning-Pytorch-Cartpole,代码行数:25,代码来源:memory.py

示例9: valid

# 需要导入模块: import config [as 别名]
# 或者: from config import device [as 别名]
def valid(valid_loader, model, logger):
    model.eval()

    losses = AverageMeter()

    # Batches
    for data in tqdm(valid_loader):
        # Move to GPU, if available
        padded_input, padded_target, input_lengths = data
        padded_input = padded_input.to(device)
        padded_target = padded_target.to(device)
        input_lengths = input_lengths.to(device)

        with torch.no_grad():
            # Forward prop.
            pred, gold = model(padded_input, input_lengths, padded_target)
            loss, n_correct = cal_performance(pred, gold, smoothing=args.label_smoothing)

        # Keep track of metrics
        losses.update(loss.item())

    # Print status
    logger.info('\nValidation Loss {loss.val:.5f} ({loss.avg:.5f})\n'.format(loss=losses))

    return losses.avg 
开发者ID:foamliu,项目名称:Speech-Transformer,代码行数:27,代码来源:train.py

示例10: get_primary_event

# 需要导入模块: import config [as 别名]
# 或者: from config import device [as 别名]
def get_primary_event(self, batch_size):
        return torch.LongTensor([[self.primary_event] * batch_size]).to(device) 
开发者ID:djosix,项目名称:Performance-RNN-PyTorch,代码行数:4,代码来源:model.py

示例11: get_image

# 需要导入模块: import config [as 别名]
# 或者: from config import device [as 别名]
def get_image(transformer, filepath, flip=False):
    img = Image.open(filepath)
    if flip:
        img = ImageOps.flip(img)
    img = transformer(img)
    return img.to(device) 
开发者ID:foamliu,项目名称:InsightFace-PyTorch,代码行数:8,代码来源:megaface_utils.py

示例12: forward

# 需要导入模块: import config [as 别名]
# 或者: from config import device [as 别名]
def forward(self, input, label):
        x = F.normalize(input)
        W = F.normalize(self.weight)
        cosine = F.linear(x, W)
        sine = torch.sqrt(1.0 - torch.pow(cosine, 2))
        phi = cosine * self.cos_m - sine * self.sin_m  # cos(theta + m)
        if self.easy_margin:
            phi = torch.where(cosine > 0, phi, cosine)
        else:
            phi = torch.where(cosine > self.th, phi, cosine - self.mm)
        one_hot = torch.zeros(cosine.size(), device=device)
        one_hot.scatter_(1, label.view(-1, 1).long(), 1)
        output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
        output *= self.s
        return output 
开发者ID:foamliu,项目名称:InsightFace-PyTorch,代码行数:17,代码来源:models.py

示例13: get_image

# 需要导入模块: import config [as 别名]
# 或者: from config import device [as 别名]
def get_image(filepath, transformer, flip=False):
    img = Image.open(filepath).convert('RGB')
    if flip:
        img = img.transpose(Image.FLIP_LEFT_RIGHT)
    img = transformer(img)
    return img.to(device) 
开发者ID:LcenArthas,项目名称:CCF-BDCI2019-Multi-person-Face-Recognition-Competition-Baseline,代码行数:8,代码来源:megaface_utils.py

示例14: train

# 需要导入模块: import config [as 别名]
# 或者: from config import device [as 别名]
def train(train_loader, model, optimizer, epoch, logger):
    model.train()  # train mode (dropout and batchnorm is used)

    losses = AverageMeter()

    # Batches
    for i, (img, alpha_label) in enumerate(train_loader):
        # Move to GPU, if available
        img = img.type(torch.FloatTensor).to(device)  # [N, 4, 320, 320]
        alpha_label = alpha_label.type(torch.FloatTensor).to(device)  # [N, 2, 320, 320]
        alpha_label = alpha_label.reshape((-1, 2, im_size * im_size))  # [N, 2, 320*320]

        # Forward prop.
        alpha_out = model(img)  # [N, 320, 320]
        alpha_out = alpha_out.reshape((-1, 1, im_size * im_size))  # [N, 320*320]

        # Calculate loss
        # loss = criterion(alpha_out, alpha_label)
        loss = alpha_prediction_loss(alpha_out, alpha_label)

        # Back prop.
        optimizer.zero_grad()
        loss.backward()

        # Clip gradients
        clip_gradient(optimizer, grad_clip)

        # Update weights
        optimizer.step()

        # Keep track of metrics
        losses.update(loss.item())

        # Print status

        if i % print_freq == 0:
            status = 'Epoch: [{0}][{1}/{2}]\t' \
                     'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(epoch, i, len(train_loader), loss=losses)
            logger.info(status)

    return losses.avg 
开发者ID:foamliu,项目名称:Mobile-Image-Matting,代码行数:43,代码来源:train.py

示例15: valid

# 需要导入模块: import config [as 别名]
# 或者: from config import device [as 别名]
def valid(valid_loader, model, logger):
    model.eval()  # eval mode (dropout and batchnorm is NOT used)

    losses = AverageMeter()

    # Batches
    for img, alpha_label in tqdm(valid_loader):
        # Move to GPU, if available
        img = img.type(torch.FloatTensor).to(device)  # [N, 4, 320, 320]
        alpha_label = alpha_label.type(torch.FloatTensor).to(device)  # [N, 2, 320, 320]
        alpha_label = alpha_label.reshape((-1, 2, im_size * im_size))  # [N, 2, 320*320]

        # Forward prop.
        alpha_out = model(img)  # [N, 320, 320]
        alpha_out = alpha_out.reshape((-1, 1, im_size * im_size))  # [N, 320*320]

        # Calculate loss
        # loss = criterion(alpha_out, alpha_label)
        loss = alpha_prediction_loss(alpha_out, alpha_label)

        # Keep track of metrics
        losses.update(loss.item())

    # Print status
    status = 'Validation: Loss {loss.avg:.4f}\n'.format(loss=losses)
    logger.info(status)

    return losses.avg 
开发者ID:foamliu,项目名称:Mobile-Image-Matting,代码行数:30,代码来源:train.py


注:本文中的config.device方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。