当前位置: 首页>>代码示例>>Python>>正文


Python data.read方法代码示例

本文整理汇总了Python中torch.utils.data.read方法的典型用法代码示例。如果您正苦于以下问题:Python data.read方法的具体用法?Python data.read怎么用?Python data.read使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在torch.utils.data的用法示例。


在下文中一共展示了data.read方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: read_image_file

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import read [as 别名]
def read_image_file(path):
    with open(path, 'rb') as f:
        data = f.read()
        assert get_int(data[:4]) == 2051
        length = get_int(data[4:8])
        num_rows = get_int(data[8:12])
        num_cols = get_int(data[12:16])
        images = []
        idx = 16
        for l in range(length):
            img = []
            images.append(img)
            for r in range(num_rows):
                row = []
                img.append(row)
                for c in range(num_cols):
                    row.append(parse_byte(data[idx]))
                    idx += 1
        assert len(images) == length
        return torch.ByteTensor(images).view(-1, 28, 28) 
开发者ID:vithursant,项目名称:MagnetLoss-PyTorch,代码行数:22,代码来源:fashion.py

示例2: read_image_file

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import read [as 别名]
def read_image_file(path):
    with open(path, 'rb') as f:
        data = f.read()
        assert get_int(data[:4]) == 2051
        length = get_int(data[4:8])
        num_rows = get_int(data[8:12])
        num_cols = get_int(data[12:16])
        images = []
        idx = 16
        for l in range(length):
            img = []
            images.append(img)
            for r in range(num_rows):
                row = []
                img.append(row)
                for c in range(num_cols):
                    row.append(parse_byte(data[idx]))
                    idx += 1
        assert len(images) == length
    return torch.ByteTensor(images).view(-1, 28, 28) 
开发者ID:TLESORT,项目名称:Generative_Continual_Learning,代码行数:22,代码来源:fashion.py

示例3: parse_cat_file

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import read [as 别名]
def parse_cat_file(path):
    """
        -cat file stores corresponding category of images
        Return:
            ByteTensor of shape (N,)
    """
    with open(path, 'rb') as f:
        header = parse_header(f)
        num, = header['dim']
        struct.unpack('<BBBB', f.read(4))
        struct.unpack('<BBBB', f.read(4))

        labels = np.zeros(shape=num, dtype=np.int32)
        for i in range(num):
            labels[i], = struct.unpack('<i', f.read(4))

        return torch.from_numpy(labels).long() 
开发者ID:yl-1993,项目名称:Matrix-Capsules-EM-PyTorch,代码行数:19,代码来源:norb.py

示例4: parse_dat_file

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import read [as 别名]
def parse_dat_file(path):
    """
        -dat file stores N image pairs. Each image pair, 
        [i, :, :] and [i+1, :, :], includes two images
        taken from two cameras. They share the category
        and additional information.

        Return:
            ByteTensor of shape (2*N, 96, 96)
    """
    with open(path, 'rb') as f:
        header = parse_header(f)
        num, c, h, w = header['dim']
        imgs = np.zeros(shape=(num * c, h, w), dtype=np.uint8)

        for i in range(num * c):
            img = struct.unpack('<' + h * w * 'B', f.read(h * w))
            imgs[i] = np.uint8(np.reshape(img, newshape=(h, w)))

        return torch.from_numpy(imgs) 
开发者ID:yl-1993,项目名称:Matrix-Capsules-EM-PyTorch,代码行数:22,代码来源:norb.py

示例5: parse_info_file

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import read [as 别名]
def parse_info_file(path):
    """
        -info file stores the additional info for each image.
        The specific meaning of each dimension is:

            (:, 0): 10 instances
            (:, 1): 9 elevation
            (:, 2): 18 azimuth
            (:, 3): 6 lighting conditions

        Return:
            ByteTensor of shape (N, 4)
    """
    with open(path, 'rb') as f:
        header = parse_header(f)
        num, num_info = header['dim']
        struct.unpack('<BBBB', f.read(4))
        info = np.zeros(shape=(num, num_info), dtype=np.int32)
        for r in range(num):
            for c in range(num_info):
                info[r, c], = struct.unpack('<i', f.read(4))

        return torch.from_numpy(info) 
