本文整理汇总了Python中preprocessing.get_input_tensors方法的典型用法代码示例。如果您正苦于以下问题:Python preprocessing.get_input_tensors方法的具体用法?Python preprocessing.get_input_tensors怎么用?Python preprocessing.get_input_tensors使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类preprocessing
的用法示例。
在下文中一共展示了preprocessing.get_input_tensors方法的11个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: extract_data
# 需要导入模块: import preprocessing [as 别名]
# 或者: from preprocessing import get_input_tensors [as 别名]
def extract_data(self, tf_record, filter_amount=1):
pos_tensor, label_tensors = preprocessing.get_input_tensors(
1, [tf_record], num_repeats=1, shuffle_records=False,
shuffle_examples=False, filter_amount=filter_amount)
recovered_data = []
with tf.Session() as sess:
while True:
try:
pos_value, label_values = sess.run([pos_tensor, label_tensors])
recovered_data.append((
pos_value,
label_values['pi_tensor'],
label_values['value_tensor']))
except tf.errors.OutOfRangeError:
break
return recovered_data
示例2: extract_data
# 需要导入模块: import preprocessing [as 别名]
# 或者: from preprocessing import get_input_tensors [as 别名]
def extract_data(self, tf_record, filter_amount=1):
pos_tensor, label_tensors = preprocessing.get_input_tensors(
model_params.DummyMiniGoParams(), 1, [tf_record], num_repeats=1,
shuffle_records=False, shuffle_examples=False,
filter_amount=filter_amount)
recovered_data = []
with tf.Session() as sess:
while True:
try:
pos_value, label_values = sess.run([pos_tensor, label_tensors])
recovered_data.append((
pos_value,
label_values['pi_tensor'],
label_values['value_tensor']))
except tf.errors.OutOfRangeError:
break
return recovered_data
示例3: validate
# 需要导入模块: import preprocessing [as 别名]
# 或者: from preprocessing import get_input_tensors [as 别名]
def validate(*tf_records):
"""Validate a model's performance on a set of holdout data."""
if FLAGS.use_tpu:
def _input_fn(params):
return preprocessing.get_tpu_input_tensors(
params['train_batch_size'], params['input_layout'], tf_records,
filter_amount=1.0)
else:
def _input_fn():
return preprocessing.get_input_tensors(
FLAGS.train_batch_size, FLAGS.input_layout, tf_records,
filter_amount=1.0, shuffle_examples=False)
steps = FLAGS.examples_to_validate // FLAGS.train_batch_size
if FLAGS.use_tpu:
steps //= FLAGS.num_tpu_cores
estimator = dual_net.get_estimator()
with utils.logged_timer("Validating"):
estimator.evaluate(_input_fn, steps=steps, name=FLAGS.validate_name)
示例4: train
# 需要导入模块: import preprocessing [as 别名]
# 或者: from preprocessing import get_input_tensors [as 别名]
def train(working_dir, tf_records, generation_num, **hparams):
assert generation_num > 0, "Model 0 is random weights"
estimator = get_estimator(working_dir, **hparams)
print ("generations = ", generation_num)
max_steps = generation_num * EXAMPLES_PER_GENERATION // TRAIN_BATCH_SIZE
print ("max_steps = ", max_steps)
def input_fn(): return preprocessing.get_input_tensors(
TRAIN_BATCH_SIZE, tf_records)
update_ratio_hook = UpdateRatioSessionHook(working_dir)
print("Train with TRAIN_BATCH_SIZE=", TRAIN_BATCH_SIZE)
estimator.train(input_fn, hooks=[update_ratio_hook], max_steps=max_steps)
示例5: validate
# 需要导入模块: import preprocessing [as 别名]
# 或者: from preprocessing import get_input_tensors [as 别名]
def validate(working_dir, tf_records, checkpoint_name=None, **hparams):
estimator = get_estimator(working_dir, **hparams)
if checkpoint_name is None:
checkpoint_name = estimator.latest_checkpoint()
def input_fn(): return preprocessing.get_input_tensors(
TRAIN_BATCH_SIZE, tf_records, shuffle_buffer_size=1000,
filter_amount=0.05)
estimator.evaluate(input_fn, steps=1000)
示例6: train
# 需要导入模块: import preprocessing [as 别名]
# 或者: from preprocessing import get_input_tensors [as 别名]
def train(working_dir, tf_records, generation_num, params):
"""Train the model for a specific generation.
Args:
working_dir: The model working directory to save model parameters,
drop logs, checkpoints, and so on.
tf_records: A list of tf_record filenames for training input.
generation_num: The generation to be trained.
params: hyperparams of the model.
Raises:
ValueError: if generation_num is not greater than 0.
"""
if generation_num <= 0:
raise ValueError('Model 0 is random weights')
estimator = tf.estimator.Estimator(
dualnet_model.model_fn, model_dir=working_dir, params=params)
max_steps = (generation_num * params.examples_per_generation
// params.batch_size)
profiler_hook = tf.train.ProfilerHook(output_dir=working_dir, save_secs=600)
def input_fn():
return preprocessing.get_input_tensors(
params, params.batch_size, tf_records)
estimator.train(
input_fn, hooks=[profiler_hook], max_steps=max_steps)
示例7: validate
# 需要导入模块: import preprocessing [as 别名]
# 或者: from preprocessing import get_input_tensors [as 别名]
def validate(working_dir, tf_records, params):
"""Perform model validation on the hold out data.
Args:
working_dir: The model working directory.
tf_records: A list of tf_records filenames for holdout data.
params: hyperparams of the model.
"""
estimator = tf.estimator.Estimator(
dualnet_model.model_fn, model_dir=working_dir, params=params)
def input_fn():
return preprocessing.get_input_tensors(
params, params.batch_size, tf_records, filter_amount=0.05)
estimator.evaluate(input_fn, steps=1000)
示例8: train
# 需要导入模块: import preprocessing [as 别名]
# 或者: from preprocessing import get_input_tensors [as 别名]
def train(estimator_dir, tf_records, model_version, **kwargs):
"""
Main training function for the PolicyValueNetwork
Args:
estimator_dir (str): Path to the estimator directory
tf_records (list): A list of TFRecords from which we parse the training examples
model_version (int): The version of the model
"""
model = get_estimator(estimator_dir, **kwargs)
logger.info("Training model version: {}".format(model_version))
max_steps = model_version * GLOBAL_PARAMETER_STORE.EXAMPLES_PER_GENERATION // \
GLOBAL_PARAMETER_STORE.TRAIN_BATCH_SIZE
model.train(input_fn=lambda: preprocessing.get_input_tensors(list_tf_records=tf_records),
max_steps=max_steps)
logger.info("Trained model version: {}".format(model_version))
示例9: validate
# 需要导入模块: import preprocessing [as 别名]
# 或者: from preprocessing import get_input_tensors [as 别名]
def validate(estimator_dir, tf_records, checkpoint_path=None, **kwargs):
model = get_estimator(estimator_dir, **kwargs)
if checkpoint_path is None:
checkpoint_path = model.latest_checkpoint()
model.evaluate(input_fn=lambda: preprocessing.get_input_tensors(
list_tf_records=tf_records,
buffer_size=GLOBAL_PARAMETER_STORE.VALIDATION_BUFFER_SIZE),
steps=GLOBAL_PARAMETER_STORE.VALIDATION_NUMBER_OF_STEPS,
checkpoint_path=checkpoint_path)
示例10: extract_data
# 需要导入模块: import preprocessing [as 别名]
# 或者: from preprocessing import get_input_tensors [as 别名]
def extract_data(self, tf_record, filter_amount=1, random_rotation=False):
pos_tensor, label_tensors = preprocessing.get_input_tensors(
1, [tf_record], num_repeats=1, shuffle_records=False,
shuffle_examples=False, filter_amount=filter_amount,
random_rotation=random_rotation)
return self.get_data_tensors(pos_tensor, label_tensors)
示例11: train
# 需要导入模块: import preprocessing [as 别名]
# 或者: from preprocessing import get_input_tensors [as 别名]
def train(working_dir, tf_records, generation, params):
"""Train the model for a specific generation.
Args:
working_dir: The model working directory to save model parameters,
drop logs, checkpoints, and so on.
tf_records: A list of tf_record filenames for training input.
generation: The generation to be trained.
params: hyperparams of the model.
Raises:
ValueError: if generation is not greater than 0.
"""
if generation <= 0:
raise ValueError('Model 0 is random weights')
estimator = tf.estimator.Estimator(
dualnet_model.model_fn, model_dir=working_dir, params=params)
max_steps = (generation * params.examples_per_generation
// params.batch_size)
profiler_hook = tf.train.ProfilerHook(output_dir=working_dir, save_secs=600)
def input_fn():
return preprocessing.get_input_tensors(
params, params.batch_size, tf_records)
estimator.train(
input_fn, hooks=[profiler_hook], max_steps=max_steps)