avoid recursion and ambiguity

This commit is contained in:
Thomas Lindner 2023-01-13 00:07:00 +01:00
parent 61dbff27b3
commit 3f3f323898
5 changed files with 323 additions and 277 deletions

View file

@ -21,8 +21,8 @@ std::any EmitVisitor::visitFunction(xlangParser::FunctionContext *ctx) {
<< "export function w $" << ctx->Identifier()->getSymbol()->getText() << "export function w $" << ctx->Identifier()->getSymbol()->getText()
<< "("; << "(";
if (auto param_list = ctx->parameterList()) { if (auto param_list = ctx->parameterList()) {
for (const auto param : param_list->Identifier()) { for (auto param : param_list->parameter()) {
output << "w %_" << param->getSymbol()->getText() << ", "; output << "w %_" << param->Identifier()->getSymbol()->getText() << ", ";
} }
} }
output << ")" << std::endl << "{" << std::endl << "@start" << std::endl; output << ")" << std::endl << "{" << std::endl << "@start" << std::endl;
@ -36,10 +36,9 @@ std::any EmitVisitor::visitFunction(xlangParser::FunctionContext *ctx) {
std::any EmitVisitor::visitStatement(xlangParser::StatementContext *ctx) { std::any EmitVisitor::visitStatement(xlangParser::StatementContext *ctx) {
if (ctx->If()) { if (ctx->If()) {
int block = blockcount++; int block = blockcount++;
output << "@if" << block << "_expr" << std::endl; output << "@if" << block << "_expr" << std::endl;
tmpcount = 0; tmpcount = 0;
visitCondition(ctx->condition()); visitExpr(ctx->expr(0));
output << " jnz " << last << ", @if" << block << "_then, @if" << block output << " jnz " << last << ", @if" << block << "_then, @if" << block
<< "_else" << std::endl; << "_else" << std::endl;
output << "@if" << block << "_then" << std::endl; output << "@if" << block << "_then" << std::endl;
@ -54,12 +53,11 @@ std::any EmitVisitor::visitStatement(xlangParser::StatementContext *ctx) {
} }
if (ctx->While()) { if (ctx->While()) {
int block = blockcount++; int block = blockcount++;
loopstack.push_back(block); loopstack.push_back(block);
output << "@loop" << block << "_continue" << std::endl; output << "@loop" << block << "_continue" << std::endl;
output << "@loop" << block << "_expr" << std::endl; output << "@loop" << block << "_expr" << std::endl;
tmpcount = 0; tmpcount = 0;
visitCondition(ctx->condition()); visitExpr(ctx->expr(0));
output << " jnz " << last << ", @loop" << block << "_body, @loop" << block output << " jnz " << last << ", @loop" << block << "_body, @loop" << block
<< "_end" << std::endl; << "_end" << std::endl;
output << "@loop" << block << "_body" << std::endl; output << "@loop" << block << "_body" << std::endl;
@ -71,20 +69,19 @@ std::any EmitVisitor::visitStatement(xlangParser::StatementContext *ctx) {
} }
if (ctx->For()) { if (ctx->For()) {
int block = blockcount++; int block = blockcount++;
loopstack.push_back(block); loopstack.push_back(block);
tmpcount = 0; tmpcount = 0;
visitValue(ctx->value(0)); visitExpr(ctx->expr(0));
output << "@loop" << block << "_expr" << std::endl; output << "@loop" << block << "_expr" << std::endl;
tmpcount = 0; tmpcount = 0;
visitCondition(ctx->condition()); visitExpr(ctx->expr(1));
output << " jnz " << last << ", @loop" << block << "_body, @loop" << block output << " jnz " << last << ", @loop" << block << "_body, @loop" << block
<< "_end" << std::endl; << "_end" << std::endl;
output << "@loop" << block << "_body" << std::endl; output << "@loop" << block << "_body" << std::endl;
visitBlock(ctx->block(0)); visitBlock(ctx->block(0));
output << "@loop" << block << "_continue" << std::endl; output << "@loop" << block << "_continue" << std::endl;
tmpcount = 0; tmpcount = 0;
visitValue(ctx->value(1)); visitExpr(ctx->expr(2));
output << " jmp @loop" << block << "_expr" << std::endl; output << " jmp @loop" << block << "_expr" << std::endl;
output << "@loop" << block << "_end" << std::endl; output << "@loop" << block << "_end" << std::endl;
loopstack.pop_back(); loopstack.pop_back();
@ -114,44 +111,74 @@ std::any EmitVisitor::visitStatement(xlangParser::StatementContext *ctx) {
} }
if (ctx->Return()) { if (ctx->Return()) {
tmpcount = 0; tmpcount = 0;
visitValue(ctx->value(0)); visitExpr(ctx->expr(0));
output << " ret " << last << std::endl; output << " ret " << last << std::endl;
output << "@dead" << blockcount++ << std::endl; output << "@dead" << blockcount++ << std::endl;
return {}; return {};
} }
if (ctx->Print()) { if (ctx->Print()) {
tmpcount = 0; tmpcount = 0;
visitValue(ctx->value(0)); visitExpr(ctx->expr(0));
output << " call $printf(l $printformat, ..., w " << last << ")" output << " call $printf(l $printformat, ..., w " << last << ")"
<< std::endl; << std::endl;
return {}; return {};
} }
visitValue(ctx->value(0)); visitExpr(ctx->expr(0));
return {}; return {};
} }
#define OPERATOR(Operator, visitLeft, left, visitRight, right, ssa_op) \ std::any EmitVisitor::visitExpr(xlangParser::ExprContext *ctx) {
if (ctx->Operator()) { \ if (auto op = ctx->assignmentOp()) {
visitLeft(ctx->left); \ auto name = ctx->Identifier()->getSymbol()->getText();
std::string l = last; \ if (op->Define() || op->Assign()) {
visitRight(ctx->right); \ visitExpr(ctx->expr());
std::string r = last; \ output << " %_" << name << " = w copy " << last << std::endl;
output << " " << tmp() << " = w " ssa_op " " << l << ", " << r \ last = "%_" + name;
<< std::endl; \ return {};
return {}; \ }
if (op->Increment()) {
visitExpr(ctx->expr());
output << " %_" << name << " = w add %_" << name << ", " << last
<< std::endl;
last = "%_" + name;
return {};
}
if (op->Decrement()) {
visitExpr(ctx->expr());
output << " %_" << name << " = w sub %_" << name << ", " << last
<< std::endl;
last = "%_" + name;
return {};
}
}
visitBooleanExpr(ctx->booleanExpr());
return {};
}
#define OPERATOR(Operator, ssa_op) \
if (op->Operator()) { \
output << " " << tmp() << " = w " ssa_op " " << left << ", " << right \
<< std::endl; \
} }
std::any EmitVisitor::visitCondition(xlangParser::ConditionContext *ctx) { std::any EmitVisitor::visitBooleanExpr(xlangParser::BooleanExprContext *ctx) {
OPERATOR(And, visitCondition, condition(), visitBoolean, boolean(), "and"); visitComparisonExpr(ctx->comparisonExpr(0));
OPERATOR(Or, visitCondition, condition(), visitBoolean, boolean(), "or"); for (size_t i = 0, n = ctx->booleanOp().size(); i < n; i++) {
OPERATOR(Xor, visitCondition, condition(), visitBoolean, boolean(), "xor"); std::string left = last;
visitBoolean(ctx->boolean()); visitComparisonExpr(ctx->comparisonExpr(i + 1));
std::string right = last;
auto op = ctx->booleanOp(i);
OPERATOR(And, "and");
OPERATOR(Or, "or");
OPERATOR(Xor, "xor");
}
return {}; return {};
} }
std::any EmitVisitor::visitBoolean(xlangParser::BooleanContext *ctx) { std::any EmitVisitor::visitComparisonExpr(
xlangParser::ComparisonExprContext *ctx) {
if (ctx->Not()) { if (ctx->Not()) {
visitBoolean(ctx->boolean()); visitComparisonExpr(ctx->comparisonExpr());
std::string b = last; std::string b = last;
output << " " << tmp() << " = w xor 1, " << b << std::endl; output << " " << tmp() << " = w xor 1, " << b << std::endl;
return {}; return {};
@ -164,33 +191,51 @@ std::any EmitVisitor::visitBoolean(xlangParser::BooleanContext *ctx) {
last = "0"; last = "0";
return {}; return {};
} }
OPERATOR(Less, visitExpr, expr(0), visitExpr, expr(1), "csltw"); visitAdditiveExpr(ctx->additiveExpr(0));
OPERATOR(LessEqual, visitExpr, expr(0), visitExpr, expr(1), "cslew"); if (auto op = ctx->comparisonOp()) {
OPERATOR(Greater, visitExpr, expr(0), visitExpr, expr(1), "csgtw"); std::string left = last;
OPERATOR(GreaterEqual, visitExpr, expr(0), visitExpr, expr(1), "csgew"); visitAdditiveExpr(ctx->additiveExpr(1));
OPERATOR(Equal, visitExpr, expr(0), visitExpr, expr(1), "ceqw"); std::string right = last;
OPERATOR(NotEqual, visitExpr, expr(0), visitExpr, expr(1), "cnew"); OPERATOR(Less, "csltw");
visitChildren(ctx); OPERATOR(LessEqual, "cslew");
OPERATOR(Greater, "csgtw");
OPERATOR(GreaterEqual, "csgew");
OPERATOR(Equal, "ceqw");
OPERATOR(NotEqual, "cnew");
}
return {}; return {};
} }
std::any EmitVisitor::visitExpr(xlangParser::ExprContext *ctx) { std::any EmitVisitor::visitAdditiveExpr(xlangParser::AdditiveExprContext *ctx) {
OPERATOR(Plus, visitExpr, expr(), visitTerm, term(), "add"); visitMultiplicativeExpr(ctx->multiplicativeExpr(0));
OPERATOR(Minus, visitExpr, expr(), visitTerm, term(), "sub"); for (size_t i = 0, n = ctx->additiveOp().size(); i < n; i++) {
OPERATOR(BitAnd, visitExpr, expr(), visitTerm, term(), "and"); std::string left = last;
OPERATOR(BitOr, visitExpr, expr(), visitTerm, term(), "or"); visitMultiplicativeExpr(ctx->multiplicativeExpr(i + 1));
OPERATOR(BitXor, visitExpr, expr(), visitTerm, term(), "xor"); std::string right = last;
OPERATOR(ShiftLeft, visitExpr, expr(), visitTerm, term(), "shl"); auto op = ctx->additiveOp(i);
OPERATOR(ShiftRight, visitExpr, expr(), visitTerm, term(), "shr"); OPERATOR(Plus, "add");
visitTerm(ctx->term()); OPERATOR(Minus, "sub");
OPERATOR(BitAnd, "and");
OPERATOR(BitOr, "or");
OPERATOR(BitXor, "xor");
OPERATOR(ShiftLeft, "shl");
OPERATOR(ShiftRight, "shr");
}
return {}; return {};
} }
std::any EmitVisitor::visitTerm(xlangParser::TermContext *ctx) { std::any EmitVisitor::visitMultiplicativeExpr(
OPERATOR(Mul, visitTerm, term(), visitFactor, factor(), "mul"); xlangParser::MultiplicativeExprContext *ctx) {
OPERATOR(Div, visitTerm, term(), visitFactor, factor(), "div"); visitFactor(ctx->factor(0));
OPERATOR(Rem, visitTerm, term(), visitFactor, factor(), "rem"); for (size_t i = 0, n = ctx->multiplicativeOp().size(); i < n; i++) {
visitFactor(ctx->factor()); std::string left = last;
visitFactor(ctx->factor(i + 1));
std::string right = last;
auto op = ctx->multiplicativeOp(i);
OPERATOR(Mul, "mul");
OPERATOR(Div, "div");
OPERATOR(Rem, "rem");
}
return {}; return {};
} }
@ -211,50 +256,28 @@ std::any EmitVisitor::visitFactor(xlangParser::FactorContext *ctx) {
last = integer->getSymbol()->getText(); last = integer->getSymbol()->getText();
return {}; return {};
} }
visitChildren(ctx); if (auto identifier = ctx->Identifier()) {
return {}; auto name = identifier->getSymbol()->getText();
} if (ctx->LeftParen()) {
std::vector<std::string> args;
std::any EmitVisitor::visitVariable(xlangParser::VariableContext *ctx) { if (auto arg_list = ctx->argumentList()) {
auto name = ctx->Identifier()->getSymbol()->getText(); for (auto expr : arg_list->expr()) {
visitExpr(expr);
if (ctx->Define() || ctx->Assign()) { args.emplace_back(last);
visitValue(ctx->value()); }
output << " %_" << name << " = w copy " << last << std::endl;
last = "%_" + name;
return {};
}
if (ctx->Increment()) {
visitExpr(ctx->expr());
output << " %_" << name << " = w add %_" << name << ", " << last
<< std::endl;
last = "%_" + name;
return {};
}
if (ctx->Decrement()) {
visitExpr(ctx->expr());
output << " %_" << name << " = w sub %_" << name << ", " << last
<< std::endl;
last = "%_" + name;
return {};
}
if (ctx->LeftParen()) {
std::vector<std::string> args;
if (auto arg_list = ctx->argumentList()) {
for (auto value : arg_list->value()) {
visitValue(value);
args.emplace_back(last);
} }
output << " " << tmp() << " = w call $" << name << "(";
for (auto arg : args) {
output << "w " << arg << ", ";
}
output << ")" << std::endl;
return {};
} else {
last = "%_" + name;
return {};
} }
output << " " << tmp() << " = w call $" << name << "(";
for (auto arg : args) {
output << "w " << arg << ", ";
}
output << ")" << std::endl;
return {};
} }
last = "%_" + name; visitExpr(ctx->expr());
return {}; return {};
} }

View file

@ -22,12 +22,14 @@ class EmitVisitor : public xlangBaseVisitor {
std::any visitFile(xlangParser::FileContext *ctx) override; std::any visitFile(xlangParser::FileContext *ctx) override;
std::any visitFunction(xlangParser::FunctionContext *ctx) override; std::any visitFunction(xlangParser::FunctionContext *ctx) override;
std::any visitStatement(xlangParser::StatementContext *ctx) override; std::any visitStatement(xlangParser::StatementContext *ctx) override;
std::any visitCondition(xlangParser::ConditionContext *ctx) override;
std::any visitBoolean(xlangParser::BooleanContext *ctx) override;
std::any visitExpr(xlangParser::ExprContext *ctx) override; std::any visitExpr(xlangParser::ExprContext *ctx) override;
std::any visitTerm(xlangParser::TermContext *ctx) override; std::any visitBooleanExpr(xlangParser::BooleanExprContext *ctx) override;
std::any visitComparisonExpr(
xlangParser::ComparisonExprContext *ctx) override;
std::any visitAdditiveExpr(xlangParser::AdditiveExprContext *ctx) override;
std::any visitMultiplicativeExpr(
xlangParser::MultiplicativeExprContext *ctx) override;
std::any visitFactor(xlangParser::FactorContext *ctx) override; std::any visitFactor(xlangParser::FactorContext *ctx) override;
std::any visitVariable(xlangParser::VariableContext *ctx) override;
}; };
} // namespace xlang } // namespace xlang

