當前位置: 首頁>>代碼示例>>Python>>正文


Python data.SubsetRandomSampler方法代碼示例

本文整理匯總了Python中torch.utils.data.SubsetRandomSampler方法的典型用法代碼示例。如果您正苦於以下問題:Python data.SubsetRandomSampler方法的具體用法?Python data.SubsetRandomSampler怎麽用?Python data.SubsetRandomSampler使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在torch.utils.data的用法示例。


在下文中一共展示了data.SubsetRandomSampler方法的2個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。

示例1: get_dataloader

# 需要導入模塊: from torch.utils import data [as 別名]
# 或者: from torch.utils.data import SubsetRandomSampler [as 別名]
def get_dataloader(synthetic_dataset, real_dataset, height, width, batch_size, workers,
                   is_train, keep_ratio):
  num_synthetic_dataset = len(synthetic_dataset)
  num_real_dataset = len(real_dataset)

  synthetic_indices = list(np.random.permutation(num_synthetic_dataset))
  synthetic_indices = synthetic_indices[num_real_dataset:]
  real_indices = list(np.random.permutation(num_real_dataset) + num_synthetic_dataset)
  concated_indices = synthetic_indices + real_indices
  assert len(concated_indices) == num_synthetic_dataset

  sampler = SubsetRandomSampler(concated_indices)
  concated_dataset = ConcatDataset([synthetic_dataset, real_dataset])
  print('total image: ', len(concated_dataset))

  data_loader = DataLoader(concated_dataset, batch_size=batch_size, num_workers=workers,
    shuffle=False, pin_memory=True, drop_last=True, sampler=sampler,
    collate_fn=AlignCollate(imgH=height, imgW=width, keep_ratio=keep_ratio))
  return concated_dataset, data_loader 
開發者ID:ayumiymk,項目名稱:aster.pytorch,代碼行數:21,代碼來源:main.py

示例2: main

# 需要導入模塊: from torch.utils import data [as 別名]
# 或者: from torch.utils.data import SubsetRandomSampler [as 別名]
def main(device, model_param, optimizer_param, scheduler_param, dataset_param, dataloader_param,
         num_epochs, seed, load_model):
    print("Seed:", seed)
    print()
    torch.manual_seed(seed)

    dataloader_param["collate_fn"] = graph_collate

    # Create dataset
    dataset = GraphDataset(dataset_param["dataset_path"], dataset_param["target_name"])

    # split the dataset into training, validation, and test sets.
    split_file_path = dataset_param["split_file"]
    if split_file_path is not None and os.path.isfile(split_file_path):
        with open(split_file_path) as f:
            split = json.load(f)
    else:
        print("No split file. Default split: 256 (train), 32 (val), 32 (test)")
        split = {"train": range(256), "val": range(256, 288), "test": range(288, 320)}
    print(" ".join(["{}: {}".format(k, len(x)) for k, x in split.items()]))

    # Create a CGNN model
    model = create_model(device, model_param, optimizer_param, scheduler_param)
    if load_model:
        print("Loading weights from model.pth")
        model.load()
    #print("Model:", model.device)

    # Train
    train_sampler = SubsetRandomSampler(split["train"])
    val_sampler = SubsetRandomSampler(split["val"])
    train_dl = DataLoader(dataset, sampler=train_sampler, **dataloader_param)
    val_dl = DataLoader(dataset, sampler=val_sampler, **dataloader_param)
    model.train(train_dl, val_dl, num_epochs)
    if num_epochs > 0:
        model.save()

    # Test
    test_set = Subset(dataset, split["test"])
    test_dl = DataLoader(test_set, **dataloader_param)
    outputs, targets = model.evaluate(test_dl)
    names = [dataset.graph_names[i] for i in split["test"]]
    df_predictions = pd.DataFrame({"name": names, "prediction": outputs, "target": targets})
    df_predictions.to_csv("test_predictions.csv", index=False)

    print("\nEND") 
開發者ID:Tony-Y,項目名稱:cgnn,代碼行數:48,代碼來源:cgnn.py


注:本文中的torch.utils.data.SubsetRandomSampler方法示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。