当前位置: 首页>>代码示例>>Python>>正文


Python sonnet.Conv2D方法代码示例

本文整理汇总了Python中sonnet.Conv2D方法的典型用法代码示例。如果您正苦于以下问题:Python sonnet.Conv2D方法的具体用法?Python sonnet.Conv2D怎么用?Python sonnet.Conv2D使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在sonnet的用法示例。


在下文中一共展示了sonnet.Conv2D方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: testConv2dIntervalBounds

# 需要导入模块: import sonnet [as 别名]
# 或者: from sonnet import Conv2D [as 别名]
def testConv2dIntervalBounds(self):
    m = snt.Conv2D(
        output_channels=1,
        kernel_shape=(2, 2),
        padding='VALID',
        stride=1,
        use_bias=True,
        initializers={
            'w': tf.constant_initializer(1.),
            'b': tf.constant_initializer(2.),
        })
    z = tf.constant([1, 2, 3, 4], dtype=tf.float32)
    z = tf.reshape(z, [1, 2, 2, 1])
    m(z)  # Connect to create weights.
    m = ibp.LinearConv2dWrapper(m)
    input_bounds = ibp.IntervalBounds(z - 1., z + 1.)
    output_bounds = m.propagate_bounds(input_bounds)
    with self.test_session() as sess:
      sess.run(tf.global_variables_initializer())
      l, u = sess.run([output_bounds.lower, output_bounds.upper])
      l = l.item()
      u = u.item()
      self.assertAlmostEqual(8., l)
      self.assertAlmostEqual(16., u) 
开发者ID:deepmind,项目名称:interval-bound-propagation,代码行数:26,代码来源:bounds_test.py

示例2: custom_build

# 需要导入模块: import sonnet [as 别名]
# 或者: from sonnet import Conv2D [as 别名]
def custom_build(inputs, is_training, keep_prob):
  x_inputs = tf.reshape(inputs, [-1, 28, 28, 1])
  """A custom build method to wrap into a sonnet Module."""
  outputs = snt.Conv2D(output_channels=32, kernel_shape=4, stride=2)(x_inputs)
  outputs = snt.BatchNorm()(outputs, is_training=is_training)
  outputs = tf.nn.relu(outputs)
  outputs = tf.nn.max_pool(outputs, ksize=[1, 2, 2, 1],
                           strides=[1, 2, 2, 1], padding='SAME')
  outputs = snt.Conv2D(output_channels=64, kernel_shape=4, stride=2)(outputs)
  outputs = snt.BatchNorm()(outputs, is_training=is_training)
  outputs = tf.nn.relu(outputs)
  outputs = tf.nn.max_pool(outputs, ksize=[1, 2, 2, 1],
                           strides=[1, 2, 2, 1], padding='SAME')
  outputs = snt.Conv2D(output_channels=1024, kernel_shape=1, stride=1)(outputs)
  outputs = snt.BatchNorm()(outputs, is_training=is_training)
  outputs = tf.nn.relu(outputs)
  outputs = snt.BatchFlatten()(outputs)
  outputs = tf.nn.dropout(outputs, keep_prob=keep_prob)
  outputs = snt.Linear(output_size=10)(outputs)
#  _activation_summary(outputs)
  return outputs 
开发者ID:normanheckscher,项目名称:mnist-multi-gpu,代码行数:23,代码来源:mnist_multi_gpu_sonnet.py

示例3: custom_build

# 需要导入模块: import sonnet [as 别名]
# 或者: from sonnet import Conv2D [as 别名]
def custom_build(self, inputs):
        """A custom build method to wrap into a sonnet Module."""
        outputs = snt.Conv2D(output_channels=16, kernel_shape=[7, 7], stride=[1, 1])(inputs)
        outputs = tf.nn.relu(outputs)
        outputs = snt.Conv2D(output_channels=16, kernel_shape=[5, 5], stride=[1, 2])(outputs)
        outputs = tf.nn.relu(outputs)
        outputs = snt.Conv2D(output_channels=16, kernel_shape=[5, 5], stride=[1, 2])(outputs)
        outputs = tf.nn.relu(outputs)
        outputs = snt.Conv2D(output_channels=16, kernel_shape=[5, 5], stride=[2, 2])(outputs)
        outputs = tf.nn.relu(outputs)
        outputs = tf.nn.dropout(outputs,  self.placeholders['keep_prob'])
        outputs = snt.BatchFlatten()(outputs)
        outputs = snt.Linear(128)(outputs)
        outputs = tf.nn.relu(outputs)

        return outputs 
开发者ID:tu-rbo,项目名称:differentiable-particle-filters,代码行数:18,代码来源:dpf_kitti.py

示例4: test_incompatible_higher_rank_inputs_raises

