当前位置: 首页>>代码示例>>Python>>正文


Python context.in_graph_mode函数代码示例

本文整理汇总了Python中tensorflow.python.eager.context.in_graph_mode函数的典型用法代码示例。如果您正苦于以下问题:Python in_graph_mode函数的具体用法?Python in_graph_mode怎么用?Python in_graph_mode使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。


在下文中一共展示了in_graph_mode函数的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: testAddWeight

  def testAddWeight(self):
    layer = base_layers.Layer(name='my_layer')

    # Test basic variable creation.
    variable = layer.add_variable(
        'my_var', [2, 2], initializer=init_ops.zeros_initializer())
    self.assertEqual(variable.name, 'my_layer/my_var:0')
    self.assertListEqual(layer.variables, [variable])
    self.assertListEqual(layer.trainable_variables, [variable])
    self.assertListEqual(layer.non_trainable_variables, [])
    if context.in_graph_mode():
      self.assertListEqual(
          layer.variables,
          ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES))

    # Test non-trainable variable creation.
    # layer.add_variable should work even outside `build` and `call`.
    variable_2 = layer.add_variable(
        'non_trainable_var', [2, 2],
        initializer=init_ops.zeros_initializer(),
        trainable=False)
    self.assertListEqual(layer.variables, [variable, variable_2])
    self.assertListEqual(layer.trainable_variables, [variable])
    self.assertListEqual(layer.non_trainable_variables, [variable_2])
    if context.in_graph_mode():
      self.assertEqual(
          len(ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)), 1)

      # regularizers only supported in GRAPH mode.
      regularizer = lambda x: math_ops.reduce_sum(x) * 1e-3
      variable = layer.add_variable(
          'reg_var', [2, 2],
          initializer=init_ops.zeros_initializer(),
          regularizer=regularizer)
      self.assertEqual(len(layer.losses), 1)
开发者ID:DjangoPeng,项目名称:tensorflow,代码行数:35,代码来源:base_test.py

示例2: testAddVariable

  def testAddVariable(self):
    obj = NonLayerCheckpointable()
    with self.assertRaisesRegexp(ValueError, "do not specify shape"):
      checkpointable_utils.add_variable(
          obj, name="shape_specified_twice", shape=[], initializer=1)
    constant_initializer = checkpointable_utils.add_variable(
        obj, name="constant_initializer", initializer=1)
    with variable_scope.variable_scope("some_variable_scope"):
      ones_initializer = checkpointable_utils.add_variable(
          obj,
          name="ones_initializer",
          shape=[2],
          initializer=init_ops.ones_initializer(dtype=dtypes.float32))
    bare_initializer = checkpointable_utils.add_variable(
        obj,
        name="bare_initializer",
        shape=[2, 2],
        dtype=dtypes.float64,
        initializer=init_ops.zeros_initializer)

    # Even in graph mode, there are no naming conflicts between objects, only
    # naming conflicts within an object.
    other_duplicate = resource_variable_ops.ResourceVariable(
        name="duplicate", initial_value=1.)
    duplicate = checkpointable_utils.add_variable(
        obj, name="duplicate", shape=[])
    with self.assertRaisesRegexp(ValueError, "'duplicate' already exists"):
      checkpointable_utils.add_variable(obj, name="duplicate", shape=[])

    if context.in_graph_mode():
      self.evaluate(variables.global_variables_initializer())
    self.assertEqual("constant_initializer:0", constant_initializer.name)
    self.assertEqual(1, self.evaluate(constant_initializer))
    self.assertEqual("some_variable_scope/ones_initializer:0",
                     ones_initializer.name)
    self.assertAllEqual([1, 1], self.evaluate(ones_initializer))
    self.assertAllEqual([[0., 0.],
                         [0., 0.]], self.evaluate(bare_initializer))
    self.assertEqual("a_variable:0", obj.a_variable.name)
    self.assertEqual("duplicate:0", other_duplicate.name)
    if context.in_graph_mode():
      # The .name attribute may be globally influenced, but the checkpoint name
      # won't be (tested below).
      self.assertEqual("duplicate_1:0", duplicate.name)
    else:
      # When executing eagerly, there's no uniquification of variable names. The
      # checkpoint name will be the same.
      self.assertEqual("duplicate:0", duplicate.name)
    named_variables, _ = checkpointable_utils._serialize_object_graph(obj)
    expected_checkpoint_names = (
        "a_variable/.ATTRIBUTES/VARIABLE_VALUE",
        "bare_initializer/.ATTRIBUTES/VARIABLE_VALUE",
        "constant_initializer/.ATTRIBUTES/VARIABLE_VALUE",
        "duplicate/.ATTRIBUTES/VARIABLE_VALUE",
        "ones_initializer/.ATTRIBUTES/VARIABLE_VALUE",
    )
    six.assertCountEqual(
        self, expected_checkpoint_names, named_variables.keys())
