当前位置: 首页>>代码示例>>Python>>正文


Python tensorflow.TensorShape方法代码示例

本文整理汇总了Python中tensorflow.TensorShape方法的典型用法代码示例。如果您正苦于以下问题:Python tensorflow.TensorShape方法的具体用法?Python tensorflow.TensorShape怎么用?Python tensorflow.TensorShape使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在tensorflow的用法示例。


在下文中一共展示了tensorflow.TensorShape方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: main

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import TensorShape [as 别名]
def main():
    dataset = tf.data.Dataset.from_generator(gen, (tf.int32, tf.int32),
                                             (tf.TensorShape([BATCH_SIZE]),
                                              tf.TensorShape([BATCH_SIZE, 1])))
    optimizer = tf.compat.v1.train.GradientDescentOptimizer(LEARNING_RATE)
    model = Word2Vec(vocab_size=VOCAB_SIZE, embed_size=EMBED_SIZE)
    grad_fn = tfe.implicit_value_and_gradients(model.compute_loss)
    total_loss = 0.0
    num_train_steps = 0
    while num_train_steps < NUM_TRAIN_STEPS:
        for center_words, target_words in tfe.Iterator(dataset):
            if num_train_steps >= NUM_TRAIN_STEPS:
                break
            loss_batch, grads = grad_fn(center_words, target_words)
            total_loss += loss_batch
            optimizer.apply_gradients(grads)
            if (num_train_steps + 1) % SKIP_STEP == 0:
                print('Average loss at step {}: {:5.1f}'.format(
                    num_train_steps, total_loss / SKIP_STEP
                ))
                total_loss = 0.0
            num_train_steps += 1 
开发者ID:wdxtub,项目名称:deep-learning-note,代码行数:24,代码来源:9_w2v_eager.py

示例2: reshape_by_blocks

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import TensorShape [as 别名]
def reshape_by_blocks(x, x_shape, memory_block_size):
  """Reshapes input by splitting its length over blocks of memory_block_size.

  Args:
    x: a Tensor with shape [batch, heads, length, depth]
    x_shape: tf.TensorShape of x.
    memory_block_size: Integer which divides length.

  Returns:
    Tensor with shape
    [batch, heads, length // memory_block_size, memory_block_size, depth].
  """
  x = tf.reshape(x, [
      x_shape[0], x_shape[1], x_shape[2] // memory_block_size,
      memory_block_size, x_shape[3]
  ])
  return x 
开发者ID:akzaidi,项目名称:fine-lm,代码行数:19,代码来源:common_attention.py

示例3: nn_distance

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import TensorShape [as 别名]
def nn_distance(xyz1,xyz2):
	'''
Computes the distance of nearest neighbors for a pair of point clouds
input: xyz1: (batch_size,#points_1,3)  the first point cloud
input: xyz2: (batch_size,#points_2,3)  the second point cloud
output: dist1: (batch_size,#point_1)   distance from first to second
output: idx1:  (batch_size,#point_1)   nearest neighbor from first to second
output: dist2: (batch_size,#point_2)   distance from second to first
output: idx2:  (batch_size,#point_2)   nearest neighbor from second to first
	'''
	return nn_distance_module.nn_distance(xyz1,xyz2)
#@tf.RegisterShape('NnDistance')
#def _nn_distance_shape(op):
	#shape1=op.inputs[0].get_shape().with_rank(3)
	#shape2=op.inputs[1].get_shape().with_rank(3)
	#return [tf.TensorShape([shape1.dims[0],shape1.dims[1]]),tf.TensorShape([shape1.dims[0],shape1.dims[1]]),
		#tf.TensorShape([shape2.dims[0],shape2.dims[1]]),tf.TensorShape([shape2.dims[0],shape2.dims[1]])] 
开发者ID:vinits5,项目名称:pointnet-registration-framework,代码行数:19,代码来源:tf_nndistance.py

示例4: _assert_same_size

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import TensorShape [as 别名]
def _assert_same_size(outputs, output_size):
    """Check if outputs match output_size

    Args:
        outputs: A Tensor or a (nested) tuple of tensors
        output_size: Can be an Integer, a TensorShape, or a (nested) tuple of
            Integers or TensorShape.
    """
    nest.assert_same_structure(outputs, output_size)
    flat_output_size = nest.flatten(output_size)
    flat_output = nest.flatten(outputs)

    for (output, size) in zip(flat_output, flat_output_size):
        if output[0].shape != tf.TensorShape(size):
            raise ValueError(
                "The output size does not match the the required output_size") 
开发者ID:qkaren,项目名称:Counterfactual-StoryRW,代码行数:18,代码来源:connectors.py

