Skip to content

Commit

Permalink
Add support for multiple tables and update CFI
Browse files Browse the repository at this point in the history
This commits makes a number of changes to the WebAssembly format,
some of which exceed the feature set desired for the MVP.

(1) It adds support for updated table definitions, including the
default, elementType, initial, and max attributes, plus a name.
Currently, the initial and max attributes must be equal to the
number of elements. The elementType attribute is interpreted as
a FunctionType index, and type homogeneity is enforced on table
elements, unless the specified FunctionType has name "anyfunc",
which corresponds to a FunctionType with a none parameter and
return type none. Format:
(table <name> [default] <type> <entries>)

(2) It adds support for multiple tables. If tables are used,
currently the first table must be default, and the remainder
must not. Example:
(table "foo" default (type $FUNCSIG$i) $a)
(table "bla" (type $anyfunc) $b $c $d)

(3) Indirect calls have an immediate argument that specifies
the index of the function call table. Example:
(call_indirect "foo" $FUNCSIG$i (get_local $1))

(4) Corresponding upstream LLVM changes are required to use
multiple tables, but the updated format is backwards compatible.
Example:
i32.call_indirect $0=, $pop0
i32.call_indirect.1 $0=, $pop0, $1, $2, $3

(5) Generating WebAssembly from code built with Clang/LLVM CFI now
utilizes multiple tables. This is the only enabled use case for
multiple tables; all others will default to a single table, if
tables are used. The value passed in the .indidx assembler
directive is now interpreted as the index of the indirect call
table to assign.
  • Loading branch information
