当前位置: 首页>>代码示例>>Python>>正文


Python variable_scope.variable_creator_scope函数代码示例

本文整理汇总了Python中tensorflow.python.ops.variable_scope.variable_creator_scope函数的典型用法代码示例。如果您正苦于以下问题:Python variable_creator_scope函数的具体用法?Python variable_creator_scope怎么用?Python variable_creator_scope使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。


在下文中一共展示了variable_creator_scope函数的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: testCreatorStacksAreThreadLocal

  def testCreatorStacksAreThreadLocal(self):
    devices = ["/device:CPU:0", "/device:GPU:0"]
    dist = mirrored_strategy.MirroredStrategy(devices)

    def model_fn(device_id):
      assert isinstance(device_id, int)
      def thread_creator_fn(next_creator, *args, **kwargs):
        return next_creator(*args, **kwargs) + ":thread_" + str(device_id)

      with variable_scope.variable_creator_scope(thread_creator_fn):
        # Create a variable in this scope.
        v = variable_scope.variable(1.0)

        # This will pause the current thread, and execute the other thread.
        distribute_lib.get_tower_context().merge_call(lambda _: _)
      return v

    def main_thread_creator(next_creator, *args, **kwargs):
      # We are not using the underlying next_creator for test purposes.
      del next_creator, args, kwargs
      return "main_thread"

    with context.graph_mode(), \
        dist.scope(), \
        variable_scope.variable_creator_scope(main_thread_creator):
      result = dist.call_for_each_tower(model_fn, dist.worker_device_index)
      result = dist.unwrap(result)
      expected = ["main_thread:thread_0", "main_thread:thread_1"]
      self.assertEquals(expected, result)
开发者ID:Jackiefan,项目名称:tensorflow,代码行数:29,代码来源:mirrored_strategy_test.py

示例2: scope

  def scope(self):
    """Returns a context manager selecting this DistributionStrategy as current.

    Inside a `with distribution_strategy.scope():` code block, this thread
    will use a variable creator set by `distribution_strategy`, and will
    enter its "cross-tower context".

    Returns:
      A context manager.
    """
    if has_distribution_strategy():
      _require_cross_tower_context(self)
      return _SameScopeAgainContext(self)

    def creator_with_resource_vars(*args, **kwargs):
      _require_distribution_strategy_scope(self)
      kwargs["use_resource"] = True
      return self._create_variable(*args, **kwargs)

    def disable_partitioned_variables(getter, *args, **kwargs):
      if kwargs.pop("partitioner", None) is not None:
        tf_logging.log_first_n(
            tf_logging.WARN, "Partitioned variables are disabled when using "
            "DistributionStrategy.", 1)
      return getter(*args, **kwargs)

    return _CurrentDistributionContext(
        self, variable_scope.variable_creator_scope(creator_with_resource_vars),
        variable_scope.variable_scope(
            variable_scope.get_variable_scope(),
            custom_getter=disable_partitioned_variables),
        self._default_device)
开发者ID:LiuCKind,项目名称:tensorflow,代码行数:32,代码来源:distribute.py

示例3: testOptimizerInsideModelFn

  def testOptimizerInsideModelFn(self, distribution, optimizer_fn):
    created_variables = []
    trainable_variables = []

    def appending_creator(next_creator, *args, **kwargs):
      v = next_creator(*args, **kwargs)
      created_variables.append(v.name)
      if "trainable" in kwargs and kwargs["trainable"]:
        trainable_variables.append(v.name)
      return v

    # Creator scope needs to be set before it's used inside
    # `distribution.scope`.
    with variable_scope.variable_creator_scope(
        appending_creator), distribution.scope():
      model_fn, dataset_fn, _ = minimize_loss_example(
          optimizer_fn,
          use_bias=True,
          use_callable_loss=True,
          create_optimizer_inside_model_fn=True)

      def step_fn(ctx, inputs):
        del ctx  # Unused
        return distribution.group(
            distribution.extended.call_for_each_replica(
                model_fn, args=(inputs,)))

      iterator = self._get_iterator(distribution, dataset_fn)

      def run_step():
        return distribution.extended.experimental_run_steps_on_iterator(
            step_fn, iterator, iterations=1).run_op

      if not context.executing_eagerly():
        with self.cached_session() as sess:
          run_step = sess.make_callable(run_step())
      self.evaluate(variables_lib.global_variables_initializer())
      run_step()

      def get_expected_variables(optimizer_fn, num_parameter_devices):
        optimizer = optimizer_fn()
        name = optimizer._name

        if isinstance(optimizer, optimizer_v2.OptimizerV2):
          variables = VAR_MAP_V2[name]
        else:
          variables = VAR_MAP_V1[name]

        extended_variables = [
            v + "/replica_{}".format(replica)
            for v in variables
            for replica in range(1, num_parameter_devices)
        ]
        variables = list(variables) + extended_variables
        return set([v + ":0" for v in variables])

      self.assertEqual(
          get_expected_variables(optimizer_fn,
                                 len(distribution.extended.parameter_devices)),
          set(created_variables))
