本文整理汇总了Python中torch.FloatStorage方法的典型用法代码示例。如果您正苦于以下问题:Python torch.FloatStorage方法的具体用法?Python torch.FloatStorage怎么用?Python torch.FloatStorage使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类torch
的用法示例。
在下文中一共展示了torch.FloatStorage方法的13个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: test_from_file
# 需要导入模块: import torch [as 别名]
# 或者: from torch import FloatStorage [as 别名]
def test_from_file(self):
size = 10000
with tempfile.NamedTemporaryFile() as f:
s1 = torch.FloatStorage.from_file(f.name, True, size)
t1 = torch.FloatTensor(s1).copy_(torch.randn(size))
# check mapping
s2 = torch.FloatStorage.from_file(f.name, True, size)
t2 = torch.FloatTensor(s2)
self.assertEqual(t1, t2, 0)
# check changes to t1 from t2
rnum = random.uniform(-1, 1)
t1.fill_(rnum)
self.assertEqual(t1, t2, 0)
# check changes to t2 from t1
rnum = random.uniform(-1, 1)
t2.fill_(rnum)
self.assertEqual(t1, t2, 0)
示例2: test_serialization
# 需要导入模块: import torch [as 别名]
# 或者: from torch import FloatStorage [as 别名]
def test_serialization(self):
a = [torch.randn(5, 5).float() for i in range(2)]
b = [a[i % 2] for i in range(4)]
b += [a[0].storage()]
b += [a[0].storage()[1:4]]
for use_name in (False, True):
with tempfile.NamedTemporaryFile() as f:
handle = f if not use_name else f.name
torch.save(b, handle)
f.seek(0)
c = torch.load(handle)
self.assertEqual(b, c, 0)
self.assertTrue(isinstance(c[0], torch.FloatTensor))
self.assertTrue(isinstance(c[1], torch.FloatTensor))
self.assertTrue(isinstance(c[2], torch.FloatTensor))
self.assertTrue(isinstance(c[3], torch.FloatTensor))
self.assertTrue(isinstance(c[4], torch.FloatStorage))
c[0].fill_(10)
self.assertEqual(c[0], c[2], 0)
self.assertEqual(c[4], torch.FloatStorage(25).fill_(10), 0)
c[1].fill_(20)
self.assertEqual(c[1], c[3], 0)
self.assertEqual(c[4], c[5][1:4], 0)
示例3: test_serialization_backwards_compat
# 需要导入模块: import torch [as 别名]
# 或者: from torch import FloatStorage [as 别名]
def test_serialization_backwards_compat(self):
a = [torch.arange(1 + i, 26 + i).view(5, 5).float() for i in range(2)]
b = [a[i % 2] for i in range(4)]
b += [a[0].storage()]
b += [a[0].storage()[1:4]]
path = download_file('https://download.pytorch.org/test_data/legacy_serialized.pt')
c = torch.load(path)
self.assertEqual(b, c, 0)
self.assertTrue(isinstance(c[0], torch.FloatTensor))
self.assertTrue(isinstance(c[1], torch.FloatTensor))
self.assertTrue(isinstance(c[2], torch.FloatTensor))
self.assertTrue(isinstance(c[3], torch.FloatTensor))
self.assertTrue(isinstance(c[4], torch.FloatStorage))
c[0].fill_(10)
self.assertEqual(c[0], c[2], 0)
self.assertEqual(c[4], torch.FloatStorage(25).fill_(10), 0)
c[1].fill_(20)
self.assertEqual(c[1], c[3], 0)
self.assertEqual(c[4][1:4], c[5], 0)
示例4: test_type_conversions
# 需要导入模块: import torch [as 别名]
# 或者: from torch import FloatStorage [as 别名]
def test_type_conversions(self):
x = torch.randn(5, 5)
self.assertIs(type(x.float()), torch.FloatTensor)
self.assertIs(type(x.cuda()), torch.cuda.DoubleTensor)
self.assertIs(type(x.cuda().float()), torch.cuda.FloatTensor)
self.assertIs(type(x.cuda().float().cpu()), torch.FloatTensor)
self.assertIs(type(x.cuda().float().cpu().int()), torch.IntTensor)
y = x.storage()
self.assertIs(type(y.float()), torch.FloatStorage)
self.assertIs(type(y.cuda()), torch.cuda.DoubleStorage)
self.assertIs(type(y.cuda().float()), torch.cuda.FloatStorage)
self.assertIs(type(y.cuda().float().cpu()), torch.FloatStorage)
self.assertIs(type(y.cuda().float().cpu().int()), torch.IntStorage)
示例5: assign
# 需要导入模块: import torch [as 别名]
# 或者: from torch import FloatStorage [as 别名]
def assign():
torch.FloatStorage(10)[1:-1] = '1'
示例6: test_from_buffer
# 需要导入模块: import torch [as 别名]
# 或者: from torch import FloatStorage [as 别名]
def test_from_buffer(self):
a = bytearray([1, 2, 3, 4])
self.assertEqual(torch.ByteStorage.from_buffer(a).tolist(), [1, 2, 3, 4])
shorts = torch.ShortStorage.from_buffer(a, 'big')
self.assertEqual(shorts.size(), 2)
self.assertEqual(shorts.tolist(), [258, 772])
ints = torch.IntStorage.from_buffer(a, 'little')
self.assertEqual(ints.size(), 1)
self.assertEqual(ints[0], 67305985)
f = bytearray([0x40, 0x10, 0x00, 0x00])
floats = torch.FloatStorage.from_buffer(f, 'big')
self.assertEqual(floats.size(), 1)
self.assertEqual(floats[0], 2.25)
示例7: test_element_size
# 需要导入模块: import torch [as 别名]
# 或者: from torch import FloatStorage [as 别名]
def test_element_size(self):
byte = torch.ByteStorage().element_size()
char = torch.CharStorage().element_size()
short = torch.ShortStorage().element_size()
int = torch.IntStorage().element_size()
long = torch.LongStorage().element_size()
float = torch.FloatStorage().element_size()
double = torch.DoubleStorage().element_size()
self.assertEqual(byte, torch.ByteTensor().element_size())
self.assertEqual(char, torch.CharTensor().element_size())
self.assertEqual(short, torch.ShortTensor().element_size())
self.assertEqual(int, torch.IntTensor().element_size())
self.assertEqual(long, torch.LongTensor().element_size())
self.assertEqual(float, torch.FloatTensor().element_size())
self.assertEqual(double, torch.DoubleTensor().element_size())
self.assertGreater(byte, 0)
self.assertGreater(char, 0)
self.assertGreater(short, 0)
self.assertGreater(int, 0)
self.assertGreater(long, 0)
self.assertGreater(float, 0)
self.assertGreater(double, 0)
# These tests are portable, not necessarily strict for your system.
self.assertEqual(byte, 1)
self.assertEqual(char, 1)
self.assertGreaterEqual(short, 2)
self.assertGreaterEqual(int, 2)
self.assertGreaterEqual(int, short)
self.assertGreaterEqual(long, 4)
self.assertGreaterEqual(long, int)
self.assertGreaterEqual(double, float)
示例8: test_serialization
# 需要导入模块: import torch [as 别名]
# 或者: from torch import FloatStorage [as 别名]
def test_serialization(self):
a = [torch.randn(5, 5).float() for i in range(2)]
b = [a[i % 2] for i in range(4)]
b += [a[0].storage()]
b += [a[0].storage()[1:4]]
b += [torch.arange(1, 11).int()]
t1 = torch.FloatTensor().set_(a[0].storage()[1:4], 0, (3,), (1,))
t2 = torch.FloatTensor().set_(a[0].storage()[1:4], 0, (3,), (1,))
b += [(t1.storage(), t1.storage(), t2.storage())]
b += [a[0].storage()[0:2]]
for use_name in (False, True):
with tempfile.NamedTemporaryFile(delete=True) as f:
handle = f if not use_name else f.name
if sys.platform == 'win32' and use_name:
handle = tempfile.mktemp()
torch.save(b, handle)
f.seek(0)
c = torch.load(handle)
self.assertEqual(b, c, 0)
self.assertTrue(isinstance(c[0], torch.FloatTensor))
self.assertTrue(isinstance(c[1], torch.FloatTensor))
self.assertTrue(isinstance(c[2], torch.FloatTensor))
self.assertTrue(isinstance(c[3], torch.FloatTensor))
self.assertTrue(isinstance(c[4], torch.FloatStorage))
c[0].fill_(10)
self.assertEqual(c[0], c[2], 0)
self.assertEqual(c[4], torch.FloatStorage(25).fill_(10), 0)
c[1].fill_(20)
self.assertEqual(c[1], c[3], 0)
self.assertEqual(c[4], c[5][1:4], 0)
# check that serializing the same storage view object unpickles
# it as one object not two (and vice versa)
views = c[7]
self.assertEqual(views[0]._cdata, views[1]._cdata)
self.assertEqual(views[0], views[2])
self.assertNotEqual(views[0]._cdata, views[2]._cdata)
rootview = c[8]
self.assertEqual(rootview.data_ptr(), c[0].data_ptr())
示例9: test_serialization_backwards_compat
# 需要导入模块: import torch [as 别名]
# 或者: from torch import FloatStorage [as 别名]
def test_serialization_backwards_compat(self):
a = [torch.arange(1 + i, 26 + i).view(5, 5).float() for i in range(2)]
b = [a[i % 2] for i in range(4)]
b += [a[0].storage()]
b += [a[0].storage()[1:4]]
DATA_URL = 'https://download.pytorch.org/test_data/legacy_serialized.pt'
data_dir = os.path.join(os.path.dirname(__file__), 'data')
test_file_path = os.path.join(data_dir, 'legacy_serialized.pt')
succ = download_file(DATA_URL, test_file_path)
if not succ:
warnings.warn(("Couldn't download the test file for backwards compatibility! "
"Tests will be incomplete!"), RuntimeWarning)
return
c = torch.load(test_file_path)
self.assertEqual(b, c, 0)
self.assertTrue(isinstance(c[0], torch.FloatTensor))
self.assertTrue(isinstance(c[1], torch.FloatTensor))
self.assertTrue(isinstance(c[2], torch.FloatTensor))
self.assertTrue(isinstance(c[3], torch.FloatTensor))
self.assertTrue(isinstance(c[4], torch.FloatStorage))
c[0].fill_(10)
self.assertEqual(c[0], c[2], 0)
self.assertEqual(c[4], torch.FloatStorage(25).fill_(10), 0)
c[1].fill_(20)
self.assertEqual(c[1], c[3], 0)
self.assertEqual(c[4], c[5][1:4], 0)
示例10: test_serialization
# 需要导入模块: import torch [as 别名]
# 或者: from torch import FloatStorage [as 别名]
def test_serialization(self):
a = [torch.randn(5, 5).float() for i in range(2)]
b = [a[i % 2] for i in range(4)]
b += [a[0].storage()]
b += [a[0].storage()[1:4]]
b += [torch.arange(1, 11).int()]
t1 = torch.FloatTensor().set_(a[0].storage()[1:4], 0, (3,), (1,))
t2 = torch.FloatTensor().set_(a[0].storage()[1:4], 0, (3,), (1,))
b += [(t1.storage(), t1.storage(), t2.storage())]
b += [a[0].storage()[0:2]]
for use_name in (False, True):
with tempfile.NamedTemporaryFile() as f:
handle = f if not use_name else f.name
torch.save(b, handle)
f.seek(0)
c = torch.load(handle)
self.assertEqual(b, c, 0)
self.assertTrue(isinstance(c[0], torch.FloatTensor))
self.assertTrue(isinstance(c[1], torch.FloatTensor))
self.assertTrue(isinstance(c[2], torch.FloatTensor))
self.assertTrue(isinstance(c[3], torch.FloatTensor))
self.assertTrue(isinstance(c[4], torch.FloatStorage))
c[0].fill_(10)
self.assertEqual(c[0], c[2], 0)
self.assertEqual(c[4], torch.FloatStorage(25).fill_(10), 0)
c[1].fill_(20)
self.assertEqual(c[1], c[3], 0)
self.assertEqual(c[4], c[5][1:4], 0)
# check that serializing the same storage view object unpickles
# it as one object not two (and vice versa)
views = c[7]
self.assertEqual(views[0]._cdata, views[1]._cdata)
self.assertEqual(views[0], views[2])
self.assertNotEqual(views[0]._cdata, views[2]._cdata)
rootview = c[8]
self.assertEqual(rootview.data_ptr(), c[0].data_ptr())
示例11: test_serialization
# 需要导入模块: import torch [as 别名]
# 或者: from torch import FloatStorage [as 别名]
def test_serialization(self):
a = [torch.randn(5, 5).float() for i in range(2)]
b = [a[i % 2] for i in range(4)]
b += [a[0].storage()]
b += [a[0].storage()[1:4]]
b += [torch.arange(1, 11).int()]
t1 = torch.FloatTensor().set_(a[0].storage()[1:4], 0, (3,), (1,))
t2 = torch.FloatTensor().set_(a[0].storage()[1:4], 0, (3,), (1,))
b += [(t1.storage(), t1.storage(), t2.storage())]
b += [a[0].storage()[0:2]]
for use_name in (False, True):
with tempfile.NamedTemporaryFile() as f:
handle = f if not use_name else f.name
torch.save(b, handle)
f.seek(0)
c = torch.load(handle)
self.assertEqual(b, c, 0)
self.assertTrue(isinstance(c[0], torch.FloatTensor))
self.assertTrue(isinstance(c[1], torch.FloatTensor))
self.assertTrue(isinstance(c[2], torch.FloatTensor))
self.assertTrue(isinstance(c[3], torch.FloatTensor))
self.assertTrue(isinstance(c[4], torch.FloatStorage))
c[0].fill_(10)
self.assertEqual(c[0], c[2], 0)
self.assertEqual(c[4], torch.FloatStorage(25).fill_(10), 0)
c[1].fill_(20)
self.assertEqual(c[1], c[3], 0)
self.assertEqual(c[4][1:4], c[5], 0)
# check that serializing the same storage view object unpickles
# it as one object not two (and vice versa)
views = c[7]
self.assertEqual(views[0]._cdata, views[1]._cdata)
self.assertEqual(views[0], views[2])
self.assertNotEqual(views[0]._cdata, views[2]._cdata)
rootview = c[8]
self.assertEqual(rootview.data_ptr(), c[0].data_ptr())
示例12: test_element_size
# 需要导入模块: import torch [as 别名]
# 或者: from torch import FloatStorage [as 别名]
def test_element_size(self):
byte = torch.ByteStorage().element_size()
char = torch.CharStorage().element_size()
short = torch.ShortStorage().element_size()
int = torch.IntStorage().element_size()
long = torch.LongStorage().element_size()
float = torch.FloatStorage().element_size()
double = torch.DoubleStorage().element_size()
self.assertEqual(byte, torch.ByteTensor().element_size())
self.assertEqual(char, torch.CharTensor().element_size())
self.assertEqual(short, torch.ShortTensor().element_size())
self.assertEqual(int, torch.IntTensor().element_size())
self.assertEqual(long, torch.LongTensor().element_size())
self.assertEqual(float, torch.FloatTensor().element_size())
self.assertEqual(double, torch.DoubleTensor().element_size())
self.assertGreater(byte, 0)
self.assertGreater(char, 0)
self.assertGreater(short, 0)
self.assertGreater(int, 0)
self.assertGreater(long, 0)
self.assertGreater(float, 0)
self.assertGreater(double, 0)
# These tests are portable, not necessarily strict for your system.
self.assertEqual(byte, 1)
self.assertEqual(char, 1)
self.assertGreaterEqual(short, 2)
self.assertGreaterEqual(int, 2)
self.assertGreaterEqual(int, short)
self.assertGreaterEqual(long, 4)
self.assertGreaterEqual(long, int)
self.assertGreaterEqual(double, float)
示例13: test_serialization
# 需要导入模块: import torch [as 别名]
# 或者: from torch import FloatStorage [as 别名]
def test_serialization(self):
a = [torch.randn(5, 5).float() for i in range(2)]
b = [a[i % 2] for i in range(4)]
b += [a[0].storage()]
b += [a[0].storage()[1:4]]
b += [torch.arange(1, 11).int()]
t1 = torch.FloatTensor().set_(a[0].storage()[1:4], 0, (3,), (1,))
t2 = torch.FloatTensor().set_(a[0].storage()[1:4], 0, (3,), (1,))
b += [(t1.storage(), t1.storage(), t2.storage())]
b += [a[0].storage()[0:2]]
for use_name in (False, True):
# Passing filename to torch.save(...) will cause the file to be opened twice,
# which is not supported on Windows
if sys.platform == "win32" and use_name:
continue
with tempfile.NamedTemporaryFile() as f:
handle = f if not use_name else f.name
torch.save(b, handle)
f.seek(0)
c = torch.load(handle)
self.assertEqual(b, c, 0)
self.assertTrue(isinstance(c[0], torch.FloatTensor))
self.assertTrue(isinstance(c[1], torch.FloatTensor))
self.assertTrue(isinstance(c[2], torch.FloatTensor))
self.assertTrue(isinstance(c[3], torch.FloatTensor))
self.assertTrue(isinstance(c[4], torch.FloatStorage))
c[0].fill_(10)
self.assertEqual(c[0], c[2], 0)
self.assertEqual(c[4], torch.FloatStorage(25).fill_(10), 0)
c[1].fill_(20)
self.assertEqual(c[1], c[3], 0)
self.assertEqual(c[4][1:4], c[5], 0)
# check that serializing the same storage view object unpickles
# it as one object not two (and vice versa)
views = c[7]
self.assertEqual(views[0]._cdata, views[1]._cdata)
self.assertEqual(views[0], views[2])
self.assertNotEqual(views[0]._cdata, views[2]._cdata)
rootview = c[8]
self.assertEqual(rootview.data_ptr(), c[0].data_ptr())