当前位置: 首页>>代码示例>>Python>>正文


Python serialization.load_lua方法代码示例

本文整理汇总了Python中torch.utils.serialization.load_lua方法的典型用法代码示例。如果您正苦于以下问题:Python serialization.load_lua方法的具体用法?Python serialization.load_lua怎么用?Python serialization.load_lua使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在torch.utils.serialization的用法示例。


在下文中一共展示了serialization.load_lua方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: _load_dataset

# 需要导入模块: from torch.utils import serialization [as 别名]
# 或者: from torch.utils.serialization import load_lua [as 别名]
def _load_dataset(self, img_root, caption_root, classes_filename, word_embedding):
        output = []
        with open(os.path.join(caption_root, classes_filename)) as f:
            lines = f.readlines()
            for line in lines:
                cls = line.replace('\n', '')
                filenames = os.listdir(os.path.join(caption_root, cls))
                for filename in filenames:
                    datum = load_lua(os.path.join(caption_root, cls, filename))
                    raw_desc = datum['char'].numpy()
                    desc, len_desc = self._get_word_vectors(raw_desc, word_embedding)
                    output.append({
                        'img': os.path.join(img_root, datum['img']),
                        'desc': desc,
                        'len_desc': len_desc
                    })
        return output 
开发者ID:woozzu,项目名称:dong_iccv_2017,代码行数:19,代码来源:data.py

示例2: __init__

# 需要导入模块: from torch.utils import serialization [as 别名]
# 或者: from torch.utils.serialization import load_lua [as 别名]
def __init__(self,args):
        super(WCT, self).__init__()
        # load pre-trained network
        vgg1 = load_lua(args.vgg1)
        decoder1_torch = load_lua(args.decoder1)
        vgg2 = load_lua(args.vgg2)
        decoder2_torch = load_lua(args.decoder2)
        vgg3 = load_lua(args.vgg3)
        decoder3_torch = load_lua(args.decoder3)
        vgg4 = load_lua(args.vgg4)
        decoder4_torch = load_lua(args.decoder4)
        vgg5 = load_lua(args.vgg5)
        decoder5_torch = load_lua(args.decoder5)


        self.e1 = encoder1(vgg1)
        self.d1 = decoder1(decoder1_torch)
        self.e2 = encoder2(vgg2)
        self.d2 = decoder2(decoder2_torch)
        self.e3 = encoder3(vgg3)
        self.d3 = decoder3(decoder3_torch)
        self.e4 = encoder4(vgg4)
        self.d4 = decoder4(decoder4_torch)
        self.e5 = encoder5(vgg5)
        self.d5 = decoder5(decoder5_torch) 
开发者ID:sunshineatnoon,项目名称:PytorchWCT,代码行数:27,代码来源:util.py

示例3: __init__

# 需要导入模块: from torch.utils import serialization [as 别名]
# 或者: from torch.utils.serialization import load_lua [as 别名]
def __init__(self, state_dict='SP_GoogleNet_ImageNet.pt'):
        super(SP_GoogLeNet, self).__init__()

        state_dict = load_lua(state_dict)
        pretrained_model = state_dict[0]
        pretrained_model.evaluate()
        
        self.features = LegacyModel(pretrained_model)
        self.pooling = nn.Sequential()
        self.pooling.add_module('adconv', nn.Conv2d(832, 1024, kernel_size=3, stride=1, padding=1, groups=2, bias=True))
        self.pooling.add_module('maps', nn.ReLU())
        self.pooling.add_module('sp', SoftProposal(factor=2.1))
        self.pooling.add_module('sum', SpatialSumOverMap())

        self.pooling.adconv.weight.data.copy_(state_dict[1][0])
        self.pooling.adconv.bias.data.copy_(state_dict[1][1])

        # classification layer
        self.classifier = nn.Linear(1024, 1000)

        self.classifier.weight.data.copy_(state_dict[2][0])
        self.classifier.bias.data.copy_(state_dict[2][1])

        # image normalization
        self.image_normalization_mean = [0.485, 0.456, 0.406]
        self.image_normalization_std = [0.229, 0.224, 0.225] 
开发者ID:yeezhu,项目名称:SPN.pytorch,代码行数:28,代码来源:SP_GoogLeNet.py

示例4: __init__

