当前位置: 首页>>代码示例>>Python>>正文


Python data_gen.data_transforms方法代码示例

本文整理汇总了Python中data_gen.data_transforms方法的典型用法代码示例。如果您正苦于以下问题:Python data_gen.data_transforms方法的具体用法?Python data_gen.data_transforms怎么用?Python data_gen.data_transforms使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在data_gen的用法示例。


在下文中一共展示了data_gen.data_transforms方法的4个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: gen_feature

# 需要导入模块: import data_gen [as 别名]
# 或者: from data_gen import data_transforms [as 别名]
def gen_feature(path, model):
    model.eval()

    print('gen features {}...'.format(path))
    # Preprocess the total files count
    files = []
    for filepath in walkdir(path, ('.jpg', '.png')):
        files.append(filepath)
    file_count = len(files)

    transformer = data_transforms['val']

    batch_size = 128

    with torch.no_grad():
        for start_idx in tqdm(range(0, file_count, batch_size)):
            end_idx = min(file_count, start_idx + batch_size)
            length = end_idx - start_idx

            imgs_0 = torch.zeros([length, 3, 112, 112], dtype=torch.float, device=device)
            for idx in range(0, length):
                i = start_idx + idx
                filepath = files[i]
                imgs_0[idx] = get_image(transformer, filepath, flip=False)

            features_0 = model(imgs_0.to(device))
            features_0 = features_0.cpu().numpy()

            imgs_1 = torch.zeros([length, 3, 112, 112], dtype=torch.float, device=device)
            for idx in range(0, length):
                i = start_idx + idx
                filepath = files[i]
                imgs_1[idx] = get_image(transformer, filepath, flip=True)

            features_1 = model(imgs_1.to(device))
            features_1 = features_1.cpu().numpy()

            for idx in range(0, length):
                i = start_idx + idx
                filepath = files[i]
                filepath = filepath.replace(' ', '_')
                tarfile = filepath + '_0.bin'
                feature = features_0[idx] + features_1[idx]
                write_feature(tarfile, feature / np.linalg.norm(feature)) 
开发者ID:foamliu,项目名称:InsightFace-PyTorch,代码行数:46,代码来源:megaface_utils.py

示例2: gen_feature

# 需要导入模块: import data_gen [as 别名]
# 或者: from data_gen import data_transforms [as 别名]
def gen_feature(path, model=None):
    transformer = data_transforms['val']

    if model is None:
        checkpoint = 'BEST_checkpoint.tar'
        print('loading model: {}...'.format(checkpoint))
        checkpoint = torch.load(checkpoint)
        model = checkpoint['model'].module.to(device)

    model.eval()

    print('gen features {}...'.format(path))
    # Preprocess the total files count
    files = []
    for filepath in walkdir(path, '.jpg'):
        files.append(filepath)
    file_count = len(files)

    batch_size = 128

    with torch.no_grad():
        for start_idx in tqdm(range(0, file_count, batch_size)):
            end_idx = min(file_count, start_idx + batch_size)
            length = end_idx - start_idx

            imgs_0 = torch.zeros([length, 3, 112, 112], dtype=torch.float)
            for idx in range(0, length):
                i = start_idx + idx
                filepath = files[i]
                imgs_0[idx] = get_image(filepath, transformer, flip=False)

            features_0 = model(imgs_0.to(device)).cpu().numpy()

            imgs_1 = torch.zeros([length, 3, 112, 112], dtype=torch.float)
            for idx in range(0, length):
                i = start_idx + idx
                filepath = files[i]
                imgs_1[idx] = get_image(filepath, transformer, flip=True)

            features_1 = model(imgs_1.to(device)).cpu().numpy()

            for idx in range(0, length):
                i = start_idx + idx
                filepath = files[i]
                tarfile = filepath + '_0.bin'
                feature = features_0[idx] + features_1[idx]
                write_feature(tarfile, feature / np.linalg.norm(feature)) 
开发者ID:LcenArthas,项目名称:CCF-BDCI2019-Multi-person-Face-Recognition-Competition-Baseline,代码行数:49,代码来源:megaface_utils.py

