当前位置: 首页>>代码示例>>Python>>正文


Python v2.TensorSpec方法代码示例

本文整理汇总了Python中tensorflow.compat.v2.TensorSpec方法的典型用法代码示例。如果您正苦于以下问题:Python v2.TensorSpec方法的具体用法?Python v2.TensorSpec怎么用?Python v2.TensorSpec使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在tensorflow.compat.v2的用法示例。


在下文中一共展示了v2.TensorSpec方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: __init__

# 需要导入模块: from tensorflow.compat import v2 [as 别名]
# 或者: from tensorflow.compat.v2 import TensorSpec [as 别名]
def __init__(self, input_tensor_spec=None, state_spec=(), name=None):
    """Creates an instance of `Network`.

    Args:
      input_tensor_spec: A nest of `tf.TypeSpec` representing the
        input observations.  Optional.  If not provided, `create_variables()`
        will fail unless a spec is provided.
      state_spec: A nest of `tensor_spec.TensorSpec` representing the state
        needed by the network. Default is `()`, which means no state.
      name: (Optional.) A string representing the name of the network.
    """
    # Disable autocast because it may convert bfloats to other types, breaking
    # our spec checks.
    super(Network, self).__init__(name=name, autocast=False)
    common.check_tf1_allowed()

    # Required for summary() to work.
    self._is_graph_network = False

    self._input_tensor_spec = input_tensor_spec
    # NOTE(ebrevdo): Would have preferred to call this output_tensor_spec, but
    # looks like keras.Layer already reserves that one.
    self._network_output_spec = None
    self._state_spec = state_spec 
开发者ID:tensorflow,项目名称:agents,代码行数:26,代码来源:network.py

示例2: testReadWriteSpecs

# 需要导入模块: from tensorflow.compat import v2 [as 别名]
# 或者: from tensorflow.compat.v2 import TensorSpec [as 别名]
def testReadWriteSpecs(self):
    logdir = FLAGS.test_tmpdir
    specs = {
        'a': tf.TensorSpec(shape=(2, 3), dtype=tf.float32),
        'b': {
            'b_1': tf.TensorSpec(shape=(5,), dtype=tf.string),
            'b_2': tf.TensorSpec(shape=(5, 6), dtype=tf.int32),
        }
    }
    utils.write_specs(logdir, specs)
    # Now read and verify
    specs_read = utils.read_specs(logdir)

    def _check_equal(sp1, sp2):
      self.assertEqual(sp1, sp2)

    tf.nest.map_structure(_check_equal, specs, specs_read) 
开发者ID:google-research,项目名称:valan,代码行数:19,代码来源:utils_test.py

示例3: test_run_eval_actor_once

# 需要导入模块: from tensorflow.compat import v2 [as 别名]
# 或者: from tensorflow.compat.v2 import TensorSpec [as 别名]
def test_run_eval_actor_once(self):
    hparams = {}
    hparams['max_iter'] = 1
    hparams['num_episodes_per_iter'] = 5
    hparams['logdir'] = os.path.join(FLAGS.test_tmpdir, 'model')

    mock_problem = testing_utils.MockProblem(unroll_length=FLAGS.unroll_length)
    agent = mock_problem.get_agent()
    ckpt_manager = _get_ckpt_manager(hparams['logdir'], agent=agent)
    ckpt_manager.save(checkpoint_number=0)

    # Create a no-op gRPC server that responds to Aggregator RPCs.
    server_address = 'unix:/tmp/eval_actor_test_grpc'
    server = grpc.Server([server_address])

    @tf.function(input_signature=[tf.TensorSpec(shape=(), dtype=tf.string)])
    def eval_enqueue(_):
      return []

    server.bind(eval_enqueue, batched=False)

    server.start()

    eval_actor.run_with_aggregator(mock_problem, server_address, hparams) 
开发者ID:google-research,项目名称:valan,代码行数:26,代码来源:eval_actor_test.py

示例4: __init__

# 需要导入模块: from tensorflow.compat import v2 [as 别名]
# 或者: from tensorflow.compat.v2 import TensorSpec [as 别名]
def __init__(self, state_space_size, unroll_length=1):
    self._state_space_size = state_space_size
    # Creates simple dynamics (T stands for transition):
    #   states = [0, 1, ... len(state_space_size - 1)] + [STOP]
    #   actions = [-1, 1]
    #   T(s, a) = s + a  iff (s + a) is a valid state
    #           = STOP   otherwise
    self._action_space = [-1, 1]
    self._current_state = None
    self._env_spec = common.EnvOutput(
        reward=tf.TensorSpec(shape=[unroll_length + 1], dtype=tf.float32),
        done=tf.TensorSpec(shape=[unroll_length + 1], dtype=tf.bool),
        observation={
            'f1':
                tf.TensorSpec(
                    shape=[unroll_length + 1, 4, 10], dtype=tf.float32),
            'f2':
                tf.TensorSpec(
                    shape=[unroll_length + 1, 7, 10, 2], dtype=tf.float32)
        },
        info=tf.TensorSpec(shape=[unroll_length + 1], dtype=tf.string)) 
