當前位置: 首頁>>代碼示例>>Python>>正文


Python nn.TripletMarginLoss方法代碼示例

本文整理匯總了Python中torch.nn.TripletMarginLoss方法的典型用法代碼示例。如果您正苦於以下問題:Python nn.TripletMarginLoss方法的具體用法?Python nn.TripletMarginLoss怎麽用?Python nn.TripletMarginLoss使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在torch.nn的用法示例。


在下文中一共展示了nn.TripletMarginLoss方法的5個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。

示例1: __init__

# 需要導入模塊: from torch import nn [as 別名]
# 或者: from torch.nn import TripletMarginLoss [as 別名]
def __init__(self, config, net):
        self.log_dir = config.log_dir
        self.model_dir = config.model_dir
        self.net = net
        self.clock = TrainClock()
        self.device = config.device

        self.use_triplet = config.use_triplet
        self.use_footvel_loss = config.use_footvel_loss

        # set loss function
        self.mse = nn.MSELoss()
        self.tripletloss = nn.TripletMarginLoss(margin=config.triplet_margin)
        self.triplet_weight = config.triplet_weight
        self.foot_idx = config.foot_idx
        self.footvel_loss_weight = config.footvel_loss_weight

        # set optimizer
        self.optimizer = optim.Adam(self.net.parameters(), config.lr)
        self.scheduler = optim.lr_scheduler.ExponentialLR(self.optimizer, 0.99) 
開發者ID:ChrisWu1997,項目名稱:2D-Motion-Retargeting,代碼行數:22,代碼來源:base_agent.py

示例2: setup

# 需要導入模塊: from torch import nn [as 別名]
# 或者: from torch.nn import TripletMarginLoss [as 別名]
def setup(model, opt):

    if opt.criterion == "l1":
        criterion = nn.L1Loss().cuda()
    elif opt.criterion == "mse":
        criterion = nn.MSELoss().cuda()
    elif opt.criterion == "crossentropy":
        criterion = nn.CrossEntropyLoss().cuda()
    elif opt.criterion == "hingeEmbedding":
        criterion = nn.HingeEmbeddingLoss().cuda()
    elif opt.criterion == "tripletmargin":
        criterion = nn.TripletMarginLoss(margin = opt.margin, swap = opt.anchorswap).cuda()

    parameters = filter(lambda p: p.requires_grad, model.parameters())

    if opt.optimType == 'sgd':
        optimizer = optim.SGD(parameters, lr = opt.lr, momentum = opt.momentum, nesterov = opt.nesterov, weight_decay = opt.weightDecay)
    elif opt.optimType == 'adam':
        optimizer = optim.Adam(parameters, lr = opt.maxlr, weight_decay = opt.weightDecay)

    if opt.weight_init:
        utils.weights_init(model, opt)

    return model, criterion, optimizer 
開發者ID:drimpossible,項目名稱:Deep-Expander-Networks,代碼行數:26,代碼來源:__init__.py

示例3: build_criterion

# 需要導入模塊: from torch import nn [as 別名]
# 或者: from torch.nn import TripletMarginLoss [as 別名]
def build_criterion(loss_dict):

    if loss_dict.type == 'CrossEntropyLoss':
        weight = loss_dict.weight
        size_average = loss_dict.size_average
        reduce = loss_dict.reduce
        reduction = loss_dict.reduction

        if loss_dict.use_sigmoid:
            return nn.BCEWithLogitsLoss(
                weight=weight,
                size_average=size_average,
                reduce=reduce,
                reduction=reduction)
        else:
            return nn.CrossEntropyLoss(
                weight=weight,
                size_average=size_average,
                reduce=reduce,
                reduction=reduction)

    elif loss_dict.type == 'TripletLoss':
        return nn.TripletMarginLoss(margin=loss_dict.margin, p=loss_dict.p)

    else:
        raise TypeError('{} cannot be processed'.format(loss_dict.type)) 
開發者ID:open-mmlab,項目名稱:mmfashion,代碼行數:28,代碼來源:utils.py

示例4: __init__

# 需要導入模塊: from torch import nn [as 別名]
# 或者: from torch.nn import TripletMarginLoss [as 別名]
def __init__(self, margin = None):
        super(TripletLoss, self).__init__()
        self.margin = margin
        if self.margin is None:  # use soft-margin
            self.Loss = nn.SoftMarginLoss()
        else:
            self.Loss = nn.TripletMarginLoss(margin = margin, p = 2) 
開發者ID:CoinCheung,項目名稱:triplet-reid-pytorch,代碼行數:9,代碼來源:loss.py

示例5: __init__

# 需要導入模塊: from torch import nn [as 別名]
# 或者: from torch.nn import TripletMarginLoss [as 別名]
def __init__(self, margin=None):
        super(TripletLoss, self).__init__()
        self.margin = margin
        if self.margin is None:  # if no margin assigned, use soft-margin
            self.Loss = nn.SoftMarginLoss()
        else:
            self.Loss = nn.TripletMarginLoss(margin=margin, p=2) 
開發者ID:CoinCheung,項目名稱:pytorch-loss,代碼行數:9,代碼來源:triplet_loss.py


注:本文中的torch.nn.TripletMarginLoss方法示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。