当前位置: 首页>>代码示例>>Python>>正文


Python data_flow_ops.dynamic_partition方法代码示例

本文整理汇总了Python中tensorflow.python.ops.data_flow_ops.dynamic_partition方法的典型用法代码示例。如果您正苦于以下问题:Python data_flow_ops.dynamic_partition方法的具体用法?Python data_flow_ops.dynamic_partition怎么用?Python data_flow_ops.dynamic_partition使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在tensorflow.python.ops.data_flow_ops的用法示例。


在下文中一共展示了data_flow_ops.dynamic_partition方法的7个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: scatter_update

# 需要导入模块: from tensorflow.python.ops import data_flow_ops [as 别名]
# 或者: from tensorflow.python.ops.data_flow_ops import dynamic_partition [as 别名]
def scatter_update(cls, factor, indices, values, sharding_func, name=None):
    """Helper function for doing sharded scatter update."""
    assert isinstance(factor, list)
    if len(factor) == 1:
      with ops.colocate_with(factor[0]):
        # TODO(agarwal): assign instead of scatter update for full batch update.
        return state_ops.scatter_update(factor[0], indices, values,
                                        name=name).op
    else:
      num_shards = len(factor)
      assignments, new_ids = sharding_func(indices)
      assert assignments is not None
      assignments = math_ops.cast(assignments, dtypes.int32)
      sharded_ids = data_flow_ops.dynamic_partition(new_ids, assignments,
                                                    num_shards)
      sharded_values = data_flow_ops.dynamic_partition(values, assignments,
                                                       num_shards)
      updates = []
      for i in xrange(num_shards):
        updates.append(state_ops.scatter_update(factor[i], sharded_ids[i],
                                                sharded_values[i]))
      return control_flow_ops.group(*updates, name=name) 
开发者ID:ryfeus,项目名称:lambda-packs,代码行数:24,代码来源:factorization_ops.py

示例2: insert

# 需要导入模块: from tensorflow.python.ops import data_flow_ops [as 别名]
# 或者: from tensorflow.python.ops.data_flow_ops import dynamic_partition [as 别名]
def insert(self, keys, values, name=None):
    self._check_keys(keys)
    num_shards = self._num_shards
    if num_shards == 1:
      return self._table_shards[0].insert(keys, values, name=name)

    shard_indices = self._shard_indices(keys)
    # TODO(andreasst): support 'keys' that are not vectors
    key_shards = data_flow_ops.dynamic_partition(keys, shard_indices,
                                                 num_shards)
    value_shards = data_flow_ops.dynamic_partition(values, shard_indices,
                                                   num_shards)
    return_values = [
        self._table_shards[i].insert(key_shards[i], value_shards[i], name=name)
        for i in range(num_shards)
    ]

    return control_flow_ops.group(*return_values) 
开发者ID:ryfeus,项目名称:lambda-packs,代码行数:20,代码来源:sharded_mutable_dense_hashtable.py

示例3: scatter_update

# 需要导入模块: from tensorflow.python.ops import data_flow_ops [as 别名]
# 或者: from tensorflow.python.ops.data_flow_ops import dynamic_partition [as 别名]
def scatter_update(cls, factor, indices, values, sharding_func):
    """Helper function for doing sharded scatter update."""
    assert isinstance(factor, list)
    if len(factor) == 1:
      with ops.colocate_with(factor[0]):
        # TODO(agarwal): assign instead of scatter update for full batch update.
        return state_ops.scatter_update(factor[0], indices, values).op
    else:
      num_shards = len(factor)
      assignments, new_ids = sharding_func(indices)
      assert assignments is not None
      assignments = math_ops.cast(assignments, dtypes.int32)
      sharded_ids = data_flow_ops.dynamic_partition(new_ids, assignments,
                                                    num_shards)
      sharded_values = data_flow_ops.dynamic_partition(values, assignments,
                                                       num_shards)
      updates = []
      for i in xrange(num_shards):
        updates.append(
            state_ops.scatter_update(factor[i], sharded_ids[i], sharded_values[
                i]))
      return control_flow_ops.group(*updates) 
