当前位置: 首页>>代码示例>>Python>>正文


Python data_loader.get_loader方法代码示例

本文整理汇总了Python中data_loader.get_loader方法的典型用法代码示例。如果您正苦于以下问题:Python data_loader.get_loader方法的具体用法?Python data_loader.get_loader怎么用?Python data_loader.get_loader使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在data_loader的用法示例。


在下文中一共展示了data_loader.get_loader方法的12个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: main

# 需要导入模块: import data_loader [as 别名]
# 或者: from data_loader import get_loader [as 别名]
def main(config):
    svhn_loader, mnist_loader = get_loader(config)
    
    solver = Solver(config, svhn_loader, mnist_loader)
    cudnn.benchmark = True 
    
    # create directories if not exist
    if not os.path.exists(config.model_path):
        os.makedirs(config.model_path)
    if not os.path.exists(config.sample_path):
        os.makedirs(config.sample_path)
    
    if config.mode == 'train':
        solver.train()
    elif config.mode == 'sample':
        solver.sample() 
开发者ID:yunjey,项目名称:mnist-svhn-transfer,代码行数:18,代码来源:main.py

示例2: train

# 需要导入模块: import data_loader [as 别名]
# 或者: from data_loader import get_loader [as 别名]
def train(model):
    optimizer = optim.Adam(model.parameters(), lr=1e-3)

    data_iter = data_loader.get_loader(batch_size=args.batch_size)

    for epoch in range(args.epochs):
        model.train()

        run_loss = 0.0

        for idx, data in enumerate(data_iter):
            data = utils.to_var(data)
            ret = model.run_on_batch(data, optimizer, epoch)

            run_loss += ret['loss'].item()

            print '\r Progress epoch {}, {:.2f}%, average loss {}'.format(epoch, (idx + 1) * 100.0 / len(data_iter), run_loss / (idx + 1.0)),

        evaluate(model, data_iter) 
开发者ID:caow13,项目名称:BRITS,代码行数:21,代码来源:main.py

示例3: main

# 需要导入模块: import data_loader [as 别名]
# 或者: from data_loader import get_loader [as 别名]
def main():
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    val_transform = transforms.Compose([
        transforms.Resize(256), 
        transforms.CenterCrop(224)])

    val_loader = get_loader(opts.img_path, val_transform, vocab, opts.data_path, partition='test',
                            batch_size=opts.batch_size, shuffle=False,
                            num_workers=opts.workers, pin_memory=True)
    print('Validation loader prepared.')

    test(val_loader) 
开发者ID:hwang1996,项目名称:ACME,代码行数:16,代码来源:test.py

示例4: main

# 需要导入模块: import data_loader [as 别名]
# 或者: from data_loader import get_loader [as 别名]
def main(config):
    # For fast training.
    cudnn.benchmark = True

    # Create directories if not exist.
    if not os.path.exists(config.log_dir):
        os.makedirs(config.log_dir)
    if not os.path.exists(config.model_save_dir):
        os.makedirs(config.model_save_dir)
    if not os.path.exists(config.sample_dir):
        os.makedirs(config.sample_dir)
    if not os.path.exists(config.result_dir):
        os.makedirs(config.result_dir)

    # Data loader.
    celeba_loader = None
    rafd_loader = None

    if config.dataset in ['CelebA', 'Both']:
        celeba_loader = get_loader(config.celeba_image_dir, config.attr_path, config.selected_attrs,
                                   config.celeba_crop_size, config.image_size, config.batch_size,
                                   'CelebA', config.mode, config.num_workers)
    if config.dataset in ['RaFD', 'Both']:
        rafd_loader = get_loader(config.rafd_image_dir, None, None,
                                 config.rafd_crop_size, config.image_size, config.batch_size,
                                 'RaFD', config.mode, config.num_workers)
    

    # Solver for training and testing StarGAN.
    solver = Solver(celeba_loader, rafd_loader, config)

    if config.mode == 'train':
        if config.dataset in ['CelebA', 'RaFD']:
            solver.train()
        elif config.dataset in ['Both']:
            solver.train_multi()
    elif config.mode == 'test':
        if config.dataset in ['CelebA', 'RaFD']:
            solver.test()
        elif config.dataset in ['Both']:
            solver.test_multi() 
