forked from dmlc/tl2cgen
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_compiler_param.cc
101 lines (94 loc) · 3.21 KB
/
test_compiler_param.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
/*!
* Copyright (c) 2023 by Contributors
* \file test_compiler_param.cc
* \author Hyunsu Cho
* \brief C++ tests for ingesting compiler parameters
*/
#include <fmt/format.h>
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include <tl2cgen/compiler_param.h>
#include <tl2cgen/logging.h>
using namespace testing; // NOLINT(build/namespaces)
namespace tl2cgen::compiler {
TEST(CompilerParam, Basic) {
std::string json_str = R"JSON(
{
"quantize": 1,
"parallel_comp": 100,
"native_lib_name": "predictor",
"annotate_in": "annotation.json",
"verbose": 3,
"code_folding_req": 1.0,
"dump_array_as_elf": 0
})JSON";
CompilerParam param = CompilerParam::ParseFromJSON(json_str.c_str());
EXPECT_EQ(param.quantize, 1);
EXPECT_EQ(param.parallel_comp, 100);
EXPECT_EQ(param.native_lib_name, "predictor");
EXPECT_EQ(param.annotate_in, "annotation.json");
EXPECT_EQ(param.verbose, 3);
EXPECT_EQ(param.code_folding_req, 1.0);
EXPECT_EQ(param.dump_array_as_elf, 0);
}
TEST(CompilerParam, NonExistentKey) {
std::string json_str = R"JSON(
{
"quantize": 1,
"parallel_comp": 100,
"nonexistent": 0.3
})JSON";
EXPECT_THAT([&]() { CompilerParam::ParseFromJSON(json_str.c_str()); },
ThrowsMessage<tl2cgen::Error>(HasSubstr("Unrecognized key 'nonexistent'")));
json_str = R"JSON(
{
"quantize": 1,
"parallel_comp": 100,
"extra_object": {
"extra": 30
}
})JSON";
EXPECT_THAT([&]() { CompilerParam::ParseFromJSON(json_str.c_str()); },
ThrowsMessage<tl2cgen::Error>(HasSubstr("Unrecognized key 'extra_object'")));
}
TEST(CompilerParam, IncorrectType) {
std::string json_str = R"JSON(
{
"quantize": "bad_type"
})JSON";
EXPECT_THAT([&]() { CompilerParam::ParseFromJSON(json_str.c_str()); },
ThrowsMessage<tl2cgen::Error>(HasSubstr("Expected an integer for 'quantize'")));
json_str = R"JSON(
{
"code_folding_req": {
"bad_type": 30
}
})JSON";
EXPECT_THAT([&]() { CompilerParam::ParseFromJSON(json_str.c_str()); },
ThrowsMessage<tl2cgen::Error>(
HasSubstr("Expected a floating-point decimal for 'code_folding_req'")));
json_str = R"JSON(
{
"native_lib_name": -10.0
})JSON";
EXPECT_THAT([&]() { CompilerParam::ParseFromJSON(json_str.c_str()); },
ThrowsMessage<tl2cgen::Error>(HasSubstr("Expected a string for 'native_lib_name'")));
json_str = R"JSON(
{
"code_folding_req": 13bad
})JSON";
EXPECT_THAT([&]() { CompilerParam::ParseFromJSON(json_str.c_str()); },
ThrowsMessage<tl2cgen::Error>(HasSubstr("Got an invalid JSON string")));
}
TEST(CompilerParam, InvalidRange) {
std::string json_str;
for (auto const& key : std::vector<std::string>{
"quantize", "parallel_comp", "code_folding_req", "dump_array_as_elf"}) {
std::string literal = (key == "code_folding_req" ? "-1.0" : "-1");
json_str = fmt::format(R"JSON({{ "{0}": {1} }})JSON", key, literal);
std::string expected_error = fmt::format("'{}' must be 0 or greater", key);
EXPECT_THAT([&]() { CompilerParam::ParseFromJSON(json_str.c_str()); },
ThrowsMessage<tl2cgen::Error>(HasSubstr(expected_error)));
}
}
} // namespace tl2cgen::compiler