开发者ID:abhisuri97,项目名称:auto-alt-text-lambda-api,代码行数:24,代码来源:factorization_ops.py

示例4: _DynamicPartitionGrads

# 需要导入模块: from tensorflow.python.ops import data_flow_ops [as 别名]
# 或者: from tensorflow.python.ops.data_flow_ops import dynamic_partition [as 别名]
def _DynamicPartitionGrads(op, *grads):
  """Gradients for DynamicPartition."""
  data = op.inputs[0]
  indices = op.inputs[1]
  num_partitions = op.get_attr("num_partitions")

  prefix_shape = array_ops.shape(indices)
  original_indices = array_ops.reshape(
      math_ops.range(math_ops.reduce_prod(prefix_shape)), prefix_shape)
  partitioned_indices = data_flow_ops.dynamic_partition(
      original_indices, indices, num_partitions)
  reconstructed = data_flow_ops.dynamic_stitch(partitioned_indices, grads)
  reconstructed = array_ops.reshape(reconstructed, array_ops.shape(data))
  return [reconstructed, None] 
开发者ID:ryfeus,项目名称:lambda-packs,代码行数:16,代码来源:data_flow_grad.py

示例5: _select_which_to_enqueue

# 需要导入模块: from tensorflow.python.ops import data_flow_ops [as 别名]
# 或者: from tensorflow.python.ops.data_flow_ops import dynamic_partition [as 别名]
def _select_which_to_enqueue(tensor_list, keep_input):
  """Select which examples to enqueue based on vector `keep_input`."""
  select_i = math_ops.cast(keep_input, dtypes.int32)
  tensor_list = [
      data_flow_ops.dynamic_partition(x, select_i, num_partitions=2)[1]
      for x in tensor_list]
  return tensor_list 
开发者ID:ryfeus,项目名称:lambda-packs,代码行数:9,代码来源:input.py

示例6: lookup

# 需要导入模块: from tensorflow.python.ops import data_flow_ops [as 别名]
# 或者: from tensorflow.python.ops.data_flow_ops import dynamic_partition [as 别名]
def lookup(self, keys, name=None):
    if keys.dtype != self._key_dtype:
      raise TypeError('Signature mismatch. Keys must be dtype %s, got %s.' %
                      (self._key_dtype, keys.dtype))
    self._check_keys(keys)
    num_shards = self._num_shards
    if num_shards == 1:
      return self._table_shards[0].lookup(keys, name=name)

    shard_indices = self._shard_indices(keys)
    # TODO(andreasst): support 'keys' that are not vectors
    key_shards = data_flow_ops.dynamic_partition(keys, shard_indices,
                                                 num_shards)
    value_shards = [
        self._table_shards[i].lookup(key_shards[i], name=name)
        for i in range(num_shards)
    ]

    num_keys = keys.get_shape().dims[0]
    original_indices = math_ops.range(num_keys)
    partitioned_indices = data_flow_ops.dynamic_partition(original_indices,
                                                          shard_indices,
                                                          num_shards)
    result = data_flow_ops.dynamic_stitch(partitioned_indices, value_shards)
    result.set_shape(
        tensor_shape.TensorShape([num_keys]).concatenate(self._value_shape))
    return result 
开发者ID:ryfeus,项目名称:lambda-packs,代码行数:29,代码来源:sharded_mutable_dense_hashtable.py

示例7: _select_which_to_enqueue

# 需要导入模块: from tensorflow.python.ops import data_flow_ops [as 别名]
# 或者: from tensorflow.python.ops.data_flow_ops import dynamic_partition [as 别名]
def _select_which_to_enqueue(tensor_list, keep_input):
  """Select which examples to enqueue based on vector `keep_input`."""
  select_i = math_ops.to_int32(keep_input)
  tensor_list = [
      data_flow_ops.dynamic_partition(x, select_i, num_partitions=2)[1]
      for x in tensor_list]
  return tensor_list 
开发者ID:PacktPublishing,项目名称:Serverless-Deep-Learning-with-TensorFlow-and-AWS-Lambda,代码行数:9,代码来源:input.py


注:本文中的tensorflow.python.ops.data_flow_ops.dynamic_partition方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。