开发者ID:google-research,项目名称:valan,代码行数:23,代码来源:testing_utils.py

示例5: export_serving_model

# 需要导入模块: from tensorflow.compat import v2 [as 别名]
# 或者: from tensorflow.compat.v2 import TensorSpec [as 别名]
def export_serving_model(tf_transform_output, model, output_dir):
  """Exports a keras model for serving.

  Args:
    tf_transform_output: Wrapper around output of tf.Transform.
    model: A keras model to export for serving.
    output_dir: A directory where the model will be exported to.
  """
  # The layer has to be saved to the model for keras tracking purpases.
  model.tft_layer = tf_transform_output.transform_features_layer()

  @tf.function
  def serve_tf_examples_fn(serialized_tf_examples):
    """Serving tf.function model wrapper."""
    feature_spec = RAW_DATA_FEATURE_SPEC.copy()
    feature_spec.pop(LABEL_KEY)
    parsed_features = tf.io.parse_example(serialized_tf_examples, feature_spec)
    transformed_features = model.tft_layer(parsed_features)
    outputs = model(transformed_features)
    classes_names = tf.constant([['0', '1']])
    classes = tf.tile(classes_names, [tf.shape(outputs)[0], 1])
    return {'classes': classes, 'scores': outputs}

  concrete_serving_fn = serve_tf_examples_fn.get_concrete_function(
      tf.TensorSpec(shape=[None], dtype=tf.string, name='inputs'))
  signatures = {'serving_default': concrete_serving_fn}

  # This is required in order to make this model servable with model_server.
  versioned_output_dir = os.path.join(output_dir, '1')
  model.save(versioned_output_dir, save_format='tf', signatures=signatures) 
开发者ID:tensorflow,项目名称:transform,代码行数:32,代码来源:census_example_v2.py

示例6: get_state_spec

# 需要导入模块: from tensorflow.compat import v2 [as 别名]
# 或者: from tensorflow.compat.v2 import TensorSpec [as 别名]
def get_state_spec(layer: tf.keras.layers.Layer) -> types.NestedTensorSpec:
  """Extracts the state spec from a layer.

  Args:
    layer: The layer to extract from; can be a `Network`.

  Returns:
    The state spec.

  Raises:
    TypeError: If `layer` is a subclass of `tf.keras.layers.RNN` (it must
      be wrapped by an `RNNWrapper` object).
    ValueError: If `layer` is a Keras layer and `create_variables` has
      not been called on it.
  """
  if isinstance(layer, Network):
    return layer.state_spec

  if isinstance(layer, tf.keras.layers.RNN):
    raise TypeError("RNN Layer must be wrapped inside "
                    "`tf_agents.keras_layers.RNNWrapper`: {}".format(layer))

  initial_state = getattr(layer, "get_initial_state", None)
  state_size = getattr(layer, "state_size", None)
  if initial_state is not None and state_size is None:
    raise ValueError(
        "Layer lacks a `state_size` property.  Unable to extract state "
        "spec: {}".format(layer))
  state_spec = ()
  if state_size is not None:
    state_spec = tf.nest.map_structure(
        lambda s: tf.TensorSpec(dtype=layer.dtype, shape=s), state_size)

  return state_spec 
开发者ID:tensorflow,项目名称:agents,代码行数:36,代码来源:network.py

示例7: __init__

# 需要导入模块: from tensorflow.compat import v2 [as 别名]
# 或者: from tensorflow.compat.v2 import TensorSpec [as 别名]
def __init__(self, time_step_spec, net):
    super(MyPolicy, self).__init__(
        time_step_spec,
        action_spec=tf.TensorSpec((None,), tf.float32))
    self._net = net 
开发者ID:tensorflow,项目名称:agents,代码行数:7,代码来源:nest_map_test.py

示例8: testIncompatibleStructureInputs

