当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python tf.Tensor.ref用法及代码示例


用法

ref()

返回此张量的可散列引用对象。

此 API 的主要用例是将张量放入集合/字典中。我们不能将张量放在集合/字典中,因为从 Tensorflow 2.0 开始,tensor.__hash__() 不再可用。

以下将引发从 2.0 开始的异常

x = tf.constant(5)
y = tf.constant(10)
z = tf.constant(10)
tensor_set = {x, y, z}
Traceback (most recent call last):

TypeError:Tensor is unhashable. Instead, use tensor.ref() as the key.
tensor_dict = {x:'five', y:'ten'}
Traceback (most recent call last):

TypeError:Tensor is unhashable. Instead, use tensor.ref() as the key.

相反,我们可以使用 tensor.ref()

tensor_set = {x.ref(), y.ref(), z.ref()}
x.ref() in tensor_set
True
tensor_dict = {x.ref():'five', y.ref():'ten', z.ref():'ten'}
tensor_dict[y.ref()]
'ten'

此外,参考对象提供了返回原始张量的.deref() 函数。

x = tf.constant(5)
x.ref().deref()
<tf.Tensor:shape=(), dtype=int32, numpy=5>

相关用法


注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.Tensor.ref。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。