开发者ID:aritratony,项目名称:tensorflow,代码行数:60,代码来源:minimize_loss_test.py

示例4: notify_about_variables

def notify_about_variables(callback):
  """Calls `callback(var)` for all `tf.{Variable,get_variable}` results.

  Callback should not modify the variable passed in. Use cases that require
  variables to be modified should use `variable_creator_scope` directly and sit
  within the variable creator stack.

  >>> variables = []
  >>> with notify_about_variables(variables.append):
  ...   v = tf.Variable(1.0, name='v')
  ...   w = tf.get_variable('w', [])
  >>> assert variables == [v, w]

  Args:
    callback: a callable taking a single argument which is a tf.Variable.

  Yields:
    `None` - used for contextmanager API.
  """
  def _tracking_creator(getter, **kwargs):
    v = getter(**kwargs)
    callback(v)
    return v

  with variable_scope_ops.variable_creator_scope(_tracking_creator):
    yield
开发者ID:ccchang0111,项目名称:sonnet,代码行数:26,代码来源:util.py

示例5: call

 def call(self, inputs, mask=None, training=None):
   arguments = self.arguments
   if self._fn_expects_mask_arg:
     arguments['mask'] = mask
   if self._fn_expects_training_arg:
     arguments['training'] = training
   with variable_scope.variable_creator_scope(self._variable_creator):
     return self.function(inputs, **arguments)
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:8,代码来源:core.py

示例6: tower_local_var_scope

  def tower_local_var_scope(self, reduce_method):
    """Does not set to resource variables."""
    def create_tower_local_variable(next_creator, *args, **kwargs):
      _require_distribution_strategy_scope(self)
      kwargs["trainable"] = False
      return next_creator(*args, **kwargs)

    _require_distribution_strategy_scope(self)
    return variable_scope.variable_creator_scope(create_tower_local_variable)
开发者ID:LiuCKind,项目名称:tensorflow,代码行数:9,代码来源:distribute.py

示例7: testOptimizerInsideModelFn

  def testOptimizerInsideModelFn(self, distribution, optimizer_fn):
    created_variables = []
    trainable_variables = []

    def appending_creator(next_creator, *args, **kwargs):
      v = next_creator(*args, **kwargs)
      created_variables.append(v.name)
      if "trainable" in kwargs and kwargs["trainable"]:
        trainable_variables.append(v.name)
      return v

    # Creator scope needs to be set before it's used inside
    # `distribution.scope`.
    with variable_scope.variable_creator_scope(
        appending_creator), distribution.scope():
      model_fn, dataset, layer = minimize_loss_example(
          optimizer_fn,
          use_bias=True,
          use_callable_loss=True,
          create_optimizer_inside_model_fn=True)

      iterator = distribution.distribute_dataset(dataset)

      def run_step():
        return distribution.group(
            distribution.call_for_each_tower(
                model_fn, iterator.get_next(), run_concurrently=layer.built))

      if not context.executing_eagerly():
        with self.test_session() as sess:
          run_step = sess.make_callable(run_step())
        self.evaluate(variables_lib.global_variables_initializer())

      run_step()

      def get_expected_variables(optimizer_fn, num_parameter_devices):
        variables_map = {
            "GradientDescent": ["dense/kernel", "dense/bias"],
            "Adam": [
                "dense/kernel", "dense/bias", "beta1_power", "beta2_power",
                "dense/kernel/Adam", "dense/kernel/Adam_1", "dense/bias/Adam",
                "dense/bias/Adam_1"
            ]
        }
        variables = variables_map[optimizer_fn().get_name()]
        variables.extend([
            v + "/replica_{}".format(replica)
            for v in variables
            for replica in range(1, num_parameter_devices)
        ])
        return set([v + ":0" for v in variables])

      self.assertEqual(
          get_expected_variables(optimizer_fn,
                                 len(distribution.parameter_devices)),
          set(created_variables))
开发者ID:bikong2,项目名称:tensorflow,代码行数:56,代码来源:minimize_loss_test.py