# 需要导入模块: from torch.utils import serialization [as 别名]
# 或者: from torch.utils.serialization import load_lua [as 别名]
def __init__(self, model_file_name, input_shape):
        super(TorchParser, self).__init__()
        if not os.path.exists(model_file_name):
            raise ValueError("Torch7 model file [{}] is not found.".format(model_file_name))
        model = load_lua(model_file_name)
        if type(model).__name__=='hashable_uniq_dict':
            model = model.model
        model.evaluate()
        self.weight_loaded = True

        # Build network graph
        self.torch_graph = TorchGraph(model)
        self.torch_graph.build([[1] + list(map(int, input_shape))]) 
开发者ID:microsoft,项目名称:MMdnn,代码行数:15,代码来源:torch_parser.py

示例5: init_vgg16

# 需要导入模块: from torch.utils import serialization [as 别名]
# 或者: from torch.utils.serialization import load_lua [as 别名]
def init_vgg16(model_folder ='model'):
	"""load the vgg16 model feature"""
	if not os.path.exists(model_folder+'/vgg16.weight'):
		if not os.path.exists(model_folder+'/vgg16.t7'):
			os.system('wget http://bengxy.com/dataset/vgg16.t7 '+model_folder+'/vgg16.t7')
		vgglua = load_lua(model_folder + '/vgg16.t7')
		vgg= net.Vgg16Part()
		for ( src, dst) in zip(vgglua.parameters()[0], vgg.parameters()):
			dst[:].data = src[:]  
			# here comes a bug in pytorch version 0.1.10
			# change to dst[:].data = src[:]
			# ref to issue: 
		torch.save(vgg.state_dict(), model_folder+'/vgg16.weight')

# Gram Loss 
开发者ID:bengxy,项目名称:FastNeuralStyle,代码行数:17,代码来源:utils.py

示例6: convert

# 需要导入模块: from torch.utils import serialization [as 别名]
# 或者: from torch.utils.serialization import load_lua [as 别名]
def convert(src_model_path, dst_model_path, weights_indices):
    model = load_lua(src_model_path)

    weights = []
    for idx in weights_indices:
        kernel = model.modules[idx].weight.numpy()
        bias   = model.modules[idx].bias.numpy()

        weights.append(kernel)
        weights.append(bias)

    np.savez(dst_model_path, *weights) 
开发者ID:elleryqueenhomels,项目名称:arbitrary_style_transfer,代码行数:14,代码来源:convertor.py

示例7: generateSampleFace

# 需要导入模块: from torch.utils import serialization [as 别名]
# 或者: from torch.utils.serialization import load_lua [as 别名]
def generateSampleFace(self, idx):
        sf = self.scale_factor
        rf = self.rot_factor

        main_pts = load_lua(self.anno[idx])
        pts = main_pts  # 3D landmarks only. # if self.pointType == '2D' else main_pts[1]
        mins_ = torch.min(pts, 0)[0].view(2)  # min vals
        maxs_ = torch.max(pts, 0)[0].view(2)  # max vals
        c = torch.FloatTensor((maxs_[0] - (maxs_[0] - mins_[0]) / 2,
                               maxs_[1] - (maxs_[1] - mins_[1]) / 2))
        c[1] -= ((maxs_[1] - mins_[1]) * 0.12)
        s = (maxs_[0] - mins_[0] + maxs_[1] - mins_[1]) / 195

        img = load_image(self.anno[idx][:-3] + '.jpg')

        r = 0
        if self.is_train:
            s = s * torch.randn(1).mul_(sf).add_(1).clamp(1 - sf, 1 + sf)[0]
            r = torch.randn(1).mul_(rf).clamp(-2 * rf, 2 * rf)[0] if random.random() <= 0.6 else 0

            if random.random() <= 0.5:
                img = torch.from_numpy(fliplr(img.numpy())).float()
                pts = shufflelr(pts, width=img.size(2), dataset='vw300')
                c[0] = img.size(2) - c[0]

            img[0, :, :].mul_(random.uniform(0.7, 1.3)).clamp_(0, 1)
            img[1, :, :].mul_(random.uniform(0.7, 1.3)).clamp_(0, 1)
            img[2, :, :].mul_(random.uniform(0.7, 1.3)).clamp_(0, 1)

        inp = crop(img, c, s, [256, 256], rot=r)
        inp = color_normalize(inp, self.mean, self.std)

        tpts = pts.clone()
        out = torch.zeros(self.nParts, 64, 64)
        for i in range(self.nParts):
            if tpts[i, 0] > 0:
                tpts[i, 0:2] = to_torch(transform(tpts[i, 0:2] + 1, c, s, [64, 64], rot=r))
                out[i] = draw_labelmap(out[i], tpts[i] - 1, sigma=1)

        return inp, out, pts, c, s 
开发者ID:GuohongLi,项目名称:face-alignment-pytorch,代码行数:42,代码来源:VW300.py

