当前位置: 首页>>代码示例>>Python>>正文


Python modules.Module方法代码示例

本文整理汇总了Python中torch.nn.modules.Module方法的典型用法代码示例。如果您正苦于以下问题:Python modules.Module方法的具体用法?Python modules.Module怎么用?Python modules.Module使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在torch.nn.modules的用法示例。


在下文中一共展示了modules.Module方法的2个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: __init__

# 需要导入模块: from torch.nn import modules [as 别名]
# 或者: from torch.nn.modules import Module [as 别名]
def __init__(self, module_or_grads_list):
        if isinstance(module_or_grads_list, Module):
            self.module = module_or_grads_list
            flat_dist_call([param.data for param in self.module.parameters()], dist.broadcast, (0,) )

        else:
            self.module = None
            self.grads = []
            extract_tensors(module_or_grads_list, self.grads) 
开发者ID:NVIDIA,项目名称:apex,代码行数:11,代码来源:distributed.py

示例2: plot_losses

# 需要导入模块: from torch.nn import modules [as 别名]
# 或者: from torch.nn.modules import Module [as 别名]
def plot_losses(
    losses: Union[nn.Module, List[nn.Module]],
    visdom_server: Optional["visdom.Visdom"] = None,
    env: Optional[str] = None,
    win: Optional[str] = None,
    title: str = "",
) -> Any:
    """Constructs a plot of specified losses as function of y * f(x). The losses
    are a list of nn.Module losses. Optionally, the environment, window handle,
    and title for the visdom plot can be specified.
    """

    if visdom_server is None and visdom_connected():
        visdom_server = vis[-1]

    # return if we are not connected to visdom server:
    if not visdom_server or not visdom_server.check_connection():
        print("WARNING: Not connected to visdom. Skipping plotting.")
        return

    # assertions:
    if isinstance(losses, nn.Module):
        losses = [losses]
    assert type(losses) == list
    assert all(isinstance(loss, nn.Module) for loss in losses)
    if any(isinstance(loss, UNSUPPORTED_LOSSES) for loss in losses):
        raise NotImplementedError("loss function not supported")

    # loop over all loss functions:
    for idx, loss in enumerate(losses):

        # construct scores and targets:
        score = torch.arange(-5.0, 5.0, 0.005)
        if idx == 0:
            loss_val = torch.FloatTensor(score.size(0), len(losses))
        if isinstance(loss, REGRESSION_LOSSES):
            target = torch.FloatTensor(score.size()).fill_(0.0)
        else:
            target = torch.LongTensor(score.size()).fill_(1)

        # compute loss values:
        for n in range(0, score.nelement()):
            loss_val[n][idx] = loss(
                score.narrow(0, n, 1), target.narrow(0, n, 1)
            ).item()

    # show plot:
    title = str(loss) if title == "" else title
    legend = [str(loss) for loss in losses]
    opts = {"title": title, "xlabel": "Score", "ylabel": "Loss", "legend": legend}
    win = visdom_server.line(loss_val, score, env=env, win=win, opts=opts)
    return win 
开发者ID:facebookresearch,项目名称:ClassyVision,代码行数:54,代码来源:visualize.py


注:本文中的torch.nn.modules.Module方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。