ddcc committed Aug 5, 2016
1 parent 1bdda7b commit 2566183
Show file tree
Hide file tree
Showing 47 changed files with 415 additions and 260 deletions.
1 change: 1 addition & 0 deletions src/asm2wasm.h
Original file line number Diff line number Diff line change
Expand Up @@ -1404,6 +1404,7 @@ Function* Asm2WasmBuilder::processFunction(Ref ast) {
}
// function pointers
auto ret = allocator.alloc<CallIndirect>();
ret->table = wasm.getDefaultTable()->name;
Ref target = ast[1];
assert(target[0] == SUB && target[1][0] == NAME && target[2][0] == BINARY && target[2][1] == AND && target[2][3][0] == NUM); // FUNCTION_TABLE[(expr) & mask]
ret->target = process(target[2]); // TODO: as an optimization, we could look through the mask
Expand Down
4 changes: 3 additions & 1 deletion src/ast_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ struct ExpressionManipulator {
return ret;
}
Expression* visitCallIndirect(CallIndirect *curr) {
auto* ret = builder.makeCallIndirect(curr->fullType, curr->target, {}, curr->type);
auto* ret = builder.makeCallIndirect(curr->table, curr->fullType, curr->target, {}, curr->type);
for (Index i = 0; i < curr->operands.size(); i++) {
ret->operands.push_back(copy(curr->operands[i]));
}
Expand Down Expand Up @@ -467,6 +467,7 @@ struct ExpressionAnalyzer {
break;
}
case Expression::Id::CallIndirectId: {
CHECK(CallIndirect, table);
PUSH(CallIndirect, target);
CHECK(CallIndirect, fullType);
CHECK(CallIndirect, operands.size());
Expand Down Expand Up @@ -678,6 +679,7 @@ struct ExpressionAnalyzer {
break;
}
case Expression::Id::CallIndirectId: {
HASH_NAME(CallIndirect, table);
PUSH(CallIndirect, target);
HASH_NAME(CallIndirect, fullType);
HASH(CallIndirect, operands.size());
Expand Down
10 changes: 8 additions & 2 deletions src/passes/Print.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,9 @@ struct PrintSExpression : public Visitor<PrintSExpression> {
printCallBody(curr);
}
void visitCallIndirect(CallIndirect *curr) {
printOpening(o, "call_indirect ") << curr->fullType;
printOpening(o, "call_indirect ");
printName(curr->table);
o << ' ' << curr->fullType;
incIndent();
printFullLine(curr->target);
for (auto operand : curr->operands) {
Expand Down Expand Up @@ -575,7 +577,11 @@ struct PrintSExpression : public Visitor<PrintSExpression> {
decIndent();
}
void visitTable(Table *curr) {
printOpening(o, "table");
printOpening(o, "table ");
printName(curr->name) << ' ';
if (curr->isDefault)
o << "default" << ' ';
visitFunctionType(curr->elementType, true);
for (auto name : curr->values) {
o << ' ';
printName(name);
Expand Down
8 changes: 7 additions & 1 deletion src/s2wasm.h
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,11 @@ class S2WasmBuilder {
return cashew::IString(str.c_str(), false);
}

uint32_t getTable() {
if (!match(".")) return 0;
return getInt();
}

std::vector<char> getQuoted() {
assert(*s == '"');
s++;
Expand Down Expand Up @@ -859,6 +864,7 @@ class S2WasmBuilder {
auto makeCall = [&](WasmType type) {
if (match("_indirect")) {
// indirect call
uint32_t table = getTable();
Name assign = getAssign();
int num = getNumInputs();
auto inputs = getInputs(num);
Expand All @@ -867,7 +873,7 @@ class S2WasmBuilder {
std::vector<Expression*> operands(++input, inputs.end());
auto* funcType = ensureFunctionType(getSig(type, operands), wasm);
assert(type == funcType->result);
auto* indirect = builder.makeCallIndirect(funcType, target, std::move(operands));
auto* indirect = builder.makeCallIndirect(linkerObj->getIndirectTable(table, funcType)->name, funcType, target, std::move(operands));
setOutput(indirect, assign);
} else {
// non-indirect call
Expand Down
43 changes: 35 additions & 8 deletions src/wasm-binary.h
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,14 @@ class WasmBinaryWriter : public Visitor<WasmBinaryWriter, void> {
abort();
}

Index getFunctionTableIndex(Name type) {
// TODO: optimize
for (size_t i = 0; i < wasm->tables.size(); i++) {
if (wasm->tables[i]->name == type) return i;
}
abort();
}

void writeImports() {
if (wasm->imports.size() == 0) return;
if (debug) std::cerr << "== writeImports" << std::endl;
Expand Down Expand Up @@ -742,11 +750,14 @@ class WasmBinaryWriter : public Visitor<WasmBinaryWriter, void> {
if (wasm->tables.size() == 0) return;
if (debug) std::cerr << "== writeFunctionTables" << std::endl;
auto start = startSection(BinaryConsts::Section::FunctionTable);
assert(wasm->tables.size() == 1);
// o << U32LEB(wasm->tables.size());
o << U32LEB(wasm->tables.size());
for (auto& curr : wasm->tables) {
if (debug) std::cerr << "write one" << std::endl;
o << U32LEB(curr->values.size());
o << int8_t(curr->isDefault);
o << U32LEB(getFunctionTypeIndex(curr->elementType->name));
assert(curr->initial == curr->values.size() && curr->initial == curr->max);
o << U32LEB(curr->initial);
o << U32LEB(curr->max);
for (auto name : curr->values) {
o << U32LEB(getFunctionIndex(name));
}
Expand Down Expand Up @@ -929,7 +940,7 @@ class WasmBinaryWriter : public Visitor<WasmBinaryWriter, void> {
for (auto* operand : curr->operands) {
recurse(operand);
}
o << int8_t(BinaryConsts::CallIndirect) << U32LEB(curr->operands.size()) << U32LEB(getFunctionTypeIndex(curr->fullType));
o << int8_t(BinaryConsts::CallIndirect) << U32LEB(curr->operands.size()) << U32LEB(getFunctionTypeIndex(curr->fullType)) << U32LEB(getFunctionTableIndex(curr->table));
}
void visitGetLocal(GetLocal *curr) {
if (debug) std::cerr << "zz node: GetLocal " << (o.size() + 1) << std::endl;
Expand Down Expand Up @@ -1457,6 +1468,11 @@ class WasmBinaryBuilder {
assert(numResults == 1);
curr->result = getWasmType();
}
// TODO: Handle "anyfunc" properly. This sets the name to "anyfunc" if
// it does not already exist, and matches the expected type signature.
if (!wasm.checkFunctionType(FunctionType::kAnyFunc) && FunctionType::isAnyFuncType(curr)) {
curr->name = FunctionType::kAnyFunc;
}
wasm.addFunctionType(curr);
}
}
Expand All @@ -1471,7 +1487,7 @@ class WasmBinaryBuilder {
curr->name = Name(std::string("import$") + std::to_string(i));
auto index = getU32LEB();
assert(index < wasm.functionTypes.size());
curr->type = wasm.getFunctionType(index);
curr->type = wasm.functionTypes[index].get();
assert(curr->type->name.is());
curr->module = getInlineString();
curr->base = getInlineString();
Expand Down Expand Up @@ -1646,6 +1662,8 @@ class WasmBinaryBuilder {
for (auto& pair : functionTable) {
assert(pair.first < wasm.tables.size());
assert(pair.second < wasm.functions.size());
assert(wasm.tables[pair.first]->values.size() <= wasm.tables[pair.first]->max);
assert(wasm.tables[pair.first]->elementType->name == FunctionType::kAnyFunc || wasm.tables[pair.first]->elementType == wasm.getFunctionType(wasm.functions[pair.second]->type));
wasm.tables[pair.first]->values.push_back(wasm.functions[pair.second]->name);
}
}
Expand All @@ -1670,12 +1688,20 @@ class WasmBinaryBuilder {

void readFunctionTables() {
if (debug) std::cerr << "== readFunctionTables" << std::endl;
size_t numTables = 1; // getU32LEB()
size_t numTables = getU32LEB();
for (size_t i = 0; i < numTables; i++) {
if (debug) std::cerr << "read one" << std::endl;
auto curr = new Table;
auto size = getU32LEB();
for (size_t j = 0; j < size; j++) {
auto flag = getInt8();
assert((!i && flag) || (i && !flag));
curr->isDefault = flag;
auto index = getU32LEB();
assert(index < functionTypes.size());
curr->elementType = wasm.getFunctionType(index);
curr->initial = getU32LEB();
curr->max = getU32LEB();
assert(curr->initial == curr->max);
for (size_t j = 0; j < curr->initial; j++) {
auto index = getU32LEB();
functionTable.push_back(std::make_pair<>(i, index));
}
Expand Down Expand Up @@ -1899,6 +1925,7 @@ class WasmBinaryBuilder {
curr->fullType = fullType->name;
auto num = fullType->params.size();
assert(num == arity);
curr->table = wasm.getTable(getU32LEB())->name;
curr->operands.resize(num);
for (size_t i = 0; i < num; i++) {
curr->operands[num - i - 1] = popExpression();
Expand Down
6 changes: 4 additions & 2 deletions src/wasm-builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,16 +115,18 @@ class Builder {
call->operands.set(args);
return call;
}
CallIndirect* makeCallIndirect(FunctionType* type, Expression* target, const std::vector<Expression*>& args) {
CallIndirect* makeCallIndirect(Name table, FunctionType* type, Expression* target, const std::vector<Expression*>& args) {
auto* call = allocator.alloc<CallIndirect>();
call->table = table;
call->fullType = type->name;
call->type = type->result;
call->target = target;
call->operands.set(args);
return call;
}
CallIndirect* makeCallIndirect(Name fullType, Expression* target, const std::vector<Expression*>& args, WasmType type) {
CallIndirect* makeCallIndirect(Name table, Name fullType, Expression* target, const std::vector<Expression*>& args, WasmType type) {
auto* call = allocator.alloc<CallIndirect>();
call->table = table;
call->fullType = fullType;
call->type = type;
call->target = target;
Expand Down
4 changes: 2 additions & 2 deletions src/wasm-interpreter.h
Original file line number Diff line number Diff line change
Expand Up @@ -666,12 +666,12 @@ class ModuleInstance {
LiteralList arguments;
Flow flow = generateArguments(curr->operands, arguments);
if (flow.breaking()) return flow;
Table *table = instance.wasm.getDefaultTable();
Table *table = instance.wasm.getTable(curr->table);
if (table->elementType->name != FunctionType::kAnyFunc && table->elementType->name != curr->fullType) trap("callIndirect: bad type");
size_t index = target.value.geti32();
if (index >= table->values.size()) trap("callIndirect: overflow");
Name name = table->values[index];
Function *func = instance.wasm.getFunction(name);
if (func->type.is() && func->type != curr->fullType) trap("callIndirect: bad type");
if (func->params.size() != arguments.size()) trap("callIndirect: bad # of arguments");
for (size_t i = 0; i < func->params.size(); i++) {
if (func->params[i] != arguments[i].type) {
Expand Down
60 changes: 29 additions & 31 deletions src/wasm-linker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,21 +116,24 @@ void Linker::layout() {
makeDummyFunction();

// Pre-assign the function indexes
for (auto& pair : out.indirectIndexes) {
Index tableIndex = Table::kDefault;
Table *table = out.getIndirectTable(tableIndex);
if (functionIndexes.count(pair.second) != 0) {
Fatal() << "Function " << pair.second << " already has an index " <<
functionIndexes[pair.second].first << " while setting index " << pair.first;
}
if (debug) {
std::cerr << "pre-assigned function index: " << pair.second << ": "
<< pair.first << '\n';
for (auto &pair : out.indirectIndexes) {
Index tableIndex = pair.first;
for (auto &funcName : pair.second) {
Function *func = out.wasm.getFunction(funcName);
FunctionType *funcType = ensureFunctionType(getSig(func), &out.wasm);
func->type = funcType->name;
Table *table = out.getIndirectTable(tableIndex, funcType);
if (functionIndexes.count(funcName) != 0) {
Fatal() << "Function " << funcName << " already has an index " <<
functionIndexes[funcName].first << " while setting index " << pair.first;
}
if (debug) {
std::cerr << "pre-assigned function index: " << funcName << ": "
<< pair.first << '\n';
}
functionIndexes[funcName] = std::make_pair(tableIndex, table->values.size());
table->values.push_back(funcName);
}
assert(table->values.size() == pair.first);
table->values.push_back(pair.second);
auto indexes = std::make_pair(tableIndex, pair.first);
functionIndexes[pair.second] = indexes;
}

for (auto& relocation : out.relocations) {
Expand Down Expand Up @@ -169,14 +172,6 @@ void Linker::layout() {
}
}

// Create the actual tables in the underlying module. This is delayed because
// table references may be out of order, and the underlying object is a vector.
Index counter = 0;
for (auto& pair : out.tables) {
if (pair.first != counter++) Fatal() << "Tables are nonconsecutive!" << '\n';
out.wasm.addTable(pair.second);
}

if (!!startFunction) {
if (out.symbolInfo.implementedFunctions.count(startFunction) == 0) {
Fatal() << "Unknown start function: `" << startFunction << "`\n";
Expand Down Expand Up @@ -212,12 +207,12 @@ void Linker::layout() {
}
}

// ensure an explicit function type for indirect call targets
for (auto& table : out.wasm.tables) {
for (auto& name : table->values) {
auto* func = out.wasm.getFunction(name);
func->type = ensureFunctionType(getSig(func), &out.wasm)->name;
}
// Create the actual tables in the underlying module. This is delayed because
// table references may be out of order, and the underlying object is a vector.
Index counter = 0;
for (auto& pair : out.tables) {
if (pair.first != counter++) Fatal() << "Tables are nonconsecutive!" << '\n';
out.wasm.addTable(pair.second);
}
}

Expand Down Expand Up @@ -380,8 +375,11 @@ void Linker::emscriptenGlue(std::ostream& o) {

Index Linker::getFunctionIndex(Name name) {
if (!functionIndexes.count(name)) {
Function *func = out.wasm.getFunction(name);
FunctionType *funcType = ensureFunctionType(getSig(func), &out.wasm);
func->type = funcType->name;
Index tableIndex = Table::kDefault;
Table *table = out.getIndirectTable(tableIndex);
Table *table = out.getIndirectTable(tableIndex, funcType);
functionIndexes[name] = std::make_pair(tableIndex, table->values.size());
table->values.push_back(name);
if (debug) {
Expand Down Expand Up @@ -416,7 +414,7 @@ void Linker::makeDummyFunction() {
Expression *unreachable = wasmBuilder.makeUnreachable();
Function *dummy = wasmBuilder.makeFunction(Name(dummyFunction), {}, WasmType::none, {}, unreachable);
out.wasm.addFunction(dummy);
getFunctionIndex(dummy->name);
out.addIndirectIndex(dummy->name, Table::kDefault);
}

void Linker::makeDynCallThunks() {
Expand All @@ -440,7 +438,7 @@ void Linker::makeDynCallThunks() {
for (unsigned i = 0; i < funcType->params.size(); ++i) {
args.push_back(wasmBuilder.makeGetLocal(i + 1, funcType->params[i]));
}
Expression* call = wasmBuilder.makeCallIndirect(funcType, fptr, args);
Expression* call = wasmBuilder.makeCallIndirect(table->name, funcType, fptr, args);
f->body = call;
out.wasm.addFunction(f);
exportFunction(f->name, true);
Expand Down
33 changes: 20 additions & 13 deletions src/wasm-linker.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,20 +113,25 @@ class LinkerObject {

// Create a table locally, because insertion into the underlying wasm vector
// needs to be delayed until all tables have been encountered.
Table *getIndirectTable(Index index) {
if (tables.count(index))
Table *getIndirectTable(Index index, FunctionType* type) {
Table *table;
if (tables.count(index)) {
return tables[index];

// Add the first default table, if it is missing and another table is
// being requested.
if (index && !tables.count(Table::kDefault)) {
getIndirectTable(Table::kDefault);
}

// Otherwise, proceed and create the requested table.
assert(index == Table::kDefault);
tables[index] = Table::createDefaultTable();
return tables[index];
if (index != Table::kDefault) {
table = new Table();
table->name = Name::fromInt(index);
table->isDefault = false;
table->elementType = type;
} else {
table = Table::createDefaultTable();
table->elementType = wasm.getAnyFuncType();
}
tables[index] = table;

return table;
}

// Add an initializer segment for the named static variable.
Expand Down Expand Up @@ -160,8 +165,10 @@ class LinkerObject {
}

void addIndirectIndex(Name name, Address index) {
assert(!indirectIndexes.count(index));
indirectIndexes[index] = name;
if (!indirectIndexes.count(index)) {
indirectIndexes[index] = std::vector<Name>();
}
indirectIndexes[index].push_back(name);
}

bool isEmpty() {
Expand Down Expand Up @@ -198,7 +205,7 @@ class LinkerObject {
std::map<Index, Table *> tables; // index => table index (in wasm module)

// preassigned indexes for functions called indirectly
std::map<Address, Name> indirectIndexes;
std::map<Address, std::vector<Name>> indirectIndexes;

std::vector<Name> initializerFunctions;

Expand Down
Loading

0 comments on commit 2566183

Please sign in to comment.