当前位置: 首页>>代码示例>>Python>>正文


Python utils.data方法代码示例

本文整理汇总了Python中torch.utils.data方法的典型用法代码示例。如果您正苦于以下问题:Python utils.data方法的具体用法?Python utils.data怎么用?Python utils.data使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在torch.utils的用法示例。


在下文中一共展示了utils.data方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: eval_1

# 需要导入模块: from torch import utils [as 别名]
# 或者: from torch.utils import data [as 别名]
def eval_1(self, model, testloader, device):
        model.eval()
        with open(self.filename, 'w') as fout:
            self.eval_1__header(fout)
            with torch.no_grad():
                for i, data in enumerate(testloader):
                    p0, p1, igt = data
                    res = self.do_estimate(p0, p1, model, device) # --> [1, 4, 4]
                    ig_gt = igt.cpu().contiguous().view(-1, 4, 4) # --> [1, 4, 4]
                    g_hat = res.cpu().contiguous().view(-1, 4, 4) # --> [1, 4, 4]

                    dg = g_hat.bmm(ig_gt) # if correct, dg == identity matrix.
                    dx = ptlk.se3.log(dg) # --> [1, 6] (if corerct, dx == zero vector)
                    dn = dx.norm(p=2, dim=1) # --> [1]
                    dm = dn.mean()

                    self.eval_1__write(fout, ig_gt, g_hat)
                    LOGGER.info('test, %d/%d, %f', i, len(testloader), dm) 
开发者ID:vinits5,项目名称:pointnet-registration-framework,代码行数:20,代码来源:test_pointlk.py

示例2: trainBatch

# 需要导入模块: from torch import utils [as 别名]
# 或者: from torch.utils import data [as 别名]
def trainBatch(net, criterion, optimizer):
    data = train_iter.next()
    cpu_images, cpu_texts = data
    batch_size = cpu_images.size(0)
    utils.loadData(image, cpu_images)
    t, l = converter.encode(cpu_texts)
    utils.loadData(text, t)
    utils.loadData(length, l)

    preds = crnn(image)
    preds_size = Variable(torch.IntTensor([preds.size(0)] * batch_size))
    cost = criterion(preds, text, preds_size, length) / batch_size
    crnn.zero_grad()
    cost.backward()
    optimizer.step()
    return cost 
开发者ID:zzzDavid,项目名称:ICDAR-2019-SROIE,代码行数:18,代码来源:train.py

示例3: run

# 需要导入模块: from torch import utils [as 别名]
# 或者: from torch.utils import data [as 别名]
def run(args, testset, action):
    if not torch.cuda.is_available():
        args.device = 'cpu'
    args.device = torch.device(args.device)

    LOGGER.debug('Testing (PID=%d), %s', os.getpid(), args)

    model = action.create_model()
    if args.pretrained:
        assert os.path.isfile(args.pretrained)
        model.load_state_dict(torch.load(args.pretrained, map_location='cpu'))
    model.to(args.device)

    # dataloader
    testloader = torch.utils.data.DataLoader(
        testset,
        batch_size=1, shuffle=False, num_workers=args.workers)

    # testing
    LOGGER.debug('tests, begin')
    action.eval_1(model, testloader, args.device)
    LOGGER.debug('tests, end') 
开发者ID:vinits5,项目名称:pointnet-registration-framework,代码行数:24,代码来源:test_pointlk.py

示例4: eval_1

# 需要导入模块: from torch import utils [as 别名]
# 或者: from torch.utils import data [as 别名]
def eval_1(self, model, testloader, device):
        model.eval()
        vloss = 0.0
        gloss = 0.0
        count = 0
        with torch.no_grad():
            for i, data in enumerate(testloader):
                loss, loss_g = self.compute_loss(model, data, device)

                vloss1 = loss.item()
                vloss += vloss1
                gloss1 = loss_g.item()
                gloss += gloss1
                count += 1

        ave_vloss = float(vloss)/count
        ave_gloss = float(gloss)/count
        return ave_vloss, ave_gloss 
开发者ID:vinits5,项目名称:pointnet-registration-framework,代码行数:20,代码来源:train_pointlk.py

示例5: train_1

