本文簡要介紹python語言中 torch.nn.TripletMarginWithDistanceLoss
的用法。
用法:
class torch.nn.TripletMarginWithDistanceLoss(*, distance_function=None, margin=1.0, swap=False, reduction='mean')
distance_function(可調用的,可選的) -量化兩個張量的接近度的非負實值函數。如果未指定,將使用
nn.PairwiseDistance
。默認值:None
margin(float,可選的) -一個非負邊距,表示損失為 0 所需的正負距離之間的最小差異。較大的邊距會懲罰負樣本相對於正樣本距離錨點不夠遠的情況。默認值: 。
swap(bool,可選的) -是否使用 V. Balntas、E. Riba 等人在論文
Learning shallow convolutional feature descriptors with triplet losses
中說明的距離交換。如果為 True,並且如果正例比錨點更接近負例,則在損失計算中交換正例和錨點。默認值:False
。reduction(string,可選的) -指定應用於輸出的(可選)縮減:
'none'
|'mean'
|'sum'
。'none'
:不應用減少,'mean'
:輸出的總和將除以輸出中的元素數,'sum'
:輸出將被求和。默認值:'mean'
在給定輸入張量 、 和 (分別表示錨點、正例和負例)和使用的非負實值函數 (“distance function”) 的情況下,創建一個衡量三元組損失的標準計算錨點和正例(“positive distance”)和錨點和負例(“negative distance”)之間的關係。
未減少的損失(即
reduction
設置為'none'
)可以說明為:其中
distance_function
; 是一個非負邊距,表示損失為 0 所需的正負距離之間的最小差異。輸入張量每個都有 元素,並且可以是距離函數可以處理的任何形狀。 是批量大小; 是一個非負實值函數,用於量化兩個張量的接近度,稱為如果
reduction
不是'none'
(默認'mean'
),則:另請參見
TripletMarginLoss
,它使用 距離作為距離函數來計算輸入張量的三元組損失。- 形狀:
輸入: 其中 表示距離函數支持的任意數量的附加維度。
輸出:如果
reduction
為'none'
則為形狀為 的張量,否則為標量。
例子:
>>> # Initialize embeddings >>> embedding = nn.Embedding(1000, 128) >>> anchor_ids = torch.randint(0, 1000, (1,)) >>> positive_ids = torch.randint(0, 1000, (1,)) >>> negative_ids = torch.randint(0, 1000, (1,)) >>> anchor = embedding(anchor_ids) >>> positive = embedding(positive_ids) >>> negative = embedding(negative_ids) >>> >>> # Built-in Distance Function >>> triplet_loss = \ >>> nn.TripletMarginWithDistanceLoss(distance_function=nn.PairwiseDistance()) >>> output = triplet_loss(anchor, positive, negative) >>> output.backward() >>> >>> # Custom Distance Function >>> def l_infinity(x1, x2): >>> return torch.max(torch.abs(x1 - x2), dim=1).values >>> >>> triplet_loss = \ >>> nn.TripletMarginWithDistanceLoss(distance_function=l_infinity, margin=1.5) >>> output = triplet_loss(anchor, positive, negative) >>> output.backward() >>> >>> # Custom Distance Function (Lambda) >>> triplet_loss = \ >>> nn.TripletMarginWithDistanceLoss( >>> distance_function=lambda x, y: 1.0 - F.cosine_similarity(x, y)) >>> output = triplet_loss(anchor, positive, negative) >>> output.backward()
- 參考:
V. Balntas 等人:學習具有三元組損失的淺層卷積特征說明符:http://www.bmva.org/bmvc/2016/papers/paper119/index.html
參數:
相關用法
- Python PyTorch TripletMarginLoss用法及代碼示例
- Python PyTorch TransformerEncoder用法及代碼示例
- Python PyTorch TransformedDistribution用法及代碼示例
- Python PyTorch Transform用法及代碼示例
- Python PyTorch Transformer用法及代碼示例
- Python PyTorch TransformerDecoderLayer用法及代碼示例
- Python PyTorch TransformerDecoder用法及代碼示例
- Python PyTorch Transformer.forward用法及代碼示例
- Python PyTorch TransformerEncoderLayer用法及代碼示例
- Python PyTorch Tensor.unflatten用法及代碼示例
- Python PyTorch Tensor.register_hook用法及代碼示例
- Python PyTorch TarArchiveLoader用法及代碼示例
- Python PyTorch Tensor.storage_offset用法及代碼示例
- Python PyTorch Tensor.to用法及代碼示例
- Python PyTorch Tensor.sparse_mask用法及代碼示例
- Python PyTorch Timer用法及代碼示例
- Python PyTorch TimeMasking用法及代碼示例
- Python PyTorch Tacotron2TTSBundle.get_text_processor用法及代碼示例
- Python PyTorch Tensor.is_leaf用法及代碼示例
- Python PyTorch Tensor.imag用法及代碼示例
- Python PyTorch Tensor.unfold用法及代碼示例
- Python PyTorch TenCrop用法及代碼示例
- Python PyTorch Tensor.real用法及代碼示例
- Python PyTorch TwRwSparseFeaturesDist用法及代碼示例
- Python PyTorch Tensor.refine_names用法及代碼示例
注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.nn.TripletMarginWithDistanceLoss。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。