示例8: generateSampleFace

# 需要导入模块: from torch.utils import serialization [as 别名]
# 或者: from torch.utils.serialization import load_lua [as 别名]
def generateSampleFace(self, idx):
        sf = self.scale_factor
        rf = self.rot_factor

        main_pts = load_lua(
            os.path.join(self.img_folder, 'landmarks', self.anno[idx].split('_')[0],
                         self.anno[idx][:-4] + '.t7'))
        pts = main_pts[0] if self.pointType == '2D' else main_pts[1]
        c = torch.Tensor((450 / 2, 450 / 2 + 50))
        s = 1.8

        img = load_image(
            os.path.join(self.img_folder, self.anno[idx].split('_')[0], self.anno[idx][:-8] +
                         '.jpg'))

        r = 0
        if self.is_train:
            s = s * torch.randn(1).mul_(sf).add_(1).clamp(1 - sf, 1 + sf)[0]
            r = torch.randn(1).mul_(rf).clamp(-2 * rf, 2 * rf)[0] if random.random() <= 0.6 else 0

            if random.random() <= 0.5:
                img = torch.from_numpy(fliplr(img.numpy())).float()
                pts = shufflelr(pts, width=img.size(2), dataset='w300lp')
                c[0] = img.size(2) - c[0]

            img[0, :, :].mul_(random.uniform(0.7, 1.3)).clamp_(0, 1)
            img[1, :, :].mul_(random.uniform(0.7, 1.3)).clamp_(0, 1)
            img[2, :, :].mul_(random.uniform(0.7, 1.3)).clamp_(0, 1)

        inp = crop(img, c, s, [256, 256], rot=r)
        inp = color_normalize(inp, self.mean, self.std)

        tpts = pts.clone()
        out = torch.zeros(self.nParts, 64, 64)
        for i in range(self.nParts):
            if tpts[i, 0] > 0:
                tpts[i, 0:2] = to_torch(transform(tpts[i, 0:2] + 1, c, s, [64, 64], rot=r))
                out[i] = draw_labelmap(out[i], tpts[i] - 1, sigma=1)

        return inp, out, pts, c, s 
开发者ID:GuohongLi,项目名称:face-alignment-pytorch,代码行数:42,代码来源:W300.py

示例9: load_torch_model

# 需要导入模块: from torch.utils import serialization [as 别名]
# 或者: from torch.utils.serialization import load_lua [as 别名]
def load_torch_model(path):
    model = load_lua(path, unknown_classes=True)
    replace_module(
        model,
        lambda m: isinstance(m, TorchObject) and
        m.torch_typename() == 'nn.InstanceNormalization',
        create_instance_norm
    )
    replace_module(
        model,
        lambda m: isinstance(m, SpatialFullConvolution),
        fix_full_conv
    )
    return model 
开发者ID:prisma-ai,项目名称:torch2coreml,代码行数:16,代码来源:convert-fast-neural-style.py

示例10: init_vgg16

# 需要导入模块: from torch.utils import serialization [as 别名]
# 或者: from torch.utils.serialization import load_lua [as 别名]
def init_vgg16(model_folder):
    if not os.path.exists(os.path.join(model_folder, 'vgg16.weight')):
        if not os.path.exists(os.path.join(model_folder, 'vgg16.t7')):
            os.system(
                'wget https://www.dropbox.com/s/76l3rt4kyi3s8x7/vgg16.t7?dl=1 -O ' + os.path.join(model_folder, 'vgg16.t7'))
        vgglua = load_lua(os.path.join(model_folder, 'vgg16.t7'))
        vgg = Vgg16()
        for (src, dst) in zip(vgglua.parameters()[0], vgg.parameters()):
            dst.data[:] = src
        torch.save(vgg.state_dict(), os.path.join(model_folder, 'vgg16.weight')) 
开发者ID:abhiskk,项目名称:fast-neural-style,代码行数:12,代码来源:utils.py

示例11: torch_to_pytorch

