当前位置: 首页>>代码示例>>Python>>正文


Python chainer.Optimizer方法代码示例

本文整理汇总了Python中chainer.Optimizer方法的典型用法代码示例。如果您正苦于以下问题:Python chainer.Optimizer方法的具体用法?Python chainer.Optimizer怎么用?Python chainer.Optimizer使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在chainer的用法示例。


在下文中一共展示了chainer.Optimizer方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: serialize

# 需要导入模块: import chainer [as 别名]
# 或者: from chainer import Optimizer [as 别名]
def serialize(self, serializer):
        """Serializes or deserializes the optimizer.

        It only saves or loads the following things:

        - Optimizer states
        - Global states (:attr:`t` and :attr:`epoch`)

        **It does not saves nor loads the parameters of the target link.** They
        should be separately saved or loaded.

        Args:
            serializer (~chainer.AbstractSerializer): Serializer or
                deserializer object.

        """
        self.t = serializer('t', self.t)
        self.epoch = serializer('epoch', self.epoch)
        for name, param in self.target.namedparams():
            rule = getattr(param, 'update_rule', None)
            if rule is not None:
                rule.serialize(serializer[name]) 
开发者ID:chainer,项目名称:chainer,代码行数:24,代码来源:optimizer.py

示例2: __init__

# 需要导入模块: import chainer [as 别名]
# 或者: from chainer import Optimizer [as 别名]
def __init__(self, optimizer):
        """

        Parameters
        ----------
        optimizer : :class:`chainer.Optimizer`
            the optimizer to wrap

        """
        if isinstance(optimizer, chainer.Optimizer):
            self._optimizer = optimizer

        else:
            raise RuntimeError("Invalid optimizer class given: Expected "
                               "instance of chainer.Optimizer, but got %s"
                               % optimizer.__class__.__name__) 
开发者ID:delira-dev,项目名称:delira,代码行数:18,代码来源:data_parallel.py

示例3: from_optimizer_class

# 需要导入模块: import chainer [as 别名]
# 或者: from chainer import Optimizer [as 别名]
def from_optimizer_class(cls, optim_cls, *args, **kwargs):
        """

        Parameters
        ----------
        optim_cls : subclass of :class:`chainer.Optimizer`
            the optimizer to use internally
        *args :
            arbitrary positional arguments (will be used for
            initialization of internally used optimizer)
        **kwargs :
            arbitrary keyword arguments (will be used for initialization
            of internally used optimizer)

        """
        if optim_cls is not None and issubclass(optim_cls,
                                                chainer.Optimizer):
            _optim = optim_cls(*args, **kwargs)
        else:
            raise RuntimeError("Invalid optimizer class given: Expected "
                               "Subclass of chainer.Optimizer, but got %s"
                               % optim_cls.__name__)
        return cls(_optim) 
开发者ID:delira-dev,项目名称:delira,代码行数:25,代码来源:data_parallel.py

示例4: save

# 需要导入模块: import chainer [as 别名]
# 或者: from chainer import Optimizer [as 别名]
def save(self, dir_name):
        dir_path = os.path.join(self._root_dir_path, dir_name)
        if not os.path.exists(dir_path):
            os.mkdir(dir_path)

        others = []
        for key, value in self.items():
            if key.startswith('_'):
                continue

            if isinstance(value, (np.ndarray, list)):
                np.save(os.path.join(dir_path, key + ".npy"), value)
            elif isinstance(value, (chainer.Chain, chainer.ChainList)):
                model_path = os.path.join(dir_path, "model.npz")
                chainer.serializers.save_npz(model_path, value)
            elif isinstance(value, chainer.Optimizer):
                optimizer_path = os.path.join(dir_path, "optimizer.npz")
                chainer.serializers.save_npz(optimizer_path, value)
            else:
                others.append("{}: {}".format(key, value))

        with open(os.path.join(dir_path, "log.txt"), "a") as f:
            text = "\n".join(others) + "\n"
            f.write(text) 
开发者ID:ronekko,项目名称:deep_metric_learning,代码行数:26,代码来源:utils.py

示例5: set_shared_states

