当前位置: 首页>>代码示例>>Python>>正文


Python ops.get_collection函数代码示例

本文整理汇总了Python中tensorflow.python.framework.ops.get_collection函数的典型用法代码示例。如果您正苦于以下问题:Python get_collection函数的具体用法?Python get_collection怎么用?Python get_collection使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。


在下文中一共展示了get_collection函数的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: __init__

    def __init__(self,
                 checkpoint_dir,
                 display_steps=100,
                 maximum_train_steps=None,
                 do_summary=True,
                 is_chief=True):
        """ Initializes the hook.

        Args:
            checkpoint_dir: A string, base directory for the checkpoint files.
            display_steps: A python integer, display every N steps.
            maximum_train_steps: A python integer, the maximum training steps.
            do_summary: Whether to save summaries when display.
            is_chief: Whether this is the chief process.do_summary:
        """

        tf.logging.info("Create DisplayHook.")
        self._checkpoint_dir = checkpoint_dir
        # display steps
        self._display_steps = display_steps
        self._maximum_train_steps = maximum_train_steps
        self._do_summary = do_summary
        self._is_chief = is_chief  # not used now

        # display values
        global_step = training_util.get_global_step()
        display_keys = ops.get_collection(Constants.DISPLAY_KEY_COLLECTION_NAME)
        display_values = ops.get_collection(Constants.DISPLAY_VALUE_COLLECTION_NAME)
        self._display_args = dict(zip(display_keys, display_values))
        self._display_args["global_step"] = global_step
        # timer & summary writer
        self._timer = None
        self._logging_timer = None
        self._summary_writer = None
开发者ID:KIngpon,项目名称:NJUNMT-tf,代码行数:34,代码来源:hooks.py

示例2: testAddWeight

    def testAddWeight(self):
        with self.test_session():
            layer = base_layers._Layer(name="my_layer")

            # Test basic variable creation.
            variable = layer._add_variable("my_var", [2, 2], initializer=init_ops.zeros_initializer)
            self.assertEqual(variable.name, "my_var:0")
            self.assertListEqual(layer.variables, [variable])
            self.assertListEqual(layer.trainable_variables, [variable])
            self.assertListEqual(layer.non_trainable_variables, [])
            self.assertListEqual(layer.variables, ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES))

            # Test non-trainable variable creation.
            # layer._add_variable should work even outside `build` and `call`.
            variable_2 = layer._add_variable(
                "non_trainable_var", [2, 2], initializer=init_ops.zeros_initializer, trainable=False
            )
            self.assertListEqual(layer.variables, [variable, variable_2])
            self.assertListEqual(layer.trainable_variables, [variable])
            self.assertListEqual(layer.non_trainable_variables, [variable_2])
            self.assertEqual(len(ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)), 1)

            # Test with regularizer.
            regularizer = lambda x: math_ops.reduce_sum(x) * 1e-3
            variable = layer._add_variable(
                "reg_var", [2, 2], initializer=init_ops.zeros_initializer, regularizer=regularizer
            )
            self.assertEqual(len(layer.losses), 1)
开发者ID:BloodD,项目名称:tensorflow,代码行数:28,代码来源:base_test.py

示例3: after_create_session

  def after_create_session(self, session, coord):  # pylint: disable=unused-argument
    """Does first run which shows the eval metrics before training."""
    if ops.get_collection(ops.GraphKeys.SAVEABLE_OBJECTS):
      raise ValueError(
          'InMemoryEvaluator does not support saveables other than global '
          'variables.')
    self._var_name_to_train_var = {
        v.name: v for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
    }
    var_names_to_transfer = set(self._var_name_to_placeholder.keys()) & set(
        self._var_name_to_train_var.keys())
    # Filter training var names that are not exist in evaluation
    self._var_name_to_train_var = {
        v_name: self._var_name_to_train_var[v_name]
        for v_name in var_names_to_transfer
    }
    # Filter eval var names that are not exist in training
    self._var_name_to_eval_var = {
        v_name: self._var_name_to_eval_var[v_name]
        for v_name in var_names_to_transfer
    }

    with self._graph.as_default():
      self._var_feed_op = control_flow_ops.group([
          state_ops.assign(self._var_name_to_eval_var[v_name],
                           self._var_name_to_placeholder[v_name])
          for v_name in var_names_to_transfer
      ])

    self._evaluate(session)