开发者ID:dananjayamahesh,项目名称:tensorflow,代码行数:58,代码来源:checkpointable_utils_test.py

示例3: testDeferredSlotRestoration

  def testDeferredSlotRestoration(self):
    checkpoint_directory = self.get_temp_dir()

    root = checkpointable.Checkpointable()
    root.var = checkpointable_utils.add_variable(
        root, name="var", initializer=0.)
    optimizer = CheckpointableAdam(0.1)
    if context.in_graph_mode():
      train_op = optimizer.minimize(root.var)
      self.evaluate(variables.global_variables_initializer())
      self.evaluate(train_op)
    else:
      optimizer.minimize(root.var.read_value)
    self.evaluate(state_ops.assign(root.var, 12.))
    no_slots_path = checkpointable_utils.Saver(root).save(
        os.path.join(checkpoint_directory, "no_slots"))
    root.optimizer = optimizer
    self.evaluate(state_ops.assign(root.var, 13.))
    self.evaluate(state_ops.assign(optimizer.get_slot(name="m", var=root.var),
                                   14.))
    slots_path = checkpointable_utils.Saver(root).save(
        os.path.join(checkpoint_directory, "with_slots"))
    new_root = checkpointable.Checkpointable()
    # Load the slot-containing checkpoint (deferred), then immediately overwrite
    # the non-slot variable (also deferred).
    slot_status = checkpointable_utils.Saver(new_root).restore(slots_path)
    no_slot_status = checkpointable_utils.Saver(new_root).restore(no_slots_path)
    with self.assertRaises(AssertionError):
      no_slot_status.assert_consumed()
    new_root.var = checkpointable_utils.add_variable(
        new_root, name="var", shape=[])
    no_slot_status.assert_consumed()
    no_slot_status.run_restore_ops()
    self.assertEqual(12., self.evaluate(new_root.var))
    new_root.optimizer = CheckpointableAdam(0.1)
    with self.assertRaisesRegexp(AssertionError, "beta1_power"):
      slot_status.assert_consumed()
    self.assertEqual(12., self.evaluate(new_root.var))
    if context.in_eager_mode():
      # Slot variables are only created with restoring initializers when
      # executing eagerly.
      self.assertEqual(14., self.evaluate(
          new_root.optimizer.get_slot(name="m", var=new_root.var)))
    else:
      self.assertIs(new_root.optimizer.get_slot(name="m", var=new_root.var),
                    None)
    if context.in_graph_mode():
      train_op = new_root.optimizer.minimize(new_root.var)
      # The slot variable now exists; restore() didn't create it, but we should
      # now have a restore op for it.
      slot_status.run_restore_ops()
      self.assertEqual(14., self.evaluate(
          new_root.optimizer.get_slot(name="m", var=new_root.var)))
      self.evaluate(train_op)
    else:
      new_root.optimizer.minimize(new_root.var.read_value)
    slot_status.assert_consumed()
开发者ID:keithc61,项目名称:tensorflow,代码行数:57,代码来源:checkpointable_utils_test.py

