当前位置: 首页>>代码示例>>Python>>正文


Python rnn.ResidualWrapper方法代码示例

本文整理汇总了Python中tensorflow.contrib.rnn.ResidualWrapper方法的典型用法代码示例。如果您正苦于以下问题:Python rnn.ResidualWrapper方法的具体用法?Python rnn.ResidualWrapper怎么用?Python rnn.ResidualWrapper使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在tensorflow.contrib.rnn的用法示例。


在下文中一共展示了rnn.ResidualWrapper方法的5个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: _make_cell

# 需要导入模块: from tensorflow.contrib import rnn [as 别名]
# 或者: from tensorflow.contrib.rnn import ResidualWrapper [as 别名]
def _make_cell(self, hidden_size=None):
        """Create a single RNN cell"""
        cell = self.cell_type(hidden_size or self.hidden_size)
        if self.dropout:
            cell = rnn.DropoutWrapper(cell, self.keep_prob)
        if self.residual:
            cell = rnn.ResidualWrapper(cell)
        return cell 
开发者ID:frnsys,项目名称:retrosynthesis_planner,代码行数:10,代码来源:seq2seq.py

示例2: build_single_cell

# 需要导入模块: from tensorflow.contrib import rnn [as 别名]
# 或者: from tensorflow.contrib.rnn import ResidualWrapper [as 别名]
def build_single_cell(self, n_hidden, use_residual):
        """构建一个单独的rnn cell
        Args:
            n_hidden: 隐藏层神经元数量
            use_residual: 是否使用residual wrapper
        """

        if self.cell_type == 'gru':
            cell_type = GRUCell
        else:
            cell_type = LSTMCell

        cell = cell_type(n_hidden)

        if self.use_dropout:
            cell = DropoutWrapper(
                cell,
                dtype=tf.float32,
                output_keep_prob=self.keep_prob_placeholder,
                seed=self.seed
            )

        if use_residual:
            cell = ResidualWrapper(cell)

        return cell 
开发者ID:yongzhuo,项目名称:nlp_xiaojiang,代码行数:28,代码来源:model_seq2seq.py

示例3: create_cell

# 需要导入模块: from tensorflow.contrib import rnn [as 别名]
# 或者: from tensorflow.contrib.rnn import ResidualWrapper [as 别名]
def create_cell(unit_type, hidden_units, num_layers, use_residual=False, input_keep_prob=1.0, output_keep_prob=1.0, devices=None):
    if unit_type == 'lstm':
        def _new_cell():
            return tf.nn.rnn_cell.BasicLSTMCell(hidden_units)
    elif unit_type == 'gru':
        def _new_cell():
            return tf.contrib.rnn.GRUCell(hidden_units)
    else:
        raise ValueError('cell_type must be either lstm or gru')

    def _new_cell_wrapper(residual_connection=False, device_id=None):
        c = _new_cell()

        if input_keep_prob < 1.0 or output_keep_prob < 1.0:
            c = rnn.DropoutWrapper(c, input_keep_prob=input_keep_prob, output_keep_prob=output_keep_prob)

        if residual_connection:
            c = rnn.ResidualWrapper(c)

        if device_id:
            c = rnn.DeviceWrapper(c, device_id)

        return c

    if num_layers > 1:
        cells = []

        for i in range(num_layers):
            is_residual = True if use_residual and i > 0 else False
            cells.append(_new_cell_wrapper(is_residual, devices[i] if devices else None))

        return tf.contrib.rnn.MultiRNNCell(cells)
    else:
        return _new_cell_wrapper(device_id=devices[0] if devices else None) 
开发者ID:nouhadziri,项目名称:THRED,代码行数:36,代码来源:rnn_factory.py

示例4: _single_cell

# 需要导入模块: from tensorflow.contrib import rnn [as 别名]
# 或者: from tensorflow.contrib.rnn import ResidualWrapper [as 别名]
def _single_cell(unit_type,
                 num_units,
                 forget_bias,
                 dropout,
                 mode,
                 residual_connection=False,
                 residual_fn=None,
                 trainable=True):
  """Create an instance of a single RNN cell."""
  # dropout (= 1 - keep_prob) is set to 0 during eval and infer
  dropout = dropout if mode == tf.estimator.ModeKeys.TRAIN else 0.0

  # Cell Type
  if unit_type == "lstm":
    single_cell = contrib_rnn.LSTMCell(
        num_units, forget_bias=forget_bias, trainable=trainable)
  elif unit_type == "gru":
    single_cell = contrib_rnn.GRUCell(num_units, trainable=trainable)
  elif unit_type == "layer_norm_lstm":
    single_cell = contrib_rnn.LayerNormBasicLSTMCell(
        num_units,
        forget_bias=forget_bias,
        layer_norm=True,
        trainable=trainable)
  elif unit_type == "nas":
    single_cell = contrib_rnn.NASCell(num_units, trainable=trainable)
  else:
    raise ValueError("Unknown unit type %s!" % unit_type)

  # Dropout (= 1 - keep_prob).
  if dropout > 0.0:
    single_cell = contrib_rnn.DropoutWrapper(
        cell=single_cell, input_keep_prob=(1.0 - dropout))

  # Residual.
  if residual_connection:
    single_cell = contrib_rnn.ResidualWrapper(
        single_cell, residual_fn=residual_fn)

  return single_cell 
开发者ID:google-research,项目名称:language,代码行数:42,代码来源:model_utils.py

示例5: rnn_cell

# 需要导入模块: from tensorflow.contrib import rnn [as 别名]
# 或者: from tensorflow.contrib.rnn import ResidualWrapper [as 别名]
def rnn_cell(rnn_cell_size, dropout_keep_prob, residual, is_training=True):
  """Builds an LSTMBlockCell based on the given parameters."""
  dropout_keep_prob = dropout_keep_prob if is_training else 1.0
  cells = []
  for i in range(len(rnn_cell_size)):
    cell = rnn.LSTMBlockCell(rnn_cell_size[i])
    if residual:
      cell = rnn.ResidualWrapper(cell)
      if i == 0 or rnn_cell_size[i] != rnn_cell_size[i - 1]:
        cell = rnn.InputProjectionWrapper(cell, rnn_cell_size[i])
    cell = rnn.DropoutWrapper(
        cell,
        input_keep_prob=dropout_keep_prob)
    cells.append(cell)
  return rnn.MultiRNNCell(cells) 
开发者ID:personads,项目名称:synvae,代码行数:17,代码来源:lstm_utils.py


注:本文中的tensorflow.contrib.rnn.ResidualWrapper方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。