本文整理汇总了Python中tensorflow.python.pywrap_tensorflow.NewCheckpointReader方法的典型用法代码示例。如果您正苦于以下问题:Python pywrap_tensorflow.NewCheckpointReader方法的具体用法?Python pywrap_tensorflow.NewCheckpointReader怎么用?Python pywrap_tensorflow.NewCheckpointReader使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类tensorflow.python.pywrap_tensorflow
的用法示例。
在下文中一共展示了pywrap_tensorflow.NewCheckpointReader方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: print_tensors_in_checkpoint_file
# 需要导入模块: from tensorflow.python import pywrap_tensorflow [as 别名]
# 或者: from tensorflow.python.pywrap_tensorflow import NewCheckpointReader [as 别名]
def print_tensors_in_checkpoint_file(file_name, tensor_name, all_tensors):
try:
reader = pywrap_tensorflow.NewCheckpointReader(file_name)
if all_tensors:
var_to_shape_map = reader.get_variable_to_shape_map()
for key in var_to_shape_map:
print("tensor_name: ", key)
print(reader.get_tensor(key))
elif not tensor_name:
print(reader.debug_string().decode("utf-8"))
else:
print("tensor_name: ", tensor_name)
print(reader.get_tensor(tensor_name))
except Exception as e: # pylint: disable=broad-except
print(str(e))
if "corrupted compressed block contents" in str(e):
print("It's likely that your checkpoint file has been compressed "
"with SNAPPY.")
示例2: resore_form_rl_net
# 需要导入模块: from tensorflow.python import pywrap_tensorflow [as 别名]
# 或者: from tensorflow.python.pywrap_tensorflow import NewCheckpointReader [as 别名]
def resore_form_rl_net(self,ckpt_name, graph, sess):
print("Restore form RL net")
print("===== Prase data from %s =====" % ckpt_name)
net_prefix = 'pi/pi'
reader = pywrap_tensorflow.NewCheckpointReader(ckpt_name)
var_to_shape_map = reader.get_variable_to_shape_map()
for _key in var_to_shape_map:
print(_key)
# print("tensor_name: ", key)
# print(reader.get_tensor(key))
# tensor = graph.get_tensor_by_name(key)
if (str(_key).startswith('%s/net/'%net_prefix) or
str(_key).startswith('%s/Trajectory_follower_mlp_net/'%net_prefix)):
notaion_list = [m.start() for m in re.finditer('/', _key)]
key = _key[int(notaion_list[1]+1):len(_key)]+ ":0"
# print(key)
try:
tensor = graph.get_tensor_by_name(key)
sess.run(tf.assign(tensor, reader.get_tensor(_key)))
# print(tensor)
except Exception as e:
print(key, " can not be restored, e= ",str(e))
pass
示例3: resort_para_form_checkpoint
# 需要导入模块: from tensorflow.python import pywrap_tensorflow [as 别名]
# 或者: from tensorflow.python.pywrap_tensorflow import NewCheckpointReader [as 别名]
def resort_para_form_checkpoint(self, _ckpt_name_vec, graph, sess):
# with tf.name_scope("restore"):
if( isinstance(_ckpt_name_vec, list)):
ckpt_name_vec = _ckpt_name_vec
else:
ckpt_name_vec = [_ckpt_name_vec]
with tf.name_scope ("restore"):
for ckpt_name in ckpt_name_vec:
print("===== Restore data from %s =====" % ckpt_name)
reader = pywrap_tensorflow.NewCheckpointReader(ckpt_name)
var_to_shape_map = reader.get_variable_to_shape_map()
for key in var_to_shape_map:
# print("tensor_name: ", key)
# print(reader.get_tensor(key))
# tensor = graph.get_tensor_by_name(key)
try:
tensor = graph.get_tensor_by_name(key + ":0")
sess.run(tf.assign(tensor, reader.get_tensor(key)))
# print(tensor)
except:
# print(key, " can not be restored")
pass
示例4: create_model_from_npz_file
# 需要导入模块: from tensorflow.python import pywrap_tensorflow [as 别名]
# 或者: from tensorflow.python.pywrap_tensorflow import NewCheckpointReader [as 别名]
def create_model_from_npz_file(npz, model, target):
"""Creates a tensorflow model from a given npz structure in which the variables for the desired model are stored.
npz: Path to the npz structure containing files representing the variables in the model.
model: Path in which the final model should be stored
target: A target model which contains the desired names for the structure
"""
reader = pywrap_tensorflow.NewCheckpointReader(target)
target_map = reader.get_variable_to_shape_map()
variables = variables_dictionary_from_npz_file(npz)
i = 0
for key in variables:
if key_contained_in_map(key, target_map):
name = 'var' + str(i)
val = tf.Variable(variables[key], name=key)
exec(name + " = val")
i += 1
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
save_path = saver.save(sess, model)
print("Model saved in file: %s" % save_path)
示例5: print_all_in_ckpt
# 需要导入模块: from tensorflow.python import pywrap_tensorflow [as 别名]
# 或者: from tensorflow.python.pywrap_tensorflow import NewCheckpointReader [as 别名]
def print_all_in_ckpt(ckpt_path):
reader = pywrap_tensorflow.NewCheckpointReader(ckpt_path)
var_to_shape_map = reader.get_variable_to_shape_map()
for key in var_to_shape_map:
print('{} {}\n'.format(key, reader.get_tensor(key).shape))
#f.write('{}\n'.format(reader.get_tensor(key).shape))
示例6: get_all_variable_names
# 需要导入模块: from tensorflow.python import pywrap_tensorflow [as 别名]
# 或者: from tensorflow.python.pywrap_tensorflow import NewCheckpointReader [as 别名]
def get_all_variable_names(ckpt_path):
reader = pywrap_tensorflow.NewCheckpointReader(ckpt_path)
var_to_shape_map = reader.get_variable_to_shape_map()
return list(var_to_shape_map.keys())
示例7: getTransformEstimator
# 需要导入模块: from tensorflow.python import pywrap_tensorflow [as 别名]
# 或者: from tensorflow.python.pywrap_tensorflow import NewCheckpointReader [as 别名]
def getTransformEstimator(cls, priors: Tuple[Distribution, ...], K: int,
chptFile: str, dtype: tf.DType = tf.float32,
noiseUniformity: NoiseUniformity = HOMOGENEOUS,
stopCriterionInit=LlhStall(10),
stopCriterionEM=LlhStall(100),
stopCriterionBCD=LlhImprovementThreshold(1e-2),
path: str = "/tmp", device: str = "/cpu:0"):
# configuring warm start settings
reader = pywrap_tensorflow.NewCheckpointReader(chptFile)
varList = [v for v in reader.get_variable_to_shape_map().keys()
if (v != "U/0" and
v != "global_step" and
v != "stop" and
not v.startswith(f"stopCriterion{Phase.INIT.name}/") and
not v.startswith(f"stopCriterion{Phase.EM.name}/") and
not v.startswith(f"stopCriterion{Phase.BCD.name}/"))]
wsVars = "|".join(varList)
ws = tf.estimator.WarmStartSettings(ckpt_to_initialize_from=chptFile,
vars_to_warm_start=wsVars)
def model_fn(features, labels, mode):
es = cls.__estimatorSpec(mode=mode, features=features,
isFullyObserved=True,
device=device, priors=priors,
noiseUniformity=noiseUniformity,
stopCriterionInit=stopCriterionInit,
stopCriterionEM=stopCriterionEM,
stopCriterionBCD=stopCriterionBCD,
K=K, path=path, cv=None,
transform=True, dtype=dtype)
return(es)
est = tf.estimator.Estimator(model_fn=model_fn,
model_dir=path,
warm_start_from=ws)
return(est)
示例8: fit_transform
# 需要导入模块: from tensorflow.python import pywrap_tensorflow [as 别名]
# 或者: from tensorflow.python.pywrap_tensorflow import NewCheckpointReader [as 别名]
def fit_transform(self, X: np.ndarray) -> np.ndarray:
self.fit(X)
ckptFile = self.__tefa.latest_checkpoint()
ckptReader = pywrap_tensorflow.NewCheckpointReader(ckptFile)
U0 = ckptReader.get_tensor("U/0")
return(U0)
示例9: transform
# 需要导入模块: from tensorflow.python import pywrap_tensorflow [as 别名]
# 或者: from tensorflow.python.pywrap_tensorflow import NewCheckpointReader [as 别名]
def transform(self, X: np.ndarray,
transformModelDirectory: str) -> np.ndarray:
# create input_fn
x = {"test": X.astype(self.__dtype)}
input_fn = tf.estimator.inputs.numpy_input_fn(
x, y=None, batch_size=X.shape[0],
shuffle=False, num_epochs=None)
ckptFile = self.__tefa.latest_checkpoint()
tefaTransform = TensorFactorisation.getTransformEstimator(
priors=self.__priors,
K=self.n_components,
dtype=tf.as_dtype(self.__dtype),
path=transformModelDirectory,
chptFile=ckptFile,
noiseUniformity=self.noiseUniformity,
stopCriterionInit=self.__stopCriterionInit,
stopCriterionEM=self.__stopCriterionEM,
stopCriterionBCD=self.__stopCriterionBCD)
tefaTransform.train(input_fn=input_fn,
steps=self.__maxIterations,
hooks=[StopHook()])
ckptFile = tefaTransform.latest_checkpoint()
ckptReader = pywrap_tensorflow.NewCheckpointReader(ckptFile)
U0 = ckptReader.get_tensor("U/0tr")
return(U0)
示例10: print_checkpoint
# 需要导入模块: from tensorflow.python import pywrap_tensorflow [as 别名]
# 或者: from tensorflow.python.pywrap_tensorflow import NewCheckpointReader [as 别名]
def print_checkpoint(path:str, name:str):
path = get_model_name(path)
reader = pywrap_tensorflow.NewCheckpointReader(path)
if not name:
var_to_shape_map = reader.get_variable_to_shape_map()
for key in sorted(var_to_shape_map):
tensor = reader.get_tensor(key)
print('{}:{}'.format(key, tensor.shape))
print(tensor)
else:
print(name)
print(reader.get_tensor(name))
示例11: get_var_from_checkpoint
# 需要导入模块: from tensorflow.python import pywrap_tensorflow [as 别名]
# 或者: from tensorflow.python.pywrap_tensorflow import NewCheckpointReader [as 别名]
def get_var_from_checkpoint(path:str, tensor_name:str):
path = get_model_name(path)
reader = pywrap_tensorflow.NewCheckpointReader(path)
return reader.get_tensor(tensor_name)
示例12: get_layers
# 需要导入模块: from tensorflow.python import pywrap_tensorflow [as 别名]
# 或者: from tensorflow.python.pywrap_tensorflow import NewCheckpointReader [as 别名]
def get_layers(path):
if os.path.isdir(path):
path = tf.train.get_checkpoint_state(path).model_checkpoint_path
reader = pywrap_tensorflow.NewCheckpointReader(path)
return reader.get_variable_to_shape_map()
示例13: get_weights
# 需要导入模块: from tensorflow.python import pywrap_tensorflow [as 别名]
# 或者: from tensorflow.python.pywrap_tensorflow import NewCheckpointReader [as 别名]
def get_weights(path, name):
if os.path.isdir(path):
path = tf.train.get_checkpoint_state(path).model_checkpoint_path
reader = pywrap_tensorflow.NewCheckpointReader(path)
return reader.get_tensor(name)
示例14: load_inception_resnet_v2
# 需要导入模块: from tensorflow.python import pywrap_tensorflow [as 别名]
# 或者: from tensorflow.python.pywrap_tensorflow import NewCheckpointReader [as 别名]
def load_inception_resnet_v2(checkpoint_path, enable_aux=False):
model = InceptionResnetV2(enable_aux=enable_aux)
with chainer.no_backprop_mode():
model(np.random.randn(2, 3, 299, 299).astype('f')) # initialize params
if _tf_import_error is not None:
raise RuntimeError('could not import tensorflow; the import error as follows:\n' + str(_tf_import_error))
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
model.load_tf_checkpoint(reader, 'InceptionResnetV2')
return model
示例15: print_tensors_in_checkpoint_file
# 需要导入模块: from tensorflow.python import pywrap_tensorflow [as 别名]
# 或者: from tensorflow.python.pywrap_tensorflow import NewCheckpointReader [as 别名]
def print_tensors_in_checkpoint_file(file_name, tensor_name, all_tensors):
"""Prints tensors in a checkpoint file.
If no `tensor_name` is provided, prints the tensor names and shapes
in the checkpoint file.
If `tensor_name` is provided, prints the content of the tensor.
Args:
file_name: Name of the checkpoint file.
tensor_name: Name of the tensor in the checkpoint file to print.
all_tensors: Boolean indicating whether to print all tensors.
"""
try:
reader = pywrap_tensorflow.NewCheckpointReader(file_name)
if all_tensors:
var_to_shape_map = reader.get_variable_to_shape_map()
for key in sorted(var_to_shape_map):
print("tensor_name: ", key)
print(reader.get_tensor(key))
elif not tensor_name:
print(reader.debug_string().decode("utf-8"))
else:
print("tensor_name: ", tensor_name)
print(reader.get_tensor(tensor_name))
except Exception as e: # pylint: disable=broad-except
print(str(e))
if "corrupted compressed block contents" in str(e):
print("It's likely that your checkpoint file has been compressed "
"with SNAPPY.")
if ("Data loss" in str(e) and
(any([e in file_name for e in [".index", ".meta", ".data"]]))):
proposed_file = ".".join(file_name.split(".")[0:-1])
v2_file_error_template = """
It's likely that this is a V2 checkpoint and you need to provide the filename
*prefix*. Try removing the '.' and extension. Try:
inspect checkpoint --file_name = {}"""
print(v2_file_error_template.format(proposed_file))