# 需要导入模块: import chainer [as 别名]
# 或者: from chainer import Optimizer [as 别名]
def set_shared_states(a, b):
    assert isinstance(a, chainer.Optimizer)
    assert hasattr(a, 'target'), 'Optimizer.setup must be called first'
    for param_name, param in a.target.namedparams():
        ensure_initialized_update_rule(param)
        state = param.update_rule.state
        for state_name, state_val in b[param_name].items():
            s = state[state_name]
            state[state_name] = np.frombuffer(
                state_val,
                dtype=s.dtype).reshape(s.shape) 
开发者ID:chainer,项目名称:chainerrl,代码行数:13,代码来源:async_.py

示例6: extract_states_as_shared_arrays

# 需要导入模块: import chainer [as 别名]
# 或者: from chainer import Optimizer [as 别名]
def extract_states_as_shared_arrays(optimizer):
    assert isinstance(optimizer, chainer.Optimizer)
    assert hasattr(optimizer, 'target'), 'Optimizer.setup must be called first'
    shared_arrays = {}
    for param_name, param in optimizer.target.namedparams():
        shared_arrays[param_name] = {}
        ensure_initialized_update_rule(param)
        state = param.update_rule.state
        for state_name, state_val in state.items():
            shared_arrays[param_name][
                state_name] = mp.RawArray('f', state_val.ravel())
    return shared_arrays 
开发者ID:chainer,项目名称:chainerrl,代码行数:14,代码来源:async_.py

示例7: as_shared_objects

# 需要导入模块: import chainer [as 别名]
# 或者: from chainer import Optimizer [as 别名]
def as_shared_objects(obj):
    if isinstance(obj, tuple):
        return tuple(as_shared_objects(x) for x in obj)
    elif isinstance(obj, chainer.Link):
        return share_params_as_shared_arrays(obj)
    elif isinstance(obj, chainer.Optimizer):
        return share_states_as_shared_arrays(obj)
    elif isinstance(obj, mp.sharedctypes.Synchronized):
        return obj
    else:
        raise ValueError('') 
开发者ID:chainer,项目名称:chainerrl,代码行数:13,代码来源:async_.py

示例8: synchronize_to_shared_objects

# 需要导入模块: import chainer [as 别名]
# 或者: from chainer import Optimizer [as 别名]
def synchronize_to_shared_objects(obj, shared_memory):
    if isinstance(obj, tuple):
        return tuple(synchronize_to_shared_objects(o, s)
                     for o, s in zip(obj, shared_memory))
    elif isinstance(obj, chainer.Link):
        set_shared_params(obj, shared_memory)
        return obj
    elif isinstance(obj, chainer.Optimizer):
        set_shared_states(obj, shared_memory)
        return obj
    elif isinstance(obj, mp.sharedctypes.Synchronized):
        return shared_memory
    else:
        raise ValueError('') 
开发者ID:chainer,项目名称:chainerrl,代码行数:16,代码来源:async_.py

示例9: set_shared_states

# 需要导入模块: import chainer [as 别名]
# 或者: from chainer import Optimizer [as 别名]
def set_shared_states(a, b):
    assert isinstance(a, chainer.Optimizer)
    assert hasattr(a, 'target'), 'Optimizer.setup must be called first'
    for state_name, shared_state in b.items():
        for param_name, param in shared_state.items():
            old_param = a._states[state_name][param_name]
            a._states[state_name][param_name] = np.frombuffer(
                param,
                dtype=old_param.dtype).reshape(old_param.shape) 
开发者ID:muupan,项目名称:async-rl,代码行数:11,代码来源:async.py

示例10: extract_states_as_shared_arrays

# 需要导入模块: import chainer [as 别名]
# 或者: from chainer import Optimizer [as 别名]
def extract_states_as_shared_arrays(optimizer):
    assert isinstance(optimizer, chainer.Optimizer)
    assert hasattr(optimizer, 'target'), 'Optimizer.setup must be called first'
    shared_arrays = {}
    for state_name, state in optimizer._states.items():
        shared_arrays[state_name] = {}
        for param_name, param in state.items():
            shared_arrays[state_name][
                param_name] = mp.RawArray('f', param.ravel())
    return shared_arrays 
