Skip to content

Commit

Permalink
Add support for multiple tables and update CFI
Browse files Browse the repository at this point in the history
This commits makes a number of changes to the WebAssembly format,
some of which exceed the feature set desired for the MVP.

(1) It adds support for updated table definitions, including the
default, elementType, initial, and max attributes, plus a name.
Currently, the initial and max attributes must be equal to the
number of elements. The elementType attribute is interpreted as
a FunctionType index, and type homogeneity is enforced on table
elements, unless the specified FunctionType has name "anyfunc",
which corresponds to a FunctionType with a none parameter and
return type none. Format:
(table <name> [default] <type> <entries>)

(2) It adds support for multiple tables. If tables are used,
currently the first table must be default, and the remainder
must not. Example:
(table "foo" default (type $FUNCSIG$i) $a)
(table "bla" (type $anyfunc) $b $c $d)

(3) Indirect calls have an immediate argument that specifies
the index of the function call table. Example:
(call_indirect "foo" $FUNCSIG$i (get_local $1))

(4) Corresponding upstream LLVM changes are required to use
multiple tables, but the updated format is backwards compatible.
Example:
i32.call_indirect $0=, $pop0
i32.call_indirect.1 $0=, $pop0, $1, $2, $3

(5) Generating WebAssembly from code built with Clang/LLVM CFI now
utilizes multiple tables. This is the only enabled use case for
multiple tables; all others will default to a single table, if
tables are used. The value passed in the .indidx assembler
directive is now interpreted as the index of the indirect call
table to assign.
  • Loading branch information