示例3: test

# 需要导入模块: import data_gen [as 别名]
# 或者: from data_gen import data_transforms [as 别名]
def test(model):
    model.eval()

    transformer = data_transforms['valid']

    names = gen_test_names()

    mse_losses = AverageMeter()
    sad_losses = AverageMeter()

    i = 0
    for name in tqdm(names):
        fcount = int(name.split('.')[0].split('_')[0])
        bcount = int(name.split('.')[0].split('_')[1])
        im_name = fg_test_files[fcount]
        bg_name = bg_test_files[bcount]
        trimap_name = im_name.split('.')[0] + '_' + str(i) + '.png'

        trimap = cv.imread('data/Combined_Dataset/Test_set/Adobe-licensed images/trimaps/' + trimap_name, 0)

        i += 1
        if i == 20:
            i = 0

        img, alpha, fg, bg, new_trimap = process_test(im_name, bg_name, trimap, trimap_name)
        h, w = img.shape[:2]

        x = torch.zeros((1, 4, h, w), dtype=torch.float)
        img = img[..., ::-1]  # RGB
        img = transforms.ToPILImage()(img)  # [3, 320, 320]
        img = transformer(img)  # [3, 320, 320]
        x[0:, 0:3, :, :] = img
        x[0:, 3, :, :] = torch.from_numpy(new_trimap.copy() / 255.)

        # Move to GPU, if available
        x = x.type(torch.FloatTensor).to(device)  # [1, 4, 320, 320]
        alpha = alpha / 255.

        with torch.no_grad():
            pred = model(x)  # [1, 4, 320, 320]

        pred = pred.cpu().numpy()
        pred = pred.reshape((h, w))  # [320, 320]

        pred[new_trimap == 0] = 0.0
        pred[new_trimap == 255] = 1.0
        cv.imwrite('images/test/out/' + trimap_name, pred * 255)

        # Calculate loss
        mse_loss = compute_mse(pred, alpha, trimap)
        sad_loss = compute_sad(pred, alpha)

        # Keep track of metrics
        mse_losses.update(mse_loss.item())
        sad_losses.update(sad_loss.item())

    return sad_losses.avg, mse_losses.avg 
开发者ID:foamliu,项目名称:Mobile-Image-Matting,代码行数:59,代码来源:test.py

示例4: evaluate

# 需要导入模块: import data_gen [as 别名]
# 或者: from data_gen import data_transforms [as 别名]
def evaluate(model):
    model.eval()

    with open(lfw_pickle, 'rb') as file:
        data = pickle.load(file)

    samples = data['samples']

    filename = 'data/lfw_test_pair.txt'
    with open(filename, 'r') as file:
        lines = file.readlines()

    transformer = data_transforms['val']

    angles = []

    start = time.time()
    with torch.no_grad():
        for line in tqdm(lines):
            tokens = line.split()
            file0 = tokens[0]
            img0 = get_image(samples, transformer, file0)
            file1 = tokens[1]
            img1 = get_image(samples, transformer, file1)
            imgs = torch.zeros([2, 3, 112, 112], dtype=torch.float, device=device)
            imgs[0] = img0
            imgs[1] = img1

            output = model(imgs)

            feature0 = output[0].cpu().numpy()
            feature1 = output[1].cpu().numpy()
            x0 = feature0 / np.linalg.norm(feature0)
            x1 = feature1 / np.linalg.norm(feature1)
            cosine = np.dot(x0, x1)
            cosine = np.clip(cosine, -1.0, 1.0)
            theta = math.acos(cosine)
            theta = theta * 180 / math.pi
            is_same = tokens[2]
            angles.append('{} {}\n'.format(theta, is_same))

    elapsed_time = time.time() - start
    print('elapsed time(sec) per image: {}'.format(elapsed_time / (6000 * 2)))

    with open('data/angles.txt', 'w') as file:
        file.writelines(angles) 
开发者ID:foamliu,项目名称:InsightFace-v2,代码行数:48,代码来源:lfw_eval.py


注:本文中的data_gen.data_transforms方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。