本文整理汇总了Python中tensorflow.python.eager.tape.watch_variable函数的典型用法代码示例。如果您正苦于以下问题:Python watch_variable函数的具体用法?Python watch_variable怎么用?Python watch_variable使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了watch_variable函数的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: _read_variable_op
def _read_variable_op(self):
if hasattr(self, "_trainable") and self._trainable:
tape.watch_variable(self)
return read_variable_op(self._handle, dtype=self._dtype)
else:
return gen_resource_variable_ops.read_variable_op(self._handle,
self._dtype)
示例2: _lazy_read
def _lazy_read(self, op):
if hasattr(self, "_trainable") and self._trainable:
tape.watch_variable(self)
return _UnreadVariable(
self._handle, self.dtype, self._handle_device, self._shape,
self._in_graph_mode,
self._handle_deleter if not self._in_graph_mode else None, op)
示例3: _lazy_read
def _lazy_read(self, op):
if self.trainable:
tape.watch_variable(self)
return _UnreadVariable(
self._handle, self.dtype, self._shape, self._in_graph_mode,
self._handle_deleter if not self._in_graph_mode else None, op,
self._unique_id)
示例4: sparse_read
def sparse_read(self, indices, name=None):
"""Reads the value of this variable sparsely, using `gather`."""
with ops.name_scope("Gather" if name is None else name) as name:
if self._trainable:
tape.watch_variable(self)
value = gen_resource_variable_ops.resource_gather(
self._handle, indices, dtype=self._dtype, name=name)
return array_ops.identity(value)
示例5: watch
def watch(self, tensor):
"""Ensures that `tensor` is being traced by this tape.
Args:
tensor: a Tensor or list of Tensors.
"""
for t in nest.flatten(tensor):
if hasattr(t, "handle"):
# There are many variable-like objects, all of them currently have
# `handle` attribute that points to a tensor. If this changes, internals
# of watch_variable need to change as well.
tape.watch_variable(self._tape, t)
else:
tape.watch(self._tape, t)
示例6: __call__
def __call__(self, *args):
"""Executes the passed function in eager mode."""
for v in self._variables:
if v._trainable: # pylint: disable=protected-access
tape.watch_variable(v)
tensor_inputs = [
x for x in nest.flatten(args)
if isinstance(x, ops.Tensor)
]
if tape.should_record(tensor_inputs) or tape.should_record(
self._extra_inputs):
if not self._has_backprop:
self._compute_backprop()
return self._backprop_call(tensor_inputs)
ctx = context.context()
if ctx.in_graph_mode():
g = ops.get_default_graph()
if self._fdef.name not in g._functions: # pylint: disable=protected-access
g._add_function(self._fdef) # pylint: disable=protected-access
for f in self._graph._functions.values(): # pylint: disable=protected-access
if f.name not in g._functions: # pylint: disable=protected-access
g._add_function(f) # pylint: disable=protected-access
signature = self._fdef.definition.signature
args = list(tensor_inputs) + self._extra_inputs
op = g.create_op(
signature.name, [ops.convert_to_tensor(x) for x in args],
[dtypes.DType(x.type) for x in signature.output_arg],
op_def=signature,
name="FunctionCall",
compute_shapes=False)
result = op.outputs
if not result:
return op
for i, s in enumerate(self._output_shapes):
result[i].set_shape(s)
else:
result = execute.execute(
str(self._func_name),
num_outputs=self._num_outputs,
inputs=tensor_inputs + self._extra_inputs,
attrs=None,
ctx=ctx)
return self._build_call_outputs(result)
示例7: watch
def watch(self, tensor):
"""Ensures that `tensor` is being traced by this tape.
Args:
tensor: a Tensor or list of Tensors.
"""
for t in nest.flatten(tensor):
if not t.dtype.is_floating:
logging.log_first_n(
logging.WARN, "The dtype of the watched tensor must be "
"floating (e.g. tf.float32), got %r", 5, t.dtype)
if hasattr(t, "handle"):
# There are many variable-like objects, all of them currently have
# `handle` attribute that points to a tensor. If this changes, internals
# of watch_variable need to change as well.
tape.watch_variable(self._tape, t)
else:
tape.watch(self._tape, t)
示例8: __call__
def __call__(self, *args):
nest.assert_same_structure(self.shape_and_dtypes, args, check_types=False)
if not all([
shape.is_compatible_with(arg.shape)
for shape, arg in zip(self.flattened_shapes, nest.flatten(args))
]):
raise ValueError(
"Declared shapes do not match argument shapes: Expected %s, found %s."
% (self.flattened_shapes, [arg.shape for arg in nest.flatten(args)]))
initialized = [resource_variable_ops.var_is_initialized_op(
v.handle).numpy() for v in self._call_fn.variables]
if all(x for x in initialized):
for v in self._call_fn.variables:
if v._trainable: # pylint: disable=protected-access
tape.watch_variable(v)
return self._call_fn(*args)
elif all(not x for x in initialized):
return self._init_fn(*args)
else:
raise ValueError("Some, but not all, variables are initialized.")
示例9: __call__
def __call__(self, *args):
"""Executes the passed function in eager mode."""
for v in self._variables:
if v._trainable: # pylint: disable=protected-access
tape.watch_variable(v)
tensor_inputs = [x for x in nest.flatten(args) if isinstance(x, ops.Tensor)]
if tape.should_record(tensor_inputs) or tape.should_record(
self._extra_inputs):
if self._backward_function is None:
self._construct_backprop_function()
return self._backprop_call(tensor_inputs)
ctx = context.context()
if ctx.executing_eagerly():
result = execute.execute(
str(self._func_name),
num_outputs=self._num_outputs,
inputs=tensor_inputs + self._extra_inputs,
attrs=None,
ctx=ctx)
else:
g = ops.get_default_graph()
self.add_to_graph(g)
signature = self._function_def.definition.signature
args = list(tensor_inputs) + self._extra_inputs
op = g.create_op(
signature.name,
[ops.internal_convert_to_tensor(x, ctx=ctx) for x in args],
tuple(dtypes_module.DType(x.type) for x in signature.output_arg),
op_def=signature,
name="FunctionCall",
compute_shapes=False)
result = op.outputs
if not result:
return op
for i, s in enumerate(self._output_shapes):
result[i].set_shape(s)
return self._build_call_outputs(result)
示例10: g
def g(x):
tape.watch_variable(three)
return f(x)
示例11: f
def f():
tape.watch_variable(embedding)
embedded_x = embedding_ops.embedding_lookup(embedding, x)
return tensor.Tensor(1.0, dtypes.float32) - embedded_x
示例12: fn
def fn():
tape.watch_variable(x)
b = tensor.Tensor(2.0)
c = math_ops.add(x.value(), b)
return math_ops.add(c, tensor.Tensor(3.0))
示例13: read
def read(self, want_gradients=True):
if want_gradients and self.trainable:
v = tape.watch_variable(self.variable)
else:
v = self.variable
return v.read_value()
示例14: inner
def inner():
tape.watch_variable(v)
return v * v
示例15: f
def f():
tape.watch_variable(embedding)
embedded_x = embedding_ops.embedding_lookup(embedding, x)
return constant_op.constant(1.0, dtypes.float32) - embedded_x