當前位置: 首頁>>代碼示例>>Python>>正文


Python lax.conv_general_dilated方法代碼示例

本文整理匯總了Python中jax.lax.conv_general_dilated方法的典型用法代碼示例。如果您正苦於以下問題:Python lax.conv_general_dilated方法的具體用法?Python lax.conv_general_dilated怎麽用?Python lax.conv_general_dilated使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在jax.lax的用法示例。


在下文中一共展示了lax.conv_general_dilated方法的4個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。

示例1: jax_conv

# 需要導入模塊: from jax import lax [as 別名]
# 或者: from jax.lax import conv_general_dilated [as 別名]
def jax_conv(inp, fltr, window_strides, padding, dimension_numbers,
             filter_dilation=None):
  """A wrapper around `lax.conv_general_dilated`.

  It requires `dimension_numbers` and disallows `inp_dilation`.

  Args:
    inp: an (N+2)-D array. The input of the convolution.
    fltr: an (N+2)-D array. The filter (i.e. kernel) of the convolution.
    window_strides: the strides for moving the convolution window.
    padding: a string, either "VALID" or "SAME". The padding algorithm.
    dimension_numbers: a tuple of three strings encoding the data format of
      input, filter and output. "I" means input; "O" means output; "C" means
      channel; other characters such as "W", "H" and "D" means spatial
      dimensions.
    filter_dilation: the dilation rates for the filter. Dilating the filter
      means adding "holes" to the filter.

  Returns:
    An (N+2)-D array. The convolution result.
  """
  return lax.conv_general_dilated(inp, fltr, window_strides, padding,
                                  lhs_dilation=None,
                                  rhs_dilation=filter_dilation,
                                  dimension_numbers=dimension_numbers) 
開發者ID:yyht,項目名稱:BERT,代碼行數:27,代碼來源:backend.py

示例2: jax_conv

# 需要導入模塊: from jax import lax [as 別名]
# 或者: from jax.lax import conv_general_dilated [as 別名]
def jax_conv(inp, fltr, window_strides, padding, dimension_numbers,
             filter_dilation=None):
  """A wrapper around `lax.conv_general_dilated`.

  It requires `dimension_numbers` and disallows `inp_dilation`.

  Args:
    inp: an (N+2)-D array. The input of the convolution.
    fltr: an (N+2)-D array. The filter (i.e. kernel) of the convolution.
    window_strides: the strides for moving the convolution window.
    padding: a string, either 'VALID' or 'SAME'. The padding algorithm.
    dimension_numbers: a tuple of three strings encoding the data format of
      input, filter and output. 'I' means input; 'O' means output; 'C' means
      channel; other characters such as 'W', 'H' and 'D' means spatial
      dimensions.
    filter_dilation: the dilation rates for the filter. Dilating the filter
      means adding "holes" to the filter.

  Returns:
    An (N+2)-D array. The convolution result.
  """
  return lax.conv_general_dilated(inp, fltr, window_strides, padding,
                                  lhs_dilation=None,
                                  rhs_dilation=filter_dilation,
                                  dimension_numbers=dimension_numbers) 
開發者ID:google,項目名稱:trax,代碼行數:27,代碼來源:jax.py

示例3: GeneralConv

# 需要導入模塊: from jax import lax [as 別名]
# 或者: from jax.lax import conv_general_dilated [as 別名]
def GeneralConv(dimension_numbers, out_chan, filter_shape, strides=None, padding='VALID',
                kernel_init=None, bias_init=normal(1e-6), dilation=None):
    """Layer construction function for a general convolution layer."""
    lhs_spec, rhs_spec, out_spec = dimension_numbers
    one = (1,) * len(filter_shape)
    strides = strides or one
    kernel_init = kernel_init or glorot_normal(rhs_spec.index('O'), rhs_spec.index('I'))
    dilation = dilation or one

    @parametrized
    def conv(inputs):
        filter_shape_iter = iter(filter_shape)
        kernel_shape = [out_chan if c == 'O' else
                        inputs.shape[lhs_spec.index('C')] if c == 'I' else
                        next(filter_shape_iter) for c in rhs_spec]
        bias_shape = tuple(itertools.dropwhile(lambda x: x == 1,
                                               [out_chan if c == 'C' else 1 for c in out_spec]))

        kernel = parameter(kernel_shape, kernel_init, 'kernel')
        bias = parameter(bias_shape, bias_init, 'bias')
        return lax.conv_general_dilated(inputs, kernel, strides, padding,
                                        lhs_dilation=one, rhs_dilation=dilation,
                                        dimension_numbers=dimension_numbers) + bias

    return conv 
開發者ID:JuliusKunze,項目名稱:jaxnet,代碼行數:27,代碼來源:modules.py

示例4: conv2d

# 需要導入模塊: from jax import lax [as 別名]
# 或者: from jax.lax import conv_general_dilated [as 別名]
def conv2d(self, x, kernel, strides=(1, 1), border_mode='same',
               image_shape=None, filter_shape=None):
        return lax.conv_general_dilated(x, kernel, strides, border_mode,
                                    dimension_numbers=("NHWC", "HWIO", "NHWC")) 
開發者ID:sharadmv,項目名稱:deepx,代碼行數:6,代碼來源:jax.py


注:本文中的jax.lax.conv_general_dilated方法示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。