当前位置: 首页>>代码示例>>Python>>正文


Python init.xavier_normal方法代码示例

本文整理汇总了Python中torch.nn.init.xavier_normal方法的典型用法代码示例。如果您正苦于以下问题:Python init.xavier_normal方法的具体用法?Python init.xavier_normal怎么用?Python init.xavier_normal使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在torch.nn.init的用法示例。


在下文中一共展示了init.xavier_normal方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: __init__

# 需要导入模块: from torch.nn import init [as 别名]
# 或者: from torch.nn.init import xavier_normal [as 别名]
def __init__(self, vocab_size, embed_size, latent_size, decoder_size, decoder_num_layers):
        super(VAE, self).__init__()

        self.latent_size = latent_size
        self.vocab_size = vocab_size
        self.embed_size = embed_size

        self.embed = nn.Embedding(self.vocab_size, self.embed_size)
        self.embed.weight = xavier_normal(self.embed.weight)

        self.encoder = Encoder(self.embed_size, self.latent_size)

        self.context_to_mu = nn.Linear(self.latent_size, self.latent_size)
        self.context_to_logvar = nn.Linear(self.latent_size, self.latent_size)

        self.decoder = Decoder(self.vocab_size, self.latent_size, decoder_size, decoder_num_layers, self.embed_size) 
开发者ID:kefirski,项目名称:hybrid_rvae,代码行数:18,代码来源:vae.py

示例2: weights_init

# 需要导入模块: from torch.nn import init [as 别名]
# 或者: from torch.nn.init import xavier_normal [as 别名]
def weights_init(init_type='xavier'):
    def init_fun(m):
        classname = m.__class__.__name__
        if (classname.find('Conv') == 0 or classname.find('Linear') == 0) and hasattr(m, 'weight'):
            if init_type == 'normal':
                init.normal(m.weight.data, 0.0, 0.02)
            elif init_type == 'xavier':
                init.xavier_normal(m.weight.data, gain=math.sqrt(2))
            elif init_type == 'kaiming':
                init.kaiming_normal(m.weight.data, a=0, mode='fan_in')
            elif init_type == 'orthogonal':
                init.orthogonal(m.weight.data, gain=math.sqrt(2))
            elif init_type == 'default':
                pass
            else:
                assert 0, "Unsupported initialization: {}".format(init_type)
            if hasattr(m, 'bias') and m.bias is not None:
                init.constant(m.bias.data, 0.0)
        elif (classname.find('Norm') == 0):
            if hasattr(m, 'weight') and m.weight is not None:
                init.constant(m.weight.data, 1.0)
            if hasattr(m, 'bias') and m.bias is not None:
                init.constant(m.bias.data, 0.0)
    return init_fun 
开发者ID:Xiaoming-Yu,项目名称:DMIT,代码行数:26,代码来源:network.py

示例3: __init__

# 需要导入模块: from torch.nn import init [as 别名]
# 或者: from torch.nn.init import xavier_normal [as 别名]
def __init__(self, n_head, d_input, d_model, d_input_v=None, dropout=0.1):
        super(MultiHeadAttention, self).__init__()

        self.n_head = n_head
        d_k, d_v = d_model//n_head, d_model//n_head
        self.d_k = d_k
        self.d_v = d_v

        if d_input_v is None:
            d_input_v = d_input

        self.w_qs = nn.Parameter(torch.FloatTensor(n_head, d_input, d_k))
        self.w_ks = nn.Parameter(torch.FloatTensor(n_head, d_input, d_k))
        self.w_vs = nn.Parameter(torch.FloatTensor(n_head, d_input_v, d_v))

        self.attention = DotProductAttention(d_model)
        # self.attention = SingleLayerAttention(d_model, d_k)
        # self.layer_norm = LayerNormalization(d_model)
        self.proj = Linear(n_head*d_v, d_model)

        self.dropout = nn.Dropout(dropout)

        init.xavier_normal(self.w_qs)
        init.xavier_normal(self.w_ks)
        init.xavier_normal(self.w_vs) 
开发者ID:thomas0809,项目名称:GraphIE,代码行数:27,代码来源:attention.py

示例4: weights_init

