本文整理汇总了Python中sonnet.get_variables_in_module方法的典型用法代码示例。如果您正苦于以下问题:Python sonnet.get_variables_in_module方法的具体用法?Python sonnet.get_variables_in_module怎么用?Python sonnet.get_variables_in_module使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类sonnet
的用法示例。
在下文中一共展示了sonnet.get_variables_in_module方法的13个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: get_trainable_vars
# 需要导入模块: import sonnet [as 别名]
# 或者: from sonnet import get_variables_in_module [as 别名]
def get_trainable_vars(self):
"""Get trainable vars included in the module.
"""
trainable_vars = snt.get_variables_in_module(self)
if self._config.model.base_network.trainable:
pretrained_trainable_vars = self.base_network.get_trainable_vars()
if len(pretrained_trainable_vars):
tf.logging.info(
'Training {} vars from pretrained module; '
'from "{}" to "{}".'.format(
len(pretrained_trainable_vars),
pretrained_trainable_vars[0].name,
pretrained_trainable_vars[-1].name,
)
)
else:
tf.logging.info('No vars from pretrained module to train.')
trainable_vars += pretrained_trainable_vars
else:
tf.logging.info('Not training variables from pretrained module')
return trainable_vars
示例2: save
# 需要导入模块: import sonnet [as 别名]
# 或者: from sonnet import get_variables_in_module [as 别名]
def save(network, sess, filename=None):
"""Save the variables contained by a network to disk."""
to_save = collections.defaultdict(dict)
variables = snt.get_variables_in_module(network)
for v in variables:
split = v.name.split(":")[0].split("/")
module_name = split[-2]
variable_name = split[-1]
to_save[module_name][variable_name] = v.eval(sess)
if filename:
with open(filename, "wb") as f:
pickle.dump(to_save, f)
return to_save
示例3: w
# 需要导入模块: import sonnet [as 别名]
# 或者: from sonnet import get_variables_in_module [as 别名]
def w(self):
var_list = snt.get_variables_in_module(self)
w = [x for x in var_list if self._raw_name(x.name) == "w"]
assert len(w) == 1
return w[0]
示例4: b
# 需要导入模块: import sonnet [as 别名]
# 或者: from sonnet import get_variables_in_module [as 别名]
def b(self):
var_list = snt.get_variables_in_module(self)
b = [x for x in var_list if self._raw_name(x.name) == "b"]
assert len(b) == 1
return b[0]
示例5: remote_variables
# 需要导入模块: import sonnet [as 别名]
# 或者: from sonnet import get_variables_in_module [as 别名]
def remote_variables(self):
train = list(
snt.get_variables_in_module(self, tf.GraphKeys.TRAINABLE_VARIABLES))
train += list(
snt.get_variables_in_module(self,
tf.GraphKeys.MOVING_AVERAGE_VARIABLES))
return train
示例6: local_variables
# 需要导入模块: import sonnet [as 别名]
# 或者: from sonnet import get_variables_in_module [as 别名]
def local_variables(self):
"""List of variables that need to be updated for each evaluation.
These variables should not be stored on a parameter server and
should be reset every computation of a meta_objective loss.
Returns:
vars: list of tf.Variable
"""
return list(
snt.get_variables_in_module(self, tf.GraphKeys.TRAINABLE_VARIABLES))
示例7: get_variables_in_modules
# 需要导入模块: import sonnet [as 别名]
# 或者: from sonnet import get_variables_in_module [as 别名]
def get_variables_in_modules(module_list):
var_list = []
for m in module_list:
var_list.extend(snt.get_variables_in_module(m))
return var_list
示例8: _get_base_network_vars
# 需要导入模块: import sonnet [as 别名]
# 或者: from sonnet import get_variables_in_module [as 别名]
def _get_base_network_vars(self):
"""Returns a list of all the base network's variables."""
if self.pretrained_weights_scope:
# We may have defined the base network in a particular scope
module_variables = tf.get_collection(
tf.GraphKeys.MODEL_VARIABLES,
scope=self.pretrained_weights_scope
)
else:
module_variables = snt.get_variables_in_module(
self, tf.GraphKeys.MODEL_VARIABLES
)
assert len(module_variables) > 0
return module_variables
示例9: get_trainable_vars
# 需要导入模块: import sonnet [as 别名]
# 或者: from sonnet import get_variables_in_module [as 别名]
def get_trainable_vars(self):
"""
Returns a list of the variables that are trainable.
If a value for `fine_tune_from` is specified in the config, only the
variables starting from the first that contains this string in its name
will be trainable. For example, specifying `vgg_16/fc6` for a VGG16
will set only the variables in the fully connected layers to be
trainable.
If `fine_tune_from` is None, then all the variables will be trainable.
Returns:
trainable_variables: a tuple of `tf.Variable`.
"""
all_variables = snt.get_variables_in_module(self)
fine_tune_from = self._config.get('fine_tune_from')
if fine_tune_from is None:
return all_variables
# Get the index of the first trainable variable
var_iter = enumerate(v.name for v in all_variables)
try:
index = next(i for i, name in var_iter if fine_tune_from in name)
except StopIteration:
raise ValueError(
'"{}" is an invalid value of fine_tune_from for this '
'architecture.'.format(fine_tune_from)
)
return all_variables[index:]
示例10: get_trainable_vars
# 需要导入模块: import sonnet [as 别名]
# 或者: from sonnet import get_variables_in_module [as 别名]
def get_trainable_vars(self):
"""Get trainable vars included in the module.
"""
trainable_vars = snt.get_variables_in_module(self)
if self._config.base_network.trainable:
pretrained_trainable_vars = (
self.feature_extractor.get_trainable_vars()
)
tf.logging.info('Training {} vars from pretrained module.'.format(
len(pretrained_trainable_vars)))
trainable_vars += pretrained_trainable_vars
else:
tf.logging.info('Not training variables from pretrained module')
return trainable_vars
示例11: get_trainable_vars
# 需要导入模块: import sonnet [as 别名]
# 或者: from sonnet import get_variables_in_module [as 别名]
def get_trainable_vars(self):
"""
Returns a list of the variables that are trainable.
Returns:
trainable_variables: a tuple of `tf.Variable`.
"""
return snt.get_variables_in_module(self)
示例12: testTrainable
# 需要导入模块: import sonnet [as 别名]
# 或者: from sonnet import get_variables_in_module [as 别名]
def testTrainable(self):
"""Tests the network contains trainable variables."""
shape = [10, 5]
gradients = tf.random_normal(shape)
net = networks.CoordinateWiseDeepLSTM(layers=(1,))
state = net.initial_state_for_inputs(gradients)
net(gradients, state)
# Weights and biases for two layers.
variables = snt.get_variables_in_module(net)
self.assertEqual(len(variables), 4)
示例13: testNonTrainable
# 需要导入模块: import sonnet [as 别名]
# 或者: from sonnet import get_variables_in_module [as 别名]
def testNonTrainable(self):
"""Tests the network doesn't contain trainable variables."""
shape = [10, 5]
gradients = tf.random_normal(shape)
net = networks.Sgd()
state = net.initial_state_for_inputs(gradients)
net(gradients, state)
variables = snt.get_variables_in_module(net)
self.assertEqual(len(variables), 0)