本文整理汇总了Python中tensorflow.contrib.data.map_and_batch方法的典型用法代码示例。如果您正苦于以下问题:Python data.map_and_batch方法的具体用法?Python data.map_and_batch怎么用?Python data.map_and_batch使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类tensorflow.contrib.data
的用法示例。
在下文中一共展示了data.map_and_batch方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: _input_fn
# 需要导入模块: from tensorflow.contrib import data [as 别名]
# 或者: from tensorflow.contrib.data import map_and_batch [as 别名]
def _input_fn():
features = {
'data': tf.FixedLenFeature([], tf.string),
'labels': tf.FixedLenFeature([], tf.int64),
}
def parse(record):
return tf.parse_single_example(record, features)
ds = PipeModeDataset(config.channel, benchmark=True)
if config.epochs > 1:
ds = ds.repeat(config.epochs)
if config.prefetch_size > 0:
ds = ds.prefetch(config.prefetch_size)
ds = ds.apply(map_and_batch(parse, batch_size=config.batch_size,
num_parallel_batches=config.parallel_transform_calls))
return ds
示例2: _input_fn
# 需要导入模块: from tensorflow.contrib import data [as 别名]
# 或者: from tensorflow.contrib.data import map_and_batch [as 别名]
def _input_fn():
features = {
'data': tf.FixedLenFeature([], tf.string),
'labels': tf.FixedLenFeature([], tf.int64),
}
def parse(record):
parsed = tf.parse_single_example(record, features)
return ({
'data': tf.decode_raw(parsed['data'], tf.float64)
}, parsed['labels'])
ds = PipeModeDataset(config.channel, benchmark=True)
if config.epochs > 1:
ds = ds.repeat(config.epochs)
if config.prefetch_size > 0:
ds = ds.prefetch(config.prefetch_size)
ds = ds.apply(map_and_batch(parse, batch_size=config.batch_size,
num_parallel_batches=config.parallel_transform_calls))
return ds
示例3: input_fn_builder
# 需要导入模块: from tensorflow.contrib import data [as 别名]
# 或者: from tensorflow.contrib.data import map_and_batch [as 别名]
def input_fn_builder(input_file, is_training, drop_remainder,
names_to_features):
"""Creates an `input_fn` closure to be passed to TPUEstimator."""
def _decode_record(record, name_to_features):
"""Decodes a record to a TensorFlow example."""
example = tf.parse_single_example(record, name_to_features)
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
# So cast all int64 to int32.
for name in list(example.keys()):
t = example[name]
if t.dtype == tf.int64:
t = tf.to_int32(t)
example[name] = t
return example
def input_fn(params):
"""The actual input function."""
batch_size = params["batch_size"]
# For training, we want a lot of parallel reading and shuffling.
# For eval, we want no shuffling and parallel reading doesn't matter.
d = tf.data.TFRecordDataset(input_file)
if is_training:
d = d.repeat()
d = d.shuffle(buffer_size=100)
d = d.apply(
contrib_data.map_and_batch(
lambda record: _decode_record(record, names_to_features),
batch_size=batch_size,
drop_remainder=drop_remainder))
return d
return input_fn
示例4: build_model
# 需要导入模块: from tensorflow.contrib import data [as 别名]
# 或者: from tensorflow.contrib.data import map_and_batch [as 别名]
def build_model(self):
""" Graph Input """
# images
Image_Data_Class = ImageData(self.img_size, self.c_dim, self.custom_dataset)
inputs = tf.data.Dataset.from_tensor_slices(self.data)
gpu_device = '/gpu:0'
inputs = inputs.\
apply(shuffle_and_repeat(self.dataset_num)).\
apply(map_and_batch(Image_Data_Class.image_processing, self.batch_size, num_parallel_batches=16, drop_remainder=True)).\
apply(prefetch_to_device(gpu_device, self.batch_size))
inputs_iterator = inputs.make_one_shot_iterator()
self.inputs = inputs_iterator.get_next()
# noises
self.z = tf.random_normal(shape=[self.batch_size, 1, 1, self.z_dim], name='random_z')
""" Loss Function """
# output of D for real images
real_logits = self.discriminator(self.inputs)
# output of D for fake images
fake_images = self.generator(self.z)
fake_logits = self.discriminator(fake_images, reuse=True)
if self.gan_type.__contains__('wgan') or self.gan_type == 'dragan':
GP = self.gradient_penalty(real=self.inputs, fake=fake_images)
else:
GP = 0
# get loss for discriminator
self.d_loss = discriminator_loss(self.gan_type, real=real_logits, fake=fake_logits, moment=self.moment) + GP
# get loss for generator
self.g_loss = generator_loss(self.gan_type, fake=fake_logits, moment=self.moment)
""" Training """
# divide trainable variables into a group for D and a group for G
t_vars = tf.trainable_variables()
d_vars = [var for var in t_vars if 'discriminator' in var.name]
g_vars = [var for var in t_vars if 'generator' in var.name]
# optimizers
with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)) :
self.d_optim = tf.train.AdamOptimizer(self.d_learning_rate, beta1=self.beta1, beta2=self.beta2).minimize(self.d_loss, var_list=d_vars)
self.g_optim = tf.train.AdamOptimizer(self.g_learning_rate, beta1=self.beta1, beta2=self.beta2).minimize(self.g_loss, var_list=g_vars)
"""" Testing """
# for test
self.fake_images = self.generator(self.z, is_training=False, reuse=True)
""" Summary """
self.d_sum = tf.summary.scalar("d_loss", self.d_loss)
self.g_sum = tf.summary.scalar("g_loss", self.g_loss)
##################################################################################
# Train
##################################################################################
示例5: build_model
# 需要导入模块: from tensorflow.contrib import data [as 别名]
# 或者: from tensorflow.contrib.data import map_and_batch [as 别名]
def build_model(self):
""" Graph Input """
# images
if self.custom_dataset :
Image_Data_Class = ImageData(self.img_size, self.c_dim)
inputs = tf.data.Dataset.from_tensor_slices(self.data)
gpu_device = '/gpu:0'
inputs = inputs.apply(shuffle_and_repeat(self.dataset_num)).apply(map_and_batch(Image_Data_Class.image_processing, self.batch_size, num_parallel_batches=16, drop_remainder=True)).apply(prefetch_to_device(gpu_device, self.batch_size))
inputs_iterator = inputs.make_one_shot_iterator()
self.inputs = inputs_iterator.get_next()
else :
self.inputs = tf.placeholder(tf.float32, [self.batch_size, self.img_size, self.img_size, self.c_dim], name='real_images')
# noises
self.z = tf.placeholder(tf.float32, [self.batch_size, 1, 1, self.z_dim], name='z')
""" Loss Function """
# output of D for real images
real_logits = self.discriminator(self.inputs)
# output of D for fake images
fake_images = self.generator(self.z)
fake_logits = self.discriminator(fake_images, reuse=True)
if self.gan_type.__contains__('wgan') or self.gan_type == 'dragan' :
GP = self.gradient_penalty(real=self.inputs, fake=fake_images)
else :
GP = 0
# get loss for discriminator
self.d_loss = discriminator_loss(self.gan_type, real=real_logits, fake=fake_logits) + GP
# get loss for generator
self.g_loss = generator_loss(self.gan_type, fake=fake_logits)
""" Training """
# divide trainable variables into a group for D and a group for G
t_vars = tf.trainable_variables()
d_vars = [var for var in t_vars if 'discriminator' in var.name]
g_vars = [var for var in t_vars if 'generator' in var.name]
# optimizers
self.d_optim = tf.train.AdamOptimizer(self.d_learning_rate, beta1=self.beta1, beta2=self.beta2).minimize(self.d_loss, var_list=d_vars)
self.g_optim = tf.train.AdamOptimizer(self.g_learning_rate, beta1=self.beta1, beta2=self.beta2).minimize(self.g_loss, var_list=g_vars)
"""" Testing """
# for test
self.fake_images = self.generator(self.z, is_training=False, reuse=True)
""" Summary """
self.d_sum = tf.summary.scalar("d_loss", self.d_loss)
self.g_sum = tf.summary.scalar("g_loss", self.g_loss)
##################################################################################
# Train
##################################################################################
示例6: build_model
# 需要导入模块: from tensorflow.contrib import data [as 别名]
# 或者: from tensorflow.contrib.data import map_and_batch [as 别名]
def build_model(self):
""" Graph Input """
# images
if self.custom_dataset :
Image_Data_Class = ImageData(self.img_size, self.c_dim)
inputs = tf.data.Dataset.from_tensor_slices(self.data)
gpu_device = '/gpu:0'
inputs = inputs.apply(shuffle_and_repeat(self.dataset_num)).apply(map_and_batch(Image_Data_Class.image_processing, self.batch_size, num_parallel_batches=16, drop_remainder=True)).apply(prefetch_to_device(gpu_device, self.batch_size))
inputs_iterator = inputs.make_one_shot_iterator()
self.inputs = inputs_iterator.get_next()
else :
self.inputs = tf.placeholder(tf.float32, [self.batch_size, self.img_size, self.img_size, self.c_dim], name='real_images')
# noises
self.z = tf.placeholder(tf.float32, [self.batch_size, 1, 1, self.z_dim], name='z')
""" Loss Function """
# output of D for real images
real_logits = self.discriminator(self.inputs)
# output of D for fake images
fake_images = self.generator(self.z)
fake_logits = self.discriminator(fake_images, reuse=True)
if self.gan_type.__contains__('gp') or self.gan_type.__contains__('lp') or self.gan_type.__contains__('dragan') :
GP = self.gradient_penalty(real=self.inputs, fake=fake_images)
else :
GP = 0
# get loss for discriminator
self.d_loss = discriminator_loss(self.Ra, self.gan_type, real=real_logits, fake=fake_logits) + GP
# get loss for generator
self.g_loss = generator_loss(self.Ra, self.gan_type, real=real_logits, fake=fake_logits)
""" Training """
# divide trainable variables into a group for D and a group for G
t_vars = tf.trainable_variables()
d_vars = [var for var in t_vars if 'discriminator' in var.name]
g_vars = [var for var in t_vars if 'generator' in var.name]
# optimizers
self.d_optim = tf.train.AdamOptimizer(self.d_learning_rate, beta1=self.beta1, beta2=self.beta2).minimize(self.d_loss, var_list=d_vars)
self.g_optim = tf.train.AdamOptimizer(self.g_learning_rate, beta1=self.beta1, beta2=self.beta2).minimize(self.g_loss, var_list=g_vars)
"""" Testing """
# for test
self.fake_images = self.generator(self.z, is_training=False, reuse=True)
""" Summary """
self.d_sum = tf.summary.scalar("d_loss", self.d_loss)
self.g_sum = tf.summary.scalar("g_loss", self.g_loss)
##################################################################################
# Train
##################################################################################
示例7: file_based_input_fn_builder
# 需要导入模块: from tensorflow.contrib import data [as 别名]
# 或者: from tensorflow.contrib.data import map_and_batch [as 别名]
def file_based_input_fn_builder(input_file, seq_length, is_training,
drop_remainder, task_name, use_tpu, bsz,
multiple=1):
"""Creates an `input_fn` closure to be passed to TPUEstimator."""
labeltype = tf.float32 if task_name == "sts-b" else tf.int64
name_to_features = {
"input_ids": tf.FixedLenFeature([seq_length * multiple], tf.int64),
"input_mask": tf.FixedLenFeature([seq_length * multiple], tf.int64),
"segment_ids": tf.FixedLenFeature([seq_length * multiple], tf.int64),
"label_ids": tf.FixedLenFeature([], labeltype),
"is_real_example": tf.FixedLenFeature([], tf.int64),
}
def _decode_record(record, name_to_features):
"""Decodes a record to a TensorFlow example."""
example = tf.parse_single_example(record, name_to_features)
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
# So cast all int64 to int32.
for name in list(example.keys()):
t = example[name]
if t.dtype == tf.int64:
t = tf.to_int32(t)
example[name] = t
return example
def input_fn(params):
"""The actual input function."""
if use_tpu:
batch_size = params["batch_size"]
else:
batch_size = bsz
# For training, we want a lot of parallel reading and shuffling.
# For eval, we want no shuffling and parallel reading doesn't matter.
d = tf.data.TFRecordDataset(input_file)
if is_training:
d = d.repeat()
d = d.shuffle(buffer_size=100)
d = d.apply(
contrib_data.map_and_batch(
lambda record: _decode_record(record, name_to_features),
batch_size=batch_size,
drop_remainder=drop_remainder))
return d
return input_fn
示例8: input_fn_builder
# 需要导入模块: from tensorflow.contrib import data [as 别名]
# 或者: from tensorflow.contrib.data import map_and_batch [as 别名]
def input_fn_builder(input_file, seq_length, is_training,
drop_remainder, use_tpu, bsz, is_v2):
"""Creates an `input_fn` closure to be passed to TPUEstimator."""
name_to_features = {
"unique_ids": tf.FixedLenFeature([], tf.int64),
"input_ids": tf.FixedLenFeature([seq_length], tf.int64),
"input_mask": tf.FixedLenFeature([seq_length], tf.int64),
"segment_ids": tf.FixedLenFeature([seq_length], tf.int64),
}
# p_mask is not required for SQuAD v1.1
if is_v2:
name_to_features["p_mask"] = tf.FixedLenFeature([seq_length], tf.int64)
if is_training:
name_to_features["start_positions"] = tf.FixedLenFeature([], tf.int64)
name_to_features["end_positions"] = tf.FixedLenFeature([], tf.int64)
name_to_features["is_impossible"] = tf.FixedLenFeature([], tf.int64)
def _decode_record(record, name_to_features):
"""Decodes a record to a TensorFlow example."""
example = tf.parse_single_example(record, name_to_features)
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
# So cast all int64 to int32.
for name in list(example.keys()):
t = example[name]
if t.dtype == tf.int64:
t = tf.to_int32(t)
example[name] = t
return example
def input_fn(params):
"""The actual input function."""
if use_tpu:
batch_size = params["batch_size"]
else:
batch_size = bsz
# For training, we want a lot of parallel reading and shuffling.
# For eval, we want no shuffling and parallel reading doesn't matter.
d = tf.data.TFRecordDataset(input_file)
if is_training:
d = d.repeat()
d = d.shuffle(buffer_size=100)
d = d.apply(
contrib_data.map_and_batch(
lambda record: _decode_record(record, name_to_features),
batch_size=batch_size,
drop_remainder=drop_remainder))
return d
return input_fn
示例9: file_based_input_fn_builder
# 需要导入模块: from tensorflow.contrib import data [as 别名]
# 或者: from tensorflow.contrib.data import map_and_batch [as 别名]
def file_based_input_fn_builder(input_file, seq_length, is_training,
drop_remainder):
"""Creates an `input_fn` closure to be passed to TPUEstimator."""
name_to_features = {
"input_ids": tf.FixedLenFeature([seq_length], tf.int64),
"input_mask": tf.FixedLenFeature([seq_length], tf.int64),
"segment_ids": tf.FixedLenFeature([seq_length], tf.int64),
"label_ids": tf.FixedLenFeature([], tf.int64),
"is_real_example": tf.FixedLenFeature([], tf.int64),
}
def _decode_record(record, name_to_features):
"""Decodes a record to a TensorFlow example."""
example = tf.parse_single_example(record, name_to_features)
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
# So cast all int64 to int32.
for name in list(example.keys()):
t = example[name]
if t.dtype == tf.int64:
t = tf.to_int32(t)
example[name] = t
return example
def input_fn(params):
"""The actual input function."""
batch_size = params["batch_size"]
# For training, we want a lot of parallel reading and shuffling.
# For eval, we want no shuffling and parallel reading doesn't matter.
d = tf.data.TFRecordDataset(input_file)
if is_training:
d = d.repeat()
d = d.shuffle(buffer_size=100)
d = d.apply(
contrib_data.map_and_batch(
lambda record: _decode_record(record, name_to_features),
batch_size=batch_size,
drop_remainder=drop_remainder))
return d
return input_fn
示例10: file_based_input_fn_builder
# 需要导入模块: from tensorflow.contrib import data [as 别名]
# 或者: from tensorflow.contrib.data import map_and_batch [as 别名]
def file_based_input_fn_builder(input_file, seq_length, is_training,
drop_remainder, num_labels):
"""Creates an `input_fn` closure to be passed to TPUEstimator."""
name_to_features = {
"input_ids": tf.FixedLenFeature([seq_length], tf.int64),
"input_mask": tf.FixedLenFeature([seq_length], tf.int64),
"segment_ids": tf.FixedLenFeature([seq_length], tf.int64),
"label_ids": tf.FixedLenFeature([num_labels], tf.float32),
"is_real_example": tf.FixedLenFeature([], tf.int64),
}
def _decode_record(record, name_to_features):
"""Decodes a record to a TensorFlow example."""
example = tf.parse_single_example(record, name_to_features)
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
# So cast all int64 to int32.
for name in list(example.keys()):
t = example[name]
if t.dtype == tf.int64:
t = tf.to_int32(t)
example[name] = t
return example
def input_fn(params):
"""The actual input function."""
batch_size = params["batch_size"]
# For training, we want a lot of parallel reading and shuffling.
# For eval, we want no shuffling and parallel reading doesn't matter.
d = tf.data.TFRecordDataset(input_file)
if is_training:
d = d.repeat()
d = d.shuffle(buffer_size=100)
d = d.apply(
contrib_data.map_and_batch(
lambda record: _decode_record(record, name_to_features),
batch_size=batch_size,
drop_remainder=drop_remainder))
return d
return input_fn
示例11: file_based_input_fn_builder
# 需要导入模块: from tensorflow.contrib import data [as 别名]
# 或者: from tensorflow.contrib.data import map_and_batch [as 别名]
def file_based_input_fn_builder(input_file, seq_length, is_training,
drop_remainder):
"""Creates an `input_fn` closure to be passed to TPUEstimator."""
name_to_features = {
"input_ids": tf.FixedLenFeature([seq_length], tf.int64),
"input_mask": tf.FixedLenFeature([seq_length], tf.int64),
"segment_ids": tf.FixedLenFeature([seq_length], tf.int64),
"label_ids": tf.FixedLenFeature([], tf.int64),
"probs": tf.FixedLenFeature([2], tf.float32)
}
def _decode_record(record, name_to_features):
"""Decodes a record to a TensorFlow example."""
example = tf.parse_single_example(record, name_to_features)
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
# So cast all int64 to int32.
for name in list(example.keys()):
t = example[name]
if t.dtype == tf.int64:
t = tf.to_int32(t)
example[name] = t
return example
def input_fn(params):
"""The actual input function."""
batch_size = params["batch_size"]
# For training, we want a lot of parallel reading and shuffling.
# For eval, we want no shuffling and parallel reading doesn't matter.
d = tf.data.TFRecordDataset(input_file)
if is_training:
d = d.repeat()
d = d.shuffle(buffer_size=100)
d = d.apply(
contrib_data.map_and_batch(
lambda record: _decode_record(record, name_to_features),
batch_size=batch_size,
drop_remainder=drop_remainder))
return d
return input_fn
示例12: input_fn_builder
# 需要导入模块: from tensorflow.contrib import data [as 别名]
# 或者: from tensorflow.contrib.data import map_and_batch [as 别名]
def input_fn_builder(input_file, seq_length, is_training, drop_remainder):
"""Creates an `input_fn` closure to be passed to TPUEstimator."""
name_to_features = {
"unique_ids": tf.FixedLenFeature([], tf.int64),
"input_ids": tf.FixedLenFeature([seq_length], tf.int64),
"input_mask": tf.FixedLenFeature([seq_length], tf.int64),
"segment_ids": tf.FixedLenFeature([seq_length], tf.int64),
}
if is_training:
name_to_features["start_positions"] = tf.FixedLenFeature([], tf.int64)
name_to_features["end_positions"] = tf.FixedLenFeature([], tf.int64)
def _decode_record(record, name_to_features):
"""Decodes a record to a TensorFlow example."""
example = tf.parse_single_example(record, name_to_features)
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
# So cast all int64 to int32.
for name in list(example.keys()):
t = example[name]
if t.dtype == tf.int64:
t = tf.to_int32(t)
example[name] = t
return example
def input_fn(params):
"""The actual input function."""
batch_size = params["batch_size"]
# For training, we want a lot of parallel reading and shuffling.
# For eval, we want no shuffling and parallel reading doesn't matter.
d = tf.data.TFRecordDataset(input_file)
if is_training:
d = d.repeat()
d = d.shuffle(buffer_size=100)
d = d.apply(
contrib_data.map_and_batch(
lambda record: _decode_record(record, name_to_features),
batch_size=batch_size,
drop_remainder=drop_remainder))
return d
return input_fn
示例13: file_based_input_fn_builder
# 需要导入模块: from tensorflow.contrib import data [as 别名]
# 或者: from tensorflow.contrib.data import map_and_batch [as 别名]
def file_based_input_fn_builder(input_file, seq_length, is_training,
drop_remainder):
"""Creates an `input_fn` closure to be passed to TPUEstimator."""
name_to_features = {
"input_ids": tf.FixedLenFeature([seq_length], tf.int64),
"input_mask": tf.FixedLenFeature([seq_length], tf.int64),
"segment_ids": tf.FixedLenFeature([seq_length], tf.int64),
"label_ids": tf.FixedLenFeature([], tf.int64),
}
def _decode_record(record, name_to_features):
"""Decodes a record to a TensorFlow example."""
example = tf.parse_single_example(record, name_to_features)
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
# So cast all int64 to int32.
for name in list(example.keys()):
t = example[name]
if t.dtype == tf.int64:
t = tf.to_int32(t)
example[name] = t
return example
def input_fn(params):
"""The actual input function."""
batch_size = params["batch_size"]
# For training, we want a lot of parallel reading and shuffling.
# For eval, we want no shuffling and parallel reading doesn't matter.
d = tf.data.TFRecordDataset(input_file)
if is_training:
d = d.repeat()
d = d.shuffle(buffer_size=100)
d = d.apply(
contrib_data.map_and_batch(
lambda record: _decode_record(record, name_to_features),
batch_size=batch_size,
drop_remainder=drop_remainder))
return d
return input_fn
示例14: input_fn_builder
# 需要导入模块: from tensorflow.contrib import data [as 别名]
# 或者: from tensorflow.contrib.data import map_and_batch [as 别名]
def input_fn_builder(input_file, seq_length, is_training, drop_remainder):
"""Creates an `input_fn` closure to be passed to TPUEstimator."""
name_to_features = {
"unique_ids": tf.FixedLenFeature([], tf.int64),
"input_ids": tf.FixedLenFeature([seq_length], tf.int64),
"input_mask": tf.FixedLenFeature([seq_length], tf.int64),
"segment_ids": tf.FixedLenFeature([seq_length], tf.int64),
}
if is_training:
name_to_features["label_ids"] = tf.FixedLenFeature([], tf.int64)
def _decode_record(record, name_to_features):
"""Decodes a record to a TensorFlow example."""
example = tf.parse_single_example(record, name_to_features)
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
# So cast all int64 to int32.
for name in list(example.keys()):
t = example[name]
if t.dtype == tf.int64:
t = tf.to_int32(t)
example[name] = t
return example
def input_fn(params):
"""The actual input function."""
batch_size = params["batch_size"]
# For training, we want a lot of parallel reading and shuffling.
# For eval, we want no shuffling and parallel reading doesn't matter.
d = tf.data.TFRecordDataset(input_file)
if is_training:
d = d.repeat()
d = d.shuffle(buffer_size=100)
d = d.apply(
contrib_data.map_and_batch(
lambda record: _decode_record(record, name_to_features),
batch_size=batch_size,
drop_remainder=drop_remainder))
return d
return input_fn
示例15: file_based_input_fn_builder
# 需要导入模块: from tensorflow.contrib import data [as 别名]
# 或者: from tensorflow.contrib.data import map_and_batch [as 别名]
def file_based_input_fn_builder(input_file,
seq_length,
is_training,
drop_remainder,
task_name,
use_tpu,
bsz,
multiple=1):
"""Creates an `input_fn` closure to be passed to TPUEstimator."""
labeltype = tf.float32 if task_name == "sts-b" else tf.int64
name_to_features = {
"input_ids": tf.FixedLenFeature([seq_length * multiple], tf.int64),
"input_mask": tf.FixedLenFeature([seq_length * multiple], tf.int64),
"segment_ids": tf.FixedLenFeature([seq_length * multiple], tf.int64),
"label_ids": tf.FixedLenFeature([], labeltype),
"is_real_example": tf.FixedLenFeature([], tf.int64),
}
def _decode_record(record, name_to_features):
"""Decodes a record to a TensorFlow example."""
example = tf.parse_single_example(record, name_to_features)
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
# So cast all int64 to int32.
for name in list(example.keys()):
t = example[name]
if t.dtype == tf.int64:
t = tf.to_int32(t)
example[name] = t
return example
def input_fn(params):
"""The actual input function."""
if use_tpu:
batch_size = params["batch_size"]
else:
batch_size = bsz
# For training, we want a lot of parallel reading and shuffling.
# For eval, we want no shuffling and parallel reading doesn't matter.
d = tf.data.TFRecordDataset(input_file)
if is_training:
d = d.repeat()
d = d.shuffle(buffer_size=100)
d = d.apply(
contrib_data.map_and_batch(
lambda record: _decode_record(record, name_to_features),
batch_size=batch_size,
drop_remainder=drop_remainder))
return d
return input_fn