當前位置: 首頁>>代碼示例>>Python>>正文


Python nn.RNNBase方法代碼示例

本文整理匯總了Python中torch.nn.RNNBase方法的典型用法代碼示例。如果您正苦於以下問題:Python nn.RNNBase方法的具體用法?Python nn.RNNBase怎麽用?Python nn.RNNBase使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在torch.nn的用法示例。


在下文中一共展示了nn.RNNBase方法的6個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。

示例1: __init__

# 需要導入模塊: from torch import nn [as 別名]
# 或者: from torch.nn import RNNBase [as 別名]
def __init__(self,mode, input_size, hidden_size,
                 num_layers=1, bias=True, batch_first=False,
                 dropout=0, bidirectional=False,weight_init=None):
        """

        :param mode:
        :param input_size:
        :param hidden_size:
        :param num_layers:
        :param bias:
        :param batch_first:
        :param dropout:
        :param bidirectional:
        :param weight_init:
        """
        super(RNNBase,self).__init__(mode, input_size, hidden_size,
                 num_layers, bias, batch_first, dropout,bidirectional)

        if weight_init is not None:
            for weight in super(RNNBase, self).parameters():
                weight_init(weight) 
開發者ID:johnolafenwa,項目名稱:TorchFusion,代碼行數:23,代碼來源:layers.py

示例2: __init__

# 需要導入模塊: from torch import nn [as 別名]
# 或者: from torch.nn import RNNBase [as 別名]
def __init__(self, cont_feats:List[str], vecs:List[str], feats_per_vec:List[str],
                 depth:int, width:int, bidirectional:bool=False, rnn:nn.RNNBase=nn.RNN,
                 do:float=0., act:str='tanh', stateful:bool=False, freeze:bool=False, **kargs):
        super().__init__(cont_feats=cont_feats, vecs=vecs, feats_per_vec=feats_per_vec, row_wise=True, freeze=freeze)
        self.stateful,self.width,self.bidirectional = stateful,width,bidirectional
        p = partial(rnn, input_size=self.n_fpv, hidden_size=width, num_layers=depth, bias=True, batch_first=True, dropout=do, bidirectional=bidirectional)
        try:              self.rnn = p(nonlinearity=act)
        except TypeError: self.rnn = p()
        self._init_rnn(width)
        self._map_outputs()
        if self.freeze: self.freeze_layers() 
開發者ID:GilesStrong,項目名稱:lumin,代碼行數:13,代碼來源:head.py

示例3: orthogonal_rnn_init_

# 需要導入模塊: from torch import nn [as 別名]
# 或者: from torch.nn import RNNBase [as 別名]
def orthogonal_rnn_init_(cell: nn.RNNBase, gain: float = 1.):
    """
    Orthogonal initialization of recurrent weights
    RNN parameters contain 3 or 4 matrices in one parameter, so we slice it.
    """
    with torch.no_grad():
        for _, hh, _, _ in cell.all_weights:
            for i in range(0, hh.size(0), cell.hidden_size):
                nn.init.orthogonal_(hh.data[i:i + cell.hidden_size], gain=gain) 
開發者ID:joeynmt,項目名稱:joeynmt,代碼行數:11,代碼來源:initialization.py

示例4: lstm_forget_gate_init_

# 需要導入模塊: from torch import nn [as 別名]
# 或者: from torch.nn import RNNBase [as 別名]
def lstm_forget_gate_init_(cell: nn.RNNBase, value: float = 1.) -> None:
    """
    Initialize LSTM forget gates with `value`.

    :param cell: LSTM cell
    :param value: initial value, default: 1
    """
    with torch.no_grad():
        for _, _, ih_b, hh_b in cell.all_weights:
            l = len(ih_b)
            ih_b.data[l // 4:l // 2].fill_(value)
            hh_b.data[l // 4:l // 2].fill_(value) 
開發者ID:joeynmt,項目名稱:joeynmt,代碼行數:14,代碼來源:initialization.py

示例5: _setup

# 需要導入模塊: from torch import nn [as 別名]
# 或者: from torch.nn import RNNBase [as 別名]
def _setup(self):
        # Terrible temporary solution to an issue regarding compacting weights re: CUDNN RNN
        if issubclass(type(self.module), nn.RNNBase):
            self.module.flatten_parameters = self.widget_demagnetizer_y2k_edition

        for name_w in self.weights:
            print('Applying weight drop of {} to {}'.format(self.dropout, name_w))
            w = getattr(self.module, name_w)
            del self.module._parameters[name_w]
            self.module.register_parameter(name_w + '_raw', nn.Parameter(w.data)) 
開發者ID:iwangjian,項目名稱:ByteCup2018,代碼行數:12,代碼來源:dropout.py

示例6: __init__

# 需要導入模塊: from torch import nn [as 別名]
# 或者: from torch.nn import RNNBase [as 別名]
def __init__(self,
                 rnn_type,
                 input_size,
                 hidden_size,
                 num_layers=1,
                 bias=True,
                 dropout=0.0,
                 bidirectional=False):
        """
        Args:
            rnn_type: The type of RNN to use as encoder in the module.
                Must be a class inheriting from torch.nn.RNNBase
                (such as torch.nn.LSTM for example).
            input_size: The number of expected features in the input of the
                module.
            hidden_size: The number of features in the hidden state of the RNN
                used as encoder by the module.
            num_layers: The number of recurrent layers in the encoder of the
                module. Defaults to 1.
            bias: If False, the encoder does not use bias weights b_ih and
                b_hh. Defaults to True.
            dropout: If non-zero, introduces a dropout layer on the outputs
                of each layer of the encoder except the last one, with dropout
                probability equal to 'dropout'. Defaults to 0.0.
            bidirectional: If True, the encoder of the module is bidirectional.
                Defaults to False.
        """
        assert issubclass(rnn_type, nn.RNNBase),\
            "rnn_type must be a class inheriting from torch.nn.RNNBase"

        super(Seq2SeqEncoder, self).__init__()

        self.rnn_type = rnn_type
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.bias = bias
        self.dropout = dropout
        self.bidirectional = bidirectional

        self._encoder = rnn_type(input_size,
                                 hidden_size,
                                 num_layers=num_layers,
                                 bias=bias,
                                 batch_first=True,
                                 dropout=dropout,
                                 bidirectional=bidirectional) 
開發者ID:coetaur0,項目名稱:ESIM,代碼行數:49,代碼來源:layers.py


注:本文中的torch.nn.RNNBase方法示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。