本文整理汇总了Python中torch.max方法的典型用法代码示例。如果您正苦于以下问题:Python torch.max方法的具体用法?Python torch.max怎么用?Python torch.max使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类torch
的用法示例。
在下文中一共展示了torch.max方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: greedy_decode
# 需要导入模块: import torch [as 别名]
# 或者: from torch import max [as 别名]
def greedy_decode(self, latent, max_len, start_id):
'''
latent: (batch_size, max_src_seq, d_model)
src_mask: (batch_size, 1, max_src_len)
'''
batch_size = latent.size(0)
ys = get_cuda(torch.ones(batch_size, 1).fill_(start_id).long()) # (batch_size, 1)
for i in range(max_len - 1):
# input("==========")
# print("="*10, i)
# print("ys", ys.size()) # (batch_size, i)
# print("tgt_mask", subsequent_mask(ys.size(1)).size()) # (1, i, i)
out = self.decode(latent.unsqueeze(1), to_var(ys), to_var(subsequent_mask(ys.size(1)).long()))
prob = self.generator(out[:, -1])
# print("prob", prob.size()) # (batch_size, vocab_size)
_, next_word = torch.max(prob, dim=1)
# print("next_word", next_word.size()) # (batch_size)
# print("next_word.unsqueeze(1)", next_word.unsqueeze(1).size())
ys = torch.cat([ys, next_word.unsqueeze(1)], dim=1)
# print("ys", ys.size())
return ys[:, 1:]
示例2: crop_images_random
# 需要导入模块: import torch [as 别名]
# 或者: from torch import max [as 别名]
def crop_images_random(path='../images/', scale=0.50): # from utils.utils import *; crop_images_random()
# crops images into random squares up to scale fraction
# WARNING: overwrites images!
for file in tqdm(sorted(glob.glob('%s/*.*' % path))):
img = cv2.imread(file) # BGR
if img is not None:
h, w = img.shape[:2]
# create random mask
a = 30 # minimum size (pixels)
mask_h = random.randint(a, int(max(a, h * scale))) # mask height
mask_w = mask_h # mask width
# box
xmin = max(0, random.randint(0, w) - mask_w // 2)
ymin = max(0, random.randint(0, h) - mask_h // 2)
xmax = min(w, xmin + mask_w)
ymax = min(h, ymin + mask_h)
# apply random color mask
cv2.imwrite(file, img[ymin:ymax, xmin:xmax])
示例3: plot_evolution_results
# 需要导入模块: import torch [as 别名]
# 或者: from torch import max [as 别名]
def plot_evolution_results(hyp): # from utils.utils import *; plot_evolution_results(hyp)
# Plot hyperparameter evolution results in evolve.txt
x = np.loadtxt('evolve.txt', ndmin=2)
f = fitness(x)
weights = (f - f.min()) ** 2 # for weighted results
fig = plt.figure(figsize=(12, 10))
matplotlib.rc('font', **{'size': 8})
for i, (k, v) in enumerate(hyp.items()):
y = x[:, i + 5]
# mu = (y * weights).sum() / weights.sum() # best weighted result
mu = y[f.argmax()] # best single result
plt.subplot(4, 5, i + 1)
plt.plot(mu, f.max(), 'o', markersize=10)
plt.plot(y, f, '.')
plt.title('%s = %.3g' % (k, mu), fontdict={'size': 9}) # limit to 40 characters
print('%15s: %.3g' % (k, mu))
fig.tight_layout()
plt.savefig('evolve.png', dpi=200)
示例4: record_output
# 需要导入模块: import torch [as 别名]
# 或者: from torch import max [as 别名]
def record_output(self, output, output_indices, target, prev_absolutes,
next_absolutes, batch_size=1):
assert output.dim() == 4
assert target.dim() == 3
_, predictions = output.max(3)
# Compute per class accuracy for unbalanced data.
sequence_length = output.size(1)
num_label = output.size(2)
num_class = output.size(3)
correct_alljoint = (target == predictions).float().sum(2)
sum_of_corrects = correct_alljoint.sum(1)
max_value = num_label * sequence_length
count_correct = (sum_of_corrects == max_value).float().mean()
correct_per_seq = ((correct_alljoint == num_label - 1).sum(1).float() /
sequence_length).mean()
self.meter.update(
torch.Tensor([count_correct * 100, correct_per_seq * 100]),
batch_size)
示例5: extract_feature
# 需要导入模块: import torch [as 别名]
# 或者: from torch import max [as 别名]
def extract_feature(model, dataloader, save_path, load_from_disk=True, model_path=''):
if load_from_disk:
model = models.Network(base_net=args.model_name,
n_class=args.num_class)
model.load_state_dict(torch.load(model_path))
model = model.to(DEVICE)
model.eval()
correct = 0
fea_all = torch.zeros(1,1+model.base_network.output_num()).to(DEVICE)
with torch.no_grad():
for inputs, labels in dataloader:
inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
feas = model.get_features(inputs)
labels = labels.view(labels.size(0), 1).float()
x = torch.cat((feas, labels), dim=1)
fea_all = torch.cat((fea_all, x), dim=0)
outputs = model(inputs)
preds = torch.max(outputs, 1)[1]
correct += torch.sum(preds == labels.data.long())
test_acc = correct.double() / len(dataloader.dataset)
fea_numpy = fea_all.cpu().numpy()
np.savetxt(save_path, fea_numpy[1:], fmt='%.6f', delimiter=',')
print('Test acc: %f' % test_acc)
# You may want to classify with 1nn after getting features
示例6: test
# 需要导入模块: import torch [as 别名]
# 或者: from torch import max [as 别名]
def test(model, target_test_loader):
model.eval()
test_loss = utils.AverageMeter()
correct = 0
criterion = torch.nn.CrossEntropyLoss()
len_target_dataset = len(target_test_loader.dataset)
with torch.no_grad():
for data, target in target_test_loader:
data, target = data.to(DEVICE), target.to(DEVICE)
s_output = model.predict(data)
loss = criterion(s_output, target)
test_loss.update(loss.item())
pred = torch.max(s_output, 1)[1]
correct += torch.sum(pred == target)
print('{} --> {}: max correct: {}, accuracy{: .2f}%\n'.format(
source_name, target_name, correct, 100. * correct / len_target_dataset))
示例7: get
# 需要导入模块: import torch [as 别名]
# 或者: from torch import max [as 别名]
def get(self):
"""Get top cost recorded.
You should call step before get at least once.
Returns:
A float number of cost.
"""
assert len(self.coll) != 0, 'Please call step before get'
if self.type == 'mean':
ret = np.mean(self.coll)
elif self.type == 'max':
ret = np.max(self.coll)
else:
ret = np.min(self.coll)
return ret
示例8: evo_norm
# 需要导入模块: import torch [as 别名]
# 或者: from torch import max [as 别名]
def evo_norm(x, prefix, running_var, v, weight, bias,
training, momentum, eps=0.1, groups=32):
if prefix == 'b0':
if training:
var = torch.var(x, dim=(0, 2, 3), keepdim=True)
running_var.mul_(momentum)
running_var.add_((1 - momentum) * var)
else:
var = running_var
if v is not None:
den = torch.max((var + eps).sqrt(), v * x + instance_std(x, eps))
x = x / den * weight + bias
else:
x = x * weight + bias
else:
if v is not None:
x = x * torch.sigmoid(v * x) / group_std(x,
groups, eps) * weight + bias
else:
x = x * weight + bias
return x
示例9: forward
# 需要导入模块: import torch [as 别名]
# 或者: from torch import max [as 别名]
def forward(self,x):
input_size = x.size()[2]
self.interp1 = nn.UpsamplingBilinear2d(size = ( int(input_size*0.75)+1, int(input_size*0.75)+1 ))
self.interp2 = nn.UpsamplingBilinear2d(size = ( int(input_size*0.5)+1, int(input_size*0.5)+1 ))
self.interp3 = nn.UpsamplingBilinear2d(size = ( outS(input_size), outS(input_size) ))
out = []
x2 = self.interp1(x)
x3 = self.interp2(x)
out.append(self.Scale(x)) #1.0x
out.append(self.interp3(self.Scale(x2))) #0.75x
out.append(self.interp3(self.Scale(x3))) #0.5x
#out.append(self.Scale(x3)) # for 0.5x scale
x2Out_interp = out[1]
x3Out_interp = out[2]
temp1 = torch.max(out[0],x2Out_interp)
out.append(torch.max(temp1,x3Out_interp))
return out
示例10: do_eval
# 需要导入模块: import torch [as 别名]
# 或者: from torch import max [as 别名]
def do_eval(model, test_loader, cuda):
model.is_training = False
predicted = []
true_label = []
for X, y in test_loader:
X = Variable(X)
if cuda:
X = X.cuda()
output = model(X)
output = output.squeeze(0)
_, output = torch.max(output, 1)
if cuda:
output = output.cpu()
predicted.extend(output.data.numpy().tolist())
y = y.squeeze(0)
true_label.extend(y.numpy().tolist())
print("Acc: %.3f" % accuracy(predicted, true_label))
return predicted
示例11: shem
# 需要导入模块: import torch [as 别名]
# 或者: from torch import max [as 别名]
def shem(roi_probs_neg, negative_count, ohem_poolsize):
"""
stochastic hard example mining: from a list of indices (referring to non-matched predictions),
determine a pool of highest scoring (worst false positives) of size negative_count*ohem_poolsize.
Then, sample n (= negative_count) predictions of this pool as negative examples for loss.
:param roi_probs_neg: tensor of shape (n_predictions, n_classes).
:param negative_count: int.
:param ohem_poolsize: int.
:return: (negative_count). indices refer to the positions in roi_probs_neg. If pool smaller than expected due to
limited negative proposals availabel, this function will return sampled indices of number < negative_count without
throwing an error.
"""
# sort according to higehst foreground score.
probs, order = roi_probs_neg[:, 1:].max(1)[0].sort(descending=True)
select = torch.tensor((ohem_poolsize * int(negative_count), order.size()[0])).min().int()
pool_indices = order[:select]
rand_idx = torch.randperm(pool_indices.size()[0])
return pool_indices[rand_idx[:negative_count].cuda()]
示例12: iou
# 需要导入模块: import torch [as 别名]
# 或者: from torch import max [as 别名]
def iou(source: Tensor, other: Tensor) -> Tensor:
source, other = source.unsqueeze(dim=-2).repeat(1, 1, other.shape[-2], 1), \
other.unsqueeze(dim=-3).repeat(1, source.shape[-2], 1, 1)
source_area = (source[..., 2] - source[..., 0]) * (source[..., 3] - source[..., 1])
other_area = (other[..., 2] - other[..., 0]) * (other[..., 3] - other[..., 1])
intersection_left = torch.max(source[..., 0], other[..., 0])
intersection_top = torch.max(source[..., 1], other[..., 1])
intersection_right = torch.min(source[..., 2], other[..., 2])
intersection_bottom = torch.min(source[..., 3], other[..., 3])
intersection_width = torch.clamp(intersection_right - intersection_left, min=0)
intersection_height = torch.clamp(intersection_bottom - intersection_top, min=0)
intersection_area = intersection_width * intersection_height
return intersection_area / (source_area + other_area - intersection_area)
示例13: _get_waveform_and_window_properties
# 需要导入模块: import torch [as 别名]
# 或者: from torch import max [as 别名]
def _get_waveform_and_window_properties(waveform: Tensor,
channel: int,
sample_frequency: float,
frame_shift: float,
frame_length: float,
round_to_power_of_two: bool,
preemphasis_coefficient: float) -> Tuple[Tensor, int, int, int]:
r"""Gets the waveform and window properties
"""
channel = max(channel, 0)
assert channel < waveform.size(0), ('Invalid channel %d for size %d' % (channel, waveform.size(0)))
waveform = waveform[channel, :] # size (n)
window_shift = int(sample_frequency * frame_shift * MILLISECONDS_TO_SECONDS)
window_size = int(sample_frequency * frame_length * MILLISECONDS_TO_SECONDS)
padded_window_size = _next_power_of_2(window_size) if round_to_power_of_two else window_size
assert 2 <= window_size <= len(waveform), ('choose a window size %d that is [2, %d]' % (window_size, len(waveform)))
assert 0 < window_shift, '`window_shift` must be greater than 0'
assert padded_window_size % 2 == 0, 'the padded `window_size` must be divisible by two.' \
' use `round_to_power_of_two` or change `frame_length`'
assert 0. <= preemphasis_coefficient <= 1.0, '`preemphasis_coefficient` must be between [0,1]'
assert sample_frequency > 0, '`sample_frequency` must be greater than zero'
return waveform, window_shift, window_size, padded_window_size
示例14: decode_greedy
# 需要导入模块: import torch [as 别名]
# 或者: from torch import max [as 别名]
def decode_greedy(self, logits, seq_lens):
decoded = []
tlogits = logits.transpose(0, 1)
_, tokens = torch.max(tlogits, 2)
for i in range(tlogits.size(0)):
output_str = self.convert_to_string(tokens[i], seq_lens[i])
decoded.append(output_str)
return decoded
示例15: _region_classification
# 需要导入模块: import torch [as 别名]
# 或者: from torch import max [as 别名]
def _region_classification(self, fc7):
cls_score = self.cls_score_net(fc7)
cls_pred = torch.max(cls_score, 1)[1] # the prediction class of each bbox
cls_prob = F.softmax(cls_score)
bbox_pred = self.bbox_pred_net(fc7)
bbox_prob = torch.stack([F.softmax(bbox_pred[:,i]) for i in range(bbox_pred.size(1))], 1)
fuse_prob = cls_prob.mul(bbox_prob)
image_prob = fuse_prob.sum(0,keepdim=True)
self._predictions["cls_pred"] = cls_pred
self._predictions["cls_prob"] = cls_prob
self._predictions["bbox_prob"] = bbox_prob
self._predictions["fuse_prob"] = fuse_prob
self._predictions["image_prob"] = image_prob
return cls_prob, bbox_prob, fuse_prob, image_prob
开发者ID:Sunarker,项目名称:Collaborative-Learning-for-Weakly-Supervised-Object-Detection,代码行数:18,代码来源:network.py