# 需要导入模块: import sonnet [as 别名]
# 或者: from sonnet import Conv2D [as 别名]
def test_incompatible_higher_rank_inputs_raises(self,
                                                  use_edges,
                                                  use_receiver_nodes,
                                                  use_sender_nodes,
                                                  use_globals,
                                                  field):
    """A exception should be raised if the inputs have incompatible shapes."""
    input_graph = self._get_shaped_input_graph()
    input_graph = input_graph.replace(
        **{field: tf.transpose(getattr(input_graph, field), [0, 2, 1, 3])})
    network = blocks.EdgeBlock(
        functools.partial(snt.Conv2D, output_channels=10, kernel_shape=[3, 3]),
        use_edges=use_edges,
        use_receiver_nodes=use_receiver_nodes,
        use_sender_nodes=use_sender_nodes,
        use_globals=use_globals
    )
    with self.assertRaisesRegexp(
        tf.errors.InvalidArgumentError, "Dimensions of inputs should match"):
      network(input_graph) 
开发者ID:deepmind,项目名称:graph_nets,代码行数:22,代码来源:blocks_test.py

示例5: test_incompatible_higher_rank_inputs_no_raise

# 需要导入模块: import sonnet [as 别名]
# 或者: from sonnet import Conv2D [as 别名]
def test_incompatible_higher_rank_inputs_no_raise(self,
                                                    use_edges,
                                                    use_receiver_nodes,
                                                    use_sender_nodes,
                                                    use_globals,
                                                    field):
    """No exception should occur if a differently shapped field is not used."""
    input_graph = self._get_shaped_input_graph()
    input_graph = input_graph.replace(
        **{field: tf.transpose(getattr(input_graph, field), [0, 2, 1, 3])})
    network = blocks.EdgeBlock(
        functools.partial(snt.Conv2D, output_channels=10, kernel_shape=[3, 3]),
        use_edges=use_edges,
        use_receiver_nodes=use_receiver_nodes,
        use_sender_nodes=use_sender_nodes,
        use_globals=use_globals
    )
    self._assert_build_and_run(network, input_graph) 
开发者ID:deepmind,项目名称:graph_nets,代码行数:20,代码来源:blocks_test.py

示例6: test_incompatible_higher_rank_partial_outputs_raises

# 需要导入模块: import sonnet [as 别名]
# 或者: from sonnet import Conv2D [as 别名]
def test_incompatible_higher_rank_partial_outputs_raises(self):
    """A error should be raised if partial outputs have incompatible shapes."""
    input_graph = self._get_shaped_input_graph()
    edge_model_fn, node_model_fn, global_model_fn = self._get_shaped_model_fns()
    edge_model_fn_2 = functools.partial(
        snt.Conv2D, output_channels=10, kernel_shape=[3, 3], stride=[1, 2])
    graph_network = modules.GraphNetwork(
        edge_model_fn_2, node_model_fn, global_model_fn)
    with self.assertRaisesRegexp(
        tf.errors.InvalidArgumentError, "Dimensions of inputs should match"):
      graph_network(input_graph)
    node_model_fn_2 = functools.partial(
        snt.Conv2D, output_channels=10, kernel_shape=[3, 3], stride=[1, 2])
    graph_network = modules.GraphNetwork(
        edge_model_fn, node_model_fn_2, global_model_fn)
    with self.assertRaisesRegexp(
        tf.errors.InvalidArgumentError, "Dimensions of inputs should match"):
      graph_network(input_graph) 
开发者ID:deepmind,项目名称:graph_nets,代码行数:20,代码来源:modules_test.py

示例7: test_incompatible_higher_rank_inputs_raises

# 需要导入模块: import sonnet [as 别名]
# 或者: from sonnet import Conv2D [as 别名]
def test_incompatible_higher_rank_inputs_raises(self,
                                                  use_edges,
                                                  use_receiver_nodes,
                                                  use_sender_nodes,
                                                  use_globals,
                                                  field):
    """A exception should be raised if the inputs have incompatible shapes."""
    input_graph = self._get_shaped_input_graph()
    input_graph = input_graph.replace(
        **{field: tf.transpose(getattr(input_graph, field), [0, 2, 1, 3])})
    network = blocks.EdgeBlock(
        functools.partial(snt.Conv2D, output_channels=10, kernel_shape=[3, 3]),
        use_edges=use_edges,
        use_receiver_nodes=use_receiver_nodes,
        use_sender_nodes=use_sender_nodes,
        use_globals=use_globals
    )
    with self.assertRaisesRegexp(ValueError, "in both shapes must be equal"):
      network(input_graph) 
开发者ID:deepmind,项目名称:graph_nets,代码行数:21,代码来源:blocks_test.py

示例8: test_incompatible_higher_rank_partial_outputs_raises