View file

@ -22,16 +22,15 @@ std::any TypeCheckVisitor::visitFile(xlangParser::FileContext *ctx) {
for (auto function : ctx->function()) { for (auto function : ctx->function()) {
auto token = function->Identifier()->getSymbol(); auto token = function->Identifier()->getSymbol();
auto name = token->getText(); auto name = token->getText();
Signature signature{std::any_cast<Type>(visitType(function->type())), {}};
if (signatures.find(name) != signatures.end()) { if (signatures.find(name) != signatures.end()) {
errorlistener.duplicateFunction(token, name); errorlistener.duplicateFunction(token, name);
continue; continue;
} }
Signature signature{std::any_cast<Type>(visitType(function->type())), {}};
if (auto param_list = function->parameterList()) { if (auto param_list = function->parameterList()) {
for (auto type : param_list->type()) { for (auto param : param_list->parameter()) {
signature.parametertypes.push_back( signature.parametertypes.push_back(
std::any_cast<Type>(visitType(type))); std::any_cast<Type>(visitType(param->type())));
} }
} }
signatures.emplace(name, signature); signatures.emplace(name, signature);
@ -43,10 +42,9 @@ std::any TypeCheckVisitor::visitFile(xlangParser::FileContext *ctx) {
std::any TypeCheckVisitor::visitFunction(xlangParser::FunctionContext *ctx) { std::any TypeCheckVisitor::visitFunction(xlangParser::FunctionContext *ctx) {
scope.enter(); scope.enter();
if (auto param_list = ctx->parameterList()) { if (auto param_list = ctx->parameterList()) {
for (size_t i = 0, n = param_list->Identifier().size(); i < n; i++) { for (auto param : param_list->parameter()) {
auto name = param_list->Identifier(i)->getSymbol()->getText(); auto name = param->Identifier()->getSymbol()->getText();
auto type = std::any_cast<Type>(visitType(param_list->type(i))); auto type = std::any_cast<Type>(visitType(param->type()));
scope.add(name, type); scope.add(name, type);
} }
} }
@ -77,10 +75,10 @@ std::any TypeCheckVisitor::visitBlock(xlangParser::BlockContext *ctx) {
std::any TypeCheckVisitor::visitStatement(xlangParser::StatementContext *ctx) { std::any TypeCheckVisitor::visitStatement(xlangParser::StatementContext *ctx) {
if (ctx->If()) { if (ctx->If()) {
scope.enter(); scope.enter();
auto condition = ctx->condition(); auto expr = ctx->expr(0);
auto type = std::any_cast<Type>(visitCondition(condition)); auto type = std::any_cast<Type>(visitExpr(expr));
if (type != Type::Boolean) { if (type != Type::Boolean) {
errorlistener.typeMismatch(condition->getStart(), Type::Boolean, type); errorlistener.typeMismatch(expr->getStart(), Type::Boolean, type);
} }
visitBlock(ctx->block(0)); visitBlock(ctx->block(0));
if (ctx->Else()) { if (ctx->Else()) {
@ -89,19 +87,28 @@ std::any TypeCheckVisitor::visitStatement(xlangParser::StatementContext *ctx) {
scope.leave(); scope.leave();
return {}; return {};
} }
if (ctx->While() || ctx->For()) { if (ctx->While()) {
scope.enter(); scope.enter();
if (ctx->For()) { auto expr = ctx->expr(0);
visitValue(ctx->value(0)); auto type = std::any_cast<Type>(visitExpr(expr));
}
auto condition = ctx->condition();
auto type = std::any_cast<Type>(visitCondition(condition));
if (type != Type::Boolean) { if (type != Type::Boolean) {
errorlistener.typeMismatch(condition->getStart(), Type::Boolean, type); errorlistener.typeMismatch(expr->getStart(), Type::Boolean, type);
} }
if (ctx->For()) { loopcount++;
visitValue(ctx->value(1)); visitBlock(ctx->block(0));
loopcount--;
scope.leave();
return {};
}
if (ctx->For()) {
scope.enter();
visitExpr(ctx->expr(0));
auto expr = ctx->expr(1);
auto type = std::any_cast<Type>(visitExpr(expr));
if (type != Type::Boolean) {
errorlistener.typeMismatch(expr->getStart(), Type::Boolean, type);
} }
visitExpr(ctx->expr(2));
loopcount++; loopcount++;
visitBlock(ctx->block(0)); visitBlock(ctx->block(0));
loopcount--; loopcount--;
@ -142,90 +149,153 @@ std::any TypeCheckVisitor::visitStatement(xlangParser::StatementContext *ctx) {
return {}; return {};
} }
if (ctx->Return()) { if (ctx->Return()) {
auto type = std::any_cast<Type>(visitValue(ctx->value(0))); auto type = std::any_cast<Type>(visitExpr(ctx->expr(0)));
if (type != returntype) { if (type != returntype) {
errorlistener.typeMismatch(ctx->value(0)->getStart(), returntype, type); errorlistener.typeMismatch(ctx->expr(0)->getStart(), returntype, type);
} }
return {};
} }
visitChildren(ctx); visitChildren(ctx);
return {}; return {};
} }
std::any TypeCheckVisitor::visitValue(xlangParser::ValueContext *ctx) { std::any TypeCheckVisitor::visitExpr(xlangParser::ExprContext *ctx) {
if (auto expr = ctx->expr()) { if (auto op = ctx->assignmentOp()) {
return visitExpr(expr); auto token = ctx->Identifier()->getSymbol();
auto name = token->getText();
if (op->Define()) {
auto type = std::any_cast<Type>(visitExpr(ctx->expr()));
if (scope.get(name)) {
errorlistener.shadowedVariable(token, name);
}
scope.add(name, type);
return type;
}
if (op->Assign()) {
auto type = std::any_cast<Type>(visitExpr(ctx->expr()));
if (auto expected = scope.get(name)) {
if (type != *expected) {
errorlistener.typeMismatch(token, *expected, type);
}
return *expected;
} else {
errorlistener.unknownVariable(token, name);
return type;
}
}
if (op->Increment() || op->Decrement()) {
if (auto type = scope.get(name)) {
if (*type != Type::Integer) {
errorlistener.typeMismatch(token, Type::Integer, *type);
}
} else {
errorlistener.unknownVariable(token, name);
}
auto expr = ctx->expr();
auto type = std::any_cast<Type>(visitExpr(expr));
if (type != Type::Integer) {
errorlistener.typeMismatch(expr->getStart(), Type::Integer, type);
}
return Type::Integer;
}
} }
if (auto condition = ctx->condition()) { if (auto boolean_expr = ctx->booleanExpr()) {
return visitCondition(condition); return visitBooleanExpr(boolean_expr);
} }
// unreachable // unreachable
errorlistener.compilerError(__FILE__, __LINE__); errorlistener.compilerError(__FILE__, __LINE__);
} }
#define CHECKOPERATOR(visitLeft, left, visitRight, right, expected) \ std::any TypeCheckVisitor::visitBooleanExpr(
auto lefttype = std::any_cast<Type>(visitLeft(left)); \ xlangParser::BooleanExprContext *ctx) {
auto righttype = std::any_cast<Type>(visitRight(right)); \ if (ctx->booleanOp().size()) {
\ for (auto comparison_expr : ctx->comparisonExpr()) {
if (lefttype != expected) { \ auto type = std::any_cast<Type>(visitComparisonExpr(comparison_expr));
errorlistener.typeMismatch(left->getStart(), expected, lefttype); \ if (type != Type::Boolean) {
} \ errorlistener.typeMismatch(comparison_expr->getStart(), Type::Boolean,
if (righttype != expected) { \ type);
errorlistener.typeMismatch(right->getStart(), expected, righttype); \ }
} }
std::any TypeCheckVisitor::visitCondition(xlangParser::ConditionContext *ctx) {
if (auto condition = ctx->condition()) {
CHECKOPERATOR(visitCondition, condition, visitBoolean, ctx->boolean(),
Type::Boolean);
return Type::Boolean; return Type::Boolean;
} }
return visitBoolean(ctx->boolean()); if (auto comparison_expr = ctx->comparisonExpr(0)) {
return visitComparisonExpr(comparison_expr);
}
// unreachable
errorlistener.compilerError(__FILE__, __LINE__);
} }
std::any TypeCheckVisitor::visitBoolean(xlangParser::BooleanContext *ctx) { std::any TypeCheckVisitor::visitComparisonExpr(
if (auto boolean = ctx->boolean()) { xlangParser::ComparisonExprContext *ctx) {
auto type = std::any_cast<Type>(visitBoolean(boolean)); if (auto comparison_expr = ctx->comparisonExpr()) {
auto type = std::any_cast<Type>(visitComparisonExpr(comparison_expr));
if (type != Type::Boolean) { if (type != Type::Boolean) {
errorlistener.typeMismatch(boolean->getStart(), Type::Boolean, type); errorlistener.typeMismatch(comparison_expr->getStart(), Type::Boolean,
type);
} }
return Type::Boolean; return Type::Boolean;
} }
if (ctx->True() || ctx->False()) { if (ctx->True() || ctx->False()) {
return Type::Boolean; return Type::Boolean;
} }
if (ctx->expr().size()) { if (ctx->comparisonOp()) {
CHECKOPERATOR(visitExpr, ctx->expr(0), visitExpr, ctx->expr(1), for (auto additive_expr : ctx->additiveExpr()) {
Type::Integer); auto type = std::any_cast<Type>(visitAdditiveExpr(additive_expr));
if (type != Type::Integer) {
errorlistener.typeMismatch(additive_expr->getStart(), Type::Integer,
type);
}
}
return Type::Boolean; return Type::Boolean;
} }
if (auto variable = ctx->variable()) { if (auto additive_expr = ctx->additiveExpr(0)) {
return visitVariable(variable); return visitAdditiveExpr(additive_expr);
} }
return visitCondition(ctx->condition()); // unreachable
errorlistener.compilerError(__FILE__, __LINE__);
} }
std::any TypeCheckVisitor::visitExpr(xlangParser::ExprContext *ctx) { std::any TypeCheckVisitor::visitAdditiveExpr(
if (auto expr = ctx->expr()) { xlangParser::AdditiveExprContext *ctx) {
CHECKOPERATOR(visitExpr, expr, visitTerm, ctx->term(), Type::Integer); if (ctx->additiveOp().size()) {
for (auto multiplicative_expr : ctx->multiplicativeExpr()) {
auto type =
std::any_cast<Type>(visitMultiplicativeExpr(multiplicative_expr));
if (type != Type::Integer) {
errorlistener.typeMismatch(multiplicative_expr->getStart(),
Type::Integer, type);
}
}
return Type::Integer; return Type::Integer;
} }
return visitTerm(ctx->term()); if (auto multiplicative_expr = ctx->multiplicativeExpr(0)) {
return visitMultiplicativeExpr(multiplicative_expr);
}
// unreachable
errorlistener.compilerError(__FILE__, __LINE__);
} }
std::any TypeCheckVisitor::visitTerm(xlangParser::TermContext *ctx) { std::any TypeCheckVisitor::visitMultiplicativeExpr(
if (auto term = ctx->term()) { xlangParser::MultiplicativeExprContext *ctx) {
CHECKOPERATOR(visitTerm, term, visitFactor, ctx->factor(), Type::Integer); if (ctx->multiplicativeOp().size()) {
for (auto factor : ctx->factor()) {
auto type = std::any_cast<Type>(visitFactor(factor));
if (type != Type::Integer) {
errorlistener.typeMismatch(factor->getStart(), Type::Integer, type);
}
}
return Type::Integer; return Type::Integer;
} }
return visitFactor(ctx->factor()); if (auto factor = ctx->factor(0)) {
return visitFactor(factor);
}
// unreachable
errorlistener.compilerError(__FILE__, __LINE__);
} }
std::any TypeCheckVisitor::visitFactor(xlangParser::FactorContext *ctx) { std::any TypeCheckVisitor::visitFactor(xlangParser::FactorContext *ctx) {
if (auto factor = ctx->factor()) { if (auto factor = ctx->factor()) {
auto type = std::any_cast<Type>(visitFactor(factor)); auto type = std::any_cast<Type>(visitFactor(factor));
if (type != Type::Integer) { if (type != Type::Integer) {
errorlistener.typeMismatch(factor->getStart(), Type::Integer, type); errorlistener.typeMismatch(factor->getStart(), Type::Integer, type);
} }
@ -234,92 +304,50 @@ std::any TypeCheckVisitor::visitFactor(xlangParser::FactorContext *ctx) {
if (ctx->Integer()) { if (ctx->Integer()) {
return Type::Integer; return Type::Integer;
} }
if (auto variable = ctx->variable()) { if (auto identifier = ctx->Identifier()) {
return visitVariable(variable); auto token = identifier->getSymbol();
} auto name = token->getText();
return visitExpr(ctx->expr()); if (ctx->LeftParen()) {
} auto it = signatures.find(name);
if (it == signatures.end()) {
std::any TypeCheckVisitor::visitVariable(xlangParser::VariableContext *ctx) { errorlistener.unknownFunction(token, name);
auto token = ctx->Identifier()->getSymbol(); return Type::Invalid;
auto name = token->getText(); } else {
auto signature = it->second;
if (ctx->Define()) { auto arity = signature.parametertypes.size();
auto type = std::any_cast<Type>(visitValue(ctx->value())); if (auto arg_list = ctx->argumentList()) {
auto arg_num = arg_list->expr().size();
if (scope.get(name)) { if (arity != arg_num) {
errorlistener.shadowedVariable(token, name); errorlistener.wrongArgumentNumber(token, name, arity, arg_num);
} }
scope.add(name, type); for (size_t i = 0; i < arg_num && i < arity; i++) {
return type; auto expr = arg_list->expr(i);
} auto type = std::any_cast<Type>(visitExpr(expr));
if (ctx->Assign()) { if (type != signature.parametertypes[i]) {
auto type = std::any_cast<Type>(visitValue(ctx->value())); errorlistener.typeMismatch(expr->getStart(),
signature.parametertypes[i], type);
if (auto expected = scope.get(name)) { }
if (type != *expected) { }
errorlistener.typeMismatch(token, *expected, type); } else {
if (arity) {
errorlistener.wrongArgumentNumber(token, name, arity, 0);
}
}
return signature.returntype;
} }
return *expected; }
if (auto type = scope.get(name)) {
return *type;
} else { } else {
errorlistener.unknownVariable(token, name); errorlistener.unknownVariable(token, name);
return type; return Type::Invalid;
} }
} }
if (auto expr = ctx->expr()) { if (auto expr = ctx->expr()) {
auto righttype = std::any_cast<Type>(visitExpr(expr)); return visitExpr(expr);
if (auto lefttype = scope.get(name)) {
if (*lefttype != Type::Integer) {
errorlistener.typeMismatch(token, Type::Integer, *lefttype);
}
} else {
errorlistener.unknownVariable(token, name);
}
if (righttype != Type::Integer) {
errorlistener.typeMismatch(token, Type::Integer, righttype);
}
return Type::Integer;
}
if (ctx->LeftParen()) {
auto it = signatures.find(name);
if (it == signatures.end()) {
errorlistener.unknownFunction(token, name);
return Type::Invalid;
} else {
auto signature = it->second;
auto arity = signature.parametertypes.size();
if (auto arg_list = ctx->argumentList()) {
auto arg_num = arg_list->value().size();
if (arity != arg_num) {
errorlistener.wrongArgumentNumber(token, name, arity, arg_num);
}
for (size_t i = 0; i < arg_num && i < arity; i++) {
auto type = std::any_cast<Type>(visitValue(arg_list->value(i)));
auto token = arg_list->value(i)->getStart();
if (type != signature.parametertypes[i]) {
errorlistener.typeMismatch(token, signature.parametertypes[i],
type);
}
}
} else {
if (arity) {
errorlistener.wrongArgumentNumber(token, name, arity, 0);
}
}
return signature.returntype;
}
}
if (auto type = scope.get(name)) {
return *type;
} else {
errorlistener.unknownVariable(token, name);
return Type::Invalid;
} }
// unreachable
errorlistener.compilerError(__FILE__, __LINE__);
} }
} // namespace xlang } // namespace xlang

View file

@ -39,13 +39,14 @@ class TypeCheckVisitor : public xlangBaseVisitor {
std::any visitType(xlangParser::TypeContext *ctx) override; std::any visitType(xlangParser::TypeContext *ctx) override;
std::any visitBlock(xlangParser::BlockContext *ctx) override; std::any visitBlock(xlangParser::BlockContext *ctx) override;
std::any visitStatement(xlangParser::StatementContext *ctx) override; std::any visitStatement(xlangParser::StatementContext *ctx) override;
std::any visitValue(xlangParser::ValueContext *ctx) override;
std::any visitCondition(xlangParser::ConditionContext *ctx) override;
std::any visitBoolean(xlangParser::BooleanContext *ctx) override;
std::any visitExpr(xlangParser::ExprContext *ctx) override; std::any visitExpr(xlangParser::ExprContext *ctx) override;
std::any visitTerm(xlangParser::TermContext *ctx) override; std::any visitBooleanExpr(xlangParser::BooleanExprContext *ctx) override;
std::any visitComparisonExpr(
xlangParser::ComparisonExprContext *ctx) override;
std::any visitAdditiveExpr(xlangParser::AdditiveExprContext *ctx) override;
std::any visitMultiplicativeExpr(
xlangParser::MultiplicativeExprContext *ctx) override;
std::any visitFactor(xlangParser::FactorContext *ctx) override; std::any visitFactor(xlangParser::FactorContext *ctx) override;
std::any visitVariable(xlangParser::VariableContext *ctx) override;
}; };
} // namespace xlang } // namespace xlang

View file

@ -2,53 +2,45 @@ grammar xlang;
file : function+ EOF; file : function+ EOF;
function : Identifier LeftParen parameterList? RightParen Colon type block; function : Identifier LeftParen parameterList? RightParen Colon type block;
parameterList : Identifier Colon type (Comma Identifier Colon type)*; parameterList : parameter (Comma parameter)*;
parameter : Identifier Colon type;
type : TypeInteger type : TypeInteger
| TypeBoolean | TypeBoolean
; ;
block : LeftBrace statement* RightBrace; block : LeftBrace statement* RightBrace;
statement : If condition block (Else block)? statement : If expr block (Else block)?
| While condition block | While expr block
| For value Semicolon condition Semicolon value block | For expr Semicolon expr Semicolon expr block
| Break Integer? Semicolon | Break Integer? Semicolon
| Continue Integer? Semicolon | Continue Integer? Semicolon
| Return value Semicolon | Return expr Semicolon
| Print value Semicolon | Print expr Semicolon
| value Semicolon | expr Semicolon
; ;
value : expr expr : Identifier assignmentOp expr
| condition | booleanExpr
;
condition : condition (And|Or|Xor) boolean
| boolean
;
boolean : Not boolean
| True
| False
| expr (Less|LessEqual|Greater|GreaterEqual|Equal|NotEqual) expr
| variable
| LeftParen condition RightParen
;
expr : expr (Plus|Minus|BitAnd|BitOr|BitXor|ShiftLeft|ShiftRight) term
| term
;
term : term (Mul|Div|Rem) factor
| factor
; ;
assignmentOp : Define | Assign | Increment | Decrement;
booleanExpr : comparisonExpr (booleanOp comparisonExpr)*;
booleanOp : And | Or | Xor;
comparisonExpr : Not comparisonExpr
| True
| False
| additiveExpr (comparisonOp additiveExpr)?
;
comparisonOp : Less | LessEqual | Greater | GreaterEqual | Equal | NotEqual;
additiveExpr : multiplicativeExpr (additiveOp multiplicativeExpr)*;
additiveOp : Plus | Minus | BitAnd | BitOr | BitXor | ShiftLeft | ShiftRight;
multiplicativeExpr : factor (multiplicativeOp factor)*;
multiplicativeOp : Mul | Div | Rem;
factor : Minus factor factor : Minus factor
| BitNot factor | BitNot factor
| Integer | Integer
| variable | Identifier
| Identifier LeftParen argumentList? RightParen
| LeftParen expr RightParen | LeftParen expr RightParen
; ;
variable : Identifier argumentList : expr (Comma expr)*;
| Identifier Define value
| Identifier Assign value
| Identifier Increment expr
| Identifier Decrement expr
| Identifier LeftParen argumentList? RightParen
;
argumentList : value (Comma value)*;
TypeInteger : 'int'; TypeInteger : 'int';
TypeBoolean : 'bool'; TypeBoolean : 'bool';