當前位置: 首頁>>代碼示例>>Python>>正文


Python gen_nn_ops.conv3d方法代碼示例

本文整理匯總了Python中tensorflow.python.ops.gen_nn_ops.conv3d方法的典型用法代碼示例。如果您正苦於以下問題:Python gen_nn_ops.conv3d方法的具體用法?Python gen_nn_ops.conv3d怎麽用?Python gen_nn_ops.conv3d使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在tensorflow.python.ops.gen_nn_ops的用法示例。


在下文中一共展示了gen_nn_ops.conv3d方法的1個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。

示例1: __init__

# 需要導入模塊: from tensorflow.python.ops import gen_nn_ops [as 別名]
# 或者: from tensorflow.python.ops.gen_nn_ops import conv3d [as 別名]
def __init__(self,
               input_shape,
               filter_shape,  # pylint: disable=redefined-builtin
               padding, data_format=None,
               strides=None, name=None):
    filter_shape = filter_shape.with_rank(input_shape.ndims)
    self.padding = padding
    self.name = name
    input_shape = input_shape.with_rank(filter_shape.ndims)
    if input_shape.ndims is None:
      raise ValueError("Rank of convolution must be known")
    if input_shape.ndims < 3 or input_shape.ndims > 5:
      raise ValueError(
          "`input` and `filter` must have rank at least 3 and at most 5")
    conv_dims = input_shape.ndims - 2
    if strides is None:
      strides = [1] * conv_dims
    elif len(strides) != conv_dims:
      raise ValueError("len(strides)=%d, but should be %d" %
                       (len(strides), conv_dims))
    if conv_dims == 1:
      # conv1d uses the 2-d data format names
      if data_format is None or data_format == "NWC":
        data_format_2d = "NHWC"
      elif data_format == "NCW":
        data_format_2d = "NCHW"
      else:
        raise ValueError("data_format must be \"NWC\" or \"NCW\".")
      self.strides = strides[0]
      self.data_format = data_format_2d
      self.conv_op = self._conv1d
    elif conv_dims == 2:
      if data_format is None or data_format == "NHWC":
        data_format = "NHWC"
        strides = [1] + list(strides) + [1]
      elif data_format == "NCHW":
        strides = [1, 1] + list(strides)
      else:
        raise ValueError("data_format must be \"NHWC\" or \"NCHW\".")
      self.strides = strides
      self.data_format = data_format
      self.conv_op = gen_nn_ops.conv2d
    elif conv_dims == 3:
      if data_format is None or data_format == "NDHWC":
        strides = [1] + list(strides) + [1]
      elif data_format == "NCDHW":
        strides = [1, 1] + list(strides)
      else:
        raise ValueError("data_format must be \"NDHWC\" or \"NCDHW\". Have: %s"
                         % data_format)
      self.strides = strides
      self.data_format = data_format
      self.conv_op = gen_nn_ops.conv3d

  # Note that we need this adapter since argument names for conv1d don't match
  # those for gen_nn_ops.conv2d and gen_nn_ops.conv3d.
  # pylint: disable=redefined-builtin 
開發者ID:PacktPublishing,項目名稱:Serverless-Deep-Learning-with-TensorFlow-and-AWS-Lambda,代碼行數:59,代碼來源:nn_ops.py


注:本文中的tensorflow.python.ops.gen_nn_ops.conv3d方法示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。