本文整理匯總了Python中tensorflow.python.ops.gen_nn_ops.conv3d方法的典型用法代碼示例。如果您正苦於以下問題:Python gen_nn_ops.conv3d方法的具體用法?Python gen_nn_ops.conv3d怎麽用?Python gen_nn_ops.conv3d使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在類tensorflow.python.ops.gen_nn_ops
的用法示例。
在下文中一共展示了gen_nn_ops.conv3d方法的1個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。
示例1: __init__
# 需要導入模塊: from tensorflow.python.ops import gen_nn_ops [as 別名]
# 或者: from tensorflow.python.ops.gen_nn_ops import conv3d [as 別名]
def __init__(self,
input_shape,
filter_shape, # pylint: disable=redefined-builtin
padding, data_format=None,
strides=None, name=None):
filter_shape = filter_shape.with_rank(input_shape.ndims)
self.padding = padding
self.name = name
input_shape = input_shape.with_rank(filter_shape.ndims)
if input_shape.ndims is None:
raise ValueError("Rank of convolution must be known")
if input_shape.ndims < 3 or input_shape.ndims > 5:
raise ValueError(
"`input` and `filter` must have rank at least 3 and at most 5")
conv_dims = input_shape.ndims - 2
if strides is None:
strides = [1] * conv_dims
elif len(strides) != conv_dims:
raise ValueError("len(strides)=%d, but should be %d" %
(len(strides), conv_dims))
if conv_dims == 1:
# conv1d uses the 2-d data format names
if data_format is None or data_format == "NWC":
data_format_2d = "NHWC"
elif data_format == "NCW":
data_format_2d = "NCHW"
else:
raise ValueError("data_format must be \"NWC\" or \"NCW\".")
self.strides = strides[0]
self.data_format = data_format_2d
self.conv_op = self._conv1d
elif conv_dims == 2:
if data_format is None or data_format == "NHWC":
data_format = "NHWC"
strides = [1] + list(strides) + [1]
elif data_format == "NCHW":
strides = [1, 1] + list(strides)
else:
raise ValueError("data_format must be \"NHWC\" or \"NCHW\".")
self.strides = strides
self.data_format = data_format
self.conv_op = gen_nn_ops.conv2d
elif conv_dims == 3:
if data_format is None or data_format == "NDHWC":
strides = [1] + list(strides) + [1]
elif data_format == "NCDHW":
strides = [1, 1] + list(strides)
else:
raise ValueError("data_format must be \"NDHWC\" or \"NCDHW\". Have: %s"
% data_format)
self.strides = strides
self.data_format = data_format
self.conv_op = gen_nn_ops.conv3d
# Note that we need this adapter since argument names for conv1d don't match
# those for gen_nn_ops.conv2d and gen_nn_ops.conv3d.
# pylint: disable=redefined-builtin
開發者ID:PacktPublishing,項目名稱:Serverless-Deep-Learning-with-TensorFlow-and-AWS-Lambda,代碼行數:59,代碼來源:nn_ops.py