ddcc committed Jul 26, 2016
1 parent 96e226d commit 55a766c
Show file tree
Hide file tree
Showing 48 changed files with 531 additions and 319 deletions.
5 changes: 3 additions & 2 deletions src/asm2wasm.h
Original file line number Diff line number Diff line change
Expand Up @@ -660,11 +660,11 @@ void Asm2WasmBuilder::processAsm(Ref ast) {
// TODO: when not using aliasing function pointers, we could merge them by noticing that
// index 0 in each table is the null func, and each other index should only have one
// non-null func. However, that breaks down when function pointer casts are emulated.
functionTableStarts[name] = wasm.table.names.size(); // this table starts here
functionTableStarts[name] = wasm.getDefaultTable()->values.size(); // this table starts here
Ref contents = value[1];
for (unsigned k = 0; k < contents->size(); k++) {
IString curr = contents[k][1]->getIString();
wasm.table.names.push_back(curr);
wasm.getDefaultTable()->values.push_back(curr);
}
} else {
abort_on("invalid var element", pair);
Expand Down Expand Up @@ -1404,6 +1404,7 @@ Function* Asm2WasmBuilder::processFunction(Ref ast) {
}
// function pointers
auto ret = allocator.alloc<CallIndirect>();
ret->table = wasm.getDefaultTable()->name;
Ref target = ast[1];
assert(target[0] == SUB && target[1][0] == NAME && target[2][0] == BINARY && target[2][1] == AND && target[2][3][0] == NUM); // FUNCTION_TABLE[(expr) & mask]
ret->target = process(target[2]); // TODO: as an optimization, we could look through the mask
Expand Down
4 changes: 3 additions & 1 deletion src/ast_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ struct ExpressionManipulator {
return ret;
}
Expression* visitCallIndirect(CallIndirect *curr) {
auto* ret = builder.makeCallIndirect(curr->fullType, curr->target, {}, curr->type);
auto* ret = builder.makeCallIndirect(curr->table, curr->fullType, curr->target, {}, curr->type);
for (Index i = 0; i < curr->operands.size(); i++) {
ret->operands.push_back(copy(curr->operands[i]));
}
Expand Down Expand Up @@ -459,6 +459,7 @@ struct ExpressionAnalyzer {
break;
}
case Expression::Id::CallIndirectId: {
CHECK(CallIndirect, table);
PUSH(CallIndirect, target);
CHECK(CallIndirect, fullType);
CHECK(CallIndirect, operands.size());
Expand Down Expand Up @@ -661,6 +662,7 @@ struct ExpressionAnalyzer {
break;
}
case Expression::Id::CallIndirectId: {
HASH_NAME(CallIndirect, table);
PUSH(CallIndirect, target);
HASH_NAME(CallIndirect, fullType);
HASH(CallIndirect, operands.size());
Expand Down
4 changes: 2 additions & 2 deletions src/binaryen-c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -586,7 +586,7 @@ BinaryenExpressionRef BinaryenSelect(BinaryenModuleRef module, BinaryenExpressio
}
BinaryenExpressionRef BinaryenReturn(BinaryenModuleRef module, BinaryenExpressionRef value) {
auto* ret = Builder(*((Module*)module)).makeReturn((Expression*)value);

if (tracing) {
auto id = noteExpression(ret);
std::cout << " expressions[" << id << "] = BinaryenReturn(the_module, expressions[" << expressions[value] << "]);\n";
Expand Down Expand Up @@ -730,7 +730,7 @@ void BinaryenSetFunctionTable(BinaryenModuleRef module, BinaryenFunctionRef* fun

auto* wasm = (Module*)module;
for (BinaryenIndex i = 0; i < numFuncs; i++) {
wasm->table.names.push_back(((Function*)funcs[i])->name);
wasm->getDefaultTable()->values.push_back(((Function*)funcs[i])->name);
}
}

Expand Down
10 changes: 6 additions & 4 deletions src/passes/DuplicateFunctionElimination.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,10 +123,12 @@ struct DuplicateFunctionElimination : public Pass {
replacerRunner.add<FunctionReplacer>(&replacements);
replacerRunner.run();
// replace in table
for (auto& name : module->table.names) {
auto iter = replacements.find(name);
if (iter != replacements.end()) {
name = iter->second;
for (auto& curr : module->tables) {
for (auto& name : curr->values) {
auto iter = replacements.find(name);
if (iter != replacements.end()) {
name = iter->second;
}
}
}
// replace in start
Expand Down
16 changes: 11 additions & 5 deletions src/passes/Print.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,9 @@ struct PrintSExpression : public Visitor<PrintSExpression> {
printCallBody(curr);
}
void visitCallIndirect(CallIndirect *curr) {
printOpening(o, "call_indirect ") << curr->fullType;
printOpening(o, "call_indirect ");
printText(o, curr->table.str);
o << ' ' << curr->fullType;
incIndent();
printFullLine(curr->target);
for (auto operand : curr->operands) {
Expand Down Expand Up @@ -555,8 +557,12 @@ struct PrintSExpression : public Visitor<PrintSExpression> {
decIndent();
}
void visitTable(Table *curr) {
printOpening(o, "table");
for (auto name : curr->names) {
printOpening(o, "table ");
printText(o, curr->name.str) << ' ';
if (curr->isDefault)
o << "default" << ' ';
visitFunctionType(curr->elementType, true);
for (auto name : curr->values) {
o << ' ';
printName(name);
}
Expand Down Expand Up @@ -621,9 +627,9 @@ struct PrintSExpression : public Visitor<PrintSExpression> {
visitExport(child.get());
o << maybeNewLine;
}
if (curr->table.names.size() > 0) {
for (auto& child : curr->tables) {
doIndent(o, indent);
visitTable(&curr->table);
visitTable(child.get());
o << maybeNewLine;
}
for (auto& child : curr->functions) {
Expand Down
6 changes: 4 additions & 2 deletions src/passes/RemoveUnusedFunctions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,10 @@ struct RemoveUnusedFunctions : public Pass {
root.push_back(module->getFunction(curr->value));
}
// For now, all functions that can be called indirectly are marked as roots.
for (auto& curr : module->table.names) {
root.push_back(module->getFunction(curr));
for (auto& child : module->tables) {
for (auto& curr : child->values) {
root.push_back(module->getFunction(curr));
}
}
// Compute function reachability starting from the root set.
DirectCallGraphAnalyzer analyzer(module, root);
Expand Down
6 changes: 4 additions & 2 deletions src/passes/ReorderFunctions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,10 @@ struct ReorderFunctions : public WalkerPass<PostWalker<ReorderFunctions, Visitor
for (auto& curr : module->exports) {
counts[curr->value]++;
}
for (auto& curr : module->table.names) {
counts[curr]++;
for (auto& child : module->tables) {
for (auto& curr : child->values) {
counts[curr]++;
}
}
std::sort(module->functions.begin(), module->functions.end(), [this](
const std::unique_ptr<Function>& a,
Expand Down
13 changes: 8 additions & 5 deletions src/s2wasm.h
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,11 @@ class S2WasmBuilder {
return cashew::IString(str.c_str(), false);
}

uint32_t getTable() {
if (!match(".")) return 0;
return getInt();
}

std::vector<char> getQuoted() {
assert(*s == '"');
s++;
Expand Down Expand Up @@ -622,7 +627,7 @@ class S2WasmBuilder {
};
wasm::Builder builder(*wasm);
std::vector<NameType> params;
int64_t indirectIndex = -1;
uint64_t indirectIndex = 0;
WasmType resultType = none;
std::vector<NameType> vars;

Expand All @@ -643,9 +648,6 @@ class S2WasmBuilder {
} else if (match(".indidx")) {
indirectIndex = getInt64();
skipWhitespace();
if (indirectIndex < 0) {
abort_on("indidx");
}
} else if (match(".local")) {
while (1) {
Name name = getNextId();
Expand Down Expand Up @@ -859,6 +861,7 @@ class S2WasmBuilder {
auto makeCall = [&](WasmType type) {
if (match("_indirect")) {
// indirect call
uint32_t table = getTable();
Name assign = getAssign();
int num = getNumInputs();
auto inputs = getInputs(num);
Expand All @@ -867,7 +870,7 @@ class S2WasmBuilder {
std::vector<Expression*> operands(++input, inputs.end());
auto* funcType = ensureFunctionType(getSig(type, operands), wasm);
assert(type == funcType->result);
auto* indirect = builder.makeCallIndirect(funcType, target, std::move(operands));
auto* indirect = builder.makeCallIndirect(linkerObj->getIndirectTable(table, funcType)->name, funcType, target, std::move(operands));
setOutput(indirect, assign);
} else {
// non-indirect call
Expand Down
85 changes: 62 additions & 23 deletions src/wasm-binary.h
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,7 @@ class WasmBinaryWriter : public Visitor<WasmBinaryWriter, void> {
writeSignatures();
writeImports();
writeFunctionSignatures();
writeFunctionTable();
writeFunctionTables();
writeMemory();
writeExports();
writeStart();
Expand Down Expand Up @@ -559,14 +559,22 @@ class WasmBinaryWriter : public Visitor<WasmBinaryWriter, void> {
finishSection(start);
}

int32_t getFunctionTypeIndex(Name type) {
Index getFunctionTypeIndex(Name type) {
// TODO: optimize
for (size_t i = 0; i < wasm->functionTypes.size(); i++) {
if (wasm->functionTypes[i]->name == type) return i;
}
abort();
}

Index getFunctionTableIndex(Name type) {
// TODO: optimize
for (size_t i = 0; i < wasm->tables.size(); i++) {
if (wasm->tables[i]->name == type) return i;
}
abort();
}

void writeImports() {
if (wasm->imports.size() == 0) return;
if (debug) std::cerr << "== writeImports" << std::endl;
Expand Down Expand Up @@ -670,7 +678,7 @@ class WasmBinaryWriter : public Visitor<WasmBinaryWriter, void> {

void writeExports() {
if (wasm->exports.size() == 0) return;
if (debug) std::cerr << "== writeexports" << std::endl;
if (debug) std::cerr << "== writeExports" << std::endl;
auto start = startSection(BinaryConsts::Section::ExportTable);
o << U32LEB(wasm->exports.size());
for (auto& curr : wasm->exports) {
Expand Down Expand Up @@ -709,8 +717,8 @@ class WasmBinaryWriter : public Visitor<WasmBinaryWriter, void> {
assert(mappedImports.count(name));
return mappedImports[name];
}
std::map<Name, uint32_t> mappedFunctions; // name of the Function => index

std::map<Name, uint32_t> mappedFunctions; // name of the Function => entry index
uint32_t getFunctionIndex(Name name) {
if (!mappedFunctions.size()) {
// Create name => index mapping.
Expand All @@ -723,13 +731,21 @@ class WasmBinaryWriter : public Visitor<WasmBinaryWriter, void> {
return mappedFunctions[name];
}

void writeFunctionTable() {
if (wasm->table.names.size() == 0) return;
if (debug) std::cerr << "== writeFunctionTable" << std::endl;
void writeFunctionTables() {
if (wasm->tables.size() == 0) return;
if (debug) std::cerr << "== writeFunctionTables" << std::endl;
auto start = startSection(BinaryConsts::Section::FunctionTable);
o << U32LEB(wasm->table.names.size());
for (auto name : wasm->table.names) {
o << U32LEB(getFunctionIndex(name));
o << U32LEB(wasm->tables.size());
for (auto& curr : wasm->tables) {
if (debug) std::cerr << "write one" << std::endl;
o << int8_t(curr->isDefault);
o << U32LEB(getFunctionTypeIndex(curr->elementType->name));
assert(curr->initial == curr->values.size() && curr->initial == curr->max);
o << U32LEB(curr->initial);
o << U32LEB(curr->max);
for (auto name : curr->values) {
o << U32LEB(getFunctionIndex(name));
}
}
finishSection(start);
}
Expand Down Expand Up @@ -909,7 +925,7 @@ class WasmBinaryWriter : public Visitor<WasmBinaryWriter, void> {
for (auto* operand : curr->operands) {
recurse(operand);
}
o << int8_t(BinaryConsts::CallIndirect) << U32LEB(curr->operands.size()) << U32LEB(getFunctionTypeIndex(curr->fullType));
o << int8_t(BinaryConsts::CallIndirect) << U32LEB(curr->operands.size()) << U32LEB(getFunctionTypeIndex(curr->fullType)) << U32LEB(getFunctionTableIndex(curr->table));
}
void visitGetLocal(GetLocal *curr) {
if (debug) std::cerr << "zz node: GetLocal " << (o.size() + 1) << std::endl;
Expand Down Expand Up @@ -1231,7 +1247,7 @@ class WasmBinaryBuilder {
else if (match(BinaryConsts::Section::Functions)) readFunctions();
else if (match(BinaryConsts::Section::ExportTable)) readExports();
else if (match(BinaryConsts::Section::DataSegments)) readDataSegments();
else if (match(BinaryConsts::Section::FunctionTable)) readFunctionTable();
else if (match(BinaryConsts::Section::FunctionTable)) readFunctionTables();
else if (match(BinaryConsts::Section::Names)) readNames();
else {
std::cerr << "unfamiliar section: ";
Expand Down Expand Up @@ -1427,6 +1443,11 @@ class WasmBinaryBuilder {
assert(numResults == 1);
curr->result = getWasmType();
}
// TODO: Handle "anyfunc" properly. This sets the name to "anyfunc" if
// it does not already exist, and matches the expected type signature.
if (!wasm.checkFunctionType(FunctionType::kAnyFunc) && FunctionType::isAnyFuncType(curr)) {
curr->name = FunctionType::kAnyFunc;
}
wasm.addFunctionType(curr);
}
}
Expand All @@ -1441,7 +1462,7 @@ class WasmBinaryBuilder {
curr->name = Name(std::string("import$") + std::to_string(i));
auto index = getU32LEB();
assert(index < wasm.functionTypes.size());
curr->type = wasm.getFunctionType(index);
curr->type = wasm.functionTypes[index].get();
assert(curr->type->name.is());
curr->module = getInlineString();
curr->base = getInlineString();
Expand Down Expand Up @@ -1596,9 +1617,12 @@ class WasmBinaryBuilder {
}
}

for (size_t index : functionTable) {
assert(index < wasm.functions.size());
wasm.table.names.push_back(wasm.functions[index]->name);
for (auto& pair : functionTable) {
assert(pair.first < wasm.tables.size());
assert(pair.second < wasm.functions.size());
assert(wasm.tables[pair.first]->values.size() <= wasm.tables[pair.first]->max);
assert(wasm.tables[pair.first]->elementType->name == FunctionType::kAnyFunc || wasm.tables[pair.first]->elementType == wasm.getFunctionType(wasm.functions[pair.second]->type));
wasm.tables[pair.first]->values.push_back(wasm.functions[pair.second]->name);
}
}

Expand All @@ -1618,14 +1642,28 @@ class WasmBinaryBuilder {
}
}

std::vector<size_t> functionTable;
std::vector<std::pair<size_t, size_t>> functionTable;

void readFunctionTable() {
if (debug) std::cerr << "== readFunctionTable" << std::endl;
auto num = getU32LEB();
for (size_t i = 0; i < num; i++) {
void readFunctionTables() {
if (debug) std::cerr << "== readFunctionTables" << std::endl;
size_t numTables = getU32LEB();
for (size_t i = 0; i < numTables; i++) {
if (debug) std::cerr << "read one" << std::endl;
auto curr = new Table;
auto flag = getInt8();
assert((!i && flag) || (i && !flag));
curr->isDefault = flag;
auto index = getU32LEB();
functionTable.push_back(index);
assert(index < functionTypes.size());
curr->elementType = wasm.getFunctionType(index);
curr->initial = getU32LEB();
curr->max = getU32LEB();
assert(curr->initial == curr->max);
for (size_t j = 0; j < curr->initial; j++) {
auto index = getU32LEB();
functionTable.push_back(std::make_pair<>(i, index));
}
wasm.addTable(curr);
}
}

Expand Down Expand Up @@ -1843,6 +1881,7 @@ class WasmBinaryBuilder {
curr->fullType = fullType->name;
auto num = fullType->params.size();
assert(num == arity);
curr->table = wasm.getTable(getU32LEB())->name;
curr->operands.resize(num);
for (size_t i = 0; i < num; i++) {
curr->operands[num - i - 1] = popExpression();
Expand Down
Loading

0 comments on commit 55a766c

Please sign in to comment.