From 3f3f323898ff3b628a4ee4f1695b14de7f2c0cd3 Mon Sep 17 00:00:00 2001 From: Thomas Lindner Date: Fri, 13 Jan 2023 00:07:00 +0100 Subject: [PATCH] avoid recursion and ambiguity --- bootstrap/emit.cc | 205 +++++++++++++++------------ bootstrap/emit.hh | 10 +- bootstrap/typecheck.cc | 314 ++++++++++++++++++++++------------------- bootstrap/typecheck.hh | 11 +- bootstrap/xlang.g4 | 60 ++++---- 5 files changed, 323 insertions(+), 277 deletions(-) diff --git a/bootstrap/emit.cc b/bootstrap/emit.cc index 32b94a4..359339c 100644 --- a/bootstrap/emit.cc +++ b/bootstrap/emit.cc @@ -21,8 +21,8 @@ std::any EmitVisitor::visitFunction(xlangParser::FunctionContext *ctx) { << "export function w $" << ctx->Identifier()->getSymbol()->getText() << "("; if (auto param_list = ctx->parameterList()) { - for (const auto param : param_list->Identifier()) { - output << "w %_" << param->getSymbol()->getText() << ", "; + for (auto param : param_list->parameter()) { + output << "w %_" << param->Identifier()->getSymbol()->getText() << ", "; } } output << ")" << std::endl << "{" << std::endl << "@start" << std::endl; @@ -36,10 +36,9 @@ std::any EmitVisitor::visitFunction(xlangParser::FunctionContext *ctx) { std::any EmitVisitor::visitStatement(xlangParser::StatementContext *ctx) { if (ctx->If()) { int block = blockcount++; - output << "@if" << block << "_expr" << std::endl; tmpcount = 0; - visitCondition(ctx->condition()); + visitExpr(ctx->expr(0)); output << " jnz " << last << ", @if" << block << "_then, @if" << block << "_else" << std::endl; output << "@if" << block << "_then" << std::endl; @@ -54,12 +53,11 @@ std::any EmitVisitor::visitStatement(xlangParser::StatementContext *ctx) { } if (ctx->While()) { int block = blockcount++; - loopstack.push_back(block); output << "@loop" << block << "_continue" << std::endl; output << "@loop" << block << "_expr" << std::endl; tmpcount = 0; - visitCondition(ctx->condition()); + visitExpr(ctx->expr(0)); output << " jnz " << last << ", @loop" << block << "_body, @loop" << block << "_end" << std::endl; output << "@loop" << block << "_body" << std::endl; @@ -71,20 +69,19 @@ std::any EmitVisitor::visitStatement(xlangParser::StatementContext *ctx) { } if (ctx->For()) { int block = blockcount++; - loopstack.push_back(block); tmpcount = 0; - visitValue(ctx->value(0)); + visitExpr(ctx->expr(0)); output << "@loop" << block << "_expr" << std::endl; tmpcount = 0; - visitCondition(ctx->condition()); + visitExpr(ctx->expr(1)); output << " jnz " << last << ", @loop" << block << "_body, @loop" << block << "_end" << std::endl; output << "@loop" << block << "_body" << std::endl; visitBlock(ctx->block(0)); output << "@loop" << block << "_continue" << std::endl; tmpcount = 0; - visitValue(ctx->value(1)); + visitExpr(ctx->expr(2)); output << " jmp @loop" << block << "_expr" << std::endl; output << "@loop" << block << "_end" << std::endl; loopstack.pop_back(); @@ -114,44 +111,74 @@ std::any EmitVisitor::visitStatement(xlangParser::StatementContext *ctx) { } if (ctx->Return()) { tmpcount = 0; - visitValue(ctx->value(0)); + visitExpr(ctx->expr(0)); output << " ret " << last << std::endl; output << "@dead" << blockcount++ << std::endl; return {}; } if (ctx->Print()) { tmpcount = 0; - visitValue(ctx->value(0)); + visitExpr(ctx->expr(0)); output << " call $printf(l $printformat, ..., w " << last << ")" << std::endl; return {}; } - visitValue(ctx->value(0)); + visitExpr(ctx->expr(0)); return {}; } -#define OPERATOR(Operator, visitLeft, left, visitRight, right, ssa_op) \ - if (ctx->Operator()) { \ - visitLeft(ctx->left); \ - std::string l = last; \ - visitRight(ctx->right); \ - std::string r = last; \ - output << " " << tmp() << " = w " ssa_op " " << l << ", " << r \ - << std::endl; \ - return {}; \ +std::any EmitVisitor::visitExpr(xlangParser::ExprContext *ctx) { + if (auto op = ctx->assignmentOp()) { + auto name = ctx->Identifier()->getSymbol()->getText(); + if (op->Define() || op->Assign()) { + visitExpr(ctx->expr()); + output << " %_" << name << " = w copy " << last << std::endl; + last = "%_" + name; + return {}; + } + if (op->Increment()) { + visitExpr(ctx->expr()); + output << " %_" << name << " = w add %_" << name << ", " << last + << std::endl; + last = "%_" + name; + return {}; + } + if (op->Decrement()) { + visitExpr(ctx->expr()); + output << " %_" << name << " = w sub %_" << name << ", " << last + << std::endl; + last = "%_" + name; + return {}; + } + } + visitBooleanExpr(ctx->booleanExpr()); + return {}; +} + +#define OPERATOR(Operator, ssa_op) \ + if (op->Operator()) { \ + output << " " << tmp() << " = w " ssa_op " " << left << ", " << right \ + << std::endl; \ } -std::any EmitVisitor::visitCondition(xlangParser::ConditionContext *ctx) { - OPERATOR(And, visitCondition, condition(), visitBoolean, boolean(), "and"); - OPERATOR(Or, visitCondition, condition(), visitBoolean, boolean(), "or"); - OPERATOR(Xor, visitCondition, condition(), visitBoolean, boolean(), "xor"); - visitBoolean(ctx->boolean()); +std::any EmitVisitor::visitBooleanExpr(xlangParser::BooleanExprContext *ctx) { + visitComparisonExpr(ctx->comparisonExpr(0)); + for (size_t i = 0, n = ctx->booleanOp().size(); i < n; i++) { + std::string left = last; + visitComparisonExpr(ctx->comparisonExpr(i + 1)); + std::string right = last; + auto op = ctx->booleanOp(i); + OPERATOR(And, "and"); + OPERATOR(Or, "or"); + OPERATOR(Xor, "xor"); + } return {}; } -std::any EmitVisitor::visitBoolean(xlangParser::BooleanContext *ctx) { +std::any EmitVisitor::visitComparisonExpr( + xlangParser::ComparisonExprContext *ctx) { if (ctx->Not()) { - visitBoolean(ctx->boolean()); + visitComparisonExpr(ctx->comparisonExpr()); std::string b = last; output << " " << tmp() << " = w xor 1, " << b << std::endl; return {}; @@ -164,33 +191,51 @@ std::any EmitVisitor::visitBoolean(xlangParser::BooleanContext *ctx) { last = "0"; return {}; } - OPERATOR(Less, visitExpr, expr(0), visitExpr, expr(1), "csltw"); - OPERATOR(LessEqual, visitExpr, expr(0), visitExpr, expr(1), "cslew"); - OPERATOR(Greater, visitExpr, expr(0), visitExpr, expr(1), "csgtw"); - OPERATOR(GreaterEqual, visitExpr, expr(0), visitExpr, expr(1), "csgew"); - OPERATOR(Equal, visitExpr, expr(0), visitExpr, expr(1), "ceqw"); - OPERATOR(NotEqual, visitExpr, expr(0), visitExpr, expr(1), "cnew"); - visitChildren(ctx); + visitAdditiveExpr(ctx->additiveExpr(0)); + if (auto op = ctx->comparisonOp()) { + std::string left = last; + visitAdditiveExpr(ctx->additiveExpr(1)); + std::string right = last; + OPERATOR(Less, "csltw"); + OPERATOR(LessEqual, "cslew"); + OPERATOR(Greater, "csgtw"); + OPERATOR(GreaterEqual, "csgew"); + OPERATOR(Equal, "ceqw"); + OPERATOR(NotEqual, "cnew"); + } return {}; } -std::any EmitVisitor::visitExpr(xlangParser::ExprContext *ctx) { - OPERATOR(Plus, visitExpr, expr(), visitTerm, term(), "add"); - OPERATOR(Minus, visitExpr, expr(), visitTerm, term(), "sub"); - OPERATOR(BitAnd, visitExpr, expr(), visitTerm, term(), "and"); - OPERATOR(BitOr, visitExpr, expr(), visitTerm, term(), "or"); - OPERATOR(BitXor, visitExpr, expr(), visitTerm, term(), "xor"); - OPERATOR(ShiftLeft, visitExpr, expr(), visitTerm, term(), "shl"); - OPERATOR(ShiftRight, visitExpr, expr(), visitTerm, term(), "shr"); - visitTerm(ctx->term()); +std::any EmitVisitor::visitAdditiveExpr(xlangParser::AdditiveExprContext *ctx) { + visitMultiplicativeExpr(ctx->multiplicativeExpr(0)); + for (size_t i = 0, n = ctx->additiveOp().size(); i < n; i++) { + std::string left = last; + visitMultiplicativeExpr(ctx->multiplicativeExpr(i + 1)); + std::string right = last; + auto op = ctx->additiveOp(i); + OPERATOR(Plus, "add"); + OPERATOR(Minus, "sub"); + OPERATOR(BitAnd, "and"); + OPERATOR(BitOr, "or"); + OPERATOR(BitXor, "xor"); + OPERATOR(ShiftLeft, "shl"); + OPERATOR(ShiftRight, "shr"); + } return {}; } -std::any EmitVisitor::visitTerm(xlangParser::TermContext *ctx) { - OPERATOR(Mul, visitTerm, term(), visitFactor, factor(), "mul"); - OPERATOR(Div, visitTerm, term(), visitFactor, factor(), "div"); - OPERATOR(Rem, visitTerm, term(), visitFactor, factor(), "rem"); - visitFactor(ctx->factor()); +std::any EmitVisitor::visitMultiplicativeExpr( + xlangParser::MultiplicativeExprContext *ctx) { + visitFactor(ctx->factor(0)); + for (size_t i = 0, n = ctx->multiplicativeOp().size(); i < n; i++) { + std::string left = last; + visitFactor(ctx->factor(i + 1)); + std::string right = last; + auto op = ctx->multiplicativeOp(i); + OPERATOR(Mul, "mul"); + OPERATOR(Div, "div"); + OPERATOR(Rem, "rem"); + } return {}; } @@ -211,50 +256,28 @@ std::any EmitVisitor::visitFactor(xlangParser::FactorContext *ctx) { last = integer->getSymbol()->getText(); return {}; } - visitChildren(ctx); - return {}; -} - -std::any EmitVisitor::visitVariable(xlangParser::VariableContext *ctx) { - auto name = ctx->Identifier()->getSymbol()->getText(); - - if (ctx->Define() || ctx->Assign()) { - visitValue(ctx->value()); - output << " %_" << name << " = w copy " << last << std::endl; - last = "%_" + name; - return {}; - } - if (ctx->Increment()) { - visitExpr(ctx->expr()); - output << " %_" << name << " = w add %_" << name << ", " << last - << std::endl; - last = "%_" + name; - return {}; - } - if (ctx->Decrement()) { - visitExpr(ctx->expr()); - output << " %_" << name << " = w sub %_" << name << ", " << last - << std::endl; - last = "%_" + name; - return {}; - } - if (ctx->LeftParen()) { - std::vector args; - - if (auto arg_list = ctx->argumentList()) { - for (auto value : arg_list->value()) { - visitValue(value); - args.emplace_back(last); + if (auto identifier = ctx->Identifier()) { + auto name = identifier->getSymbol()->getText(); + if (ctx->LeftParen()) { + std::vector args; + if (auto arg_list = ctx->argumentList()) { + for (auto expr : arg_list->expr()) { + visitExpr(expr); + args.emplace_back(last); + } } + output << " " << tmp() << " = w call $" << name << "("; + for (auto arg : args) { + output << "w " << arg << ", "; + } + output << ")" << std::endl; + return {}; + } else { + last = "%_" + name; + return {}; } - output << " " << tmp() << " = w call $" << name << "("; - for (auto arg : args) { - output << "w " << arg << ", "; - } - output << ")" << std::endl; - return {}; } - last = "%_" + name; + visitExpr(ctx->expr()); return {}; } diff --git a/bootstrap/emit.hh b/bootstrap/emit.hh index 40596c6..f753043 100644 --- a/bootstrap/emit.hh +++ b/bootstrap/emit.hh @@ -22,12 +22,14 @@ class EmitVisitor : public xlangBaseVisitor { std::any visitFile(xlangParser::FileContext *ctx) override; std::any visitFunction(xlangParser::FunctionContext *ctx) override; std::any visitStatement(xlangParser::StatementContext *ctx) override; - std::any visitCondition(xlangParser::ConditionContext *ctx) override; - std::any visitBoolean(xlangParser::BooleanContext *ctx) override; std::any visitExpr(xlangParser::ExprContext *ctx) override; - std::any visitTerm(xlangParser::TermContext *ctx) override; + std::any visitBooleanExpr(xlangParser::BooleanExprContext *ctx) override; + std::any visitComparisonExpr( + xlangParser::ComparisonExprContext *ctx) override; + std::any visitAdditiveExpr(xlangParser::AdditiveExprContext *ctx) override; + std::any visitMultiplicativeExpr( + xlangParser::MultiplicativeExprContext *ctx) override; std::any visitFactor(xlangParser::FactorContext *ctx) override; - std::any visitVariable(xlangParser::VariableContext *ctx) override; }; } // namespace xlang diff --git a/bootstrap/typecheck.cc b/bootstrap/typecheck.cc index 2a52e25..8d2c689 100644 --- a/bootstrap/typecheck.cc +++ b/bootstrap/typecheck.cc @@ -22,16 +22,15 @@ std::any TypeCheckVisitor::visitFile(xlangParser::FileContext *ctx) { for (auto function : ctx->function()) { auto token = function->Identifier()->getSymbol(); auto name = token->getText(); - Signature signature{std::any_cast(visitType(function->type())), {}}; - if (signatures.find(name) != signatures.end()) { errorlistener.duplicateFunction(token, name); continue; } + Signature signature{std::any_cast(visitType(function->type())), {}}; if (auto param_list = function->parameterList()) { - for (auto type : param_list->type()) { + for (auto param : param_list->parameter()) { signature.parametertypes.push_back( - std::any_cast(visitType(type))); + std::any_cast(visitType(param->type()))); } } signatures.emplace(name, signature); @@ -43,10 +42,9 @@ std::any TypeCheckVisitor::visitFile(xlangParser::FileContext *ctx) { std::any TypeCheckVisitor::visitFunction(xlangParser::FunctionContext *ctx) { scope.enter(); if (auto param_list = ctx->parameterList()) { - for (size_t i = 0, n = param_list->Identifier().size(); i < n; i++) { - auto name = param_list->Identifier(i)->getSymbol()->getText(); - auto type = std::any_cast(visitType(param_list->type(i))); - + for (auto param : param_list->parameter()) { + auto name = param->Identifier()->getSymbol()->getText(); + auto type = std::any_cast(visitType(param->type())); scope.add(name, type); } } @@ -77,10 +75,10 @@ std::any TypeCheckVisitor::visitBlock(xlangParser::BlockContext *ctx) { std::any TypeCheckVisitor::visitStatement(xlangParser::StatementContext *ctx) { if (ctx->If()) { scope.enter(); - auto condition = ctx->condition(); - auto type = std::any_cast(visitCondition(condition)); + auto expr = ctx->expr(0); + auto type = std::any_cast(visitExpr(expr)); if (type != Type::Boolean) { - errorlistener.typeMismatch(condition->getStart(), Type::Boolean, type); + errorlistener.typeMismatch(expr->getStart(), Type::Boolean, type); } visitBlock(ctx->block(0)); if (ctx->Else()) { @@ -89,19 +87,28 @@ std::any TypeCheckVisitor::visitStatement(xlangParser::StatementContext *ctx) { scope.leave(); return {}; } - if (ctx->While() || ctx->For()) { + if (ctx->While()) { scope.enter(); - if (ctx->For()) { - visitValue(ctx->value(0)); - } - auto condition = ctx->condition(); - auto type = std::any_cast(visitCondition(condition)); + auto expr = ctx->expr(0); + auto type = std::any_cast(visitExpr(expr)); if (type != Type::Boolean) { - errorlistener.typeMismatch(condition->getStart(), Type::Boolean, type); + errorlistener.typeMismatch(expr->getStart(), Type::Boolean, type); } - if (ctx->For()) { - visitValue(ctx->value(1)); + loopcount++; + visitBlock(ctx->block(0)); + loopcount--; + scope.leave(); + return {}; + } + if (ctx->For()) { + scope.enter(); + visitExpr(ctx->expr(0)); + auto expr = ctx->expr(1); + auto type = std::any_cast(visitExpr(expr)); + if (type != Type::Boolean) { + errorlistener.typeMismatch(expr->getStart(), Type::Boolean, type); } + visitExpr(ctx->expr(2)); loopcount++; visitBlock(ctx->block(0)); loopcount--; @@ -142,90 +149,153 @@ std::any TypeCheckVisitor::visitStatement(xlangParser::StatementContext *ctx) { return {}; } if (ctx->Return()) { - auto type = std::any_cast(visitValue(ctx->value(0))); - + auto type = std::any_cast(visitExpr(ctx->expr(0))); if (type != returntype) { - errorlistener.typeMismatch(ctx->value(0)->getStart(), returntype, type); + errorlistener.typeMismatch(ctx->expr(0)->getStart(), returntype, type); } + return {}; } visitChildren(ctx); return {}; } -std::any TypeCheckVisitor::visitValue(xlangParser::ValueContext *ctx) { - if (auto expr = ctx->expr()) { - return visitExpr(expr); +std::any TypeCheckVisitor::visitExpr(xlangParser::ExprContext *ctx) { + if (auto op = ctx->assignmentOp()) { + auto token = ctx->Identifier()->getSymbol(); + auto name = token->getText(); + if (op->Define()) { + auto type = std::any_cast(visitExpr(ctx->expr())); + if (scope.get(name)) { + errorlistener.shadowedVariable(token, name); + } + scope.add(name, type); + return type; + } + if (op->Assign()) { + auto type = std::any_cast(visitExpr(ctx->expr())); + if (auto expected = scope.get(name)) { + if (type != *expected) { + errorlistener.typeMismatch(token, *expected, type); + } + return *expected; + } else { + errorlistener.unknownVariable(token, name); + return type; + } + } + if (op->Increment() || op->Decrement()) { + if (auto type = scope.get(name)) { + if (*type != Type::Integer) { + errorlistener.typeMismatch(token, Type::Integer, *type); + } + } else { + errorlistener.unknownVariable(token, name); + } + auto expr = ctx->expr(); + auto type = std::any_cast(visitExpr(expr)); + if (type != Type::Integer) { + errorlistener.typeMismatch(expr->getStart(), Type::Integer, type); + } + return Type::Integer; + } } - if (auto condition = ctx->condition()) { - return visitCondition(condition); + if (auto boolean_expr = ctx->booleanExpr()) { + return visitBooleanExpr(boolean_expr); } // unreachable errorlistener.compilerError(__FILE__, __LINE__); } -#define CHECKOPERATOR(visitLeft, left, visitRight, right, expected) \ - auto lefttype = std::any_cast(visitLeft(left)); \ - auto righttype = std::any_cast(visitRight(right)); \ - \ - if (lefttype != expected) { \ - errorlistener.typeMismatch(left->getStart(), expected, lefttype); \ - } \ - if (righttype != expected) { \ - errorlistener.typeMismatch(right->getStart(), expected, righttype); \ - } - -std::any TypeCheckVisitor::visitCondition(xlangParser::ConditionContext *ctx) { - if (auto condition = ctx->condition()) { - CHECKOPERATOR(visitCondition, condition, visitBoolean, ctx->boolean(), - Type::Boolean); +std::any TypeCheckVisitor::visitBooleanExpr( + xlangParser::BooleanExprContext *ctx) { + if (ctx->booleanOp().size()) { + for (auto comparison_expr : ctx->comparisonExpr()) { + auto type = std::any_cast(visitComparisonExpr(comparison_expr)); + if (type != Type::Boolean) { + errorlistener.typeMismatch(comparison_expr->getStart(), Type::Boolean, + type); + } + } return Type::Boolean; } - return visitBoolean(ctx->boolean()); + if (auto comparison_expr = ctx->comparisonExpr(0)) { + return visitComparisonExpr(comparison_expr); + } + // unreachable + errorlistener.compilerError(__FILE__, __LINE__); } -std::any TypeCheckVisitor::visitBoolean(xlangParser::BooleanContext *ctx) { - if (auto boolean = ctx->boolean()) { - auto type = std::any_cast(visitBoolean(boolean)); - +std::any TypeCheckVisitor::visitComparisonExpr( + xlangParser::ComparisonExprContext *ctx) { + if (auto comparison_expr = ctx->comparisonExpr()) { + auto type = std::any_cast(visitComparisonExpr(comparison_expr)); if (type != Type::Boolean) { - errorlistener.typeMismatch(boolean->getStart(), Type::Boolean, type); + errorlistener.typeMismatch(comparison_expr->getStart(), Type::Boolean, + type); } return Type::Boolean; } if (ctx->True() || ctx->False()) { return Type::Boolean; } - if (ctx->expr().size()) { - CHECKOPERATOR(visitExpr, ctx->expr(0), visitExpr, ctx->expr(1), - Type::Integer); + if (ctx->comparisonOp()) { + for (auto additive_expr : ctx->additiveExpr()) { + auto type = std::any_cast(visitAdditiveExpr(additive_expr)); + if (type != Type::Integer) { + errorlistener.typeMismatch(additive_expr->getStart(), Type::Integer, + type); + } + } return Type::Boolean; } - if (auto variable = ctx->variable()) { - return visitVariable(variable); + if (auto additive_expr = ctx->additiveExpr(0)) { + return visitAdditiveExpr(additive_expr); } - return visitCondition(ctx->condition()); + // unreachable + errorlistener.compilerError(__FILE__, __LINE__); } -std::any TypeCheckVisitor::visitExpr(xlangParser::ExprContext *ctx) { - if (auto expr = ctx->expr()) { - CHECKOPERATOR(visitExpr, expr, visitTerm, ctx->term(), Type::Integer); +std::any TypeCheckVisitor::visitAdditiveExpr( + xlangParser::AdditiveExprContext *ctx) { + if (ctx->additiveOp().size()) { + for (auto multiplicative_expr : ctx->multiplicativeExpr()) { + auto type = + std::any_cast(visitMultiplicativeExpr(multiplicative_expr)); + if (type != Type::Integer) { + errorlistener.typeMismatch(multiplicative_expr->getStart(), + Type::Integer, type); + } + } return Type::Integer; } - return visitTerm(ctx->term()); + if (auto multiplicative_expr = ctx->multiplicativeExpr(0)) { + return visitMultiplicativeExpr(multiplicative_expr); + } + // unreachable + errorlistener.compilerError(__FILE__, __LINE__); } -std::any TypeCheckVisitor::visitTerm(xlangParser::TermContext *ctx) { - if (auto term = ctx->term()) { - CHECKOPERATOR(visitTerm, term, visitFactor, ctx->factor(), Type::Integer); +std::any TypeCheckVisitor::visitMultiplicativeExpr( + xlangParser::MultiplicativeExprContext *ctx) { + if (ctx->multiplicativeOp().size()) { + for (auto factor : ctx->factor()) { + auto type = std::any_cast(visitFactor(factor)); + if (type != Type::Integer) { + errorlistener.typeMismatch(factor->getStart(), Type::Integer, type); + } + } return Type::Integer; } - return visitFactor(ctx->factor()); + if (auto factor = ctx->factor(0)) { + return visitFactor(factor); + } + // unreachable + errorlistener.compilerError(__FILE__, __LINE__); } std::any TypeCheckVisitor::visitFactor(xlangParser::FactorContext *ctx) { if (auto factor = ctx->factor()) { auto type = std::any_cast(visitFactor(factor)); - if (type != Type::Integer) { errorlistener.typeMismatch(factor->getStart(), Type::Integer, type); } @@ -234,92 +304,50 @@ std::any TypeCheckVisitor::visitFactor(xlangParser::FactorContext *ctx) { if (ctx->Integer()) { return Type::Integer; } - if (auto variable = ctx->variable()) { - return visitVariable(variable); - } - return visitExpr(ctx->expr()); -} - -std::any TypeCheckVisitor::visitVariable(xlangParser::VariableContext *ctx) { - auto token = ctx->Identifier()->getSymbol(); - auto name = token->getText(); - - if (ctx->Define()) { - auto type = std::any_cast(visitValue(ctx->value())); - - if (scope.get(name)) { - errorlistener.shadowedVariable(token, name); - } - scope.add(name, type); - return type; - } - if (ctx->Assign()) { - auto type = std::any_cast(visitValue(ctx->value())); - - if (auto expected = scope.get(name)) { - if (type != *expected) { - errorlistener.typeMismatch(token, *expected, type); + if (auto identifier = ctx->Identifier()) { + auto token = identifier->getSymbol(); + auto name = token->getText(); + if (ctx->LeftParen()) { + auto it = signatures.find(name); + if (it == signatures.end()) { + errorlistener.unknownFunction(token, name); + return Type::Invalid; + } else { + auto signature = it->second; + auto arity = signature.parametertypes.size(); + if (auto arg_list = ctx->argumentList()) { + auto arg_num = arg_list->expr().size(); + if (arity != arg_num) { + errorlistener.wrongArgumentNumber(token, name, arity, arg_num); + } + for (size_t i = 0; i < arg_num && i < arity; i++) { + auto expr = arg_list->expr(i); + auto type = std::any_cast(visitExpr(expr)); + if (type != signature.parametertypes[i]) { + errorlistener.typeMismatch(expr->getStart(), + signature.parametertypes[i], type); + } + } + } else { + if (arity) { + errorlistener.wrongArgumentNumber(token, name, arity, 0); + } + } + return signature.returntype; } - return *expected; + } + if (auto type = scope.get(name)) { + return *type; } else { errorlistener.unknownVariable(token, name); - return type; + return Type::Invalid; } } if (auto expr = ctx->expr()) { - auto righttype = std::any_cast(visitExpr(expr)); - - if (auto lefttype = scope.get(name)) { - if (*lefttype != Type::Integer) { - errorlistener.typeMismatch(token, Type::Integer, *lefttype); - } - } else { - errorlistener.unknownVariable(token, name); - } - if (righttype != Type::Integer) { - errorlistener.typeMismatch(token, Type::Integer, righttype); - } - return Type::Integer; - } - if (ctx->LeftParen()) { - auto it = signatures.find(name); - - if (it == signatures.end()) { - errorlistener.unknownFunction(token, name); - return Type::Invalid; - } else { - auto signature = it->second; - auto arity = signature.parametertypes.size(); - - if (auto arg_list = ctx->argumentList()) { - auto arg_num = arg_list->value().size(); - - if (arity != arg_num) { - errorlistener.wrongArgumentNumber(token, name, arity, arg_num); - } - for (size_t i = 0; i < arg_num && i < arity; i++) { - auto type = std::any_cast(visitValue(arg_list->value(i))); - auto token = arg_list->value(i)->getStart(); - - if (type != signature.parametertypes[i]) { - errorlistener.typeMismatch(token, signature.parametertypes[i], - type); - } - } - } else { - if (arity) { - errorlistener.wrongArgumentNumber(token, name, arity, 0); - } - } - return signature.returntype; - } - } - if (auto type = scope.get(name)) { - return *type; - } else { - errorlistener.unknownVariable(token, name); - return Type::Invalid; + return visitExpr(expr); } + // unreachable + errorlistener.compilerError(__FILE__, __LINE__); } } // namespace xlang diff --git a/bootstrap/typecheck.hh b/bootstrap/typecheck.hh index 9cfca99..25ee7b6 100644 --- a/bootstrap/typecheck.hh +++ b/bootstrap/typecheck.hh @@ -39,13 +39,14 @@ class TypeCheckVisitor : public xlangBaseVisitor { std::any visitType(xlangParser::TypeContext *ctx) override; std::any visitBlock(xlangParser::BlockContext *ctx) override; std::any visitStatement(xlangParser::StatementContext *ctx) override; - std::any visitValue(xlangParser::ValueContext *ctx) override; - std::any visitCondition(xlangParser::ConditionContext *ctx) override; - std::any visitBoolean(xlangParser::BooleanContext *ctx) override; std::any visitExpr(xlangParser::ExprContext *ctx) override; - std::any visitTerm(xlangParser::TermContext *ctx) override; + std::any visitBooleanExpr(xlangParser::BooleanExprContext *ctx) override; + std::any visitComparisonExpr( + xlangParser::ComparisonExprContext *ctx) override; + std::any visitAdditiveExpr(xlangParser::AdditiveExprContext *ctx) override; + std::any visitMultiplicativeExpr( + xlangParser::MultiplicativeExprContext *ctx) override; std::any visitFactor(xlangParser::FactorContext *ctx) override; - std::any visitVariable(xlangParser::VariableContext *ctx) override; }; } // namespace xlang diff --git a/bootstrap/xlang.g4 b/bootstrap/xlang.g4 index 325ccd5..8cb414c 100644 --- a/bootstrap/xlang.g4 +++ b/bootstrap/xlang.g4 @@ -2,53 +2,45 @@ grammar xlang; file : function+ EOF; function : Identifier LeftParen parameterList? RightParen Colon type block; -parameterList : Identifier Colon type (Comma Identifier Colon type)*; +parameterList : parameter (Comma parameter)*; +parameter : Identifier Colon type; type : TypeInteger | TypeBoolean ; block : LeftBrace statement* RightBrace; -statement : If condition block (Else block)? - | While condition block - | For value Semicolon condition Semicolon value block +statement : If expr block (Else block)? + | While expr block + | For expr Semicolon expr Semicolon expr block | Break Integer? Semicolon | Continue Integer? Semicolon - | Return value Semicolon - | Print value Semicolon - | value Semicolon + | Return expr Semicolon + | Print expr Semicolon + | expr Semicolon ; -value : expr - | condition - ; -condition : condition (And|Or|Xor) boolean - | boolean - ; -boolean : Not boolean - | True - | False - | expr (Less|LessEqual|Greater|GreaterEqual|Equal|NotEqual) expr - | variable - | LeftParen condition RightParen - ; -expr : expr (Plus|Minus|BitAnd|BitOr|BitXor|ShiftLeft|ShiftRight) term - | term - ; -term : term (Mul|Div|Rem) factor - | factor +expr : Identifier assignmentOp expr + | booleanExpr ; +assignmentOp : Define | Assign | Increment | Decrement; +booleanExpr : comparisonExpr (booleanOp comparisonExpr)*; +booleanOp : And | Or | Xor; +comparisonExpr : Not comparisonExpr + | True + | False + | additiveExpr (comparisonOp additiveExpr)? + ; +comparisonOp : Less | LessEqual | Greater | GreaterEqual | Equal | NotEqual; +additiveExpr : multiplicativeExpr (additiveOp multiplicativeExpr)*; +additiveOp : Plus | Minus | BitAnd | BitOr | BitXor | ShiftLeft | ShiftRight; +multiplicativeExpr : factor (multiplicativeOp factor)*; +multiplicativeOp : Mul | Div | Rem; factor : Minus factor | BitNot factor | Integer - | variable + | Identifier + | Identifier LeftParen argumentList? RightParen | LeftParen expr RightParen ; -variable : Identifier - | Identifier Define value - | Identifier Assign value - | Identifier Increment expr - | Identifier Decrement expr - | Identifier LeftParen argumentList? RightParen - ; -argumentList : value (Comma value)*; +argumentList : expr (Comma expr)*; TypeInteger : 'int'; TypeBoolean : 'bool';