开发者ID:yl-1993,项目名称:Matrix-Capsules-EM-PyTorch,代码行数:25,代码来源:norb.py

示例6: read_label_file

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import read [as 别名]
def read_label_file(self, path):
        with open(path, 'rb') as f:
            data = f.read()
            assert self.get_int(data[:4]) == 2049
            length = self.get_int(data[4:8])
            parsed = np.frombuffer(data, dtype=np.uint8, offset=8)
            return torch.from_numpy(parsed).view(length).long() 
开发者ID:igolan,项目名称:bgd,代码行数:9,代码来源:datasets.py

示例7: read_image_file

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import read [as 别名]
def read_image_file(self, path):
        with open(path, 'rb') as f:
            data = f.read()
            assert self.get_int(data[:4]) == 2051
            length = self.get_int(data[4:8])
            num_rows = self.get_int(data[8:12])
            num_cols = self.get_int(data[12:16])
            images = []
            parsed = np.frombuffer(data, dtype=np.uint8, offset=16)
            return torch.from_numpy(parsed).view(length, num_rows, num_cols)


###########################################################################
# Callable datasets
########################################################################### 
开发者ID:igolan,项目名称:bgd,代码行数:17,代码来源:datasets.py

示例8: download

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import read [as 别名]
def download(self):
        """
        Download, and unzip in the correct location.
        Returns:

        """
        import urllib
        import zipfile

        if self.check_exists():
            return

        # download files
        try:
            os.makedirs(self.root)
        except OSError as e:
            if e.errno == errno.EEXIST:
                pass
            else:
                raise

        for url in self.urls:
            print('Downloading ' + url)
            data = urllib.request.urlopen(url)
            filename = url.rpartition('/')[2]
            file_path = os.path.join(self.root, filename)
            ext = os.path.splitext(file_path)[1]
            with open(file_path, 'wb') as f:
                f.write(data.read())
            if ext == '.zip':
                with zipfile.ZipFile(file_path) as zip_f:
                    zip_f.extractall(self.root)
                os.unlink(file_path)

        print('Done!') 
开发者ID:rdevon,项目名称:cortex,代码行数:37,代码来源:toysets.py

示例9: download

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import read [as 别名]
def download(self):
        from six.moves import urllib
        import zipfile

        if self._check_exists():
            return

        # download files
        try:
            os.makedirs(os.path.join(self.root, self.raw_folder))
            os.makedirs(os.path.join(self.root, self.processed_folder))
        except OSError as e:
            if e.errno == errno.EEXIST:
                pass
            else:
                raise

        for url in self.urls:
            print('== Downloading ' + url)
            data = urllib.request.urlopen(url)
            filename = url.rpartition('/')[2]
            file_path = os.path.join(self.root, self.raw_folder, filename)
            with open(file_path, 'wb') as f:
                f.write(data.read())
            file_processed = os.path.join(self.root, self.processed_folder)
            print("== Unzip from " + file_path + " to " + file_processed)
            zip_ref = zipfile.ZipFile(file_path, 'r')
            zip_ref.extractall(file_processed)
            zip_ref.close()
        print("Download finished.") 
开发者ID:dragen1860,项目名称:MAML-Pytorch,代码行数:32,代码来源:omniglot.py

示例10: get_current_classes

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import read [as 别名]
def get_current_classes(fname):
    with open(fname) as f:
        classes = f.read().replace('/', os.sep).splitlines()
    return classes 
开发者ID:orobix,项目名称:Prototypical-Networks-for-Few-shot-Learning-PyTorch,代码行数:6,代码来源:omniglot_dataset.py

