当前位置: 首页>>代码示例>>Python>>正文


Python lax.conv_general_dilated方法代码示例

本文整理汇总了Python中jax.lax.conv_general_dilated方法的典型用法代码示例。如果您正苦于以下问题:Python lax.conv_general_dilated方法的具体用法?Python lax.conv_general_dilated怎么用?Python lax.conv_general_dilated使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在jax.lax的用法示例。


在下文中一共展示了lax.conv_general_dilated方法的4个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: jax_conv

# 需要导入模块: from jax import lax [as 别名]
# 或者: from jax.lax import conv_general_dilated [as 别名]
def jax_conv(inp, fltr, window_strides, padding, dimension_numbers,
             filter_dilation=None):
  """A wrapper around `lax.conv_general_dilated`.

  It requires `dimension_numbers` and disallows `inp_dilation`.

  Args:
    inp: an (N+2)-D array. The input of the convolution.
    fltr: an (N+2)-D array. The filter (i.e. kernel) of the convolution.
    window_strides: the strides for moving the convolution window.
    padding: a string, either "VALID" or "SAME". The padding algorithm.
    dimension_numbers: a tuple of three strings encoding the data format of
      input, filter and output. "I" means input; "O" means output; "C" means
      channel; other characters such as "W", "H" and "D" means spatial
      dimensions.
    filter_dilation: the dilation rates for the filter. Dilating the filter
      means adding "holes" to the filter.

  Returns:
    An (N+2)-D array. The convolution result.
  """
  return lax.conv_general_dilated(inp, fltr, window_strides, padding,
                                  lhs_dilation=None,
                                  rhs_dilation=filter_dilation,
                                  dimension_numbers=dimension_numbers) 
开发者ID:yyht,项目名称:BERT,代码行数:27,代码来源:backend.py

示例2: jax_conv

# 需要导入模块: from jax import lax [as 别名]
# 或者: from jax.lax import conv_general_dilated [as 别名]
def jax_conv(inp, fltr, window_strides, padding, dimension_numbers,
             filter_dilation=None):
  """A wrapper around `lax.conv_general_dilated`.

  It requires `dimension_numbers` and disallows `inp_dilation`.

  Args:
    inp: an (N+2)-D array. The input of the convolution.
    fltr: an (N+2)-D array. The filter (i.e. kernel) of the convolution.
    window_strides: the strides for moving the convolution window.
    padding: a string, either 'VALID' or 'SAME'. The padding algorithm.
    dimension_numbers: a tuple of three strings encoding the data format of
      input, filter and output. 'I' means input; 'O' means output; 'C' means
      channel; other characters such as 'W', 'H' and 'D' means spatial
      dimensions.
    filter_dilation: the dilation rates for the filter. Dilating the filter
      means adding "holes" to the filter.

  Returns:
    An (N+2)-D array. The convolution result.
  """
  return lax.conv_general_dilated(inp, fltr, window_strides, padding,
                                  lhs_dilation=None,
                                  rhs_dilation=filter_dilation,
                                  dimension_numbers=dimension_numbers) 
开发者ID:google,项目名称:trax,代码行数:27,代码来源:jax.py

示例3: GeneralConv

# 需要导入模块: from jax import lax [as 别名]
# 或者: from jax.lax import conv_general_dilated [as 别名]
def GeneralConv(dimension_numbers, out_chan, filter_shape, strides=None, padding='VALID',
                kernel_init=None, bias_init=normal(1e-6), dilation=None):
    """Layer construction function for a general convolution layer."""
    lhs_spec, rhs_spec, out_spec = dimension_numbers
    one = (1,) * len(filter_shape)
    strides = strides or one
    kernel_init = kernel_init or glorot_normal(rhs_spec.index('O'), rhs_spec.index('I'))
    dilation = dilation or one

    @parametrized
    def conv(inputs):
        filter_shape_iter = iter(filter_shape)
        kernel_shape = [out_chan if c == 'O' else
                        inputs.shape[lhs_spec.index('C')] if c == 'I' else
                        next(filter_shape_iter) for c in rhs_spec]
        bias_shape = tuple(itertools.dropwhile(lambda x: x == 1,
                                               [out_chan if c == 'C' else 1 for c in out_spec]))

        kernel = parameter(kernel_shape, kernel_init, 'kernel')
        bias = parameter(bias_shape, bias_init, 'bias')
        return lax.conv_general_dilated(inputs, kernel, strides, padding,
                                        lhs_dilation=one, rhs_dilation=dilation,
                                        dimension_numbers=dimension_numbers) + bias

    return conv 
开发者ID:JuliusKunze,项目名称:jaxnet,代码行数:27,代码来源:modules.py

示例4: conv2d

# 需要导入模块: from jax import lax [as 别名]
# 或者: from jax.lax import conv_general_dilated [as 别名]
def conv2d(self, x, kernel, strides=(1, 1), border_mode='same',
               image_shape=None, filter_shape=None):
        return lax.conv_general_dilated(x, kernel, strides, border_mode,
                                    dimension_numbers=("NHWC", "HWIO", "NHWC")) 
开发者ID:sharadmv,项目名称:deepx,代码行数:6,代码来源:jax.py


注:本文中的jax.lax.conv_general_dilated方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。