当前位置: 首页>>代码示例>>Python>>正文


Python gen_nn_ops.conv3d方法代码示例

本文整理汇总了Python中tensorflow.python.ops.gen_nn_ops.conv3d方法的典型用法代码示例。如果您正苦于以下问题:Python gen_nn_ops.conv3d方法的具体用法?Python gen_nn_ops.conv3d怎么用?Python gen_nn_ops.conv3d使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在tensorflow.python.ops.gen_nn_ops的用法示例。


在下文中一共展示了gen_nn_ops.conv3d方法的1个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: __init__

# 需要导入模块: from tensorflow.python.ops import gen_nn_ops [as 别名]
# 或者: from tensorflow.python.ops.gen_nn_ops import conv3d [as 别名]
def __init__(self,
               input_shape,
               filter_shape,  # pylint: disable=redefined-builtin
               padding, data_format=None,
               strides=None, name=None):
    filter_shape = filter_shape.with_rank(input_shape.ndims)
    self.padding = padding
    self.name = name
    input_shape = input_shape.with_rank(filter_shape.ndims)
    if input_shape.ndims is None:
      raise ValueError("Rank of convolution must be known")
    if input_shape.ndims < 3 or input_shape.ndims > 5:
      raise ValueError(
          "`input` and `filter` must have rank at least 3 and at most 5")
    conv_dims = input_shape.ndims - 2
    if strides is None:
      strides = [1] * conv_dims
    elif len(strides) != conv_dims:
      raise ValueError("len(strides)=%d, but should be %d" %
                       (len(strides), conv_dims))
    if conv_dims == 1:
      # conv1d uses the 2-d data format names
      if data_format is None or data_format == "NWC":
        data_format_2d = "NHWC"
      elif data_format == "NCW":
        data_format_2d = "NCHW"
      else:
        raise ValueError("data_format must be \"NWC\" or \"NCW\".")
      self.strides = strides[0]
      self.data_format = data_format_2d
      self.conv_op = self._conv1d
    elif conv_dims == 2:
      if data_format is None or data_format == "NHWC":
        data_format = "NHWC"
        strides = [1] + list(strides) + [1]
      elif data_format == "NCHW":
        strides = [1, 1] + list(strides)
      else:
        raise ValueError("data_format must be \"NHWC\" or \"NCHW\".")
      self.strides = strides
      self.data_format = data_format
      self.conv_op = gen_nn_ops.conv2d
    elif conv_dims == 3:
      if data_format is None or data_format == "NDHWC":
        strides = [1] + list(strides) + [1]
      elif data_format == "NCDHW":
        strides = [1, 1] + list(strides)
      else:
        raise ValueError("data_format must be \"NDHWC\" or \"NCDHW\". Have: %s"
                         % data_format)
      self.strides = strides
      self.data_format = data_format
      self.conv_op = gen_nn_ops.conv3d

  # Note that we need this adapter since argument names for conv1d don't match
  # those for gen_nn_ops.conv2d and gen_nn_ops.conv3d.
  # pylint: disable=redefined-builtin 
开发者ID:PacktPublishing,项目名称:Serverless-Deep-Learning-with-TensorFlow-and-AWS-Lambda,代码行数:59,代码来源:nn_ops.py


注:本文中的tensorflow.python.ops.gen_nn_ops.conv3d方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。