本文整理汇总了Python中data_loader.get_loader方法的典型用法代码示例。如果您正苦于以下问题:Python data_loader.get_loader方法的具体用法?Python data_loader.get_loader怎么用?Python data_loader.get_loader使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类data_loader
的用法示例。
在下文中一共展示了data_loader.get_loader方法的12个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: main
# 需要导入模块: import data_loader [as 别名]
# 或者: from data_loader import get_loader [as 别名]
def main(config):
svhn_loader, mnist_loader = get_loader(config)
solver = Solver(config, svhn_loader, mnist_loader)
cudnn.benchmark = True
# create directories if not exist
if not os.path.exists(config.model_path):
os.makedirs(config.model_path)
if not os.path.exists(config.sample_path):
os.makedirs(config.sample_path)
if config.mode == 'train':
solver.train()
elif config.mode == 'sample':
solver.sample()
示例2: train
# 需要导入模块: import data_loader [as 别名]
# 或者: from data_loader import get_loader [as 别名]
def train(model):
optimizer = optim.Adam(model.parameters(), lr=1e-3)
data_iter = data_loader.get_loader(batch_size=args.batch_size)
for epoch in range(args.epochs):
model.train()
run_loss = 0.0
for idx, data in enumerate(data_iter):
data = utils.to_var(data)
ret = model.run_on_batch(data, optimizer, epoch)
run_loss += ret['loss'].item()
print '\r Progress epoch {}, {:.2f}%, average loss {}'.format(epoch, (idx + 1) * 100.0 / len(data_iter), run_loss / (idx + 1.0)),
evaluate(model, data_iter)
示例3: main
# 需要导入模块: import data_loader [as 别名]
# 或者: from data_loader import get_loader [as 别名]
def main():
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
val_transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224)])
val_loader = get_loader(opts.img_path, val_transform, vocab, opts.data_path, partition='test',
batch_size=opts.batch_size, shuffle=False,
num_workers=opts.workers, pin_memory=True)
print('Validation loader prepared.')
test(val_loader)
示例4: main
# 需要导入模块: import data_loader [as 别名]
# 或者: from data_loader import get_loader [as 别名]
def main(config):
# For fast training.
cudnn.benchmark = True
# Create directories if not exist.
if not os.path.exists(config.log_dir):
os.makedirs(config.log_dir)
if not os.path.exists(config.model_save_dir):
os.makedirs(config.model_save_dir)
if not os.path.exists(config.sample_dir):
os.makedirs(config.sample_dir)
if not os.path.exists(config.result_dir):
os.makedirs(config.result_dir)
# Data loader.
celeba_loader = None
rafd_loader = None
if config.dataset in ['CelebA', 'Both']:
celeba_loader = get_loader(config.celeba_image_dir, config.attr_path, config.selected_attrs,
config.celeba_crop_size, config.image_size, config.batch_size,
'CelebA', config.mode, config.num_workers)
if config.dataset in ['RaFD', 'Both']:
rafd_loader = get_loader(config.rafd_image_dir, None, None,
config.rafd_crop_size, config.image_size, config.batch_size,
'RaFD', config.mode, config.num_workers)
# Solver for training and testing StarGAN.
solver = Solver(celeba_loader, rafd_loader, config)
if config.mode == 'train':
if config.dataset in ['CelebA', 'RaFD']:
solver.train()
elif config.dataset in ['Both']:
solver.train_multi()
elif config.mode == 'test':
if config.dataset in ['CelebA', 'RaFD']:
solver.test()
elif config.dataset in ['Both']:
solver.test_multi()
示例5: main
# 需要导入模块: import data_loader [as 别名]
# 或者: from data_loader import get_loader [as 别名]
def main(config):
prepare_dirs_and_logger(config)
torch.manual_seed(config.random_seed)
if config.num_gpu > 0:
torch.cuda.manual_seed(config.random_seed)
if config.is_train:
data_path = config.data_path
batch_size = config.batch_size
else:
if config.test_data_path is None:
data_path = config.data_path
else:
data_path = config.test_data_path
batch_size = config.sample_per_image
a_data_loader, b_data_loader = get_loader(
data_path, batch_size, config.input_scale_size,
config.num_worker, config.skip_pix2pix_processing)
trainer = Trainer(config, a_data_loader, b_data_loader)
if config.is_train:
save_config(config)
trainer.train()
else:
if not config.load_path:
raise Exception("[!] You should specify `load_path` to load a pretrained model")
trainer.test()
示例6: main
# 需要导入模块: import data_loader [as 别名]
# 或者: from data_loader import get_loader [as 别名]
def main(config):
from torch.backends import cudnn
# For fast training
cudnn.benchmark = True
data_loader = get_loader(
config.mode_data,
config.image_size,
config.batch_size,
config.dataset_fake,
config.mode,
num_workers=config.num_workers,
all_attr=config.ALL_ATTR,
c_dim=config.c_dim)
from misc.scores import set_score
if set_score(config):
return
if config.mode == 'train':
from train import Train
Train(config, data_loader)
from test import Test
test = Test(config, data_loader)
test(dataset=config.dataset_real)
elif config.mode == 'test':
from test import Test
test = Test(config, data_loader)
if config.DEMO_PATH:
test.DEMO(config.DEMO_PATH)
else:
test(dataset=config.dataset_real)
示例7: __init__
# 需要导入模块: import data_loader [as 别名]
# 或者: from data_loader import get_loader [as 别名]
def __init__(self, config):
super(Scores, self).__init__(config)
self.data_loader = get_loader(
config.mode_data,
config.image_size,
1,
config.dataset_fake,
config.mode,
num_workers=config.num_workers,
all_attr=config.ALL_ATTR,
c_dim=config.c_dim)
示例8: main
# 需要导入模块: import data_loader [as 别名]
# 或者: from data_loader import get_loader [as 别名]
def main(config):
prepare_dirs_and_logger(config)
rng = np.random.RandomState(config.random_seed)
tf.set_random_seed(config.random_seed)
if config.is_train:
data_path = config.data_path
batch_size = config.batch_size
do_shuffle = True
else:
setattr(config, 'batch_size', 64)
if config.test_data_path is None:
data_path = config.data_path
else:
data_path = config.test_data_path
batch_size = config.sample_per_image
do_shuffle = False
data_loader = get_loader(
data_path, config.batch_size, config.input_scale_size,
config.data_format, config.split)
trainer = Trainer(config, data_loader)
if config.is_train:
save_config(config)
trainer.train()
else:
if not config.load_path:
raise Exception("[!] You should specify `load_path` to load a pretrained model")
trainer.test()
示例9: DEMO
# 需要导入模块: import data_loader [as 别名]
# 或者: from data_loader import get_loader [as 别名]
def DEMO(self, path):
from data_loader import get_loader
last_name = self.resume_name()
save_folder = os.path.join(self.config.sample_path,
'{}_test'.format(last_name))
create_dir(save_folder)
batch_size = 1
no_label = self.config.dataset_fake in self.Binary_Datasets
data_loader = get_loader(
path,
self.config.image_size,
batch_size,
shuffling=False,
dataset='DEMO',
Detect_Face=True,
mode='test')
label = self.config.DEMO_LABEL
if self.config.DEMO_LABEL != '':
label = torch.FloatTensor([int(i) for i in label.split(',')]).view(
1, -1)
else:
label = None
_debug = range(self.config.style_label_debug + 1)
style_all = self.G.random_style(max(self.config.batch_size, 50))
name = TimeNow_str()
for i, real_x in enumerate(data_loader):
save_path = os.path.join(save_folder, 'DEMO_{}_{}.jpg'.format(
name, i + 1))
self.PRINT('Translated test images and saved into "{}"..!'.format(
save_path))
for k in _debug:
self.generate_SMIT(
real_x,
save_path,
label=label,
Multimodal=k,
fixed_style=style_all,
TIME=not i,
no_label=no_label,
circle=True)
self.generate_SMIT(
real_x,
save_path,
label=label,
Multimodal=k,
no_label=no_label,
circle=True)
# ==================================================================#
# ==================================================================#
示例10: main
# 需要导入模块: import data_loader [as 别名]
# 或者: from data_loader import get_loader [as 别名]
def main(args):
# Create model directory
if not os.path.exists(args.model_path):
os.makedirs(args.model_path)
# Image preprocessing, normalization for the pretrained resnet
transform = transforms.Compose([
transforms.RandomCrop(args.crop_size),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406),
(0.229, 0.224, 0.225))])
# Load vocabulary wrapper
with open(args.vocab_path, 'rb') as f:
vocab = pickle.load(f)
# Build data loader
data_loader = get_loader(args.image_dir, args.caption_path, vocab,
transform, args.batch_size,
shuffle=True, num_workers=args.num_workers)
# Build the models
encoder = EncoderCNN(args.embed_size).to(device)
decoder = DecoderRNN(args.embed_size, args.hidden_size, len(vocab), args.num_layers).to(device)
# Loss and optimizer
criterion = nn.CrossEntropyLoss()
params = list(decoder.parameters()) + list(encoder.linear.parameters()) + list(encoder.bn.parameters())
optimizer = torch.optim.Adam(params, lr=args.learning_rate)
# Train the models
total_step = len(data_loader)
for epoch in range(args.num_epochs):
for i, (images, captions, lengths) in enumerate(data_loader):
# Set mini-batch dataset
images = images.to(device)
captions = captions.to(device)
targets = pack_padded_sequence(captions, lengths, batch_first=True)[0]
# Forward, backward and optimize
features = encoder(images)
outputs = decoder(features, captions, lengths)
loss = criterion(outputs, targets)
decoder.zero_grad()
encoder.zero_grad()
loss.backward()
optimizer.step()
# Print log info
if i % args.log_step == 0:
print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}, Perplexity: {:5.4f}'
.format(epoch, args.num_epochs, i, total_step, loss.item(), np.exp(loss.item())))
# Save the model checkpoints
if (i+1) % args.save_step == 0:
torch.save(decoder.state_dict(), os.path.join(
args.model_path, 'decoder-{}-{}.ckpt'.format(epoch+1, i+1)))
torch.save(encoder.state_dict(), os.path.join(
args.model_path, 'encoder-{}-{}.ckpt'.format(epoch+1, i+1)))
示例11: main
# 需要导入模块: import data_loader [as 别名]
# 或者: from data_loader import get_loader [as 别名]
def main(config):
# For fast training
cudnn.benchmark = True
# Create directories if not exist
if not os.path.exists(config.log_path):
os.makedirs(config.log_path)
if not os.path.exists(config.model_save_path):
os.makedirs(config.model_save_path)
# Data loader
of_loader = None
img_size = config.image_size
rgb_loader = get_loader(
config.metadata_path,
img_size,
img_size,
config.batch_size,
config.mode,
demo=config.DEMO,
num_workers=config.num_workers,
OF=False,
verbose=True,
imagenet=config.finetuning == 'imagenet')
if config.OF:
of_loader = get_loader(
config.metadata_path,
img_size,
img_size,
config.batch_size,
config.mode,
demo=config.DEMO,
num_workers=config.num_workers,
OF=True,
verbose=True,
imagenet=config.finetuning == 'imagenet')
# Solver
from solver import Solver
solver = Solver(rgb_loader, config, of_loader=of_loader)
if config.SHOW_MODEL:
solver.display_net()
return
if config.DEMO:
solver.DEMO()
return
if config.mode == 'train':
solver.train()
solver.test()
elif config.mode == 'val':
solver.val(load=True, init=True)
elif config.mode == 'test':
solver.test()
elif config.mode == 'sample':
solver.sample()
示例12: val
# 需要导入模块: import data_loader [as 别名]
# 或者: from data_loader import get_loader [as 别名]
def val(self, init=False, load=False):
if init:
from data_loader import get_loader
self.rgb_loader_val = get_loader(self.metadata_path,
self.image_size, self.image_size,
self.batch_size, 'val')
if self.OF:
self.of_loader_val = get_loader(
self.metadata_path,
self.image_size,
self.image_size,
self.batch_size,
'val',
OF=True)
txt_path = os.path.join(self.model_save_path, '0_init_val.txt')
if load:
last_name = os.path.basename(self.test_model).split('.')[0]
txt_path = os.path.join(self.model_save_path,
'{}_{}_val.txt'.format(last_name, '{}'))
try:
output_txt = sorted(glob.glob(txt_path.format('*')))[-1]
number_file = len(glob.glob(output_txt))
except BaseException:
number_file = 0
txt_path = txt_path.format(str(number_file).zfill(2))
D_path = os.path.join(self.model_save_path,
'{}.pth'.format(last_name))
self.C.load_state_dict(torch.load(D_path))
self.C.eval()
if load:
self.f = open(txt_path, 'a')
self.thresh = np.linspace(0.01, 0.99, 200).astype(np.float32)
if not self.OF:
self.of_loader_val = None
f1, _, _, loss, f1_one = F1_TEST(
self,
self.rgb_loader_val,
mode='VAL',
OF=self.of_loader_val,
verbose=load)
if load:
self.f.close()
if init:
return f1, loss, f1_one
else:
return f1, loss
# ====================================================================#
# ====================================================================#