本文整理汇总了Python中chainer.Link方法的典型用法代码示例。如果您正苦于以下问题:Python chainer.Link方法的具体用法?Python chainer.Link怎么用?Python chainer.Link使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类chainer
的用法示例。
在下文中一共展示了chainer.Link方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: set_shared_params
# 需要导入模块: import chainer [as 别名]
# 或者: from chainer import Link [as 别名]
def set_shared_params(a, b):
"""Set shared params (and persistent values) to a link.
Args:
a (chainer.Link): link whose params are to be replaced
b (dict): dict that consists of (param_name, multiprocessing.Array)
"""
assert isinstance(a, chainer.Link)
remaining_keys = set(b.keys())
for param_name, param in a.namedparams():
if param_name in b:
shared_param = b[param_name]
param.array = np.frombuffer(
shared_param, dtype=param.dtype).reshape(param.shape)
remaining_keys.remove(param_name)
for persistent_name, _ in chainerrl.misc.namedpersistent(a):
if persistent_name in b:
_set_persistent_values_recursively(
a, persistent_name, b[persistent_name])
remaining_keys.remove(persistent_name)
assert not remaining_keys
示例2: extract_params_as_shared_arrays
# 需要导入模块: import chainer [as 别名]
# 或者: from chainer import Link [as 别名]
def extract_params_as_shared_arrays(link):
assert isinstance(link, chainer.Link)
shared_arrays = {}
for param_name, param in link.namedparams():
typecode = param.array.dtype.char
shared_arrays[param_name] = mp.RawArray(typecode, param.array.ravel())
for persistent_name, persistent in chainerrl.misc.namedpersistent(link):
if isinstance(persistent, np.ndarray):
typecode = persistent.dtype.char
shared_arrays[persistent_name] = mp.RawArray(
typecode, persistent.ravel())
else:
assert np.isscalar(persistent)
# Wrap by a 1-dim array because multiprocessing.RawArray does not
# accept a 0-dim array.
persistent_as_array = np.asarray([persistent])
typecode = persistent_as_array.dtype.char
shared_arrays[persistent_name] = mp.RawArray(
typecode, persistent_as_array)
return shared_arrays
示例3: _batch_reset_recurrent_states_when_episodes_end
# 需要导入模块: import chainer [as 别名]
# 或者: from chainer import Link [as 别名]
def _batch_reset_recurrent_states_when_episodes_end(
model, batch_done, batch_reset, recurrent_states):
"""Reset recurrent states when episodes end.
Args:
model (chainer.Link): Model that implements `StatelessRecurrent`.
batch_done (array-like of bool): True iff episodes are terminal.
batch_reset (array-like of bool): True iff episodes will be reset.
recurrent_states (object): Recurrent state.
Returns:
object: New recurrent states.
"""
indices_that_ended = [
i for i, (done, reset)
in enumerate(zip(batch_done, batch_reset)) if done or reset]
if indices_that_ended:
return model.mask_recurrent_state_at(
recurrent_states, indices_that_ended)
else:
return recurrent_states
示例4: setUp
# 需要导入模块: import chainer [as 别名]
# 或者: from chainer import Link [as 别名]
def setUp(self):
x_shape_0 = 2
x_shape_1 = numpy.int64(3)
self.link = chainer.Link(x=((x_shape_0, x_shape_1), 'd'),
u=(None, 'd'))
with self.link.init_scope():
self.link.y = chainer.Parameter(shape=(2,))
self.link.v = chainer.Parameter()
self.p = numpy.array([1, 2, 3], dtype='f')
self.link.add_persistent('p', self.p)
self.link.name = 'a'
self.link.x.update_rule = chainer.UpdateRule()
self.link.x.update_rule.enabled = False
self.link.u.update_rule = chainer.UpdateRule()
if cuda.available:
self.current_device_id = cuda.cupy.cuda.get_device_id()
示例5: test_serialize
# 需要导入模块: import chainer [as 别名]
# 或者: from chainer import Link [as 别名]
def test_serialize(self, backend_config):
call_record = []
def serializer(key, value):
call_record.append((key, value))
return value
l = chainer.Link()
with l.init_scope():
l.x = chainer.Parameter() # uninitialized
l.to_device(backend_config.device)
l.serialize(serializer)
# Link is kept uninitialized
self.assertIsNone(l.x.array)
# Check inputs to the serializer
self.assertEqual(len(call_record), 1)
self.assertEqual(call_record[0][0], 'x')
self.assertIs(call_record[0][1], None)
示例6: test_deserialize
# 需要导入模块: import chainer [as 别名]
# 或者: from chainer import Link [as 别名]
def test_deserialize(self, backend_config):
# Deserializes uninitialized parameters into uninitialied ones.
call_record = []
def serializer(key, value):
call_record.append((key, value))
return None # to be uninitialized
l = chainer.Link()
with l.init_scope():
l.x = chainer.Parameter() # uninitialized
l.to_device(backend_config.device)
l.serialize(serializer)
# Link is kept uninitialized
self.assertIsNone(l.x.array)
# Check inputs to the serializer
self.assertEqual(len(call_record), 1)
self.assertEqual(call_record[0][0], 'x')
self.assertIs(call_record[0][1], None)
示例7: test_copyparams
# 需要导入模块: import chainer [as 别名]
# 或者: from chainer import Link [as 别名]
def test_copyparams(self):
l1 = chainer.Link()
with l1.init_scope():
l1.x = chainer.Parameter(shape=(2, 3))
l1.y = chainer.Parameter()
l2 = chainer.Link()
with l2.init_scope():
l2.x = chainer.Parameter(shape=2)
l3 = chainer.Link()
with l3.init_scope():
l3.x = chainer.Parameter(shape=3)
c1 = chainer.ChainList(l1, l2)
c2 = chainer.ChainList(c1, l3)
l1.x.data.fill(0)
l2.x.data.fill(1)
l3.x.data.fill(2)
self.c2.copyparams(c2)
numpy.testing.assert_array_equal(self.l1.x.data, l1.x.data)
numpy.testing.assert_array_equal(self.l2.x.data, l2.x.data)
numpy.testing.assert_array_equal(self.l3.x.data, l3.x.data)
示例8: test_adam_w
# 需要导入模块: import chainer [as 别名]
# 或者: from chainer import Link [as 别名]
def test_adam_w(self, backend_config):
xp = backend_config.xp
device = backend_config.device
link = chainer.Link(x=(1,))
link.to_device(device)
opt = optimizers.Adam(eta=0.5, weight_decay_rate=0.1)
opt.setup(link)
link.x.data.fill(1)
link.x.grad = device.send(xp.ones_like(link.x.data))
opt.update()
# compare against the value computed with v5 impl
testing.assert_allclose(link.x.data, np.array([0.9495]),
atol=1e-7, rtol=1e-7)
示例9: forward_postprocess
# 需要导入模块: import chainer [as 别名]
# 或者: from chainer import Link [as 别名]
def forward_postprocess(
self,
args: _ForwardPostprocessCallbackArgs
) -> None:
"""Callback function invoked after a forward call of a link.
Args:
args: Callback data. It has the following attributes:
* link (:class:`~chainer.Link`)
Link object.
* forward_name (:class:`str`)
Name of the forward method.
* args (:class:`tuple`)
Non-keyword arguments given to the forward method.
* kwargs (:class:`dict`)
Keyword arguments given to the forward method.
* out
Return value of the forward method.
"""
pass
示例10: use_cleargrads
# 需要导入模块: import chainer [as 别名]
# 或者: from chainer import Link [as 别名]
def use_cleargrads(self, use=True):
"""Enables or disables use of :func:`~chainer.Link.cleargrads` in `update`.
Args:
use (bool): If ``True``, this function enables use of
`cleargrads`. If ``False``, disables use of `cleargrads`
(`zerograds` is used).
.. deprecated:: v2.0
Note that :meth:`update` calls :meth:`~Link.cleargrads` by default.
:meth:`~Link.cleargrads` is more efficient than
:meth:`~Link.zerograds`, so one does not have to call
:meth:`use_cleargrads`. This method remains for backward
compatibility.
"""
warnings.warn(
'GradientMethod.use_cleargrads is deprecated.',
DeprecationWarning)
self._use_cleargrads = use
示例11: create_simple_link
# 需要导入模块: import chainer [as 别名]
# 或者: from chainer import Link [as 别名]
def create_simple_link():
link = chainer.Link()
with link.init_scope():
link.param = chainer.Parameter(np.zeros(1))
return link
示例12: test_namedpersistent
# 需要导入模块: import chainer [as 别名]
# 或者: from chainer import Link [as 别名]
def test_namedpersistent():
# This test case is adopted from
# https://github.com/chainer/chainer/pull/6788
l1 = chainer.Link()
with l1.init_scope():
l1.x = chainer.Parameter(shape=(2, 3))
l2 = chainer.Link()
with l2.init_scope():
l2.x = chainer.Parameter(shape=2)
l2.add_persistent(
'l2_a', numpy.array([1, 2, 3], dtype=numpy.float32))
l3 = chainer.Link()
with l3.init_scope():
l3.x = chainer.Parameter()
l3.add_persistent(
'l3_a', numpy.array([1, 2, 3], dtype=numpy.float32))
c1 = chainer.Chain()
with c1.init_scope():
c1.l1 = l1
c1.add_link('l2', l2)
c1.add_persistent(
'c1_a', numpy.array([1, 2, 3], dtype=numpy.float32))
c2 = chainer.Chain()
with c2.init_scope():
c2.c1 = c1
c2.l3 = l3
c2.add_persistent(
'c2_a', numpy.array([1, 2, 3], dtype=numpy.float32))
namedpersistent = list(chainerrl.misc.namedpersistent(c2))
assert (
[(name, id(p)) for name, p in namedpersistent] ==
[('/c2_a', id(c2.c2_a)), ('/c1/c1_a', id(c2.c1.c1_a)),
('/c1/l2/l2_a', id(c2.c1.l2.l2_a)), ('/l3/l3_a', id(c2.l3.l3_a))])
示例13: _assert_same_pointers_to_persistent_values
# 需要导入模块: import chainer [as 别名]
# 或者: from chainer import Link [as 别名]
def _assert_same_pointers_to_persistent_values(a, b):
assert isinstance(a, chainer.Link)
assert isinstance(b, chainer.Link)
a_persistents = dict(chainerrl.misc.namedpersistent(a))
b_persistents = dict(chainerrl.misc.namedpersistent(b))
assert set(a_persistents.keys()) == set(b_persistents.keys())
for key in a_persistents:
a_persistent = a_persistents[key]
b_persistent = b_persistents[key]
assert isinstance(a_persistent, np.ndarray)
assert isinstance(b_persistent, np.ndarray)
assert a_persistent.ctypes.data == b_persistent.ctypes.data
示例14: _assert_same_pointers_to_param_data
# 需要导入模块: import chainer [as 别名]
# 或者: from chainer import Link [as 别名]
def _assert_same_pointers_to_param_data(a, b):
assert isinstance(a, chainer.Link)
assert isinstance(b, chainer.Link)
a_params = dict(a.namedparams())
b_params = dict(b.namedparams())
assert set(a_params.keys()) == set(b_params.keys())
for key in a_params.keys():
assert isinstance(a_params[key], chainer.Variable)
assert isinstance(b_params[key], chainer.Variable)
assert (a_params[key].array.ctypes.data
== b_params[key].array.ctypes.data)
示例15: _assert_different_pointers_to_param_grad
# 需要导入模块: import chainer [as 别名]
# 或者: from chainer import Link [as 别名]
def _assert_different_pointers_to_param_grad(a, b):
assert isinstance(a, chainer.Link)
assert isinstance(b, chainer.Link)
a_params = dict(a.namedparams())
b_params = dict(b.namedparams())
assert set(a_params.keys()) == set(b_params.keys())
for key in a_params.keys():
assert isinstance(a_params[key], chainer.Variable)
assert isinstance(b_params[key], chainer.Variable)
assert (a_params[key].grad.ctypes.data
!= b_params[key].grad.ctypes.data)