本文整理汇总了Python中jax.lax.conv_general_dilated方法的典型用法代码示例。如果您正苦于以下问题:Python lax.conv_general_dilated方法的具体用法?Python lax.conv_general_dilated怎么用?Python lax.conv_general_dilated使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类jax.lax
的用法示例。
在下文中一共展示了lax.conv_general_dilated方法的4个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: jax_conv
# 需要导入模块: from jax import lax [as 别名]
# 或者: from jax.lax import conv_general_dilated [as 别名]
def jax_conv(inp, fltr, window_strides, padding, dimension_numbers,
filter_dilation=None):
"""A wrapper around `lax.conv_general_dilated`.
It requires `dimension_numbers` and disallows `inp_dilation`.
Args:
inp: an (N+2)-D array. The input of the convolution.
fltr: an (N+2)-D array. The filter (i.e. kernel) of the convolution.
window_strides: the strides for moving the convolution window.
padding: a string, either "VALID" or "SAME". The padding algorithm.
dimension_numbers: a tuple of three strings encoding the data format of
input, filter and output. "I" means input; "O" means output; "C" means
channel; other characters such as "W", "H" and "D" means spatial
dimensions.
filter_dilation: the dilation rates for the filter. Dilating the filter
means adding "holes" to the filter.
Returns:
An (N+2)-D array. The convolution result.
"""
return lax.conv_general_dilated(inp, fltr, window_strides, padding,
lhs_dilation=None,
rhs_dilation=filter_dilation,
dimension_numbers=dimension_numbers)
示例2: jax_conv
# 需要导入模块: from jax import lax [as 别名]
# 或者: from jax.lax import conv_general_dilated [as 别名]
def jax_conv(inp, fltr, window_strides, padding, dimension_numbers,
filter_dilation=None):
"""A wrapper around `lax.conv_general_dilated`.
It requires `dimension_numbers` and disallows `inp_dilation`.
Args:
inp: an (N+2)-D array. The input of the convolution.
fltr: an (N+2)-D array. The filter (i.e. kernel) of the convolution.
window_strides: the strides for moving the convolution window.
padding: a string, either 'VALID' or 'SAME'. The padding algorithm.
dimension_numbers: a tuple of three strings encoding the data format of
input, filter and output. 'I' means input; 'O' means output; 'C' means
channel; other characters such as 'W', 'H' and 'D' means spatial
dimensions.
filter_dilation: the dilation rates for the filter. Dilating the filter
means adding "holes" to the filter.
Returns:
An (N+2)-D array. The convolution result.
"""
return lax.conv_general_dilated(inp, fltr, window_strides, padding,
lhs_dilation=None,
rhs_dilation=filter_dilation,
dimension_numbers=dimension_numbers)
示例3: GeneralConv
# 需要导入模块: from jax import lax [as 别名]
# 或者: from jax.lax import conv_general_dilated [as 别名]
def GeneralConv(dimension_numbers, out_chan, filter_shape, strides=None, padding='VALID',
kernel_init=None, bias_init=normal(1e-6), dilation=None):
"""Layer construction function for a general convolution layer."""
lhs_spec, rhs_spec, out_spec = dimension_numbers
one = (1,) * len(filter_shape)
strides = strides or one
kernel_init = kernel_init or glorot_normal(rhs_spec.index('O'), rhs_spec.index('I'))
dilation = dilation or one
@parametrized
def conv(inputs):
filter_shape_iter = iter(filter_shape)
kernel_shape = [out_chan if c == 'O' else
inputs.shape[lhs_spec.index('C')] if c == 'I' else
next(filter_shape_iter) for c in rhs_spec]
bias_shape = tuple(itertools.dropwhile(lambda x: x == 1,
[out_chan if c == 'C' else 1 for c in out_spec]))
kernel = parameter(kernel_shape, kernel_init, 'kernel')
bias = parameter(bias_shape, bias_init, 'bias')
return lax.conv_general_dilated(inputs, kernel, strides, padding,
lhs_dilation=one, rhs_dilation=dilation,
dimension_numbers=dimension_numbers) + bias
return conv
示例4: conv2d
# 需要导入模块: from jax import lax [as 别名]
# 或者: from jax.lax import conv_general_dilated [as 别名]
def conv2d(self, x, kernel, strides=(1, 1), border_mode='same',
image_shape=None, filter_shape=None):
return lax.conv_general_dilated(x, kernel, strides, border_mode,
dimension_numbers=("NHWC", "HWIO", "NHWC"))