當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python PyTorch TripletMarginWithDistanceLoss用法及代碼示例


本文簡要介紹python語言中 torch.nn.TripletMarginWithDistanceLoss 的用法。

用法:

class torch.nn.TripletMarginWithDistanceLoss(*, distance_function=None, margin=1.0, swap=False, reduction='mean')

參數

  • distance_function(可調用的,可選的) -量化兩個張量的接近度的非負實值函數。如果未指定,將使用nn.PairwiseDistance。默認值:None

  • margin(float,可選的) -一個非負邊距,表示損失為 0 所需的正負距離之間的最小差異。較大的邊距會懲罰負樣本相對於正樣本距離錨點不夠遠的情況。默認值:

  • swap(bool,可選的) -是否使用 V. Balntas、E. Riba 等人在論文 Learning shallow convolutional feature descriptors with triplet losses 中說明的距離交換。如果為 True,並且如果正例比錨點更接近負例,則在損失計算中交換正例和錨點。默認值:False

  • reduction(string,可選的) -指定應用於輸出的(可選)縮減:'none' | 'mean' | 'sum''none' :不應用減少,'mean':輸出的總和將除以輸出中的元素數,'sum':輸出將被求和。默認值:'mean'

在給定輸入張量 (分別表示錨點、正例和負例)和使用的非負實值函數 (“distance function”) 的情況下,創建一個衡量三元組損失的標準計算錨點和正例(“positive distance”)和錨點和負例(“negative distance”)之間的關係。

未減少的損失(即 reduction 設置為 'none' )可以說明為:

其中 是批量大小; 是一個非負實值函數,用於量化兩個張量的接近度,稱為 distance_function 是一個非負邊距,表示損失為 0 所需的正負距離之間的最小差異。輸入張量每個都有 元素,並且可以是距離函數可以處理的任何形狀。

如果 reduction 不是 'none' (默認 'mean' ),則:

另請參見 TripletMarginLoss ,它使用 距離作為距離函數來計算輸入張量的三元組損失。

形狀:
  • 輸入: 其中 表示距離函數支持的任意數量的附加維度。

  • 輸出:如果 reduction'none' 則為形狀為 的張量,否則為標量。

例子:

>>> # Initialize embeddings
>>> embedding = nn.Embedding(1000, 128)
>>> anchor_ids = torch.randint(0, 1000, (1,))
>>> positive_ids = torch.randint(0, 1000, (1,))
>>> negative_ids = torch.randint(0, 1000, (1,))
>>> anchor = embedding(anchor_ids)
>>> positive = embedding(positive_ids)
>>> negative = embedding(negative_ids)
>>>
>>> # Built-in Distance Function
>>> triplet_loss = \
>>>     nn.TripletMarginWithDistanceLoss(distance_function=nn.PairwiseDistance())
>>> output = triplet_loss(anchor, positive, negative)
>>> output.backward()
>>>
>>> # Custom Distance Function
>>> def l_infinity(x1, x2):
>>>     return torch.max(torch.abs(x1 - x2), dim=1).values
>>>
>>> triplet_loss = \
>>>     nn.TripletMarginWithDistanceLoss(distance_function=l_infinity, margin=1.5)
>>> output = triplet_loss(anchor, positive, negative)
>>> output.backward()
>>>
>>> # Custom Distance Function (Lambda)
>>> triplet_loss = \
>>>     nn.TripletMarginWithDistanceLoss(
>>>         distance_function=lambda x, y: 1.0 - F.cosine_similarity(x, y))
>>> output = triplet_loss(anchor, positive, negative)
>>> output.backward()
參考:

V. Balntas 等人:學習具有三元組損失的淺層卷積特征說明符:http://www.bmva.org/bmvc/2016/papers/paper119/index.html

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.nn.TripletMarginWithDistanceLoss。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。