# 需要导入模块: import sonnet [as 别名]
# 或者: from sonnet import Conv2D [as 别名]
def test_incompatible_higher_rank_partial_outputs_raises(self):
    """A error should be raised if partial outputs have incompatible shapes."""
    input_graph = self._get_shaped_input_graph()
    edge_model_fn, node_model_fn, global_model_fn = self._get_shaped_model_fns()
    edge_model_fn_2 = functools.partial(
        snt.Conv2D, output_channels=10, kernel_shape=[3, 3], stride=[1, 2])
    graph_network = modules.GraphNetwork(
        edge_model_fn_2, node_model_fn, global_model_fn)
    with self.assertRaisesRegexp(ValueError, "in both shapes must be equal"):
      graph_network(input_graph)
    node_model_fn_2 = functools.partial(
        snt.Conv2D, output_channels=10, kernel_shape=[3, 3], stride=[1, 2])
    graph_network = modules.GraphNetwork(
        edge_model_fn, node_model_fn_2, global_model_fn)
    with self.assertRaisesRegexp(ValueError, "in both shapes must be equal"):
      graph_network(input_graph) 
开发者ID:deepmind,项目名称:graph_nets,代码行数:18,代码来源:modules_test.py

示例9: _build

# 需要导入模块: import sonnet [as 别名]
# 或者: from sonnet import Conv2D [as 别名]
def _build(self, inputs):

    if FLAGS.l2_reg:
      regularizers = {'w': lambda w: FLAGS.l2_reg*tf.nn.l2_loss(w),
                      'b': lambda w: FLAGS.l2_reg*tf.nn.l2_loss(w),}
    else:
      regularizers = None

    reshape = snt.BatchReshape([28, 28, 1])

    conv = snt.Conv2D(2, 5, padding=snt.SAME, regularizers=regularizers)
    act = _NONLINEARITY(conv(reshape(inputs)))

    pool = tf.nn.pool(act, window_shape=(2, 2), pooling_type=_POOL,
                      padding=snt.SAME, strides=(2, 2))

    conv = snt.Conv2D(4, 5, padding=snt.SAME, regularizers=regularizers)
    act = _NONLINEARITY(conv(pool))

    pool = tf.nn.pool(act, window_shape=(2, 2), pooling_type=_POOL,
                      padding=snt.SAME, strides=(2, 2))

    flatten = snt.BatchFlatten()(pool)

    linear = snt.Linear(32, regularizers=regularizers)(flatten)

    return snt.Linear(10, regularizers=regularizers)(linear) 
开发者ID:tensorflow,项目名称:kfac,代码行数:29,代码来源:classifier_mnist.py

示例10: _build

# 需要导入模块: import sonnet [as 别名]
# 或者: from sonnet import Conv2D [as 别名]
def _build(self, x):
    h = x
    for unused_i, l in enumerate(self.layers):
      h = tf.nn.relu(snt.Conv2D(l[0], l[1], l[2])(h))

    h_shape = h.get_shape().as_list()
    h = tf.reshape(h, [-1, h_shape[1] * h_shape[2] * h_shape[3]])
    for _, l in enumerate(self.padding_linear_layers):
      h = snt.Linear(l)(h)
    pre_z = snt.Linear(2 * self.n_latent)(h)
    mu = pre_z[:, :self.n_latent]
    sigma = tf.nn.softplus(pre_z[:, self.n_latent:])
    return mu, sigma 
开发者ID:magenta,项目名称:magenta,代码行数:15,代码来源:nn.py

示例11: _inputs_for_observed_module

# 需要导入模块: import sonnet [as 别名]
# 或者: from sonnet import Conv2D [as 别名]
def _inputs_for_observed_module(self, subgraph):
    """Extracts input tensors from a connected Sonnet module.

    This default implementation supports common layer types, but should be
    overridden if custom layer types are to be supported.

    Args:
      subgraph: `snt.ConnectedSubGraph` specifying the Sonnet module being
        connected, and its inputs and outputs.

    Returns:
      List of input tensors, or None if not a supported Sonnet module.
    """
    m = subgraph.module
    # Only support a few operations for now.
    if not (isinstance(m, snt.BatchReshape) or
            isinstance(m, snt.Linear) or
            isinstance(m, snt.Conv1D) or
            isinstance(m, snt.Conv2D) or
            isinstance(m, snt.BatchNorm) or
            isinstance(m, layers.ImageNorm)):
      return None

    if isinstance(m, snt.BatchNorm):
      return subgraph.inputs['input_batch'],
    else:
      return subgraph.inputs['inputs'], 
开发者ID:deepmind,项目名称:interval-bound-propagation,代码行数:29,代码来源:model.py

示例12: _wrapper_for_observed_module

