本文整理汇总了Python中torch.utils.data.read方法的典型用法代码示例。如果您正苦于以下问题:Python data.read方法的具体用法?Python data.read怎么用?Python data.read使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类torch.utils.data
的用法示例。
在下文中一共展示了data.read方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: read_image_file
# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import read [as 别名]
def read_image_file(path):
with open(path, 'rb') as f:
data = f.read()
assert get_int(data[:4]) == 2051
length = get_int(data[4:8])
num_rows = get_int(data[8:12])
num_cols = get_int(data[12:16])
images = []
idx = 16
for l in range(length):
img = []
images.append(img)
for r in range(num_rows):
row = []
img.append(row)
for c in range(num_cols):
row.append(parse_byte(data[idx]))
idx += 1
assert len(images) == length
return torch.ByteTensor(images).view(-1, 28, 28)
示例2: read_image_file
# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import read [as 别名]
def read_image_file(path):
with open(path, 'rb') as f:
data = f.read()
assert get_int(data[:4]) == 2051
length = get_int(data[4:8])
num_rows = get_int(data[8:12])
num_cols = get_int(data[12:16])
images = []
idx = 16
for l in range(length):
img = []
images.append(img)
for r in range(num_rows):
row = []
img.append(row)
for c in range(num_cols):
row.append(parse_byte(data[idx]))
idx += 1
assert len(images) == length
return torch.ByteTensor(images).view(-1, 28, 28)
示例3: parse_cat_file
# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import read [as 别名]
def parse_cat_file(path):
"""
-cat file stores corresponding category of images
Return:
ByteTensor of shape (N,)
"""
with open(path, 'rb') as f:
header = parse_header(f)
num, = header['dim']
struct.unpack('<BBBB', f.read(4))
struct.unpack('<BBBB', f.read(4))
labels = np.zeros(shape=num, dtype=np.int32)
for i in range(num):
labels[i], = struct.unpack('<i', f.read(4))
return torch.from_numpy(labels).long()
示例4: parse_dat_file
# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import read [as 别名]
def parse_dat_file(path):
"""
-dat file stores N image pairs. Each image pair,
[i, :, :] and [i+1, :, :], includes two images
taken from two cameras. They share the category
and additional information.
Return:
ByteTensor of shape (2*N, 96, 96)
"""
with open(path, 'rb') as f:
header = parse_header(f)
num, c, h, w = header['dim']
imgs = np.zeros(shape=(num * c, h, w), dtype=np.uint8)
for i in range(num * c):
img = struct.unpack('<' + h * w * 'B', f.read(h * w))
imgs[i] = np.uint8(np.reshape(img, newshape=(h, w)))
return torch.from_numpy(imgs)
示例5: parse_info_file
# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import read [as 别名]
def parse_info_file(path):
"""
-info file stores the additional info for each image.
The specific meaning of each dimension is:
(:, 0): 10 instances
(:, 1): 9 elevation
(:, 2): 18 azimuth
(:, 3): 6 lighting conditions
Return:
ByteTensor of shape (N, 4)
"""
with open(path, 'rb') as f:
header = parse_header(f)
num, num_info = header['dim']
struct.unpack('<BBBB', f.read(4))
info = np.zeros(shape=(num, num_info), dtype=np.int32)
for r in range(num):
for c in range(num_info):
info[r, c], = struct.unpack('<i', f.read(4))
return torch.from_numpy(info)
示例6: read_label_file
# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import read [as 别名]
def read_label_file(self, path):
with open(path, 'rb') as f:
data = f.read()
assert self.get_int(data[:4]) == 2049
length = self.get_int(data[4:8])
parsed = np.frombuffer(data, dtype=np.uint8, offset=8)
return torch.from_numpy(parsed).view(length).long()
示例7: read_image_file
# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import read [as 别名]
def read_image_file(self, path):
with open(path, 'rb') as f:
data = f.read()
assert self.get_int(data[:4]) == 2051
length = self.get_int(data[4:8])
num_rows = self.get_int(data[8:12])
num_cols = self.get_int(data[12:16])
images = []
parsed = np.frombuffer(data, dtype=np.uint8, offset=16)
return torch.from_numpy(parsed).view(length, num_rows, num_cols)
###########################################################################
# Callable datasets
###########################################################################
示例8: download
# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import read [as 别名]
def download(self):
"""
Download, and unzip in the correct location.
Returns:
"""
import urllib
import zipfile
if self.check_exists():
return
# download files
try:
os.makedirs(self.root)
except OSError as e:
if e.errno == errno.EEXIST:
pass
else:
raise
for url in self.urls:
print('Downloading ' + url)
data = urllib.request.urlopen(url)
filename = url.rpartition('/')[2]
file_path = os.path.join(self.root, filename)
ext = os.path.splitext(file_path)[1]
with open(file_path, 'wb') as f:
f.write(data.read())
if ext == '.zip':
with zipfile.ZipFile(file_path) as zip_f:
zip_f.extractall(self.root)
os.unlink(file_path)
print('Done!')
示例9: download
# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import read [as 别名]
def download(self):
from six.moves import urllib
import zipfile
if self._check_exists():
return
# download files
try:
os.makedirs(os.path.join(self.root, self.raw_folder))
os.makedirs(os.path.join(self.root, self.processed_folder))
except OSError as e:
if e.errno == errno.EEXIST:
pass
else:
raise
for url in self.urls:
print('== Downloading ' + url)
data = urllib.request.urlopen(url)
filename = url.rpartition('/')[2]
file_path = os.path.join(self.root, self.raw_folder, filename)
with open(file_path, 'wb') as f:
f.write(data.read())
file_processed = os.path.join(self.root, self.processed_folder)
print("== Unzip from " + file_path + " to " + file_processed)
zip_ref = zipfile.ZipFile(file_path, 'r')
zip_ref.extractall(file_processed)
zip_ref.close()
print("Download finished.")
示例10: get_current_classes
# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import read [as 别名]
def get_current_classes(fname):
with open(fname) as f:
classes = f.read().replace('/', os.sep).splitlines()
return classes
开发者ID:orobix,项目名称:Prototypical-Networks-for-Few-shot-Learning-PyTorch,代码行数:6,代码来源:omniglot_dataset.py
示例11: read_label_file
# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import read [as 别名]
def read_label_file(path):
with open(path, 'rb') as f:
data = f.read()
assert get_int(data[:4]) == 2049
length = get_int(data[4:8])
labels = [parse_byte(b) for b in data[8:]]
assert len(labels) == length
return torch.LongTensor(labels)
示例12: read_label_file
# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import read [as 别名]
def read_label_file(path, extension):
with open(path, 'rb') as f:
data = f.read()
assert get_int(data[:4]) == 2049
length = get_int(data[4:8])
parsed = np.frombuffer(data, dtype=np.uint8, offset=8)
multi_labels_l = np.zeros((1*length),dtype=np.long)
multi_labels_r = np.zeros((1*length),dtype=np.long)
for im_id in range(length):
for rim in range(1):
multi_labels_l[1*im_id+rim] = parsed[im_id]
multi_labels_r[1*im_id+rim] = parsed[extension[1*im_id+rim]]
return torch.from_numpy(parsed).view(length).long(), torch.from_numpy(multi_labels_l).view(length*1).long(), torch.from_numpy(multi_labels_r).view(length*1).long()
示例13: read_image_file
# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import read [as 别名]
def read_image_file(path):
with open(path, 'rb') as f:
data = f.read()
assert get_int(data[:4]) == 2051
length = get_int(data[4:8])
num_rows = get_int(data[8:12])
num_cols = get_int(data[12:16])
images = []
parsed = np.frombuffer(data, dtype=np.uint8, offset=16)
pv = parsed.reshape(length, num_rows, num_cols)
multi_length = length * 1
multi_data = np.zeros((1*length, num_rows, num_cols))
extension = np.zeros(1*length, dtype=np.int32)
for left in range(length):
chosen_ones = np.random.permutation(length)[:1]
extension[left*1:(left+1)*1] = chosen_ones
for j, right in enumerate(chosen_ones):
lim = pv[left,:,:]
rim = pv[right,:,:]
new_im = np.zeros((36,36))
new_im[0:28,0:28] = lim
new_im[6:34,6:34] = rim
new_im[6:28,6:28] = np.maximum(lim[6:28,6:28], rim[0:22,0:22])
multi_data_im = m.imresize(new_im, (28, 28), interp='nearest')
multi_data[left*1 + j,:,:] = multi_data_im
return torch.from_numpy(parsed).view(length, num_rows, num_cols), torch.from_numpy(multi_data).view(length,num_rows, num_cols), extension
示例14: read_label_file
# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import read [as 别名]
def read_label_file(path):
with open(path, 'rb') as f:
data = f.read()
assert get_int(data[:4]) == 2049
length = get_int(data[4:8])
parsed = np.frombuffer(data, dtype=np.uint8, offset=8)
return torch.from_numpy(parsed).view(length).long()
示例15: read_image_file
# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import read [as 别名]
def read_image_file(path):
with open(path, 'rb') as f:
data = f.read()
assert get_int(data[:4]) == 2051
length = get_int(data[4:8])
num_rows = get_int(data[8:12])
num_cols = get_int(data[12:16])
images = []
parsed = np.frombuffer(data, dtype=np.uint8, offset=16)
return torch.from_numpy(parsed).view(length, num_rows, num_cols)