# 需要导入模块: from torch.utils import serialization [as 别名]
# 或者: from torch.utils.serialization import load_lua [as 别名]
def torch_to_pytorch(t7_filename,outputname=None, save_output_to_file = True):
    model = load_lua(t7_filename,unknown_classes=True)
    if type(model).__name__=='hashable_uniq_dict': model=model.model
    model.gradInput = None
    slist = lua_recursive_source(lnn.Sequential().add(model))
    s = simplify_source(slist)
    header = '''
import torch
import torch.nn as nn
import torch.legacy.nn as lnn

from functools import reduce
from torch.autograd import Variable

class LambdaBase(nn.Sequential):
    def __init__(self, fn, *args):
        super(LambdaBase, self).__init__(*args)
        self.lambda_func = fn

    def forward_prepare(self, input):
        output = []
        for module in self._modules.values():
            output.append(module(input))
        return output if output else input

class Lambda(LambdaBase):
    def forward(self, input):
        return self.lambda_func(self.forward_prepare(input))

class LambdaMap(LambdaBase):
    def forward(self, input):
        return list(map(self.lambda_func,self.forward_prepare(input)))

class LambdaReduce(LambdaBase):
    def forward(self, input):
        return reduce(self.lambda_func,self.forward_prepare(input))
'''
    varname = t7_filename.replace('.t7','').replace('.','_').replace('-','_')
    s = '{}\n\n{} = {}'.format(header,varname,s[:-2])
    if save_output_to_file:
        if outputname is None: outputname=varname
        with open(outputname+'.py', "w") as pyfile:
            pyfile.write(s)

    n = nn.Sequential()
    lua_recursive_model(model,n)
    if save_output_to_file:
        torch.save(n.state_dict(),outputname+'.pth')
    return n 
开发者ID:kipoi,项目名称:models,代码行数:51,代码来源:convert_Basset_to_pytorch.py

示例12: torch_to_pytorch

# 需要导入模块: from torch.utils import serialization [as 别名]
# 或者: from torch.utils.serialization import load_lua [as 别名]
def torch_to_pytorch(t7_filename, outputname=None):
    model = load_lua(t7_filename, unknown_classes=True)
    if type(model).__name__ == 'hashable_uniq_dict': model = model.model
    model.gradInput = None
    slist = lua_recursive_source(torch.legacy.nn.Sequential().add(model))
    s = simplify_source(slist)
    header = '''
import torch
import torch.nn as nn
from torch.autograd import Variable
from functools import reduce

class LambdaBase(nn.Sequential):
    def __init__(self, fn, *args):
        super(LambdaBase, self).__init__(*args)
        self.lambda_func = fn

    def forward_prepare(self, input):
        output = []
        for module in self._modules.values():
            output.append(module(input))
        return output if output else input

class Lambda(LambdaBase):
    def forward(self, input):
        return self.lambda_func(self.forward_prepare(input))

class LambdaMap(LambdaBase):
    def forward(self, input):
        return list(map(self.lambda_func,self.forward_prepare(input)))

class LambdaReduce(LambdaBase):
    def forward(self, input):
        return reduce(self.lambda_func,self.forward_prepare(input))
'''
    varname = t7_filename.replace('.t7', '').replace('.', '_').replace('-',
                                                                       '_')
    s = '{}\n\n{} = {}'.format(header, varname, s[:-2])

    if outputname is None: outputname = varname
    with open(outputname + '.py', "w") as pyfile:
        pyfile.write(s)

    n = nn.Sequential()
    lua_recursive_model(model, n)
    torch.save(n.state_dict(), outputname + '.pth') 
开发者ID:rgeirhos,项目名称:Stylized-ImageNet,代码行数:48,代码来源:torch_to_pytorch.py

示例13: cvt2png

# 需要导入模块: from torch.utils import serialization [as 别名]
# 或者: from torch.utils.serialization import load_lua [as 别名]
def cvt2png(target_dir, patterns, pano_map_path):
    os.makedirs(target_dir, exist_ok=True)
    for cat in cat_list:
        for pat in patterns:
            # Define source file paths
            th_path = os.path.join(ORGIN_DATA_DIR, pat % cat)
            assert os.path.isfile(th_path), '%s not found !!!' % th_path

            if pat.startswith('stanford'):
                gt_path = os.path.join(
                    ORGIN_GT_DIR, 'pano_id_%s.txt' % pat[-9:-3])
            else:
                gt_path = os.path.join(
                    ORGIN_GT_DIR, 'panoContext_%s.txt' % pat.split('_')[-1].split('.')[0])
            assert os.path.isfile(gt_path), '%s not found !!!' % gt_path

            # Parse file names from gt list
            with open(gt_path) as f:
                fnames = [line.strip() for line in f]
            print('%-30s: %3d examples' % (pat % cat, len(fnames)))

            # Remapping panoContext filenames
            if pat.startswith('pano'):
                fnames_cnt = dict([(v, 0) for v in fnames])
                with open(pano_map_path) as f:
                    for line in f:
                        v, k, _ = line.split()
                        k = int(k)
                        fnames[k] = v
                        fnames_cnt[v] += 1
                for v in fnames_cnt.values():
                    assert v == 1

            # Parse th file
            imgs = load_lua(th_path).numpy()
            assert imgs.shape[0] == len(fnames), 'number of data and gt mismatched !!!'

            # Dump each images to target direcotry
            target_cat_dir = os.path.join(target_dir, cat)
            os.makedirs(target_cat_dir, exist_ok=True)
            for img, fname in zip(imgs, fnames):
                target_path = os.path.join(target_cat_dir, fname)
                if img.shape[0] == 3:
                    # RGB
                    Image.fromarray(
                        (img.transpose([1, 2, 0]) * 255).astype(np.uint8)).save(target_path)
                else:
                    # Gray
                    Image.fromarray(
                        (img[0] * 255).astype(np.uint8)).save(target_path) 