# 需要导入模块: from torch.nn import init [as 别名]
# 或者: from torch.nn.init import xavier_normal [as 别名]
def weights_init(init_type='xavier'):
    def init_fun(m):
        classname = m.__class__.__name__
        if (classname.find('Conv') == 0 or classname.find('Linear') == 0) and hasattr(m, 'weight'):
            if init_type == 'normal':
                init.normal(m.weight.data, 0.0, 0.02)
            elif init_type == 'xavier':
                init.xavier_normal(m.weight.data, gain=math.sqrt(2))
            elif init_type == 'kaiming':
                init.kaiming_normal(m.weight.data, a=0, mode='fan_in')
            elif init_type == 'orthogonal':
                init.orthogonal(m.weight.data, gain=math.sqrt(2))
            elif init_type == 'default':
                pass
            else:
                assert 0, "Unsupported initialization: {}".format(init_type)
            if hasattr(m, 'bias') and m.bias is not None:
                init.constant(m.bias.data, 0.0)
    return init_fun 
开发者ID:Xiaoming-Yu,项目名称:SingleGAN,代码行数:21,代码来源:model.py

示例5: __init__

# 需要导入模块: from torch.nn import init [as 别名]
# 或者: from torch.nn.init import xavier_normal [as 别名]
def __init__(self, dims):
        """
        M2 code replication from the paper
        'Semi-Supervised Learning with Deep Generative Models'
        (Kingma 2014) in PyTorch.

        The "Generative semi-supervised model" is a probabilistic
        model that incorporates label information in both
        inference and generation.

        Initialise a new generative model
        :param dims: dimensions of x, y, z and hidden layers.
        """
        [x_dim, self.y_dim, z_dim, h_dim] = dims
        super(DeepGenerativeModel, self).__init__([x_dim, z_dim, h_dim])

        self.encoder = Encoder([x_dim + self.y_dim, h_dim, z_dim])
        self.decoder = Decoder([z_dim + self.y_dim, list(reversed(h_dim)), x_dim])
        self.classifier = Classifier([x_dim, h_dim[0], self.y_dim])

        for m in self.modules():
            if isinstance(m, nn.Linear):
                init.xavier_normal(m.weight.data)
                if m.bias is not None:
                    m.bias.data.zero_() 
开发者ID:wohlert,项目名称:semi-supervised-pytorch,代码行数:27,代码来源:dgm.py

示例6: __init__

# 需要导入模块: from torch.nn import init [as 别名]
# 或者: from torch.nn.init import xavier_normal [as 别名]
def __init__(self, dims):
        """
        Variational Autoencoder [Kingma 2013] model
        consisting of an encoder/decoder pair for which
        a variational distribution is fitted to the
        encoder. Also known as the M1 model in [Kingma 2014].

        :param dims: x, z and hidden dimensions of the networks
        """
        super(VariationalAutoencoder, self).__init__()

        [x_dim, z_dim, h_dim] = dims
        self.z_dim = z_dim
        self.flow = None

        self.encoder = Encoder([x_dim, h_dim, z_dim])
        self.decoder = Decoder([z_dim, list(reversed(h_dim)), x_dim])
        self.kl_divergence = 0

        for m in self.modules():
            if isinstance(m, nn.Linear):
                init.xavier_normal(m.weight.data)
                if m.bias is not None:
                    m.bias.data.zero_() 
开发者ID:wohlert,项目名称:semi-supervised-pytorch,代码行数:26,代码来源:vae.py

示例7: __init__

# 需要导入模块: from torch.nn import init [as 别名]
# 或者: from torch.nn.init import xavier_normal [as 别名]
def __init__(self, inplanes, planes, dilation = 1, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes,dilation, stride)
        # self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        # self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
                # weight_init.xavier_normal()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_() 
开发者ID:baowenbo,项目名称:DAIN,代码行数:20,代码来源:BasicBlock.py

示例8: __init__

# 需要导入模块: from torch.nn import init [as 别名]
# 或者: from torch.nn.init import xavier_normal [as 别名]
def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
        super(MultiHeadAttention, self).__init__()

        self.n_head = n_head
        self.d_k = d_k
        self.d_v = d_v

        self.w_qs = nn.Parameter(torch.FloatTensor(n_head, d_model, d_k))
        self.w_ks = nn.Parameter(torch.FloatTensor(n_head, d_model, d_k))
        self.w_vs = nn.Parameter(torch.FloatTensor(n_head, d_model, d_v))

        self.attention = ScaledDotProductAttention(d_model)
        self.layer_norm = LayerNormalization(d_model)
        self.proj = Linear(n_head*d_v, d_model)

        self.dropout = nn.Dropout(dropout)

        init.xavier_normal(self.w_qs)
        init.xavier_normal(self.w_ks)
        init.xavier_normal(self.w_vs) 
开发者ID:wabyking,项目名称:TextClassificationBenchmark,代码行数:22,代码来源:Transformer.py

示例9: init_weights_xavier