# 需要导入模块: import sonnet [as 别名]
# 或者: from sonnet import Conv2D [as 别名]
def _wrapper_for_observed_module(self, subgraph):
    """Creates a wrapper for a connected Sonnet module.

    This default implementation supports common layer types, but should be
    overridden if custom layer types are to be supported.

    Args:
      subgraph: `snt.ConnectedSubGraph` specifying the Sonnet module being
        connected, and its inputs and outputs.

    Returns:
      `ibp.VerifiableWrapper` for the Sonnet module.
    """
    m = subgraph.module
    if isinstance(m, snt.BatchReshape):
      shape = subgraph.outputs.get_shape()[1:].as_list()
      return verifiable_wrapper.BatchReshapeWrapper(m, shape)
    elif isinstance(m, snt.Linear):
      return verifiable_wrapper.LinearFCWrapper(m)
    elif isinstance(m, snt.Conv1D):
      return verifiable_wrapper.LinearConv1dWrapper(m)
    elif isinstance(m, snt.Conv2D):
      return verifiable_wrapper.LinearConv2dWrapper(m)
    elif isinstance(m, layers.ImageNorm):
      return verifiable_wrapper.ImageNormWrapper(m)
    else:
      assert isinstance(m, snt.BatchNorm)
      return verifiable_wrapper.BatchNormWrapper(m) 
开发者ID:deepmind,项目名称:interval-bound-propagation,代码行数:30,代码来源:model.py

示例13: __init__

# 需要导入模块: import sonnet [as 别名]
# 或者: from sonnet import Conv2D [as 别名]
def __init__(self, module):
    if not isinstance(module, snt.Conv2D):
      raise ValueError('Cannot wrap {} with a LinearConv2dWrapper.'.format(
          module))
    super(LinearConv2dWrapper, self).__init__(module) 
开发者ID:deepmind,项目名称:interval-bound-propagation,代码行数:7,代码来源:verifiable_wrapper.py

示例14: testConv2dSymbolicBounds

# 需要导入模块: import sonnet [as 别名]
# 或者: from sonnet import Conv2D [as 别名]
def testConv2dSymbolicBounds(self):
    m = snt.Conv2D(
        output_channels=1,
        kernel_shape=(2, 2),
        padding='VALID',
        stride=1,
        use_bias=True,
        initializers={
            'w': tf.constant_initializer(1.),
            'b': tf.constant_initializer(2.),
        })
    z = tf.constant([1, 2, 3, 4], dtype=tf.float32)
    z = tf.reshape(z, [1, 2, 2, 1])
    m(z)  # Connect to create weights.
    m = ibp.LinearConv2dWrapper(m)
    input_bounds = ibp.IntervalBounds(z - 1., z + 1.)
    input_bounds = ibp.SymbolicBounds.convert(input_bounds)
    output_bounds = m.propagate_bounds(input_bounds)
    output_bounds = ibp.IntervalBounds.convert(output_bounds)
    with self.test_session() as sess:
      sess.run(tf.global_variables_initializer())
      l, u = sess.run([output_bounds.lower, output_bounds.upper])
      l = l.item()
      u = u.item()
      self.assertAlmostEqual(8., l)
      self.assertAlmostEqual(16., u) 
开发者ID:deepmind,项目名称:interval-bound-propagation,代码行数:28,代码来源:fastlin_test.py

示例15: testConv2dBackwardBounds

# 需要导入模块: import sonnet [as 别名]
# 或者: from sonnet import Conv2D [as 别名]
def testConv2dBackwardBounds(self):
    m = snt.Conv2D(
        output_channels=1,
        kernel_shape=(2, 2),
        padding='VALID',
        stride=1,
        use_bias=True,
        initializers={
            'w': tf.constant_initializer(1.),
            'b': tf.constant_initializer(2.),
        })
    z = tf.constant([1, 2, 3, 4], dtype=tf.float32)
    z = tf.reshape(z, [1, 2, 2, 1])
    m(z)  # Connect to create weights.
    m = ibp.LinearConv2dWrapper(m)
    input_bounds = ibp.IntervalBounds(z - 1., z + 1.)
    m.propagate_bounds(input_bounds)   # Create IBP bounds.
    crown_init_bounds = _generate_identity_spec([m], shape=(1, 1, 1, 1, 1))
    output_bounds = m.propagate_bounds(crown_init_bounds)
    concrete_bounds = output_bounds.concretize()
    with self.test_session() as sess:
      sess.run(tf.global_variables_initializer())
      l, u = sess.run([concrete_bounds.lower, concrete_bounds.upper])
      l = l.item()
      u = u.item()
      self.assertAlmostEqual(8., l)
      self.assertAlmostEqual(16., u) 
开发者ID:deepmind,项目名称:interval-bound-propagation,代码行数:29,代码来源:crown_test.py


注:本文中的sonnet.Conv2D方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。