示例8: save

  def save(self, session=None, checkpoint_number=None):
    """Creates a new checkpoint and manages it.

    Args:
      session: The session to evaluate variables in. Ignored when executing
        eagerly. If not provided when graph building, the default session is
        used.
      checkpoint_number: An optional integer, or an integer-dtype `Variable` or
        `Tensor`, used to number the checkpoint. If `None` (default),
        checkpoints are numbered using `checkpoint.save_counter`. Even if
        `checkpoint_number` is provided, `save_counter` is still incremented. A
        user-provided `checkpoint_number` is not incremented even if it is a
        `Variable`.

    Returns:
      The path to the new checkpoint. It is also recorded in the `checkpoints`
      and `latest_checkpoint` properies.
    """
    # Save counter logic duplicated from tf.train.Checkpoint, soon to diverge
    # slightly with a custom numbering option.
    if context.executing_eagerly():
      save_counter = self._checkpoint.save_counter
      save_counter.assign_add(1)
    else:
      if session is None:
        session = ops.get_default_session()

      def _initializing_creator(next_creator, **kwargs):
        """Initialize the save counter if it has been newly created."""
        v = next_creator(**kwargs)
        session.run(v.initializer)
        return v

      with variable_scope.variable_creator_scope(_initializing_creator):
        save_counter = self._checkpoint.save_counter
      if self._save_counter_assign is None:
        self._save_counter_assign = save_counter.assign_add(1, read_value=False)
      session.run(self._save_counter_assign)
    if checkpoint_number is None:
      checkpoint_number = save_counter
    if not isinstance(checkpoint_number, compat.integral_types):
      checkpoint_number = training_util.global_step(
          sess=session, global_step_tensor=checkpoint_number)
    prefix = "%s-%d" % (self._prefix, checkpoint_number)
    save_path = self._checkpoint.write(prefix)
    timestamp = time.time()
    # If this is an overwritten checkpoint we were previously tracking, delete
    # and reinsert it to make sure it goes to the end of the queue.
    if save_path in self._maybe_delete:
      del self._maybe_delete[save_path]
    self._maybe_delete[save_path] = timestamp
    self._latest_checkpoint = save_path
    self._sweep()
    self._record_state()
    return save_path
开发者ID:AnishShah,项目名称:tensorflow,代码行数:55,代码来源:checkpoint_management.py

示例9: _call_func

  def _call_func(self, args, kwargs):
    try:
      vars_at_start = len(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES))
      trainable_at_start = len(
          ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES))
      if self._variables_created:
        result = self._func(*args, **kwargs)
      else:
        # The first time we run, restore variables if necessary (via
        # Checkpointable).
        with variable_scope.variable_creator_scope(
            self._checkpointable_custom_creator):
          result = self._func(*args, **kwargs)

      if self._variables_created:
        # Variables were previously created, implying this is not the first
        # time the template has been called. Check to make sure that no new
        # trainable variables were created this time around.
        trainable_variables = ops.get_collection(
            ops.GraphKeys.TRAINABLE_VARIABLES)
        # If a variable that we intend to train is created as a side effect
        # of creating a template, then that is almost certainly an error.
        if trainable_at_start != len(trainable_variables):
          raise ValueError("Trainable variable created when calling a template "
                           "after the first time, perhaps you used tf.Variable "
                           "when you meant tf.get_variable: %s" %
                           (trainable_variables[trainable_at_start:],))

        # Non-trainable tracking variables are a legitimate reason why a new
        # variable would be created, but it is a relatively advanced use-case,
        # so log it.
        variables = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
        if vars_at_start != len(variables):
          logging.info("New variables created when calling a template after "
                       "the first time, perhaps you used tf.Variable when you "
                       "meant tf.get_variable: %s",
                       variables[vars_at_start:])
      else:
        self._variables_created = True
      return result
    except Exception as exc:
      # Reraise the exception, but append the original definition to the
      # trace.
      args = exc.args
      if not args:
        arg0 = ""
      else:
        arg0 = args[0]
      trace = "".join(_skip_common_stack_elements(self._stacktrace,
                                                  traceback.format_stack()))
      arg0 = "%s\n\noriginally defined at:\n%s" % (arg0, trace)
      new_args = [arg0]
      new_args.extend(args[1:])
      exc.args = tuple(new_args)
      raise
开发者ID:AndrewTwinz,项目名称:tensorflow,代码行数:55,代码来源:template.py

