本文整理汇总了Python中tensorflow.contrib.data.python.ops.interleave_ops.parallel_interleave函数的典型用法代码示例。如果您正苦于以下问题:Python parallel_interleave函数的具体用法?Python parallel_interleave怎么用?Python parallel_interleave使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了parallel_interleave函数的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: _make_parallel_scan_dataset
def _make_parallel_scan_dataset(self, ds, num_parallel_scans,
normalized_probability, normalized_columns):
"""Builds a parallel dataset from a given range.
Args:
ds: A `_BigtableSampleKeyPairsDataset` returning ranges of keys to use.
num_parallel_scans: The number of concurrent parallel scans to use.
normalized_probability: A number between 0 and 1 for the keep probability.
normalized_columns: The column families and column qualifiers to retrieve.
Returns:
A @{tf.data.Dataset} representing the result of the parallel scan.
"""
if num_parallel_scans is None:
num_parallel_scans = 50
ds = ds.shuffle(buffer_size=10000) # TODO(saeta): Make configurable.
def _interleave_fn(start, end):
return _BigtableScanDataset(
self,
prefix="",
start=start,
end=end,
normalized=normalized_columns,
probability=normalized_probability)
# Note prefetch_input_elements must be set in order to avoid rpc timeouts.
ds = ds.apply(
interleave_ops.parallel_interleave(
_interleave_fn,
cycle_length=num_parallel_scans,
sloppy=True,
prefetch_input_elements=1))
return ds
示例2: testShutdownRace
def testShutdownRace(self):
dataset = dataset_ops.Dataset.range(20)
map_fn = lambda x: dataset_ops.Dataset.range(20 * x, 20 * (x + 1))
dataset = dataset.apply(
interleave_ops.parallel_interleave(
map_fn,
cycle_length=3,
sloppy=False,
buffer_output_elements=1,
prefetch_input_elements=0))
dataset = dataset.batch(32)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
results = []
with self.cached_session() as sess:
for _ in range(2):
elements = []
sess.run(iterator.initializer)
try:
while True:
elements.extend(sess.run(next_element))
except errors.OutOfRangeError:
pass
results.append(elements)
self.assertAllEqual(results[0], results[1])
示例3: testParallelInterleave
def testParallelInterleave(self):
dataset = dataset_ops.Dataset.from_tensor_slices(
self._createTFRecordFiles())
dataset = dataset.apply(interleave_ops.parallel_interleave(
readers.TFRecordDataset,
cycle_length=4,
block_length=self._num_records))
dataset = input_ops.auto_shard_dataset(
dataset, self._num_shards, self._shard_index)
# Since block_length == num records in each file, the output will still
# contain records in order of files.
self._verifySimpleShardingOutput(dataset, self._record)
示例4: testErrorsInInputFn
def testErrorsInInputFn(self):
def map_py_fn(x):
if x == 5:
raise ValueError()
return x
def map_fn(x):
return script_ops.py_func(map_py_fn, [x], x.dtype)
def interleave_fn(x):
dataset = dataset_ops.Dataset.from_tensors(x)
dataset = dataset.repeat(x)
return dataset
self.dataset = (
dataset_ops.Dataset.from_tensor_slices(self.input_values).map(map_fn)
.repeat(self.repeat_count).apply(
interleave_ops.parallel_interleave(interleave_fn, self.cycle_length,
self.block_length, self.sloppy,
self.buffer_output_elements,
self.prefetch_input_elements)))
self.iterator = self.dataset.make_initializable_iterator()
self.init_op = self.iterator.initializer
self.next_element = self.iterator.get_next()
with self.test_session() as sess:
sess.run(
self.init_op,
feed_dict={
self.input_values: [4, 5, 6],
self.cycle_length: 2,
self.block_length: 1,
self.sloppy: False,
self.buffer_output_elements: 1,
self.prefetch_input_elements: 0,
})
for i, expected_element in enumerate(
self._interleave([[4] * 4, [5], [6] * 6] * self.repeat_count, 2, 1)):
if expected_element == 5:
with self.assertRaises(errors.InvalidArgumentError):
sess.run(self.next_element)
else:
actual_element = sess.run(self.next_element)
self.assertEqual(expected_element, actual_element,
"At index %s: %s expected, got: %s" %
(i, expected_element, actual_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(self.next_element)
示例5: setUp
def setUp(self):
self.input_values = array_ops.placeholder(dtypes.int64, shape=[None])
self.cycle_length = array_ops.placeholder(dtypes.int64, shape=[])
self.block_length = array_ops.placeholder(dtypes.int64, shape=[])
self.sloppy = array_ops.placeholder(dtypes.bool, shape=[])
self.buffer_output_elements = array_ops.placeholder(dtypes.int64, shape=[])
self.prefetch_input_elements = array_ops.placeholder(dtypes.int64, shape=[])
self.error = None
self.repeat_count = 2
# Set up threading events used to sequence when items are produced that
# are subsequently interleaved. These events allow us to deterministically
# simulate slowdowns and force sloppiness.
self.read_coordination_events = {}
self.write_coordination_events = {}
# input values [4, 5, 6] are the common case for the tests; set defaults
for i in range(4, 7):
self.read_coordination_events[i] = threading.Semaphore(0)
self.write_coordination_events[i] = threading.Event()
def map_py_fn(x):
self.write_coordination_events[x].wait()
self.write_coordination_events[x].clear()
self.read_coordination_events[x].release()
if self.error:
err = self.error
self.error = None
raise err # pylint: disable=raising-bad-type
return x * x
def map_fn(x):
return script_ops.py_func(map_py_fn, [x], x.dtype)
def interleave_fn(x):
dataset = dataset_ops.Dataset.from_tensors(x)
dataset = dataset.repeat(x)
return dataset.map(map_fn)
self.dataset = (
dataset_ops.Dataset.from_tensor_slices(self.input_values)
.repeat(self.repeat_count).apply(
interleave_ops.parallel_interleave(interleave_fn, self.cycle_length,
self.block_length, self.sloppy,
self.buffer_output_elements,
self.prefetch_input_elements)))
self.iterator = self.dataset.make_initializable_iterator()
self.init_op = self.iterator.initializer
self.next_element = self.iterator.get_next()
示例6: _testTooManyReaders
def _testTooManyReaders(self, sloppy=False):
def interleave_fn(x):
dataset = dataset_ops.Dataset.from_tensors(x)
dataset = dataset.repeat(math_ops.cast(x, dtype=dtypes.int64))
return dataset
dataset = dataset_ops.Dataset.from_tensor_slices([4, 5, 6])
dataset = dataset.repeat(self.repeat_count)
dataset = dataset.apply(
interleave_ops.parallel_interleave(
interleave_fn, cycle_length=16, block_length=2, sloppy=sloppy))
iterator = dataset.make_one_shot_iterator()
with self.test_session() as sess:
output_values = []
for _ in range(30):
output_values.append(sess.run(iterator.get_next()))
expected_values = self._interleave(
[[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 1, 2)
self.assertItemsEqual(output_values, expected_values)
示例7: testSparse
def testSparse(self):
def _map_fn(i):
return sparse_tensor.SparseTensor(
indices=[[0, 0], [1, 1]], values=(i * [1, -1]), dense_shape=[2, 2])
def _interleave_fn(x):
return dataset_ops.Dataset.from_tensor_slices(
sparse_ops.sparse_to_dense(x.indices, x.dense_shape, x.values))
dataset = dataset_ops.Dataset.range(10).map(_map_fn)
iterator = dataset.apply(
interleave_ops.parallel_interleave(
_interleave_fn, cycle_length=1)).make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()
with self.test_session() as sess:
sess.run(init_op)
for i in range(10):
for j in range(2):
expected = [i, 0] if j % 2 == 0 else [0, -i]
self.assertAllEqual(expected, sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
示例8: make_batched_features_dataset
#.........这里部分代码省略.........
```
features: {
"age": FixedLenFeature([], dtype=tf.int64, default_value=-1),
"gender": FixedLenFeature([], dtype=tf.string),
"kws": VarLenFeature(dtype=tf.string),
}
```
And the expected output is:
```python
{
"age": [[0], [-1]],
"gender": [["f"], ["f"]],
"kws": SparseTensor(
indices=[[0, 0], [0, 1], [1, 0]],
values=["code", "art", "sports"]
dense_shape=[2, 2]),
}
```
Args:
file_pattern: List of files or patterns of file paths containing
`Example` records. See `tf.gfile.Glob` for pattern rules.
batch_size: An int representing the number of records to combine
in a single batch.
features: A `dict` mapping feature keys to `FixedLenFeature` or
`VarLenFeature` values. See `tf.parse_example`.
reader: A function or class that can be
called with a `filenames` tensor and (optional) `reader_args` and returns
a `Dataset` of `Example` tensors. Defaults to `tf.data.TFRecordDataset`.
reader_args: Additional arguments to pass to the reader class.
num_epochs: Integer specifying the number of times to read through the
dataset. If None, cycles through the dataset forever. Defaults to `None`.
shuffle: A boolean, indicates whether the input should be shuffled. Defaults
to `True`.
shuffle_buffer_size: Buffer size of the ShuffleDataset. A large capacity
ensures better shuffling but would increase memory usage and startup time.
shuffle_seed: Randomization seed to use for shuffling.
prefetch_buffer_size: Number of feature batches to prefetch in order to
improve performance. Recommended value is the number of batches consumed
per training step (default is 1).
reader_num_threads: Number of threads used to read `Example` records. If >1,
the results will be interleaved.
parser_num_threads: Number of threads to use for parsing `Example` tensors
into a dictionary of `Feature` tensors.
sloppy_ordering: If `True`, reading performance will be improved at
the cost of non-deterministic ordering. If `False`, the order of elements
produced is deterministic prior to shuffling (elements are still
randomized if `shuffle=True`. Note that if the seed is set, then order
of elements after shuffling is deterministic). Defaults to `False`.
drop_final_batch: If `True`, and the batch size does not evenly divide the
input dataset size, the final smaller batch will be dropped. Defaults to
`False`.
Returns:
A dataset of `dict` elements. Each `dict` maps feature keys to
`Tensor` or `SparseTensor` objects.
"""
# Create dataset of all matching filenames
filenames = _get_file_names(file_pattern, False)
dataset = dataset_ops.Dataset.from_tensor_slices(filenames)
if shuffle:
dataset = dataset.shuffle(len(filenames), shuffle_seed)
# Read `Example` records from files as tensor objects.
if reader_args is None:
reader_args = []
# Read files sequentially (if reader_num_threads=1) or in parallel
dataset = dataset.apply(
interleave_ops.parallel_interleave(
lambda filename: reader(filename, *reader_args),
cycle_length=reader_num_threads,
sloppy=sloppy_ordering))
# Extract values if the `Example` tensors are stored as key-value tuples.
if dataset.output_types == (dtypes.string, dtypes.string):
dataset = dataset.map(lambda _, v: v)
# Apply dataset repeat and shuffle transformations.
dataset = _maybe_shuffle_and_repeat(
dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed)
if drop_final_batch:
dataset = dataset.apply(batching.batch_and_drop_remainder(batch_size))
else:
dataset = dataset.batch(batch_size)
# Parse `Example` tensors to a dictionary of `Feature` tensors.
dataset = dataset.map(
lambda x: parsing_ops.parse_example(x, features),
num_parallel_calls=parser_num_threads)
# TODO(rachelim): Add an optional label_name argument for extracting the label
# from the features dictionary, to comply with the type expected by the
# input_fn to a `tf.Estimator.train` or `tf.Estimator.evaluate` function.
dataset = dataset.prefetch(prefetch_buffer_size)
return dataset
示例9: make_csv_dataset
#.........这里部分代码省略.........
randomized if `shuffle=True`. Note that if the seed is set, then order
of elements after shuffling is deterministic). Defaults to `False`.
num_rows_for_inference: Number of rows of a file to use for type inference
if record_defaults is not provided. If None, reads all the rows of all
the files. Defaults to 100.
Returns:
A dataset, where each element is a (features, labels) tuple that corresponds
to a batch of `batch_size` CSV rows. The features dictionary maps feature
column names to `Tensor`s containing the corresponding column data, and
labels is a `Tensor` containing the column data for the label column
specified by `label_name`.
Raises:
ValueError: If any of the arguments is malformed.
"""
# Create dataset of all matching filenames
filenames = _get_file_names(file_pattern, False)
dataset = dataset_ops.Dataset.from_tensor_slices(filenames)
if shuffle:
dataset = dataset.shuffle(len(filenames), shuffle_seed)
# Clean arguments; figure out column names and defaults
if column_names is None:
if not header:
raise ValueError("Cannot infer column names without a header line.")
# If column names are not provided, infer from the header lines
column_names = _infer_column_names(filenames, field_delim, use_quote_delim)
if len(column_names) != len(set(column_names)):
raise ValueError("Cannot have duplicate column names.")
if select_columns is not None:
select_columns = _get_sorted_col_indices(select_columns, column_names)
if column_defaults is not None:
column_defaults = [
constant_op.constant([], dtype=x) if x in _ACCEPTABLE_CSV_TYPES else x
for x in column_defaults
]
else:
# If column defaults are not provided, infer from records at graph
# construction time
column_defaults = _infer_column_defaults(
filenames, len(column_names), field_delim, use_quote_delim, na_value,
header, num_rows_for_inference, select_columns)
if select_columns is not None and len(column_defaults) != len(select_columns):
raise ValueError(
"If specified, column_defaults and select_columns must have same "
"length."
)
if select_columns is not None and len(column_names) > len(select_columns):
# Pick the relevant subset of column names
column_names = [column_names[i] for i in select_columns]
if label_name is not None and label_name not in column_names:
raise ValueError("`label_name` provided must be one of the columns.")
def filename_to_dataset(filename):
return CsvDataset(
filename,
record_defaults=column_defaults,
field_delim=field_delim,
use_quote_delim=use_quote_delim,
na_value=na_value,
select_cols=select_columns,
header=header)
def map_fn(*columns):
"""Organizes columns into a features dictionary.
Args:
*columns: list of `Tensor`s corresponding to one csv record.
Returns:
An OrderedDict of feature names to values for that particular record. If
label_name is provided, extracts the label feature to be returned as the
second element of the tuple.
"""
features = collections.OrderedDict(zip(column_names, columns))
if label_name is not None:
label = features.pop(label_name)
return features, label
return features
# Read files sequentially (if num_parallel_reads=1) or in parallel
dataset = dataset.apply(
interleave_ops.parallel_interleave(
filename_to_dataset, cycle_length=num_parallel_reads, sloppy=sloppy))
dataset = _maybe_shuffle_and_repeat(
dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed)
# Apply batch before map for perf, because map has high overhead relative
# to the size of the computation in each map
dataset = dataset.batch(batch_size=batch_size)
dataset = dataset.map(map_fn, num_parallel_calls=num_parallel_parser_calls)
dataset = dataset.prefetch(prefetch_buffer_size)
return dataset
示例10: make_csv_dataset
#.........这里部分代码省略.........
specified by `label_name`.
Raises:
ValueError: If any of the arguments is malformed.
"""
# Create dataset of all matching filenames
filenames = _get_file_names(file_pattern, False)
dataset = dataset_ops.Dataset.from_tensor_slices(filenames)
if shuffle:
dataset = dataset.shuffle(len(filenames), shuffle_seed)
# Clean arguments; figure out column names and defaults
if comment is not None and len(comment) != 1:
raise ValueError("`comment` arg must be a single-character string or None")
if column_names is None:
if not header:
raise ValueError("Cannot infer column names without a header line.")
# If column names are not provided, infer from the header lines
column_names = _infer_column_names(filenames, field_delim, use_quote_delim)
if len(column_names) != len(set(column_names)):
raise ValueError("Cannot have duplicate column names.")
if column_defaults is not None:
column_defaults = [
constant_op.constant([], dtype=x) if x in _ACCEPTABLE_CSV_TYPES else x
for x in column_defaults
]
else:
# If column defaults are not provided, infer from records at graph
# construction time
column_defaults = _infer_column_defaults(
filenames, len(column_names), field_delim, use_quote_delim, na_value,
header, comment, default_float_type, num_rows_for_inference)
if label_name is not None and label_name not in column_names:
raise ValueError("`label_name` provided must be one of the columns.")
# Define map and filter functions
def filter_fn(line):
return math_ops.not_equal(string_ops.substr(line, 0, 1), comment)
def filename_to_dataset(filename):
ds = core_readers.TextLineDataset(filename)
if header:
ds = ds.skip(1)
if comment is not None:
ds = ds.filter(filter_fn)
return ds
def decode_csv(line):
"""Decodes CSV line into features.
Args:
line: String tensor corresponding to one csv record.
Returns:
A dictionary of feature names to values for that particular record. If
label_name is provided, extracts the label feature to be returned as the
second element of the tuple.
"""
columns = parsing_ops.decode_csv(
line,
column_defaults,
field_delim=field_delim,
use_quote_delim=use_quote_delim,
na_value=na_value,
)
features = dict(zip(column_names, columns))
if label_name is not None:
label = features.pop(label_name)
return features, label
return features
# Read files sequentially or in parallel
dataset = dataset.apply(
interleave_ops.parallel_interleave(
filename_to_dataset, cycle_length=num_parallel_reads, sloppy=sloppy))
if num_epochs != 1 and shuffle:
# Use shuffle_and_repeat for perf
dataset = dataset.apply(
shuffle_ops.shuffle_and_repeat(shuffle_buffer_size, num_epochs,
shuffle_seed))
elif shuffle:
dataset = dataset.shuffle(shuffle_buffer_size, shuffle_seed)
elif num_epochs != 1:
dataset = dataset.repeat(num_epochs)
# Use map_and_batch for perf
# TODO(b/76425672): use num_parallel_calls for better performance tuning when
# that is added
dataset = dataset.apply(
batching.map_and_batch(
map_func=decode_csv,
batch_size=batch_size,
num_parallel_batches=int(
ceil(num_parallel_parser_calls / batch_size))))
dataset = dataset.prefetch(prefetch_buffer_size)
return dataset
示例11: make_batched_features_dataset
#.........这里部分代码省略.........
"age": [[0], [-1]],
"gender": [["f"], ["f"]],
"kws": SparseTensor(
indices=[[0, 0], [0, 1], [1, 0]],
values=["code", "art", "sports"]
dense_shape=[2, 2]),
}
```
Args:
file_pattern: List of files or patterns of file paths containing
`Example` records. See `tf.gfile.Glob` for pattern rules.
batch_size: An int representing the number of records to combine
in a single batch.
features: A `dict` mapping feature keys to `FixedLenFeature` or
`VarLenFeature` values. See `tf.parse_example`.
reader: A function or class that can be
called with a `filenames` tensor and (optional) `reader_args` and returns
a `Dataset` of `Example` tensors. Defaults to `tf.data.TFRecordDataset`.
label_key: (Optional) A string corresponding to the key labels are stored in
`tf.Examples`. If provided, it must be one of the `features` key,
otherwise results in `ValueError`.
reader_args: Additional arguments to pass to the reader class.
num_epochs: Integer specifying the number of times to read through the
dataset. If None, cycles through the dataset forever. Defaults to `None`.
shuffle: A boolean, indicates whether the input should be shuffled. Defaults
to `True`.
shuffle_buffer_size: Buffer size of the ShuffleDataset. A large capacity
ensures better shuffling but would increase memory usage and startup time.
shuffle_seed: Randomization seed to use for shuffling.
prefetch_buffer_size: Number of feature batches to prefetch in order to
improve performance. Recommended value is the number of batches consumed
per training step. Defaults to auto-tune.
reader_num_threads: Number of threads used to read `Example` records. If >1,
the results will be interleaved.
parser_num_threads: Number of threads to use for parsing `Example` tensors
into a dictionary of `Feature` tensors.
sloppy_ordering: If `True`, reading performance will be improved at
the cost of non-deterministic ordering. If `False`, the order of elements
produced is deterministic prior to shuffling (elements are still
randomized if `shuffle=True`. Note that if the seed is set, then order
of elements after shuffling is deterministic). Defaults to `False`.
drop_final_batch: If `True`, and the batch size does not evenly divide the
input dataset size, the final smaller batch will be dropped. Defaults to
`False`.
Returns:
A dataset of `dict` elements, (or a tuple of `dict` elements and label).
Each `dict` maps feature keys to `Tensor` or `SparseTensor` objects.
Raises:
ValueError: If `label_key` is not one of the `features` keys.
"""
# Create dataset of all matching filenames
filenames = _get_file_names(file_pattern, False)
dataset = dataset_ops.Dataset.from_tensor_slices(filenames)
if shuffle:
dataset = dataset.shuffle(len(filenames), shuffle_seed)
# Read `Example` records from files as tensor objects.
if reader_args is None:
reader_args = []
# Read files sequentially (if reader_num_threads=1) or in parallel
dataset = dataset.apply(
interleave_ops.parallel_interleave(
lambda filename: reader(filename, *reader_args),
cycle_length=reader_num_threads,
sloppy=sloppy_ordering))
# Extract values if the `Example` tensors are stored as key-value tuples.
if dataset.output_types == (dtypes.string, dtypes.string):
dataset = dataset_ops.MapDataset(
dataset, lambda _, v: v, use_inter_op_parallelism=False)
# Apply dataset repeat and shuffle transformations.
dataset = _maybe_shuffle_and_repeat(
dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed)
# NOTE(mrry): We set `drop_remainder=True` when `num_epochs is None` to
# improve the shape inference, because it makes the batch dimension static.
# It is safe to do this because in that case we are repeating the input
# indefinitely, and all batches will be full-sized.
dataset = dataset.batch(
batch_size, drop_remainder=drop_final_batch or num_epochs is None)
# Parse `Example` tensors to a dictionary of `Feature` tensors.
dataset = dataset.apply(
parsing_ops.parse_example_dataset(
features, num_parallel_calls=parser_num_threads))
if label_key:
if label_key not in features:
raise ValueError(
"The `label_key` provided (%r) must be one of the `features` keys." %
label_key)
dataset = dataset.map(lambda x: (x, x.pop(label_key)))
dataset = dataset.prefetch(prefetch_buffer_size)
return dataset
示例12: _build_dataset
def _build_dataset():
return dataset_ops.Dataset.range(10).map(_map_fn).apply(
interleave_ops.parallel_interleave(_interleave_fn, 1))
示例13: _build_ds
def _build_ds(self, cycle_length, block_length, sloppy=False):
return (dataset_ops.Dataset.from_tensor_slices(
self.input_values).repeat(self.num_repeats).apply(
interleave_ops.parallel_interleave(
lambda x: dataset_ops.Dataset.range(10 * x, 11 * x),
cycle_length, block_length, sloppy)))
示例14: StreamingFilesDataset
#.........这里部分代码省略.........
generated. By default, it will repeat infinitely.
filename_shuffle_buffer_size: An optional integer whose value controls the
shuffling of the file names. If you would like to read from the files in
the same order, set to 0 or False.
num_parallel_reads: An optional integer controlling the number of files to
read from concurrently. (Set to 1 for no parallelism.)
batch_transfer_size: An optional integer controlling the batching used to
amortize the remote function invocation overhead. Set to a very large
number to increase throughput. Set to a very small number to reduce memory
consumption. Set to False to skip batching.
sloppy: (Optional.) If `False`, read input data while maintaining a
deterministic order. (This may have significant performance impacts.)
sloppy defaults to: True.
Returns:
A `tf.data.Dataset` with an infinite stream of elements generated by a
parallel interleaving of the set of files matched (or generated) by `files`
with a type is the output of the dataset specified by `filetype`.
Raises:
ValueError: if any argument is not of the expected type.
"""
if filetype is None:
filetype = 'tfrecord'
if isinstance(filetype, str):
if filetype not in _FILETYPE_MAP:
raise ValueError('Unexpected filetype: %s' % filetype)
reader_fn = _FILETYPE_MAP[filetype]
elif callable(filetype):
reader_fn = filetype
else:
raise ValueError('filetype should be a string or a callable')
file_reader_job = file_reader_job or 'coordinator'
worker_job = worker_job or 'worker'
if filename_shuffle_buffer_size is None:
filename_shuffle_buffer_size = 4096
num_parallel_reads = num_parallel_reads or 8
if batch_transfer_size is None:
batch_transfer_size = 256
if sloppy is None:
sloppy = True
with ops.device('/job:%s' % file_reader_job):
if isinstance(files, str):
source_dataset = dataset_ops.Dataset.list_files(files)
elif isinstance(files, dataset_ops.Dataset):
source_dataset = files
else:
raise ValueError('files was not a string or a dataset: %s' % files)
if filename_shuffle_buffer_size:
source_dataset = source_dataset.shuffle(
buffer_size=filename_shuffle_buffer_size)
# NOTE: We perform the `repeat` on the source dataset, because the output
# dataset does not currently have enough information to recreate an iterator
# over the source dataset when it reaches the end.
source_dataset = source_dataset.repeat(num_epochs)
source_dataset = source_dataset.apply(
interleave_ops.parallel_interleave(
reader_fn, cycle_length=num_parallel_reads, sloppy=sloppy))
if batch_transfer_size:
source_dataset = source_dataset.batch(batch_transfer_size)
source_dataset = source_dataset.prefetch(1)
source_iterator = source_dataset.make_one_shot_iterator()
source_handle = source_iterator.string_handle()
@function.Defun(dtypes.string)
def LoadingFunc(h):
remote_iterator = iterator_ops.Iterator.from_string_handle(
h, source_dataset.output_types, source_dataset.output_shapes)
return remote_iterator.get_next()
def MapFn(unused_input):
return functional_ops.remote_call(
args=[source_handle],
Tout=[dtypes.string],
f=LoadingFunc,
target='/job:%s/replica:0/task:0/cpu:0' % file_reader_job)[0]
with ops.device('/job:%s' % worker_job):
output_dataset = dataset_ops.Dataset.range(2).repeat().map(
MapFn, num_parallel_calls=4 if sloppy else None)
output_dataset = output_dataset.prefetch(1)
if batch_transfer_size:
# Undo the batching used during the transfer.
output_dataset = output_dataset.apply(batching.unbatch()).prefetch(1)
return output_dataset
示例15: minibatch
def minibatch(self, dataset, subset, use_datasets, cache_data,
shift_ratio=-1):
if shift_ratio < 0:
shift_ratio = self.shift_ratio
with tf.name_scope('batch_processing'):
# Build final results per split.
images = [[] for _ in range(self.num_splits)]
labels = [[] for _ in range(self.num_splits)]
if use_datasets:
glob_pattern = dataset.tf_record_pattern(subset)
file_names = gfile.Glob(glob_pattern)
if not file_names:
raise ValueError('Found no files in --data_dir matching: {}'
.format(glob_pattern))
ds = tf.data.TFRecordDataset.list_files(file_names)
ds = ds.apply(
interleave_ops.parallel_interleave(
tf.data.TFRecordDataset, cycle_length=10))
if cache_data:
ds = ds.take(1).cache().repeat()
counter = tf.data.Dataset.range(self.batch_size)
counter = counter.repeat()
ds = tf.data.Dataset.zip((ds, counter))
ds = ds.prefetch(buffer_size=self.batch_size)
ds = ds.shuffle(buffer_size=10000)
ds = ds.repeat()
ds = ds.apply(
batching.map_and_batch(
map_func=self.parse_and_preprocess,
batch_size=self.batch_size_per_split,
num_parallel_batches=self.num_splits))
ds = ds.prefetch(buffer_size=self.num_splits)
ds_iterator = ds.make_one_shot_iterator()
for d in xrange(self.num_splits):
labels[d], images[d] = ds_iterator.get_next()
else:
record_input = data_flow_ops.RecordInput(
file_pattern=dataset.tf_record_pattern(subset),
seed=301,
parallelism=64,
buffer_size=10000,
batch_size=self.batch_size,
shift_ratio=shift_ratio,
name='record_input')
records = record_input.get_yield_op()
records = tf.split(records, self.batch_size, 0)
records = [tf.reshape(record, []) for record in records]
for idx in xrange(self.batch_size):
value = records[idx]
(label, image) = self.parse_and_preprocess(value, idx)
split_index = idx % self.num_splits
labels[split_index].append(label)
images[split_index].append(image)
for split_index in xrange(self.num_splits):
if not use_datasets:
images[split_index] = tf.parallel_stack(images[split_index])
labels[split_index] = tf.concat(labels[split_index], 0)
images[split_index] = tf.cast(images[split_index], self.dtype)
depth = 3
images[split_index] = tf.reshape(
images[split_index],
shape=[self.batch_size_per_split, self.height, self.width, depth])
labels[split_index] = tf.reshape(labels[split_index],
[self.batch_size_per_split])
return images, labels