当前位置: 首页>>代码示例>>Python>>正文


Python link.Chain方法代码示例

本文整理汇总了Python中chainer.link.Chain方法的典型用法代码示例。如果您正苦于以下问题:Python link.Chain方法的具体用法?Python link.Chain怎么用?Python link.Chain使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在chainer.link的用法示例。


在下文中一共展示了link.Chain方法的11个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: test_deserialize_hierarchy

# 需要导入模块: from chainer import link [as 别名]
# 或者: from chainer.link import Chain [as 别名]
def test_deserialize_hierarchy(self):
        # Load a link
        child = link.Chain()
        with child.init_scope():
            child.linear2 = links.Linear(2, 3)
        target = link.Chain()
        with target.init_scope():
            target.linear = links.Linear(3, 2)
            target.child = child
        target_child_W = numpy.copy(child.linear2.W.data)
        target_child_b = numpy.copy(child.linear2.b.data)

        self.deserializer.load(target)

        # Check
        numpy.testing.assert_array_equal(
            self.source.linear.W.data, target.linear.W.data)
        numpy.testing.assert_array_equal(
            self.source.linear.W.data, target.linear.W.data)
        numpy.testing.assert_array_equal(
            self.source.linear.b.data, target.linear.b.data)
        numpy.testing.assert_array_equal(
            target.child.linear2.W.data, target_child_W)
        numpy.testing.assert_array_equal(
            target.child.linear2.b.data, target_child_b) 
开发者ID:chainer,项目名称:chainer,代码行数:27,代码来源:test_npz.py

示例2: setUp

# 需要导入模块: from chainer import link [as 别名]
# 或者: from chainer.link import Chain [as 别名]
def setUp(self):
        fd, path = tempfile.mkstemp()
        os.close(fd)
        self.temp_file_path = path

        child = link.Chain()
        with child.init_scope():
            child.linear2 = links.Linear(2, 3)
        parent = link.Chain()
        with parent.init_scope():
            parent.linear = links.Linear(3, 2)
            parent.child = child
        npz.save_npz(self.temp_file_path, parent)
        self.source = parent

        self.npzfile = numpy.load(path)
        self.deserializer = npz.NpzDeserializer(
            self.npzfile, ignore_names=self.ignore_names) 
开发者ID:chainer,项目名称:chainer,代码行数:20,代码来源:test_npz.py

示例3: test_deserialize_ignore_names

# 需要导入模块: from chainer import link [as 别名]
# 或者: from chainer.link import Chain [as 别名]
def test_deserialize_ignore_names(self):
        child = link.Chain()
        with child.init_scope():
            child.linear2 = links.Linear(2, 3)
        target = link.Chain()
        with target.init_scope():
            target.linear = links.Linear(3, 2)
            target.child = child
        target_W = numpy.copy(target.linear.W.data)
        target_child_b = numpy.copy(child.linear2.b.data)
        self.deserializer.load(target)

        numpy.testing.assert_array_equal(
            self.source.linear.b.data, target.linear.b.data)
        numpy.testing.assert_array_equal(
            self.source.child.linear2.W.data, target.child.linear2.W.data)
        numpy.testing.assert_array_equal(
            target.linear.W.data, target_W)
        numpy.testing.assert_array_equal(
            target.child.linear2.b.data, target_child_b) 
开发者ID:chainer,项目名称:chainer,代码行数:22,代码来源:test_npz.py

示例4: setUp

# 需要导入模块: from chainer import link [as 别名]
# 或者: from chainer.link import Chain [as 别名]
def setUp(self):
        fd, path = tempfile.mkstemp()
        os.close(fd)
        self.temp_file_path = path

        child = link.Chain()
        with child.init_scope():
            child.linear = links.Linear(2, 3)
        parent = link.Chain()
        with parent.init_scope():
            parent.linear = links.Linear(3, 2)
            parent.child = child
        hdf5.save_hdf5(self.temp_file_path, parent)
        self.source = parent

        self.hdf5file = h5py.File(path, 'r')
        self.deserializer = hdf5.HDF5Deserializer(self.hdf5file, strict=False) 
开发者ID:chainer,项目名称:chainer,代码行数:19,代码来源:test_hdf5.py

示例5: test_deserialize_hierarchy

# 需要导入模块: from chainer import link [as 别名]
# 或者: from chainer.link import Chain [as 别名]
def test_deserialize_hierarchy(self):
        child = link.Chain()
        with child.init_scope():
            child.linear2 = links.Linear(2, 3)
        target = link.Chain()
        with target.init_scope():
            target.linear = links.Linear(3, 2)
            target.child = child
        target_child_W = numpy.copy(child.linear2.W.data)
        target_child_b = numpy.copy(child.linear2.b.data)
        self.deserializer.load(target)

        numpy.testing.assert_array_equal(
            self.source.linear.W.data, target.linear.W.data)
        numpy.testing.assert_array_equal(
            self.source.linear.b.data, target.linear.b.data)
        numpy.testing.assert_array_equal(
            target.child.linear2.W.data, target_child_W)
        numpy.testing.assert_array_equal(
            target.child.linear2.b.data, target_child_b) 
开发者ID:chainer,项目名称:chainer,代码行数:22,代码来源:test_hdf5.py

示例6: copy_model

