tf.estimator.Estimators
中的热启动设置。
用法
tf.estimator.WarmStartSettings(
ckpt_to_initialize_from, vars_to_warm_start='.*',
var_name_to_vocab_info=None, var_name_to_prev_var_name=None
)
属性
-
ckpt_to_initialize_from
[必需] 一个字符串,指定带有检查点文件的目录或检查点的路径,从该目录中热启动模型参数。 -
vars_to_warm_start
[可选] 以下之一:捕获要热启动的变量的正则表达式(字符串)(请参阅 tf.compat.v1.get_collection)。此表达式将仅考虑 TRAINABLE_VARIABLES 集合中的变量——如果您需要热启动 non_TRAINABLE 变量(例如优化器累加器或批量标准统计信息),请使用以下选项。
字符串列表,每个字符串都是通过GLOBAL_VARIABLES 提供给 tf.compat.v1.get_collection 的正则表达式范围(请参阅 tf.compat.v1.get_collection)。出于向后兼容性的原因,这与 single-string 参数类型是分开的。
要热启动的变量列表。如果您无权访问调用站点的
Variable
对象,请使用上述选项。None
,在这种情况下,只有在var_name_to_vocab_info
中指定的 TRAINABLE 变量将被热启动。
默认为
'.*'
,其中 warm-starts TRAINABLE_VARIABLES 集合中的所有变量。请注意,这不包括变量,例如累加器和批量标准的移动统计信息。 -
var_name_to_vocab_info
[可选]tf.estimator.VocabInfo
的变量名称(字符串)字典。变量名称应该是"full" 变量,而不是分区的名称。如果未明确提供,则假定该变量没有(更改)词汇表。 -
var_name_to_prev_var_name
[可选] 变量名称(字符串)的字典到ckpt_to_initialize_from
中的 previously-trained 变量的名称。如果未明确提供,则假定变量的名称在先前的检查点和当前模型之间是相同的。请注意,这对热启动的变量集没有影响,并且仅控制名称映射(使用vars_to_warm_start
来控制要热启动的变量)。
与罐装 tf.estimator.DNNEstimator
一起使用的示例:
emb_vocab_file = tf.feature_column.embedding_column(
tf.feature_column.categorical_column_with_vocabulary_file(
"sc_vocab_file", "new_vocab.txt", vocab_size=100),
dimension=8)
emb_vocab_list = tf.feature_column.embedding_column(
tf.feature_column.categorical_column_with_vocabulary_list(
"sc_vocab_list", vocabulary_list=["a", "b"]),
dimension=8)
estimator = tf.estimator.DNNClassifier(
hidden_units=[128, 64], feature_columns=[emb_vocab_file, emb_vocab_list],
warm_start_from=ws)
其中ws
可以定义为:
热启动模型中的所有权重(输入层和隐藏权重)。可以提供目录或特定检查点(在前者的情况下,将使用最新的检查点):
ws = WarmStartSettings(ckpt_to_initialize_from="/tmp")
ws = WarmStartSettings(ckpt_to_initialize_from="/tmp/model-1000")
仅热启动嵌入(输入层):
ws = WarmStartSettings(ckpt_to_initialize_from="/tmp",
vars_to_warm_start=".*input_layer.*")
热启动所有权重,但对应于 sc_vocab_file
的嵌入参数与当前模型中使用的词汇不同:
vocab_info = tf.estimator.VocabInfo(
new_vocab=sc_vocab_file.vocabulary_file,
new_vocab_size=sc_vocab_file.vocabulary_size,
num_oov_buckets=sc_vocab_file.num_oov_buckets,
old_vocab="old_vocab.txt"
)
ws = WarmStartSettings(
ckpt_to_initialize_from="/tmp",
var_name_to_vocab_info={
"input_layer/sc_vocab_file_embedding/embedding_weights":vocab_info
})
仅热启动 sc_vocab_file
嵌入(没有其他变量),其词汇与当前模型中使用的词汇不同:
vocab_info = tf.estimator.VocabInfo(
new_vocab=sc_vocab_file.vocabulary_file,
new_vocab_size=sc_vocab_file.vocabulary_size,
num_oov_buckets=sc_vocab_file.num_oov_buckets,
old_vocab="old_vocab.txt"
)
ws = WarmStartSettings(
ckpt_to_initialize_from="/tmp",
vars_to_warm_start=None,
var_name_to_vocab_info={
"input_layer/sc_vocab_file_embedding/embedding_weights":vocab_info
})
热启动所有权重,但对应于 sc_vocab_file
的参数与当前检查点中使用的词汇不同,并且仅使用了其中的 100 个条目:
vocab_info = tf.estimator.VocabInfo(
new_vocab=sc_vocab_file.vocabulary_file,
new_vocab_size=sc_vocab_file.vocabulary_size,
num_oov_buckets=sc_vocab_file.num_oov_buckets,
old_vocab="old_vocab.txt",
old_vocab_size=100
)
ws = WarmStartSettings(
ckpt_to_initialize_from="/tmp",
var_name_to_vocab_info={
"input_layer/sc_vocab_file_embedding/embedding_weights":vocab_info
})
热启动所有权重,但sc_vocab_file
对应的参数与当前检查点中使用的参数不同,sc_vocab_list
对应的参数名称与当前检查点不同:
vocab_info = tf.estimator.VocabInfo(
new_vocab=sc_vocab_file.vocabulary_file,
new_vocab_size=sc_vocab_file.vocabulary_size,
num_oov_buckets=sc_vocab_file.num_oov_buckets,
old_vocab="old_vocab.txt",
old_vocab_size=100
)
ws = WarmStartSettings(
ckpt_to_initialize_from="/tmp",
var_name_to_vocab_info={
"input_layer/sc_vocab_file_embedding/embedding_weights":vocab_info
},
var_name_to_prev_var_name={
"input_layer/sc_vocab_list_embedding/embedding_weights":
"old_tensor_name"
})
热启动所有 TRAINABLE 变量:
ws = WarmStartSettings(ckpt_to_initialize_from="/tmp",
vars_to_warm_start=".*")
热启动所有变量(包括非 TRAINABLE):
ws = WarmStartSettings(ckpt_to_initialize_from="/tmp",
vars_to_warm_start=[".*"])
热启动非 TRAINABLE 变量 "v1"、"v1/Momentum" 和 "v2" 但不是 "v2/momentum":
ws = WarmStartSettings(ckpt_to_initialize_from="/tmp",
vars_to_warm_start=["v1", "v2[^/]"])
相关用法
- Python tf.estimator.TrainSpec用法及代码示例
- Python tf.estimator.LogisticRegressionHead用法及代码示例
- Python tf.estimator.MultiHead用法及代码示例
- Python tf.estimator.PoissonRegressionHead用法及代码示例
- Python tf.estimator.experimental.stop_if_lower_hook用法及代码示例
- Python tf.estimator.RunConfig用法及代码示例
- Python tf.estimator.MultiLabelHead用法及代码示例
- Python tf.estimator.experimental.stop_if_no_increase_hook用法及代码示例
- Python tf.estimator.BaselineEstimator用法及代码示例
- Python tf.estimator.DNNLinearCombinedEstimator用法及代码示例
- Python tf.estimator.Estimator用法及代码示例
- Python tf.estimator.experimental.LinearSDCA用法及代码示例
- Python tf.estimator.experimental.RNNClassifier用法及代码示例
- Python tf.estimator.experimental.make_early_stopping_hook用法及代码示例
- Python tf.estimator.LinearRegressor用法及代码示例
- Python tf.estimator.LinearEstimator用法及代码示例
- Python tf.estimator.DNNClassifier用法及代码示例
- Python tf.estimator.BaselineClassifier用法及代码示例
- Python tf.estimator.experimental.stop_if_higher_hook用法及代码示例
- Python tf.estimator.train_and_evaluate用法及代码示例
注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.estimator.WarmStartSettings。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。