示例10: scope

  def scope(self):
    """Context manager setting a variable creator and `self` as current."""
    if distribution_strategy_context.has_distribution_strategy():
      raise RuntimeError("Must not nest DistributionStrategy scopes.")

    def creator(next_creator, *args, **kwargs):
      _require_distribution_strategy_scope(self)
      return next_creator(*args, **kwargs)

    return _CurrentDistributionContext(
        self, variable_scope.variable_creator_scope(creator))
开发者ID:AnishShah,项目名称:tensorflow,代码行数:11,代码来源:distribute.py

示例11: testSharedVariable

  def testSharedVariable(self):

    shared_variable_store = {}
    num_devices = 3
    creator_fns = []
    for i in range(num_devices):
      creator_fn = shared_variable_creator.make_fn(shared_variable_store, i)
      creator_fns.append(creator_fn)

    with variable_scope.variable_creator_scope(creator_fns[0]):
      v0 = variable_scope.variable(1.0, name="foo")

    with variable_scope.variable_creator_scope(creator_fns[1]):
      v1 = variable_scope.variable(1.0, name="foo")

    with variable_scope.variable_creator_scope(creator_fns[2]):
      v2 = variable_scope.variable(1.0, name="foo")

    # v1 and v2 should be same as v0
    self.assertIs(v1, v0)
    self.assertIs(v2, v0)
开发者ID:Jackiefan,项目名称:tensorflow,代码行数:21,代码来源:shared_variable_creator_test.py

示例12: model_fn

    def model_fn(device_id):
      assert isinstance(device_id, int)
      def thread_creator_fn(next_creator, *args, **kwargs):
        return next_creator(*args, **kwargs) + ":thread_" + str(device_id)

      with variable_scope.variable_creator_scope(thread_creator_fn):
        # Create a variable in this scope.
        v = variable_scope.variable(1.0)

        # This will pause the current thread, and execute the other thread.
        distribute_lib.get_tower_context().merge_call(lambda _: _)
      return v
开发者ID:Jackiefan,项目名称:tensorflow,代码行数:12,代码来源:mirrored_strategy_test.py

示例13: scope

  def scope(self):
    """Context manager setting a variable creator and `self` as current."""
    if has_distribution_strategy():
      raise RuntimeError("Must not nest DistributionStrategy scopes.")

    def creator(next_creator, *args, **kwargs):
      _require_distribution_strategy_scope(self)
      if kwargs.pop("tower_local_reduce_method", None) is not None:
        kwargs["trainable"] = False
      return next_creator(*args, **kwargs)

    return _CurrentDistributionContext(
        self, variable_scope.variable_creator_scope(creator))
开发者ID:kimr843,项目名称:tensorflow,代码行数:13,代码来源:distribute.py

示例14: one_host_numpy_dataset

def one_host_numpy_dataset(numpy_input, colocate_with, session):
  """Create a dataset on `colocate_with` from `numpy_input`."""
  def create_colocated_variable(next_creator, *args, **kwargs):
    kwargs["colocate_with"] = colocate_with
    return next_creator(*args, **kwargs)

  numpy_flat = nest.flatten(numpy_input)
  with variable_scope.variable_creator_scope(create_colocated_variable):
    vars_flat = tuple(variable_scope.variable(array_ops.zeros(i.shape, i.dtype),
                                              trainable=False)
                      for i in numpy_flat)
  for v, i in zip(vars_flat, numpy_flat):
    init_var_from_numpy(v, i, session)
  vars_nested = nest.pack_sequence_as(numpy_input, vars_flat)
  return dataset_ops.Dataset.from_tensor_slices(vars_nested)
开发者ID:Wajih-O,项目名称:tensorflow,代码行数:15,代码来源:numpy_dataset.py

示例15: testScopeVarCreatorNestingError

  def testScopeVarCreatorNestingError(self):

    def creator(next_creator, **kwargs):
      return next_creator(**kwargs)

    _assert_in_default_state(self)
    dist = _TestStrategy()
    scope = dist.scope()
    scope.__enter__()
    self.assertIs(dist, ds_context.get_strategy())
    with variable_scope.variable_creator_scope(creator):
      with self.assertRaisesRegexp(RuntimeError,
                                   "Variable creator scope nesting error"):
        scope.__exit__(None, None, None)
    scope.__exit__(None, None, None)
    _assert_in_default_state(self)
开发者ID:aritratony,项目名称:tensorflow,代码行数:16,代码来源:distribute_lib_test.py


注:本文中的tensorflow.python.ops.variable_scope.variable_creator_scope函数示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。