本文整理汇总了Python中gin.tf方法的典型用法代码示例。如果您正苦于以下问题:Python gin.tf方法的具体用法?Python gin.tf怎么用?Python gin.tf使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类gin
的用法示例。
在下文中一共展示了gin.tf方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: get_eval_hooks
# 需要导入模块: import gin [as 别名]
# 或者: from gin import tf [as 别名]
def get_eval_hooks(self, config, params):
"""Get eval_hooks to be passed to estimator spec."""
logging.warning('This function is deprecated and will be replaced.')
hooks = []
summary_op = tf.summary.merge_all()
if summary_op is not None:
eval_name = 'eval'
if params is not None:
eval_name = params.get('eval_name', eval_name)
hooks = [
tf.train.SummarySaverHook(
output_dir=os.path.join(config.model_dir, eval_name),
save_steps=config.save_summary_steps,
summary_op=summary_op),
]
return hooks
#############################################################################
# END DEPRECATED functions which will be removed soon.
#############################################################################
示例2: get_variable_dtype
# 需要导入模块: import gin [as 别名]
# 或者: from gin import tf [as 别名]
def get_variable_dtype(
master_dtype=tf.bfloat16,
slice_dtype=tf.float32,
activation_dtype=tf.float32):
"""Datatypes to use for the run.
Args:
master_dtype: string, datatype for checkpoints
keep this the same between training and eval/inference
slice_dtype: string, datatype for variables in memory
must be tf.float32 for training
activation_dtype: string, datatype for activations
less memory usage if tf.bfloat16 but possible numerical issues
Returns:
a mtf.VariableDtype
"""
return mtf.VariableDType(
master_dtype=tf.as_dtype(master_dtype),
slice_dtype=tf.as_dtype(slice_dtype),
activation_dtype=tf.as_dtype(activation_dtype))
示例3: clean_decodes
# 需要导入模块: import gin [as 别名]
# 或者: from gin import tf [as 别名]
def clean_decodes(ids, eos_id=1, pad_id=0, length_axis=-1):
"""Replaces everything after EOS with PAD (along last axis).
Args:
ids: a d Tensor of type int.
eos_id: int, EOS id.
pad_id: int, PAD id.
length_axis: an integer.
Returns:
a Tensor of type int of ids.
"""
eos_and_after = tf.cumsum(tf.cast(tf.equal(ids, eos_id), tf.int32),
exclusive=True, axis=length_axis)
valid_ids = tf.equal(eos_and_after, 0)
return tf.where_v2(valid_ids, ids, pad_id)
示例4: _get_latest_checkpoint_from_dir
# 需要导入模块: import gin [as 别名]
# 或者: from gin import tf [as 别名]
def _get_latest_checkpoint_from_dir(model_dir):
"""Helper function to return the latest checkpoint number from a directory.
Args:
model_dir: str, Directory with checkpoint files.
Returns:
an int, latest checkpoint number.
Raises:
ValueError: if no checkpoints are found.
"""
ckpt = tf.train.latest_checkpoint(model_dir)
if ckpt is None:
raise ValueError("No checkpoints found in model directory: %s" % model_dir)
return int(re.sub(".*ckpt-", "", ckpt))
示例5: default_init_from_checkpoint_fn
# 需要导入模块: import gin [as 别名]
# 或者: from gin import tf [as 别名]
def default_init_from_checkpoint_fn(checkpoint,
allow_partial_restore = False):
"""init_from_checkpoint_fn that can be used to init a model from a checkpoint.
Args:
checkpoint: String pointing to path of TF checkpoint.
allow_partial_restore: If True, we allow partial restore, otherwise we raise
an error if a variable cannot be restored.
Raises:
A ValueError if a variable(s) is missing and partial restore is not
explicitly enabled.
"""
logging.info('Initializing model weights from %s', checkpoint)
reader = tf.train.load_checkpoint(checkpoint)
variables_to_restore = contrib_framework.get_variables()
assignment_map = {}
for v in variables_to_restore:
op_name = v.op.name
if reader.has_tensor(op_name):
logging.info('Loading variable %s from checkpoint', op_name)
assignment_map[op_name] = v
elif allow_partial_restore:
logging.warning('Variable %s is not in the checkpoint, skipping.',
op_name)
else:
raise ValueError('Attempting to restore variable {} which is '
'not in the checkpoint.'.format(op_name))
tf.train.init_from_checkpoint(checkpoint, assignment_map)
示例6: get_feature_specification
# 需要导入模块: import gin [as 别名]
# 或者: from gin import tf [as 别名]
def get_feature_specification(
self, mode):
"""Required features for the model_fn/model_inference_fn.
Note, the model_fn might use additional features for debugging/development
purposes. The create_export_outputs_fn will however only require the
specified required features. Only this subset of features will be used to
generate automatic tf.Example extractors and numpy placeholders for the
serving models.
Args:
mode: The mode for feature specifications
"""
示例7: get_run_config
# 需要导入模块: import gin [as 别名]
# 或者: from gin import tf [as 别名]
def get_run_config(self):
"""Get the RunConfig for Estimator model.
Returns:
tf.estimator.RunConfig() for this model.
"""
return gin_configurable_run_config_cls(
session_config=self.get_session_config())
示例8: get_session_config
# 需要导入模块: import gin [as 别名]
# 或者: from gin import tf [as 别名]
def get_session_config(self):
"""Get the session config for Estimator model.
Defaults to None which tells tf.Estimator to use its default session config.
Not used in TPU jobs at the moment.
Returns:
None, or the desired session config.
"""
return None
示例9: get_feature_specification
# 需要导入模块: import gin [as 别名]
# 或者: from gin import tf [as 别名]
def get_feature_specification(
self, mode):
"""Returns the feature specification with bfloat16 replacing float32."""
return tensorspec_utils.replace_dtype(
self._t2r_model.get_feature_specification(mode),
from_dtype=tf.float32,
to_dtype=tf.bfloat16)
示例10: get_label_specification
# 需要导入模块: import gin [as 别名]
# 或者: from gin import tf [as 别名]
def get_label_specification(self, mode):
"""Returns the label specification with bfloat16 replacing float32."""
return tensorspec_utils.replace_dtype(
self._t2r_model.get_label_specification(mode),
from_dtype=tf.float32,
to_dtype=tf.bfloat16)
示例11: metric_sum
# 需要导入模块: import gin [as 别名]
# 或者: from gin import tf [as 别名]
def metric_sum(values, name=None, **kwargs):
del kwargs
with tf.variable_scope(name, "metric_sum", [values]):
accum = tf.get_variable(
"accum", shape=[], dtype=tf.float32, trainable=False,
collections=[tf.GraphKeys.LOCAL_VARIABLES],
initializer=tf.zeros_initializer())
update_op = tf.assign_add(accum, tf.reduce_sum(tf.cast(values, tf.float32)))
return accum, update_op
示例12: metric_max
# 需要导入模块: import gin [as 别名]
# 或者: from gin import tf [as 别名]
def metric_max(values, name=None, **kwargs):
del kwargs
with tf.variable_scope(name, "metric_max", [values]):
accum = tf.get_variable(
"accum", shape=[], dtype=tf.float32, trainable=False,
collections=[tf.GraphKeys.LOCAL_VARIABLES],
initializer=tf.zeros_initializer())
update_op = tf.assign(
accum, tf.maximum(accum, tf.reduce_max(tf.cast(values, tf.float32))))
return accum, update_op
示例13: get_inputs_from_file
# 需要导入模块: import gin [as 别名]
# 或者: from gin import tf [as 别名]
def get_inputs_from_file(input_filename, ignore_comments=False):
"""Read data from file and strip new lines."""
inputs = [line.rstrip() for line in tf.io.gfile.GFile(input_filename)]
# Strip the last empty line.
if not inputs[-1]:
inputs.pop()
if ignore_comments:
inputs = [l for l in inputs if not l.startswith("#")]
return inputs
示例14: decode
# 需要导入模块: import gin [as 别名]
# 或者: from gin import tf [as 别名]
def decode(estimator,
input_fn,
vocabulary,
checkpoint_path=None):
"""Decode from an input_fn.
Args:
estimator: a TPUEstimator
input_fn: function that returns a tf.Dataset
vocabulary: a vocabulary.Vocabulary or (inputs_vocabulary,
targets_vocabulary) tuple
checkpoint_path: an optional string
Returns:
list of decoded strings
"""
result_iter = estimator.predict(
input_fn, checkpoint_path=checkpoint_path)
def _maybe_detokenize(value, vocab):
if isinstance(value, six.binary_type):
return value
return vocab.decode([int(x) for x in value])
decodes = []
for i, result in enumerate(result_iter):
input_string = _maybe_detokenize(
result["inputs"], inputs_vocabulary(vocabulary))
output_string = _maybe_detokenize(
result["outputs"], targets_vocabulary(vocabulary))
decodes.append(output_string)
if i & (i - 1) == 0:
# LOG every power of 2.
tf.logging.info("decoded {}: {}".format(i, input_string))
tf.logging.info(" -> {}".format(output_string))
return decodes
示例15: write_lines_to_file
# 需要导入模块: import gin [as 别名]
# 或者: from gin import tf [as 别名]
def write_lines_to_file(lines, filename):
"""Write each line to a filename, replacing the file if it exists.
Args:
lines: list of str, lines to write out.
filename: str, path to filename.
"""
if tf.io.gfile.exists(filename):
tf.io.gfile.remove(filename)
with tf.io.gfile.GFile(filename, "w") as output_file:
for line in lines:
output_file.write("{}\n".format(line))