當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python tf.saved_model.experimental.TrackableResource用法及代碼示例


持有 tf.function 可以捕獲的張量。

用法

tf.saved_model.experimental.TrackableResource(
    device=''
)

參數

  • device 指示此資源所需放置的字符串,例如"CPU" 如果必須在 CPU 設備上創建此資源。空白設備允許用戶創建資源,因此通常這應該是空白的,除非資源僅在一個設備上有意義。

屬性

  • resource_handle 返回與此資源關聯的資源句柄。

TrackableResource 對於需要初始化的有狀態張量最有用,例如 tf.lookup.StaticHashTableTrackableResource 是通過遍曆對象屬性圖來發現的,例如在 tf.saved_model.save 期間。

一個 TrackableResource 有三個方法可以覆蓋:

  • _create_resource 應該創建資源張量句柄。
  • _initialize 應該初始化保存在 self.resource_handle 的資源。
  • _destroy_resourceTrackableResource 的破壞時被調用,並且應該減少資源的引用計數。對於大多數資源,這應該通過調用 tf.raw_ops.DestroyResourceOp 來完成。

示例用法:

class DemoResource(tf.saved_model.experimental.TrackableResource):
  def __init__(self):
    super().__init__()
    self._initialize()
  def _create_resource(self):
    return tf.raw_ops.VarHandleOp(dtype=tf.float32, shape=[2])
  def _initialize(self):
    tf.raw_ops.AssignVariableOp(
        resource=self.resource_handle, value=tf.ones([2]))
  def _destroy_resource(self):
    tf.raw_ops.DestroyResourceOp(resource=self.resource_handle)
class DemoModule(tf.Module):
  def __init__(self):
    self.resource = DemoResource()
  def increment(self, tensor):
    return tensor + tf.raw_ops.ReadVariableOp(
        resource=self.resource.resource_handle, dtype=tf.float32)
demo = DemoModule()
demo.increment([5, 1])
<tf.Tensor:shape=(2,), dtype=float32, numpy=array([6., 2.], dtype=float32)>

相關用法


注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.saved_model.experimental.TrackableResource。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。