当前位置: 首页>>代码示例>>Python>>正文


Python ops.get_collection_ref函数代码示例

本文整理汇总了Python中tensorflow.python.framework.ops.get_collection_ref函数的典型用法代码示例。如果您正苦于以下问题:Python get_collection_ref函数的具体用法?Python get_collection_ref怎么用?Python get_collection_ref使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。


在下文中一共展示了get_collection_ref函数的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: test_asset_loading

  def test_asset_loading(self):
    first_path = self._v1_asset_saved_model()
    imported = load.load(first_path)
    self.evaluate(lookup_ops.tables_initializer())
    fn = imported.signatures["serving_default"]
    self.assertAllClose({"output": [2, 0]},
                        fn(start=constant_op.constant(["gamma", "alpha"])))
    second_path = os.path.join(self.get_temp_dir(), "saved_model",
                               str(ops.uid()))
    save.save(imported, second_path, signatures=imported.signatures)
    shutil.rmtree(first_path)
    del ops.get_collection_ref(ops.GraphKeys.TABLE_INITIALIZERS)[:]
    second_import = load.load(second_path)
    self.evaluate(lookup_ops.tables_initializer())
    fn = second_import.signatures["serving_default"]
    self.assertAllClose({"output": [2, 0]},
                        fn(start=constant_op.constant(["gamma", "alpha"])))

    third_path = os.path.join(self.get_temp_dir(), "saved_model",
                              str(ops.uid()))
    save.save(second_import, third_path, signatures=second_import.signatures)
    shutil.rmtree(second_path)
    del ops.get_collection_ref(ops.GraphKeys.TABLE_INITIALIZERS)[:]
    third_import = load.load(third_path)
    self.evaluate(lookup_ops.tables_initializer())
    fn = third_import.signatures["serving_default"]
    self.assertAllClose({"output": [2, 0]},
                        fn(start=constant_op.constant(["gamma", "alpha"])))
开发者ID:aritratony,项目名称:tensorflow,代码行数:28,代码来源:load_v1_in_v2_test.py

示例2: _call_func

  def _call_func(self, args, kwargs):
    try:
      vars_at_start = len(
          ops.get_collection_ref(ops.GraphKeys.GLOBAL_VARIABLES))
      trainable_at_start = len(
          ops.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES))
      if self._variables_created:
        result = self._func(*args, **kwargs)
      else:
        # The first time we run, restore variables if necessary (via
        # Checkpointable).
        with checkpointable_util.capture_dependencies(template=self):
          result = self._func(*args, **kwargs)

      if self._variables_created:
        # Variables were previously created, implying this is not the first
        # time the template has been called. Check to make sure that no new
        # trainable variables were created this time around.
        trainable_variables = ops.get_collection_ref(
            ops.GraphKeys.TRAINABLE_VARIABLES)
        # If a variable that we intend to train is created as a side effect
        # of creating a template, then that is almost certainly an error.
        if trainable_at_start != len(trainable_variables):
          raise ValueError("Trainable variable created when calling a template "
                           "after the first time, perhaps you used tf.Variable "
                           "when you meant tf.get_variable: %s" %
                           (trainable_variables[trainable_at_start:],))

        # Non-trainable tracking variables are a legitimate reason why a new
        # variable would be created, but it is a relatively advanced use-case,
        # so log it.
        variables = ops.get_collection_ref(ops.GraphKeys.GLOBAL_VARIABLES)
        if vars_at_start != len(variables):
          logging.info("New variables created when calling a template after "
                       "the first time, perhaps you used tf.Variable when you "
                       "meant tf.get_variable: %s",
                       variables[vars_at_start:])
      else:
        self._variables_created = True
      return result
    except Exception as exc:
      # Reraise the exception, but append the original definition to the
      # trace.
      args = exc.args
      if not args:
        arg0 = ""
      else:
        arg0 = args[0]
      trace = "".join(_skip_common_stack_elements(self._stacktrace,
                                                  traceback.format_stack()))
      arg0 = "%s\n\noriginally defined at:\n%s" % (arg0, trace)
      new_args = [arg0]
      new_args.extend(args[1:])
      exc.args = tuple(new_args)
      raise
开发者ID:Wajih-O,项目名称:tensorflow,代码行数:55,代码来源:template.py

示例3: _clear_saved_model_collections

def _clear_saved_model_collections():
  """Clear collections that are expected empty when exporting a SavedModel.

  The SavedModel builder uses these collections to track ops necessary to
  restore the graph state. These collections are expected to be empty before
  MetaGraphs are added to the builder.
  """
  del ops.get_collection_ref(constants.ASSETS_KEY)[:]
  del ops.get_collection_ref(constants.LEGACY_INIT_OP_KEY)[:]
  del ops.get_collection_ref(constants.MAIN_OP_KEY)[:]
  del ops.get_collection_ref(constants.TRAIN_OP_KEY)[:]