# 需要导入模块: from chainer import link [as 别名]
# 或者: from chainer.link import Chain [as 别名]
def copy_model(src, dst):
    assert isinstance(src, link.Chain)
    assert isinstance(dst, link.Chain)
    for child in src.children():
        if child.name not in dst.__dict__: continue
        dst_child = dst[child.name]
        if type(child) != type(dst_child): continue
        if isinstance(child, link.Chain):
            copy_model(child, dst_child)
        if isinstance(child, link.Link):
            match = True
            for a, b in zip(child.namedparams(), dst_child.namedparams()):
                if a[0] != b[0]:
                    match = False
                    break
                if a[1].data.shape != b[1].data.shape:
                    match = False
                    break
            if not match:
                print 'Ignore %s because of parameter mismatch' % child.name
                continue
            for a, b in zip(child.namedparams(), dst_child.namedparams()):
                b[1].data = a[1].data
            print 'Copy %s' % child.name 
开发者ID:dsanno,项目名称:chainer-dfi,代码行数:26,代码来源:create_chainer_model.py

示例7: copy_model

# 需要导入模块: from chainer import link [as 别名]
# 或者: from chainer.link import Chain [as 别名]
def copy_model(src, dst):
    assert isinstance(src, link.Chain)
    assert isinstance(dst, link.Chain)
    for child in src.children():
        if child.name not in dst.__dict__: continue
        dst_child = dst[child.name]
        if type(child) != type(dst_child): continue
        if isinstance(child, link.Chain):
            copy_model(child, dst_child)
        if isinstance(child, link.Link):
            match = True
            for a, b in zip(child.namedparams(), dst_child.namedparams()):
                if a[0] != b[0]:
                    match = False
                    break
                if a[1].data.shape != b[1].data.shape:
                    match = False
                    break
            if not match:
                print('Ignore %s because of parameter mismatch' % child.name)
                continue
            for a, b in zip(child.namedparams(), dst_child.namedparams()):
                b[1].data = a[1].data
            print('Copy %s' % child.name) 
开发者ID:yusuketomoto,项目名称:chainer-fast-neuralstyle,代码行数:26,代码来源:create_chainer_model.py

示例8: copy_chainermodel

# 需要导入模块: from chainer import link [as 别名]
# 或者: from chainer.link import Chain [as 别名]
def copy_chainermodel(src, dst):
    from chainer import link
    assert isinstance(src, link.Chain)
    assert isinstance(dst, link.Chain)
    print('Copying layers %s -> %s:' %
          (src.__class__.__name__, dst.__class__.__name__))
    for child in src.children():
        if child.name not in dst.__dict__:
            continue
        dst_child = dst[child.name]
        if type(child) != type(dst_child):
            continue
        if isinstance(child, link.Chain):
            copy_chainermodel(child, dst_child)
        if isinstance(child, link.Link):
            match = True
            for a, b in zip(child.namedparams(), dst_child.namedparams()):
                if a[0] != b[0]:
                    match = False
                    break
                if a[1].data.shape != b[1].data.shape:
                    match = False
                    break
            if not match:
                print('Ignore %s because of parameter mismatch.' % child.name)
                continue
            for a, b in zip(child.namedparams(), dst_child.namedparams()):
                b[1].data = a[1].data
            print(' layer: %s -> %s' % (child.name, dst_child.name))



# -----------------------------------------------------------------------------
# Data Util
# ----------------------------------------------------------------------------- 
开发者ID:oyam,项目名称:Semantic-Segmentation-using-Adversarial-Networks,代码行数:37,代码来源:utils.py

示例9: test_load_npz_ignore_names

# 需要导入模块: from chainer import link [as 别名]
# 或者: from chainer.link import Chain [as 别名]
def test_load_npz_ignore_names(self):
        chain = link.Chain()
        with chain.init_scope():
            chain.x = chainer.variable.Parameter(shape=())
            chain.yy = chainer.variable.Parameter(shape=(2, 3))
        npz.load_npz(
            self.temp_file_path, chain, ignore_names=self.ignore_names)
        self.assertEqual(chain.x.data, self.x)
        self.assertFalse(numpy.all(chain.yy.data == self.yy)) 
开发者ID:chainer,项目名称:chainer,代码行数:11,代码来源:test_npz.py

示例10: test_load_with_path

# 需要导入模块: from chainer import link [as 别名]
# 或者: from chainer.link import Chain [as 别名]
def test_load_with_path(self):
        target = link.Chain()
        with target.init_scope():
            target.child_linear = links.Linear(2, 3)
        npz.load_npz(self.file, target, 'child/')
        numpy.testing.assert_array_equal(
            self.source_child.child_linear.W.data, target.child_linear.W.data) 
开发者ID:chainer,项目名称:chainer,代码行数:9,代码来源:test_npz.py

示例11: test_load_without_path

# 需要导入模块: from chainer import link [as 别名]
# 或者: from chainer.link import Chain [as 别名]
def test_load_without_path(self):
        target = link.Chain()
        with target.init_scope():
            target.parent_linear = links.Linear(3, 2)
        npz.load_npz(self.file, target, path='')
        numpy.testing.assert_array_equal(
            self.source_parent.parent_linear.W.data,
            target.parent_linear.W.data) 
开发者ID:chainer,项目名称:chainer,代码行数:10,代码来源:test_npz.py


注:本文中的chainer.link.Chain方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。