示例4: testActivation

  def testActivation(self):
    dense = core_layers.Dense(2, activation=nn_ops.relu, name='dense1')
    inputs = random_ops.random_uniform((5, 3), seed=1)
    outputs = dense(inputs)
    if context.in_graph_mode():
      self.assertEqual(outputs.op.name, 'dense1/Relu')

    dense = core_layers.Dense(2, name='dense2')
    inputs = random_ops.random_uniform((5, 3), seed=1)
    outputs = dense(inputs)
    if context.in_graph_mode():
      self.assertEqual(outputs.op.name, 'dense2/BiasAdd')
开发者ID:AbhinavJain13,项目名称:tensorflow,代码行数:12,代码来源:core_test.py

示例5: test_variable_reuse_exception_nested

 def test_variable_reuse_exception_nested(self):
   with test_util.IsolateTest(), session.Session():
     first_container_variable = resource_variable_ops.ResourceVariable(
         name="first_container_variable",
         initial_value=1)
     if context.in_graph_mode():
       self.evaluate([variables.global_variables_initializer()])
     with test_util.IsolateTest(), session.Session():
       if context.in_graph_mode():
         with self.assertRaises(RuntimeError):
           self.evaluate(first_container_variable.read_value())
       else:
         with self.assertRaises(ValueError):
           first_container_variable.read_value()
开发者ID:SylChan,项目名称:tensorflow,代码行数:14,代码来源:test_util_test.py

示例6: test_name_scopes_for_variable_scopes

  def test_name_scopes_for_variable_scopes(self):
    # Test that name scopes are not unnecessarily uniquified (but are
    # still uniquified when necessary).
    def linear_module(x, output_size):
      w = variable_scope.get_variable(
          "w", shape=[x.get_shape()[1], output_size],
          initializer=init_ops.zeros_initializer())
      b = variable_scope.get_variable(
          "b", shape=[output_size],
          initializer=init_ops.zeros_initializer())
      return (math_ops.matmul(x, w) + b), w

    def make_linear_module(output_size, name):
      return template.make_template(
          name,
          linear_module,
          output_size=output_size,
          create_scope_now_=True)

    inputs = array_ops.ones((3, 4))

    linear1 = make_linear_module(output_size=2, name="foo")
    outputs_a, w1 = linear1(inputs)
    outputs_b, _ = linear1(inputs)
    self.assertEquals("foo", linear1.variable_scope.name)
    self.assertEquals("foo/w:0", w1.name)
    if context.in_graph_mode():
      self.assertEquals("foo/add:0", outputs_a.name,
                        "First application of template should get "
                        "same name scope as variables.")
      self.assertEquals("foo_1/add:0", outputs_b.name,
                        "Second application of template should get "
                        "a freshly uniquified name scope.")

    linear2 = make_linear_module(output_size=2, name="foo")
    outputs_c, w2 = linear2(inputs)
    outputs_d, _ = linear2(inputs)
    self.assertEquals("foo_1", linear2.variable_scope.name,
                      "New template gets a freshly uniquified variable scope "
                      "because 'foo' is already taken.")
    self.assertEquals("foo_1/w:0", w2.name)
    if context.in_graph_mode():
      self.assertEquals("foo_1_1/add:0", outputs_c.name,
                        "First application of template would get "
                        "same name scope as variables, but 'foo_1' is already "
                        "a name scope.")
      self.assertEquals("foo_1_2/add:0", outputs_d.name,
                        "Second application of template should also get "
                        "a freshly uniquified name scope.")
开发者ID:ChengYuXiang,项目名称:tensorflow,代码行数:49,代码来源:template_test.py

