本文整理汇总了Python中tensorflow.get_default_graph方法的典型用法代码示例。如果您正苦于以下问题:Python tensorflow.get_default_graph方法的具体用法?Python tensorflow.get_default_graph怎么用?Python tensorflow.get_default_graph使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类tensorflow
的用法示例。
在下文中一共展示了tensorflow.get_default_graph方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: build_from_pb
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import get_default_graph [as 别名]
def build_from_pb(self):
with tf.gfile.FastGFile(self.FLAGS.pbLoad, "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
tf.import_graph_def(
graph_def,
name=""
)
with open(self.FLAGS.metaLoad, 'r') as fp:
self.meta = json.load(fp)
self.framework = create_framework(self.meta, self.FLAGS)
# Placeholders
self.inp = tf.get_default_graph().get_tensor_by_name('input:0')
self.feed = dict() # other placeholders
self.out = tf.get_default_graph().get_tensor_by_name('output:0')
self.setup_meta_ops()
示例2: init_uninited_vars
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import get_default_graph [as 别名]
def init_uninited_vars(vars=None):
if vars is None: vars = tf.global_variables()
test_vars = []; test_ops = []
with tf.control_dependencies(None): # ignore surrounding control_dependencies
for var in vars:
assert is_tf_expression(var)
try:
tf.get_default_graph().get_tensor_by_name(var.name.replace(':0', '/IsVariableInitialized:0'))
except KeyError:
# Op does not exist => variable may be uninitialized.
test_vars.append(var)
with absolute_name_scope(var.name.split(':')[0]):
test_ops.append(tf.is_variable_initialized(var))
init_vars = [var for var, inited in zip(test_vars, run(test_ops)) if not inited]
run([var.initializer for var in init_vars])
#----------------------------------------------------------------------------
# Set the values of given tf.Variables.
# Equivalent to the following, but more efficient and does not bloat the tf graph:
# tfutil.run([tf.assign(var, value) for var, value in var_to_value_dict.items()]
示例3: _py_func_with_gradient
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import get_default_graph [as 别名]
def _py_func_with_gradient(func, inp, Tout, stateful=True, name=None,
grad_func=None):
"""
PyFunc defined as given by Tensorflow
:param func: Custom Function
:param inp: Function Inputs
:param Tout: Ouput Type of out Custom Function
:param stateful: Calculate Gradients when stateful is True
:param name: Name of the PyFunction
:param grad: Custom Gradient Function
:return:
"""
# Generate random name in order to avoid conflicts with inbuilt names
rnd_name = 'PyFuncGrad-' + '%0x' % getrandbits(30 * 4)
# Register Tensorflow Gradient
tf.RegisterGradient(rnd_name)(grad_func)
# Get current graph
g = tf.get_default_graph()
# Add gradient override map
with g.gradient_override_map(
{"PyFunc": rnd_name, "PyFuncStateless": rnd_name}):
return tf.py_func(func, inp, Tout, stateful=stateful, name=name)
示例4: network_surgery
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import get_default_graph [as 别名]
def network_surgery():
tf.reset_default_graph()
inputs = tf.placeholder(tf.float32,
shape=(None, 131072, 4),
name='inputs')
targets = tf.placeholder(tf.float32, shape=(None, 1024, 4229),
name='targets')
targets_na = tf.placeholder(tf.bool, shape=(None, 1024), name="targets_na")
preds_adhoc = tf.placeholder(tf.float32, shape=(None, 960, 4229), name="Placeholder_15")
saver = tf.train.import_meta_graph("model_files/model.tf.meta",
input_map={'Placeholder_15:0': preds_adhoc,
'Placeholder:0': targets_na,
'inputs:0': inputs,
'targets:0': targets
})
ops = tf.get_default_graph().get_operations()
out = tf.train.export_meta_graph(filename='model_files/model.tf-modified.meta', as_text=True)
ops[:15]
示例5: main
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import get_default_graph [as 别名]
def main(_):
tf.logging.set_verbosity(tf.logging.INFO)
tf.gfile.MakeDirs(FLAGS.eval_dir)
tf.logging.info('Building eval graph...')
output = graphs.get_model().eval_graph(FLAGS.eval_data)
eval_ops, moving_averaged_variables = output
saver = tf.train.Saver(moving_averaged_variables)
summary_writer = tf.summary.FileWriter(
FLAGS.eval_dir, graph=tf.get_default_graph())
while True:
run_eval(eval_ops, summary_writer, saver)
if FLAGS.run_once:
break
time.sleep(FLAGS.eval_interval_secs)
示例6: count_weights
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import get_default_graph [as 别名]
def count_weights(scope=None, exclude=None, graph=None):
"""Count learnable parameters.
Args:
scope: Resrict the count to a variable scope.
exclude: Regex to match variable names to exclude.
graph: Operate on a graph other than the current default graph.
Returns:
Number of learnable parameters as integer.
"""
if scope:
scope = scope if scope.endswith('/') else scope + '/'
graph = graph or tf.get_default_graph()
vars_ = graph.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
if scope:
vars_ = [var for var in vars_ if var.name.startswith(scope)]
if exclude:
exclude = re.compile(exclude)
vars_ = [var for var in vars_ if not exclude.match(var.name)]
shapes = [var.get_shape().as_list() for var in vars_]
return int(sum(np.prod(shape) for shape in shapes))
示例7: build_from_pb
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import get_default_graph [as 别名]
def build_from_pb(self):
with tf.gfile.FastGFile(self.FLAGS.pbLoad, "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
tf.import_graph_def(
graph_def,
name=""
)
with open(self.FLAGS.metaLoad, 'r') as fp:
self.meta = json.load(fp)
self.framework = create_framework(self.meta, self.FLAGS)
# Placeholders
self.inp = tf.get_default_graph().get_tensor_by_name('input:0')
self.feed = dict() # other placeholders
self.out = tf.get_default_graph().get_tensor_by_name('output:0')
self.setup_meta_ops()
示例8: __init__
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import get_default_graph [as 别名]
def __init__(self, checkpoint_filename, input_name="images",
output_name="features"):
self.session = tf.Session()
with tf.gfile.GFile(checkpoint_filename, "rb") as file_handle:
graph_def = tf.GraphDef()
graph_def.ParseFromString(file_handle.read())
tf.import_graph_def(graph_def, name="net")
self.input_var = tf.get_default_graph().get_tensor_by_name(
"net/%s:0" % input_name)
self.output_var = tf.get_default_graph().get_tensor_by_name(
"net/%s:0" % output_name)
assert len(self.output_var.get_shape()) == 2
assert len(self.input_var.get_shape()) == 4
self.feature_dim = self.output_var.get_shape().as_list()[-1]
self.image_shape = self.input_var.get_shape().as_list()[1:]
示例9: main
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import get_default_graph [as 别名]
def main():
args = parse_args()
with tf.Session(graph=tf.Graph()) as session:
input_var = tf.placeholder(
tf.uint8, (None, 128, 64, 3), name="images")
image_var = tf.map_fn(
lambda x: _preprocess(x), tf.cast(input_var, tf.float32),
back_prop=False)
factory_fn = _network_factory()
features, _ = factory_fn(image_var, reuse=None)
features = tf.identity(features, name="features")
saver = tf.train.Saver(slim.get_variables_to_restore())
saver.restore(session, args.checkpoint_in)
output_graph_def = tf.graph_util.convert_variables_to_constants(
session, tf.get_default_graph().as_graph_def(),
[features.name.split(":")[0]])
with tf.gfile.GFile(args.graphdef_out, "wb") as file_handle:
file_handle.write(output_graph_def.SerializeToString())
示例10: __init__
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import get_default_graph [as 别名]
def __init__(self, train_df, word_count, batch_size, epochs):
tf.set_random_seed(4)
session_conf = tf.ConfigProto(intra_op_parallelism_threads=2, inter_op_parallelism_threads=8)
backend.set_session(tf.Session(graph=tf.get_default_graph(), config=session_conf))
self.batch_size = batch_size
self.epochs = epochs
self.max_name_seq = 10
self.max_item_desc_seq = 75
self.max_text = word_count + 1
self.max_brand = np.max(train_df.brand_name.max()) + 1
self.max_condition = np.max(train_df.item_condition_id.max()) + 1
self.max_subcat0 = np.max(train_df.subcat_0.max()) + 1
self.max_subcat1 = np.max(train_df.subcat_1.max()) + 1
self.max_subcat2 = np.max(train_df.subcat_2.max()) + 1
示例11: __init__
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import get_default_graph [as 别名]
def __init__(self, num_features, discriminator = discriminator, generator = generator_gatedcnn, mode = 'train', log_dir = './log'):
self.num_features = num_features
self.input_shape = [None, num_features, None] # [batch_size, num_features, num_frames]
self.discriminator = discriminator
self.generator = generator
self.mode = mode
self.build_model()
self.optimizer_initializer()
self.saver = tf.train.Saver()
self.sess = tf.Session()
self.sess.run(tf.global_variables_initializer())
if self.mode == 'train':
self.train_step = 0
now = datetime.now()
self.log_dir = os.path.join(log_dir, now.strftime('%Y%m%d-%H%M%S'))
self.writer = tf.summary.FileWriter(self.log_dir, tf.get_default_graph())
self.generator_summaries, self.discriminator_summaries = self.summary()
示例12: test_with_dynamic_shape
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import get_default_graph [as 别名]
def test_with_dynamic_shape(self):
def fn(input_tensor):
return tf.reduce_sum(input_tensor)
input_tensor = tf.placeholder(tf.float32, shape=(None, 2))
map_fn_output = shape_utils.static_or_dynamic_map_fn(fn, input_tensor)
op_names = [op.name for op in tf.get_default_graph().get_operations()]
self.assertTrue(any(['map' == op_name[:3] for op_name in op_names]))
with self.test_session() as sess:
result1 = sess.run(
map_fn_output, feed_dict={
input_tensor: [[1, 2], [3, 1], [0, 4]]})
result2 = sess.run(
map_fn_output, feed_dict={
input_tensor: [[-1, 1], [0, 9]]})
self.assertAllEqual(result1, [3, 4, 4])
self.assertAllEqual(result2, [0, 9])
示例13: test_has_fused_batchnorm
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import get_default_graph [as 别名]
def test_has_fused_batchnorm(self, use_keras):
image_height = 40
image_width = 40
depth_multiplier = 1
pad_to_multiple = 1
image_placeholder = tf.placeholder(tf.float32,
[1, image_height, image_width, 3])
feature_extractor = self._create_feature_extractor(depth_multiplier,
pad_to_multiple,
use_keras=use_keras)
preprocessed_image = feature_extractor.preprocess(image_placeholder)
if use_keras:
_ = feature_extractor(preprocessed_image)
else:
_ = feature_extractor.extract_features(preprocessed_image)
self.assertTrue(any(op.type == 'FusedBatchNorm'
for op in tf.get_default_graph().get_operations()))
开发者ID:ahmetozlu,项目名称:vehicle_counting_tensorflow,代码行数:19,代码来源:ssd_mobilenet_v2_feature_extractor_test.py
示例14: build
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import get_default_graph [as 别名]
def build(graph_rewriter_config, is_training):
"""Returns a function that modifies default graph based on options.
Args:
graph_rewriter_config: graph_rewriter_pb2.GraphRewriter proto.
is_training: whether in training of eval mode.
"""
def graph_rewrite_fn():
"""Function to quantize weights and activation of the default graph."""
if (graph_rewriter_config.quantization.weight_bits != 8 or
graph_rewriter_config.quantization.activation_bits != 8):
raise ValueError('Only 8bit quantization is supported')
# Quantize the graph by inserting quantize ops for weights and activations
if is_training:
tf.contrib.quantize.create_training_graph(
input_graph=tf.get_default_graph(),
quant_delay=graph_rewriter_config.quantization.delay)
else:
tf.contrib.quantize.create_eval_graph(input_graph=tf.get_default_graph())
tf.contrib.layers.summarize_collection('quant_vars')
return graph_rewrite_fn
示例15: testQuantizationBuilderSetsUpCorrectTrainArguments
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import get_default_graph [as 别名]
def testQuantizationBuilderSetsUpCorrectTrainArguments(self):
with mock.patch.object(
tf.contrib.quantize, 'create_training_graph') as mock_quant_fn:
with mock.patch.object(tf.contrib.layers,
'summarize_collection') as mock_summarize_col:
graph_rewriter_proto = graph_rewriter_pb2.GraphRewriter()
graph_rewriter_proto.quantization.delay = 10
graph_rewriter_proto.quantization.weight_bits = 8
graph_rewriter_proto.quantization.activation_bits = 8
graph_rewrite_fn = graph_rewriter_builder.build(
graph_rewriter_proto, is_training=True)
graph_rewrite_fn()
_, kwargs = mock_quant_fn.call_args
self.assertEqual(kwargs['input_graph'], tf.get_default_graph())
self.assertEqual(kwargs['quant_delay'], 10)
mock_summarize_col.assert_called_with('quant_vars')