# 需要导入模块: from torch import utils [as 别名]
# 或者: from torch.utils import data [as 别名]
def train_1(self, model, trainloader, optimizer, device):
        model.train()
        vloss = 0.0
        pred  = 0.0
        count = 0
        for i, data in enumerate(trainloader):
            target, output, loss = self.compute_loss(model, data, device)
            # forward + backward + optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            loss1 = loss.item()
            vloss += loss1
            count += output.size(0)

            _, pred1 = output.max(dim=1)
            ag = (pred1 == target)
            am = ag.sum()
            pred += am.item()

        running_loss = float(vloss)/count
        accuracy = float(pred)/count
        return running_loss, accuracy 
开发者ID:vinits5,项目名称:pointnet-registration-framework,代码行数:26,代码来源:train_classifier.py

示例6: eval_1

# 需要导入模块: from torch import utils [as 别名]
# 或者: from torch.utils import data [as 别名]
def eval_1(self, model, testloader, device):
        model.eval()
        vloss = 0.0
        pred  = 0.0
        count = 0
        with torch.no_grad():
            for i, data in enumerate(testloader):
                target, output, loss = self.compute_loss(model, data, device)

                loss1 = loss.item()
                vloss += loss1
                count += output.size(0)

                _, pred1 = output.max(dim=1)
                ag = (pred1 == target)
                am = ag.sum()
                pred += am.item()

        ave_loss = float(vloss)/count
        accuracy = float(pred)/count
        return ave_loss, accuracy 
开发者ID:vinits5,项目名称:pointnet-registration-framework,代码行数:23,代码来源:train_classifier.py

示例7: set_input

# 需要导入模块: from torch import utils [as 别名]
# 或者: from torch.utils import data [as 别名]
def set_input(self, input:torch.Tensor):
        """ Set input and ground truth

        Args:
            input (FloatTensor): Input data for batch i.
        """
        with torch.no_grad():
            self.input.resize_(input[0].size()).copy_(input[0])
            self.gt.resize_(input[1].size()).copy_(input[1])
            self.label.resize_(input[1].size())

            # Copy the first batch as the fixed input.
            if self.total_steps == self.opt.batchsize:
                self.fixed_input.resize_(input[0].size()).copy_(input[0])

    ## 
开发者ID:samet-akcay,项目名称:ganomaly,代码行数:18,代码来源:model.py

示例8: make_batch_data_sampler

# 需要导入模块: from torch import utils [as 别名]
# 或者: from torch.utils import data [as 别名]
def make_batch_data_sampler(
    dataset, sampler, aspect_grouping, images_per_batch, num_iters=None, start_iter=0
):
    if aspect_grouping:
        if not isinstance(aspect_grouping, (list, tuple)):
            aspect_grouping = [aspect_grouping]
        aspect_ratios = _compute_aspect_ratios(dataset)
        group_ids = _quantize(aspect_ratios, aspect_grouping)
        batch_sampler = samplers.GroupedBatchSampler(
            sampler, group_ids, images_per_batch, drop_uneven=False
        )
    else:
        batch_sampler = torch.utils.data.sampler.BatchSampler(
            sampler, images_per_batch, drop_last=False
        )
    if num_iters is not None:
        batch_sampler = samplers.IterationBasedBatchSampler(
            batch_sampler, num_iters, start_iter
        )
    return batch_sampler 
开发者ID:Res2Net,项目名称:Res2Net-maskrcnn,代码行数:22,代码来源:build.py

示例9: __init__

# 需要导入模块: from torch import utils [as 别名]
# 或者: from torch.utils import data [as 别名]
def __init__(self, data, transform=lambda data: data, one_hot=None, shuffle=False, dir=None):
        """
        Load the cached data (.pkl) into memory.
        :author 申瑞珉 (Ruimin Shen)
        :param data: A list contains the data samples (dict).
        :param transform: A function transforms (usually performs a sequence of data augmentation operations) the labels in a dict.
        :param one_hot: If a int value (total number of classes) is given, the class label (key "cls") will be generated in a one-hot format.
        :param shuffle: Shuffle the loaded dataset.
        :param dir: The directory to store the exception data.
        """
        self.data = data
        if shuffle:
            random.shuffle(self.data)
        self.transform = transform
        self.one_hot = None if one_hot is None else sklearn.preprocessing.OneHotEncoder(one_hot, dtype=np.float32)
        self.dir = dir 
