diff --git a/.github/workflows/scala.yml b/.github/workflows/scala.yml index 4677da77..d2f8348b 100644 --- a/.github/workflows/scala.yml +++ b/.github/workflows/scala.yml @@ -79,3 +79,4 @@ jobs: sbt 'testOnly gensym.wasm.TestConcolic' sbt 'testOnly gensym.wasm.TestDriver' sbt 'testOnly gensym.wasm.TestStagedEval' + sbt 'testOnly gensym.wasm.TestStagedConcolicEval' diff --git a/benchmarks/wasm/branch-strip-buggy.wat b/benchmarks/wasm/branch-strip-buggy.wat index c957db7f..0685f0be 100644 --- a/benchmarks/wasm/branch-strip-buggy.wat +++ b/benchmarks/wasm/branch-strip-buggy.wat @@ -29,6 +29,7 @@ else i32.const 0 call 2 + i32.const 1 ;; to satisfy the type checker, this line will never be reached end end ) diff --git a/benchmarks/wasm/staged/brtable_concolic.wat b/benchmarks/wasm/staged/brtable_concolic.wat new file mode 100644 index 00000000..04429e90 --- /dev/null +++ b/benchmarks/wasm/staged/brtable_concolic.wat @@ -0,0 +1,22 @@ +(module $brtable + (global (;0;) (mut i32) (i32.const 1048576)) + (type (;0;) (func (param i32))) + (func (;0;) (type 1) (result i32) + i32.const 2 + (block + (block + (block + i32.const 0 + i32.symbolic + br_table 0 1 2 0 ;; br_table will consume an element from the stack + ) + i32.const 1 + call 1 + br 1 + ) + i32.const 0 + call 1 + ) + ) + (import "console" "assert" (func (type 0))) + (start 0)) diff --git a/headers/wasm.hpp b/headers/wasm.hpp index 21da2ff7..36fe3849 100644 --- a/headers/wasm.hpp +++ b/headers/wasm.hpp @@ -2,5 +2,7 @@ #define WASM_HEADERS #include "wasm/concrete_rt.hpp" - +#include "wasm/symbolic_rt.hpp" +#include "wasm/concolic_driver.hpp" +#include "wasm/utils.hpp" #endif \ No newline at end of file diff --git a/headers/wasm/concolic_driver.hpp b/headers/wasm/concolic_driver.hpp new file mode 100644 index 00000000..8e8ca815 --- /dev/null +++ b/headers/wasm/concolic_driver.hpp @@ -0,0 +1,105 @@ +#ifndef CONCOLIC_DRIVER_HPP +#define CONCOLIC_DRIVER_HPP + +#include "concrete_rt.hpp" +#include "smt_solver.hpp" +#include "symbolic_rt.hpp" +#include +#include +#include +#include + +class ConcolicDriver { + friend class ManagedConcolicCleanup; + +public: + ConcolicDriver(std::function entrypoint, std::string tree_file) + : entrypoint(entrypoint), tree_file(tree_file) {} + ConcolicDriver(std::function entrypoint) + : entrypoint(entrypoint), tree_file(std::nullopt) {} + void run(); + +private: + Solver solver; + std::function entrypoint; + std::optional tree_file; +}; + +class ManagedConcolicCleanup { + const ConcolicDriver &driver; + +public: + ManagedConcolicCleanup(const ConcolicDriver &driver) : driver(driver) {} + ~ManagedConcolicCleanup() { + if (driver.tree_file.has_value()) + ExploreTree.dump_graphviz(driver.tree_file.value()); + } +}; + +inline void ConcolicDriver::run() { + ManagedConcolicCleanup cleanup{*this}; + while (true) { + ExploreTree.reset_cursor(); + + auto unexplored = ExploreTree.pick_unexplored(); + if (!unexplored) { + std::cout << "No unexplored nodes found, exiting..." << std::endl; + return; + } + auto cond = unexplored->collect_path_conds(); + auto result = solver.solve(cond); + if (!result.has_value()) { + // TODO: current implementation is buggy, there could be other reachable + // unexplored paths + std::cout << "Found an unreachable path, marking it as unreachable..." + << std::endl; + unexplored->fillUnreachableNode(); + continue; + } + auto new_env = result.value(); + SymEnv.update(std::move(new_env)); + try { + entrypoint(); + std::cout << "Execution finished successfully with symbolic environment:" + << std::endl; + std::cout << SymEnv.to_string() << std::endl; + } catch (...) { + ExploreTree.fillFailedNode(); + std::cout << "Caught runtime error with symbolic environment:" + << std::endl; + std::cout << SymEnv.to_string() << std::endl; + return; + } + } +} + +static std::monostate reset_stacks() { + Stack.reset(); + Frames.reset(); + SymStack.reset(); + SymFrames.reset(); + initRand(); + Memory = Memory_t(1); + return std::monostate{}; +} + +static void start_concolic_execution_with( + std::function entrypoint, + std::string tree_file) { + ConcolicDriver driver([=]() { entrypoint(std::monostate{}); }, tree_file); + driver.run(); +} + +static void start_concolic_execution_with( + std::function entrypoint) { + + const char *env_tree_file = std::getenv("TREE_FILE"); + + ConcolicDriver driver = + env_tree_file ? ConcolicDriver([=]() { entrypoint(std::monostate{}); }, + env_tree_file) + : ConcolicDriver([=]() { entrypoint(std::monostate{}); }); + driver.run(); +} + +#endif // CONCOLIC_DRIVER_HPP \ No newline at end of file diff --git a/headers/wasm/concrete_rt.hpp b/headers/wasm/concrete_rt.hpp index 34d739f4..a0961453 100644 --- a/headers/wasm/concrete_rt.hpp +++ b/headers/wasm/concrete_rt.hpp @@ -1,3 +1,6 @@ +#ifndef WASM_CONCRETE_RT_HPP +#define WASM_CONCRETE_RT_HPP + #include #include #include @@ -49,8 +52,6 @@ static Num I32V(int v) { return v; } static Num I64V(int64_t v) { return v; } -using Slice = std::vector; - const int STACK_SIZE = 1024 * 64; class Stack_t { @@ -115,9 +116,12 @@ class Stack_t { } void initialize() { - // do nothing for now + // todo: remove this method + reset(); } + void reset() { count = 0; } + private: int32_t count; Num *stack_ptr; @@ -148,6 +152,8 @@ class Frames_t { count += size; } + void reset() { count = 0; } + private: int32_t count; Num *stack_ptr; @@ -200,4 +206,6 @@ struct Memory_t { } }; -static Memory_t Memory(1); // 1 page memory \ No newline at end of file +static Memory_t Memory(1); // 1 page memory + +#endif // WASM_CONCRETE_RT_HPP \ No newline at end of file diff --git a/headers/wasm/smt_solver.hpp b/headers/wasm/smt_solver.hpp new file mode 100644 index 00000000..f2450905 --- /dev/null +++ b/headers/wasm/smt_solver.hpp @@ -0,0 +1,126 @@ +#ifndef SMT_SOLVER_HPP +#define SMT_SOLVER_HPP + +#include "concrete_rt.hpp" +#include "symbolic_rt.hpp" +#include "z3++.h" +#include +#include +#include +#include +#include + +class Solver { +public: + Solver() {} + std::optional> solve(const std::vector &conditions) { + // make an conjunction of all conditions + z3::expr conjunction = z3_ctx.bool_val(true); + for (const auto &cond : conditions) { + auto z3_cond = build_z3_expr(cond); + conjunction = conjunction && z3_cond != z3_ctx.bv_val(0, 32); + } +#ifdef DEBUG + std::cout << "Symbolic conditions size: " << conditions.size() << std::endl; + std::cout << "Solving conditions: " << conjunction << std::endl; +#endif + // call z3 to solve the condition + z3::solver z3_solver(z3_ctx); + z3_solver.add(conjunction); + switch (z3_solver.check()) { + case z3::unsat: + return std::nullopt; // No solution found + case z3::sat: { + z3::model model = z3_solver.get_model(); + std::vector result; + // Reference: + // https://github.com/Z3Prover/z3/blob/master/examples/c%2B%2B/example.cpp#L59 + + std::cout << "Solved Z3 model" << std::endl << model << std::endl; + for (unsigned i = 0; i < model.size(); ++i) { + z3::func_decl var = model[i]; + z3::expr value = model.get_const_interp(var); + std::string name = var.name().str(); + if (name.starts_with("s_")) { + int id = std::stoi(name.substr(2)); + if (id >= result.size()) { + result.resize(id + 1); + } + result[id] = Num(value.get_numeral_int64()); + } else { + std::cout << "Find a variable that is not created by GenSym: " << name + << std::endl; + } + } + return result; + } + case z3::unknown: + throw std::runtime_error("Z3 solver returned unknown status"); + } + return std::nullopt; // Should not reach here + } + +private: + z3::context z3_ctx; + z3::expr build_z3_expr(const SymVal &sym_val); +}; + +inline z3::expr Solver::build_z3_expr(const SymVal &sym_val) { + if (auto sym = std::dynamic_pointer_cast(sym_val.symptr)) { + return z3_ctx.bv_const(("s_" + std::to_string(sym->get_id())).c_str(), 32); + } else if (auto concrete = + std::dynamic_pointer_cast(sym_val.symptr)) { + return z3_ctx.bv_val(concrete->value.value, 32); + } else if (auto binary = + std::dynamic_pointer_cast(sym_val.symptr)) { + auto bit_width = 32; + z3::expr zero_bv = + z3_ctx.bv_val(0, bit_width); // Represents 0 as a 32-bit bitvector + z3::expr one_bv = + z3_ctx.bv_val(1, bit_width); // Represents 1 as a 32-bit bitvector + + z3::expr left = build_z3_expr(binary->lhs); + z3::expr right = build_z3_expr(binary->rhs); + // TODO: make sure the semantics of these operations are aligned with wasm + switch (binary->op) { + case EQ: { + auto temp_bool = left == right; + return z3::ite(temp_bool, one_bv, zero_bv); + } + case NEQ: { + auto temp_bool = left != right; + return z3::ite(temp_bool, one_bv, zero_bv); + } + case LT: { + auto temp_bool = left < right; + return z3::ite(temp_bool, one_bv, zero_bv); + } + case LEQ: { + auto temp_bool = left <= right; + return z3::ite(temp_bool, one_bv, zero_bv); + } + case GT: { + auto temp_bool = left > right; + return z3::ite(temp_bool, one_bv, zero_bv); + } + case GEQ: { + auto temp_bool = left >= right; + return z3::ite(temp_bool, one_bv, zero_bv); + } + case ADD: { + return left + right; + } + case SUB: { + return left - right; + } + case MUL: { + return left * right; + } + case DIV: { + return left / right; + } + } + } + throw std::runtime_error("Unsupported symbolic value type"); +} +#endif // SMT_SOLVER_HPP \ No newline at end of file diff --git a/headers/wasm/symbolic_rt.hpp b/headers/wasm/symbolic_rt.hpp new file mode 100644 index 00000000..18629c80 --- /dev/null +++ b/headers/wasm/symbolic_rt.hpp @@ -0,0 +1,556 @@ +#ifndef WASM_SYMBOLIC_RT_HPP +#define WASM_SYMBOLIC_RT_HPP + +#include "concrete_rt.hpp" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +class Symbolic { +public: + Symbolic() {} // TODO: remove this default constructor later + virtual ~Symbolic() = default; // Make Symbolic polymorphic +}; + +static int max_id = 0; + +class Symbol : public Symbolic { +public: + // TODO: add type information to determine the size of bitvector + // for now we just assume that only i32 will be used + Symbol(int id) : id(id) { max_id = std::max(max_id, id); } + int get_id() const { return id; } + +private: + int id; +}; + +class SymConcrete : public Symbolic { +public: + Num value; + SymConcrete(Num num) : value(num) {} +}; + +struct SymBinary; + +struct SymVal { + std::shared_ptr symptr; + + SymVal() : symptr(nullptr) {} + SymVal(std::shared_ptr symptr) : symptr(symptr) {} + + // data structure operations + SymVal makeSymbolic() const; + + // arithmetic operations + SymVal is_zero() const; + SymVal add(const SymVal &other) const; + SymVal minus(const SymVal &other) const; + SymVal mul(const SymVal &other) const; + SymVal div(const SymVal &other) const; + SymVal eq(const SymVal &other) const; + SymVal neq(const SymVal &other) const; + SymVal lt(const SymVal &other) const; + SymVal leq(const SymVal &other) const; + SymVal gt(const SymVal &other) const; + SymVal geq(const SymVal &other) const; + SymVal negate() const; +}; + +inline SymVal Concrete(Num num) { + return SymVal(std::make_shared(num)); +} + +enum Operation { ADD, SUB, MUL, DIV, EQ, NEQ, LT, LEQ, GT, GEQ }; + +struct SymBinary : Symbolic { + Operation op; + SymVal lhs; + SymVal rhs; + + SymBinary(Operation op, SymVal lhs, SymVal rhs) + : op(op), lhs(lhs), rhs(rhs) {} +}; + +inline SymVal SymVal::add(const SymVal &other) const { + return SymVal(std::make_shared(ADD, *this, other)); +} + +inline SymVal SymVal::minus(const SymVal &other) const { + return SymVal(std::make_shared(SUB, *this, other)); +} + +inline SymVal SymVal::mul(const SymVal &other) const { + return SymVal(std::make_shared(MUL, *this, other)); +} + +inline SymVal SymVal::div(const SymVal &other) const { + return SymVal(std::make_shared(DIV, *this, other)); +} + +inline SymVal SymVal::eq(const SymVal &other) const { + return SymVal(std::make_shared(EQ, *this, other)); +} + +inline SymVal SymVal::neq(const SymVal &other) const { + return SymVal(std::make_shared(NEQ, *this, other)); +} +inline SymVal SymVal::lt(const SymVal &other) const { + return SymVal(std::make_shared(LT, *this, other)); +} +inline SymVal SymVal::leq(const SymVal &other) const { + return SymVal(std::make_shared(LEQ, *this, other)); +} +inline SymVal SymVal::gt(const SymVal &other) const { + return SymVal(std::make_shared(GT, *this, other)); +} +inline SymVal SymVal::geq(const SymVal &other) const { + return SymVal(std::make_shared(GEQ, *this, other)); +} +inline SymVal SymVal::is_zero() const { + return SymVal(std::make_shared(EQ, *this, Concrete(I32V(0)))); +} +inline SymVal SymVal::negate() const { + return SymVal(std::make_shared(EQ, *this, Concrete(I32V(0)))); +} + +inline SymVal SymVal::makeSymbolic() const { + auto concrete = dynamic_cast(symptr.get()); + if (concrete) { + // If the symbolic value is a concrete value, use it to create a symbol + return SymVal(std::make_shared(concrete->value.toInt())); + } else { + throw std::runtime_error( + "Cannot make symbolic a non-concrete symbolic value"); + } +} + +class SymStack_t { +public: + void push(SymVal val) { + // Push a symbolic value to the stack + stack.push_back(val); + } + + SymVal pop() { + // Pop a symbolic value from the stack + auto ret = stack.back(); + stack.pop_back(); + return ret; + } + + SymVal peek() { return stack.back(); } + + void reset() { + // Reset the symbolic stack + stack.clear(); + } + + std::vector stack; +}; + +static SymStack_t SymStack; + +class SymFrames_t { +public: + void pushFrame(int size) { + // Push a new frame with the given size + stack.resize(size + stack.size()); + } + std::monostate popFrame(int size) { + // Pop the frame of the given size + stack.resize(stack.size() - size); + return std::monostate(); + } + + SymVal get(int index) { + // Get the symbolic value at the given frame index + return stack[stack.size() - 1 - index]; + } + + void set(int index, SymVal val) { + // Set the symbolic value at the given index + // Not implemented yet + stack[stack.size() - 1 - index] = val; + } + + void reset() { + // Reset the symbolic frames + stack.clear(); + } + + std::vector stack; +}; + +static SymFrames_t SymFrames; + +struct Node; + +struct NodeBox { + explicit NodeBox(NodeBox *parent); + std::unique_ptr node; + NodeBox *parent; + + std::monostate fillIfElseNode(SymVal cond); + std::monostate fillFinishedNode(); + std::monostate fillFailedNode(); + std::monostate fillUnreachableNode(); + + std::vector collect_path_conds(); +}; + +struct Node { + virtual ~Node(){}; + virtual std::string to_string() = 0; + void to_graphviz(std::ostream &os) { + os << "digraph G {\n"; + os << " rankdir=TB;\n"; + os << " node [shape=box, style=filled, fillcolor=lightblue];\n"; + current_id = 0; + generate_dot(os, -1, ""); + + os << "}\n"; + } + virtual void generate_dot(std::ostream &os, int parent_dot_id, + const std::string &edge_label) = 0; + +protected: + // Counter for unique node IDs across the entire graph, only for generating + // graphviz purpose + static int current_id; + void graphviz_node(std::ostream &os, const int node_id, + const std::string &label, const std::string &shape, + const std::string &fillcolor) { + os << " node" << node_id << " [label=\"" << label << "\", shape=" << shape + << ", style=filled, fillcolor=" << fillcolor << "];\n"; + } + + void graphviz_edge(std::ostream &os, int from_id, int target_id, + const std::string &edge_label) { + os << " node" << from_id << " -> node" << target_id; + if (!edge_label.empty()) { + os << " [label=\"" << edge_label << "\"]"; + } + os << ";\n"; + } +}; + +// TODO: use this header file in multiple compilation units will cause problems +// during linking +int Node::current_id = 0; + +struct IfElseNode : Node { + SymVal cond; + std::unique_ptr true_branch; + std::unique_ptr false_branch; + + IfElseNode(SymVal cond, NodeBox *parent) + : cond(cond), true_branch(std::make_unique(parent)), + false_branch(std::make_unique(parent)) {} + + std::string to_string() override { + std::string result = "IfElseNode {\n"; + result += " true_branch: "; + if (true_branch) { + result += true_branch->node->to_string(); + } else { + result += "nullptr"; + } + result += "\n"; + + result += " false_branch: "; + if (false_branch) { + result += false_branch->node->to_string(); + } else { + result += "nullptr"; + } + result += "\n"; + result += "}"; + return result; + } + + void generate_dot(std::ostream &os, int parent_dot_id, + const std::string &edge_label) override { + int current_node_dot_id = current_id; + current_id += 1; + + graphviz_node(os, current_node_dot_id, "If", "diamond", "lightyellow"); + + // Draw edge from parent if this is not the root node + if (parent_dot_id != -1) { + graphviz_edge(os, parent_dot_id, current_node_dot_id, edge_label); + } + assert(true_branch != nullptr); + assert(true_branch->node != nullptr); + true_branch->node->generate_dot(os, current_node_dot_id, "true"); + assert(false_branch != nullptr); + assert(false_branch->node != nullptr); + false_branch->node->generate_dot(os, current_node_dot_id, "false"); + } +}; + +struct UnExploredNode : Node { + UnExploredNode() {} + std::string to_string() override { return "UnexploredNode"; } + +protected: + void generate_dot(std::ostream &os, int parent_dot_id, + const std::string &edge_label) override { + int current_node_dot_id = current_id++; + graphviz_node(os, current_node_dot_id, "Unexplored", "octagon", + "lightgrey"); + + if (parent_dot_id != -1) { + graphviz_edge(os, parent_dot_id, current_node_dot_id, edge_label); + } + } +}; + +struct Finished : Node { + Finished() {} + std::string to_string() override { return "FinishedNode"; } + +protected: + void generate_dot(std::ostream &os, int parent_dot_id, + const std::string &edge_label) override { + int current_node_dot_id = current_id++; + graphviz_node(os, current_node_dot_id, "Finished", "box", "lightgreen"); + + if (parent_dot_id != -1) { + graphviz_edge(os, parent_dot_id, current_node_dot_id, edge_label); + } + } +}; + +struct Failed : Node { + Failed() {} + std::string to_string() override { return "FailedNode"; } + +protected: + void generate_dot(std::ostream &os, int parent_dot_id, + const std::string &edge_label) override { + int current_node_dot_id = current_id++; + graphviz_node(os, current_node_dot_id, "Failed", "box", "red"); + + if (parent_dot_id != -1) { + graphviz_edge(os, parent_dot_id, current_node_dot_id, edge_label); + } + } +}; + +struct Unreachable : Node { + Unreachable() {} + std::string to_string() override { return "UnreachableNode"; } + +protected: + void generate_dot(std::ostream &os, int parent_dot_id, + const std::string &edge_label) override { + int current_node_dot_id = current_id++; + graphviz_node(os, current_node_dot_id, "Unreachable", "box", "orange"); + + if (parent_dot_id != -1) { + graphviz_edge(os, parent_dot_id, current_node_dot_id, edge_label); + } + } +}; + +inline NodeBox::NodeBox(NodeBox *parent) + : node(std::make_unique()), + /* TODO: avoid allocation of unexplored node */ + parent(parent) {} + +inline std::monostate NodeBox::fillIfElseNode(SymVal cond) { + // fill the current NodeBox with an ifelse branch node it's unexplored + if (dynamic_cast(node.get())) { + node = std::make_unique(cond, this); + } + assert(dynamic_cast(node.get()) != nullptr && + "Current node is not an IfElseNode, cannot fill it!"); + return std::monostate(); +} + +inline std::monostate NodeBox::fillFinishedNode() { + if (dynamic_cast(node.get())) { + node = std::make_unique(); + } else { + assert(dynamic_cast(node.get()) != nullptr); + } + return std::monostate(); +} + +inline std::monostate NodeBox::fillFailedNode() { + if (dynamic_cast(node.get())) { + node = std::make_unique(); + } else { + assert(dynamic_cast(node.get()) != nullptr); + } + return std::monostate(); +} + +inline std::monostate NodeBox::fillUnreachableNode() { + if (dynamic_cast(node.get())) { + node = std::make_unique(); + } else { + assert(dynamic_cast(node.get()) != nullptr); + } + return std::monostate(); +} + +inline std::vector NodeBox::collect_path_conds() { + auto box = this; + auto result = std::vector(); + while (box->parent) { + auto parent = box->parent; + auto if_else_node = dynamic_cast(parent->node.get()); + if (if_else_node) { + if (if_else_node->true_branch.get() == box) { + // If the current box is the true branch, add the condition + result.push_back(if_else_node->cond); + } else if (if_else_node->false_branch.get() == box) { + // If the current box is the false branch, add the negated condition + result.push_back(if_else_node->cond.negate()); + } else { + throw std::runtime_error("Unexpected node structure in explore tree"); + } + } + // Move to parent + box = box->parent; + } + return result; +} + +class ExploreTree_t { +public: + explicit ExploreTree_t() + : root(std::make_unique(nullptr)), cursor(root.get()) {} + + void reset_cursor() { + // Reset the cursor to the root of the tree + cursor = root.get(); + } + + std::monostate fillFinishedNode() { return cursor->fillFinishedNode(); } + + std::monostate fillFailedNode() { return cursor->fillFailedNode(); } + + std::monostate fillIfElseNode(SymVal cond) { + return cursor->fillIfElseNode(cond); + } + + std::monostate moveCursor(bool branch) { + assert(cursor != nullptr); + auto if_else_node = dynamic_cast(cursor->node.get()); + assert( + if_else_node != nullptr && + "Can't move cursor when the branch node is not initialized correctly!"); + if (branch) { + cursor = if_else_node->true_branch.get(); + } else { + cursor = if_else_node->false_branch.get(); + } + return std::monostate(); + } + + std::monostate print() { + std::cout << root->node->to_string() << std::endl; + return std::monostate(); + } + + std::monostate to_graphviz(std::ostream &os) { + root->node->to_graphviz(os); + return std::monostate(); + } + + std::monostate dump_graphviz(std::string filepath) { + std::ofstream ofs(filepath); + if (!ofs.is_open()) { + throw std::runtime_error("Failed to open " + filepath + " for writing"); + } + to_graphviz(ofs); + return std::monostate(); + } + + std::optional> get_unexplored_conditions() { + // Get all unexplored conditions in the tree + std::vector result; + auto box = pick_unexplored(); + if (!box) { + return std::nullopt; + } + return box->collect_path_conds(); + } + + NodeBox *pick_unexplored() { + // Pick an unexplored node from the tree + // For now, we just iterate through the tree and return the first unexplored + return pick_unexplored_of(root.get()); + } + +private: + NodeBox *pick_unexplored_of(NodeBox *node) { + if (dynamic_cast(node->node.get()) != nullptr) { + return node; + } + auto if_else_node = dynamic_cast(node->node.get()); + if (if_else_node) { + NodeBox *result = pick_unexplored_of(if_else_node->true_branch.get()); + if (result) { + return result; + } + return pick_unexplored_of(if_else_node->false_branch.get()); + } + return nullptr; // No unexplored node found + } + std::unique_ptr root; + NodeBox *cursor; +}; + +static ExploreTree_t ExploreTree; + +class SymEnv_t { +public: + Num read(SymVal sym) { + auto symbol = dynamic_cast(sym.symptr.get()); + assert(symbol); + if (symbol->get_id() >= map.size()) { + map.resize(symbol->get_id() + 1); + } +#if DEBUG + std::cout << "Read symbol: " << symbol->get_id() + << " from symbolic environment" << std::endl; + std::cout << "Current symbolic environment: " << to_string() << std::endl; +#endif + return map[symbol->get_id()]; + } + + void update(std::vector new_env) { + map = std::move(new_env); + } + + std::string to_string() const { + std::string result; + result += "(\n"; + for (int i = 0; i < map.size(); ++i) { + const Num &num = map[i]; + result += + " (" + std::to_string(i) + "->" + std::to_string(num.value) + ")\n"; + } + result += ")"; + return result; + } + +private: + std::vector map; // The symbolic environment, a vector of Num +}; + +static SymEnv_t SymEnv; + +#endif // WASM_SYMBOLIC_RT_HPP \ No newline at end of file diff --git a/headers/wasm/utils.hpp b/headers/wasm/utils.hpp new file mode 100644 index 00000000..8a86ac98 --- /dev/null +++ b/headers/wasm/utils.hpp @@ -0,0 +1,15 @@ +#ifndef UTILS_HPP +#define UTILS_HPP + +#ifndef GENSYM_ASSERT +#define GENSYM_ASSERT(condition) \ + do { \ + if (!(condition)) { \ + throw std::runtime_error(std::string("Assertion failed: ") + " (" + \ + __FILE__ + ":" + std::to_string(__LINE__) + \ + ")"); \ + } \ + } while (0) +#endif + +#endif // UTILS_HPP \ No newline at end of file diff --git a/src/main/scala/wasm/StagedConcolicMiniWasm.scala b/src/main/scala/wasm/StagedConcolicMiniWasm.scala new file mode 100644 index 00000000..833bbc9b --- /dev/null +++ b/src/main/scala/wasm/StagedConcolicMiniWasm.scala @@ -0,0 +1,1102 @@ +package gensym.wasm.stagedconcolicminiwasm + +import scala.collection.mutable.{ArrayBuffer, HashMap} + +import lms.core.stub.Adapter +import lms.core.virtualize +import lms.macros.SourceContext +import lms.core.stub.{Base, ScalaGenBase, CGenBase} +import lms.core.Backend._ +import lms.core.Backend.{Block => LMSBlock, Const => LMSConst} +import lms.core.Graph + +import gensym.wasm.ast._ +import gensym.wasm.ast.{Const => WasmConst, Block => WasmBlock} +import gensym.wasm.miniwasm.{ModuleInstance} +import gensym.wasm.symbolic.{SymVal} +import gensym.lmsx.{SAIDriver, StringOps, SAIOps, SAICodeGenBase, CppSAIDriver, CppSAICodeGenBase} +import gensym.wasm.symbolic.Concrete +import gensym.wasm.symbolic.ExploreTree +import gensym.structure.freer.Explore + +@virtualize +trait StagedWasmEvaluator extends SAIOps { + def module: ModuleInstance + + trait ReturnSite + + trait StagedNum { + def tipe: ValueType = this match { + case I32(_, _) => NumType(I32Type) + case I64(_, _) => NumType(I64Type) + case F32(_, _) => NumType(F32Type) + case F64(_, _) => NumType(F64Type) + } + + def i: Rep[Num] + + def s: Rep[SymVal] + } + case class I32(i: Rep[Num], s: Rep[SymVal]) extends StagedNum + case class I64(i: Rep[Num], s: Rep[SymVal]) extends StagedNum + case class F32(i: Rep[Num], s: Rep[SymVal]) extends StagedNum + case class F64(i: Rep[Num], s: Rep[SymVal]) extends StagedNum + + def toStagedNum(num: Num): StagedNum = { + num match { + case I32V(_) => I32(num, Concrete(num)) + case I64V(_) => I64(num, Concrete(num)) + case F32V(_) => F32(num, Concrete(num)) + case F64V(_) => F64(num, Concrete(num)) + } + } + + implicit class ValueTypeOps(ty: ValueType) { + def size: Int = ty match { + case NumType(I32Type) => 4 + case NumType(I64Type) => 8 + case NumType(F32Type) => 4 + case NumType(F64Type) => 8 + } + + def toTagger: (Rep[Num], Rep[SymVal]) => StagedNum = { + ty match { + case NumType(I32Type) => I32 + case NumType(I64Type) => I64 + case NumType(F32Type) => F32 + case NumType(F64Type) => F64 + } + } + } + + case class Context( + stackTypes: List[ValueType], + frameTypes: List[ValueType] + ) { + def push(ty: ValueType): Context = { + Context(ty :: stackTypes, frameTypes) + } + + def pop(): (ValueType, Context) = { + val (ty :: rest) = stackTypes + (ty, Context(rest, frameTypes)) + } + + def shift(offset: Int, size: Int): Context = { + // Predef.println(s"[DEBUG] Shifting stack by $offset, size $size, $this") + Predef.assert(offset >= 0, s"Context shift offset must be non-negative, get $offset") + if (offset == 0) { + this + } else { + this.copy( + stackTypes = stackTypes.take(size) ++ stackTypes.drop(offset + size) + ) + } + } + } + + type MCont[A] = Unit => A + type Cont[A] = (MCont[A]) => A + type Trail[A] = List[Context => Rep[Cont[A]]] + + // a cache storing the compiled code for each function, to reduce re-compilation + val compileCache = new HashMap[Int, Rep[(MCont[Unit]) => Unit]] + + def makeDummy: Rep[Unit] = "dummy".reflectCtrlWith[Unit]() + + def funHere[A:Manifest,B:Manifest](f: Rep[A] => Rep[B], dummy: Rep[Unit]): Rep[A => B] = { + // to avoid LMS lifting a function, we create a dummy node and read it inside function + fun((x: Rep[A]) => { + "dummy-op".reflectCtrlWith[Unit](dummy) + f(x) + }) + } + + + def eval(insts: List[Instr], + kont: Context => Rep[Cont[Unit]], + mkont: Rep[MCont[Unit]], + trail: Trail[Unit]) + (implicit ctx: Context): Rep[Unit] = { + if (insts.isEmpty) return kont(ctx)(mkont) + + // Predef.println(s"[DEBUG] Evaluating instructions: ${insts.mkString(", ")}") + // Predef.println(s"[DEBUG] Current context: $ctx") + + val (inst, rest) = (insts.head, insts.tail) + inst match { + case Drop => + val (_, newCtx) = Stack.pop() + eval(rest, kont, mkont, trail)(newCtx) + case WasmConst(num) => + val newCtx = Stack.push(toStagedNum(num)) + eval(rest, kont, mkont, trail)(newCtx) + case Symbolic(ty) => + val (id, newCtx1) = Stack.pop() + val symVal = id.makeSymbolic() + val concVal = SymEnv.read(symVal) + val tagger = ty.toTagger + val value = tagger(concVal, symVal) + val newCtx2 = Stack.push(value)(newCtx1) + eval(rest, kont, mkont, trail)(newCtx2) + case LocalGet(i) => + val newCtx = Stack.push(Frames.get(i)) + eval(rest, kont, mkont, trail)(newCtx) + case LocalSet(i) => + val (num, newCtx) = Stack.pop() + Frames.set(i, num)(newCtx) + eval(rest, kont, mkont, trail)(newCtx) + case LocalTee(i) => + val (num, newCtx) = Stack.peek + Frames.set(i, num) + eval(rest, kont, mkont, trail)(newCtx) + case GlobalGet(i) => + val newCtx = Stack.push(Globals(i)) + eval(rest, kont, mkont, trail)(newCtx) + case GlobalSet(i) => + val (value, newCtx) = Stack.pop() + module.globals(i).ty match { + case GlobalType(tipe, true) => Globals(i) = value + case _ => throw new Exception("Cannot set immutable global") + } + eval(rest, kont, mkont, trail)(newCtx) + case Store(StoreOp(align, offset, ty, None)) => + val (value, newCtx1) = Stack.pop() + val (addr, newCtx2) = Stack.pop()(newCtx1) + Memory.storeInt(addr.toInt, offset, value.toInt) + eval(rest, kont, mkont, trail)(newCtx2) + case Nop => eval(rest, kont, mkont, trail) + case Load(LoadOp(align, offset, ty, None, None)) => + val (addr, newCtx1) = Stack.pop() + val value = Memory.loadInt(addr.toInt, offset) + val newCtx2 = Stack.push(value)(newCtx1) + eval(rest, kont, mkont, trail)(newCtx2) + case MemorySize => ??? + case MemoryGrow => + val (delta, newCtx1) = Stack.pop() + val ret = Memory.grow(delta.toInt) + val retNum = Values.I32V(ret) + val retSym = "Concrete".reflectCtrlWith[SymVal](retNum) + val newCtx2 = Stack.push(I32(retNum, retSym))(newCtx1) + eval(rest, kont, mkont, trail)(newCtx2) + case MemoryFill => ??? + case Unreachable => unreachable() + case Test(op) => + val (v, newCtx1) = Stack.pop() + val newCtx2 = Stack.push(evalTestOp(op, v))(newCtx1) + eval(rest, kont, mkont, trail)(newCtx2) + case Unary(op) => + val (v, newCtx1) = Stack.pop() + val newCtx2 = Stack.push(evalUnaryOp(op, v))(newCtx1) + eval(rest, kont, mkont, trail)(newCtx2) + case Binary(op) => + val (v2, newCtx1) = Stack.pop() + val (v1, newCtx2) = Stack.pop()(newCtx1) + val newCtx3 = Stack.push(evalBinOp(op, v1, v2))(newCtx2) + eval(rest, kont, mkont, trail)(newCtx3) + case Compare(op) => + val (v2, newCtx1) = Stack.pop() + val (v1, newCtx2) = Stack.pop()(newCtx1) + val newCtx3 = Stack.push(evalRelOp(op, v1, v2))(newCtx2) + eval(rest, kont, mkont, trail)(newCtx3) + case WasmBlock(ty, inner) => + // no need to modify the stack when entering a block + // the type system guarantees that we will never take more than the input size from the stack + val funcTy = ty.funcType + val exitSize = ctx.stackTypes.size - funcTy.inps.size + funcTy.out.size + val dummy = makeDummy + def restK(restCtx: Context): Rep[Cont[Unit]] = topFun((mk: Rep[MCont[Unit]]) => { + info(s"Exiting the block, stackSize =", Stack.size) + val offset = restCtx.stackTypes.size - exitSize + val newRestCtx = Stack.shift(offset, funcTy.out.size)(restCtx) + eval(rest, kont, mk, trail)(newRestCtx) + }) + eval(inner, restK _, mkont, restK _ :: trail) + case Loop(ty, inner) => + val funcTy = ty.funcType + val exitSize = ctx.stackTypes.size - funcTy.inps.size + funcTy.out.size + val dummy = makeDummy + def restK(restCtx: Context): Rep[Cont[Unit]] = topFun((mk: Rep[MCont[Unit]]) => { + info(s"Exiting the loop, stackSize =", Stack.size) + val offset = restCtx.stackTypes.size - exitSize + val newRestCtx = Stack.shift(offset, funcTy.out.size)(restCtx) + eval(rest, kont, mk, trail)(newRestCtx) + }) + val enterSize = ctx.stackTypes.size + def loop(restCtx: Context): Rep[Cont[Unit]] = topFun((mk: Rep[MCont[Unit]]) => { + info(s"Entered the loop, stackSize =", Stack.size) + val offset = restCtx.stackTypes.size - enterSize + val newRestCtx = Stack.shift(offset, funcTy.inps.size)(restCtx) + eval(inner, restK _, mk, loop _ :: trail)(newRestCtx) + }) + loop(ctx)(mkont) + case If(ty, thn, els) => + val funcTy = ty.funcType + val (cond, newCtx) = Stack.pop() + val exitSize = newCtx.stackTypes.size - funcTy.inps.size + funcTy.out.size + // TODO: can we avoid code duplication here? + val dummy = makeDummy + def restK(restCtx: Context): Rep[Cont[Unit]] = topFun((mk: Rep[MCont[Unit]]) => { + info(s"Exiting the if, stackSize =", Stack.size) + val offset = restCtx.stackTypes.size - exitSize + val newRestCtx = Stack.shift(offset, funcTy.out.size)(restCtx) + eval(rest, kont, mk, trail)(newRestCtx) + }) + // TODO: put the cond.s to path condition + ExploreTree.fillWithIfElse(cond.s) + if (cond.toInt != 0) { + ExploreTree.moveCursor(true) + eval(thn, restK _, mkont, restK _ :: trail)(newCtx) + } else { + ExploreTree.moveCursor(false) + eval(els, restK _, mkont, restK _ :: trail)(newCtx) + } + () + case Br(label) => + info(s"Jump to $label") + trail(label)(ctx)(mkont) + case BrIf(label) => + val (cond, newCtx) = Stack.pop() + info(s"The br_if(${label})'s condition is ", cond.toInt) + // TODO: put the cond.s to path condition + ExploreTree.fillWithIfElse(cond.s) + if (cond.toInt != 0) { + info(s"Jump to $label") + ExploreTree.moveCursor(true) + trail(label)(newCtx)(mkont) + } else { + info(s"Continue") + ExploreTree.moveCursor(false) + eval(rest, kont, mkont, trail)(newCtx) + } + () + case BrTable(labels, default) => + val (label, newCtx) = Stack.pop() + def aux(choices: List[Int], idx: Int): Rep[Unit] = { + if (choices.isEmpty) trail(default)(newCtx)(mkont) + else { + val cond = (label - toStagedNum(I32V(idx))).isZero() + ExploreTree.fillWithIfElse(cond.s) + if (cond.toInt != 0) { + ExploreTree.moveCursor(true) + trail(choices.head)(newCtx)(mkont) + } + else { + ExploreTree.moveCursor(false) + aux(choices.tail, idx + 1) + } + } + } + aux(labels, 0) + case Return => trail.last(ctx)(mkont) + case Call(f) => evalCall(rest, kont, mkont, trail, f, false) + case ReturnCall(f) => evalCall(rest, kont, mkont, trail, f, true) + case _ => + val todo = "todo-op".reflectCtrlWith[Unit]() + eval(rest, kont, mkont, trail) + } + } + + def forwardKont: Rep[Cont[Unit]] = topFun((mk: Rep[MCont[Unit]]) => mk(())) + + + def evalCall(rest: List[Instr], + kont: Context => Rep[Cont[Unit]], + mkont: Rep[MCont[Unit]], + trail: Trail[Unit], + funcIndex: Int, + isTail: Boolean) + (implicit ctx: Context): Rep[Unit] = { + module.funcs(funcIndex) match { + case FuncDef(_, FuncBodyDef(ty, _, bodyLocals, body)) => + val locals = bodyLocals ++ ty.inps + val callee = + if (compileCache.contains(funcIndex)) { + compileCache(funcIndex) + } else { + val callee = topFun((mk: Rep[MCont[Unit]]) => { + info(s"Entered the function at $funcIndex, stackSize =", Stack.size) + // we can do some check here to ensure the function returns correct size of stack + eval(body, (_: Context) => forwardKont, mk, ((_: Context) => forwardKont)::Nil)(Context(Nil, locals)) + }) + compileCache(funcIndex) = callee + callee + } + // Predef.println(s"[DEBUG] locals size: ${locals.size}") + val (args, newCtx) = Stack.take(ty.inps.size) + if (isTail) { + // when tail call, return to the caller's return continuation + Frames.popFrame(ctx.frameTypes.size) + Frames.pushFrame(locals) + Frames.putAll(args) + callee(mkont) + } else { + // We make a new trail by `restK`, since function creates a new block to escape + // (more or less like `return`) + val restK: Rep[Cont[Unit]] = topFun((mk: Rep[MCont[Unit]]) => { + info(s"Exiting the function at $funcIndex, stackSize =", Stack.size) + Frames.popFrame(locals.size) + eval(rest, kont, mk, trail)(newCtx.copy(stackTypes = ty.out.reverse ++ ctx.stackTypes.drop(ty.inps.size))) + }) + val dummy = makeDummy + val newMKont: Rep[MCont[Unit]] = funHere((_u: Rep[Unit]) => { + restK(mkont) + }, dummy) + Frames.pushFrame(locals) + Frames.putAll(args) + callee(newMKont) + } + case Import("console", "log", _) + | Import("spectest", "print_i32", _) => + //println(s"[DEBUG] current stack: $stack") + val (v, newCtx) = Stack.pop() + println(v.toInt) + eval(rest, kont, mkont, trail)(newCtx) + case Import("console", "assert", _) => + val (v, newCtx) = Stack.pop() + runtimeAssert(v.toInt != 0) + eval(rest, kont, mkont, trail)(newCtx) + case Import(_, _, _) => throw new Exception(s"Unknown import at $funcIndex") + case _ => throw new Exception(s"Definition at $funcIndex is not callable") + } + } + + def evalTestOp(op: TestOp, value: StagedNum): StagedNum = op match { + case Eqz(_) => value.isZero + } + + def evalUnaryOp(op: UnaryOp, value: StagedNum): StagedNum = op match { + case Clz(_) => value.clz() + case Ctz(_) => value.ctz() + case Popcnt(_) => value.popcnt() + case _ => ??? + } + + def evalBinOp(op: BinOp, v1: StagedNum, v2: StagedNum): StagedNum = op match { + case Add(_) => v1 + v2 + case Mul(_) => v1 * v2 + case Sub(_) => v1 - v2 + case Shl(_) => v1 << v2 + // case ShrS(_) => v1 >> v2 // TODO: signed shift right + case ShrU(_) => v1 >> v2 + case And(_) => v1 & v2 + case DivS(_) => v1 / v2 + case DivU(_) => v1 / v2 + case _ => + throw new Exception(s"Unknown binary operation $op") + } + + def evalRelOp(op: RelOp, v1: StagedNum, v2: StagedNum): StagedNum = op match { + case Eq(_) => v1 numEq v2 + case Ne(_) => v1 numNe v2 + case LtS(_) => v1 < v2 + case LtU(_) => v1 ltu v2 + case GtS(_) => v1 > v2 + case GtU(_) => v1 gtu v2 + case LeS(_) => v1 <= v2 + case LeU(_) => v1 leu v2 + case GeS(_) => v1 >= v2 + case GeU(_) => v1 geu v2 + case _ => ??? + } + + def evalTop(mkont: Rep[MCont[Unit]], main: Option[String]): Rep[Unit] = { + val funBody: FuncBodyDef = main match { + case Some(func_name) => + module.defs.flatMap({ + case Export(`func_name`, ExportFunc(fid)) => + Predef.println(s"Now compiling start with function $main") + module.funcs(fid) match { + case FuncDef(_, body@FuncBodyDef(_,_,_,_)) => Some(body) + case _ => throw new Exception("Entry function has no concrete body") + } + case _ => None + }).head + case None => + val startIds = module.defs.flatMap { + case Start(id) => Some(id) + case _ => None + } + val startId = startIds.headOption.getOrElse { throw new Exception("No start function") } + module.funcs(startId) match { + case FuncDef(_, body@FuncBodyDef(_,_,_,_)) => body + case _ => + throw new Exception("Entry function has no concrete body") + } + } + val (instrs, locals) = (funBody.body, funBody.locals) + resetStacks() + Frames.pushFrame(locals) + eval(instrs, (_: Context) => forwardKont, mkont, ((_: Context) => forwardKont)::Nil)(Context(Nil, locals)) + Frames.popFrame(locals.size) + } + + def evalTop(main: Option[String], printRes: Boolean, dumpTree: Option[String]): Rep[Unit] = { + val haltK: Rep[Unit] => Rep[Unit] = (_) => { + info("Exiting the program...") + if (printRes) { + Stack.print() + } + ExploreTree.fillWithFinished() + "no-op".reflectCtrlWith[Unit]() + } + val temp: Rep[MCont[Unit]] = topFun(haltK) + evalTop(temp, main) + } + + def runtimeAssert(b: Rep[Boolean]): Rep[Unit] = { + "assert-true".reflectCtrlWith[Unit](b) + } + + // stack operations + object Stack { + def shift(offset: Int, size: Int)(ctx: Context): Context = { + if (offset > 0) { + "stack-shift".reflectCtrlWith[Unit](offset, size) + } + ctx.shift(offset, size) + } + + def initialize(): Rep[Unit] = { + "stack-init".reflectCtrlWith[Unit]() + } + + def pop()(implicit ctx: Context): (StagedNum, Context) = { + val (ty, newContext) = ctx.pop() + val num = ty match { + case NumType(I32Type) => I32("stack-pop".reflectCtrlWith[Num](), "sym-stack-pop".reflectCtrlWith[SymVal]()) + case NumType(I64Type) => I64("stack-pop".reflectCtrlWith[Num](), "sym-stack-pop".reflectCtrlWith[SymVal]()) + case NumType(F32Type) => F32("stack-pop".reflectCtrlWith[Num](), "sym-stack-pop".reflectCtrlWith[SymVal]()) + case NumType(F32Type) => F64("stack-pop".reflectCtrlWith[Num](), "sym-stack-pop".reflectCtrlWith[SymVal]()) + } + (num, newContext) + } + + def peek(implicit ctx: Context): (StagedNum, Context) = { + val ty = ctx.stackTypes.head + val num = ty match { + case NumType(I32Type) => I32("stack-peek".reflectCtrlWith[Num](), "sym-stack-peek".reflectCtrlWith[SymVal]()) + case NumType(I64Type) => I64("stack-peek".reflectCtrlWith[Num](), "sym-stack-peek".reflectCtrlWith[SymVal]()) + case NumType(F32Type) => F32("stack-peek".reflectCtrlWith[Num](), "sym-stack-peek".reflectCtrlWith[SymVal]()) + case NumType(F32Type) => F64("stack-peek".reflectCtrlWith[Num](), "sym-stack-peek".reflectCtrlWith[SymVal]()) + } + (num, ctx) + } + + def push(num: StagedNum)(implicit ctx: Context): Context = { + num match { + case I32(v, s) => "stack-push".reflectCtrlWith[Unit](v); "sym-stack-push".reflectCtrlWith[Unit](s) + case I64(v, s) => "stack-push".reflectCtrlWith[Unit](v); "sym-stack-push".reflectCtrlWith[Unit](s) + case F32(v, s) => "stack-push".reflectCtrlWith[Unit](v); "sym-stack-push".reflectCtrlWith[Unit](s) + case F64(v, s) => "stack-push".reflectCtrlWith[Unit](v); "sym-stack-push".reflectCtrlWith[Unit](s) + } + ctx.push(num.tipe) + } + + def take(n: Int)(implicit ctx: Context): (List[StagedNum], Context) = n match { + case 0 => (Nil, ctx) + case n => + val (v, newCtx1) = pop() + val (rest, newCtx2) = take(n - 1) + (v::rest, newCtx2) + } + + def drop(n: Int)(implicit ctx: Context): Context = { + take(n)._2 + } + + def shift(offset: Rep[Int], size: Rep[Int]): Rep[Unit] = { + if (offset > 0) { + "stack-shift".reflectCtrlWith[Unit](offset, size) + "sym-stack-shift".reflectCtrlWith[Unit](offset, size) + } + } + + def print(): Rep[Unit] = { + "stack-print".reflectCtrlWith[Unit]() + } + + def size: Rep[Int] = { + "stack-size".reflectCtrlWith[Int]() + } + } + + object Frames { + def get(i: Int)(implicit ctx: Context): StagedNum = { + // val offset = ctx.frameTypes.take(i).map(_.size).sum + ctx.frameTypes(i) match { + case NumType(I32Type) => I32("frame-get".reflectCtrlWith[Num](i), "sym-frame-get".reflectCtrlWith[SymVal](i)) + case NumType(I64Type) => I64("frame-get".reflectCtrlWith[Num](i), "sym-frame-get".reflectCtrlWith[SymVal](i)) + case NumType(F32Type) => F32("frame-get".reflectCtrlWith[Num](i), "sym-frame-get".reflectCtrlWith[SymVal](i)) + case NumType(F64Type) => F64("frame-get".reflectCtrlWith[Num](i), "sym-frame-get".reflectCtrlWith[SymVal](i)) + } + } + + def set(i: Int, v: StagedNum)(implicit ctx: Context): Rep[Unit] = { + // val offset = ctx.frameTypes.take(i).map(_.size).sum + v match { + case I32(v, s) => "frame-set".reflectCtrlWith[Unit](i, v); "sym-frame-set".reflectCtrlWith[Unit](i, s) + case I64(v, s) => "frame-set".reflectCtrlWith[Unit](i, v); "sym-frame-set".reflectCtrlWith[Unit](i, s) + case F32(v, s) => "frame-set".reflectCtrlWith[Unit](i, v); "sym-frame-set".reflectCtrlWith[Unit](i, s) + case F64(v, s) => "frame-set".reflectCtrlWith[Unit](i, v); "sym-frame-set".reflectCtrlWith[Unit](i, s) + } + } + + def pushFrame(locals: List[ValueType]): Rep[Unit] = { + // Predef.println(s"[DEBUG] push frame: $locals") + val size = locals.size + "frame-push".reflectCtrlWith[Unit](size) + "sym-frame-push".reflectCtrlWith[Unit](size) + } + + def popFrame(size: Int): Rep[Unit] = { + "frame-pop".reflectCtrlWith[Unit](size) + "sym-frame-pop".reflectCtrlWith[Unit](size) + } + + def putAll(args: List[StagedNum])(implicit ctx: Context): Rep[Unit] = { + for ((arg, i) <- args.view.reverse.zipWithIndex) { + Frames.set(i, arg) + } + } + } + + object Memory { + def storeInt(base: Rep[Int], offset: Int, value: Rep[Int]): Rep[Unit] = { + "memory-store-int".reflectCtrlWith[Unit](base, offset, value) + // todo: store symbolic value to memory via extract/concat operation + } + + def loadInt(base: Rep[Int], offset: Int): StagedNum = { + I32("I32V".reflectCtrlWith[Num]("memory-load-int".reflectCtrlWith[Int](base, offset)), "sym-load-int-todo".reflectCtrlWith[SymVal](base, offset)) + } + + // Returns the previous memory size on success, or -1 if the memory cannot be grown. + def grow(delta: Rep[Int]): Rep[Int] = { + "memory-grow".reflectCtrlWith[Int](delta) + } + } + + def resetStacks(): Rep[Unit] = { + "reset-stacks".reflectCtrlWith[Unit]() + } + + // call unreachable + def unreachable(): Rep[Unit] = { + "unreachable".reflectCtrlWith[Unit]() + } + + def info(xs: Rep[_]*): Rep[Unit] = { + "info".reflectCtrlWith[Unit](xs: _*) + } + + // runtime values + object Values { + def I32V(i: Rep[Int]): Rep[Num] = { + "I32V".reflectCtrlWith[Num](i) + } + + def I64V(i: Rep[Long]): Rep[Num] = { + "I64V".reflectCtrlWith[Num](i) + } + } + + // global read/write + object Globals { + def apply(i: Int): StagedNum = { + module.globals(i).ty match { + case GlobalType(NumType(I32Type), _) => I32("global-get".reflectCtrlWith[Num](i), "sym-global-get".reflectCtrlWith[SymVal](i)) + case GlobalType(NumType(I64Type), _) => I64("global-get".reflectCtrlWith[Num](i), "sym-global-get".reflectCtrlWith[SymVal](i)) + case GlobalType(NumType(F32Type), _) => F32("global-get".reflectCtrlWith[Num](i), "sym-global-get".reflectCtrlWith[SymVal](i)) + case GlobalType(NumType(F64Type), _) => F64("global-get".reflectCtrlWith[Num](i), "sym-global-get".reflectCtrlWith[SymVal](i)) + } + } + + def update(i: Int, v: StagedNum): Rep[Unit] = { + module.globals(i).ty match { + case GlobalType(NumType(I32Type), _) => "global-set".reflectCtrlWith[Unit](i, v.i);"sym-global-set".reflectCtrlWith[Unit](i, v.s) + case GlobalType(NumType(I64Type), _) => "global-set".reflectCtrlWith[Unit](i, v.i);"sym-global-set".reflectCtrlWith[Unit](i, v.s) + case GlobalType(NumType(F32Type), _) => "global-set".reflectCtrlWith[Unit](i, v.i);"sym-global-set".reflectCtrlWith[Unit](i, v.s) + case GlobalType(NumType(F64Type), _) => "global-set".reflectCtrlWith[Unit](i, v.i);"sym-global-set".reflectCtrlWith[Unit](i, v.s) + } + } + } + + // Exploration tree, + object ExploreTree { + def fillWithIfElse(s: Rep[SymVal]): Rep[Unit] = { + "tree-fill-if-else".reflectCtrlWith[Unit](s) + } + + def fillWithFinished(): Rep[Unit] = { + "tree-fill-finished".reflectCtrlWith[Unit]() + } + + def moveCursor(branch: Boolean): Rep[Unit] = { + "tree-move-cursor".reflectCtrlWith[Unit](branch) + } + + def print(): Rep[Unit] = { + "tree-print".reflectCtrlWith[Unit]() + } + + def dumpGraphiviz(filePath: String): Rep[Unit] = { + "tree-dump-graphviz".reflectCtrlWith[Unit](filePath) + } + } + + object SymEnv { + def read(sym: Rep[SymVal]): Rep[Num] = { + "sym-env-read".reflectCtrlWith[Num](sym) + } + } + + // runtime Num type + implicit class StagedNumOps(num: StagedNum) { + + def toInt: Rep[Int] = "num-to-int".reflectCtrlWith[Int](num.i) + + def isZero(): StagedNum = num match { + case I32(x_c, x_s) => I32(Values.I32V("is-zero".reflectCtrlWith[Int](num.toInt)), "sym-is-zero".reflectCtrlWith[SymVal](x_s)) + } + + def clz(): StagedNum = num match { + case I32(x_c, x_s) => I32("clz".reflectCtrlWith[Num](x_c), "sym-clz".reflectCtrlWith[SymVal](x_s)) + case I64(x_c, x_s) => I64("clz".reflectCtrlWith[Num](x_c), "sym-clz".reflectCtrlWith[SymVal](x_s)) + } + + def ctz(): StagedNum = num match { + case I32(x_c, x_s) => I32("ctz".reflectCtrlWith[Num](x_c), "sym-ctz".reflectCtrlWith[SymVal](x_s)) + case I64(x_c, x_s) => I64("ctz".reflectCtrlWith[Num](x_c), "sym-ctz".reflectCtrlWith[SymVal](x_s)) + } + + def popcnt(): StagedNum = num match { + case I32(x_c, x_s) => I32("popcnt".reflectCtrlWith[Num](x_c), "sym-popcnt".reflectCtrlWith[SymVal](x_s)) + case I64(x_c, x_s) => I64("popcnt".reflectCtrlWith[Num](x_c), "sym-popcnt".reflectCtrlWith[SymVal](x_s)) + } + + def makeSymbolic(): Rep[SymVal] = { + "make-symbolic".reflectCtrlWith[SymVal](num.s) + } + + def +(rhs: StagedNum): StagedNum = { + (num, rhs) match { + case (I32(x_c, x_s), I32(y_c, y_s)) => I32("binary-add".reflectCtrlWith[Num](x_c, y_c), "sym-binary-add".reflectCtrlWith[SymVal](x_s, y_s)) + case (I64(x_c, x_s), I64(y_c, y_s)) => I64("binary-add".reflectCtrlWith[Num](x_c, y_c), "sym-binary-add".reflectCtrlWith[SymVal](x_s, y_s)) + case (F32(x_c, x_s), F32(y_c, y_s)) => F32("binary-add".reflectCtrlWith[Num](x_c, y_c), "sym-binary-add".reflectCtrlWith[SymVal](x_s, y_s)) + case (F64(x_c, x_s), F64(y_c, y_s)) => F64("binary-add".reflectCtrlWith[Num](x_c, y_c), "sym-binary-add".reflectCtrlWith[SymVal](x_s, y_s)) + } + } + + + def -(rhs: StagedNum): StagedNum = { + (num, rhs) match { + case (I32(x_c, x_s), I32(y_c, y_s)) => I32("binary-sub".reflectCtrlWith[Num](x_c, y_c), "sym-binary-sub".reflectCtrlWith[SymVal](x_s, y_s)) + case (I64(x_c, x_s), I64(y_c, y_s)) => I64("binary-sub".reflectCtrlWith[Num](x_c, y_c), "sym-binary-sub".reflectCtrlWith[SymVal](x_s, y_s)) + case (F32(x_c, x_s), F32(y_c, y_s)) => F32("binary-sub".reflectCtrlWith[Num](x_c, y_c), "sym-binary-sub".reflectCtrlWith[SymVal](x_s, y_s)) + case (F64(x_c, x_s), F64(y_c, y_s)) => F64("binary-sub".reflectCtrlWith[Num](x_c, y_c), "sym-binary-sub".reflectCtrlWith[SymVal](x_s, y_s)) + } + } + + def *(rhs: StagedNum): StagedNum = { + (num, rhs) match { + case (I32(x_c, x_s), I32(y_c, y_s)) => I32("binary-mul".reflectCtrlWith[Num](x_c, y_c), "sym-binary-mul".reflectCtrlWith[SymVal](x_s, y_s)) + case (I64(x_c, x_s), I64(y_c, y_s)) => I64("binary-mul".reflectCtrlWith[Num](x_c, y_c), "sym-binary-mul".reflectCtrlWith[SymVal](x_s, y_s)) + case (F32(x_c, x_s), F32(y_c, y_s)) => F32("binary-mul".reflectCtrlWith[Num](x_c, y_c), "sym-binary-mul".reflectCtrlWith[SymVal](x_s, y_s)) + case (F64(x_c, x_s), F64(y_c, y_s)) => F64("binary-mul".reflectCtrlWith[Num](x_c, y_c), "sym-binary-mul".reflectCtrlWith[SymVal](x_s, y_s)) + } + } + + def /(rhs: StagedNum): StagedNum = { + (num, rhs) match { + case (I32(x_c, x_s), I32(y_c, y_s)) => I32("binary-div".reflectCtrlWith[Num](x_c, y_c), "sym-binary-div".reflectCtrlWith[SymVal](x_s, y_s)) + case (I64(x_c, x_s), I64(y_c, y_s)) => I64("binary-div".reflectCtrlWith[Num](x_c, y_c), "sym-binary-div".reflectCtrlWith[SymVal](x_s, y_s)) + case (F32(x_c, x_s), F32(y_c, y_s)) => F32("binary-div".reflectCtrlWith[Num](x_c, y_c), "sym-binary-div".reflectCtrlWith[SymVal](x_s, y_s)) + case (F64(x_c, x_s), F64(y_c, y_s)) => F64("binary-div".reflectCtrlWith[Num](x_c, y_c), "sym-binary-div".reflectCtrlWith[SymVal](x_s, y_s)) + } + } + + def <<(rhs: StagedNum): StagedNum = { + (num, rhs) match { + case (I32(x_c, x_s), I32(y_c, y_s)) => I32("binary-shl".reflectCtrlWith[Num](x_c, y_c), "sym-binary-shl".reflectCtrlWith[SymVal](x_s, y_s)) + case (I64(x_c, x_s), I64(y_c, y_s)) => I64("binary-shl".reflectCtrlWith[Num](x_c, y_c), "sym-binary-shl".reflectCtrlWith[SymVal](x_s, y_s)) + case (F32(x_c, x_s), F32(y_c, y_s)) => F32("binary-shl".reflectCtrlWith[Num](x_c, y_c), "sym-binary-shl".reflectCtrlWith[SymVal](x_s, y_s)) + case (F64(x_c, x_s), F64(y_c, y_s)) => F64("binary-shl".reflectCtrlWith[Num](x_c, y_c), "sym-binary-shl".reflectCtrlWith[SymVal](x_s, y_s)) + } + } + + def >>(rhs: StagedNum): StagedNum = { + (num, rhs) match { + case (I32(x_c, x_s), I32(y_c, y_s)) => I32("binary-shr".reflectCtrlWith[Num](x_c, y_c), "sym-binary-shr".reflectCtrlWith[SymVal](x_s, y_s)) + case (I64(x_c, x_s), I64(y_c, y_s)) => I64("binary-shr".reflectCtrlWith[Num](x_c, y_c), "sym-binary-shr".reflectCtrlWith[SymVal](x_s, y_s)) + case (F32(x_c, x_s), F32(y_c, y_s)) => F32("binary-shr".reflectCtrlWith[Num](x_c, y_c), "sym-binary-shr".reflectCtrlWith[SymVal](x_s, y_s)) + case (F64(x_c, x_s), F64(y_c, y_s)) => F64("binary-shr".reflectCtrlWith[Num](x_c, y_c), "sym-binary-shr".reflectCtrlWith[SymVal](x_s, y_s)) + } + } + + def &(rhs: StagedNum): StagedNum = { + (num, rhs) match { + case (I32(x_c, x_s), I32(y_c, y_s)) => I32("binary-and".reflectCtrlWith[Num](x_c, y_c), "sym-binary-and".reflectCtrlWith[SymVal](x_s, y_s)) + case (I64(x_c, x_s), I64(y_c, y_s)) => I64("binary-and".reflectCtrlWith[Num](x_c, y_c), "sym-binary-and".reflectCtrlWith[SymVal](x_s, y_s)) + case (F32(x_c, x_s), F32(y_c, y_s)) => F32("binary-and".reflectCtrlWith[Num](x_c, y_c), "sym-binary-and".reflectCtrlWith[SymVal](x_s, y_s)) + case (F64(x_c, x_s), F64(y_c, y_s)) => F64("binary-and".reflectCtrlWith[Num](x_c, y_c), "sym-binary-and".reflectCtrlWith[SymVal](x_s, y_s)) + } + } + + def numEq(rhs: StagedNum): StagedNum = { + (num, rhs) match { + case (I32(x_c, x_s), I32(y_c, y_s)) => I32("relation-eq".reflectCtrlWith[Num](x_c, y_c), "sym-relation-eq".reflectCtrlWith[SymVal](x_s, y_s)) + case (I64(x_c, x_s), I64(y_c, y_s)) => I32("relation-eq".reflectCtrlWith[Num](x_c, y_c), "sym-relation-eq".reflectCtrlWith[SymVal](x_s, y_s)) + } + } + + def numNe(rhs: StagedNum): StagedNum = { + (num, rhs) match { + case (I32(x_c, x_s), I32(y_c, y_s)) => I32("relation-ne".reflectCtrlWith[Num](x_c, y_c), "sym-relation-ne".reflectCtrlWith[SymVal](x_s, y_s)) + case (I64(x_c, x_s), I64(y_c, y_s)) => I32("relation-ne".reflectCtrlWith[Num](x_c, y_c), "sym-relation-ne".reflectCtrlWith[SymVal](x_s, y_s)) + } + } + + def <(rhs: StagedNum): StagedNum = { + (num, rhs) match { + case (I32(x_c, x_s), I32(y_c, y_s)) => I32("relation-lt".reflectCtrlWith[Num](x_c, y_c), "sym-relation-lt".reflectCtrlWith[SymVal](x_s, y_s)) + case (I64(x_c, x_s), I64(y_c, y_s)) => I32("relation-lt".reflectCtrlWith[Num](x_c, y_c), "sym-relation-lt".reflectCtrlWith[SymVal](x_s, y_s)) + } + } + + def ltu(rhs: StagedNum): StagedNum = { + (num, rhs) match { + case (I32(x_c, x_s), I32(y_c, y_s)) => I32("relation-ltu".reflectCtrlWith[Num](x_c, y_c), "sym-relation-ltu".reflectCtrlWith[SymVal](x_s, y_s)) + case (I64(x_c, x_s), I64(y_c, y_s)) => I32("relation-ltu".reflectCtrlWith[Num](x_c, y_c), "sym-relation-ltu".reflectCtrlWith[SymVal](x_s, y_s)) + } + } + + def >(rhs: StagedNum): StagedNum = { + (num, rhs) match { + case (I32(x_c, x_s), I32(y_c, y_s)) => I32("relation-gt".reflectCtrlWith[Num](x_c, y_c), "sym-relation-gt".reflectCtrlWith[SymVal](x_s, y_s)) + case (I64(x_c, x_s), I64(y_c, y_s)) => I32("relation-gt".reflectCtrlWith[Num](x_c, y_c), "sym-relation-gt".reflectCtrlWith[SymVal](x_s, y_s)) + } + } + + def gtu(rhs: StagedNum): StagedNum = { + (num, rhs) match { + case (I32(x_c, x_s), I32(y_c, y_s)) => I32("relation-gtu".reflectCtrlWith[Num](x_c, y_c), "sym-relation-gtu".reflectCtrlWith[SymVal](x_s, y_s)) + case (I64(x_c, x_s), I64(y_c, y_s)) => I32("relation-gtu".reflectCtrlWith[Num](x_c, y_c), "sym-relation-gtu".reflectCtrlWith[SymVal](x_s, y_s)) + } + } + + def <=(rhs: StagedNum): StagedNum = { + (num, rhs) match { + case (I32(x_c, x_s), I32(y_c, y_s)) => I32("relation-le".reflectCtrlWith[Num](x_c, y_c), "sym-relation-le".reflectCtrlWith[SymVal](x_s, y_s)) + case (I64(x_c, x_s), I64(y_c, y_s)) => I32("relation-le".reflectCtrlWith[Num](x_c, y_c), "sym-relation-le".reflectCtrlWith[SymVal](x_s, y_s)) + } + } + + def leu(rhs: StagedNum): StagedNum = { + (num, rhs) match { + case (I32(x_c, x_s), I32(y_c, y_s)) => I32("relation-leu".reflectCtrlWith[Num](x_c, y_c), "sym-relation-leu".reflectCtrlWith[SymVal](x_s, y_s)) + case (I64(x_c, x_s), I64(y_c, y_s)) => I32("relation-leu".reflectCtrlWith[Num](x_c, y_c), "sym-relation-leu".reflectCtrlWith[SymVal](x_s, y_s)) + } + } + + def >=(rhs: StagedNum): StagedNum = { + (num, rhs) match { + case (I32(x_c, x_s), I32(y_c, y_s)) => I32("relation-ge".reflectCtrlWith[Num](x_c, y_c), "sym-relation-ge".reflectCtrlWith[SymVal](x_s, y_s)) + case (I64(x_c, x_s), I64(y_c, y_s)) => I32("relation-ge".reflectCtrlWith[Num](x_c, y_c), "sym-relation-ge".reflectCtrlWith[SymVal](x_s, y_s)) + } + } + + def geu(rhs: StagedNum): StagedNum = { + (num, rhs) match { + case (I32(x_c, x_s), I32(y_c, y_s)) => I32("relation-geu".reflectCtrlWith[Num](x_c, y_c), "sym-relation-geu".reflectCtrlWith[SymVal](x_s, y_s)) + case (I64(x_c, x_s), I64(y_c, y_s)) => I32("relation-geu".reflectCtrlWith[Num](x_c, y_c), "sym-relation-geu".reflectCtrlWith[SymVal](x_s, y_s)) + } + } + } + + implicit class SymbolicOps(s: Rep[SymVal]) { + def not(): Rep[SymVal] = { + "sym-not".reflectCtrlWith(s) + } + } +} + +trait StagedWasmCppGen extends CGenBase with CppSAICodeGenBase { + // clear include path and headers by first + includePaths.clear() + headers.clear() + + registerHeader("headers", "\"wasm.hpp\"") + registerHeader("") + registerHeader("") + registerHeader("") + registerHeader("") + + override def mayInline(n: Node): Boolean = n match { + case Node(_, "stack-pop", _, _) + | Node(_, "stack-peek", _, _) + | Node(_, "sym-stack-pop", _, _) + => false + case _ => super.mayInline(n) + } + + override def remap(m: Manifest[_]): String = { + if (m.toString.endsWith("Num")) "Num" + else if (m.toString.endsWith("Frame")) "Frame" + else if (m.toString.endsWith("Stack")) "Stack" + else if (m.toString.endsWith("Global")) "Global" + else if (m.toString.endsWith("I32V")) "I32V" + else if (m.toString.endsWith("I64V")) "I64V" + else if (m.toString.endsWith("SymVal")) "SymVal" + + else super.remap(m) + } + + override def traverse(n: Node): Unit = n match { + case Node(_, "stack-push", List(value), _) => + emit("Stack.push("); shallow(value); emit(");\n") + case Node(_, "sym-stack-push", List(s_value), _) => + emit("SymStack.push("); shallow(s_value); emit(");\n") + case Node(_, "stack-drop", List(n), _) => + emit("Stack.drop("); shallow(n); emit(");\n") + case Node(_, "stack-init", _, _) => + emit("Stack.initialize();\n") + case Node(_, "stack-print", _, _) => + emit("Stack.print();\n") + case Node(_, "frame-push", List(i), _) => + emit("Frames.pushFrame("); shallow(i); emit(");\n") + case Node(_, "sym-frame-push", List(i), _) => + emit("SymFrames.pushFrame("); shallow(i); emit(");\n") + case Node(_, "frame-pop", List(i), _) => + emit("Frames.popFrame("); shallow(i); emit(");\n") + case Node(_, "frame-set", List(i, value), _) => + emit("Frames.set("); shallow(i); emit(", "); shallow(value); emit(");\n") + case Node(_, "sym-frame-set", List(i, s_value), _) => + emit("SymFrames.set("); shallow(i); emit(", "); shallow(s_value); emit(");\n") + case Node(_, "global-set", List(i, value), _) => + emit("Global.globalSet("); shallow(i); emit(", "); shallow(value); emit(");\n") + // Note: The following code is copied from the traverse of CppBackend.scala, try to avoid duplicated code + case n @ Node(f, "λ", (b: LMSBlock)::LMSConst(0)::rest, _) => + // TODO: Is a leading block followed by 0 a hint for top function? + super.traverse(n) + case n @ Node(f, "λ", (b: LMSBlock)::rest, _) => + val retType = remap(typeBlockRes(b.res)) + val argTypes = b.in.map(a => remap(typeMap(a))).mkString(", ") + emitln(s"std::function<$retType(${argTypes})> ${quote(f)};") + emit(quote(f)); emit(" = ") + quoteTypedBlock(b, false, true, capture = "&") + emitln(";") + case _ => super.traverse(n) + } + + override def shallow(n: Node): Unit = n match { + case Node(_, "reset-stacks", _, _) => + emit("reset_stacks()") + case Node(_, "frame-get", List(i), _) => + emit("Frames.get("); shallow(i); emit(")") + case Node(_, "sym-frame-get", List(i), _) => + emit("SymFrames.get("); shallow(i); emit(")") + case Node(_, "stack-drop", List(n), _) => + emit("Stack.drop("); shallow(n); emit(")") + case Node(_, "stack-push", List(value), _) => + emit("Stack.push("); shallow(value); emit(")") + case Node(_, "stack-shift", List(offset, size), _) => + emit("Stack.shift("); shallow(offset); emit(", "); shallow(size); emit(")") + case Node(_, "stack-pop", _, _) => + emit("Stack.pop()") + case Node(_, "sym-stack-pop", _, _) => + emit("SymStack.pop()") + case Node(_, "frame-pop", List(i), _) => + emit("Frames.popFrame("); shallow(i); emit(")") + case Node(_, "sym-frame-pop", List(i), _) => + emit("SymFrames.popFrame("); shallow(i); emit(")") + case Node(_, "stack-peek", _, _) => + emit("Stack.peek()") + case Node(_, "sym-stack-peek", _, _) => + emit("SymStack.peek()") + case Node(_, "stack-take", List(n), _) => + emit("Stack.take("); shallow(n); emit(")") + case Node(_, "slice-reverse", List(slice), _) => + shallow(slice); emit(".reverse") + case Node(_, "memory-store-int", List(base, offset, value), _) => + emit("Memory.storeInt("); shallow(base); emit(", "); shallow(offset); emit(", "); shallow(value); emit(")") + case Node(_, "memory-load-int", List(base, offset), _) => + emit("Memory.loadInt("); shallow(base); emit(", "); shallow(offset); emit(")") + case Node(_, "memory-grow", List(delta), _) => + emit("Memory.grow("); shallow(delta); emit(")") + case Node(_, "stack-size", _, _) => + emit("Stack.size()") + case Node(_, "global-get", List(i), _) => + emit("Global.globalGet("); shallow(i); emit(")") + case Node(_, "is-zero", List(num), _) => + emit("(0 == "); shallow(num); emit(")") + case Node(_, "sym-is-zero", List(s_num), _) => + shallow(s_num); emit(".is_zero()") + case Node(_, "binary-add", List(lhs, rhs), _) => + shallow(lhs); emit(" + "); shallow(rhs) + case Node(_, "binary-sub", List(lhs, rhs), _) => + shallow(lhs); emit(" - "); shallow(rhs) + case Node(_, "binary-mul", List(lhs, rhs), _) => + shallow(lhs); emit(" * "); shallow(rhs) + case Node(_, "binary-div", List(lhs, rhs), _) => + shallow(lhs); emit(" / "); shallow(rhs) + case Node(_, "binary-shl", List(lhs, rhs), _) => + shallow(lhs); emit(" << "); shallow(rhs) + case Node(_, "binary-shr", List(lhs, rhs), _) => + shallow(lhs); emit(" >> "); shallow(rhs) + case Node(_, "binary-and", List(lhs, rhs), _) => + shallow(lhs); emit(" & "); shallow(rhs) + case Node(_, "relation-eq", List(lhs, rhs), _) => + shallow(lhs); emit(" == "); shallow(rhs) + case Node(_, "relation-ne", List(lhs, rhs), _) => + shallow(lhs); emit(" != "); shallow(rhs) + case Node(_, "relation-lt", List(lhs, rhs), _) => + shallow(lhs); emit(" < "); shallow(rhs) + case Node(_, "relation-ltu", List(lhs, rhs), _) => + shallow(lhs); emit(" < "); shallow(rhs) + case Node(_, "relation-gt", List(lhs, rhs), _) => + shallow(lhs); emit(" > "); shallow(rhs) + case Node(_, "relation-gtu", List(lhs, rhs), _) => + shallow(lhs); emit(" > "); shallow(rhs) + case Node(_, "relation-le", List(lhs, rhs), _) => + shallow(lhs); emit(" <= "); shallow(rhs) + case Node(_, "relation-leu", List(lhs, rhs), _) => + shallow(lhs); emit(" <= "); shallow(rhs) + case Node(_, "relation-ge", List(lhs, rhs), _) => + shallow(lhs); emit(" >= "); shallow(rhs) + case Node(_, "relation-geu", List(lhs, rhs), _) => + shallow(lhs); emit(" >= "); shallow(rhs) + case Node(_, "sym-binary-add", List(lhs, rhs), _) => + shallow(lhs); emit(".add("); shallow(rhs); emit(")") + case Node(_, "sym-binary-sub", List(lhs, rhs), _) => + shallow(lhs); emit(".minus("); shallow(rhs); emit(")") + case Node(_, "sym-binary-mul", List(lhs, rhs), _) => + shallow(lhs); emit(".mul("); shallow(rhs); emit(")") + case Node(_, "sym-binary-div", List(lhs, rhs), _) => + shallow(lhs); emit(".div("); shallow(rhs); emit(")") + case Node(_, "sym-relation-le", List(lhs, rhs), _) => + shallow(lhs); emit(".leq("); shallow(rhs); emit(")") + case Node(_, "sym-relation-leu", List(lhs, rhs), _) => + shallow(lhs); emit(".leu("); shallow(rhs); emit(")") + case Node(_, "sym-relation-ge", List(lhs, rhs), _) => + shallow(lhs); emit(".ge("); shallow(rhs); emit(")") + case Node(_, "sym-relation-geu", List(lhs, rhs), _) => + shallow(lhs); emit(".geu("); shallow(rhs); emit(")") + case Node(_, "sym-relation-eq", List(lhs, rhs), _) => + shallow(lhs); emit(".eq("); shallow(rhs); emit(")") + case Node(_, "sym-relation-ne", List(lhs, rhs), _) => + shallow(lhs); emit(".neq("); shallow(rhs); emit(")") + case Node(_, "num-to-int", List(num), _) => + shallow(num); emit(".toInt()") + case Node(_, "make-symbolic", List(num), _) => + shallow(num); emit(".makeSymbolic()") + case Node(_, "sym-env-read", List(sym), _) => + emit("SymEnv.read("); shallow(sym); emit(")") + case Node(_, "assert-true", List(cond), _) => + emit("GENSYM_ASSERT("); shallow(cond); emit(")") + case Node(_, "tree-fill-if-else", List(s), _) => + emit("ExploreTree.fillIfElseNode("); shallow(s); emit(")") + case Node(_, "tree-fill-finished", List(), _) => + emit("ExploreTree.fillFinishedNode()") + case Node(_, "tree-move-cursor", List(b), _) => + emit("ExploreTree.moveCursor("); shallow(b); emit(")") + case Node(_, "tree-print", List(), _) => + emit("ExploreTree.print()") + case Node(_, "tree-dump-graphviz", List(f), _) => + emit("ExploreTree.dump_graphviz("); shallow(f); emit(")") + case Node(_, "sym-not", List(s), _) => + shallow(s); emit(".negate()") + case Node(_, "dummy", _, _) => emit("std::monostate()") + case Node(_, "dummy-op", _, _) => emit("std::monostate()") + case Node(_, "no-op", _, _) => + emit("std::monostate()") + case _ => super.shallow(n) + } + + override def registerTopLevelFunction(id: String, streamId: String = "general")(f: => Unit) = + if (!registeredFunctions(id)) { + //if (ongoingFun(streamId)) ??? + //ongoingFun += streamId + registeredFunctions += id + withStream(functionsStreams.getOrElseUpdate(id, { + val functionsStream = new java.io.ByteArrayOutputStream() + val functionsWriter = new java.io.PrintStream(functionsStream) + (functionsWriter, functionsStream) + })._1)(f) + //ongoingFun -= streamId + } else { + // If a function is registered, don't re-register it. + // withStream(functionsStreams(id)._1)(f) + } + + override def emitAll(g: Graph, name: String)(m1: Manifest[_], m2: Manifest[_]): Unit = { + val ng = init(g) + emitHeaders(stream) + emitln(""" + |/***************************************** + |Emitting Generated Code + |*******************************************/ + """.stripMargin) + val src = run(name, ng) + emitFunctionDecls(stream) + emitDatastructures(stream) + emitFunctions(stream) + emit(src) + emitln(""" + |/***************************************** + |End of Generated Code + |*******************************************/ + |int main(int argc, char *argv[]) { + | start_concolic_execution_with(Snippet); + | return 0; + |}""".stripMargin) + } +} + +trait WasmToCppCompilerDriver[A, B] extends CppSAIDriver[A, B] with StagedWasmEvaluator { q => + override val codegen = new StagedWasmCppGen { + val IR: q.type = q + import IR._ + } +} + +object WasmToCppCompiler { + case class GeneratedCpp(source: String, headerFolders: List[String]) + + def compile(moduleInst: ModuleInstance, main: Option[String], printRes: Boolean, dumpTree: Option[String]): GeneratedCpp = { + println(s"Now compiling wasm module with entry function $main") + val driver = new WasmToCppCompilerDriver[Unit, Unit] { + def module: ModuleInstance = moduleInst + def snippet(x: Rep[Unit]): Rep[Unit] = { + evalTop(main, printRes, dumpTree) + } + } + GeneratedCpp(driver.code, driver.codegen.includePaths.toList) + } + + def compileToExe(moduleInst: ModuleInstance, + main: Option[String], + outputCpp: String, + outputExe: String, + printRes: Boolean, + dumpTree: Option[String]): Unit = { + val generated = compile(moduleInst, main, printRes, dumpTree) + val code = generated.source + + val writer = new java.io.PrintWriter(new java.io.File(outputCpp)) + try { + writer.write(code) + } finally { + writer.close() + } + + import sys.process._ + val command = s"g++ -std=c++20 $outputCpp -o $outputExe -O3 -g -l z3 " + generated.headerFolders.map(f => s"-I$f").mkString(" ") + if (command.! != 0) { + throw new RuntimeException(s"Compilation failed for $outputCpp") + } + } + +} + + diff --git a/src/main/scala/wasm/StagedMiniWasm.scala b/src/main/scala/wasm/StagedMiniWasm.scala index 2e22e7a7..ea9dc9c6 100644 --- a/src/main/scala/wasm/StagedMiniWasm.scala +++ b/src/main/scala/wasm/StagedMiniWasm.scala @@ -1,4 +1,4 @@ -package gensym.wasm.miniwasm +package gensym.wasm.stagedminiwasm import scala.collection.mutable.{ArrayBuffer, HashMap} @@ -12,6 +12,7 @@ import lms.core.Graph import gensym.wasm.ast._ import gensym.wasm.ast.{Const => WasmConst, Block => WasmBlock} +import gensym.wasm.miniwasm.ModuleInstance import gensym.lmsx.{SAIDriver, StringOps, SAIOps, SAICodeGenBase, CppSAIDriver, CppSAICodeGenBase} @virtualize @@ -433,10 +434,10 @@ trait StagedWasmEvaluator extends SAIOps { def push(v: StagedNum)(implicit ctx: Context): Context = { v match { - case I32(v) => NumType(I32Type); "stack-push".reflectCtrlWith[Unit](v) - case I64(v) => NumType(I64Type); "stack-push".reflectCtrlWith[Unit](v) - case F32(v) => NumType(F32Type); "stack-push".reflectCtrlWith[Unit](v) - case F64(v) => NumType(F64Type); "stack-push".reflectCtrlWith[Unit](v) + case I32(v) => "stack-push".reflectCtrlWith[Unit](v) + case I64(v) => "stack-push".reflectCtrlWith[Unit](v) + case F32(v) => "stack-push".reflectCtrlWith[Unit](v) + case F64(v) => "stack-push".reflectCtrlWith[Unit](v) } ctx.push(v.tipe) } diff --git a/src/test/scala/genwasym/TestStagedConcolicEval.scala b/src/test/scala/genwasym/TestStagedConcolicEval.scala new file mode 100644 index 00000000..a65d0eda --- /dev/null +++ b/src/test/scala/genwasym/TestStagedConcolicEval.scala @@ -0,0 +1,42 @@ +package gensym.wasm + +import org.scalatest.FunSuite + +import lms.core.stub.Adapter + +import gensym.wasm.miniwasm.{ModuleInstance} +import gensym.wasm.parser._ +import gensym.wasm.stagedconcolicminiwasm._ + +class TestStagedConcolicEval extends FunSuite { + def testFileToCpp(filename: String, main: Option[String] = None, expect: Option[List[Float]]=None) = { + val moduleInst = ModuleInstance(Parser.parseFile(filename)) + val cppFile = s"$filename.cpp" + val exe = s"$cppFile.exe" + val exploreTreeFile = s"$filename.tree.dot" + WasmToCppCompiler.compileToExe(moduleInst, main, cppFile, exe, true, Some(exploreTreeFile)) + + import sys.process._ + val result = s"./$exe".!! + println(result) + + expect.map(vs => { + val stackValues = result + .split("Stack contents: \n")(1) + .split("\n") + .map(_.toFloat) + .toList + assert(vs == stackValues) + }) + } + + test("ack-cpp") { testFileToCpp("./benchmarks/wasm/ack.wat", Some("real_main")) } + + test("bug-finding") { + testFileToCpp("./benchmarks/wasm/branch-strip-buggy.wat", Some("real_main")) + } + + test("brtable-bug-finding") { + testFileToCpp("./benchmarks/wasm/staged/brtable_concolic.wat") + } +} diff --git a/src/test/scala/genwasym/TestStagedEval.scala b/src/test/scala/genwasym/TestStagedEval.scala index d4d1e960..3769428f 100644 --- a/src/test/scala/genwasym/TestStagedEval.scala +++ b/src/test/scala/genwasym/TestStagedEval.scala @@ -6,6 +6,7 @@ import lms.core.stub.Adapter import gensym.wasm.parser._ import gensym.wasm.miniwasm._ +import gensym.wasm.stagedminiwasm._ class TestStagedEval extends FunSuite { def testFileToScala(filename: String, main: Option[String] = None, printRes: Boolean = false) = {