Skip to content

Commit

Permalink
add test to check zero_strided tensors in blas level 2 and 3 functions
Browse files Browse the repository at this point in the history
  • Loading branch information
albanD authored and soumith committed Oct 30, 2017
1 parent d04574b commit 1ae10a4
Showing 1 changed file with 98 additions and 0 deletions.
98 changes: 98 additions & 0 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,104 @@ def test_dot(self):
res2 += i * j
self.assertEqual(res1, res2)

# Test 0-strided
for tname, _prec in types.items():
v1 = torch.randn(1).type(tname).expand(100)
v2 = torch.randn(100).type(tname)
res1 = torch.dot(v1, v2)
res2 = 0
for i, j in zip(v1, v2):
res2 += i * j
self.assertEqual(res1, res2)

def test_ger(self):
types = {
'torch.DoubleTensor': 1e-8,
'torch.FloatTensor': 1e-4,
}
for tname, _prec in types.items():
v1 = torch.randn(100).type(tname)
v2 = torch.randn(100).type(tname)
res1 = torch.ger(v1, v2)
res2 = torch.zeros(100, 100).type(tname)
for i in range(100):
for j in range(100):
res2[i, j] = v1[i] * v2[j]
self.assertEqual(res1, res2)

# Test 0-strided
for tname, _prec in types.items():
v1 = torch.randn(1).type(tname).expand(100)
v2 = torch.randn(100).type(tname)
res1 = torch.ger(v1, v2)
res2 = torch.zeros(100, 100).type(tname)
for i in range(100):
for j in range(100):
res2[i, j] = v1[i] * v2[j]
self.assertEqual(res1, res2)

def test_addmv(self):
types = {
'torch.DoubleTensor': 1e-8,
'torch.FloatTensor': 1e-4,
}
for tname, _prec in types.items():
t = torch.randn(10).type(tname)
m = torch.randn(10, 100).type(tname)
v = torch.randn(100).type(tname)
res1 = torch.addmv(t, m, v)
res2 = torch.zeros(10).type(tname)
res2 += t
for i in range(10):
for j in range(100):
res2[i] += m[i, j] * v[j]
self.assertEqual(res1, res2)

# Test 0-strided
for tname, _prec in types.items():
t = torch.randn(1).type(tname).expand(10)
m = torch.randn(10, 1).type(tname).expand(10, 100)
v = torch.randn(100).type(tname)
res1 = torch.addmv(t, m, v)
res2 = torch.zeros(10).type(tname)
res2 += t
for i in range(10):
for j in range(100):
res2[i] += m[i, j] * v[j]
self.assertEqual(res1, res2)

def test_addmm(self):
types = {
'torch.DoubleTensor': 1e-8,
'torch.FloatTensor': 1e-4,
}
for tname, _prec in types.items():
M = torch.randn(10, 25).type(tname)
m1 = torch.randn(10, 50).type(tname)
m2 = torch.randn(50, 25).type(tname)
res1 = torch.addmm(M, m1, m2)
res2 = torch.zeros(10, 25).type(tname)
res2 += M
for i in range(10):
for j in range(25):
for k in range(50):
res2[i, j] += m1[i, k] * m2[k, j]
self.assertEqual(res1, res2)

# Test 0-strided
for tname, _prec in types.items():
M = torch.randn(10, 1).type(tname).expand(10, 25)
m1 = torch.randn(10, 1).type(tname).expand(10, 50)
m2 = torch.randn(50, 25).type(tname)
res1 = torch.addmm(M, m1, m2)
res2 = torch.zeros(10, 25).type(tname)
res2 += M
for i in range(10):
for j in range(25):
for k in range(50):
res2[i, j] += m1[i, k] * m2[k, j]
self.assertEqual(res1, res2)

def _testMath(self, torchfn, mathfn):
size = (10, 5)
# contiguous
Expand Down

0 comments on commit 1ae10a4

Please sign in to comment.