当前位置: 首页>>代码示例>>Python>>正文


Python sonnet.get_variables_in_module方法代码示例

本文整理汇总了Python中sonnet.get_variables_in_module方法的典型用法代码示例。如果您正苦于以下问题:Python sonnet.get_variables_in_module方法的具体用法?Python sonnet.get_variables_in_module怎么用?Python sonnet.get_variables_in_module使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在sonnet的用法示例。


在下文中一共展示了sonnet.get_variables_in_module方法的13个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: get_trainable_vars

# 需要导入模块: import sonnet [as 别名]
# 或者: from sonnet import get_variables_in_module [as 别名]
def get_trainable_vars(self):
        """Get trainable vars included in the module.
        """
        trainable_vars = snt.get_variables_in_module(self)
        if self._config.model.base_network.trainable:
            pretrained_trainable_vars = self.base_network.get_trainable_vars()
            if len(pretrained_trainable_vars):
                tf.logging.info(
                    'Training {} vars from pretrained module; '
                    'from "{}" to "{}".'.format(
                        len(pretrained_trainable_vars),
                        pretrained_trainable_vars[0].name,
                        pretrained_trainable_vars[-1].name,
                    )
                )
            else:
                tf.logging.info('No vars from pretrained module to train.')
            trainable_vars += pretrained_trainable_vars
        else:
            tf.logging.info('Not training variables from pretrained module')

        return trainable_vars 
开发者ID:Sargunan,项目名称:Table-Detection-using-Deep-learning,代码行数:24,代码来源:fasterrcnn.py

示例2: save

# 需要导入模块: import sonnet [as 别名]
# 或者: from sonnet import get_variables_in_module [as 别名]
def save(network, sess, filename=None):
  """Save the variables contained by a network to disk."""
  to_save = collections.defaultdict(dict)
  variables = snt.get_variables_in_module(network)

  for v in variables:
    split = v.name.split(":")[0].split("/")
    module_name = split[-2]
    variable_name = split[-1]
    to_save[module_name][variable_name] = v.eval(sess)

  if filename:
    with open(filename, "wb") as f:
      pickle.dump(to_save, f)

  return to_save 
开发者ID:deepmind,项目名称:learning-to-learn,代码行数:18,代码来源:networks.py

示例3: w

# 需要导入模块: import sonnet [as 别名]
# 或者: from sonnet import get_variables_in_module [as 别名]
def w(self):
    var_list = snt.get_variables_in_module(self)
    w = [x for x in var_list if self._raw_name(x.name) == "w"]
    assert len(w) == 1
    return w[0] 
开发者ID:itsamitgoel,项目名称:Gun-Detector,代码行数:7,代码来源:common.py

示例4: b

# 需要导入模块: import sonnet [as 别名]
# 或者: from sonnet import get_variables_in_module [as 别名]
def b(self):
    var_list = snt.get_variables_in_module(self)
    b = [x for x in var_list if self._raw_name(x.name) == "b"]
    assert len(b) == 1
    return b[0] 
开发者ID:itsamitgoel,项目名称:Gun-Detector,代码行数:7,代码来源:common.py

示例5: remote_variables

# 需要导入模块: import sonnet [as 别名]
# 或者: from sonnet import get_variables_in_module [as 别名]
def remote_variables(self):
    train = list(
        snt.get_variables_in_module(self, tf.GraphKeys.TRAINABLE_VARIABLES))
    train += list(
        snt.get_variables_in_module(self,
                                    tf.GraphKeys.MOVING_AVERAGE_VARIABLES))
    return train 
开发者ID:itsamitgoel,项目名称:Gun-Detector,代码行数:9,代码来源:more_local_weight_update.py

示例6: local_variables

# 需要导入模块: import sonnet [as 别名]
# 或者: from sonnet import get_variables_in_module [as 别名]
def local_variables(self):
    """List of variables that need to be updated for each evaluation.

    These variables should not be stored on a parameter server and
    should be reset every computation of a meta_objective loss.

    Returns:
      vars: list of tf.Variable
    """
    return list(
        snt.get_variables_in_module(self, tf.GraphKeys.TRAINABLE_VARIABLES)) 
开发者ID:itsamitgoel,项目名称:Gun-Detector,代码行数:13,代码来源:linear_regression.py

示例7: get_variables_in_modules

# 需要导入模块: import sonnet [as 别名]
# 或者: from sonnet import get_variables_in_module [as 别名]
def get_variables_in_modules(module_list):
  var_list = []
  for m in module_list:
    var_list.extend(snt.get_variables_in_module(m))
  return var_list 
开发者ID:itsamitgoel,项目名称:Gun-Detector,代码行数:7,代码来源:utils.py