开发者ID:sunset1995,项目名称:pytorch-layoutnet,代码行数:52,代码来源:torch2pytorch_data.py

示例14: torch_to_pytorch

# 需要导入模块: from torch.utils import serialization [as 别名]
# 或者: from torch.utils.serialization import load_lua [as 别名]
def torch_to_pytorch(t7_filename,outputname=None):
    model = load_lua(t7_filename,unknown_classes=True)
    if type(model).__name__=='hashable_uniq_dict': model=model.model
    model.gradInput = None
    slist = lua_recursive_source(torch.legacy.nn.Sequential().add(model))
    s = simplify_source(slist)
    header = '''
import torch
import torch.nn as nn
from torch.autograd import Variable
from functools import reduce

class LambdaBase(nn.Sequential):
    def __init__(self, fn, *args):
        super(LambdaBase, self).__init__(*args)
        self.lambda_func = fn

    def forward_prepare(self, input):
        output = []
        for module in self._modules.values():
            output.append(module(input))
        return output if output else input

class Lambda(LambdaBase):
    def forward(self, input):
        return self.lambda_func(self.forward_prepare(input))

class LambdaMap(LambdaBase):
    def forward(self, input):
        return list(map(self.lambda_func,self.forward_prepare(input)))

class LambdaReduce(LambdaBase):
    def forward(self, input):
        return reduce(self.lambda_func,self.forward_prepare(input))
'''
    varname = t7_filename.replace('.t7','').replace('.','_').replace('-','_')
    s = '{}\n\n{} = {}'.format(header,varname,s[:-2])

    if outputname is None: outputname=varname
    with open(outputname+'.py', "w") as pyfile:
        pyfile.write(s)

    n = nn.Sequential()
    lua_recursive_model(model,n)
    torch.save(n.state_dict(),outputname+'.pth') 
开发者ID:fastai,项目名称:imagenet-fast,代码行数:47,代码来源:convert_torch.py

示例15: torch_to_pytorch

# 需要导入模块: from torch.utils import serialization [as 别名]
# 或者: from torch.utils.serialization import load_lua [as 别名]
def torch_to_pytorch(t7_filename,outputname=None):
    model = load_lua(t7_filename,unknown_classes=True)
    if type(model).__name__=='hashable_uniq_dict': model=model.model
    model.gradInput = None
    slist = lua_recursive_source(lnn.Sequential().add(model))
    s = simplify_source(slist)
    header = '''
import torch
import torch.nn as nn
import torch.legacy.nn as lnn

from functools import reduce
from torch.autograd import Variable

class LambdaBase(nn.Sequential):
    def __init__(self, fn, *args):
        super(LambdaBase, self).__init__(*args)
        self.lambda_func = fn

    def forward_prepare(self, input):
        output = []
        for module in self._modules.values():
            output.append(module(input))
        return output if output else input

class Lambda(LambdaBase):
    def forward(self, input):
        return self.lambda_func(self.forward_prepare(input))

class LambdaMap(LambdaBase):
    def forward(self, input):
        return list(map(self.lambda_func,self.forward_prepare(input)))

class LambdaReduce(LambdaBase):
    def forward(self, input):
        return reduce(self.lambda_func,self.forward_prepare(input))
'''
    varname = t7_filename.replace('.t7','').replace('.','_').replace('-','_')
    s = '{}\n\n{} = {}'.format(header,varname,s[:-2])

    if outputname is None: outputname=varname
    with open(outputname+'.py', "w") as pyfile:
        pyfile.write(s)

    n = nn.Sequential()
    lua_recursive_model(model,n)
    torch.save(n.state_dict(),outputname+'.pth') 
开发者ID:clcarwin,项目名称:convert_torch_to_pytorch,代码行数:49,代码来源:convert_torch.py


注:本文中的torch.utils.serialization.load_lua方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。