当前位置: 首页>>代码示例>>Python>>正文


Python chainer.Link方法代码示例

本文整理汇总了Python中chainer.Link方法的典型用法代码示例。如果您正苦于以下问题:Python chainer.Link方法的具体用法?Python chainer.Link怎么用?Python chainer.Link使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在chainer的用法示例。


在下文中一共展示了chainer.Link方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: set_shared_params

# 需要导入模块: import chainer [as 别名]
# 或者: from chainer import Link [as 别名]
def set_shared_params(a, b):
    """Set shared params (and persistent values) to a link.

    Args:
      a (chainer.Link): link whose params are to be replaced
      b (dict): dict that consists of (param_name, multiprocessing.Array)
    """
    assert isinstance(a, chainer.Link)
    remaining_keys = set(b.keys())
    for param_name, param in a.namedparams():
        if param_name in b:
            shared_param = b[param_name]
            param.array = np.frombuffer(
                shared_param, dtype=param.dtype).reshape(param.shape)
            remaining_keys.remove(param_name)
    for persistent_name, _ in chainerrl.misc.namedpersistent(a):
        if persistent_name in b:
            _set_persistent_values_recursively(
                a, persistent_name, b[persistent_name])
            remaining_keys.remove(persistent_name)
    assert not remaining_keys 
开发者ID:chainer,项目名称:chainerrl,代码行数:23,代码来源:async_.py

示例2: extract_params_as_shared_arrays

# 需要导入模块: import chainer [as 别名]
# 或者: from chainer import Link [as 别名]
def extract_params_as_shared_arrays(link):
    assert isinstance(link, chainer.Link)
    shared_arrays = {}
    for param_name, param in link.namedparams():
        typecode = param.array.dtype.char
        shared_arrays[param_name] = mp.RawArray(typecode, param.array.ravel())

    for persistent_name, persistent in chainerrl.misc.namedpersistent(link):
        if isinstance(persistent, np.ndarray):
            typecode = persistent.dtype.char
            shared_arrays[persistent_name] = mp.RawArray(
                typecode, persistent.ravel())
        else:
            assert np.isscalar(persistent)
            # Wrap by a 1-dim array because multiprocessing.RawArray does not
            # accept a 0-dim array.
            persistent_as_array = np.asarray([persistent])
            typecode = persistent_as_array.dtype.char
            shared_arrays[persistent_name] = mp.RawArray(
                typecode, persistent_as_array)
    return shared_arrays 
开发者ID:chainer,项目名称:chainerrl,代码行数:23,代码来源:async_.py

示例3: _batch_reset_recurrent_states_when_episodes_end

# 需要导入模块: import chainer [as 别名]
# 或者: from chainer import Link [as 别名]
def _batch_reset_recurrent_states_when_episodes_end(
        model, batch_done, batch_reset, recurrent_states):
    """Reset recurrent states when episodes end.

    Args:
        model (chainer.Link): Model that implements `StatelessRecurrent`.
        batch_done (array-like of bool): True iff episodes are terminal.
        batch_reset (array-like of bool): True iff episodes will be reset.
        recurrent_states (object): Recurrent state.

    Returns:
        object: New recurrent states.
    """
    indices_that_ended = [
        i for i, (done, reset)
        in enumerate(zip(batch_done, batch_reset)) if done or reset]
    if indices_that_ended:
        return model.mask_recurrent_state_at(
            recurrent_states, indices_that_ended)
    else:
        return recurrent_states 
开发者ID:chainer,项目名称:chainerrl,代码行数:23,代码来源:dqn.py

示例4: setUp

# 需要导入模块: import chainer [as 别名]
# 或者: from chainer import Link [as 别名]
def setUp(self):
        x_shape_0 = 2
        x_shape_1 = numpy.int64(3)
        self.link = chainer.Link(x=((x_shape_0, x_shape_1), 'd'),
                                 u=(None, 'd'))
        with self.link.init_scope():
            self.link.y = chainer.Parameter(shape=(2,))
            self.link.v = chainer.Parameter()
        self.p = numpy.array([1, 2, 3], dtype='f')
        self.link.add_persistent('p', self.p)
        self.link.name = 'a'
        self.link.x.update_rule = chainer.UpdateRule()
        self.link.x.update_rule.enabled = False
        self.link.u.update_rule = chainer.UpdateRule()
        if cuda.available:
            self.current_device_id = cuda.cupy.cuda.get_device_id() 
开发者ID:chainer,项目名称:chainer,代码行数:18,代码来源:test_link.py

示例5: test_serialize

