Skip to content

Commit

Permalink
test sparse dp, broadcast_coalesced, reduce_add_coalesced
Browse files Browse the repository at this point in the history
  • Loading branch information
ssnl authored and ezyang committed Oct 28, 2017
1 parent 01be4d6 commit 91a8d33
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 24 deletions.
106 changes: 82 additions & 24 deletions test/test_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,17 @@ def make_tensor(t, *sizes):
return t(*sizes).copy_(torch.randn(*sizes))


def make_sparse_tensor(t, n, *sizes):
assert t.is_sparse
tensor = t()
i = tensor._indices()
i = i.new(len(sizes), n).copy_(
torch.cat([torch.LongTensor(1, n).random_(s) for s in sizes], 0))
v = tensor._values()
v = v.new(n).copy_(torch.randn(n))
return t(i, v, torch.Size(sizes))


def small_2d(t):
return make_tensor(t, S, S)

Expand Down Expand Up @@ -480,44 +491,43 @@ def test_broadcast_cpu(self):
def test_broadcast_gpu(self):
self._test_broadcast(torch.randn(5, 5))

@unittest.skipIf(torch.cuda.device_count() < 2, "only one GPU detected")
def test_broadcast_coalesced(self):
numel = 5
num_bytes = numel * 8
tensors = [
torch.randn(numel).long().cuda(),
torch.randn(numel).cuda(),
torch.randn(numel).long().cuda(),
torch.randn(numel).long().cuda(),
torch.randn(numel * 2).int().cuda(), # int is 2x shorter
torch.randn(numel).cuda(),
]

@staticmethod
def _test_broadcast_coalesced(self, tensors, buffer_size):
b_tensors = [comm.broadcast(t, (0, 1)) for t in tensors]
for (_, bt), t in zip(b_tensors, tensors):
self.assertEqual(bt.get_device(), 1)
self.assertEqual(bt, t)
self.assertIsInstance(bt, type(t))