示例7: __call__

  def __call__(self, *args):
    """Executes the passed function in eager mode."""
    tensor_inputs = [
        x for x in nest.flatten(args)
        if isinstance(x, ops.Tensor)
    ]
    if tape.should_record(tensor_inputs) or tape.should_record(
        self._extra_inputs):
      if not self._has_backprop:
        self._compute_backprop()
      return self._backprop_call(tensor_inputs)

    if context.in_graph_mode():
      g = ops.get_default_graph()
      if self._fdef.name not in g._functions:  # pylint: disable=protected-access
        g._add_function(self._fdef)  # pylint: disable=protected-access
      signature = self._fdef.definition.signature
      args = list(tensor_inputs) + self._extra_inputs
      op = g.create_op(
          signature.name, [ops.convert_to_tensor(x) for x in args],
          [dtypes.DType(x.type) for x in signature.output_arg],
          op_def=signature,
          name="FunctionCall",
          compute_shapes=False)
      result = op.outputs
      for i, s in enumerate(self._output_shapes):
        result[i].set_shape(s)
    else:
      result = execute.execute(
          str(self._func_name),
          num_outputs=self._num_outputs,
          inputs=tensor_inputs + self._extra_inputs)

    return self._build_call_outputs(self._returns, result)
开发者ID:allanbian1017,项目名称:tensorflow,代码行数:34,代码来源:function.py

示例8: test_no_sharing

 def test_no_sharing(self):
   with test_util.IsolateTest(), session.Session():
     first_container_variable = resource_variable_ops.ResourceVariable(
         name="same_name",
         initial_value=1)
     if context.in_graph_mode():
       self.evaluate([variables.global_variables_initializer()])
     with test_util.IsolateTest(), session.Session():
       second_container_variable = resource_variable_ops.ResourceVariable(
           name="same_name",
           initial_value=2)
       if context.in_graph_mode():
         self.evaluate([variables.global_variables_initializer()])
       self.assertEqual(
           2, self.evaluate(second_container_variable.read_value()))
     self.assertEqual(1, self.evaluate(first_container_variable.read_value()))
开发者ID:SylChan,项目名称:tensorflow,代码行数:16,代码来源:test_util_test.py

示例9: _init_from_proto

  def _init_from_proto(self, variable_def, import_scope=None):
    """Initializes from `VariableDef` proto."""
    # Note that init_from_proto is currently not supported in Eager mode.
    assert context.in_graph_mode()
    self._in_graph_mode = True
    assert isinstance(variable_def, variable_pb2.VariableDef)
    if not variable_def.is_resource:
      raise ValueError("Trying to restore Variable as ResourceVariable.")

    # Create from variable_def.
    g = ops.get_default_graph()
    self._handle = g.as_graph_element(
        ops.prepend_name_scope(
            variable_def.variable_name, import_scope=import_scope))
    self._handle_device = self._handle.device
    self._handle_name = self._handle.name
    self._initializer_op = g.as_graph_element(
        ops.prepend_name_scope(
            variable_def.initializer_name, import_scope=import_scope))
    if variable_def.snapshot_name:
      self._cached_value = g.as_graph_element(
          ops.prepend_name_scope(
              variable_def.snapshot_name, import_scope=import_scope))
    else:
      self._cached_value = None
    if variable_def.HasField("save_slice_info_def"):
      self._save_slice_info = variables.Variable.SaveSliceInfo(
          save_slice_info_def=variable_def.save_slice_info_def)
    else:
      self._save_slice_info = None
    self._caching_device = None
    self._dtype = dtypes.as_dtype(self._handle.op.get_attr("dtype"))
    self._graph_element = self.value()
    self._constraint = None
开发者ID:1000sprites,项目名称:tensorflow,代码行数:34,代码来源:resource_variable_ops.py