示例5: _compute_concat_output_shape

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import TensorShape [as 别名]
def _compute_concat_output_shape(input_shape, axis):
    """Infers the output shape of concat given the input shape.

    The code is adapted from the ConcatLayer of lasagne
    (https://github.com/Lasagne/Lasagne/blob/master/lasagne/layers/merge.py)

    Args:
        input_shape (list): A list of shapes, each of which is in turn a
            list or TensorShape.
        axis (int): Axis of the concat operation.

    Returns:
        list: Output shape of concat.
    """
    # The size of each axis of the output shape equals the first
    # input size of respective axis that is not `None`
    input_shape = [tf.TensorShape(s).as_list() for s in input_shape]
    output_shape = [next((s for s in sizes if s is not None), None)
                    for sizes in zip(*input_shape)]
    axis_sizes = [s[axis] for s in input_shape]
    concat_axis_size = None if any(s is None for s in axis_sizes) \
            else sum(axis_sizes)
    output_shape[axis] = concat_axis_size
    return output_shape 
开发者ID:qkaren,项目名称:Counterfactual-StoryRW,代码行数:26,代码来源:layers.py

示例6: _make_shapes_consistent

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import TensorShape [as 别名]
def _make_shapes_consistent(output, labels):
  """Try to make inputs have the same shape by adding dimensions of size 1."""
  shape1 = output.shape
  shape2 = labels.shape
  len1 = len(shape1)
  len2 = len(shape2)
  if len1 == len2:
    return (output, labels)
  if isinstance(shape1, tf.TensorShape):
    shape1 = tuple(shape1.as_list())
  if isinstance(shape2, tf.TensorShape):
    shape2 = tuple(shape2.as_list())
  if len1 > len2 and all(i == 1 for i in shape1[len2:]):
    for i in range(len1 - len2):
      labels = tf.expand_dims(labels, -1)
    return (output, labels)
  if len2 > len1 and all(i == 1 for i in shape2[len1:]):
    for i in range(len2 - len1):
      output = tf.expand_dims(output, -1)
    return (output, labels)
  raise ValueError("Incompatible shapes for outputs and labels: %s versus %s" %
                   (str(shape1), str(shape2))) 
开发者ID:deepchem,项目名称:deepchem,代码行数:24,代码来源:losses.py

示例7: __init__

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import TensorShape [as 别名]
def __init__(self,
                 dtype,
                 batch_shape=None,
                 value_shape=None,
                 group_ndims=0,
                 is_continuous=None,
                 **kwargs):
        dtype = tf.float32 if dtype is None else tf.as_dtype(dtype).base_dtype

        self.explicit_batch_shape = tf.TensorShape(batch_shape)

        self.explicit_value_shape = tf.TensorShape(value_shape)

        if is_continuous is None:
            is_continuous = dtype.is_floating

        super(Empirical, self).__init__(
            dtype=dtype,
            param_dtype=None,
            is_continuous=is_continuous,
            is_reparameterized=False,
            use_path_derivative=False,
            group_ndims=group_ndims,
            **kwargs) 
开发者ID:thu-ml,项目名称:zhusuan,代码行数:26,代码来源:special.py

示例8: _sample

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import TensorShape [as 别名]
def _sample(self, n_samples):
        if self.n_experiments is None:
            raise ValueError('Cannot sample when `n_experiments` is None')

        if self.logits.get_shape().ndims == 2:
            logits_flat = self.logits
        else:
            logits_flat = tf.reshape(self.logits, [-1, self.n_categories])
        samples_flat = tf.transpose(
            tf.random.categorical(logits_flat, n_samples * self.n_experiments))
        shape = tf.concat([[n_samples, self.n_experiments],
                           self.batch_shape], 0)
        samples = tf.reshape(samples_flat, shape)
        static_n_samples = n_samples if isinstance(n_samples,
                                                   int) else None
        static_n_exps = self.n_experiments \
            if isinstance(self.n_experiments, int) else None
        samples.set_shape(
            tf.TensorShape([static_n_samples, static_n_exps]).
            concatenate(self.get_batch_shape()))
        samples = tf.reduce_sum(
            tf.one_hot(samples, self.n_categories, dtype=self.dtype),
            axis=1)
        return samples 
开发者ID:thu-ml,项目名称:zhusuan,代码行数:26,代码来源:multivariate.py

示例9: main

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import TensorShape [as 别名]
def main():
    dataset = tf.data.Dataset.from_generator(gen,
                                             (tf.int32, tf.int32),
                                             (tf.TensorShape([BATCH_SIZE]), tf.TensorShape([BATCH_SIZE, 1])))
    word2vec(dataset) 
开发者ID:wdxtub,项目名称:deep-learning-note,代码行数:7,代码来源:10_w2v_graph.py