开发者ID:AnishShah,项目名称:tensorflow,代码行数:11,代码来源:saved_model_estimator.py

示例4: testBasicMemory

  def testBasicMemory(self):
    """Make sure arguments can be passed correctly."""
    with test_util.device(use_gpu=False):
      a = constant_op.constant(10, name="a")
      b = constant_op.constant(20, name="b")
      c = math_ops.add_n([a, b], name="c")
      d = math_ops.add_n([b, c], name="d")
      train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
      train_op.append(d)
      mg = meta_graph.create_meta_graph_def(graph=ops.get_default_graph())

    report = cost_analyzer.GenerateMemoryReport(mg)

    # Print the report to make it easier to debug
    print("{}".format(report))

    # Check the report
    self.assertTrue(
        "Peak usage for device /job:localhost/replica:0/task:0/device:CPU:0: "
        "16 bytes"
        in report)
    self.assertTrue("  a:0 uses 4 bytes" in report)
    self.assertTrue("  b:0 uses 4 bytes" in report)
    self.assertTrue("  c:0 uses 4 bytes" in report)
    self.assertTrue("  d:0 uses 4 bytes" in report)
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:25,代码来源:cost_analyzer_test.py

示例5: build

  def build(self, inputs_shape):
    # Call the build method of the parent class.
    super(MaskedBasicLSTMCell, self).build(inputs_shape)

    self.built = False

    input_depth = inputs_shape[1].value
    h_depth = self._num_units
    self._mask = self.add_variable(
        name="mask",
        shape=[input_depth + h_depth, 4 * h_depth],
        initializer=init_ops.ones_initializer(),
        trainable=False,
        dtype=self.dtype)
    self._threshold = self.add_variable(
        name="threshold",
        shape=[],
        initializer=init_ops.zeros_initializer(),
        trainable=False,
        dtype=self.dtype)
    # Add masked_weights in the weights namescope so as to make it easier
    # for the quantization library to add quant ops.
    self._masked_kernel = math_ops.multiply(self._mask, self._kernel,
                                            core_layers.MASKED_WEIGHT_NAME)
    if self._mask not in ops.get_collection_ref(core_layers.MASK_COLLECTION):
      ops.add_to_collection(core_layers.MASK_COLLECTION, self._mask)
      ops.add_to_collection(core_layers.MASKED_WEIGHT_COLLECTION,
                            self._masked_kernel)
      ops.add_to_collection(core_layers.THRESHOLD_COLLECTION, self._threshold)
      ops.add_to_collection(core_layers.WEIGHT_COLLECTION, self._kernel)

    self.built = True
开发者ID:AbhinavJain13,项目名称:tensorflow,代码行数:32,代码来源:rnn_cells.py

示例6: record_summaries_every_n_global_steps

def record_summaries_every_n_global_steps(n):
  """Sets the should_record_summaries Tensor to true if global_step % n == 0."""
  collection_ref = ops.get_collection_ref(_SHOULD_RECORD_SUMMARIES_NAME)
  old = collection_ref[:]
  collection_ref[:] = [training_util.get_global_step() % n == 0]
  yield
  collection_ref[:] = old
开发者ID:benoitsteiner,项目名称:tensorflow-opencl,代码行数:7,代码来源:summary_ops.py

示例7: never_record_summaries

def never_record_summaries():
  """Sets the should_record_summaries Tensor to always false."""
  collection_ref = ops.get_collection_ref(_SHOULD_RECORD_SUMMARIES_NAME)
  old = collection_ref[:]
  collection_ref[:] = [False]
  yield
  collection_ref[:] = old
开发者ID:ChengYuXiang,项目名称:tensorflow,代码行数:7,代码来源:summary_ops.py

示例8: testSimpleSwap

  def testSimpleSwap(self):
    """Check that the swap annotations are followed."""
    a = variables.Variable(10, name='a')
    b = variables.Variable(20, name='b')
    c = math_ops.add_n([a, b], name='c')
    d = math_ops.add_n([b, c], name='d')
    train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
    train_op.append(d)

    d.op.node_def.attr['_swap_to_host'].i = 0

    mg = meta_graph.create_meta_graph_def(graph=ops.get_default_graph())
    graph_size = len(mg.graph_def.node)

    rewriter_config = rewriter_config_pb2.RewriterConfig(
        disable_model_pruning=True,
        memory_optimization=rewriter_config_pb2.RewriterConfig.MANUAL)
    graph = tf_optimizer.OptimizeGraph(rewriter_config, mg)

    self.assertEqual(len(graph.node), graph_size + 2)
    self.assertTrue(
        set([node.name for node in graph.node]) > set(
            ['a', 'b', 'c', 'd', 'swap_in_d_0', 'swap_out_d_0']))
    for node in graph.node:
      if node.name == 'swap_in_d_0':
        self.assertEqual('swap_out_d_0', node.input[0])
        self.assertEqual('^b/read', node.input[1])
      elif node.name == 'swap_out_d_0':
        self.assertEqual('b/read', node.input[0])
      elif node.name == 'd':
        self.assertEqual('swap_in_d_0', node.input[0])
        self.assertEqual('c', node.input[1])