开发者ID:ChristinaEricka,项目名称:tensorflow,代码行数:30,代码来源:hooks.py

示例4: testAddWeight

  def testAddWeight(self):
    layer = base_layers.Layer(name='my_layer')

    # Test basic variable creation.
    variable = layer.add_variable(
        'my_var', [2, 2], initializer=init_ops.zeros_initializer())
    self.assertEqual(variable.name, 'my_layer/my_var:0')
    self.assertListEqual(layer.variables, [variable])
    self.assertListEqual(layer.trainable_variables, [variable])
    self.assertListEqual(layer.non_trainable_variables, [])
    self.assertListEqual(layer.variables,
                         ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES))

    # Test non-trainable variable creation.
    # layer.add_variable should work even outside `build` and `call`.
    variable_2 = layer.add_variable(
        'non_trainable_var', [2, 2],
        initializer=init_ops.zeros_initializer(),
        trainable=False)
    self.assertListEqual(layer.variables, [variable, variable_2])
    self.assertListEqual(layer.trainable_variables, [variable])
    self.assertListEqual(layer.non_trainable_variables, [variable_2])
    self.assertEqual(
        len(ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)), 1)

    if context.in_graph_mode():
      # regularizers only supported in GRAPH mode.
      regularizer = lambda x: math_ops.reduce_sum(x) * 1e-3
      variable = layer.add_variable(
          'reg_var', [2, 2],
          initializer=init_ops.zeros_initializer(),
          regularizer=regularizer)
      self.assertEqual(len(layer.losses), 1)
开发者ID:keveman,项目名称:tensorflow,代码行数:33,代码来源:base_test.py

示例5: testVariableCollections

 def testVariableCollections(self):
   with self.test_session():
     a = variables_lib2.variable('a', [], collections=['A', 'C'])
     b = variables_lib2.variable('b', [], collections=['B', 'C'])
     self.assertEquals(a, ops.get_collection('A')[0])
     self.assertEquals(b, ops.get_collection('B')[0])
     self.assertListEqual([a, b], ops.get_collection('C'))
开发者ID:AliMiraftab,项目名称:tensorflow,代码行数:7,代码来源:variables_test.py

示例6: _train_op_fn

  def _train_op_fn(loss):
    """Returns the op to optimize the loss."""
    train_ops = []
    global_step = training_util.get_global_step()
    if dnn_logits is not None:
      train_ops.append(
          dnn_optimizer.minimize(
              loss,
              var_list=ops.get_collection(
                  ops.GraphKeys.TRAINABLE_VARIABLES,
                  scope=dnn_parent_scope)))
    if linear_logits is not None:
      train_ops.append(
          linear_optimizer.minimize(
              loss,
              var_list=ops.get_collection(
                  ops.GraphKeys.TRAINABLE_VARIABLES,
                  scope=linear_parent_scope)))

    train_op = control_flow_ops.group(*train_ops)
    with ops.control_dependencies([train_op]):
      with ops.colocate_with(global_step):
        return state_ops.assign_add(global_step, 1)

    return head.create_estimator_spec(
        features=features,
        mode=mode,
        labels=labels,
        train_op_fn=_train_op_fn,
        logits=logits)
开发者ID:m-colombo,项目名称:tensorflow,代码行数:30,代码来源:dnn_linear_combined.py

示例7: testSaveAsText

  def testSaveAsText(self):
    export_dir = self._get_export_dir("test_astext")
    builder = saved_model_builder.SavedModelBuilder(export_dir)

    # Graph with a single variable. SavedModel invoked to:
    # - add with weights.
    with self.test_session(graph=ops.Graph()) as sess:
      self._init_and_validate_variable(sess, "v", 42)
      builder.add_meta_graph_and_variables(sess, ["foo"])

    # Graph with the same single variable. SavedModel invoked to:
    # - simply add the model (weights are not updated).
    with self.test_session(graph=ops.Graph()) as sess:
      self._init_and_validate_variable(sess, "v", 43)
      builder.add_meta_graph(["bar"])

    # Save the SavedModel to disk in text format.
    builder.save(as_text=True)

    # Restore the graph with tag "foo", whose variables were saved.
    with self.test_session(graph=ops.Graph()) as sess:
      loader.load(sess, ["foo"], export_dir)
      self.assertEqual(
          42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval())

    # Restore the graph with tag "bar", whose variables were not saved.
    with self.test_session(graph=ops.Graph()) as sess:
      loader.load(sess, ["bar"], export_dir)
      self.assertEqual(
          42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval())
