当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python tf.saved_model.experimental.TrackableResource用法及代码示例


持有 tf.function 可以捕获的张量。

用法

tf.saved_model.experimental.TrackableResource(
    device=''
)

参数

  • device 指示此资源所需放置的字符串,例如"CPU" 如果必须在 CPU 设备上创建此资源。空白设备允许用户创建资源,因此通常这应该是空白的,除非资源仅在一个设备上有意义。

属性

  • resource_handle 返回与此资源关联的资源句柄。

TrackableResource 对于需要初始化的有状态张量最有用,例如 tf.lookup.StaticHashTableTrackableResource 是通过遍历对象属性图来发现的,例如在 tf.saved_model.save 期间。

一个 TrackableResource 有三个方法可以覆盖:

  • _create_resource 应该创建资源张量句柄。
  • _initialize 应该初始化保存在 self.resource_handle 的资源。
  • _destroy_resourceTrackableResource 的破坏时被调用,并且应该减少资源的引用计数。对于大多数资源,这应该通过调用 tf.raw_ops.DestroyResourceOp 来完成。

示例用法:

class DemoResource(tf.saved_model.experimental.TrackableResource):
  def __init__(self):
    super().__init__()
    self._initialize()
  def _create_resource(self):
    return tf.raw_ops.VarHandleOp(dtype=tf.float32, shape=[2])
  def _initialize(self):
    tf.raw_ops.AssignVariableOp(
        resource=self.resource_handle, value=tf.ones([2]))
  def _destroy_resource(self):
    tf.raw_ops.DestroyResourceOp(resource=self.resource_handle)
class DemoModule(tf.Module):
  def __init__(self):
    self.resource = DemoResource()
  def increment(self, tensor):
    return tensor + tf.raw_ops.ReadVariableOp(
        resource=self.resource.resource_handle, dtype=tf.float32)
demo = DemoModule()
demo.increment([5, 1])
<tf.Tensor:shape=(2,), dtype=float32, numpy=array([6., 2.], dtype=float32)>

相关用法


注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.saved_model.experimental.TrackableResource。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。