本文整理汇总了Python中chainer.serializers方法的典型用法代码示例。如果您正苦于以下问题:Python chainer.serializers方法的具体用法?Python chainer.serializers怎么用?Python chainer.serializers使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类chainer
的用法示例。
在下文中一共展示了chainer.serializers方法的5个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: load
# 需要导入模块: import chainer [as 别名]
# 或者: from chainer import serializers [as 别名]
def load(self, snapshot_root_directory, epoch):
model_path = os.path.join(snapshot_root_directory,
self.snapshot_filename)
try:
if os.path.exists(model_path):
print("loading {}".format(model_path))
chainer.serializers.load_hdf5(model_path, self.parameters)
return True
except Exception as error:
print(error)
return False
示例2: load_encdec_from_config
# 需要导入模块: import chainer [as 别名]
# 或者: from chainer import serializers [as 别名]
def load_encdec_from_config(config_fn, model_fn):
config=json.load(open(config_fn))
ced = create_model(config)
charlist = json.load(open(config["indexer"], "r"))
chardict = dict((c,i) for i,c in enumerate(charlist))
serializers.load_npz(model_fn, ced)
return ced, charlist, chardict
示例3: test_raise
# 需要导入模块: import chainer [as 别名]
# 或者: from chainer import serializers [as 别名]
def test_raise(self):
del sys.modules['chainer.serializers.hdf5']
del sys.modules['chainer.serializers.npz']
del sys.modules['chainer.serializers']
import chainer.serializers
self.assertFalse(chainer.serializers.hdf5._available)
with self.assertRaises(RuntimeError):
chainer.serializers.save_hdf5(None, None, None)
with self.assertRaises(RuntimeError):
chainer.serializers.load_hdf5(None, None)
with self.assertRaises(RuntimeError):
chainer.serializers.HDF5Serializer(None)
with self.assertRaises(RuntimeError):
chainer.serializers.HDF5Deserializer(None)
示例4: _build_model
# 需要导入模块: import chainer [as 别名]
# 或者: from chainer import serializers [as 别名]
def _build_model(self, config, src_vocab, trg_vocab):
def convert(val):
if val.isdigit():
return int(val)
try:
return float(val)
except:
return val
model_config = config['Model']
kwargs = {k: convert(v) for k, v in model_config.items() if k != 'name'}
m = getattr(models, model_config['name'])(**kwargs)
model_path = os.path.join(self.save_dir, 'model.hdf')
# load
if os.path.exists(model_path):
chainer.serializers.load_hdf5(model_path, m)
xstoi = src_vocab.stoi
ystoi = trg_vocab.stoi
xbos = xstoi('<s>')
xeos = xstoi('</s>')
ybos = ystoi('<s>')
yeos = ystoi('</s>')
m.set_symbols(xbos, xeos, ybos, yeos)
m.name = model_config['name']
m.byte = self._load_binary_config(config['Training'], 'byte')
m.reverse_output = self._load_binary_config(
config['Training'], 'reverse_output')
if m.byte:
m.vocab = trg_vocab
return m
示例5: save
# 需要导入模块: import chainer [as 别名]
# 或者: from chainer import serializers [as 别名]
def save(self):
save_dir = self.save_dir
m = self.model.copy()
m.name = self.model.name
m.to_cpu()
model_path = os.path.join(save_dir, 'model.hdf')
chainer.serializers.save_hdf5(model_path, m)
with open(os.path.join(save_dir, "vocab.pkl"), "wb") as f:
pickle.dump((self.src_vcb, self.trg_vcb), f)