开发者ID:yunjey,项目名称:stargan,代码行数:43,代码来源:main.py

示例5: main

# 需要导入模块: import data_loader [as 别名]
# 或者: from data_loader import get_loader [as 别名]
def main(config):
    prepare_dirs_and_logger(config)

    torch.manual_seed(config.random_seed)
    if config.num_gpu > 0:
        torch.cuda.manual_seed(config.random_seed)

    if config.is_train:
        data_path = config.data_path
        batch_size = config.batch_size
    else:
        if config.test_data_path is None:
            data_path = config.data_path
        else:
            data_path = config.test_data_path
        batch_size = config.sample_per_image

    a_data_loader, b_data_loader = get_loader(
            data_path, batch_size, config.input_scale_size,
            config.num_worker, config.skip_pix2pix_processing)

    trainer = Trainer(config, a_data_loader, b_data_loader)

    if config.is_train:
        save_config(config)
        trainer.train()
    else:
        if not config.load_path:
            raise Exception("[!] You should specify `load_path` to load a pretrained model")
        trainer.test() 
开发者ID:BMIRDS,项目名称:HistoGAN,代码行数:32,代码来源:main.py

示例6: main

# 需要导入模块: import data_loader [as 别名]
# 或者: from data_loader import get_loader [as 别名]
def main(config):
    from torch.backends import cudnn
    # For fast training
    cudnn.benchmark = True

    data_loader = get_loader(
        config.mode_data,
        config.image_size,
        config.batch_size,
        config.dataset_fake,
        config.mode,
        num_workers=config.num_workers,
        all_attr=config.ALL_ATTR,
        c_dim=config.c_dim)

    from misc.scores import set_score
    if set_score(config):
        return

    if config.mode == 'train':
        from train import Train
        Train(config, data_loader)
        from test import Test
        test = Test(config, data_loader)
        test(dataset=config.dataset_real)

    elif config.mode == 'test':
        from test import Test
        test = Test(config, data_loader)
        if config.DEMO_PATH:
            test.DEMO(config.DEMO_PATH)
        else:
            test(dataset=config.dataset_real) 
开发者ID:BCV-Uniandes,项目名称:SMIT,代码行数:35,代码来源:main.py

示例7: __init__

# 需要导入模块: import data_loader [as 别名]
# 或者: from data_loader import get_loader [as 别名]
def __init__(self, config):

        super(Scores, self).__init__(config)
        self.data_loader = get_loader(
            config.mode_data,
            config.image_size,
            1,
            config.dataset_fake,
            config.mode,
            num_workers=config.num_workers,
            all_attr=config.ALL_ATTR,
            c_dim=config.c_dim) 
开发者ID:BCV-Uniandes,项目名称:SMIT,代码行数:14,代码来源:scores.py

示例8: main

# 需要导入模块: import data_loader [as 别名]
# 或者: from data_loader import get_loader [as 别名]
def main(config):
    prepare_dirs_and_logger(config)

    rng = np.random.RandomState(config.random_seed)
    tf.set_random_seed(config.random_seed)

    if config.is_train:
        data_path = config.data_path
        batch_size = config.batch_size
        do_shuffle = True
    else:
        setattr(config, 'batch_size', 64)
        if config.test_data_path is None:
            data_path = config.data_path
        else:
            data_path = config.test_data_path
        batch_size = config.sample_per_image
        do_shuffle = False

    data_loader = get_loader(
            data_path, config.batch_size, config.input_scale_size,
            config.data_format, config.split)
    trainer = Trainer(config, data_loader)

    if config.is_train:
        save_config(config)
        trainer.train()
    else:
        if not config.load_path:
            raise Exception("[!] You should specify `load_path` to load a pretrained model")
        trainer.test() 
开发者ID:carpedm20,项目名称:BEGAN-tensorflow,代码行数:33,代码来源:main.py

