diff --git a/bootstrap/emit.cc b/bootstrap/emit.cc index e478853..9589e3b 100644 --- a/bootstrap/emit.cc +++ b/bootstrap/emit.cc @@ -1,38 +1,163 @@ #include -#include +#include namespace xlang { -EmitListener::EmitListener(std::string_view outputfile) : output{outputfile} {} +class EmitFunctionVisitor : public xlangBaseVisitor {}; -void EmitListener::enterFile(xlangParser::FileContext *ctx) { - output << "data $printformat = { b \"%ld\\n\", b 0 }" << std::endl; - (void)ctx; +EmitVisitor::EmitVisitor(std::string_view outputfile) : output{outputfile} {} + +std::any EmitVisitor::visitFile(xlangParser::FileContext *ctx) { + output << "data $printformat = { b \"%d\\n\", b 0 }" << std::endl; + visitChildren(ctx); + return {}; } -void EmitListener::enterFunction(xlangParser::FunctionContext *ctx) { +std::any EmitVisitor::visitFunction(xlangParser::FunctionContext *ctx) { output << std::endl << "export function w $" << ctx->Identifier()->getSymbol()->getText() - << "()" << std::endl - << "{" << std::endl - << "@start" << std::endl; + << "("; + if (auto arg_list = ctx->argumentList()) { + for (const auto arg : arg_list->Identifier()) { + output << "w %_" << arg->getSymbol()->getText() << ", "; + } + } + output << ")" << std::endl << "{" << std::endl << "@start" << std::endl; + blockcount = 0; + visitChildren(ctx); + output << " ret 0" << std::endl; + output << "}" << std::endl; + return {}; } -void EmitListener::exitFunction(xlangParser::FunctionContext *ctx) { - output << " ret 0" << std::endl << "}" << std::endl; - (void)ctx; +std::any EmitVisitor::visitBlock(xlangParser::BlockContext *ctx) { + output << "@block" << blockcount++ << std::endl; + visitChildren(ctx); + return {}; } -void EmitListener::exitStatement(xlangParser::StatementContext *ctx) { +std::any EmitVisitor::visitStatement(xlangParser::StatementContext *ctx) { + if (auto identifier = ctx->Identifier()) { + tmpcount = 0; + visitChildren(ctx); + output << " %_" << identifier->getSymbol()->getText() << " = w copy %t" + << tmpcount - 1 << std::endl; + return {}; + } + if (ctx->If()) { + int block = blockcount; + output << "@ifexpr" << block << std::endl; + tmpcount = 0; + visitExpr(ctx->expr()); + output << " jnz %t" << tmpcount - 1 << ", @block" << block << ", @ifelse" + << block << std::endl; + visitBlock(ctx->block(0)); + output << "@ifelse" << block << std::endl; + if (ctx->Else()) { + visitBlock(ctx->block(1)); + } + output << "@ifend" << block << std::endl; + return {}; + } + if (ctx->While()) { + int block = blockcount; + output << "@whileexpr" << block << std::endl; + tmpcount = 0; + visitExpr(ctx->expr()); + output << " jnz %t" << tmpcount - 1 << ", @block" << block << ", @whileend" + << block << std::endl; + visitBlock(ctx->block(0)); + output << " jmp @whileexpr" << block << std::endl; + output << "@whileend" << block << std::endl; + return {}; + } + if (ctx->Return()) { + tmpcount = 0; + visitChildren(ctx); + output << " ret %t" << tmpcount - 1 << std::endl; + output << "@dead" << blockcount++ << std::endl; + return {}; + } if (ctx->Print()) { - output << " call $printf(l $printformat, ..., w %v)" << std::endl; + tmpcount = 0; + visitChildren(ctx); + output << " call $printf(l $printformat, ..., w %t" << tmpcount - 1 << ")" + << std::endl; + return {}; } + // unreachable + return {}; } -void EmitListener::exitFactor(xlangParser::FactorContext *ctx) { - if (auto integer = ctx->Integer()) { - output << " %v = w copy " << integer->getSymbol()->getText() << std::endl; +#define OPERATOR(Operator, visitLeft, left, visitRight, right, ssa_op) \ + if (ctx->Operator()) { \ + visitLeft(ctx->left); \ + int l = tmpcount - 1; \ + visitRight(ctx->right); \ + output << " %t" << tmpcount << " = w " ssa_op " %t" << l << ", %t" \ + << tmpcount - 1 << std::endl; \ + tmpcount++; \ + return {}; \ } + +std::any EmitVisitor::visitExpr(xlangParser::ExprContext *ctx) { + OPERATOR(Less, visitSum, sum(0), visitSum, sum(1), "csltw"); + OPERATOR(LessEqual, visitSum, sum(0), visitSum, sum(1), "cslew"); + OPERATOR(Greater, visitSum, sum(0), visitSum, sum(1), "csgtw"); + OPERATOR(GreaterEqual, visitSum, sum(0), visitSum, sum(1), "csgew"); + OPERATOR(Equal, visitSum, sum(0), visitSum, sum(1), "ceqw"); + OPERATOR(NotEqual, visitSum, sum(0), visitSum, sum(1), "cnew"); + visitSum(ctx->sum(0)); + return {}; +} + +std::any EmitVisitor::visitSum(xlangParser::SumContext *ctx) { + OPERATOR(Plus, visitSum, sum(), visitTerm, term(), "add"); + OPERATOR(Minus, visitSum, sum(), visitTerm, term(), "sub"); + visitTerm(ctx->term()); + return {}; +} + +std::any EmitVisitor::visitTerm(xlangParser::TermContext *ctx) { + OPERATOR(Mul, visitTerm, term(), visitFactor, factor(), "mul"); + OPERATOR(Div, visitTerm, term(), visitFactor, factor(), "div"); + visitFactor(ctx->factor()); + return {}; +} + +std::any EmitVisitor::visitFactor(xlangParser::FactorContext *ctx) { + if (auto integer = ctx->Integer()) { + output << " %t" << tmpcount++ << " = w copy " + << integer->getSymbol()->getText() << std::endl; + return {}; + } + if (auto identifier = ctx->Identifier()) { + if (ctx->LeftParen()) { + std::vector args; + if (auto expr_list = ctx->exprList()) { + for (auto expr : expr_list->expr()) { + visitExpr(expr); + args.push_back(tmpcount - 1); + } + } + output << " %t" << tmpcount++ << " = w call $" + << identifier->getSymbol()->getText() << "("; + for (auto arg : args) { + output << "w %t" << arg << ", "; + } + output << ")" << std::endl; + } else { + output << " %t" << tmpcount++ << " = w copy %_" + << identifier->getSymbol()->getText() << std::endl; + } + return {}; + } + if (auto sum = ctx->sum()) { + visitSum(sum); + return {}; + } + // unreachable + return {}; } } // namespace xlang diff --git a/bootstrap/emit.hh b/bootstrap/emit.hh index 405f63b..695171a 100644 --- a/bootstrap/emit.hh +++ b/bootstrap/emit.hh @@ -2,21 +2,26 @@ #include #include -#include +#include namespace xlang { -class EmitListener : public xlangBaseListener { +class EmitVisitor : public xlangBaseVisitor { std::ofstream output; + int blockcount; + int tmpcount; public: - EmitListener(std::string_view outputfile); + EmitVisitor(std::string_view outputfile); - void enterFile(xlangParser::FileContext *ctx) override; - void enterFunction(xlangParser::FunctionContext *ctx) override; - void exitFunction(xlangParser::FunctionContext *ctx) override; - void exitStatement(xlangParser::StatementContext *ctx) override; - void exitFactor(xlangParser::FactorContext *ctx) override; + std::any visitFile(xlangParser::FileContext *ctx) override; + std::any visitFunction(xlangParser::FunctionContext *ctx) override; + std::any visitBlock(xlangParser::BlockContext *ctx) override; + std::any visitStatement(xlangParser::StatementContext *ctx) override; + std::any visitExpr(xlangParser::ExprContext *ctx) override; + std::any visitSum(xlangParser::SumContext *ctx) override; + std::any visitTerm(xlangParser::TermContext *ctx) override; + std::any visitFactor(xlangParser::FactorContext *ctx) override; }; } // namespace xlang diff --git a/bootstrap/main.cc b/bootstrap/main.cc index 4c4c6c3..888b0b7 100644 --- a/bootstrap/main.cc +++ b/bootstrap/main.cc @@ -43,7 +43,7 @@ int main(int argc, char **argv) { xlang::xlangParser parser{&tokens}; auto *tree = parser.file(); - xlang::EmitListener emit{outputfile}; - antlr4::tree::ParseTreeWalker::DEFAULT.walk(&emit, tree); + xlang::EmitVisitor emit{outputfile}; + emit.visitFile(tree); return 0; } diff --git a/bootstrap/xlang.g4 b/bootstrap/xlang.g4 index 0c8c36d..b16bbbb 100644 --- a/bootstrap/xlang.g4 +++ b/bootstrap/xlang.g4 @@ -11,13 +11,19 @@ statement : Identifier Assign expr Semicolon | Print expr Semicolon | expr Semicolon ; -expr : sum ((Less|LessEqual|Greater|GreaterEqual|Equal|NotEqual) sum)*; -sum : term ((Plus|Minus) term)*; -term : factor ((Mul|Div) factor)*; +expr : sum ((Less|LessEqual|Greater|GreaterEqual|Equal|NotEqual) sum) + | sum + ; +sum : sum ((Plus|Minus) term) + | term + ; +term : term ((Mul|Div) factor) + | factor + ; factor : Integer | Identifier | Identifier LeftParen exprList? RightParen - | LeftParen expr RightParen + | LeftParen sum RightParen ; exprList : expr (Comma expr)*; @@ -45,7 +51,7 @@ Div : '/'; Comma : ','; Semicolon : ';'; -Identifier : [a-zA-Z][a-zA-Z0-9]*; +Identifier : [_a-zA-Z][_a-zA-Z0-9]*; Integer : [0-9]+; Comment : '//' ~[\n]* '\n' -> skip; diff --git a/test/fib.x b/test/fib.x index f6133fb..fb2690b 100644 --- a/test/fib.x +++ b/test/fib.x @@ -1,10 +1,20 @@ main() { - // print 5th fibonacci number - print fib(5); - return 0; + i = 0; + while i < 10 { + print fib_rec(i); + print fib_iter(i); + i = i + 1; + } } -fib(n) { +fib_rec(n) { + if n < 2 { + return 1; + } + return fib_rec(n - 1) + fib_rec(n - 2); +} + +fib_iter(n) { x0 = 1; x1 = 1; i = 0;