本文整理汇总了Python中torch.nn.modules.Module方法的典型用法代码示例。如果您正苦于以下问题:Python modules.Module方法的具体用法?Python modules.Module怎么用?Python modules.Module使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类torch.nn.modules
的用法示例。
在下文中一共展示了modules.Module方法的2个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: __init__
# 需要导入模块: from torch.nn import modules [as 别名]
# 或者: from torch.nn.modules import Module [as 别名]
def __init__(self, module_or_grads_list):
if isinstance(module_or_grads_list, Module):
self.module = module_or_grads_list
flat_dist_call([param.data for param in self.module.parameters()], dist.broadcast, (0,) )
else:
self.module = None
self.grads = []
extract_tensors(module_or_grads_list, self.grads)
示例2: plot_losses
# 需要导入模块: from torch.nn import modules [as 别名]
# 或者: from torch.nn.modules import Module [as 别名]
def plot_losses(
losses: Union[nn.Module, List[nn.Module]],
visdom_server: Optional["visdom.Visdom"] = None,
env: Optional[str] = None,
win: Optional[str] = None,
title: str = "",
) -> Any:
"""Constructs a plot of specified losses as function of y * f(x). The losses
are a list of nn.Module losses. Optionally, the environment, window handle,
and title for the visdom plot can be specified.
"""
if visdom_server is None and visdom_connected():
visdom_server = vis[-1]
# return if we are not connected to visdom server:
if not visdom_server or not visdom_server.check_connection():
print("WARNING: Not connected to visdom. Skipping plotting.")
return
# assertions:
if isinstance(losses, nn.Module):
losses = [losses]
assert type(losses) == list
assert all(isinstance(loss, nn.Module) for loss in losses)
if any(isinstance(loss, UNSUPPORTED_LOSSES) for loss in losses):
raise NotImplementedError("loss function not supported")
# loop over all loss functions:
for idx, loss in enumerate(losses):
# construct scores and targets:
score = torch.arange(-5.0, 5.0, 0.005)
if idx == 0:
loss_val = torch.FloatTensor(score.size(0), len(losses))
if isinstance(loss, REGRESSION_LOSSES):
target = torch.FloatTensor(score.size()).fill_(0.0)
else:
target = torch.LongTensor(score.size()).fill_(1)
# compute loss values:
for n in range(0, score.nelement()):
loss_val[n][idx] = loss(
score.narrow(0, n, 1), target.narrow(0, n, 1)
).item()
# show plot:
title = str(loss) if title == "" else title
legend = [str(loss) for loss in losses]
opts = {"title": title, "xlabel": "Score", "ylabel": "Loss", "legend": legend}
win = visdom_server.line(loss_val, score, env=env, win=win, opts=opts)
return win