本文整理匯總了Python中torch.nn.RNNBase方法的典型用法代碼示例。如果您正苦於以下問題:Python nn.RNNBase方法的具體用法?Python nn.RNNBase怎麽用?Python nn.RNNBase使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在類torch.nn
的用法示例。
在下文中一共展示了nn.RNNBase方法的6個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。
示例1: __init__
# 需要導入模塊: from torch import nn [as 別名]
# 或者: from torch.nn import RNNBase [as 別名]
def __init__(self,mode, input_size, hidden_size,
num_layers=1, bias=True, batch_first=False,
dropout=0, bidirectional=False,weight_init=None):
"""
:param mode:
:param input_size:
:param hidden_size:
:param num_layers:
:param bias:
:param batch_first:
:param dropout:
:param bidirectional:
:param weight_init:
"""
super(RNNBase,self).__init__(mode, input_size, hidden_size,
num_layers, bias, batch_first, dropout,bidirectional)
if weight_init is not None:
for weight in super(RNNBase, self).parameters():
weight_init(weight)
示例2: __init__
# 需要導入模塊: from torch import nn [as 別名]
# 或者: from torch.nn import RNNBase [as 別名]
def __init__(self, cont_feats:List[str], vecs:List[str], feats_per_vec:List[str],
depth:int, width:int, bidirectional:bool=False, rnn:nn.RNNBase=nn.RNN,
do:float=0., act:str='tanh', stateful:bool=False, freeze:bool=False, **kargs):
super().__init__(cont_feats=cont_feats, vecs=vecs, feats_per_vec=feats_per_vec, row_wise=True, freeze=freeze)
self.stateful,self.width,self.bidirectional = stateful,width,bidirectional
p = partial(rnn, input_size=self.n_fpv, hidden_size=width, num_layers=depth, bias=True, batch_first=True, dropout=do, bidirectional=bidirectional)
try: self.rnn = p(nonlinearity=act)
except TypeError: self.rnn = p()
self._init_rnn(width)
self._map_outputs()
if self.freeze: self.freeze_layers()
示例3: orthogonal_rnn_init_
# 需要導入模塊: from torch import nn [as 別名]
# 或者: from torch.nn import RNNBase [as 別名]
def orthogonal_rnn_init_(cell: nn.RNNBase, gain: float = 1.):
"""
Orthogonal initialization of recurrent weights
RNN parameters contain 3 or 4 matrices in one parameter, so we slice it.
"""
with torch.no_grad():
for _, hh, _, _ in cell.all_weights:
for i in range(0, hh.size(0), cell.hidden_size):
nn.init.orthogonal_(hh.data[i:i + cell.hidden_size], gain=gain)
示例4: lstm_forget_gate_init_
# 需要導入模塊: from torch import nn [as 別名]
# 或者: from torch.nn import RNNBase [as 別名]
def lstm_forget_gate_init_(cell: nn.RNNBase, value: float = 1.) -> None:
"""
Initialize LSTM forget gates with `value`.
:param cell: LSTM cell
:param value: initial value, default: 1
"""
with torch.no_grad():
for _, _, ih_b, hh_b in cell.all_weights:
l = len(ih_b)
ih_b.data[l // 4:l // 2].fill_(value)
hh_b.data[l // 4:l // 2].fill_(value)
示例5: _setup
# 需要導入模塊: from torch import nn [as 別名]
# 或者: from torch.nn import RNNBase [as 別名]
def _setup(self):
# Terrible temporary solution to an issue regarding compacting weights re: CUDNN RNN
if issubclass(type(self.module), nn.RNNBase):
self.module.flatten_parameters = self.widget_demagnetizer_y2k_edition
for name_w in self.weights:
print('Applying weight drop of {} to {}'.format(self.dropout, name_w))
w = getattr(self.module, name_w)
del self.module._parameters[name_w]
self.module.register_parameter(name_w + '_raw', nn.Parameter(w.data))
示例6: __init__
# 需要導入模塊: from torch import nn [as 別名]
# 或者: from torch.nn import RNNBase [as 別名]
def __init__(self,
rnn_type,
input_size,
hidden_size,
num_layers=1,
bias=True,
dropout=0.0,
bidirectional=False):
"""
Args:
rnn_type: The type of RNN to use as encoder in the module.
Must be a class inheriting from torch.nn.RNNBase
(such as torch.nn.LSTM for example).
input_size: The number of expected features in the input of the
module.
hidden_size: The number of features in the hidden state of the RNN
used as encoder by the module.
num_layers: The number of recurrent layers in the encoder of the
module. Defaults to 1.
bias: If False, the encoder does not use bias weights b_ih and
b_hh. Defaults to True.
dropout: If non-zero, introduces a dropout layer on the outputs
of each layer of the encoder except the last one, with dropout
probability equal to 'dropout'. Defaults to 0.0.
bidirectional: If True, the encoder of the module is bidirectional.
Defaults to False.
"""
assert issubclass(rnn_type, nn.RNNBase),\
"rnn_type must be a class inheriting from torch.nn.RNNBase"
super(Seq2SeqEncoder, self).__init__()
self.rnn_type = rnn_type
self.input_size = input_size
self.hidden_size = hidden_size
self.num_layers = num_layers
self.bias = bias
self.dropout = dropout
self.bidirectional = bidirectional
self._encoder = rnn_type(input_size,
hidden_size,
num_layers=num_layers,
bias=bias,
batch_first=True,
dropout=dropout,
bidirectional=bidirectional)