# 需要导入模块: from tensorflow.compat import v2 [as 别名]
# 或者: from tensorflow.compat.v2 import TensorSpec [as 别名]
def testIncompatibleStructureInputs(self):
    with self.assertRaisesRegex(
        ValueError,
        r'`nested_layers` and `input_spec` do not have matching structures'):
      nest_map.NestMap(
          tf.keras.layers.Dense(8),
          input_spec={'ick': tf.TensorSpec(8, tf.float32)})

    with self.assertRaisesRegex(
        ValueError,
        r'`inputs` and `self.nested_layers` do not have matching structures'):
      net = nest_map.NestMap(tf.keras.layers.Dense(8))
      net.create_variables({'ick': tf.TensorSpec((1,), dtype=tf.float32)})

    with self.assertRaisesRegex(
        ValueError,
        r'`inputs` and `self.nested_layers` do not have matching structures'):
      net = nest_map.NestMap(tf.keras.layers.Dense(8))
      net({'ick': tf.constant([[1.0]])})

    with self.assertRaisesRegex(
        ValueError,
        r'`network_state` and `state_spec` do not have matching structures'):
      net = nest_map.NestMap(
          tf.keras.layers.LSTM(8, return_state=True, return_sequences=True))
      net(tf.ones((1, 2)), network_state=(tf.ones((1, 1)), ())) 
开发者ID:tensorflow,项目名称:agents,代码行数:28,代码来源:nest_map_test.py

示例9: setUp

# 需要导入模块: from tensorflow.compat import v2 [as 别名]
# 或者: from tensorflow.compat.v2 import TensorSpec [as 别名]
def setUp(self):
    super(PolicyInfoUpdaterWrapperTest, self).setUp()
    self._obs_spec = tensor_spec.TensorSpec([2], tf.float32)
    self._time_step_spec = ts.time_step_spec(self._obs_spec) 
开发者ID:tensorflow,项目名称:agents,代码行数:6,代码来源:policy_info_updater_wrapper_test.py

示例10: test_model_id_updater

# 需要导入模块: from tensorflow.compat import v2 [as 别名]
# 或者: from tensorflow.compat.v2 import TensorSpec [as 别名]
def test_model_id_updater(self):
    loc = 0.0
    scale = 0.5
    action_spec = tensor_spec.BoundedTensorSpec([1], tf.float32, tf.float32.min,
                                                tf.float32.max)
    wrapped_policy = DistributionPolicy(
        distribution=tfp.distributions.Normal([loc], [scale]),
        time_step_spec=self._time_step_spec,
        action_spec=action_spec,
        info_spec={
            'test_info':
                tf.TensorSpec(shape=(1,), dtype=tf.int32, name='test_info')
        })
    updater_info_spec = {
        'model_id': tf.TensorSpec(shape=(1,), dtype=tf.int32, name='model_id')
    }
    updater_info_spec.update(wrapped_policy.info_spec)
    policy = policy_info_updater_wrapper.PolicyInfoUpdaterWrapper(
        policy=wrapped_policy,
        info_spec=updater_info_spec,
        updater_fn=ModelIdUpdater(),
        name='model_id_updater')

    self.assertEqual(policy.time_step_spec, self._time_step_spec)
    self.assertEqual(policy.action_spec, action_spec)

    observations = tf.constant([[1, 2], [3, 4]], dtype=tf.float32)
    time_step = ts.restart(observations, batch_size=2)
    action_step = policy.action(time_step)
    distribution_step = policy.distribution(time_step)

    tf.nest.assert_same_structure(action_spec, action_step.action)
    tf.nest.assert_same_structure(action_spec, distribution_step.action)

    self.assertListEqual(list(self.evaluate(action_step.info['model_id'])), [2])
    self.assertListEqual(
        list(self.evaluate(distribution_step.info['model_id'])), [2]) 
开发者ID:tensorflow,项目名称:agents,代码行数:39,代码来源:policy_info_updater_wrapper_test.py

示例11: _check_value

# 需要导入模块: from tensorflow.compat import v2 [as 别名]
# 或者: from tensorflow.compat.v2 import TensorSpec [as 别名]
def _check_value(self, tensor: tf.Tensor, tensorspec: tf.TensorSpec):
    if not tf.TensorShape(tf.squeeze(tensor.get_shape())).is_compatible_with(
        tensorspec.shape):
      raise ValueError(
          'Tensor {} is not compatible with specification {}.'.format(
              tensor, tensorspec)) 
开发者ID:tensorflow,项目名称:agents,代码行数:8,代码来源:policy_info_updater_wrapper.py

示例12: assert_matches_spec

# 需要导入模块: from tensorflow.compat import v2 [as 别名]
# 或者: from tensorflow.compat.v2 import TensorSpec [as 别名]
def assert_matches_spec(specs, tensor_list):
  """Assert that a list of tensors matches the given TensorSpecs."""
  # Weirdly `tf.nest.pack_sequence_as` doesn't fail if tensor_list doesn't
  # conform to the specs type. So first pack the sequence, then explicitly
  # check the compatibility of each tensor with the corresponding spec.
  packed_tensors = tf.nest.pack_sequence_as(specs, tensor_list)
  packed_tensors = tf.nest.map_structure(tf.convert_to_tensor, packed_tensors)

  def is_compatible(sp, tensor):
    assert sp.is_compatible_with(
        tensor), 'TensorSpec {} is not compatible with tensor {}'.format(
            sp, tensor)

  tf.nest.map_structure(is_compatible, specs, packed_tensors) 