示例11: read_label_file

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import read [as 别名]
def read_label_file(path):
    with open(path, 'rb') as f:
        data = f.read()
        assert get_int(data[:4]) == 2049
        length = get_int(data[4:8])
        labels = [parse_byte(b) for b in data[8:]]
        assert len(labels) == length
        return torch.LongTensor(labels) 
开发者ID:vithursant,项目名称:MagnetLoss-PyTorch,代码行数:10,代码来源:fashion.py

示例12: read_label_file

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import read [as 别名]
def read_label_file(path, extension):
    with open(path, 'rb') as f:
        data = f.read()
        assert get_int(data[:4]) == 2049
        length = get_int(data[4:8])
        parsed = np.frombuffer(data, dtype=np.uint8, offset=8)
        multi_labels_l = np.zeros((1*length),dtype=np.long)
        multi_labels_r = np.zeros((1*length),dtype=np.long)
        for im_id in range(length):
            for rim in range(1):
                multi_labels_l[1*im_id+rim] = parsed[im_id]
                multi_labels_r[1*im_id+rim] = parsed[extension[1*im_id+rim]] 
        return torch.from_numpy(parsed).view(length).long(), torch.from_numpy(multi_labels_l).view(length*1).long(), torch.from_numpy(multi_labels_r).view(length*1).long() 
开发者ID:intel-isl,项目名称:MultiObjectiveOptimization,代码行数:15,代码来源:multi_mnist_loader.py

示例13: read_image_file

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import read [as 别名]
def read_image_file(path):
    with open(path, 'rb') as f:
        data = f.read()
        assert get_int(data[:4]) == 2051
        length = get_int(data[4:8])
        num_rows = get_int(data[8:12])
        num_cols = get_int(data[12:16])
        images = []
        parsed = np.frombuffer(data, dtype=np.uint8, offset=16)
        pv = parsed.reshape(length, num_rows, num_cols)
        multi_length = length * 1
        multi_data = np.zeros((1*length, num_rows, num_cols))
        extension = np.zeros(1*length, dtype=np.int32)
        for left in range(length):
            chosen_ones = np.random.permutation(length)[:1]
            extension[left*1:(left+1)*1] = chosen_ones
            for j, right in enumerate(chosen_ones):
                lim = pv[left,:,:]
                rim = pv[right,:,:]
                new_im = np.zeros((36,36))
                new_im[0:28,0:28] = lim
                new_im[6:34,6:34] = rim
                new_im[6:28,6:28] = np.maximum(lim[6:28,6:28], rim[0:22,0:22])
                multi_data_im =  m.imresize(new_im, (28, 28), interp='nearest')
                multi_data[left*1 + j,:,:] = multi_data_im
        return torch.from_numpy(parsed).view(length, num_rows, num_cols), torch.from_numpy(multi_data).view(length,num_rows, num_cols), extension 
开发者ID:intel-isl,项目名称:MultiObjectiveOptimization,代码行数:28,代码来源:multi_mnist_loader.py

示例14: read_label_file

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import read [as 别名]
def read_label_file(path):
    with open(path, 'rb') as f:
        data = f.read()
        assert get_int(data[:4]) == 2049
        length = get_int(data[4:8])
        parsed = np.frombuffer(data, dtype=np.uint8, offset=8)
        return torch.from_numpy(parsed).view(length).long() 
开发者ID:michaal94,项目名称:torch_DCEC,代码行数:9,代码来源:mnist.py

示例15: read_image_file

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import read [as 别名]
def read_image_file(path):
    with open(path, 'rb') as f:
        data = f.read()
        assert get_int(data[:4]) == 2051
        length = get_int(data[4:8])
        num_rows = get_int(data[8:12])
        num_cols = get_int(data[12:16])
        images = []
        parsed = np.frombuffer(data, dtype=np.uint8, offset=16)
        return torch.from_numpy(parsed).view(length, num_rows, num_cols) 
开发者ID:michaal94,项目名称:torch_DCEC,代码行数:12,代码来源:mnist.py


注:本文中的torch.utils.data.read方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。