本文整理汇总了Python中torch.utils.serialization.load_lua方法的典型用法代码示例。如果您正苦于以下问题:Python serialization.load_lua方法的具体用法?Python serialization.load_lua怎么用?Python serialization.load_lua使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类torch.utils.serialization
的用法示例。
在下文中一共展示了serialization.load_lua方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: _load_dataset
# 需要导入模块: from torch.utils import serialization [as 别名]
# 或者: from torch.utils.serialization import load_lua [as 别名]
def _load_dataset(self, img_root, caption_root, classes_filename, word_embedding):
output = []
with open(os.path.join(caption_root, classes_filename)) as f:
lines = f.readlines()
for line in lines:
cls = line.replace('\n', '')
filenames = os.listdir(os.path.join(caption_root, cls))
for filename in filenames:
datum = load_lua(os.path.join(caption_root, cls, filename))
raw_desc = datum['char'].numpy()
desc, len_desc = self._get_word_vectors(raw_desc, word_embedding)
output.append({
'img': os.path.join(img_root, datum['img']),
'desc': desc,
'len_desc': len_desc
})
return output
示例2: __init__
# 需要导入模块: from torch.utils import serialization [as 别名]
# 或者: from torch.utils.serialization import load_lua [as 别名]
def __init__(self,args):
super(WCT, self).__init__()
# load pre-trained network
vgg1 = load_lua(args.vgg1)
decoder1_torch = load_lua(args.decoder1)
vgg2 = load_lua(args.vgg2)
decoder2_torch = load_lua(args.decoder2)
vgg3 = load_lua(args.vgg3)
decoder3_torch = load_lua(args.decoder3)
vgg4 = load_lua(args.vgg4)
decoder4_torch = load_lua(args.decoder4)
vgg5 = load_lua(args.vgg5)
decoder5_torch = load_lua(args.decoder5)
self.e1 = encoder1(vgg1)
self.d1 = decoder1(decoder1_torch)
self.e2 = encoder2(vgg2)
self.d2 = decoder2(decoder2_torch)
self.e3 = encoder3(vgg3)
self.d3 = decoder3(decoder3_torch)
self.e4 = encoder4(vgg4)
self.d4 = decoder4(decoder4_torch)
self.e5 = encoder5(vgg5)
self.d5 = decoder5(decoder5_torch)
示例3: __init__
# 需要导入模块: from torch.utils import serialization [as 别名]
# 或者: from torch.utils.serialization import load_lua [as 别名]
def __init__(self, state_dict='SP_GoogleNet_ImageNet.pt'):
super(SP_GoogLeNet, self).__init__()
state_dict = load_lua(state_dict)
pretrained_model = state_dict[0]
pretrained_model.evaluate()
self.features = LegacyModel(pretrained_model)
self.pooling = nn.Sequential()
self.pooling.add_module('adconv', nn.Conv2d(832, 1024, kernel_size=3, stride=1, padding=1, groups=2, bias=True))
self.pooling.add_module('maps', nn.ReLU())
self.pooling.add_module('sp', SoftProposal(factor=2.1))
self.pooling.add_module('sum', SpatialSumOverMap())
self.pooling.adconv.weight.data.copy_(state_dict[1][0])
self.pooling.adconv.bias.data.copy_(state_dict[1][1])
# classification layer
self.classifier = nn.Linear(1024, 1000)
self.classifier.weight.data.copy_(state_dict[2][0])
self.classifier.bias.data.copy_(state_dict[2][1])
# image normalization
self.image_normalization_mean = [0.485, 0.456, 0.406]
self.image_normalization_std = [0.229, 0.224, 0.225]
示例4: __init__
# 需要导入模块: from torch.utils import serialization [as 别名]
# 或者: from torch.utils.serialization import load_lua [as 别名]
def __init__(self, model_file_name, input_shape):
super(TorchParser, self).__init__()
if not os.path.exists(model_file_name):
raise ValueError("Torch7 model file [{}] is not found.".format(model_file_name))
model = load_lua(model_file_name)
if type(model).__name__=='hashable_uniq_dict':
model = model.model
model.evaluate()
self.weight_loaded = True
# Build network graph
self.torch_graph = TorchGraph(model)
self.torch_graph.build([[1] + list(map(int, input_shape))])
示例5: init_vgg16
# 需要导入模块: from torch.utils import serialization [as 别名]
# 或者: from torch.utils.serialization import load_lua [as 别名]
def init_vgg16(model_folder ='model'):
"""load the vgg16 model feature"""
if not os.path.exists(model_folder+'/vgg16.weight'):
if not os.path.exists(model_folder+'/vgg16.t7'):
os.system('wget http://bengxy.com/dataset/vgg16.t7 '+model_folder+'/vgg16.t7')
vgglua = load_lua(model_folder + '/vgg16.t7')
vgg= net.Vgg16Part()
for ( src, dst) in zip(vgglua.parameters()[0], vgg.parameters()):
dst[:].data = src[:]
# here comes a bug in pytorch version 0.1.10
# change to dst[:].data = src[:]
# ref to issue:
torch.save(vgg.state_dict(), model_folder+'/vgg16.weight')
# Gram Loss
示例6: convert
# 需要导入模块: from torch.utils import serialization [as 别名]
# 或者: from torch.utils.serialization import load_lua [as 别名]
def convert(src_model_path, dst_model_path, weights_indices):
model = load_lua(src_model_path)
weights = []
for idx in weights_indices:
kernel = model.modules[idx].weight.numpy()
bias = model.modules[idx].bias.numpy()
weights.append(kernel)
weights.append(bias)
np.savez(dst_model_path, *weights)
示例7: generateSampleFace
# 需要导入模块: from torch.utils import serialization [as 别名]
# 或者: from torch.utils.serialization import load_lua [as 别名]
def generateSampleFace(self, idx):
sf = self.scale_factor
rf = self.rot_factor
main_pts = load_lua(self.anno[idx])
pts = main_pts # 3D landmarks only. # if self.pointType == '2D' else main_pts[1]
mins_ = torch.min(pts, 0)[0].view(2) # min vals
maxs_ = torch.max(pts, 0)[0].view(2) # max vals
c = torch.FloatTensor((maxs_[0] - (maxs_[0] - mins_[0]) / 2,
maxs_[1] - (maxs_[1] - mins_[1]) / 2))
c[1] -= ((maxs_[1] - mins_[1]) * 0.12)
s = (maxs_[0] - mins_[0] + maxs_[1] - mins_[1]) / 195
img = load_image(self.anno[idx][:-3] + '.jpg')
r = 0
if self.is_train:
s = s * torch.randn(1).mul_(sf).add_(1).clamp(1 - sf, 1 + sf)[0]
r = torch.randn(1).mul_(rf).clamp(-2 * rf, 2 * rf)[0] if random.random() <= 0.6 else 0
if random.random() <= 0.5:
img = torch.from_numpy(fliplr(img.numpy())).float()
pts = shufflelr(pts, width=img.size(2), dataset='vw300')
c[0] = img.size(2) - c[0]
img[0, :, :].mul_(random.uniform(0.7, 1.3)).clamp_(0, 1)
img[1, :, :].mul_(random.uniform(0.7, 1.3)).clamp_(0, 1)
img[2, :, :].mul_(random.uniform(0.7, 1.3)).clamp_(0, 1)
inp = crop(img, c, s, [256, 256], rot=r)
inp = color_normalize(inp, self.mean, self.std)
tpts = pts.clone()
out = torch.zeros(self.nParts, 64, 64)
for i in range(self.nParts):
if tpts[i, 0] > 0:
tpts[i, 0:2] = to_torch(transform(tpts[i, 0:2] + 1, c, s, [64, 64], rot=r))
out[i] = draw_labelmap(out[i], tpts[i] - 1, sigma=1)
return inp, out, pts, c, s
示例8: generateSampleFace
# 需要导入模块: from torch.utils import serialization [as 别名]
# 或者: from torch.utils.serialization import load_lua [as 别名]
def generateSampleFace(self, idx):
sf = self.scale_factor
rf = self.rot_factor
main_pts = load_lua(
os.path.join(self.img_folder, 'landmarks', self.anno[idx].split('_')[0],
self.anno[idx][:-4] + '.t7'))
pts = main_pts[0] if self.pointType == '2D' else main_pts[1]
c = torch.Tensor((450 / 2, 450 / 2 + 50))
s = 1.8
img = load_image(
os.path.join(self.img_folder, self.anno[idx].split('_')[0], self.anno[idx][:-8] +
'.jpg'))
r = 0
if self.is_train:
s = s * torch.randn(1).mul_(sf).add_(1).clamp(1 - sf, 1 + sf)[0]
r = torch.randn(1).mul_(rf).clamp(-2 * rf, 2 * rf)[0] if random.random() <= 0.6 else 0
if random.random() <= 0.5:
img = torch.from_numpy(fliplr(img.numpy())).float()
pts = shufflelr(pts, width=img.size(2), dataset='w300lp')
c[0] = img.size(2) - c[0]
img[0, :, :].mul_(random.uniform(0.7, 1.3)).clamp_(0, 1)
img[1, :, :].mul_(random.uniform(0.7, 1.3)).clamp_(0, 1)
img[2, :, :].mul_(random.uniform(0.7, 1.3)).clamp_(0, 1)
inp = crop(img, c, s, [256, 256], rot=r)
inp = color_normalize(inp, self.mean, self.std)
tpts = pts.clone()
out = torch.zeros(self.nParts, 64, 64)
for i in range(self.nParts):
if tpts[i, 0] > 0:
tpts[i, 0:2] = to_torch(transform(tpts[i, 0:2] + 1, c, s, [64, 64], rot=r))
out[i] = draw_labelmap(out[i], tpts[i] - 1, sigma=1)
return inp, out, pts, c, s
示例9: load_torch_model
# 需要导入模块: from torch.utils import serialization [as 别名]
# 或者: from torch.utils.serialization import load_lua [as 别名]
def load_torch_model(path):
model = load_lua(path, unknown_classes=True)
replace_module(
model,
lambda m: isinstance(m, TorchObject) and
m.torch_typename() == 'nn.InstanceNormalization',
create_instance_norm
)
replace_module(
model,
lambda m: isinstance(m, SpatialFullConvolution),
fix_full_conv
)
return model
示例10: init_vgg16
# 需要导入模块: from torch.utils import serialization [as 别名]
# 或者: from torch.utils.serialization import load_lua [as 别名]
def init_vgg16(model_folder):
if not os.path.exists(os.path.join(model_folder, 'vgg16.weight')):
if not os.path.exists(os.path.join(model_folder, 'vgg16.t7')):
os.system(
'wget https://www.dropbox.com/s/76l3rt4kyi3s8x7/vgg16.t7?dl=1 -O ' + os.path.join(model_folder, 'vgg16.t7'))
vgglua = load_lua(os.path.join(model_folder, 'vgg16.t7'))
vgg = Vgg16()
for (src, dst) in zip(vgglua.parameters()[0], vgg.parameters()):
dst.data[:] = src
torch.save(vgg.state_dict(), os.path.join(model_folder, 'vgg16.weight'))
示例11: torch_to_pytorch
# 需要导入模块: from torch.utils import serialization [as 别名]
# 或者: from torch.utils.serialization import load_lua [as 别名]
def torch_to_pytorch(t7_filename,outputname=None, save_output_to_file = True):
model = load_lua(t7_filename,unknown_classes=True)
if type(model).__name__=='hashable_uniq_dict': model=model.model
model.gradInput = None
slist = lua_recursive_source(lnn.Sequential().add(model))
s = simplify_source(slist)
header = '''
import torch
import torch.nn as nn
import torch.legacy.nn as lnn
from functools import reduce
from torch.autograd import Variable
class LambdaBase(nn.Sequential):
def __init__(self, fn, *args):
super(LambdaBase, self).__init__(*args)
self.lambda_func = fn
def forward_prepare(self, input):
output = []
for module in self._modules.values():
output.append(module(input))
return output if output else input
class Lambda(LambdaBase):
def forward(self, input):
return self.lambda_func(self.forward_prepare(input))
class LambdaMap(LambdaBase):
def forward(self, input):
return list(map(self.lambda_func,self.forward_prepare(input)))
class LambdaReduce(LambdaBase):
def forward(self, input):
return reduce(self.lambda_func,self.forward_prepare(input))
'''
varname = t7_filename.replace('.t7','').replace('.','_').replace('-','_')
s = '{}\n\n{} = {}'.format(header,varname,s[:-2])
if save_output_to_file:
if outputname is None: outputname=varname
with open(outputname+'.py', "w") as pyfile:
pyfile.write(s)
n = nn.Sequential()
lua_recursive_model(model,n)
if save_output_to_file:
torch.save(n.state_dict(),outputname+'.pth')
return n
示例12: torch_to_pytorch
# 需要导入模块: from torch.utils import serialization [as 别名]
# 或者: from torch.utils.serialization import load_lua [as 别名]
def torch_to_pytorch(t7_filename, outputname=None):
model = load_lua(t7_filename, unknown_classes=True)
if type(model).__name__ == 'hashable_uniq_dict': model = model.model
model.gradInput = None
slist = lua_recursive_source(torch.legacy.nn.Sequential().add(model))
s = simplify_source(slist)
header = '''
import torch
import torch.nn as nn
from torch.autograd import Variable
from functools import reduce
class LambdaBase(nn.Sequential):
def __init__(self, fn, *args):
super(LambdaBase, self).__init__(*args)
self.lambda_func = fn
def forward_prepare(self, input):
output = []
for module in self._modules.values():
output.append(module(input))
return output if output else input
class Lambda(LambdaBase):
def forward(self, input):
return self.lambda_func(self.forward_prepare(input))
class LambdaMap(LambdaBase):
def forward(self, input):
return list(map(self.lambda_func,self.forward_prepare(input)))
class LambdaReduce(LambdaBase):
def forward(self, input):
return reduce(self.lambda_func,self.forward_prepare(input))
'''
varname = t7_filename.replace('.t7', '').replace('.', '_').replace('-',
'_')
s = '{}\n\n{} = {}'.format(header, varname, s[:-2])
if outputname is None: outputname = varname
with open(outputname + '.py', "w") as pyfile:
pyfile.write(s)
n = nn.Sequential()
lua_recursive_model(model, n)
torch.save(n.state_dict(), outputname + '.pth')
示例13: cvt2png
# 需要导入模块: from torch.utils import serialization [as 别名]
# 或者: from torch.utils.serialization import load_lua [as 别名]
def cvt2png(target_dir, patterns, pano_map_path):
os.makedirs(target_dir, exist_ok=True)
for cat in cat_list:
for pat in patterns:
# Define source file paths
th_path = os.path.join(ORGIN_DATA_DIR, pat % cat)
assert os.path.isfile(th_path), '%s not found !!!' % th_path
if pat.startswith('stanford'):
gt_path = os.path.join(
ORGIN_GT_DIR, 'pano_id_%s.txt' % pat[-9:-3])
else:
gt_path = os.path.join(
ORGIN_GT_DIR, 'panoContext_%s.txt' % pat.split('_')[-1].split('.')[0])
assert os.path.isfile(gt_path), '%s not found !!!' % gt_path
# Parse file names from gt list
with open(gt_path) as f:
fnames = [line.strip() for line in f]
print('%-30s: %3d examples' % (pat % cat, len(fnames)))
# Remapping panoContext filenames
if pat.startswith('pano'):
fnames_cnt = dict([(v, 0) for v in fnames])
with open(pano_map_path) as f:
for line in f:
v, k, _ = line.split()
k = int(k)
fnames[k] = v
fnames_cnt[v] += 1
for v in fnames_cnt.values():
assert v == 1
# Parse th file
imgs = load_lua(th_path).numpy()
assert imgs.shape[0] == len(fnames), 'number of data and gt mismatched !!!'
# Dump each images to target direcotry
target_cat_dir = os.path.join(target_dir, cat)
os.makedirs(target_cat_dir, exist_ok=True)
for img, fname in zip(imgs, fnames):
target_path = os.path.join(target_cat_dir, fname)
if img.shape[0] == 3:
# RGB
Image.fromarray(
(img.transpose([1, 2, 0]) * 255).astype(np.uint8)).save(target_path)
else:
# Gray
Image.fromarray(
(img[0] * 255).astype(np.uint8)).save(target_path)
示例14: torch_to_pytorch
# 需要导入模块: from torch.utils import serialization [as 别名]
# 或者: from torch.utils.serialization import load_lua [as 别名]
def torch_to_pytorch(t7_filename,outputname=None):
model = load_lua(t7_filename,unknown_classes=True)
if type(model).__name__=='hashable_uniq_dict': model=model.model
model.gradInput = None
slist = lua_recursive_source(torch.legacy.nn.Sequential().add(model))
s = simplify_source(slist)
header = '''
import torch
import torch.nn as nn
from torch.autograd import Variable
from functools import reduce
class LambdaBase(nn.Sequential):
def __init__(self, fn, *args):
super(LambdaBase, self).__init__(*args)
self.lambda_func = fn
def forward_prepare(self, input):
output = []
for module in self._modules.values():
output.append(module(input))
return output if output else input
class Lambda(LambdaBase):
def forward(self, input):
return self.lambda_func(self.forward_prepare(input))
class LambdaMap(LambdaBase):
def forward(self, input):
return list(map(self.lambda_func,self.forward_prepare(input)))
class LambdaReduce(LambdaBase):
def forward(self, input):
return reduce(self.lambda_func,self.forward_prepare(input))
'''
varname = t7_filename.replace('.t7','').replace('.','_').replace('-','_')
s = '{}\n\n{} = {}'.format(header,varname,s[:-2])
if outputname is None: outputname=varname
with open(outputname+'.py', "w") as pyfile:
pyfile.write(s)
n = nn.Sequential()
lua_recursive_model(model,n)
torch.save(n.state_dict(),outputname+'.pth')
示例15: torch_to_pytorch
# 需要导入模块: from torch.utils import serialization [as 别名]
# 或者: from torch.utils.serialization import load_lua [as 别名]
def torch_to_pytorch(t7_filename,outputname=None):
model = load_lua(t7_filename,unknown_classes=True)
if type(model).__name__=='hashable_uniq_dict': model=model.model
model.gradInput = None
slist = lua_recursive_source(lnn.Sequential().add(model))
s = simplify_source(slist)
header = '''
import torch
import torch.nn as nn
import torch.legacy.nn as lnn
from functools import reduce
from torch.autograd import Variable
class LambdaBase(nn.Sequential):
def __init__(self, fn, *args):
super(LambdaBase, self).__init__(*args)
self.lambda_func = fn
def forward_prepare(self, input):
output = []
for module in self._modules.values():
output.append(module(input))
return output if output else input
class Lambda(LambdaBase):
def forward(self, input):
return self.lambda_func(self.forward_prepare(input))
class LambdaMap(LambdaBase):
def forward(self, input):
return list(map(self.lambda_func,self.forward_prepare(input)))
class LambdaReduce(LambdaBase):
def forward(self, input):
return reduce(self.lambda_func,self.forward_prepare(input))
'''
varname = t7_filename.replace('.t7','').replace('.','_').replace('-','_')
s = '{}\n\n{} = {}'.format(header,varname,s[:-2])
if outputname is None: outputname=varname
with open(outputname+'.py', "w") as pyfile:
pyfile.write(s)
n = nn.Sequential()
lua_recursive_model(model,n)
torch.save(n.state_dict(),outputname+'.pth')