当前位置: 首页>>代码示例>>Python>>正文


Python data.random_split方法代码示例

本文整理汇总了Python中torch.utils.data.random_split方法的典型用法代码示例。如果您正苦于以下问题:Python data.random_split方法的具体用法?Python data.random_split怎么用?Python data.random_split使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在torch.utils.data的用法示例。


在下文中一共展示了data.random_split方法的8个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: load_data

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import random_split [as 别名]
def load_data(test_split, batch_size):
    """Loads the data"""
    sonar_dataset = SonarDataset('./sonar.all-data')
    # Create indices for the split
    dataset_size = len(sonar_dataset)
    test_size = int(test_split * dataset_size)
    train_size = dataset_size - test_size

    train_dataset, test_dataset = random_split(sonar_dataset,
                                               [train_size, test_size])

    train_loader = DataLoader(
        train_dataset.dataset,
        batch_size=batch_size,
        shuffle=True)
    test_loader = DataLoader(
        test_dataset.dataset,
        batch_size=batch_size,
        shuffle=True)

    return train_loader, test_loader 
开发者ID:GoogleCloudPlatform,项目名称:cloudml-samples,代码行数:23,代码来源:data_utils.py

示例2: load_data

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import random_split [as 别名]
def load_data(test_split, seed, batch_size):
    """Loads the data"""
    sonar_dataset = SonarDataset('./sonar.all-data')
    # Create indices for the split
    dataset_size = len(sonar_dataset)
    test_size = int(test_split * dataset_size)
    train_size = dataset_size - test_size

    train_dataset, test_dataset = random_split(sonar_dataset,
                                               [train_size, test_size])

    train_loader = DataLoader(
        train_dataset.dataset,
        batch_size=batch_size,
        shuffle=True)
    test_loader = DataLoader(
        test_dataset.dataset,
        batch_size=batch_size,
        shuffle=True)

    return train_loader, test_loader 
开发者ID:GoogleCloudPlatform,项目名称:cloudml-samples,代码行数:23,代码来源:data_utils.py

示例3: get_iterators

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import random_split [as 别名]
def get_iterators(config, train_file, test_file, val_file=None):
    train_set = MyDataset(train_file, config)
    test_set = MyDataset(test_file, config)
    
    # If validation file exists, load it. Otherwise get validation data from training data
    if val_file:
        val_set = MyDataset(val_file, config)
    else:
        train_size = int(0.9 * len(train_set))
        test_size = len(train_set) - train_size
        train_set, val_set = data.random_split(train_set, [train_size, test_size])
    
    train_iterator = DataLoader(train_set, batch_size=config.batch_size, shuffle=True)
    test_iterator = DataLoader(test_set, batch_size=config.batch_size)
    val_iterator = DataLoader(val_set, batch_size=config.batch_size)
    return train_iterator, test_iterator, val_iterator 
开发者ID:AnubhavGupta3377,项目名称:Text-Classification-Models-Pytorch,代码行数:18,代码来源:utils.py

示例4: _get_samples

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import random_split [as 别名]
def _get_samples(dataset, sample_dataset_size=1):
        import math
        if int(len(dataset) * sample_dataset_size) <= 0:
            raise ValueError(
                "Dataset is %d too small. `sample_dataset_size` is %f" % (len(dataset), sample_dataset_size))
        size_is_prop = isinstance(sample_dataset_size, float)
        size_is_amount = isinstance(sample_dataset_size, int)
        if size_is_prop:
            if not (0 < sample_dataset_size <= 1):
                raise ValueError("sample_dataset_size proportion should between 0. and 1.")
            subdata_size = math.floor(sample_dataset_size * len(dataset))
        elif size_is_amount:
            if not (sample_dataset_size < len(dataset)):
                raise ValueError("sample_dataset_size amount should be smaller than length of dataset")
            subdata_size = sample_dataset_size
        else:
            raise Exception("sample_dataset_size should be float or int."
                            "%s was given" % str(sample_dataset_size))
        sample_dataset, _ = random_split(dataset, [subdata_size, len(dataset) - subdata_size])
        sample_loader = DataLoader(sample_dataset, batch_size=subdata_size, shuffle=True)
        [samples_data] = list(sample_loader)
        return samples_data 
开发者ID:dingguanglei,项目名称:jdit,代码行数:24,代码来源:dataset.py

示例5: setup_data_loaders

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import random_split [as 别名]
def setup_data_loaders(data, batch_size, input_size):
    label_encoder = LabelEncoder()
    y = label_encoder.fit_transform(data[1])
    data = TensorDataset(torch.from_numpy(data[0]), torch.from_numpy(y), torch.from_numpy(data[2]))
    train_size = int(0.9 * len(data))
    sizes = (train_size, len(data) - train_size)
    train_set, test_set = random_split(data, sizes)
    train_transforms = transforms.Compose([
        transforms.ToPILImage(),
        transforms.RandomCrop(input_size),
        transforms.ToTensor(),
    ])
    train_set = TransformDataset(train_set, train_transforms)
    val_transforms = transforms.Compose([
        transforms.ToPILImage(),
        transforms.CenterCrop(input_size),
        transforms.ToTensor(),
    ])
    test_set = TransformDataset(test_set, val_transforms)
    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(test_set, batch_size=batch_size)
    return train_loader, val_loader 