开发者ID:Dr4KK,项目名称:tensorflow,代码行数:32,代码来源:memory_optimizer_test.py

示例9: apply_mask

def apply_mask(x, scope=''):
  """Apply mask to a given weight tensor.

  Args:
    x: Input weight tensor
    scope: The current variable scope. Defaults to ""
  Returns:
    Tensor representing masked_weights
  """

  mask = _weight_mask_variable(x, scope)
  threshold = _weight_threshold_variable(x, scope)
  # Add masked_weights in the weights namescope so as to make it easier
  # for the quantization library to add quant ops.
  masked_weights = math_ops.multiply(mask, x, _MASKED_WEIGHT_NAME)

  # Make sure the mask for a given variable are not added multiple times to the
  # collection. This is particularly important when applying mask to RNN's
  # weight variables
  if mask not in ops.get_collection_ref(_MASK_COLLECTION):
    ops.add_to_collection(_THRESHOLD_COLLECTION, threshold)
    ops.add_to_collection(_MASK_COLLECTION, mask)
    ops.add_to_collection(_MASKED_WEIGHT_COLLECTION, masked_weights)
    ops.add_to_collection(_WEIGHT_COLLECTION, x)
  return masked_weights
开发者ID:SylChan,项目名称:tensorflow,代码行数:25,代码来源:pruning.py

示例10: testUpdates

  def testUpdates(self):
    with ops.Graph().as_default() as g:
      a = constant_op.constant(10)
      b = constant_op.constant(20)
      c = a + b
      train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
      train_op.append(c)
      mg = meta_graph.create_meta_graph_def(graph=g)
      grappler_item = item.Item(mg)

    initial_tf_item = grappler_item.tf_item
    no_change_tf_item = grappler_item.tf_item
    self.assertEqual(initial_tf_item, no_change_tf_item)

    # Modify the placement.
    for node in grappler_item.metagraph.graph_def.node:
      node.device = '/cpu:0'
    new_tf_item = grappler_item.tf_item
    self.assertNotEqual(initial_tf_item, new_tf_item)

    # Assign the same placement.
    for node in grappler_item.metagraph.graph_def.node:
      node.device = '/cpu:0'
    newest_tf_item = grappler_item.tf_item
    self.assertEqual(new_tf_item, newest_tf_item)
开发者ID:AnddyWang,项目名称:tensorflow,代码行数:25,代码来源:item_test.py

示例11: testFromGenerator

  def testFromGenerator(self):
    test_cases = [{
        'tensor': 0,
        'shape': tensor_shape.TensorShape([])
    }, {
        'tensor': np.array([1, 2, 3]),
        'shape': tensor_shape.TensorShape([3])
    }, {
        'tensor': np.array([[1, 2, 3]]),
        'shape': tensor_shape.TensorShape([1, 3])
    }]

    for test_case in test_cases:

      def make_generator(tensor):

        def generator():
          yield tensor

        return generator

      with ops.Graph().as_default() as g:
        dataset = dataset_ops.Dataset.from_generator(
            make_generator(test_case['tensor']),
            dtypes.int64,
            output_shapes=test_case['shape'])
        iterator = dataset.make_one_shot_iterator()
        get_next = iterator.get_next()
        train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
        train_op.append(get_next)
        mg = meta_graph.create_meta_graph_def(graph=g)
        grappler_item = item.Item(mg)
        op_properties = grappler_item.GetOpProperties()
        self.assertEqual(test_case['shape'],
                         op_properties['IteratorGetNext'][0].shape)
开发者ID:JonathanRaiman,项目名称:tensorflow,代码行数:35,代码来源:datasets_test.py

