本文整理汇总了Python中torch.square方法的典型用法代码示例。如果您正苦于以下问题:Python torch.square方法的具体用法?Python torch.square怎么用?Python torch.square使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类torch
的用法示例。
在下文中一共展示了torch.square方法的6个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: calc_ams
# 需要导入模块: import torch [as 别名]
# 或者: from torch import square [as 别名]
def calc_ams(s:float, b:float, br:float=0, unc_b:float=0) -> float:
r'''
Compute Approximate Median Significance (https://arxiv.org/abs/1007.1727)
Arguments:
s: signal weight
b: background weight
br: background offset bias
unc_b: fractional systemtatic uncertainty on background
Returns:
Approximate Median Significance if b > 0 else -1
'''
if b == 0: return -1
if not unc_b:
radicand = 2*((s+b+br)*np.log(1.0+s/(b+br))-s)
else:
sigma_b_2 = np.square(unc_b*b)
radicand = 2*(((s+b)*np.log((s+b)*(b+sigma_b_2)/((b**2)+((s+b)*sigma_b_2))))-(((b**2)/sigma_b_2)*np.log(1+((sigma_b_2*s)/(b*(b+sigma_b_2))))))
return np.sqrt(radicand) if radicand > 0 else -1
示例2: calc_ams_torch
# 需要导入模块: import torch [as 别名]
# 或者: from torch import square [as 别名]
def calc_ams_torch(s:Tensor, b:Tensor, br:float=0, unc_b:float=0) -> Tensor:
r'''
Compute Approximate Median Significance (https://arxiv.org/abs/1007.1727) using Tensor inputs
Arguments:
s: signal weight
b: background weight
br: background offset bias
unc_b: fractional systemtatic uncertainty on background
Returns:
Approximate Median Significance if b > 0 else 1e-18 * s
'''
'''Compute Approximate Median Significance with torch for signal (background) weight s (b),
fractional systemtatic uncertainty unc_b, and offset br'''
if b == 0: return 1e-18*s
if not unc_b:
radicand = 2*((s+b+br)*torch.log(1.0+s/(b+br))-s)
else:
sigma_b_2 = torch.square(unc_b*b)
radicand = 2*(((s+b)*torch.log((s+b)*(b+sigma_b_2)/((b**2)+((s+b)*sigma_b_2))))-(((b**2)/sigma_b_2)*torch.log(1+((sigma_b_2*s)/(b*(b+sigma_b_2))))))
return torch.sqrt(radicand) if radicand > 0 else 1e-18*s
示例3: phi
# 需要导入模块: import torch [as 别名]
# 或者: from torch import square [as 别名]
def phi(r, order):
"""Coordinate-wise nonlinearity used to define the order of the interpolation.
See https://en.wikipedia.org/wiki/Polyharmonic_spline for the definition.
Args:
r: input op
order: interpolation order
Returns:
phi_k evaluated coordinate-wise on r, for k = r
"""
EPSILON = torch.tensor(1e-10)
# using EPSILON prevents log(0), sqrt0), etc.
# sqrt(0) is well-defined, but its gradient is not
if order == 1:
r = torch.max(r, EPSILON)
r = torch.sqrt(r)
return r
elif order == 2:
return 0.5 * r * torch.log(torch.max(r, EPSILON))
elif order == 4:
return 0.5 * torch.square(r) * torch.log(torch.max(r, EPSILON))
elif order % 2 == 0:
r = torch.max(r, EPSILON)
return 0.5 * torch.pow(r, 0.5 * order) * torch.log(r)
else:
r = torch.max(r, EPSILON)
return torch.pow(r, 0.5 * order)
示例4: phi
# 需要导入模块: import torch [as 别名]
# 或者: from torch import square [as 别名]
def phi(r, order):
"""Coordinate-wise nonlinearity used to define the order of the interpolation.
See https://en.wikipedia.org/wiki/Polyharmonic_spline for the definition.
Args:
r: input op
order: interpolation order
Returns:
phi_k evaluated coordinate-wise on r, for k = r
"""
EPSILON = torch.tensor(1e-10, device=r.device)
# using EPSILON prevents log(0), sqrt0), etc.
# sqrt(0) is well-defined, but its gradient is not
if order == 1:
r = torch.max(r, EPSILON)
r = torch.sqrt(r)
return r
elif order == 2:
return 0.5 * r * torch.log(torch.max(r, EPSILON))
elif order == 4:
return 0.5 * torch.square(r) * torch.log(torch.max(r, EPSILON))
elif order % 2 == 0:
r = torch.max(r, EPSILON)
return 0.5 * torch.pow(r, 0.5 * order) * torch.log(r)
else:
r = torch.max(r, EPSILON)
return torch.pow(r, 0.5 * order)
示例5: forward
# 需要导入模块: import torch [as 别名]
# 或者: from torch import square [as 别名]
def forward(self, x):
size = x.size()
x = x.view(x.size(0), -1)
mean = th.mean(x, 1).expand_as(x)
center = x - mean
std = th.sqrt(th.mean(th.square(center), 1)).expand_as(x)
output = center / (std + self.epsilon)
if self.learnable:
output = self.alpha * output + self.beta
return output.view(size)
示例6: ams_scan_slow
# 需要导入模块: import torch [as 别名]
# 或者: from torch import square [as 别名]
def ams_scan_slow(df:pd.DataFrame, wgt_factor:float=1, br:float=0, syst_unc_b:float=0,
use_stat_unc:bool=False, start_cut:float=0.9, min_events:int=10,
pred_name:str='pred', targ_name:str='gen_target', wgt_name:str='gen_weight', show_prog:bool=True) -> Tuple[float,float]:
r'''
Scan accross a range of possible prediction thresholds in order to maximise the Approximate Median Significance (https://arxiv.org/abs/1007.1727).
Note that whilst this method is slower than :meth:`~lumin.evaluation.ams.ams_scan_quick`, it does not suffer as much from float precison.
Additionally it allows one to account for statistical uncertainty in AMS calculation.
Arguments:
df: DataFrame containing prediction data
wgt_factor: factor to reweight signal and background weights
br: background offset bias
syst_unc_b: fractional systemtatic uncertainty on background
use_stat_unc: whether to account for the statistical uncertainty on the background
start_cut: minimum prediction to consider; useful for speeding up scan
min_events: minimum number of background unscaled events required to pass threshold
pred_name: column to use as predictions
targ_name: column to use as truth labels for signal and background
wgt_name: column to use as weights for signal and background events
show_prog: whether to display progress and ETA of scan
Returns:
maximum AMS
prediction threshold corresponding to maximum AMS
'''
max_ams, threshold = 0, 0.0
sig, bkg = df[df[targ_name] == 1], df[df[targ_name] == 0]
syst_unc_b2 = np.square(syst_unc_b)
for i, cut in enumerate(progress_bar(df.loc[df[pred_name] >= start_cut, pred_name].values, display=show_prog, leave=show_prog)):
bkg_pass = bkg.loc[(bkg[pred_name] >= cut), wgt_name]
n_bkg = len(bkg_pass)
if n_bkg < min_events: continue
s = np.sum(sig.loc[(sig[pred_name] >= cut), wgt_name])
b = np.sum(bkg_pass)
if use_stat_unc: unc_b = np.sqrt(syst_unc_b2+(1/n_bkg))
else: unc_b = syst_unc_b
ams = calc_ams(s*wgt_factor, b*wgt_factor, br, unc_b)
if ams > max_ams: max_ams, threshold = ams, cut
return max_ams, threshold