开发者ID:luizgh,项目名称:sigver,代码行数:24,代码来源:train.py

示例6: prepare_data

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import random_split [as 别名]
def prepare_data(self):
        mnist_train = self.download_data(self.data_dir)

        self.mnist_train, self.mnist_val = random_split(
            mnist_train, [55000, 5000]) 
开发者ID:ray-project,项目名称:ray,代码行数:7,代码来源:mnist_pytorch_lightning.py

示例7: create_dataloader

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import random_split [as 别名]
def create_dataloader(self,
                          df: pd.DataFrame,
                          batch_size: int = 32,
                          shuffle: bool = False,
                          valid_pct: float = None):
        "Process rows in pd.DataFrame using n_cpus and return a DataLoader"

        tqdm.pandas()
        with ProcessPoolExecutor(max_workers=n_cpu) as executor:
            result = list(
                tqdm(executor.map(self.process_row, df.iterrows(), chunksize=8192),
                     desc=f"Processing {len(df)} examples on {n_cpu} cores",
                     total=len(df)))

        features = [r[0] for r in result]
        labels = [r[1] for r in result]

        dataset = TensorDataset(torch.tensor(features, dtype=torch.long),
                                torch.tensor(labels, dtype=torch.long))

        if valid_pct is not None:
            valid_size = int(valid_pct * len(df))
            train_size = len(df) - valid_size
            valid_dataset, train_dataset = random_split(dataset, [valid_size, train_size])
            valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)
            train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
            return train_loader, valid_loader

        data_loader = DataLoader(dataset,
                                 batch_size=batch_size,
                                 num_workers=0,
                                 shuffle=shuffle,
                                 pin_memory=torch.cuda.is_available())
        return data_loader 
开发者ID:prrao87,项目名称:fine-grained-sentiment,代码行数:36,代码来源:loader.py

示例8: main

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import random_split [as 别名]
def main(conf):
    total_set = DNSDataset(conf['data']['json_dir'])
    train_len = int(len(total_set) * (1 - conf['data']['val_prop']))
    val_len = len(total_set) - train_len
    train_set, val_set = random_split(total_set, [train_len, val_len])

    train_loader = DataLoader(train_set, shuffle=True,
                              batch_size=conf['training']['batch_size'],
                              num_workers=conf['training']['num_workers'],
                              drop_last=True)
    val_loader = DataLoader(val_set, shuffle=False,
                            batch_size=conf['training']['batch_size'],
                            num_workers=conf['training']['num_workers'],
                            drop_last=True)

    # Define model and optimizer in a local function (defined in the recipe).
    # Two advantages to this : re-instantiating the model and optimizer
    # for retraining and evaluating is straight-forward.
    model, optimizer = make_model_and_optimizer(conf)
    # Define scheduler
    scheduler = None
    if conf['training']['half_lr']:
        scheduler = ReduceLROnPlateau(optimizer=optimizer, factor=0.5,
                                      patience=5)
    # Just after instantiating, save the args. Easy loading in the future.
    exp_dir = conf['main_args']['exp_dir']
    os.makedirs(exp_dir, exist_ok=True)
    conf_path = os.path.join(exp_dir, 'conf.yml')
    with open(conf_path, 'w') as outfile:
        yaml.safe_dump(conf, outfile)

    # Define Loss function.
    loss_func = partial(distance, is_complex=conf['main_args']['is_complex'])
    system = SimpleSystem(model=model, loss_func=loss_func, optimizer=optimizer,
                          train_loader=train_loader, val_loader=val_loader,
                          scheduler=scheduler, config=conf)

    # Define callbacks
    checkpoint_dir = os.path.join(exp_dir, 'checkpoints/')
    checkpoint = ModelCheckpoint(checkpoint_dir, monitor='val_loss',
                                 mode='min', save_top_k=5, verbose=1)
    early_stopping = False
    if conf['training']['early_stop']:
        early_stopping = EarlyStopping(monitor='val_loss', patience=10,
                                       verbose=1)

    # Don't ask GPU if they are not available.
    gpus = -1 if torch.cuda.is_available() else None
    trainer = pl.Trainer(max_nb_epochs=conf['training']['epochs'],
                         checkpoint_callback=checkpoint,
                         early_stop_callback=early_stopping,
                         default_save_path=exp_dir,
                         gpus=gpus,
                         distributed_backend='dp',
                         train_percent_check=1.0,  # Useful for fast experiment
                         gradient_clip_val=5.,)
    trainer.fit(system)

    best_k = {k: v.item() for k, v in checkpoint.best_k_models.items()}
    with open(os.path.join(exp_dir, "best_k_models.json"), "w") as f:
        json.dump(best_k, f, indent=0) 
开发者ID:mpariente,项目名称:asteroid,代码行数:63,代码来源:train.py


注:本文中的torch.utils.data.random_split方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。