本文整理汇总了Python中tensorflow.compat.v1.get_default_graph方法的典型用法代码示例。如果您正苦于以下问题:Python v1.get_default_graph方法的具体用法?Python v1.get_default_graph怎么用?Python v1.get_default_graph使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类tensorflow.compat.v1
的用法示例。
在下文中一共展示了v1.get_default_graph方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: underlying_variable
# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import get_default_graph [as 别名]
def underlying_variable(t):
"""Find the underlying tf.Variable object.
Args:
t: a Tensor
Returns:
tf.Variable.
"""
t = underlying_variable_ref(t)
assert t is not None
# make sure that the graph has a variable index and that it is up-to-date
if not hasattr(tf.get_default_graph(), "var_index"):
tf.get_default_graph().var_index = {}
var_index = tf.get_default_graph().var_index
for v in tf.global_variables()[len(var_index):]:
var_index[v.name] = v
return var_index[t.name]
示例2: testLossDecorated
# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import get_default_graph [as 别名]
def testLossDecorated(self):
self.BuildWithBatchNorm(True)
self.AddRegularizer()
# Create network regularizer with DummyDecorator op regularization.
self.gamma_flop_reg = flop_regularizer.GammaFlopsRegularizer(
[self.conv3.op, self.conv4.op],
gamma_threshold=0.45,
regularizer_decorator=dummy_decorator.DummyDecorator,
decorator_parameters={'scale': 0.5})
all_convs = [
o for o in tf.get_default_graph().get_operations() if o.type == 'Conv2D'
]
total_reg_term = 1410376.375
self.assertAllClose(total_reg_term * 0.5, self.GetLoss(all_convs))
self.assertAllClose(total_reg_term * 0.5, self.GetLoss([]))
示例3: test_group_lasso_conv3d
# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import get_default_graph [as 别名]
def test_group_lasso_conv3d(self):
shape = [3, 3, 3]
video = tf.zeros([2, 3, 3, 3, 1])
net = slim.conv3d(
video,
5,
shape,
padding='VALID',
weights_initializer=tf.glorot_normal_initializer(),
scope='vconv1')
conv3d_op = tf.get_default_graph().get_operation_by_name('vconv1/Conv3D')
conv3d_weights = conv3d_op.inputs[1]
threshold = 0.09
flop_reg = flop_regularizer.GroupLassoFlopsRegularizer([net.op],
threshold=threshold)
norm = tf.sqrt(tf.reduce_mean(tf.square(conv3d_weights), [0, 1, 2, 3]))
alive = tf.reduce_sum(tf.cast(norm > threshold, tf.float32))
with self.session():
flop_coeff = 2 * shape[0] * shape[1] * shape[2]
tf.compat.v1.global_variables_initializer().run()
self.assertAllClose(flop_reg.get_cost(), flop_coeff * alive)
self.assertAllClose(flop_reg.get_regularization_term(),
flop_coeff * tf.reduce_sum(norm))
示例4: testShareParams
# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import get_default_graph [as 别名]
def testShareParams(self):
# Tests reuse option.
first_outputs = 2
alternate_num_outputs = 12
parameterization = {'first/Conv2D': first_outputs}
decorator = ops.ConfigurableOps(parameterization=parameterization)
explicit = layers.conv2d(
self.inputs, first_outputs, 3, scope='first')
with arg_scope([layers.conv2d], reuse=True):
decorated = decorator.conv2d(
self.inputs,
num_outputs=alternate_num_outputs,
kernel_size=3,
scope='first')
with self.cached_session():
tf.global_variables_initializer().run()
# verifies that parameters are shared.
self.assertAllClose(explicit.eval(), decorated.eval())
conv_ops = sorted([
op.name
for op in tf.get_default_graph().get_operations()
if op.type == 'Conv2D'
])
self.assertAllEqual(['first/Conv2D', 'first_1/Conv2D'], conv_ops)
示例5: test_fused_batchnorm
# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import get_default_graph [as 别名]
def test_fused_batchnorm(self, use_depthwise):
use_keras = False
image_height = 256
image_width = 256
depth_multiplier = 1
pad_to_multiple = 1
image_placeholder = tf.placeholder(tf.float32,
[1, image_height, image_width, 3])
feature_extractor = self._create_feature_extractor(
depth_multiplier,
pad_to_multiple,
use_keras=use_keras,
use_depthwise=use_depthwise)
preprocessed_image = feature_extractor.preprocess(image_placeholder)
_ = feature_extractor.extract_features(preprocessed_image)
self.assertTrue(
any('FusedBatchNorm' in op.type
for op in tf.get_default_graph().get_operations()))
示例6: test_overwriting_activation_fn
# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import get_default_graph [as 别名]
def test_overwriting_activation_fn(self):
for architecture in ['resnet_v1_50', 'resnet_v1_101', 'resnet_v1_152']:
feature_extractor = self._build_feature_extractor(
first_stage_features_stride=16,
architecture=architecture,
activation_fn=tf.nn.relu6)
preprocessed_inputs = tf.random_uniform([4, 224, 224, 3],
maxval=255,
dtype=tf.float32)
rpn_feature_map, _ = feature_extractor.extract_proposal_features(
preprocessed_inputs, scope='TestStage1Scope')
_ = feature_extractor.extract_box_classifier_features(
rpn_feature_map, scope='TestStaget2Scope')
conv_ops = [
op for op in tf.get_default_graph().get_operations()
if op.type == 'Relu6'
]
op_names = [op.name for op in conv_ops]
self.assertIsNotNone(conv_ops)
self.assertIn('TestStage1Scope/resnet_v1_50/resnet_v1_50/conv1/Relu6',
op_names)
self.assertIn(
'TestStaget2Scope/resnet_v1_50/block4/unit_1/bottleneck_v1/conv1/Relu6',
op_names)
示例7: testQuantizationBuilderSetsUpCorrectTrainArguments
# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import get_default_graph [as 别名]
def testQuantizationBuilderSetsUpCorrectTrainArguments(self):
with mock.patch.object(
contrib_quantize,
'experimental_create_training_graph') as mock_quant_fn:
with mock.patch.object(slim,
'summarize_collection') as mock_summarize_col:
graph_rewriter_proto = graph_rewriter_pb2.GraphRewriter()
graph_rewriter_proto.quantization.delay = 10
graph_rewriter_proto.quantization.weight_bits = 8
graph_rewriter_proto.quantization.activation_bits = 8
graph_rewrite_fn = graph_rewriter_builder.build(
graph_rewriter_proto, is_training=True)
graph_rewrite_fn()
_, kwargs = mock_quant_fn.call_args
self.assertEqual(kwargs['input_graph'], tf.get_default_graph())
self.assertEqual(kwargs['quant_delay'], 10)
mock_summarize_col.assert_called_with('quant_vars')
示例8: test_output_nodes_for_tflite
# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import get_default_graph [as 别名]
def test_output_nodes_for_tflite(self):
image_height = 64
image_width = 64
depth_multiplier = 1.0
pad_to_multiple = 1
image_placeholder = tf.placeholder(tf.float32,
[1, image_height, image_width, 3])
feature_extractor = self._create_feature_extractor(depth_multiplier,
pad_to_multiple)
preprocessed_image = feature_extractor.preprocess(image_placeholder)
_ = feature_extractor.extract_features(preprocessed_image, unroll_length=1)
tflite_nodes = [
'raw_inputs/init_lstm_c',
'raw_inputs/init_lstm_h',
'raw_inputs/base_endpoint',
'raw_outputs/lstm_c',
'raw_outputs/lstm_h',
'raw_outputs/base_endpoint_1',
'raw_outputs/base_endpoint_2'
]
ops_names = [op.name for op in tf.get_default_graph().get_operations()]
for node in tflite_nodes:
self.assertTrue(any(node in s for s in ops_names))
开发者ID:tensorflow,项目名称:models,代码行数:26,代码来源:lstm_ssd_interleaved_mobilenet_v2_feature_extractor_test.py
示例9: test_fixed_concat_nodes
# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import get_default_graph [as 别名]
def test_fixed_concat_nodes(self):
image_height = 64
image_width = 64
depth_multiplier = 1.0
pad_to_multiple = 1
image_placeholder = tf.placeholder(tf.float32,
[1, image_height, image_width, 3])
feature_extractor = self._create_feature_extractor(
depth_multiplier, pad_to_multiple, is_quantized=True)
preprocessed_image = feature_extractor.preprocess(image_placeholder)
_ = feature_extractor.extract_features(preprocessed_image, unroll_length=1)
concat_nodes = [
'MobilenetV2_1/expanded_conv_16/project/Relu6',
'MobilenetV2_2/expanded_conv_16/project/Relu6'
]
ops_names = [op.name for op in tf.get_default_graph().get_operations()]
for node in concat_nodes:
self.assertTrue(any(node in s for s in ops_names))
开发者ID:tensorflow,项目名称:models,代码行数:21,代码来源:lstm_ssd_interleaved_mobilenet_v2_feature_extractor_test.py
示例10: run_benchmark
# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import get_default_graph [as 别名]
def run_benchmark(bench_cnn, num_iters):
"""Runs the all-reduce benchmark.
Args:
bench_cnn: The BenchmarkCNN where params, the variable manager, and other
attributes are obtained.
num_iters: Number of iterations to do all-reduce for for.
Raises:
ValueError: Invalid params of bench_cnn.
"""
if bench_cnn.params.variable_update != 'replicated':
raise ValueError('--variable_update=replicated must be specified to use'
'the all-reduce benchmark')
if bench_cnn.params.variable_consistency == 'relaxed':
raise ValueError('--variable_consistency=relaxed is not supported')
benchmark_op = build_graph(bench_cnn.raw_devices,
get_var_shapes(bench_cnn.model),
bench_cnn.variable_mgr, num_iters)
init_ops = [
tf.global_variables_initializer(),
bench_cnn.variable_mgr.get_post_init_ops()
]
loss_op = tf.no_op()
if bench_cnn.graph_file:
path, filename = os.path.split(bench_cnn.graph_file)
as_text = filename.endswith('txt')
log_fn('Writing GraphDef as %s to %s' % (
'text' if as_text else 'binary', bench_cnn.graph_file))
tf.train.write_graph(tf.get_default_graph().as_graph_def(add_shapes=True),
path, filename, as_text)
run_graph(benchmark_op, bench_cnn, init_ops, loss_op)
# TODO(reedwm): Reduce redundancy with tf_cnn_benchmarks
示例11: find_ops
# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import get_default_graph [as 别名]
def find_ops(optype):
"""Find ops of a given type in graphdef or a graph.
Args:
optype: operation type (e.g. Conv2D)
Returns:
List of operations.
"""
gd = tf.get_default_graph()
return [var for var in gd.get_operations() if var.type == optype]
示例12: _run_eval
# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import get_default_graph [as 别名]
def _run_eval(self):
"""Evaluate a model every self.params.eval_interval_secs.
Returns:
Dictionary containing eval statistics. Currently returns an empty
dictionary.
Raises:
ValueError: If self.params.train_dir is unspecified.
"""
if self.params.train_dir is None:
raise ValueError('Trained model directory not specified')
graph_info = self._build_eval_graph()
saver = tf.train.Saver(self.variable_mgr.savable_variables())
summary_writer = tf.summary.FileWriter(self.params.eval_dir,
tf.get_default_graph())
target = ''
# TODO(huangyp): Check if checkpoints haven't updated for hours and abort.
while True:
with tf.Session(
target=target, config=create_config_proto(self.params)) as sess:
image_producer = None
try:
global_step = load_checkpoint(saver, sess, self.params.train_dir)
image_producer = self._initialize_eval_graph(
graph_info.enqueue_ops, graph_info.input_producer_op,
graph_info.local_var_init_op_group, sess)
except CheckpointNotFoundException:
log_fn('Checkpoint not found in %s' % self.params.train_dir)
else: # Only executes if an exception was not thrown
self._eval_once(sess, summary_writer, graph_info.fetches,
graph_info.summary_op, image_producer, global_step)
if image_producer is not None:
image_producer.done()
if self.params.eval_interval_secs <= 0:
break
time.sleep(self.params.eval_interval_secs)
return {}
示例13: remove_summaries
# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import get_default_graph [as 别名]
def remove_summaries():
"""Remove summaries from the default graph."""
g = tf.get_default_graph()
key = tf.GraphKeys.SUMMARIES
log_debug("Remove summaries %s" % str(g.get_collection(key)))
del g.get_collection_ref(key)[:]
assert not g.get_collection(key)
示例14: framework
# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import get_default_graph [as 别名]
def framework(msg='err'):
"""Return framework module or dummy version."""
del msg
if is_tf2:
return DummyModule(
arg_scope=None,
get_name_scope=lambda: tf.get_default_graph().get_name_scope(),
name_scope=tf.name_scope,
deprecated=deprecated,
nest=tf.nest,
argsort=tf.argsort)
from tensorflow.contrib import framework as contrib_framework # pylint: disable=g-direct-tensorflow-import,g-import-not-at-top
return contrib_framework
示例15: _get_beta_accumulators
# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import get_default_graph [as 别名]
def _get_beta_accumulators(self):
with tf.init_scope():
if tf.executing_eagerly():
graph = None
else:
graph = tf.get_default_graph()
return (self._get_non_slot_variable("beta1_power", graph=graph),
self._get_non_slot_variable("beta2_power", graph=graph))