Skip to content

Commit

Permalink
Refactor algorithm tests to take transport parameter
Browse files Browse the repository at this point in the history
This is required to run the same algorithm tests against multiple
transports. The test suite is comprised of both class algorithms and
function algorithms. The former require fixed buffers, which are only
supported on the tcp backend (and will be phased out). The latter
require the unbound buffers, which are supported on both the tcp and
uv (not yet merged) transports. This means we can't set the transport
globally and forget about it, since it depends on the test whether or
not it works with the selected transport.

Instead, by parameterizing the tests we allow for multiple transports
to be tested in the same test run. If a transport is not
available (e.g. not compiled) then those tests can be skipped.

Side notes:

* There are a series of small modifications included here as well,
  including more consistent variable naming, and a more consistent
  usage of test parameterization generator functions.
* Because this commit touches many lines anyway, I also ran
  clang-format on all files under `gloo/test` to make the style
  consistent.

ghstack-source-id: 46122ee90d31cd486e5fba23165d0b1d3f952a42
Pull Request resolved: #208
  • Loading branch information
pietern committed Aug 14, 2019
1 parent 21a32c1 commit dfb770d
Show file tree
Hide file tree
Showing 21 changed files with 1,104 additions and 1,099 deletions.
96 changes: 45 additions & 51 deletions gloo/test/allgather_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,19 @@ namespace test {
namespace {

// Test parameterization.
using Param = std::tuple<int, int, int>;
using Param = std::tuple<Transport, int, int, int>;

// Test fixture.
class AllgatherTest : public BaseTest,
public ::testing::WithParamInterface<Param> {};

TEST_P(AllgatherTest, VarNumPointer) {
auto contextSize = std::get<0>(GetParam());
auto dataSize = std::get<1>(GetParam());
auto numPtrs = std::get<2>(GetParam());
const auto transport = std::get<0>(GetParam());
const auto contextSize = std::get<0>(GetParam());
const auto dataSize = std::get<1>(GetParam());
const auto numPtrs = std::get<2>(GetParam());

spawn(contextSize, [&](std::shared_ptr<Context> context) {
spawn(transport, contextSize, [&](std::shared_ptr<Context> context) {
Fixture<float> inFixture(context, numPtrs, dataSize);
inFixture.assignValues();

Expand Down Expand Up @@ -59,11 +60,12 @@ TEST_P(AllgatherTest, VarNumPointer) {
}

TEST_F(AllgatherTest, MultipleAlgorithms) {
auto contextSize = 4;
auto dataSize = 1000;
auto numPtrs = 8;
const auto transport = Transport::TCP;
const auto contextSize = 4;
const auto dataSize = 1000;
const auto numPtrs = 8;

spawn(contextSize, [&](std::shared_ptr<Context> context) {
spawn(transport, contextSize, [&](std::shared_ptr<Context> context) {
Fixture<float> inFixture(context, numPtrs, dataSize);
inFixture.assignValues();

Expand Down Expand Up @@ -92,82 +94,74 @@ TEST_F(AllgatherTest, MultipleAlgorithms) {
});
}

std::vector<int> genMemorySizes() {
std::vector<int> v;
v.push_back(sizeof(float));
v.push_back(100);
v.push_back(1000);
v.push_back(10000);
return v;
}

INSTANTIATE_TEST_CASE_P(
AllgatherRing,
AllgatherTest,
::testing::Combine(
::testing::ValuesIn(kTransportsForClassAlgorithms),
::testing::Range(2, 10),
::testing::ValuesIn(genMemorySizes()),
::testing::Values(4, 100, 1000, 10000),
::testing::Range(1, 4)));

using NewParam = std::tuple<int, int, bool>;
using NewParam = std::tuple<Transport, int, int, bool>;

class AllgatherNewTest : public BaseTest,
public ::testing::WithParamInterface<NewParam> {};

TEST_P(AllgatherNewTest, Default) {
auto contextSize = std::get<0>(GetParam());
auto dataSize = std::get<1>(GetParam());
auto passBuffers = std::get<2>(GetParam());
const auto transport = std::get<0>(GetParam());
const auto contextSize = std::get<1>(GetParam());
const auto dataSize = std::get<2>(GetParam());
const auto passBuffers = std::get<3>(GetParam());

auto validate = [dataSize](
const std::shared_ptr<Context>& context,
Fixture<uint64_t>& output) {
const std::shared_ptr<Context>& context,
Fixture<uint64_t>& output) {
const auto ptr = output.getPointer();
const auto stride = context->size;
for (auto j = 0; j < context->size; j++) {
for (auto k = 0; k < dataSize; k++) {
ASSERT_EQ(j + k * stride, ptr[k + j * dataSize])
<< "Mismatch at index " << (k + j * dataSize);
<< "Mismatch at index " << (k + j * dataSize);
}
}
};

spawn(contextSize, [&](std::shared_ptr<Context> context) {
auto input = Fixture<uint64_t>(context, 1, dataSize);
auto output = Fixture<uint64_t>(context, 1, contextSize * dataSize);

AllgatherOptions opts(context);

if (passBuffers) {
// Run with (optionally cached) unbound buffers in options
opts.setInput<uint64_t>(context->createUnboundBuffer(
input.getPointer(),
dataSize * sizeof(uint64_t)));
opts.setOutput<uint64_t>(context->createUnboundBuffer(
output.getPointer(),
contextSize * dataSize * sizeof(uint64_t)));
} else {
// Run with raw pointers and sizes in options
opts.setInput(input.getPointer(), dataSize);
opts.setOutput(output.getPointer(), contextSize * dataSize);
}
spawn(transport, contextSize, [&](std::shared_ptr<Context> context) {
auto input = Fixture<uint64_t>(context, 1, dataSize);
auto output = Fixture<uint64_t>(context, 1, contextSize * dataSize);

AllgatherOptions opts(context);

if (passBuffers) {
// Run with (optionally cached) unbound buffers in options
opts.setInput<uint64_t>(context->createUnboundBuffer(
input.getPointer(), dataSize * sizeof(uint64_t)));
opts.setOutput<uint64_t>(context->createUnboundBuffer(
output.getPointer(), contextSize * dataSize * sizeof(uint64_t)));
} else {
// Run with raw pointers and sizes in options
opts.setInput(input.getPointer(), dataSize);
opts.setOutput(output.getPointer(), contextSize * dataSize);
}

input.assignValues();
allgather(opts);
validate(context, output);
});
input.assignValues();
allgather(opts);
validate(context, output);
});
}

INSTANTIATE_TEST_CASE_P(
AllgatherNewDefault,
AllgatherNewTest,
::testing::Combine(
::testing::ValuesIn(kTransportsForFunctionAlgorithms),
::testing::Values(1, 2, 4, 7),
::testing::ValuesIn(genMemorySizes()),
::testing::Values(4, 100, 1000, 10000),
::testing::Values(false, true)));

TEST_F(AllgatherNewTest, TestTimeout) {
spawn(2, [&](std::shared_ptr<Context> context) {
spawn(Transport::TCP, 2, [&](std::shared_ptr<Context> context) {
Fixture<uint64_t> input(context, 1, 1);
Fixture<uint64_t> output(context, 1, context->size);
AllgatherOptions opts(context);
Expand Down
14 changes: 8 additions & 6 deletions gloo/test/allgatherv_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,18 @@ namespace gloo {
namespace test {
namespace {

using Param = std::tuple<int, int, bool>;
using Param = std::tuple<Transport, int, int, bool>;

class AllgathervTest : public BaseTest,
public ::testing::WithParamInterface<Param> {};

TEST_P(AllgathervTest, Default) {
auto contextSize = std::get<0>(GetParam());
auto dataSize = std::get<1>(GetParam());
auto passBuffers = std::get<2>(GetParam());
const auto transport = std::get<0>(GetParam());
const auto contextSize = std::get<1>(GetParam());
const auto dataSize = std::get<2>(GetParam());
const auto passBuffers = std::get<3>(GetParam());

spawn(contextSize, [&](std::shared_ptr<Context> context) {
spawn(transport, contextSize, [&](std::shared_ptr<Context> context) {
// This test uses the same output size for every iteration,
// but assigns different counts to different ranks.
std::vector<uint64_t> output(
Expand Down Expand Up @@ -84,12 +85,13 @@ INSTANTIATE_TEST_CASE_P(
AllgathervDefault,
AllgathervTest,
::testing::Combine(
::testing::ValuesIn(kTransportsForFunctionAlgorithms),
::testing::Values(1, 2, 4, 7),
::testing::Values(1, 10, 100, 1000),
::testing::Values(false, true)));

TEST_F(AllgathervTest, TestTimeout) {
spawn(2, [&](std::shared_ptr<Context> context) {
spawn(Transport::TCP, 2, [&](std::shared_ptr<Context> context) {
Fixture<uint64_t> output(context, 1, context->size);
std::vector<size_t> counts({1, 1});
AllgathervOptions opts(context);
Expand Down
Loading

0 comments on commit dfb770d

Please sign in to comment.