#include #include namespace xlang { std::string typeToString(Type type) { switch (type) { case Type::Integer: return "int"; case Type::Boolean: return "bool"; default: case Type::Invalid: return ""; } } TypeCheckVisitor::TypeCheckVisitor(ErrorListener &errorlistener) : errorlistener{errorlistener}, loopcount{0} {} std::any TypeCheckVisitor::visitFile(xlangParser::FileContext *ctx) { for (auto function : ctx->function()) { auto token = function->Identifier()->getSymbol(); auto name = token->getText(); Signature signature{std::any_cast(visitType(function->type())), {}}; if (signatures.find(name) != signatures.end()) { errorlistener.duplicateFunction(token, name); continue; } if (auto param_list = function->parameterList()) { for (auto type : param_list->type()) { signature.parametertypes.push_back( std::any_cast(visitType(type))); } } signatures.emplace(name, signature); } visitChildren(ctx); return {}; } std::any TypeCheckVisitor::visitFunction(xlangParser::FunctionContext *ctx) { scope.enter(); if (auto param_list = ctx->parameterList()) { for (size_t i = 0, n = param_list->Identifier().size(); i < n; i++) { auto name = param_list->Identifier(i)->getSymbol()->getText(); auto type = std::any_cast(visitType(param_list->type(i))); scope.add(name, type); } } returntype = std::any_cast(visitType(ctx->type())); visitBlock(ctx->block()); scope.leave(); return {}; } std::any TypeCheckVisitor::visitType(xlangParser::TypeContext *ctx) { if (ctx->TypeInteger()) { return Type::Integer; } if (ctx->TypeBoolean()) { return Type::Boolean; } // unreachable errorlistener.compilerError(__FILE__, __LINE__); } std::any TypeCheckVisitor::visitBlock(xlangParser::BlockContext *ctx) { scope.enter(); visitChildren(ctx); scope.leave(); return {}; } std::any TypeCheckVisitor::visitStatement(xlangParser::StatementContext *ctx) { if (auto condition = ctx->condition()) { auto type = std::any_cast(visitCondition(condition)); if (type != Type::Boolean) { errorlistener.typeMismatch(condition->getStart(), Type::Boolean, type); } if (ctx->While()) { loopcount++; } visitBlock(ctx->block(0)); if (ctx->While()) { loopcount--; } if (ctx->Else()) { visitBlock(ctx->block(1)); } return {}; } if (ctx->Break()) { if (!loopcount) { errorlistener.loopControlWithoutLoop(ctx->Break()->getSymbol()); } return {}; } if (ctx->Continue()) { if (!loopcount) { errorlistener.loopControlWithoutLoop(ctx->Continue()->getSymbol()); } return {}; } if (ctx->Return()) { auto type = std::any_cast(visitValue(ctx->value())); if (type != returntype) { errorlistener.typeMismatch(ctx->value()->getStart(), returntype, type); } } visitChildren(ctx); return {}; } std::any TypeCheckVisitor::visitValue(xlangParser::ValueContext *ctx) { if (auto expr = ctx->expr()) { return visitExpr(expr); } if (auto condition = ctx->condition()) { return visitCondition(condition); } // unreachable errorlistener.compilerError(__FILE__, __LINE__); } #define CHECKOPERATOR(visitLeft, left, visitRight, right, expected) \ auto lefttype = std::any_cast(visitLeft(left)); \ auto righttype = std::any_cast(visitRight(right)); \ \ if (lefttype != expected) { \ errorlistener.typeMismatch(left->getStart(), expected, lefttype); \ } \ if (righttype != expected) { \ errorlistener.typeMismatch(right->getStart(), expected, righttype); \ } std::any TypeCheckVisitor::visitCondition(xlangParser::ConditionContext *ctx) { if (auto condition = ctx->condition()) { CHECKOPERATOR(visitCondition, condition, visitBoolean, ctx->boolean(), Type::Boolean); return Type::Boolean; } return visitBoolean(ctx->boolean()); } std::any TypeCheckVisitor::visitBoolean(xlangParser::BooleanContext *ctx) { if (auto boolean = ctx->boolean()) { auto type = std::any_cast(visitBoolean(boolean)); if (type != Type::Boolean) { errorlistener.typeMismatch(boolean->getStart(), Type::Boolean, type); } return Type::Boolean; } if (ctx->True() || ctx->False()) { return Type::Boolean; } if (ctx->expr().size()) { CHECKOPERATOR(visitExpr, ctx->expr(0), visitExpr, ctx->expr(1), Type::Integer); return Type::Boolean; } if (auto variable = ctx->variable()) { return visitVariable(variable); } return visitCondition(ctx->condition()); } std::any TypeCheckVisitor::visitExpr(xlangParser::ExprContext *ctx) { if (auto expr = ctx->expr()) { CHECKOPERATOR(visitExpr, expr, visitTerm, ctx->term(), Type::Integer); return Type::Integer; } return visitTerm(ctx->term()); } std::any TypeCheckVisitor::visitTerm(xlangParser::TermContext *ctx) { if (auto term = ctx->term()) { CHECKOPERATOR(visitTerm, term, visitFactor, ctx->factor(), Type::Integer); return Type::Integer; } return visitFactor(ctx->factor()); } std::any TypeCheckVisitor::visitFactor(xlangParser::FactorContext *ctx) { if (auto factor = ctx->factor()) { auto type = std::any_cast(visitFactor(factor)); if (type != Type::Integer) { errorlistener.typeMismatch(factor->getStart(), Type::Integer, type); } return Type::Integer; } if (ctx->Integer()) { return Type::Integer; } if (auto variable = ctx->variable()) { return visitVariable(variable); } return visitExpr(ctx->expr()); } std::any TypeCheckVisitor::visitVariable(xlangParser::VariableContext *ctx) { auto token = ctx->Identifier()->getSymbol(); auto name = token->getText(); if (ctx->Define()) { auto type = std::any_cast(visitValue(ctx->value())); if (scope.get(name)) { errorlistener.shadowedVariable(token, name); } scope.add(name, type); return type; } if (ctx->Assign()) { auto type = std::any_cast(visitValue(ctx->value())); if (auto expected = scope.get(name)) { if (type != *expected) { errorlistener.typeMismatch(token, *expected, type); } return *expected; } else { errorlistener.unknownVariable(token, name); return type; } } if (ctx->LeftParen()) { auto it = signatures.find(name); if (it == signatures.end()) { errorlistener.unknownFunction(token, name); return Type::Invalid; } else { auto signature = it->second; auto arity = signature.parametertypes.size(); if (auto arg_list = ctx->argumentList()) { auto arg_num = arg_list->value().size(); if (arity != arg_num) { errorlistener.wrongArgumentNumber(token, name, arity, arg_num); } for (size_t i = 0; i < arg_num && i < arity; i++) { auto type = std::any_cast(visitValue(arg_list->value(i))); auto token = arg_list->value(i)->getStart(); if (type != signature.parametertypes[i]) { errorlistener.typeMismatch(token, signature.parametertypes[i], type); } } } else { if (arity) { errorlistener.wrongArgumentNumber(token, name, arity, 0); } } return signature.returntype; } } if (auto type = scope.get(name)) { return *type; } else { errorlistener.unknownVariable(token, name); return Type::Invalid; } } } // namespace xlang