开发者ID:KiaraStarlab,项目名称:tensorflow,代码行数:30,代码来源:saved_model_test.py

示例8: testTrainOpGroup

  def testTrainOpGroup(self):
    export_dir = self._get_export_dir("test_train_op_group")
    builder = saved_model_builder.SavedModelBuilder(export_dir)

    with self.test_session(graph=ops.Graph()) as sess:
      # Add `v1` and `v2` variables to the graph.
      v1 = variables.Variable(1, name="v1")
      ops.add_to_collection("v", v1)
      v2 = variables.Variable(2, name="v2")
      ops.add_to_collection("v", v2)

      sess.run(variables.global_variables_initializer())
      train_op = control_flow_ops.group()

      sess.run(train_op)
      # TODO(karmel): remove explicit call when in the public method.
      builder._add_train_op(train_op)
      builder.add_meta_graph_and_variables(sess, ["foo"])

    # Save the SavedModel to disk.
    builder.save()

    with self.test_session(graph=ops.Graph()) as sess:
      loader.load(sess, ["foo"], export_dir)
      self.assertEqual(1, ops.get_collection("v")[0].eval())
      self.assertEqual(2, ops.get_collection("v")[1].eval())
      self.assertIsInstance(
          ops.get_collection(constants.TRAIN_OP_KEY)[0], ops.Operation)
开发者ID:KiaraStarlab,项目名称:tensorflow,代码行数:28,代码来源:saved_model_test.py

示例9: testCustomMainOp

  def testCustomMainOp(self):
    export_dir = self._get_export_dir("test_main_op")
    builder = saved_model_builder.SavedModelBuilder(export_dir)

    with self.test_session(graph=ops.Graph()) as sess:
      # Add `v1` and `v2` variables to the graph.
      v1 = variables.Variable(1, name="v1")
      ops.add_to_collection("v", v1)
      v2 = variables.Variable(2, name="v2")
      ops.add_to_collection("v", v2)

      # Initialize another variable `v3` to 42.
      v3 = variables.Variable(42, name="v3")
      ops.add_to_collection("v", v3)

      # Set up an assignment op to be run as part of the main_op.
      with ops.control_dependencies([main_op.main_op()]):
        add_v1_v2 = math_ops.add(v1._ref(), v2._ref())
        custom_main_op = control_flow_ops.group(state_ops.assign(v3, add_v1_v2))

      sess.run(custom_main_op)
      builder.add_meta_graph_and_variables(
          sess, ["foo"], main_op=custom_main_op)

    # Save the SavedModel to disk.
    builder.save()

    with self.test_session(graph=ops.Graph()) as sess:
      loader.load(sess, ["foo"], export_dir)
      self.assertEqual(1, ops.get_collection("v")[0].eval())
      self.assertEqual(2, ops.get_collection("v")[1].eval())
      # Evaluates to the sum of the first two variables and assigned as part of
      # the main_op, following a restore.
      self.assertEqual(3, ops.get_collection("v")[2].eval())
开发者ID:KiaraStarlab,项目名称:tensorflow,代码行数:34,代码来源:saved_model_test.py

示例10: testLegacyInitOp

  def testLegacyInitOp(self):
    export_dir = self._get_export_dir("test_legacy_init_op")
    builder = saved_model_builder.SavedModelBuilder(export_dir)

    with self.test_session(graph=ops.Graph()) as sess:
      # Add `v1` and `v2` variables to the graph.
      v1 = variables.Variable(1, name="v1")
      ops.add_to_collection("v", v1)
      v2 = variables.Variable(2, name="v2")
      ops.add_to_collection("v", v2)

      # Initialize another variable `v3` to 42.
      v3 = variables.Variable(42, name="v3", trainable=False, collections=[])
      ops.add_to_collection("v", v3)

      # Set up an assignment op to be run as part of the legacy_init_op.
      assign_v3 = state_ops.assign(v3, math_ops.add(v1, v2))
      legacy_init_op = control_flow_ops.group(assign_v3, name="legacy_init_op")

      sess.run(variables.global_variables_initializer())
      builder.add_meta_graph_and_variables(
          sess, ["foo"], legacy_init_op=legacy_init_op)

    # Save the SavedModel to disk.
    builder.save()

    with self.test_session(graph=ops.Graph()) as sess:
      loader.load(sess, ["foo"], export_dir)
      self.assertEqual(1, ops.get_collection("v")[0].eval())
      self.assertEqual(2, ops.get_collection("v")[1].eval())
      # Evaluates to the sum of the first two variables and assigned as part of
      # the legacy_init_op, following a restore.
      self.assertEqual(3, ops.get_collection("v")[2].eval())