# 需要导入模块: import chainer [as 别名]
# 或者: from chainer import Link [as 别名]
def test_serialize(self, backend_config):
        call_record = []

        def serializer(key, value):
            call_record.append((key, value))
            return value

        l = chainer.Link()
        with l.init_scope():
            l.x = chainer.Parameter()  # uninitialized
        l.to_device(backend_config.device)

        l.serialize(serializer)

        # Link is kept uninitialized
        self.assertIsNone(l.x.array)

        # Check inputs to the serializer
        self.assertEqual(len(call_record), 1)
        self.assertEqual(call_record[0][0], 'x')
        self.assertIs(call_record[0][1], None) 
开发者ID:chainer,项目名称:chainer,代码行数:23,代码来源:test_link.py

示例6: test_deserialize

# 需要导入模块: import chainer [as 别名]
# 或者: from chainer import Link [as 别名]
def test_deserialize(self, backend_config):
        # Deserializes uninitialized parameters into uninitialied ones.
        call_record = []

        def serializer(key, value):
            call_record.append((key, value))
            return None  # to be uninitialized

        l = chainer.Link()
        with l.init_scope():
            l.x = chainer.Parameter()  # uninitialized
        l.to_device(backend_config.device)

        l.serialize(serializer)

        # Link is kept uninitialized
        self.assertIsNone(l.x.array)

        # Check inputs to the serializer
        self.assertEqual(len(call_record), 1)
        self.assertEqual(call_record[0][0], 'x')
        self.assertIs(call_record[0][1], None) 
开发者ID:chainer,项目名称:chainer,代码行数:24,代码来源:test_link.py

示例7: test_copyparams

# 需要导入模块: import chainer [as 别名]
# 或者: from chainer import Link [as 别名]
def test_copyparams(self):
        l1 = chainer.Link()
        with l1.init_scope():
            l1.x = chainer.Parameter(shape=(2, 3))
            l1.y = chainer.Parameter()
        l2 = chainer.Link()
        with l2.init_scope():
            l2.x = chainer.Parameter(shape=2)
        l3 = chainer.Link()
        with l3.init_scope():
            l3.x = chainer.Parameter(shape=3)
        c1 = chainer.ChainList(l1, l2)
        c2 = chainer.ChainList(c1, l3)
        l1.x.data.fill(0)
        l2.x.data.fill(1)
        l3.x.data.fill(2)

        self.c2.copyparams(c2)

        numpy.testing.assert_array_equal(self.l1.x.data, l1.x.data)
        numpy.testing.assert_array_equal(self.l2.x.data, l2.x.data)
        numpy.testing.assert_array_equal(self.l3.x.data, l3.x.data) 
开发者ID:chainer,项目名称:chainer,代码行数:24,代码来源:test_link.py

示例8: test_adam_w

# 需要导入模块: import chainer [as 别名]
# 或者: from chainer import Link [as 别名]
def test_adam_w(self, backend_config):
        xp = backend_config.xp
        device = backend_config.device

        link = chainer.Link(x=(1,))
        link.to_device(device)

        opt = optimizers.Adam(eta=0.5, weight_decay_rate=0.1)
        opt.setup(link)

        link.x.data.fill(1)
        link.x.grad = device.send(xp.ones_like(link.x.data))

        opt.update()

        # compare against the value computed with v5 impl
        testing.assert_allclose(link.x.data, np.array([0.9495]),
                                atol=1e-7, rtol=1e-7) 
开发者ID:chainer,项目名称:chainer,代码行数:20,代码来源:test_optimizers.py

示例9: forward_postprocess

# 需要导入模块: import chainer [as 别名]
# 或者: from chainer import Link [as 别名]
def forward_postprocess(
            self,
            args: _ForwardPostprocessCallbackArgs
    ) -> None:
        """Callback function invoked after a forward call of a link.

        Args:
            args: Callback data. It has the following attributes:

                * link (:class:`~chainer.Link`)
                    Link object.
                * forward_name (:class:`str`)
                    Name of the forward method.
                * args (:class:`tuple`)
                    Non-keyword arguments given to the forward method.
                * kwargs (:class:`dict`)
                    Keyword arguments given to the forward method.
                * out
                    Return value of the forward method.
        """
        pass 
开发者ID:chainer,项目名称:chainer,代码行数:23,代码来源:link_hook.py

示例10: use_cleargrads

# 需要导入模块: import chainer [as 别名]
# 或者: from chainer import Link [as 别名]
def use_cleargrads(self, use=True):
        """Enables or disables use of :func:`~chainer.Link.cleargrads` in `update`.

        Args:
            use (bool): If ``True``, this function enables use of
                `cleargrads`. If ``False``, disables use of `cleargrads`
                (`zerograds` is used).

        .. deprecated:: v2.0
           Note that :meth:`update` calls :meth:`~Link.cleargrads` by default.
           :meth:`~Link.cleargrads` is more efficient than
           :meth:`~Link.zerograds`, so one does not have to call
           :meth:`use_cleargrads`. This method remains for backward
           compatibility.

        """
        warnings.warn(
            'GradientMethod.use_cleargrads is deprecated.',
            DeprecationWarning)

        self._use_cleargrads = use 
