本文整理汇总了Python中tensorflow.python.ops.gen_nn_ops.conv3d方法的典型用法代码示例。如果您正苦于以下问题:Python gen_nn_ops.conv3d方法的具体用法?Python gen_nn_ops.conv3d怎么用?Python gen_nn_ops.conv3d使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类tensorflow.python.ops.gen_nn_ops
的用法示例。
在下文中一共展示了gen_nn_ops.conv3d方法的1个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: __init__
# 需要导入模块: from tensorflow.python.ops import gen_nn_ops [as 别名]
# 或者: from tensorflow.python.ops.gen_nn_ops import conv3d [as 别名]
def __init__(self,
input_shape,
filter_shape, # pylint: disable=redefined-builtin
padding, data_format=None,
strides=None, name=None):
filter_shape = filter_shape.with_rank(input_shape.ndims)
self.padding = padding
self.name = name
input_shape = input_shape.with_rank(filter_shape.ndims)
if input_shape.ndims is None:
raise ValueError("Rank of convolution must be known")
if input_shape.ndims < 3 or input_shape.ndims > 5:
raise ValueError(
"`input` and `filter` must have rank at least 3 and at most 5")
conv_dims = input_shape.ndims - 2
if strides is None:
strides = [1] * conv_dims
elif len(strides) != conv_dims:
raise ValueError("len(strides)=%d, but should be %d" %
(len(strides), conv_dims))
if conv_dims == 1:
# conv1d uses the 2-d data format names
if data_format is None or data_format == "NWC":
data_format_2d = "NHWC"
elif data_format == "NCW":
data_format_2d = "NCHW"
else:
raise ValueError("data_format must be \"NWC\" or \"NCW\".")
self.strides = strides[0]
self.data_format = data_format_2d
self.conv_op = self._conv1d
elif conv_dims == 2:
if data_format is None or data_format == "NHWC":
data_format = "NHWC"
strides = [1] + list(strides) + [1]
elif data_format == "NCHW":
strides = [1, 1] + list(strides)
else:
raise ValueError("data_format must be \"NHWC\" or \"NCHW\".")
self.strides = strides
self.data_format = data_format
self.conv_op = gen_nn_ops.conv2d
elif conv_dims == 3:
if data_format is None or data_format == "NDHWC":
strides = [1] + list(strides) + [1]
elif data_format == "NCDHW":
strides = [1, 1] + list(strides)
else:
raise ValueError("data_format must be \"NDHWC\" or \"NCDHW\". Have: %s"
% data_format)
self.strides = strides
self.data_format = data_format
self.conv_op = gen_nn_ops.conv3d
# Note that we need this adapter since argument names for conv1d don't match
# those for gen_nn_ops.conv2d and gen_nn_ops.conv3d.
# pylint: disable=redefined-builtin
开发者ID:PacktPublishing,项目名称:Serverless-Deep-Learning-with-TensorFlow-and-AWS-Lambda,代码行数:59,代码来源:nn_ops.py