开发者ID:KiaraStarlab,项目名称:tensorflow,代码行数:33,代码来源:saved_model_test.py

示例11: test_example

  def test_example(self):
    with self.test_session() as session:
      for tower_id in range(3):
        self.create_tower_metrics(tower_id)

      session.run(
          variables.variables_initializer(
              ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES)))

      session.run(
          replicate_model_fn._reduce_metric_variables(number_of_towers=3))

      # 1st tower = 1.3, 2.3,  [3.3, 3.5, 3.7]
      # 2nd tower = 2.6, 4.6,  [6.6, 7.0, 7.4]
      # 3rd tower = 3.9, 6.9,  [9.9, 10.5, 11.1]
      # Reduced =   7.8, 13.8, [19.8, 21.0, 22.2]
      # Towers are accumulated in the first tower.
      local_metrics = session.run(
          ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES))

      self.assertNear(7.8, local_metrics[0], 0.01)
      self.assertNear(13.8, local_metrics[1], 0.01)
      self.assertAllClose([19.8, 21., 22.1], local_metrics[2], 0.01)
      self.assertNear(0.0, local_metrics[3], 0.01)
      self.assertNear(0.0, local_metrics[4], 0.01)
      self.assertAllClose([0.0, 0.0, 0.0], local_metrics[5], 0.01)
      self.assertNear(0.0, local_metrics[6], 0.01)
      self.assertNear(0.0, local_metrics[7], 0.01)
      self.assertAllClose([0.0, 0.0, 0.0], local_metrics[8], 0.01)
开发者ID:AbhinavJain13,项目名称:tensorflow,代码行数:29,代码来源:replicate_model_fn_test.py

示例12: testMultipleConvMaskAdded

  def testMultipleConvMaskAdded(self):
    number_of_layers = 5

    kernel_size = 3
    base_depth = 4
    depth_step = 7

    input_tensor = array_ops.ones((8, self.height, self.width, base_depth))

    top_layer = input_tensor

    for ix in range(number_of_layers):
      top_layer = layers.masked_conv2d(top_layer, base_depth +
                                       (ix + 1) * depth_step, kernel_size)

    masks = ops.get_collection(core_layers.MASK_COLLECTION)
    self.assertEqual(len(masks), number_of_layers)
    for ix in range(number_of_layers):
      self.assertListEqual(masks[ix].get_shape().as_list(), [
          kernel_size, kernel_size, base_depth + ix * depth_step,
          base_depth + (ix + 1) * depth_step
      ])

    masked_weight = ops.get_collection(core_layers.MASKED_WEIGHT_COLLECTION)
    self.assertEqual(len(masked_weight), number_of_layers)
    for ix in range(number_of_layers):
      self.assertListEqual(masked_weight[ix].get_shape().as_list(), [
          kernel_size, kernel_size, base_depth + ix * depth_step,
          base_depth + (ix + 1) * depth_step
      ])
开发者ID:AbhinavJain13,项目名称:tensorflow,代码行数:30,代码来源:layers_test.py