示例8: _get_base_network_vars

# 需要导入模块: import sonnet [as 别名]
# 或者: from sonnet import get_variables_in_module [as 别名]
def _get_base_network_vars(self):
        """Returns a list of all the base network's variables."""
        if self.pretrained_weights_scope:
            # We may have defined the base network in a particular scope
            module_variables = tf.get_collection(
                tf.GraphKeys.MODEL_VARIABLES,
                scope=self.pretrained_weights_scope
            )
        else:
            module_variables = snt.get_variables_in_module(
                self, tf.GraphKeys.MODEL_VARIABLES
            )
        assert len(module_variables) > 0
        return module_variables 
开发者ID:Sargunan,项目名称:Table-Detection-using-Deep-learning,代码行数:16,代码来源:base_network.py

示例9: get_trainable_vars

# 需要导入模块: import sonnet [as 别名]
# 或者: from sonnet import get_variables_in_module [as 别名]
def get_trainable_vars(self):
        """
        Returns a list of the variables that are trainable.

        If a value for `fine_tune_from` is specified in the config, only the
        variables starting from the first that contains this string in its name
        will be trainable. For example, specifying `vgg_16/fc6` for a VGG16
        will set only the variables in the fully connected layers to be
        trainable.
        If `fine_tune_from` is None, then all the variables will be trainable.

        Returns:
            trainable_variables: a tuple of `tf.Variable`.
        """
        all_variables = snt.get_variables_in_module(self)

        fine_tune_from = self._config.get('fine_tune_from')
        if fine_tune_from is None:
            return all_variables

        # Get the index of the first trainable variable
        var_iter = enumerate(v.name for v in all_variables)
        try:
            index = next(i for i, name in var_iter if fine_tune_from in name)
        except StopIteration:
            raise ValueError(
                '"{}" is an invalid value of fine_tune_from for this '
                'architecture.'.format(fine_tune_from)
            )

        return all_variables[index:] 
开发者ID:Sargunan,项目名称:Table-Detection-using-Deep-learning,代码行数:33,代码来源:base_network.py

示例10: get_trainable_vars

# 需要导入模块: import sonnet [as 别名]
# 或者: from sonnet import get_variables_in_module [as 别名]
def get_trainable_vars(self):
        """Get trainable vars included in the module.
        """
        trainable_vars = snt.get_variables_in_module(self)
        if self._config.base_network.trainable:
            pretrained_trainable_vars = (
                self.feature_extractor.get_trainable_vars()
            )
            tf.logging.info('Training {} vars from pretrained module.'.format(
                len(pretrained_trainable_vars)))
            trainable_vars += pretrained_trainable_vars
        else:
            tf.logging.info('Not training variables from pretrained module')

        return trainable_vars 
开发者ID:Sargunan,项目名称:Table-Detection-using-Deep-learning,代码行数:17,代码来源:ssd.py

示例11: get_trainable_vars

# 需要导入模块: import sonnet [as 别名]
# 或者: from sonnet import get_variables_in_module [as 别名]
def get_trainable_vars(self):
        """
        Returns a list of the variables that are trainable.

        Returns:
            trainable_variables: a tuple of `tf.Variable`.
        """
        return snt.get_variables_in_module(self) 
开发者ID:Sargunan,项目名称:Table-Detection-using-Deep-learning,代码行数:10,代码来源:feature_extractor.py

示例12: testTrainable

# 需要导入模块: import sonnet [as 别名]
# 或者: from sonnet import get_variables_in_module [as 别名]
def testTrainable(self):
    """Tests the network contains trainable variables."""
    shape = [10, 5]
    gradients = tf.random_normal(shape)
    net = networks.CoordinateWiseDeepLSTM(layers=(1,))
    state = net.initial_state_for_inputs(gradients)
    net(gradients, state)
    # Weights and biases for two layers.
    variables = snt.get_variables_in_module(net)
    self.assertEqual(len(variables), 4) 
开发者ID:deepmind,项目名称:learning-to-learn,代码行数:12,代码来源:networks_test.py

示例13: testNonTrainable

# 需要导入模块: import sonnet [as 别名]
# 或者: from sonnet import get_variables_in_module [as 别名]
def testNonTrainable(self):
    """Tests the network doesn't contain trainable variables."""
    shape = [10, 5]
    gradients = tf.random_normal(shape)
    net = networks.Sgd()
    state = net.initial_state_for_inputs(gradients)
    net(gradients, state)
    variables = snt.get_variables_in_module(net)
    self.assertEqual(len(variables), 0) 
开发者ID:deepmind,项目名称:learning-to-learn,代码行数:11,代码来源:networks_test.py


注:本文中的sonnet.get_variables_in_module方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。