本文整理汇总了Python中torch.utils.data.SubsetRandomSampler方法的典型用法代码示例。如果您正苦于以下问题:Python data.SubsetRandomSampler方法的具体用法?Python data.SubsetRandomSampler怎么用?Python data.SubsetRandomSampler使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类torch.utils.data
的用法示例。
在下文中一共展示了data.SubsetRandomSampler方法的2个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: get_dataloader
# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import SubsetRandomSampler [as 别名]
def get_dataloader(synthetic_dataset, real_dataset, height, width, batch_size, workers,
is_train, keep_ratio):
num_synthetic_dataset = len(synthetic_dataset)
num_real_dataset = len(real_dataset)
synthetic_indices = list(np.random.permutation(num_synthetic_dataset))
synthetic_indices = synthetic_indices[num_real_dataset:]
real_indices = list(np.random.permutation(num_real_dataset) + num_synthetic_dataset)
concated_indices = synthetic_indices + real_indices
assert len(concated_indices) == num_synthetic_dataset
sampler = SubsetRandomSampler(concated_indices)
concated_dataset = ConcatDataset([synthetic_dataset, real_dataset])
print('total image: ', len(concated_dataset))
data_loader = DataLoader(concated_dataset, batch_size=batch_size, num_workers=workers,
shuffle=False, pin_memory=True, drop_last=True, sampler=sampler,
collate_fn=AlignCollate(imgH=height, imgW=width, keep_ratio=keep_ratio))
return concated_dataset, data_loader
示例2: main
# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import SubsetRandomSampler [as 别名]
def main(device, model_param, optimizer_param, scheduler_param, dataset_param, dataloader_param,
num_epochs, seed, load_model):
print("Seed:", seed)
print()
torch.manual_seed(seed)
dataloader_param["collate_fn"] = graph_collate
# Create dataset
dataset = GraphDataset(dataset_param["dataset_path"], dataset_param["target_name"])
# split the dataset into training, validation, and test sets.
split_file_path = dataset_param["split_file"]
if split_file_path is not None and os.path.isfile(split_file_path):
with open(split_file_path) as f:
split = json.load(f)
else:
print("No split file. Default split: 256 (train), 32 (val), 32 (test)")
split = {"train": range(256), "val": range(256, 288), "test": range(288, 320)}
print(" ".join(["{}: {}".format(k, len(x)) for k, x in split.items()]))
# Create a CGNN model
model = create_model(device, model_param, optimizer_param, scheduler_param)
if load_model:
print("Loading weights from model.pth")
model.load()
#print("Model:", model.device)
# Train
train_sampler = SubsetRandomSampler(split["train"])
val_sampler = SubsetRandomSampler(split["val"])
train_dl = DataLoader(dataset, sampler=train_sampler, **dataloader_param)
val_dl = DataLoader(dataset, sampler=val_sampler, **dataloader_param)
model.train(train_dl, val_dl, num_epochs)
if num_epochs > 0:
model.save()
# Test
test_set = Subset(dataset, split["test"])
test_dl = DataLoader(test_set, **dataloader_param)
outputs, targets = model.evaluate(test_dl)
names = [dataset.graph_names[i] for i in split["test"]]
df_predictions = pd.DataFrame({"name": names, "prediction": outputs, "target": targets})
df_predictions.to_csv("test_predictions.csv", index=False)
print("\nEND")