开发者ID:chainer,项目名称:chainer,代码行数:23,代码来源:optimizer.py

示例11: create_simple_link

# 需要导入模块: import chainer [as 别名]
# 或者: from chainer import Link [as 别名]
def create_simple_link():
    link = chainer.Link()
    with link.init_scope():
        link.param = chainer.Parameter(np.zeros(1))
    return link 
开发者ID:chainer,项目名称:chainerrl,代码行数:7,代码来源:test_agent.py

示例12: test_namedpersistent

# 需要导入模块: import chainer [as 别名]
# 或者: from chainer import Link [as 别名]
def test_namedpersistent():
    # This test case is adopted from
    # https://github.com/chainer/chainer/pull/6788

    l1 = chainer.Link()
    with l1.init_scope():
        l1.x = chainer.Parameter(shape=(2, 3))

    l2 = chainer.Link()
    with l2.init_scope():
        l2.x = chainer.Parameter(shape=2)
    l2.add_persistent(
        'l2_a', numpy.array([1, 2, 3], dtype=numpy.float32))

    l3 = chainer.Link()
    with l3.init_scope():
        l3.x = chainer.Parameter()
    l3.add_persistent(
        'l3_a', numpy.array([1, 2, 3], dtype=numpy.float32))

    c1 = chainer.Chain()
    with c1.init_scope():
        c1.l1 = l1
    c1.add_link('l2', l2)
    c1.add_persistent(
        'c1_a', numpy.array([1, 2, 3], dtype=numpy.float32))

    c2 = chainer.Chain()
    with c2.init_scope():
        c2.c1 = c1
        c2.l3 = l3
    c2.add_persistent(
        'c2_a', numpy.array([1, 2, 3], dtype=numpy.float32))
    namedpersistent = list(chainerrl.misc.namedpersistent(c2))
    assert (
        [(name, id(p)) for name, p in namedpersistent] ==
        [('/c2_a', id(c2.c2_a)), ('/c1/c1_a', id(c2.c1.c1_a)),
         ('/c1/l2/l2_a', id(c2.c1.l2.l2_a)), ('/l3/l3_a', id(c2.l3.l3_a))]) 
开发者ID:chainer,项目名称:chainerrl,代码行数:40,代码来源:test_namedpersistent.py

示例13: _assert_same_pointers_to_persistent_values

# 需要导入模块: import chainer [as 别名]
# 或者: from chainer import Link [as 别名]
def _assert_same_pointers_to_persistent_values(a, b):
    assert isinstance(a, chainer.Link)
    assert isinstance(b, chainer.Link)
    a_persistents = dict(chainerrl.misc.namedpersistent(a))
    b_persistents = dict(chainerrl.misc.namedpersistent(b))
    assert set(a_persistents.keys()) == set(b_persistents.keys())
    for key in a_persistents:
        a_persistent = a_persistents[key]
        b_persistent = b_persistents[key]
        assert isinstance(a_persistent, np.ndarray)
        assert isinstance(b_persistent, np.ndarray)
        assert a_persistent.ctypes.data == b_persistent.ctypes.data 
开发者ID:chainer,项目名称:chainerrl,代码行数:14,代码来源:test_async.py

示例14: _assert_same_pointers_to_param_data

# 需要导入模块: import chainer [as 别名]
# 或者: from chainer import Link [as 别名]
def _assert_same_pointers_to_param_data(a, b):
    assert isinstance(a, chainer.Link)
    assert isinstance(b, chainer.Link)
    a_params = dict(a.namedparams())
    b_params = dict(b.namedparams())
    assert set(a_params.keys()) == set(b_params.keys())
    for key in a_params.keys():
        assert isinstance(a_params[key], chainer.Variable)
        assert isinstance(b_params[key], chainer.Variable)
        assert (a_params[key].array.ctypes.data
                == b_params[key].array.ctypes.data) 
开发者ID:chainer,项目名称:chainerrl,代码行数:13,代码来源:test_async.py

示例15: _assert_different_pointers_to_param_grad

# 需要导入模块: import chainer [as 别名]
# 或者: from chainer import Link [as 别名]
def _assert_different_pointers_to_param_grad(a, b):
    assert isinstance(a, chainer.Link)
    assert isinstance(b, chainer.Link)
    a_params = dict(a.namedparams())
    b_params = dict(b.namedparams())
    assert set(a_params.keys()) == set(b_params.keys())
    for key in a_params.keys():
        assert isinstance(a_params[key], chainer.Variable)
        assert isinstance(b_params[key], chainer.Variable)
        assert (a_params[key].grad.ctypes.data
                != b_params[key].grad.ctypes.data) 
开发者ID:chainer,项目名称:chainerrl,代码行数:13,代码来源:test_async.py


注:本文中的chainer.Link方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。