# 需要导入模块: from torch.nn import init [as 别名]
# 或者: from torch.nn.init import xavier_normal [as 别名]
def init_weights_xavier(model):
    if isinstance(model, nn.Conv2d):
        init.xavier_normal(model.weight)
        init.constant(model.bias, 0) 
开发者ID:minerva-ml,项目名称:steppy-toolkit,代码行数:6,代码来源:models.py

示例10: weights_init_xavier

# 需要导入模块: from torch.nn import init [as 别名]
# 或者: from torch.nn.init import xavier_normal [as 别名]
def weights_init_xavier(m):
    classname = m.__class__.__name__
    #print(classname)
    if classname.find('Conv') != -1:
        init.xavier_normal(m.weight.data, gain=1)
    elif classname.find('Linear') != -1:
        init.xavier_normal(m.weight.data, gain=1)
    elif classname.find('BatchNorm') != -1:
        init.normal(m.weight.data, 1.0, 0.02)
        init.constant(m.bias.data, 0.0) 
开发者ID:ozan-oktay,项目名称:Attention-Gated-Networks,代码行数:12,代码来源:networks_other.py

示例11: weights_init_xavier

# 需要导入模块: from torch.nn import init [as 别名]
# 或者: from torch.nn.init import xavier_normal [as 别名]
def weights_init_xavier(m):
    classname = m.__class__.__name__
    # print(classname)
    if classname.find('Conv') != -1:
        init.xavier_normal(m.weight.data, gain=0.02)
    elif classname.find('Linear') != -1:
        init.xavier_normal(m.weight.data, gain=0.02)
    elif classname.find('BatchNorm2d') != -1:
        init.normal(m.weight.data, 1.0, 0.02)
        init.constant(m.bias.data, 0.0) 
开发者ID:joelmoniz,项目名称:DepthNets,代码行数:12,代码来源:networks.py

示例12: weight_initializaton

# 需要导入模块: from torch.nn import init [as 别名]
# 或者: from torch.nn.init import xavier_normal [as 别名]
def weight_initializaton(m):
        if isinstance(m, nn.Conv2d):
            init.xavier_normal(m.weight)
            init.constant(m.bias, 0) 
开发者ID:maxjiang93,项目名称:space_time_pde,代码行数:6,代码来源:unet.py

示例13: weights_init_xavier

# 需要导入模块: from torch.nn import init [as 别名]
# 或者: from torch.nn.init import xavier_normal [as 别名]
def weights_init_xavier(self, m):
        classname = m.__class__.__name__
        if classname.find('Conv') != -1:
            init.xavier_normal(m.weight.data, gain=1.0)
        elif classname.find('Linear') != -1:
            init.xavier_normal(m.weight.data, gain=1.0) 
开发者ID:wtjiang98,项目名称:BeautyGAN_pytorch,代码行数:8,代码来源:solver_makeup.py

示例14: weights_init_xavier

# 需要导入模块: from torch.nn import init [as 别名]
# 或者: from torch.nn.init import xavier_normal [as 别名]
def weights_init_xavier(m):
    classname = m.__class__.__name__
    # print(classname)
    if classname.find('Conv') != -1:
        init.xavier_normal(m.weight.data, gain=1)
    elif classname.find('Linear') != -1:
        init.xavier_normal(m.weight.data, gain=1)
    elif classname.find('BatchNorm2d') != -1:
        init.uniform(m.weight.data, 1.0, 0.02)
        init.constant(m.bias.data, 0.0) 
开发者ID:arnabgho,项目名称:iSketchNFill,代码行数:12,代码来源:networks.py

示例15: initialize_weights

# 需要导入模块: from torch.nn import init [as 别名]
# 或者: from torch.nn.init import xavier_normal [as 别名]
def initialize_weights(method='kaiming', *models):
    for model in models:
        for module in model.modules():

            if isinstance(module, nn.Conv2d) or isinstance(module, nn.ConvTranspose2d) or isinstance(module, nn.Linear):
                if method == 'kaiming':
                    init.kaiming_normal(module.weight.data, np.sqrt(2.0))
                elif method == 'xavier':
                    init.xavier_normal(module.weight.data, np.sqrt(2.0))
                elif method == 'orthogonal':
                    init.orthogonal(module.weight.data, np.sqrt(2.0))
                elif method == 'normal':
                    init.normal(module.weight.data,mean=0, std=0.02)
                if module.bias is not None:
                    init.constant(module.bias.data,0) 
开发者ID:saeedizadi,项目名称:binseg_pytoch,代码行数:17,代码来源:tools.py


注:本文中的torch.nn.init.xavier_normal方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。