示例13: testTags

  def testTags(self):
    export_dir = os.path.join(test.get_temp_dir(), "test_tags")
    builder = saved_model_builder.SavedModelBuilder(export_dir)

    # Graph with a single variable. SavedModel invoked to:
    # - add with weights.
    # - a single tag (from predefined constants).
    with self.test_session(graph=ops.Graph()) as sess:
      self._init_and_validate_variable(sess, "v", 42)
      builder.add_meta_graph_and_variables(sess, [tag_constants.TRAINING])

    # Graph that updates the single variable. SavedModel invoked to:
    # - simply add the model (weights are not updated).
    # - a single tag (from predefined constants).
    with self.test_session(graph=ops.Graph()) as sess:
      self._init_and_validate_variable(sess, "v", 43)
      builder.add_meta_graph([tag_constants.SERVING])

    # Graph that updates the single variable. SavedModel is invoked:
    # - to add the model (weights are not updated).
    # - multiple custom tags.
    with self.test_session(graph=ops.Graph()) as sess:
      self._init_and_validate_variable(sess, "v", 44)
      builder.add_meta_graph(["foo", "bar"])

    # Save the SavedModel to disk.
    builder.save()

    # Restore the graph with a single predefined tag whose variables were saved.
    with self.test_session(graph=ops.Graph()) as sess:
      loader.load(sess, [tag_constants.TRAINING], export_dir)
      self.assertEqual(
          42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval())

    # Restore the graph with a single predefined tag whose variables were not
    # saved.
    with self.test_session(graph=ops.Graph()) as sess:
      loader.load(sess, [tag_constants.SERVING], export_dir)
      self.assertEqual(
          42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval())

    # Restore the graph with multiple tags. Provide duplicate tags to test set
    # semantics.
    with self.test_session(graph=ops.Graph()) as sess:
      loader.load(sess, ["foo", "bar", "foo"], export_dir)
      self.assertEqual(
          42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval())

    # Try restoring a graph with a non-existent tag. This should yield a runtime
    # error.
    with self.test_session(graph=ops.Graph()) as sess:
      self.assertRaises(RuntimeError, loader.load, sess, ["INVALID"],
                        export_dir)

    # Try restoring a graph where a subset of the tags match. Since tag matching
    # for meta graph defs follows "all" semantics, this should yield a runtime
    # error.
    with self.test_session(graph=ops.Graph()) as sess:
      self.assertRaises(RuntimeError, loader.load, sess, ["foo", "baz"],
                        export_dir)
开发者ID:adityaatluri,项目名称:tensorflow,代码行数:60,代码来源:saved_model_test.py

示例14: test_reduce_is_idempotent

  def test_reduce_is_idempotent(self):
    with self.test_session() as session:
      for tower_id in range(3):
        self.create_tower_metrics(tower_id)

      session.run(
          variables.variables_initializer(
              ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES)))

      for _ in range(20):
        session.run(
            replicate_model_fn._reduce_metric_variables(number_of_towers=3))

      local_metrics = session.run(
          ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES))

      self.assertNear(7.8, local_metrics[0], 0.01)
      self.assertNear(13.8, local_metrics[1], 0.01)
      self.assertAllClose([19.8, 21., 22.1], local_metrics[2], 0.01)
      self.assertNear(0.0, local_metrics[3], 0.01)
      self.assertNear(0.0, local_metrics[4], 0.01)
      self.assertAllClose([0.0, 0.0, 0.0], local_metrics[5], 0.01)
      self.assertNear(0.0, local_metrics[6], 0.01)
      self.assertNear(0.0, local_metrics[7], 0.01)
      self.assertAllClose([0.0, 0.0, 0.0], local_metrics[8], 0.01)
开发者ID:AbhinavJain13,项目名称:tensorflow,代码行数:25,代码来源:replicate_model_fn_test.py

示例15: testCreateBN

  def testCreateBN(self):
    # Call layer.
    bn = normalization_layers.BatchNormalization(axis=1)
    inputs = random_ops.random_uniform((5, 4, 3), seed=1)
    training = array_ops.placeholder(dtype='bool')
    outputs = bn.apply(inputs, training=training)

    # Verify shape.
    self.assertListEqual(outputs.get_shape().as_list(), [5, 4, 3])

    # Verify layer attributes.
    self.assertEqual(len(bn.updates), 2)
    self.assertEqual(len(bn.variables), 4)
    self.assertEqual(len(bn.trainable_variables), 2)
    self.assertEqual(len(bn.non_trainable_variables), 2)

    # Test that updates were created and added to UPDATE_OPS.
    self.assertEqual(len(bn.updates), 2)
    self.assertListEqual(
        ops.get_collection(ops.GraphKeys.UPDATE_OPS), bn.updates)

    # Test that weights were created and added to TRAINABLE_VARIABLES.
    self.assertListEqual(
        ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES),
        bn.trainable_variables)
开发者ID:adityaatluri,项目名称:tensorflow,代码行数:25,代码来源:normalization_test.py


注:本文中的tensorflow.python.framework.ops.get_collection函数示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。