当前位置: 首页>>代码示例>>Python>>正文


Python data.SubsetRandomSampler方法代码示例

本文整理汇总了Python中torch.utils.data.SubsetRandomSampler方法的典型用法代码示例。如果您正苦于以下问题:Python data.SubsetRandomSampler方法的具体用法?Python data.SubsetRandomSampler怎么用?Python data.SubsetRandomSampler使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在torch.utils.data的用法示例。


在下文中一共展示了data.SubsetRandomSampler方法的2个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: get_dataloader

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import SubsetRandomSampler [as 别名]
def get_dataloader(synthetic_dataset, real_dataset, height, width, batch_size, workers,
                   is_train, keep_ratio):
  num_synthetic_dataset = len(synthetic_dataset)
  num_real_dataset = len(real_dataset)

  synthetic_indices = list(np.random.permutation(num_synthetic_dataset))
  synthetic_indices = synthetic_indices[num_real_dataset:]
  real_indices = list(np.random.permutation(num_real_dataset) + num_synthetic_dataset)
  concated_indices = synthetic_indices + real_indices
  assert len(concated_indices) == num_synthetic_dataset

  sampler = SubsetRandomSampler(concated_indices)
  concated_dataset = ConcatDataset([synthetic_dataset, real_dataset])
  print('total image: ', len(concated_dataset))

  data_loader = DataLoader(concated_dataset, batch_size=batch_size, num_workers=workers,
    shuffle=False, pin_memory=True, drop_last=True, sampler=sampler,
    collate_fn=AlignCollate(imgH=height, imgW=width, keep_ratio=keep_ratio))
  return concated_dataset, data_loader 
开发者ID:ayumiymk,项目名称:aster.pytorch,代码行数:21,代码来源:main.py

示例2: main

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import SubsetRandomSampler [as 别名]
def main(device, model_param, optimizer_param, scheduler_param, dataset_param, dataloader_param,
         num_epochs, seed, load_model):
    print("Seed:", seed)
    print()
    torch.manual_seed(seed)

    dataloader_param["collate_fn"] = graph_collate

    # Create dataset
    dataset = GraphDataset(dataset_param["dataset_path"], dataset_param["target_name"])

    # split the dataset into training, validation, and test sets.
    split_file_path = dataset_param["split_file"]
    if split_file_path is not None and os.path.isfile(split_file_path):
        with open(split_file_path) as f:
            split = json.load(f)
    else:
        print("No split file. Default split: 256 (train), 32 (val), 32 (test)")
        split = {"train": range(256), "val": range(256, 288), "test": range(288, 320)}
    print(" ".join(["{}: {}".format(k, len(x)) for k, x in split.items()]))

    # Create a CGNN model
    model = create_model(device, model_param, optimizer_param, scheduler_param)
    if load_model:
        print("Loading weights from model.pth")
        model.load()
    #print("Model:", model.device)

    # Train
    train_sampler = SubsetRandomSampler(split["train"])
    val_sampler = SubsetRandomSampler(split["val"])
    train_dl = DataLoader(dataset, sampler=train_sampler, **dataloader_param)
    val_dl = DataLoader(dataset, sampler=val_sampler, **dataloader_param)
    model.train(train_dl, val_dl, num_epochs)
    if num_epochs > 0:
        model.save()

    # Test
    test_set = Subset(dataset, split["test"])
    test_dl = DataLoader(test_set, **dataloader_param)
    outputs, targets = model.evaluate(test_dl)
    names = [dataset.graph_names[i] for i in split["test"]]
    df_predictions = pd.DataFrame({"name": names, "prediction": outputs, "target": targets})
    df_predictions.to_csv("test_predictions.csv", index=False)

    print("\nEND") 
开发者ID:Tony-Y,项目名称:cgnn,代码行数:48,代码来源:cgnn.py


注:本文中的torch.utils.data.SubsetRandomSampler方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。