开发者ID:ruiminshen,项目名称:yolo2-pytorch,代码行数:18,代码来源:data.py

示例10: __call__

# 需要导入模块: from torch import utils [as 别名]
# 或者: from torch.utils import data [as 别名]
def __call__(self, batch):
        height, width = self.next_size()
        dim = max(len(data['cls']) for data in batch)
        _batch = []
        for data in batch:
            try:
                data = self.resize(data, height, width)
                data['image'] = self.transform_image(data['image'])
                data = padding_labels(data, dim)
                if self.transform_tensor is not None:
                    data['tensor'] = self.transform_tensor(data['image'])
                _batch.append(data)
            except:
                if self.dir is not None:
                    os.makedirs(self.dir, exist_ok=True)
                    name = self.__module__ + '.' + type(self).__name__
                    with open(os.path.join(self.dir, name + '.pkl'), 'wb') as f:
                        pickle.dump(data, f)
                raise
        return torch.utils.data.dataloader.default_collate(_batch) 
开发者ID:ruiminshen,项目名称:yolo2-pytorch,代码行数:22,代码来源:data.py

示例11: get_loader

# 需要导入模块: from torch import utils [as 别名]
# 或者: from torch.utils import data [as 别名]
def get_loader(self):
        paths = [os.path.join(self.cache_dir, phase + '.pkl') for phase in self.config.get('eval', 'phase').split()]
        dataset = utils.data.Dataset(utils.data.load_pickles(paths))
        logging.info('num_examples=%d' % len(dataset))
        size = tuple(map(int, self.config.get('image', 'size').split()))
        try:
            workers = self.config.getint('data', 'workers')
        except configparser.NoOptionError:
            workers = multiprocessing.cpu_count()
        collate_fn = utils.data.Collate(
            transform.parse_transform(self.config, self.config.get('transform', 'resize_eval')),
            [size],
            transform_image=transform.get_transform(self.config, self.config.get('transform', 'image_test').split()),
            transform_tensor=transform.get_transform(self.config, self.config.get('transform', 'tensor').split()),
        )
        return torch.utils.data.DataLoader(dataset, batch_size=self.args.batch_size, num_workers=workers, collate_fn=collate_fn) 
开发者ID:ruiminshen,项目名称:yolo2-pytorch,代码行数:18,代码来源:eval.py

示例12: __init__

# 需要导入模块: from torch import utils [as 别名]
# 或者: from torch.utils import data [as 别名]
def __init__(self, args, config):
        self.args = args
        self.config = config
        self.model_dir = utils.get_model_dir(config)
        self.category = utils.get_category(config)
        self.anchors = torch.from_numpy(utils.get_anchors(config)).contiguous()
        self.dnn = utils.parse_attr(config.get('model', 'dnn'))(model.ConfigChannels(config), self.anchors, len(self.category))
        self.dnn.eval()
        logging.info(humanize.naturalsize(sum(var.cpu().numpy().nbytes for var in self.dnn.state_dict().values())))
        if torch.cuda.is_available():
            self.dnn.cuda()
        self.height, self.width = tuple(map(int, config.get('image', 'size').split()))
        output = self.dnn(torch.autograd.Variable(utils.ensure_device(torch.zeros(1, 3, self.height, self.width)), volatile=True))
        _, _, self.rows, self.cols = output.size()
        self.i, self.j = self.rows // 2, self.cols // 2
        self.output = output[:, :, self.i, self.j]
        dataset = Dataset(self.height, self.width)
        try:
            workers = self.config.getint('data', 'workers')
        except configparser.NoOptionError:
            workers = multiprocessing.cpu_count()
        self.loader = torch.utils.data.DataLoader(dataset, batch_size=self.args.batch_size, num_workers=workers) 
开发者ID:ruiminshen,项目名称:yolo2-pytorch,代码行数:24,代码来源:receptive_field_analyzer.py

示例13: __call__