开发者ID:muupan,项目名称:async-rl,代码行数:12,代码来源:async.py

示例11: test_all_optimizers_coverage

# 需要导入模块: import chainer [as 别名]
# 或者: from chainer import Optimizer [as 别名]
def test_all_optimizers_coverage(self):
        module = chainer.optimizers
        module_optimizers = []
        for name in dir(module):
            obj = getattr(module, name)
            if (isinstance(obj, type) and issubclass(obj, chainer.Optimizer)):
                module_optimizers.append(name)

        assert sorted(_all_optimizers) == sorted(module_optimizers) 
开发者ID:chainer,项目名称:chainer,代码行数:11,代码来源:test_optimizers.py

示例12: _check_set_up

# 需要导入模块: import chainer [as 别名]
# 或者: from chainer import Optimizer [as 别名]
def _check_set_up(self):
        if self._hookable is None:
            raise RuntimeError('Optimizer is not set up. Call `setup` method.') 
开发者ID:chainer,项目名称:chainer,代码行数:5,代码来源:optimizer.py

示例13: get_optimizer

# 需要导入模块: import chainer [as 别名]
# 或者: from chainer import Optimizer [as 别名]
def get_optimizer(self, name):
        """Gets the optimizer of given name.

        Args:
            name (str): Name of the optimizer.

        Returns:
            ~chainer.Optimizer: Corresponding optimizer.

        """
        return self._optimizers[name] 
开发者ID:chainer,项目名称:chainer,代码行数:13,代码来源:standard_updater.py

示例14: __call__

# 需要导入模块: import chainer [as 别名]
# 或者: from chainer import Optimizer [as 别名]
def __call__(self, optimizer: chainer.Optimizer):
        """
        Summing up all parameters if the target is an instance of
        ``DataParallel``

        Parameters
        ----------
        optimizer : chainer.Optimizer
            the optimizer holding the target, whoose gradients should be
            summed across the replications

        """
        if isinstance(optimizer.target, DataParallelChainerNetwork):
            for module in optimizer.target.modules[1:]:
                optimizer.target.modules[0].addgrads(module) 
开发者ID:delira-dev,项目名称:delira,代码行数:17,代码来源:data_parallel.py

示例15: __init__

# 需要导入模块: import chainer [as 别名]
# 或者: from chainer import Optimizer [as 别名]
def __init__(
            self,
            args,
            loss_maker,
            main_optimizer,
            main_lossfun,
            reinput_optimizer=None,
            reinput_lossfun=None,
            discriminator_optimizer=None,
            discriminator_lossfun=None,
            *_args, **kwargs
    ):
        # type: (any, comicolorization.loss.LossMaker, any, typing.Callable[[typing.Dict], any], typing.List[chainer.Optimizer], typing.Callable[[int, typing.Dict], any], any, typing.Callable[[typing.Dict], any], *any, **any) -> None
        optimizers = {'main': main_optimizer}
        if reinput_optimizer is not None:
            for i_reinput, optimizer in enumerate(reinput_optimizer):
                optimizers['reinput{}'.format(i_reinput)] = optimizer

        if discriminator_optimizer is not None:
            optimizers['discriminator'] = discriminator_optimizer

        super().__init__(optimizer=optimizers, *_args, **kwargs)

        # chainer.reporter cannot work on some optimizer focus same model
        if args.separate_backward_reinput and reinput_optimizer is None:
            reinput_optimizer = [main_optimizer for _ in range(len(args.loss_blend_ratio_reinput))]

        self.args = args
        self.loss_maker = loss_maker
        self.main_optimizer = main_optimizer
        self.main_lossfun = main_lossfun
        self.reinput_optimizer = reinput_optimizer
        self.reinput_lossfun = reinput_lossfun
        self.discriminator_optimizer = discriminator_optimizer
        self.discriminator_lossfun = discriminator_lossfun 
开发者ID:DwangoMediaVillage,项目名称:Comicolorization,代码行数:37,代码来源:multi_updater.py


注:本文中的chainer.Optimizer方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。