示例10: new_func

 def new_func(*args, **kwargs):
   """Deprecation wrapper."""
   # TODO(apassos) figure out a way to have reasonable performance with
   # deprecation warnings and eager mode.
   if context.in_graph_mode() and _PRINT_DEPRECATION_WARNINGS:
     invalid_args = []
     named_args = tf_inspect.getcallargs(func, *args, **kwargs)
     for arg_name, spec in iter(deprecated_positions.items()):
       if (spec.position < len(args) and
           not (spec.has_ok_value and
                _same_value(named_args[arg_name], spec.ok_value))):
         invalid_args.append(arg_name)
     if is_varargs_deprecated and len(args) > len(arg_spec.args):
       invalid_args.append(arg_spec.varargs)
     if is_kwargs_deprecated and kwargs:
       invalid_args.append(arg_spec.keywords)
     for arg_name in deprecated_arg_names:
       if (arg_name in kwargs and
           not (deprecated_positions[arg_name].has_ok_value and
                _same_value(named_args[arg_name],
                            deprecated_positions[arg_name].ok_value))):
         invalid_args.append(arg_name)
     for arg_name in invalid_args:
       if (func, arg_name) not in _PRINTED_WARNING:
         if warn_once:
           _PRINTED_WARNING[(func, arg_name)] = True
         logging.warning(
             'From %s: calling %s (from %s) with %s is deprecated and will '
             'be removed %s.\nInstructions for updating:\n%s',
             _call_location(), decorator_utils.get_qualified_name(func),
             func.__module__, arg_name,
             'in a future version' if date is None else ('after %s' % date),
             instructions)
   return func(*args, **kwargs)
开发者ID:ChengYuXiang,项目名称:tensorflow,代码行数:34,代码来源:deprecation.py

示例11: testRandomSeed

 def testRandomSeed(self):
   test_cases = [
       # Each test case is a tuple with input to get_seed:
       # (input_graph_seed, input_op_seed)
       # and output from get_seed:
       # (output_graph_seed, output_op_seed)
       ((None, None), (None, None)),
       ((None, 1), (random_seed.DEFAULT_GRAPH_SEED, 1)),
       ((1, 1), (1, 1)),
       ((0, 0), (0, 2**31 - 1)),  # Avoid nondeterministic (0, 0) output
       ((2**31 - 1, 0), (0, 2**31 - 1)),  # Don't wrap to (0, 0) either
       ((0, 2**31 - 1), (0, 2**31 - 1)),  # Wrapping for the other argument
   ]
   if context.in_graph_mode():
     # 0 will be the default_graph._lastid.
     test_cases.append(((1, None), (1, 0)))
   else:
     # operation seed is random number generated based on global seed.
     # it's not tested due to possibility of platform or version difference.
     pass
   for tc in test_cases:
     tinput, toutput = tc[0], tc[1]
     random_seed.set_random_seed(tinput[0])
     g_seed, op_seed = random_seed.get_seed(tinput[1])
     msg = 'test_case = {0}, got {1}, want {2}'.format(tinput,
                                                       (g_seed, op_seed),
                                                       toutput)
     self.assertEqual((g_seed, op_seed), toutput, msg=msg)
     random_seed.set_random_seed(None)
开发者ID:AbhinavJain13,项目名称:tensorflow,代码行数:29,代码来源:random_seed_test.py

示例12: __init__

 def __init__(self, handle, dtype, handle_device,  # pylint: disable=super-init-not-called
              shape, in_graph_mode, deleter, parent_op):
   # We do not call super init on purpose.
   self._trainable = False
   self._save_slice_info = None
   self._graph_key = ops.get_default_graph()._graph_key  # pylint: disable=protected-access
   self._in_graph_mode = in_graph_mode
   self._handle = handle
   self._handle_device = handle_device
   self._shape = shape
   self._initial_value = None
   if isinstance(self._handle, ops.EagerTensor):
     self._handle_name = ""
   else:
     self._handle_name = self._handle.name
   self._dtype = dtype
   self._constraint = None
   self._cached_value = None
   self._is_initialized_op = None
   self._initializer_op = None
   self._parent_op = parent_op
   if context.in_graph_mode():
     self._graph_element = self.read_value()
   else:
     self._graph_element = None
   self._handle_deleter = deleter
开发者ID:keithc61,项目名称:tensorflow,代码行数:26,代码来源:resource_variable_ops.py