# 需要导入模块: from torch import utils [as 别名]
# 或者: from torch.utils import data [as 别名]
def __call__(self):
        changed = np.zeros([self.height, self.width], np.bool)
        for yx in tqdm.tqdm(self.loader):
            batch_size = yx.size(0)
            tensor = torch.zeros(batch_size, 3, self.height, self.width)
            for i, _yx in enumerate(torch.unbind(yx)):
                y, x = torch.unbind(_yx)
                tensor[i, :, y, x] = 1
            tensor = utils.ensure_device(tensor)
            output = self.dnn(torch.autograd.Variable(tensor, volatile=True))
            output = output[:, :, self.i, self.j]
            cmp = output == self.output
            cmp = torch.prod(cmp, -1).data
            for _yx, c in zip(torch.unbind(yx), torch.unbind(cmp)):
                y, x = torch.unbind(_yx)
                changed[y, x] = c
        return changed 
开发者ID:ruiminshen,项目名称:yolo2-pytorch,代码行数:19,代码来源:receptive_field_analyzer.py

示例14: get_loader

# 需要导入模块: from torch import utils [as 别名]
# 或者: from torch.utils import data [as 别名]
def get_loader(self):
        paths = [os.path.join(self.cache_dir, phase + '.pkl') for phase in self.config.get('train', 'phase').split()]
        dataset = utils.data.Dataset(
            utils.data.load_pickles(paths),
            transform=transform.augmentation.get_transform(self.config, self.config.get('transform', 'augmentation').split()),
            one_hot=None if self.config.getboolean('train', 'cross_entropy') else len(self.category),
            shuffle=self.config.getboolean('data', 'shuffle'),
            dir=os.path.join(self.model_dir, 'exception'),
        )
        logging.info('num_examples=%d' % len(dataset))
        try:
            workers = self.config.getint('data', 'workers')
            if torch.cuda.is_available():
                workers = workers * torch.cuda.device_count()
        except configparser.NoOptionError:
            workers = multiprocessing.cpu_count()
        collate_fn = utils.data.Collate(
            transform.parse_transform(self.config, self.config.get('transform', 'resize_train')),
            utils.train.load_sizes(self.config),
            maintain=self.config.getint('data', 'maintain'),
            transform_image=transform.get_transform(self.config, self.config.get('transform', 'image_train').split()),
            transform_tensor=transform.get_transform(self.config, self.config.get('transform', 'tensor').split()),
            dir=os.path.join(self.model_dir, 'exception'),
        )
        return torch.utils.data.DataLoader(dataset, batch_size=self.args.batch_size * torch.cuda.device_count() if torch.cuda.is_available() else self.args.batch_size, shuffle=True, num_workers=workers, collate_fn=collate_fn, pin_memory=torch.cuda.is_available()) 
开发者ID:ruiminshen,项目名称:yolo2-pytorch,代码行数:27,代码来源:train.py

示例15: iterate

# 需要导入模块: from torch import utils [as 别名]
# 或者: from torch.utils import data [as 别名]
def iterate(self, data):
        for key in data:
            t = data[key]
            if torch.is_tensor(t):
                data[key] = utils.ensure_device(t)
        tensor = torch.autograd.Variable(data['tensor'])
        pred = pybenchmark.profile('inference')(model._inference)(self.inference, tensor)
        height, width = data['image'].size()[1:3]
        rows, cols = pred['feature'].size()[-2:]
        loss, debug = pybenchmark.profile('loss')(model.loss)(self.anchors, norm_data(data, height, width, rows, cols), pred, self.config.getfloat('model', 'threshold'))
        loss_hparam = {key: loss[key] * self.config.getfloat('hparam', key) for key in loss}
        loss_total = sum(loss_hparam.values())
        self.optimizer.zero_grad()
        loss_total.backward()
        try:
            clip = self.config.getfloat('train', 'clip')
            nn.utils.clip_grad_norm(self.inference.parameters(), clip)
        except configparser.NoOptionError:
            pass
        self.optimizer.step()
        return dict(
            height=height, width=width, rows=rows, cols=cols,
            data=data, pred=pred, debug=debug,
            loss_total=loss_total, loss=loss, loss_hparam=loss_hparam,
        ) 
开发者ID:ruiminshen,项目名称:yolo2-pytorch,代码行数:27,代码来源:train.py


注:本文中的torch.utils.data方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。