示例9: DEMO

# 需要导入模块: import data_loader [as 别名]
# 或者: from data_loader import get_loader [as 别名]
def DEMO(self, path):
        from data_loader import get_loader
        last_name = self.resume_name()
        save_folder = os.path.join(self.config.sample_path,
                                   '{}_test'.format(last_name))
        create_dir(save_folder)
        batch_size = 1
        no_label = self.config.dataset_fake in self.Binary_Datasets
        data_loader = get_loader(
            path,
            self.config.image_size,
            batch_size,
            shuffling=False,
            dataset='DEMO',
            Detect_Face=True,
            mode='test')
        label = self.config.DEMO_LABEL
        if self.config.DEMO_LABEL != '':
            label = torch.FloatTensor([int(i) for i in label.split(',')]).view(
                1, -1)
        else:
            label = None
        _debug = range(self.config.style_label_debug + 1)
        style_all = self.G.random_style(max(self.config.batch_size, 50))

        name = TimeNow_str()
        for i, real_x in enumerate(data_loader):
            save_path = os.path.join(save_folder, 'DEMO_{}_{}.jpg'.format(
                name, i + 1))
            self.PRINT('Translated test images and saved into "{}"..!'.format(
                save_path))
            for k in _debug:
                self.generate_SMIT(
                    real_x,
                    save_path,
                    label=label,
                    Multimodal=k,
                    fixed_style=style_all,
                    TIME=not i,
                    no_label=no_label,
                    circle=True)
                self.generate_SMIT(
                    real_x,
                    save_path,
                    label=label,
                    Multimodal=k,
                    no_label=no_label,
                    circle=True)

    # ==================================================================#
    # ==================================================================# 
开发者ID:BCV-Uniandes,项目名称:SMIT,代码行数:53,代码来源:test.py

示例10: main

# 需要导入模块: import data_loader [as 别名]
# 或者: from data_loader import get_loader [as 别名]
def main(args):
    # Create model directory
    if not os.path.exists(args.model_path):
        os.makedirs(args.model_path)
    
    # Image preprocessing, normalization for the pretrained resnet
    transform = transforms.Compose([ 
        transforms.RandomCrop(args.crop_size),
        transforms.RandomHorizontalFlip(), 
        transforms.ToTensor(), 
        transforms.Normalize((0.485, 0.456, 0.406), 
                             (0.229, 0.224, 0.225))])
    
    # Load vocabulary wrapper
    with open(args.vocab_path, 'rb') as f:
        vocab = pickle.load(f)
    
    # Build data loader
    data_loader = get_loader(args.image_dir, args.caption_path, vocab, 
                             transform, args.batch_size,
                             shuffle=True, num_workers=args.num_workers) 

    # Build the models
    encoder = EncoderCNN(args.embed_size).to(device)
    decoder = DecoderRNN(args.embed_size, args.hidden_size, len(vocab), args.num_layers).to(device)
    
    # Loss and optimizer
    criterion = nn.CrossEntropyLoss()
    params = list(decoder.parameters()) + list(encoder.linear.parameters()) + list(encoder.bn.parameters())
    optimizer = torch.optim.Adam(params, lr=args.learning_rate)
    
    # Train the models
    total_step = len(data_loader)
    for epoch in range(args.num_epochs):
        for i, (images, captions, lengths) in enumerate(data_loader):
            
            # Set mini-batch dataset
            images = images.to(device)
            captions = captions.to(device)
            targets = pack_padded_sequence(captions, lengths, batch_first=True)[0]
            
            # Forward, backward and optimize
            features = encoder(images)
            outputs = decoder(features, captions, lengths)
            loss = criterion(outputs, targets)
            decoder.zero_grad()
            encoder.zero_grad()
            loss.backward()
            optimizer.step()

            # Print log info
            if i % args.log_step == 0:
                print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}, Perplexity: {:5.4f}'
                      .format(epoch, args.num_epochs, i, total_step, loss.item(), np.exp(loss.item()))) 
                
            # Save the model checkpoints
            if (i+1) % args.save_step == 0:
                torch.save(decoder.state_dict(), os.path.join(
                    args.model_path, 'decoder-{}-{}.ckpt'.format(epoch+1, i+1)))
                torch.save(encoder.state_dict(), os.path.join(
                    args.model_path, 'encoder-{}-{}.ckpt'.format(epoch+1, i+1))) 