示例13: testMaskingSingleInput

  def testMaskingSingleInput(self):

    class MaskedLayer(base_layers.Layer):

      def call(self, inputs, mask=None):
        if mask is not None:
          return inputs * mask
        return inputs

      def compute_mask(self, inputs, mask=None):
        return array_ops.ones_like(inputs)

    if context.in_graph_mode():
      x = base_layers.Input(shape=(32,))
      y = MaskedLayer()(x)  # pylint: disable=not-callable
      network = base_layers.Network(x, y)

      # test callability on Input
      x_2 = base_layers.Input(shape=(32,))
      y_2 = network(x_2)
      self.assertEqual(y_2.get_shape().as_list(), [None, 32])

      # test callability on regular tensor
      x_2 = array_ops.placeholder(dtype='float32', shape=(None, 32))
      y_2 = network(x_2)
      self.assertEqual(y_2.get_shape().as_list(), [None, 32])
    else:
      a = constant_op.constant([2] * 32)
      mask = constant_op.constant([0, 1] * 16)
      a._keras_mask = mask
      b = MaskedLayer().apply(a)
      self.assertTrue(hasattr(b, '_keras_mask'))
      self.assertAllEqual(self.evaluate(array_ops.ones_like(mask)),
                          self.evaluate(getattr(b, '_keras_mask')))
      self.assertAllEqual(self.evaluate(a * mask), self.evaluate(b))
开发者ID:keveman,项目名称:tensorflow,代码行数:35,代码来源:base_test.py

示例14: _delay_checks

  def _delay_checks(self):
    """Context manager for combining checks depending on tensor evaluations.

    Each call to Session.run has some overhead, and this overhead can easily
    account for the majority of the time spent in tests that call Session.run
    (or Tensor.eval) many times.

    This context manager provides a mechanism for registering callback functions
    and associated tensors.  When the context is exited, all of the tensors
    associated with all of the registrations are evaluated with a single call to
    Session.run, and then each registered callback function is called with the
    values of its associated tensors.

    Yields:
      A function `add_check(check, *args, **kwargs)` where `check` is the
      callback function to be invoked, and `*args` and `**kwargs` specify the
      associated Tensors. When in EAGER mode, check is executed in add_check,
      otherwise, it's delayed after the context.
    """
    checks = []

    def add_check(check, *args, **kwargs):
      if context.in_eager_mode():
        args_val, kwargs_val = self.evaluate([args, kwargs])
        check(*args_val, **kwargs_val)
      else:
        checks.append((check, args, kwargs))

    yield add_check
    if context.in_graph_mode():
      all_values = self.evaluate([[args, kwargs] for _, args, kwargs in checks])
      for (check, _, _), (args, kwargs) in zip(checks, all_values):
        check(*args, **kwargs)
开发者ID:ChengYuXiang,项目名称:tensorflow,代码行数:33,代码来源:atrous_convolution_test.py

示例15: __init__

  def __init__(self, root_checkpointable):
    """Configure saving.

    Args:
      root_checkpointable: The root of the object graph to save/restore. This
        object and all of its dependencies are saved in the checkpoint. When
        restoring, objects are matched and restored starting from this root.
    """
    # Allow passing in a weak reference to avoid reference cycles when
    # `Checkpointable` objects save themselves.
    self._root_checkpointable_ref = root_checkpointable
    if context.in_graph_mode():
      self._file_prefix_placeholder = constant_op.constant("model")
    else:
      self._file_prefix_placeholder = None

    # Op caching for save
    self._object_graph_feed_tensor = None
    self._last_save_object_graph = None
    self._last_save_saver = None

    # Op caching for restore
    self._object_graph_restore_tensor = None
    self._last_restore_object_graph = None
    self._last_restore_checkpoint = None
开发者ID:hhu-luqi,项目名称:tensorflow,代码行数:25,代码来源:checkpointable_utils.py


注:本文中的tensorflow.python.eager.context.in_graph_mode函数示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。