Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RFC: Add support for multiple tables #642

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/asm2wasm.h
Original file line number Diff line number Diff line change
Expand Up @@ -660,11 +660,11 @@ void Asm2WasmBuilder::processAsm(Ref ast) {
// TODO: when not using aliasing function pointers, we could merge them by noticing that
// index 0 in each table is the null func, and each other index should only have one
// non-null func. However, that breaks down when function pointer casts are emulated.
functionTableStarts[name] = wasm.table.names.size(); // this table starts here
functionTableStarts[name] = wasm.getDefaultTable()->values.size(); // this table starts here
Ref contents = value[1];
for (unsigned k = 0; k < contents->size(); k++) {
IString curr = contents[k][1]->getIString();
wasm.table.names.push_back(curr);
wasm.getDefaultTable()->values.push_back(curr);
}
} else {
abort_on("invalid var element", pair);
Expand Down Expand Up @@ -1404,6 +1404,7 @@ Function* Asm2WasmBuilder::processFunction(Ref ast) {
}
// function pointers
auto ret = allocator.alloc<CallIndirect>();
ret->table = wasm.getDefaultTable()->name;
Ref target = ast[1];
assert(target[0] == SUB && target[1][0] == NAME && target[2][0] == BINARY && target[2][1] == AND && target[2][3][0] == NUM); // FUNCTION_TABLE[(expr) & mask]
ret->target = process(target[2]); // TODO: as an optimization, we could look through the mask
Expand Down
4 changes: 3 additions & 1 deletion src/ast_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ struct ExpressionManipulator {
return ret;
}
Expression* visitCallIndirect(CallIndirect *curr) {
auto* ret = builder.makeCallIndirect(curr->fullType, curr->target, {}, curr->type);
auto* ret = builder.makeCallIndirect(curr->table, curr->fullType, curr->target, {}, curr->type);
for (Index i = 0; i < curr->operands.size(); i++) {
ret->operands.push_back(copy(curr->operands[i]));
}
Expand Down Expand Up @@ -467,6 +467,7 @@ struct ExpressionAnalyzer {
break;
}
case Expression::Id::CallIndirectId: {
CHECK(CallIndirect, table);
PUSH(CallIndirect, target);
CHECK(CallIndirect, fullType);
CHECK(CallIndirect, operands.size());
Expand Down Expand Up @@ -678,6 +679,7 @@ struct ExpressionAnalyzer {
break;
}
case Expression::Id::CallIndirectId: {
HASH_NAME(CallIndirect, table);
PUSH(CallIndirect, target);
HASH_NAME(CallIndirect, fullType);
HASH(CallIndirect, operands.size());
Expand Down
4 changes: 2 additions & 2 deletions src/binaryen-c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -586,7 +586,7 @@ BinaryenExpressionRef BinaryenSelect(BinaryenModuleRef module, BinaryenExpressio
}
BinaryenExpressionRef BinaryenReturn(BinaryenModuleRef module, BinaryenExpressionRef value) {
auto* ret = Builder(*((Module*)module)).makeReturn((Expression*)value);

if (tracing) {
auto id = noteExpression(ret);
std::cout << " expressions[" << id << "] = BinaryenReturn(the_module, expressions[" << expressions[value] << "]);\n";
Expand Down Expand Up @@ -730,7 +730,7 @@ void BinaryenSetFunctionTable(BinaryenModuleRef module, BinaryenFunctionRef* fun

auto* wasm = (Module*)module;
for (BinaryenIndex i = 0; i < numFuncs; i++) {
wasm->table.names.push_back(((Function*)funcs[i])->name);
wasm->getDefaultTable()->values.push_back(((Function*)funcs[i])->name);
}
}

Expand Down
4 changes: 0 additions & 4 deletions src/js/wasm.js-post.js
Original file line number Diff line number Diff line change
Expand Up @@ -301,8 +301,6 @@ function integrateWasmJS(Module) {
for (var i = 0; i < methods.length; i++) {
var curr = methods[i];

Module['printErr']('trying binaryen method: ' + curr);

if (curr === 'native-wasm') {
if (exports = doNativeWasm(global, env, providedBuffer)) break;
} else if (curr === 'asmjs') {
Expand All @@ -316,8 +314,6 @@ function integrateWasmJS(Module) {

if (!exports) throw 'no binaryen method succeeded';

Module['printErr']('binaryen method succeeded.');

return exports;
};

Expand Down
10 changes: 6 additions & 4 deletions src/passes/DuplicateFunctionElimination.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,10 +123,12 @@ struct DuplicateFunctionElimination : public Pass {
replacerRunner.add<FunctionReplacer>(&replacements);
replacerRunner.run();
// replace in table
for (auto& name : module->table.names) {
auto iter = replacements.find(name);
if (iter != replacements.end()) {
name = iter->second;
for (auto& curr : module->tables) {
for (auto& name : curr->values) {
auto iter = replacements.find(name);
if (iter != replacements.end()) {
name = iter->second;
}
}
}
// replace in start
Expand Down
16 changes: 11 additions & 5 deletions src/passes/Print.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,9 @@ struct PrintSExpression : public Visitor<PrintSExpression> {
printCallBody(curr);
}
void visitCallIndirect(CallIndirect *curr) {
printOpening(o, "call_indirect ") << curr->fullType;
printOpening(o, "call_indirect ");
printName(curr->table);
o << ' ' << curr->fullType;
incIndent();
printFullLine(curr->target);
for (auto operand : curr->operands) {
Expand Down Expand Up @@ -575,8 +577,12 @@ struct PrintSExpression : public Visitor<PrintSExpression> {
decIndent();
}
void visitTable(Table *curr) {
printOpening(o, "table");
for (auto name : curr->names) {
printOpening(o, "table ");
printName(curr->name) << ' ';
if (curr->isDefault)
o << "default" << ' ';
visitFunctionType(curr->elementType, true);
for (auto name : curr->values) {
o << ' ';
printName(name);
}
Expand Down Expand Up @@ -647,9 +653,9 @@ struct PrintSExpression : public Visitor<PrintSExpression> {
visitGlobal(child.get());
o << maybeNewLine;
}
if (curr->table.names.size() > 0) {
for (auto& child : curr->tables) {
doIndent(o, indent);
visitTable(&curr->table);
visitTable(child.get());
o << maybeNewLine;
}
for (auto& child : curr->functions) {
Expand Down
6 changes: 4 additions & 2 deletions src/passes/RemoveUnusedFunctions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,10 @@ struct RemoveUnusedFunctions : public Pass {
root.push_back(module->getFunction(curr->value));
}
// For now, all functions that can be called indirectly are marked as roots.
for (auto& curr : module->table.names) {
root.push_back(module->getFunction(curr));
for (auto& child : module->tables) {
for (auto& curr : child->values) {
root.push_back(module->getFunction(curr));
}
}
// Compute function reachability starting from the root set.
DirectCallGraphAnalyzer analyzer(module, root);
Expand Down
6 changes: 4 additions & 2 deletions src/passes/ReorderFunctions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,10 @@ struct ReorderFunctions : public WalkerPass<PostWalker<ReorderFunctions, Visitor
for (auto& curr : module->exports) {
counts[curr->value]++;
}
for (auto& curr : module->table.names) {
counts[curr]++;
for (auto& child : module->tables) {
for (auto& curr : child->values) {
counts[curr]++;
}
}
std::sort(module->functions.begin(), module->functions.end(), [this](
const std::unique_ptr<Function>& a,
Expand Down
8 changes: 7 additions & 1 deletion src/s2wasm.h
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,11 @@ class S2WasmBuilder {
return cashew::IString(str.c_str(), false);
}

uint32_t getTable() {
if (!match(".")) return 0;
return getInt();
}

std::vector<char> getQuoted() {
assert(*s == '"');
s++;
Expand Down Expand Up @@ -863,6 +868,7 @@ class S2WasmBuilder {
auto makeCall = [&](WasmType type) {
if (match("_indirect")) {
// indirect call
uint32_t table = getTable();
Name assign = getAssign();
int num = getNumInputs();
auto inputs = getInputs(num);
Expand All @@ -871,7 +877,7 @@ class S2WasmBuilder {
std::vector<Expression*> operands(++input, inputs.end());
auto* funcType = ensureFunctionType(getSig(type, operands), wasm);
assert(type == funcType->result);
auto* indirect = builder.makeCallIndirect(funcType, target, std::move(operands));
auto* indirect = builder.makeCallIndirect(linkerObj->getIndirectTable(table, funcType)->name, funcType, target, std::move(operands));
setOutput(indirect, assign);
} else {
// non-indirect call
Expand Down
85 changes: 62 additions & 23 deletions src/wasm-binary.h
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,7 @@ class WasmBinaryWriter : public Visitor<WasmBinaryWriter, void> {
writeSignatures();
writeImports();
writeFunctionSignatures();
writeFunctionTable();
writeFunctionTables();
writeMemory();
writeGlobals();
writeExports();
Expand Down Expand Up @@ -556,14 +556,22 @@ class WasmBinaryWriter : public Visitor<WasmBinaryWriter, void> {
finishSection(start);
}

int32_t getFunctionTypeIndex(Name type) {
Index getFunctionTypeIndex(Name type) {
// TODO: optimize
for (size_t i = 0; i < wasm->functionTypes.size(); i++) {
if (wasm->functionTypes[i]->name == type) return i;
}
abort();
}

Index getFunctionTableIndex(Name type) {
// TODO: optimize
for (size_t i = 0; i < wasm->tables.size(); i++) {
if (wasm->tables[i]->name == type) return i;
}
abort();
}

void writeImports() {
if (wasm->imports.size() == 0) return;
if (debug) std::cerr << "== writeImports" << std::endl;
Expand Down Expand Up @@ -685,7 +693,7 @@ class WasmBinaryWriter : public Visitor<WasmBinaryWriter, void> {

void writeExports() {
if (wasm->exports.size() == 0) return;
if (debug) std::cerr << "== writeexports" << std::endl;
if (debug) std::cerr << "== writeExports" << std::endl;
auto start = startSection(BinaryConsts::Section::ExportTable);
o << U32LEB(wasm->exports.size());
for (auto& curr : wasm->exports) {
Expand Down Expand Up @@ -724,8 +732,8 @@ class WasmBinaryWriter : public Visitor<WasmBinaryWriter, void> {
assert(mappedImports.count(name));
return mappedImports[name];
}
std::map<Name, uint32_t> mappedFunctions; // name of the Function => index

std::map<Name, uint32_t> mappedFunctions; // name of the Function => entry index
uint32_t getFunctionIndex(Name name) {
if (!mappedFunctions.size()) {
// Create name => index mapping.
Expand All @@ -738,13 +746,21 @@ class WasmBinaryWriter : public Visitor<WasmBinaryWriter, void> {
return mappedFunctions[name];
}

void writeFunctionTable() {
if (wasm->table.names.size() == 0) return;
if (debug) std::cerr << "== writeFunctionTable" << std::endl;
void writeFunctionTables() {
if (wasm->tables.size() == 0) return;
if (debug) std::cerr << "== writeFunctionTables" << std::endl;
auto start = startSection(BinaryConsts::Section::FunctionTable);
o << U32LEB(wasm->table.names.size());
for (auto name : wasm->table.names) {
o << U32LEB(getFunctionIndex(name));
o << U32LEB(wasm->tables.size());
for (auto& curr : wasm->tables) {
if (debug) std::cerr << "write one" << std::endl;
o << int8_t(curr->isDefault);
o << U32LEB(getFunctionTypeIndex(curr->elementType->name));
assert(curr->initial == curr->values.size() && curr->initial == curr->max);
o << U32LEB(curr->initial);
o << U32LEB(curr->max);
for (auto name : curr->values) {
o << U32LEB(getFunctionIndex(name));
}
}
finishSection(start);
}
Expand Down Expand Up @@ -924,7 +940,7 @@ class WasmBinaryWriter : public Visitor<WasmBinaryWriter, void> {
for (auto* operand : curr->operands) {
recurse(operand);
}
o << int8_t(BinaryConsts::CallIndirect) << U32LEB(curr->operands.size()) << U32LEB(getFunctionTypeIndex(curr->fullType));
o << int8_t(BinaryConsts::CallIndirect) << U32LEB(curr->operands.size()) << U32LEB(getFunctionTypeIndex(curr->fullType)) << U32LEB(getFunctionTableIndex(curr->table));
}
void visitGetLocal(GetLocal *curr) {
if (debug) std::cerr << "zz node: GetLocal " << (o.size() + 1) << std::endl;
Expand Down Expand Up @@ -1256,7 +1272,7 @@ class WasmBinaryBuilder {
else if (match(BinaryConsts::Section::ExportTable)) readExports();
else if (match(BinaryConsts::Section::Globals)) readGlobals();
else if (match(BinaryConsts::Section::DataSegments)) readDataSegments();
else if (match(BinaryConsts::Section::FunctionTable)) readFunctionTable();
else if (match(BinaryConsts::Section::FunctionTable)) readFunctionTables();
else if (match(BinaryConsts::Section::Names)) readNames();
else {
std::cerr << "unfamiliar section: ";
Expand Down Expand Up @@ -1452,6 +1468,11 @@ class WasmBinaryBuilder {
assert(numResults == 1);
curr->result = getWasmType();
}
// TODO: Handle "anyfunc" properly. This sets the name to "anyfunc" if
// it does not already exist, and matches the expected type signature.
if (!wasm.checkFunctionType(FunctionType::kAnyFunc) && FunctionType::isAnyFuncType(curr)) {
curr->name = FunctionType::kAnyFunc;
}
wasm.addFunctionType(curr);
}
}
Expand All @@ -1466,7 +1487,7 @@ class WasmBinaryBuilder {
curr->name = Name(std::string("import$") + std::to_string(i));
auto index = getU32LEB();
assert(index < wasm.functionTypes.size());
curr->type = wasm.getFunctionType(index);
curr->type = wasm.functionTypes[index].get();
assert(curr->type->name.is());
curr->module = getInlineString();
curr->base = getInlineString();
Expand Down Expand Up @@ -1638,9 +1659,12 @@ class WasmBinaryBuilder {
}
}

for (size_t index : functionTable) {
assert(index < wasm.functions.size());
wasm.table.names.push_back(wasm.functions[index]->name);
for (auto& pair : functionTable) {
assert(pair.first < wasm.tables.size());
assert(pair.second < wasm.functions.size());
assert(wasm.tables[pair.first]->values.size() <= wasm.tables[pair.first]->max);
assert(wasm.tables[pair.first]->elementType->name == FunctionType::kAnyFunc || wasm.tables[pair.first]->elementType == wasm.getFunctionType(wasm.functions[pair.second]->type));
wasm.tables[pair.first]->values.push_back(wasm.functions[pair.second]->name);
}
}

Expand All @@ -1660,14 +1684,28 @@ class WasmBinaryBuilder {
}
}

std::vector<size_t> functionTable;
std::vector<std::pair<size_t, size_t>> functionTable;

void readFunctionTable() {
if (debug) std::cerr << "== readFunctionTable" << std::endl;
auto num = getU32LEB();
for (size_t i = 0; i < num; i++) {
void readFunctionTables() {
if (debug) std::cerr << "== readFunctionTables" << std::endl;
size_t numTables = getU32LEB();
for (size_t i = 0; i < numTables; i++) {
if (debug) std::cerr << "read one" << std::endl;
auto curr = new Table;
auto flag = getInt8();
assert((!i && flag) || (i && !flag));
curr->isDefault = flag;
auto index = getU32LEB();
functionTable.push_back(index);
assert(index < functionTypes.size());
curr->elementType = wasm.getFunctionType(index);
curr->initial = getU32LEB();
curr->max = getU32LEB();
assert(curr->initial == curr->max);
for (size_t j = 0; j < curr->initial; j++) {
auto index = getU32LEB();
functionTable.push_back(std::make_pair<>(i, index));
}
wasm.addTable(curr);
}
}

Expand Down Expand Up @@ -1887,6 +1925,7 @@ class WasmBinaryBuilder {
curr->fullType = fullType->name;
auto num = fullType->params.size();
assert(num == arity);
curr->table = wasm.getTable(getU32LEB())->name;
curr->operands.resize(num);
for (size_t i = 0; i < num; i++) {
curr->operands[num - i - 1] = popExpression();
Expand Down
6 changes: 4 additions & 2 deletions src/wasm-builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,16 +115,18 @@ class Builder {
call->operands.set(args);
return call;
}
CallIndirect* makeCallIndirect(FunctionType* type, Expression* target, const std::vector<Expression*>& args) {
CallIndirect* makeCallIndirect(Name table, FunctionType* type, Expression* target, const std::vector<Expression*>& args) {
auto* call = allocator.alloc<CallIndirect>();
call->table = table;
call->fullType = type->name;
call->type = type->result;
call->target = target;
call->operands.set(args);
return call;
}
CallIndirect* makeCallIndirect(Name fullType, Expression* target, const std::vector<Expression*>& args, WasmType type) {
CallIndirect* makeCallIndirect(Name table, Name fullType, Expression* target, const std::vector<Expression*>& args, WasmType type) {
auto* call = allocator.alloc<CallIndirect>();
call->table = table;
call->fullType = fullType;
call->type = type;
call->target = target;
Expand Down
Loading