开发者ID:yunjey,项目名称:pytorch-tutorial,代码行数:63,代码来源:train.py

示例11: main

# 需要导入模块: import data_loader [as 别名]
# 或者: from data_loader import get_loader [as 别名]
def main(config):
    # For fast training
    cudnn.benchmark = True

    # Create directories if not exist
    if not os.path.exists(config.log_path):
        os.makedirs(config.log_path)
    if not os.path.exists(config.model_save_path):
        os.makedirs(config.model_save_path)

    # Data loader
    of_loader = None

    img_size = config.image_size
    rgb_loader = get_loader(
        config.metadata_path,
        img_size,
        img_size,
        config.batch_size,
        config.mode,
        demo=config.DEMO,
        num_workers=config.num_workers,
        OF=False,
        verbose=True,
        imagenet=config.finetuning == 'imagenet')

    if config.OF:
        of_loader = get_loader(
            config.metadata_path,
            img_size,
            img_size,
            config.batch_size,
            config.mode,
            demo=config.DEMO,
            num_workers=config.num_workers,
            OF=True,
            verbose=True,
            imagenet=config.finetuning == 'imagenet')

    # Solver
    from solver import Solver
    solver = Solver(rgb_loader, config, of_loader=of_loader)

    if config.SHOW_MODEL:
        solver.display_net()
        return

    if config.DEMO:
        solver.DEMO()
        return

    if config.mode == 'train':
        solver.train()
        solver.test()
    elif config.mode == 'val':
        solver.val(load=True, init=True)
    elif config.mode == 'test':
        solver.test()
    elif config.mode == 'sample':
        solver.sample() 
开发者ID:BCV-Uniandes,项目名称:AUNets,代码行数:62,代码来源:main.py

示例12: val

# 需要导入模块: import data_loader [as 别名]
# 或者: from data_loader import get_loader [as 别名]
def val(self, init=False, load=False):

        if init:
            from data_loader import get_loader
            self.rgb_loader_val = get_loader(self.metadata_path,
                                             self.image_size, self.image_size,
                                             self.batch_size, 'val')
            if self.OF:
                self.of_loader_val = get_loader(
                    self.metadata_path,
                    self.image_size,
                    self.image_size,
                    self.batch_size,
                    'val',
                    OF=True)

            txt_path = os.path.join(self.model_save_path, '0_init_val.txt')

        if load:
            last_name = os.path.basename(self.test_model).split('.')[0]
            txt_path = os.path.join(self.model_save_path,
                                    '{}_{}_val.txt'.format(last_name, '{}'))
            try:
                output_txt = sorted(glob.glob(txt_path.format('*')))[-1]
                number_file = len(glob.glob(output_txt))
            except BaseException:
                number_file = 0
            txt_path = txt_path.format(str(number_file).zfill(2))

            D_path = os.path.join(self.model_save_path,
                                  '{}.pth'.format(last_name))
            self.C.load_state_dict(torch.load(D_path))

        self.C.eval()

        if load:
            self.f = open(txt_path, 'a')
        self.thresh = np.linspace(0.01, 0.99, 200).astype(np.float32)
        if not self.OF:
            self.of_loader_val = None
        f1, _, _, loss, f1_one = F1_TEST(
            self,
            self.rgb_loader_val,
            mode='VAL',
            OF=self.of_loader_val,
            verbose=load)
        if load:
            self.f.close()
        if init:
            return f1, loss, f1_one
        else:
            return f1, loss

    # ====================================================================#
    # ====================================================================# 
开发者ID:BCV-Uniandes,项目名称:AUNets,代码行数:57,代码来源:solver.py


注:本文中的data_loader.get_loader方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。