开发者ID:google-research,项目名称:valan,代码行数:16,代码来源:testing_utils.py

示例13: add_time_dimension

# 需要导入模块: from tensorflow.compat import v2 [as 别名]
# 或者: from tensorflow.compat.v2 import TensorSpec [as 别名]
def add_time_dimension(s: tf.TensorSpec):
  return tf.TensorSpec([FLAGS.unroll_length + 1] + s.shape.as_list(), s.dtype) 
开发者ID:google-research,项目名称:valan,代码行数:4,代码来源:actor.py

示例14: _write_tensor_specs

# 需要导入模块: from tensorflow.compat import v2 [as 别名]
# 或者: from tensorflow.compat.v2 import TensorSpec [as 别名]
def _write_tensor_specs(initial_agent_state: Any,
                        env_output: common.EnvOutput,
                        agent_output: common.AgentOutput,
                        actor_action: common.ActorAction,
                        loss_type: Optional[int] = common.AC_LOSS):
  """Writes tensor specs of ActorOutput tuple to disk.

  Args:
    initial_agent_state: A tensor or nested structure of tensor without any time
      or batch dimensions.
    env_output: An instance of `EnvOutput` where individual tensors don't have
      time and batch dimensions.
    agent_output: An instance of `AgentOutput` where individual tensors don't
      have time and batch dimensions.
    actor_action: An instance of `ActorAction`.
    loss_type: A scalar int denoting the loss type.
  """
  actor_output = common.ActorOutput(
      initial_agent_state,
      env_output,
      agent_output,
      actor_action,
      loss_type,
      info='')
  specs = tf.nest.map_structure(tf.convert_to_tensor, actor_output)
  specs = tf.nest.map_structure(tf.TensorSpec.from_tensor, specs)
  env_output = tf.nest.map_structure(add_time_dimension, specs.env_output)
  agent_output = tf.nest.map_structure(add_time_dimension, specs.agent_output)
  actor_action = tf.nest.map_structure(add_time_dimension, specs.actor_action)
  specs = specs._replace(
      env_output=env_output,
      agent_output=agent_output,
      actor_action=actor_action)
  utils.write_specs(FLAGS.logdir, specs) 
开发者ID:google-research,项目名称:valan,代码行数:36,代码来源:actor.py

示例15: eval_on_shapes

# 需要导入模块: from tensorflow.compat import v2 [as 别名]
# 或者: from tensorflow.compat.v2 import TensorSpec [as 别名]
def eval_on_shapes(f, static_argnums=()):
  """Returns a function that evaluates `f` given input shapes and dtypes.

  It transforms function `f` to a function that performs the same computation as
  `f` but only on shapes and dtypes (a.k.a. shape inference).

  Args:
    f: the function to be transformed.
    static_argnums: See documentation of `jit`.

  Returns:
    A function whose input arguments can be either the same as `f`'s or only
    their shapes/dtypes represented by `TensorSpec`, and whose return values are
    `TensorSpec`s with the same nested structure as `f`'s return values.
  """
  # TODO(wangpeng): tf.function could add a knob to turn off materializing the
  #   graph, so that we don't waste computation and memory when we just want
  #   shape inference.
  tf_f = jit(f, static_argnums=static_argnums).tf_function

  # pylint: disable=missing-docstring
  def f_return(*args):

    def abstractify(x):
      x = _canonicalize_jit_arg(x)
      if isinstance(x, (tf.Tensor, tf_np.ndarray)):
        return tf.TensorSpec(x.shape, x.dtype)
      else:
        return x

    def to_tensor_spec(x):
      if isinstance(x, tf.Tensor):
        return tf.TensorSpec(x.shape, x.dtype)
      else:
        return x

    new_args = []
    for i, arg in enumerate(args):
      if i in static_argnums:
        new_args.append(arg)
      else:
        new_args.append(tf.nest.map_structure(abstractify, arg))
    res = tf_f.get_concrete_function(*new_args).structured_outputs

    return tf.nest.map_structure(to_tensor_spec, res)

  # Provides access to `tf_f` for testing purpose.
  f_return._tf_function = tf_f  # pylint: disable=protected-access
  return f_return 
开发者ID:google,项目名称:trax,代码行数:51,代码来源:extensions.py


注:本文中的tensorflow.compat.v2.TensorSpec方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。