示例10: main

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import TensorShape [as 别名]
def main():
    dataset = tf.data.Dataset.from_generator(gen,
                                             (tf.int32, tf.int32),
                                             (tf.TensorShape([BATCH_SIZE]), tf.TensorShape([BATCH_SIZE, 1])))
    model = SkipGramModel(dataset, VOCAB_SIZE, EMBED_SIZE, BATCH_SIZE, NUM_SAMPLED, LEARNING_RATE)
    model.build_graph()
    model.train(NUM_TRAIN_STEPS)
    model.visualize(VISUAL_FLD, NUM_VISUALIZE) 
开发者ID:wdxtub,项目名称:deep-learning-note,代码行数:10,代码来源:11_w2v_visual.py

示例11: test_net_slice_char_logits_with_correct_shape

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import TensorShape [as 别名]
def test_net_slice_char_logits_with_correct_shape(self):
    batch_size = 2
    seq_length = 4
    num_char_classes = 3

    layer = create_layer(sequence_layers.NetSlice, batch_size, seq_length,
                         num_char_classes)
    char_logits = layer.create_logits()

    self.assertEqual(
        tf.TensorShape([batch_size, seq_length, num_char_classes]),
        char_logits.get_shape()) 
开发者ID:ringringyi,项目名称:DOTA_models,代码行数:14,代码来源:sequence_layers_test.py

示例12: test_net_slice_with_autoregression_char_logits_with_correct_shape

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import TensorShape [as 别名]
def test_net_slice_with_autoregression_char_logits_with_correct_shape(self):
    batch_size = 2
    seq_length = 4
    num_char_classes = 3

    layer = create_layer(sequence_layers.NetSliceWithAutoregression,
                         batch_size, seq_length, num_char_classes)
    char_logits = layer.create_logits()

    self.assertEqual(
        tf.TensorShape([batch_size, seq_length, num_char_classes]),
        char_logits.get_shape()) 
开发者ID:ringringyi,项目名称:DOTA_models,代码行数:14,代码来源:sequence_layers_test.py

示例13: test_attention_char_logits_with_correct_shape

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import TensorShape [as 别名]
def test_attention_char_logits_with_correct_shape(self):
    batch_size = 2
    seq_length = 4
    num_char_classes = 3

    layer = create_layer(sequence_layers.Attention, batch_size, seq_length,
                         num_char_classes)
    char_logits = layer.create_logits()

    self.assertEqual(
        tf.TensorShape([batch_size, seq_length, num_char_classes]),
        char_logits.get_shape()) 
开发者ID:ringringyi,项目名称:DOTA_models,代码行数:14,代码来源:sequence_layers_test.py

示例14: test_attention_with_autoregression_char_logits_with_correct_shape

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import TensorShape [as 别名]
def test_attention_with_autoregression_char_logits_with_correct_shape(self):
    batch_size = 2
    seq_length = 4
    num_char_classes = 3

    layer = create_layer(sequence_layers.AttentionWithAutoregression,
                         batch_size, seq_length, num_char_classes)
    char_logits = layer.create_logits()

    self.assertEqual(
        tf.TensorShape([batch_size, seq_length, num_char_classes]),
        char_logits.get_shape()) 
开发者ID:ringringyi,项目名称:DOTA_models,代码行数:14,代码来源:sequence_layers_test.py

示例15: testParsingReaderOpWhileLoop

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import TensorShape [as 别名]
def testParsingReaderOpWhileLoop(self):
    feature_size = 3
    batch_size = 5

    def ParserEndpoints():
      return gen_parser_ops.gold_parse_reader(self._task_context,
                                              feature_size,
                                              batch_size,
                                              corpus_name='training-corpus')

    with self.test_session() as sess:
      # The 'condition' and 'body' functions expect as many arguments as there
      # are loop variables. 'condition' depends on the 'epoch' loop variable
      # only, so we disregard the remaining unused function arguments. 'body'
      # returns a list of updated loop variables.
      def Condition(epoch, *unused_args):
        return tf.less(epoch, 2)

      def Body(epoch, num_actions, *feature_args):
        # By adding one of the outputs of the reader op ('epoch') as a control
        # dependency to the reader op we force the repeated evaluation of the
        # reader op.
        with epoch.graph.control_dependencies([epoch]):
          features, epoch, gold_actions = ParserEndpoints()
        num_actions = tf.maximum(num_actions,
                                 tf.reduce_max(gold_actions, [0], False) + 1)
        feature_ids = []
        for i in range(len(feature_args)):
          feature_ids.append(features[i])
        return [epoch, num_actions] + feature_ids

      epoch = ParserEndpoints()[-2]
      num_actions = tf.constant(0)
      loop_vars = [epoch, num_actions]

      res = sess.run(
          tf.while_loop(Condition, Body, loop_vars,
                        shape_invariants=[tf.TensorShape(None)] * 2,
                        parallel_iterations=1))
      logging.info('Result: %s', res)
      self.assertEqual(res[0], 2) 
开发者ID:ringringyi,项目名称:DOTA_models,代码行数:43,代码来源:reader_ops_test.py


注:本文中的tensorflow.TensorShape方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。