本文整理汇总了Python中tensorflow.python.ops.resource_variable_ops.is_resource_variable函数的典型用法代码示例。如果您正苦于以下问题:Python is_resource_variable函数的具体用法?Python is_resource_variable怎么用?Python is_resource_variable使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了is_resource_variable函数的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: update
def update(v, g):
"""Apply gradients to a replica variable."""
assert v is not None
try:
# Convert the grad to Tensor or IndexedSlices if necessary.
g = ops.convert_to_tensor_or_indexed_slices(g)
except TypeError:
raise TypeError("Gradient must be convertible to a Tensor"
" or IndexedSlices, or None: %s" % g)
if not isinstance(g, (ops.Tensor, ops.IndexedSlices)):
raise TypeError(
"Gradient must be a Tensor, IndexedSlices, or None: %s" % g)
p = _get_processor(v)
if context.executing_eagerly() or (
resource_variable_ops.is_resource_variable(v) and
not v._in_graph_mode): # pylint: disable=protected-access
scope_name = v.name.split(":")[0]
else:
scope_name = v.op.name
# device_policy is set because non-mirrored tensors will be read in
# `update_op`. `_resource_apply_dense`, `lr_t`, `beta1_t` and `beta2_t`
# is an example.
with ops.name_scope("update_" + scope_name):
return p.update_op(self, g)
示例2: __init__
def __init__(self, var, slice_spec, name):
self._var_device = var.device
self._var_shape = var.shape
if isinstance(var, ops.Tensor):
self.handle_op = var.op.inputs[0]
tensor = var
elif resource_variable_ops.is_resource_variable(var):
def _read_variable_closure(v):
def f():
with ops.device(v.device):
x = v.read_value()
# To allow variables placed on non-CPU devices to be checkpointed,
# we copy them to CPU on the same machine first.
with ops.device("/device:CPU:0"):
return array_ops.identity(x)
return f
self.handle_op = var.handle
tensor = _read_variable_closure(var)
else:
raise ValueError(
"Saveable is neither a resource variable nor a read operation."
" Got: %s" % repr(var))
spec = saveable_object.SaveSpec(tensor, slice_spec, name,
dtype=var.dtype, device=var.device)
super(ResourceVariableSaveable, self).__init__(var, [spec], name)
示例3: _write_object_proto
def _write_object_proto(obj, proto, asset_file_def_index):
"""Saves an object into SavedObject proto."""
if isinstance(obj, tracking.TrackableAsset):
proto.asset.SetInParent()
proto.asset.asset_file_def_index = asset_file_def_index[obj]
elif resource_variable_ops.is_resource_variable(obj):
proto.variable.SetInParent()
if not obj.name.endswith(":0"):
raise ValueError("Cowardly refusing to save variable %s because of"
" unexpected suffix which won't be restored.")
proto.variable.name = meta_graph._op_name(obj.name) # pylint: disable=protected-access
proto.variable.trainable = obj.trainable
proto.variable.dtype = obj.dtype.as_datatype_enum
proto.variable.synchronization = obj.synchronization.value
proto.variable.aggregation = obj.aggregation.value
proto.variable.shape.CopyFrom(obj.shape.as_proto())
elif isinstance(obj, def_function.Function):
proto.function.CopyFrom(
function_serialization.serialize_function(obj))
elif isinstance(obj, defun.ConcreteFunction):
proto.bare_concrete_function.CopyFrom(
function_serialization.serialize_bare_concrete_function(obj))
elif isinstance(obj, _CapturedConstant):
proto.constant.operation = obj.graph_tensor.op.name
elif isinstance(obj, tracking.CapturableResource):
proto.resource.device = obj._resource_device # pylint: disable=protected-access
else:
registered_type_proto = revived_types.serialize(obj)
if registered_type_proto is None:
# Fallback for types with no matching registration
registered_type_proto = saved_object_graph_pb2.SavedUserObject(
identifier="_generic_user_object",
version=versions_pb2.VersionDef(
producer=1, min_consumer=1, bad_consumers=[]))
proto.user_object.CopyFrom(registered_type_proto)
示例4: _write_object_proto
def _write_object_proto(obj, proto, asset_file_def_index, node_ids):
"""Saves an object into SavedObject proto."""
if isinstance(obj, tracking.TrackableAsset):
proto.asset.SetInParent()
proto.asset.asset_file_def_index = asset_file_def_index[obj]
elif resource_variable_ops.is_resource_variable(obj):
proto.variable.SetInParent()
proto.variable.trainable = obj.trainable
proto.variable.dtype = obj.dtype.as_datatype_enum
proto.variable.shape.CopyFrom(obj.shape.as_proto())
elif isinstance(obj, def_function.Function):
proto.function.CopyFrom(
function_serialization.serialize_function(obj, node_ids))
elif isinstance(obj, defun.ConcreteFunction):
proto.concrete_function.CopyFrom(
function_serialization.serialize_concrete_function(obj, node_ids))
else:
registered_type_proto = revived_types.serialize(obj)
if registered_type_proto is None:
# Fallback for types with no matching registration
registered_type_proto = saved_object_graph_pb2.SavedUserObject(
identifier="_generic_user_object",
version=versions_pb2.VersionDef(
producer=1, min_consumer=1, bad_consumers=[]))
proto.user_object.CopyFrom(registered_type_proto)
示例5: _create_non_slot_variable
def _create_non_slot_variable(self, initial_value, name, colocate_with):
"""Add an extra variable, not associated with a slot."""
# Recommendation: Use OptimizerV2 if your optimizer uses non-slot variables.
eager = context.executing_eagerly()
graph = None if eager else colocate_with.graph
key = (name, graph)
v = self._non_slot_dict.get(key, None)
if v is None:
self._maybe_initialize_trackable()
distribution_strategy = distribute_ctx.get_strategy()
with distribution_strategy.extended.colocate_vars_with(colocate_with):
if eager:
restored_initial_value = self._preload_simple_restoration(
name=name, shape=None)
if restored_initial_value is not None:
initial_value = restored_initial_value
v = variable_scope.variable(
initial_value, name=name, trainable=False,
use_resource=resource_variable_ops.is_resource_variable(
colocate_with))
# Restore this variable by name if necessary, but don't add a
# Trackable dependency. Optimizers return the current graph's
# non-slot variables from _checkpoint_dependencies explicitly rather
# than unconditionally adding dependencies (since there may be multiple
# non-slot variables with the same name in different graphs, trying to
# save all of them would result in errors).
self._handle_deferred_dependencies(name=name, trackable=v)
self._non_slot_dict[key] = v
return v
示例6: _create_slot_var
def _create_slot_var(primary, val, scope, validate_shape, shape, dtype):
"""Helper function for creating a slot variable."""
# TODO(lukaszkaiser): Consider allowing partitioners to be set in the current
# scope.
current_partitioner = variable_scope.get_variable_scope().partitioner
variable_scope.get_variable_scope().set_partitioner(None)
# When init from val instead of callable initializer, the shape is expected to
# be None, not <unknown> or any fully defined shape.
shape = shape if callable(val) else None
slot = variable_scope.get_variable(
scope, initializer=val, trainable=False,
use_resource=resource_variable_ops.is_resource_variable(primary),
shape=shape, dtype=dtype,
validate_shape=validate_shape)
variable_scope.get_variable_scope().set_partitioner(current_partitioner)
# pylint: disable=protected-access
if isinstance(primary, variables.Variable) and primary._save_slice_info:
# Primary is a partitioned variable, so we need to also indicate that
# the slot is a partitioned variable. Slots have the same partitioning
# as their primaries.
# For examples when using AdamOptimizer in linear model, slot.name
# here can be "linear//weights/Adam:0", while primary.op.name is
# "linear//weight". We want to get 'Adam' as real_slot_name, so we
# remove "'linear//weight' + '/'" and ':0'.
real_slot_name = slot.name[len(primary.op.name + "/"):-2]
slice_info = primary._save_slice_info
slot._set_save_slice_info(variables.Variable.SaveSliceInfo(
slice_info.full_name + "/" + real_slot_name,
slice_info.full_shape[:],
slice_info.var_offset[:],
slice_info.var_shape[:]))
# pylint: enable=protected-access
return slot
示例7: _write_object_graph
def _write_object_graph(saveable_view, export_dir, asset_file_def_index):
"""Save a SavedObjectGraph proto for `root`."""
# SavedObjectGraph is similar to the CheckpointableObjectGraph proto in the
# checkpoint. It will eventually go into the SavedModel.
proto = saved_object_graph_pb2.SavedObjectGraph()
saveable_view.fill_object_graph_proto(proto)
node_ids = util.ObjectIdentityDictionary()
for i, obj in enumerate(saveable_view.nodes):
node_ids[obj] = i
if resource_variable_ops.is_resource_variable(obj):
node_ids[obj.handle] = i
elif isinstance(obj, tracking.TrackableAsset):
node_ids[obj.asset_path.handle] = i
for obj, obj_proto in zip(saveable_view.nodes, proto.nodes):
_write_object_proto(obj, obj_proto, asset_file_def_index, node_ids)
extra_asset_dir = os.path.join(
compat.as_bytes(export_dir),
compat.as_bytes(constants.EXTRA_ASSETS_DIRECTORY))
file_io.recursive_create_dir(extra_asset_dir)
object_graph_filename = os.path.join(
extra_asset_dir, compat.as_bytes("object_graph.pb"))
file_io.write_string_to_file(object_graph_filename, proto.SerializeToString())
示例8: _get_tensor_from_node
def _get_tensor_from_node(self, node_id):
obj = self._nodes[node_id]
if resource_variable_ops.is_resource_variable(obj):
return obj.handle
elif isinstance(obj, tracking.TrackableAsset):
return obj.asset_path.handle
raise ValueError("Can't convert node %s to tensor" % (type(obj)))
示例9: _map_resources
def _map_resources(accessible_objects):
"""Makes new resource handle ops corresponding to existing resource tensors.
Creates resource handle ops in the current default graph, whereas
`accessible_objects` will be from an eager context. Resource mapping adds
resource handle ops to the main GraphDef of a SavedModel, which allows the C++
loader API to interact with variables.
Args:
accessible_objects: A list of objects, some of which may contain resources,
to create replacements for.
Returns:
A tuple of (object_map, resource_map):
object_map: A dictionary mapping from object in `accessible_objects` to
replacement objects created to hold the new resource tensors.
resource_map: A dictionary mapping from resource tensors extracted from
`accessible_objects` to newly created resource tensors.
"""
# TODO(allenl, rohanj): Map generic resources rather than just variables.
# TODO(allenl): Handle MirroredVariables and other types of variables which
# may need special casing.
object_map = {}
resource_map = {}
for obj in accessible_objects:
if resource_variable_ops.is_resource_variable(obj):
new_variable = resource_variable_ops.copy_to_graph_uninitialized(obj)
object_map[obj] = new_variable
resource_map[obj.handle] = new_variable.handle
return object_map, resource_map
示例10: _restore_checkpoint
def _restore_checkpoint(self):
"""Load state from checkpoint into the deserialized objects."""
variables_path = saved_model_utils.get_variables_path(self._export_dir)
# TODO(andresp): Clean use of private methods of TrackableSaver.
# pylint: disable=protected-access
saver = util.TrackableSaver(graph_view.ObjectGraphView(self.get(0)))
with ops.device("CPU"):
saver._file_prefix_placeholder = constant_op.constant(variables_path)
load_status = saver.restore(variables_path)
load_status.assert_existing_objects_matched()
checkpoint = load_status._checkpoint
# When running in eager mode, the `restore` call above has already run and
# restored the state of trackables, call `position.restore_ops()` will
# return an empty list as there is nothing left to do. In graph mode, that
# will return the list of ops that must run to restore the object on that
# position. We have to wire them in the initializers of the objects so that
# they get initialized properly when using common practices (e.g. the ones
# used by ManagedSession) without further user action.
for object_id, obj in dict(checkpoint.object_by_proto_id).items():
position = base.CheckpointPosition(checkpoint=checkpoint,
proto_id=object_id)
restore_ops = position.restore_ops()
if restore_ops:
if resource_variable_ops.is_resource_variable(obj):
obj._initializer_op = restore_ops
else:
raise NotImplementedError(
("Missing functionality to restore state of object "
"%r from the checkpoint." % obj))
示例11: map_resources
def map_resources(self):
"""Makes new resource handle ops corresponding to existing resource tensors.
Creates resource handle ops in the current default graph, whereas
`accessible_objects` will be from an eager context. Resource mapping adds
resource handle ops to the main GraphDef of a SavedModel, which allows the
C++ loader API to interact with variables.
Returns:
A tuple of (object_map, resource_map, asset_info):
object_map: A dictionary mapping from object in `accessible_objects` to
replacement objects created to hold the new resource tensors.
resource_map: A dictionary mapping from resource tensors extracted from
`accessible_objects` to newly created resource tensors.
asset_info: An _AssetInfo tuple describing external assets referenced
from accessible_objects.
"""
# Only makes sense when adding to the export Graph
assert not context.executing_eagerly()
# TODO(allenl): Handle MirroredVariables and other types of variables which
# may need special casing.
object_map = object_identity.ObjectIdentityDictionary()
resource_map = {}
asset_info = _AssetInfo(
asset_defs=[],
asset_initializers_by_resource={},
asset_filename_map={},
asset_index={})
for node_id, obj in enumerate(self.nodes):
if isinstance(obj, tracking.TrackableResource):
new_resource = obj._create_resource() # pylint: disable=protected-access
resource_map[obj.resource_handle] = new_resource
self.captured_tensor_node_ids[obj.resource_handle] = node_id
elif resource_variable_ops.is_resource_variable(obj):
new_variable = resource_variable_ops.copy_to_graph_uninitialized(obj)
object_map[obj] = new_variable
resource_map[obj.handle] = new_variable.handle
self.captured_tensor_node_ids[obj.handle] = node_id
elif isinstance(obj, tracking.TrackableAsset):
_process_asset(obj, asset_info, resource_map)
self.captured_tensor_node_ids[obj.asset_path] = node_id
for concrete_function in self.concrete_functions:
for capture in concrete_function.captured_inputs:
if (tensor_util.is_tensor(capture)
and capture.dtype not in _UNCOPIABLE_DTYPES
and capture not in self.captured_tensor_node_ids):
copied_tensor = constant_op.constant(
tensor_util.constant_value(capture))
node_id = len(self.nodes)
node = _CapturedConstant(
eager_tensor=capture, graph_tensor=copied_tensor)
self.nodes.append(node)
self.node_ids[capture] = node_id
self.node_ids[node] = node_id
self.captured_tensor_node_ids[capture] = node_id
resource_map[capture] = copied_tensor
return object_map, resource_map, asset_info
示例12: _get_tensor_from_node
def _get_tensor_from_node(self, node_id):
"""Resolves a node id into a tensor to be captured for a function."""
with ops.init_scope():
obj = self._nodes[node_id]
if resource_variable_ops.is_resource_variable(obj):
return obj.handle
elif isinstance(obj, tracking.TrackableAsset):
return obj.asset_path
elif tensor_util.is_tensor(obj):
return obj
elif isinstance(obj, tracking.CapturableResource):
# Note: this executes restored functions in the CapturableResource.
return obj.resource_handle
raise ValueError("Can't convert node %s to tensor" % (type(obj)))
示例13: _get_processor
def _get_processor(v):
"""The processor of v."""
if context.executing_eagerly():
if isinstance(v, ops.Tensor):
return _TensorProcessor(v)
else:
return _DenseResourceVariableProcessor(v)
if resource_variable_ops.is_resource_variable(v) and not v._in_graph_mode: # pylint: disable=protected-access
# True if and only if `v` was initialized eagerly.
return _DenseResourceVariableProcessor(v)
if v.op.type == "VarHandleOp":
return _DenseResourceVariableProcessor(v)
if isinstance(v, variables.Variable):
return _RefVariableProcessor(v)
if isinstance(v, ops.Tensor):
return _TensorProcessor(v)
raise NotImplementedError("Trying to optimize unsupported type ", v)
示例14: _get_expanded_variable_list
def _get_expanded_variable_list(var_list):
"""Given a list of variables, expands them if they are partitioned.
Args:
var_list: A list of variables.
Returns:
A list of variables where each partitioned variable is expanded to its
components.
"""
returned_list = []
for variable in var_list:
if (isinstance(variable, variable_ops.Variable) or
resource_variable_ops.is_resource_variable(variable)):
returned_list.append(variable) # Single variable case.
else: # Must be a PartitionedVariable, so convert into a list.
returned_list.extend(list(variable))
return returned_list
示例15: zero_initializer
def zero_initializer(ref, use_locking=True, name="zero_initializer"):
"""Initialize 'ref' with all zeros, ref tensor should be uninitialized.
If already initialized, you will get ValueError. This op is intended to
save memory during initialization.
Args:
ref: ref of the tensor need to be zero initialized.
name: optional name for this operation.
Returns:
ref that initialized.
Raises:
ValueError: If ref tensor is initialized.
"""
loader.load_op_library(
resource_loader.get_path_to_datafile("_variable_ops.so"))
if resource_variable_ops.is_resource_variable(ref):
return gen_variable_ops.zero_var_initializer(
ref.handle, shape=ref.shape, dtype=ref.dtype, name=name)
else:
return gen_variable_ops.zero_initializer(ref, name=name)