bc_tensors = comm.broadcast_coalesced(tensors, (0, 1), buffer_size=num_bytes * 5 // 2)
bc_tensors = comm.broadcast_coalesced(tensors, (0, 1), buffer_size=buffer_size)
bc_tensors_t = list(zip(*bc_tensors))
self.assertEqual(b_tensors, bc_tensors_t)
for (_, bt), (_, bct) in zip(b_tensors, bc_tensors_t):
self.assertEqual(bt.get_device(), bct.get_device())
self.assertIsInstance(bct, type(bt))

@unittest.skipIf(torch.cuda.device_count() < 2, "only one GPU detected")
def test_reduce_add(self):
x = torch.randn(5, 5)
y = torch.randn(5, 5)
x_cuda = x.cuda(0)
y_cuda = y.cuda(1)
result = comm.reduce_add((x_cuda, y_cuda))
self.assertEqual(result.get_device(), 0)
self.assertEqual(result.cpu(), x + y)
def test_broadcast_coalesced(self):
numel = 5
num_bytes = numel * 8
tensors = [
make_sparse_tensor(torch.cuda.sparse.DoubleTensor, 1, 2, 3),
torch.randn(numel).long().cuda(),
torch.randn(numel).cuda(),
make_sparse_tensor(torch.cuda.sparse.DoubleTensor, 10, 2, 3),
make_sparse_tensor(torch.cuda.sparse.DoubleTensor, 5, 2, 3),
make_sparse_tensor(torch.cuda.sparse.LongTensor, 7, 3, 3),
make_sparse_tensor(torch.cuda.sparse.FloatTensor, 2, 2, 3),
torch.randn(numel).long().cuda(),
torch.randn(numel).long().cuda(),
make_sparse_tensor(torch.cuda.sparse.LongTensor, 3, 2, 7),
torch.randn(numel * 2).int().cuda(), # int is 2x shorter
torch.randn(numel).cuda(),
]
self._test_broadcast_coalesced(self, tensors, num_bytes * 5 // 2)

@unittest.skipIf(torch.cuda.device_count() < 2, "only one GPU detected")
def test_reduce_add_coalesced(self):
def test_broadcast_coalesced_dense_only(self):
numel = 5
num_bytes = numel * 8
tensors = [
Expand All @@ -528,6 +538,20 @@ def test_reduce_add_coalesced(self):
torch.randn(numel * 2).int().cuda(), # int is 2x shorter
torch.randn(numel).cuda(),
]
self._test_broadcast_coalesced(self, tensors, num_bytes * 5 // 2)

@unittest.skipIf(torch.cuda.device_count() < 2, "only one GPU detected")
def test_reduce_add(self):
x = torch.randn(5, 5)
y = torch.randn(5, 5)
x_cuda = x.cuda(0)
y_cuda = y.cuda(1)
result = comm.reduce_add((x_cuda, y_cuda))
self.assertEqual(result.get_device(), 0)
self.assertEqual(result.cpu(), x + y)

@staticmethod
def _test_reduce_add_coalesced(self, tensors, buffer_size):
dup_tensors = [tensors, list(map(lambda t: t.cuda(1), tensors))]

r_tensors = list(map(comm.reduce_add, zip(*dup_tensors)))
Expand All @@ -536,12 +560,46 @@ def test_reduce_add_coalesced(self):
self.assertEqual(r, t * 2)
self.assertIsInstance(r, type(t))

rc_tensors = comm.reduce_add_coalesced(dup_tensors, buffer_size=num_bytes * 5 // 2)
rc_tensors = comm.reduce_add_coalesced(dup_tensors, buffer_size=buffer_size)
self.assertEqual(r_tensors, rc_tensors)
for r, rc in zip(r_tensors, rc_tensors):
self.assertEqual(rc.get_device(), r.get_device())
self.assertIsInstance(rc, type(r))

@unittest.skipIf(torch.cuda.device_count() < 2, "only one GPU detected")
def test_reduce_add_coalesced(self):
numel = 5
num_bytes = numel * 8
tensors = [
make_sparse_tensor(torch.cuda.sparse.DoubleTensor, 1, 2, 3),
torch.randn(numel).long().cuda(),
torch.randn(numel).cuda(),
make_sparse_tensor(torch.cuda.sparse.DoubleTensor, 10, 2, 3),
make_sparse_tensor(torch.cuda.sparse.DoubleTensor, 5, 2, 3),
make_sparse_tensor(torch.cuda.sparse.LongTensor, 7, 3, 3),
make_sparse_tensor(torch.cuda.sparse.FloatTensor, 2, 2, 3),
torch.randn(numel).long().cuda(),
torch.randn(numel).long().cuda(),
make_sparse_tensor(torch.cuda.sparse.LongTensor, 3, 2, 7),
torch.randn(numel * 2).int().cuda(), # int is 2x shorter
torch.randn(numel).cuda(),
]
self._test_reduce_add_coalesced(self, tensors, num_bytes * 5 // 2)

@unittest.skipIf(torch.cuda.device_count() < 2, "only one GPU detected")
def test_reduce_add_coalesced_dense_only(self):
numel = 5
num_bytes = numel * 8
tensors = [
torch.randn(numel).long().cuda(),
torch.randn(numel).cuda(),
torch.randn(numel).long().cuda(),
torch.randn(numel).long().cuda(),
torch.randn(numel * 2).int().cuda(), # int is 2x shorter
torch.randn(numel).cuda(),
]
self._test_reduce_add_coalesced(self, tensors, num_bytes * 5 // 2)

def _test_scatter(self, input, chunk_sizes=None, dim=0):
if torch.cuda.device_count() < 2:
raise unittest.SkipTest("only one GPU detected")
Expand Down
27 changes: 27 additions & 0 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1518,6 +1518,33 @@ def test_data_parallel(self):
l = l.cuda()
out = dp.data_parallel(l, i)

@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
def test_data_parallel_sparse(self):
l = nn.Embedding(10, 5, sparse=True).cuda(1)
i = Variable(torch.LongTensor(20, 5).random_(0, 10).cuda(1))
expected_out = l(i)
loss = expected_out.sum()
loss.backward()
expected_grads = []
for param in l.parameters():
expected_grads.append(param.grad.clone())
dev_ids_list = [(0, 1), (1, 0)]
for dev_id in dev_ids_list:
with torch.cuda.device(dev_id[0]):
l.cuda()
l.zero_grad()
out = dp.data_parallel(l, i, dev_id)
loss = out.sum()
loss.backward()
self.assertEqual(out.get_device(), dev_id[0])
self.assertEqual(out.data, expected_out.data)
for expected, param in zip(expected_grads, l.parameters()):
self.assertEqual(param.grad.data, expected.data)

# Check for None device_ids
l = l.cuda()
out = dp.data_parallel(l, i)

@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
def test_data_parallel_nested_output(self):
def fn(input):
Expand Down

0 comments on commit 91a8d33

Please sign in to comment.