本文整理汇总了Python中tensorflow.python.data.ops.dataset_ops.flat_structure函数的典型用法代码示例。如果您正苦于以下问题:Python flat_structure函数的具体用法?Python flat_structure怎么用?Python flat_structure使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了flat_structure函数的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: __init__
def __init__(self, input_dataset):
self._input_dataset = input_dataset
temp_variant_tensor = gen_dataset_ops.prefetch_dataset(
input_dataset._variant_tensor,
buffer_size=1,
**dataset_ops.flat_structure(self))
variant_tensor = gen_dataset_ops.model_dataset(
temp_variant_tensor, **dataset_ops.flat_structure(self))
super(_TestDataset, self).__init__(input_dataset, variant_tensor)
示例2: __init__
def __init__(self, per_device_dataset, incarnation_id):
# pylint: disable=protected-access
self._structure = per_device_dataset._structure
self._init_func = per_device_dataset._init_func
self._init_captured_args = self._init_func.captured_inputs
self._next_func = per_device_dataset._next_func
self._next_captured_args = per_device_dataset._next_captured_args
# The captured arguments to the next_func are string_handle, incarnation_id.
# We update the incarnation id to the new one.
self._next_captured_args[
per_device_dataset._incarnation_id_index] = incarnation_id
self._finalize_func = per_device_dataset._finalize_func
self._finalize_captured_args = per_device_dataset._finalize_captured_args
variant_tensor = gen_dataset_ops.generator_dataset(
self._init_captured_args,
self._next_captured_args,
self._finalize_captured_args,
init_func=self._init_func,
next_func=self._next_func,
finalize_func=self._finalize_func,
**dataset_ops.flat_structure(self))
super(_ReincarnatedPerDeviceGenerator, self).__init__(variant_tensor)
示例3: _as_variant_tensor
def _as_variant_tensor(self):
input_t = self._input_dataset._as_variant_tensor() # pylint: disable=protected-access
return gen_dataset_ops.map_dataset(
input_t,
self._map_func.captured_inputs,
f=self._map_func,
**dataset_ops.flat_structure(self))
示例4: _as_variant_tensor
def _as_variant_tensor(self):
# pylint: disable=protected-access
return (
gen_experimental_dataset_ops.experimental_directed_interleave_dataset(
self._selector_input._variant_tensor,
[data_input._variant_tensor for data_input in self._data_inputs],
**dataset_ops.flat_structure(self)))
示例5: __init__
def __init__(self, datasets, num_experiments=10):
"""Chooses the fastest of some input datasets.
Given input datasets, produces elements as quickly as the fastest of the
inputs. Note that this dataset assumes that input datasets have the same
elements in the same order, though this is not enforced besides checking
that the input datasets have compatible output types, output shapes, and
cardinality at runtime. The resulting dataset produces elements that are
identical to the input elements, and in the same order.
Note that the time to first iteration is longer when this dataset is used
due to the overhead of dynamically picking the faster dataset. Namely,
for the first num_experiments iterations, this dataset will pull from all
of its inputs simultaneously in order to determine which input is the
fastest. For all subsequent iterations, that input will be used.
Args:
datasets: A list of `Datasets` that all have the same elements in the same
order.
num_experiments: The number of experiments to run before deciding which
dataset is fastest. In each "experiment" iteration, the dataset will
call from all its inputs simultaneously, and update its knowledge of
which input is the fastest.
Returns:
A `Dataset` that has the same elements the inputs.
"""
self._datasets = list(datasets)
self._structure = self._datasets[0]._element_structure # pylint: disable=protected-access
variant_tensor = (
gen_experimental_dataset_ops.experimental_choose_fastest_dataset(
[dataset._variant_tensor for dataset in self._datasets], # pylint: disable=protected-access
num_experiments=num_experiments,
**dataset_ops.flat_structure(self)))
super(_ChooseFastestDataset, self).__init__(variant_tensor)
示例6: __init__
def __init__(self, input_dataset, num_workers):
self._input_dataset = input_dataset
def recalculate_output_shapes(output_shapes):
"""Recalculates the output_shapes after dividing it by num_workers."""
if len(output_shapes) < 1:
raise ValueError("Input shape should have at least one dimension.")
if (tensor_shape.dimension_value(output_shapes[0]) and
tensor_shape.dimension_value(output_shapes[0]) % num_workers != 0):
raise errors.InvalidArgumentError(
None, None,
"First dim of input shape: %d is not divisible by num_workers: %d" %
(output_shapes[0], num_workers))
output_dims = [d for d in output_shapes.dims]
output_dims[0] = output_dims[0] // num_workers
return tensor_shape.TensorShape(output_dims)
input_types = dataset_ops.get_legacy_output_types(self._input_dataset)
input_shapes = dataset_ops.get_legacy_output_shapes(self._input_dataset)
input_classes = dataset_ops.get_legacy_output_classes(self._input_dataset)
output_shapes = nest.map_structure(recalculate_output_shapes, input_shapes)
self._structure = structure.convert_legacy_structure(
input_types, output_shapes, input_classes)
variant_tensor = ged_ops.experimental_rebatch_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
num_workers=num_workers,
**dataset_ops.flat_structure(self))
super(_RebatchDataset, self).__init__(input_dataset, variant_tensor)
示例7: _as_variant_tensor
def _as_variant_tensor(self):
return gen_dataset_ops.slide_dataset(
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
window_size=self._window_size,
window_shift=self._window_shift,
window_stride=self._window_stride,
**dataset_ops.flat_structure(self))
示例8: materialize
def materialize(self, shared_name=None, container=None):
"""Materialize creates a MaterializedIndexedDataset.
IndexedDatasets can be combined through operations such as TBD. Therefore,
they are only materialized when absolutely required.
Args:
shared_name: a string for the shared name to use for the resource.
container: a string for the container to store the resource.
Returns:
A MaterializedIndexedDataset.
"""
if container is None:
container = ""
if shared_name is None:
shared_name = ""
materialized_resource = (
ged_ops.experimental_materialized_index_dataset_handle(
container=container,
shared_name=shared_name,
**dataset_ops.flat_structure(self)))
with ops.colocate_with(materialized_resource):
materializer = ged_ops.experimental_indexed_dataset_materialize(
self._as_variant_tensor(), materialized_resource)
return MaterializedIndexedDataset(materialized_resource, materializer,
self.output_classes, self.output_types,
self.output_shapes)
示例9: __init__
def __init__(self, input_dataset, map_func, batch_size, num_parallel_calls,
drop_remainder, use_legacy_function=False):
"""See `Dataset.map()` for details."""
self._input_dataset = input_dataset
self._map_func = dataset_ops.StructuredFunctionWrapper(
map_func,
"tf.data.experimental.map_and_batch()",
dataset=input_dataset,
use_legacy_function=use_legacy_function)
self._batch_size_t = ops.convert_to_tensor(
batch_size, dtype=dtypes.int64, name="batch_size")
self._num_parallel_calls_t = ops.convert_to_tensor(
num_parallel_calls, dtype=dtypes.int64, name="num_parallel_calls")
self._drop_remainder_t = ops.convert_to_tensor(
drop_remainder, dtype=dtypes.bool, name="drop_remainder")
constant_drop_remainder = tensor_util.constant_value(self._drop_remainder_t)
if constant_drop_remainder:
# NOTE(mrry): `constant_drop_remainder` may be `None` (unknown statically)
# or `False` (explicitly retaining the remainder).
self._structure = self._map_func.output_structure._batch( # pylint: disable=protected-access
tensor_util.constant_value(self._batch_size_t))
else:
self._structure = self._map_func.output_structure._batch(None) # pylint: disable=protected-access
variant_tensor = ged_ops.experimental_map_and_batch_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
self._map_func.function.captured_inputs,
f=self._map_func.function,
batch_size=self._batch_size_t,
num_parallel_calls=self._num_parallel_calls_t,
drop_remainder=self._drop_remainder_t,
preserve_cardinality=True,
**dataset_ops.flat_structure(self))
super(_MapAndBatchDataset, self).__init__(input_dataset, variant_tensor)
示例10: _as_variant_tensor
def _as_variant_tensor(self):
return gen_dataset_ops.set_stats_aggregator_dataset(
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
self._stats_aggregator._resource, # pylint: disable=protected-access
self._tag,
self._prefix,
**dataset_ops.flat_structure(self))
示例11: _as_variant_tensor
def _as_variant_tensor(self):
return ged_ops.experimental_sliding_window_dataset(
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
window_size=self._window_size,
window_shift=self._window_shift,
window_stride=self._window_stride,
**dataset_ops.flat_structure(structure=self._output_structure))
示例12: __init__
def __init__(self, input_dataset, features, num_parallel_calls):
self._input_dataset = input_dataset
if not input_dataset._element_structure.is_compatible_with( # pylint: disable=protected-access
structure.TensorStructure(dtypes.string, [None])):
raise TypeError("Input dataset should be a dataset of vectors of strings")
self._num_parallel_calls = num_parallel_calls
# pylint: disable=protected-access
self._features = parsing_ops._prepend_none_dimension(features)
# sparse_keys and dense_keys come back sorted here.
(sparse_keys, sparse_types, dense_keys, dense_types, dense_defaults,
dense_shapes) = parsing_ops._features_to_raw_params(
self._features, [
parsing_ops.VarLenFeature, parsing_ops.SparseFeature,
parsing_ops.FixedLenFeature, parsing_ops.FixedLenSequenceFeature
])
# TODO(b/112859642): Pass sparse_index and sparse_values for SparseFeature.
(_, dense_defaults_vec, sparse_keys, sparse_types, dense_keys, dense_shapes,
dense_shape_as_shape) = parsing_ops._process_raw_parameters(
None, dense_defaults, sparse_keys, sparse_types, dense_keys,
dense_types, dense_shapes)
# pylint: enable=protected-access
self._sparse_keys = sparse_keys
self._sparse_types = sparse_types
self._dense_keys = dense_keys
self._dense_defaults = dense_defaults_vec
self._dense_shapes = dense_shapes
self._dense_types = dense_types
input_dataset_shape = dataset_ops.get_legacy_output_shapes(
self._input_dataset)
dense_output_shapes = [input_dataset_shape.concatenate(shape)
for shape in dense_shape_as_shape]
sparse_output_shapes = [input_dataset_shape.concatenate([None])
for _ in range(len(sparse_keys))]
output_shapes = dict(
zip(self._dense_keys + self._sparse_keys,
dense_output_shapes + sparse_output_shapes))
output_types = dict(
zip(self._dense_keys + self._sparse_keys,
self._dense_types + self._sparse_types))
output_classes = dict(
zip(self._dense_keys + self._sparse_keys,
[ops.Tensor for _ in range(len(self._dense_defaults))] +
[sparse_tensor.SparseTensor for _ in range(len(self._sparse_keys))
]))
self._structure = structure.convert_legacy_structure(
output_types, output_shapes, output_classes)
variant_tensor = (
gen_experimental_dataset_ops.experimental_parse_example_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
self._num_parallel_calls,
self._dense_defaults,
self._sparse_keys,
self._dense_keys,
self._sparse_types,
self._dense_shapes,
**dataset_ops.flat_structure(self)))
super(_ParseExampleDataset, self).__init__(input_dataset, variant_tensor)
示例13: __init__
def __init__(self, input_dataset, sleep_microseconds):
self._input_dataset = input_dataset
self._sleep_microseconds = sleep_microseconds
variant_tensor = gen_experimental_dataset_ops.experimental_sleep_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
self._sleep_microseconds,
**dataset_ops.flat_structure(self))
super(_SleepDataset, self).__init__(input_dataset, variant_tensor)
示例14: _as_variant_tensor
def _as_variant_tensor(self):
input_t = self._input_dataset._as_variant_tensor() # pylint: disable=protected-access
return gen_dataset_ops.scan_dataset(
input_t,
nest.flatten(sparse.serialize_sparse_tensors(self._initial_state)),
self._scan_func.captured_inputs,
f=self._scan_func,
**dataset_ops.flat_structure(self))
示例15: __init__
def __init__(self, input_dataset, thread_pool):
self._input_dataset = input_dataset
self._thread_pool = thread_pool
variant_tensor = ged_ops.experimental_thread_pool_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
self._thread_pool._resource, # pylint: disable=protected-access
**dataset_ops.flat_structure(self))
super(_ThreadPoolDataset, self).__init__(input_dataset, variant_tensor)