示例12: testPruning

  def testPruning(self):
    x = constant_op.constant(1)

    tensor_list = list_ops.empty_tensor_list(
        element_dtype=x.dtype, element_shape=x.shape)

    def Cond(x, tl):
      del tl  # Unused for Cond.
      return x < 5

    def Body(x, tl):
      return x + 1, list_ops.tensor_list_push_back(tl, x)

    outputs = while_loop_v1(Cond, Body, [x, tensor_list])

    train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
    train_op.append(outputs[0])

    def GetOptimizedGraph():
      mg = meta_graph.create_meta_graph_def(graph=ops.get_default_graph())
      rewriter_config = rewriter_config_pb2.RewriterConfig(
          constant_folding=rewriter_config_pb2.RewriterConfig.OFF,
          memory_optimization=rewriter_config_pb2.RewriterConfig.MANUAL)
      return tf_optimizer.OptimizeGraph(rewriter_config, mg)

    g = GetOptimizedGraph()
    self.assertEqual(len([n for n in g.node if n.op == "Enter"]), 1)

    stack = list_ops.tensor_list_stack(outputs[1], element_dtype=x.dtype)
    train_op.append(stack)
    g = GetOptimizedGraph()
    self.assertEqual(len([n for n in g.node if n.op == "Enter"]), 2)
开发者ID:ThunderQi,项目名称:tensorflow,代码行数:32,代码来源:while_v2_test.py

示例13: testMap

  def testMap(self):
    test_cases = [{
        'tensor': 0,
        'shape': tensor_shape.TensorShape([])
    }, {
        'tensor': np.array([1, 2, 3]),
        'shape': tensor_shape.TensorShape([3])
    }, {
        'tensor': np.array([[1, 2, 3]]),
        'shape': tensor_shape.TensorShape([3, 1])
    }, {
        'tensor': np.array([[[1, 2, 3], [4, 5, 6]]]),
        'shape': tensor_shape.TensorShape([3, 2, 1])
    }]

    for test_case in test_cases:
      with ops.Graph().as_default() as g:
        dataset = dataset_ops.Dataset.from_tensors(test_case['tensor'])
        dataset = dataset.map(array_ops.transpose)
        iterator = dataset.make_one_shot_iterator()
        get_next = iterator.get_next()
        train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
        train_op.append(get_next)
        mg = meta_graph.create_meta_graph_def(graph=g)
        grappler_item = item.Item(mg)
        op_properties = grappler_item.GetOpProperties()
        self.assertEqual(test_case['shape'],
                         op_properties['IteratorGetNext'][0].shape)
开发者ID:JonathanRaiman,项目名称:tensorflow,代码行数:28,代码来源:datasets_test.py

示例14: testFromStringHandle

  def testFromStringHandle(self):
    test_cases = [{
        'shape': tensor_shape.TensorShape([])
    }, {
        'shape': tensor_shape.TensorShape([3])
    }, {
        'shape': tensor_shape.TensorShape([1, 2])
    }, {
        'shape': tensor_shape.TensorShape([1, 2, 3])
    }]

    for test_case in test_cases:
      with ops.Graph().as_default() as g:
        iterator = iterator_ops.Iterator.from_structure(dtypes.int64)
        handle = iterator.string_handle()
        iterator = iterator_ops.Iterator.from_string_handle(
            handle, dtypes.int64, output_shapes=test_case['shape'])
        get_next = iterator.get_next()
        train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
        train_op.append(get_next)
        mg = meta_graph.create_meta_graph_def(graph=g)
        grappler_item = item.Item(mg)
        op_properties = grappler_item.GetOpProperties()
        self.assertEqual(test_case['shape'],
                         op_properties['IteratorGetNext'][0].shape)
开发者ID:JonathanRaiman,项目名称:tensorflow,代码行数:25,代码来源:datasets_test.py

示例15: testInterleave

  def testInterleave(self):
    test_cases = [{
        'tensor': 0,
        'shape': tensor_shape.TensorShape([])
    }, {
        'tensor': np.array([1, 2, 3]),
        'shape': tensor_shape.TensorShape([3])
    }, {
        'tensor': np.array([[1, 2, 3]]),
        'shape': tensor_shape.TensorShape([1, 3])
    }]

    for test_case in test_cases:
      with ops.Graph().as_default() as g:
        dataset = dataset_ops.Dataset.range(42)

        def make_dataset(tensor):

          def dataset_fn(n):
            return dataset_ops.Dataset.from_tensors(tensor).repeat(n)

          return dataset_fn

        dataset = dataset.interleave(
            make_dataset(test_case['tensor']), cycle_length=42)
        iterator = dataset.make_one_shot_iterator()
        get_next = iterator.get_next()
        train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
        train_op.append(get_next)
        mg = meta_graph.create_meta_graph_def(graph=g)
        grappler_item = item.Item(mg)
        op_properties = grappler_item.GetOpProperties()
        self.assertEqual(test_case['shape'],
                         op_properties['IteratorGetNext'][0].shape)
开发者ID:JonathanRaiman,项目名称:tensorflow,代码行数:34,代码来源:datasets_test.py


注:本文中的tensorflow.python.framework.ops.get_collection_ref函数示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。