#include #include namespace xlang { std::string typeToString(Type type) { switch (type) { case Type::Integer: return "int"; case Type::Boolean: return "bool"; default: case Type::Invalid: return ""; } } TypeCheckVisitor::TypeCheckVisitor(ErrorListener &errorlistener) : errorlistener{errorlistener}, loopcount{0} {} std::any TypeCheckVisitor::visitFile(xlangParser::FileContext *ctx) { for (auto function : ctx->function()) { auto token = function->Identifier()->getSymbol(); auto name = token->getText(); if (signatures.find(name) != signatures.end()) { errorlistener.duplicateFunction(token, name); continue; } Signature signature{std::any_cast(visitType(function->type())), {}}; if (auto param_list = function->parameterList()) { for (auto param : param_list->parameter()) { signature.parametertypes.push_back( std::any_cast(visitType(param->type()))); } } signatures.emplace(name, signature); } visitChildren(ctx); return {}; } std::any TypeCheckVisitor::visitFunction(xlangParser::FunctionContext *ctx) { scope.enter(); if (auto param_list = ctx->parameterList()) { for (auto param : param_list->parameter()) { auto name = param->Identifier()->getSymbol()->getText(); auto type = std::any_cast(visitType(param->type())); scope.add(name, type); } } returntype = std::any_cast(visitType(ctx->type())); visitBlock(ctx->block()); scope.leave(); return {}; } std::any TypeCheckVisitor::visitType(xlangParser::TypeContext *ctx) { if (ctx->TypeInteger()) { return Type::Integer; } if (ctx->TypeBoolean()) { return Type::Boolean; } // unreachable errorlistener.compilerError(__FILE__, __LINE__); } std::any TypeCheckVisitor::visitBlock(xlangParser::BlockContext *ctx) { scope.enter(); visitChildren(ctx); scope.leave(); return {}; } std::any TypeCheckVisitor::visitStatement(xlangParser::StatementContext *ctx) { if (ctx->If()) { scope.enter(); auto expr = ctx->expr(0); auto type = std::any_cast(visitExpr(expr)); if (type != Type::Boolean) { errorlistener.typeMismatch(expr->getStart(), Type::Boolean, type); } visitBlock(ctx->block(0)); if (ctx->Else()) { visitBlock(ctx->block(1)); } scope.leave(); return {}; } if (ctx->While()) { scope.enter(); auto expr = ctx->expr(0); auto type = std::any_cast(visitExpr(expr)); if (type != Type::Boolean) { errorlistener.typeMismatch(expr->getStart(), Type::Boolean, type); } loopcount++; visitBlock(ctx->block(0)); loopcount--; scope.leave(); return {}; } if (ctx->For()) { scope.enter(); visitExpr(ctx->expr(0)); auto expr = ctx->expr(1); auto type = std::any_cast(visitExpr(expr)); if (type != Type::Boolean) { errorlistener.typeMismatch(expr->getStart(), Type::Boolean, type); } visitExpr(ctx->expr(2)); loopcount++; visitBlock(ctx->block(0)); loopcount--; scope.leave(); return {}; } if (ctx->Break()) { if (auto integer = ctx->Integer()) { auto num = stoul(integer->getSymbol()->getText()); if (!num) { errorlistener.breakZero(ctx->Break()->getSymbol()); } if (num > loopcount) { errorlistener.breakTooMany(ctx->Break()->getSymbol(), num, loopcount); } } else { if (!loopcount) { errorlistener.loopControlWithoutLoop(ctx->Break()->getSymbol()); } } return {}; } if (ctx->Continue()) { if (auto integer = ctx->Integer()) { auto num = stoul(integer->getSymbol()->getText()); if (!num) { errorlistener.continueZero(ctx->Break()->getSymbol()); } if (num > loopcount) { errorlistener.continueTooMany(ctx->Break()->getSymbol(), num, loopcount); } } else { if (!loopcount) { errorlistener.loopControlWithoutLoop(ctx->Continue()->getSymbol()); } } return {}; } if (ctx->Return()) { auto type = std::any_cast(visitExpr(ctx->expr(0))); if (type != returntype) { errorlistener.typeMismatch(ctx->expr(0)->getStart(), returntype, type); } return {}; } visitChildren(ctx); return {}; } std::any TypeCheckVisitor::visitExpr(xlangParser::ExprContext *ctx) { if (auto op = ctx->assignmentOp()) { auto token = ctx->Identifier()->getSymbol(); auto name = token->getText(); if (op->Define()) { auto type = std::any_cast(visitExpr(ctx->expr())); if (scope.get(name)) { errorlistener.shadowedVariable(token, name); } scope.add(name, type); return type; } if (op->Assign()) { auto type = std::any_cast(visitExpr(ctx->expr())); if (auto expected = scope.get(name)) { if (type != *expected) { errorlistener.typeMismatch(token, *expected, type); } return *expected; } else { errorlistener.unknownVariable(token, name); return type; } } if (op->Increment() || op->Decrement()) { if (auto type = scope.get(name)) { if (*type != Type::Integer) { errorlistener.typeMismatch(token, Type::Integer, *type); } } else { errorlistener.unknownVariable(token, name); } auto expr = ctx->expr(); auto type = std::any_cast(visitExpr(expr)); if (type != Type::Integer) { errorlistener.typeMismatch(expr->getStart(), Type::Integer, type); } return Type::Integer; } } if (auto boolean_expr = ctx->booleanExpr()) { return visitBooleanExpr(boolean_expr); } // unreachable errorlistener.compilerError(__FILE__, __LINE__); } std::any TypeCheckVisitor::visitBooleanExpr( xlangParser::BooleanExprContext *ctx) { if (ctx->booleanOp().size()) { for (auto comparison_expr : ctx->comparisonExpr()) { auto type = std::any_cast(visitComparisonExpr(comparison_expr)); if (type != Type::Boolean) { errorlistener.typeMismatch(comparison_expr->getStart(), Type::Boolean, type); } } return Type::Boolean; } if (auto comparison_expr = ctx->comparisonExpr(0)) { return visitComparisonExpr(comparison_expr); } // unreachable errorlistener.compilerError(__FILE__, __LINE__); } std::any TypeCheckVisitor::visitComparisonExpr( xlangParser::ComparisonExprContext *ctx) { if (auto comparison_expr = ctx->comparisonExpr()) { auto type = std::any_cast(visitComparisonExpr(comparison_expr)); if (type != Type::Boolean) { errorlistener.typeMismatch(comparison_expr->getStart(), Type::Boolean, type); } return Type::Boolean; } if (ctx->True() || ctx->False()) { return Type::Boolean; } if (ctx->comparisonOp()) { for (auto additive_expr : ctx->additiveExpr()) { auto type = std::any_cast(visitAdditiveExpr(additive_expr)); if (type != Type::Integer) { errorlistener.typeMismatch(additive_expr->getStart(), Type::Integer, type); } } return Type::Boolean; } if (auto additive_expr = ctx->additiveExpr(0)) { return visitAdditiveExpr(additive_expr); } // unreachable errorlistener.compilerError(__FILE__, __LINE__); } std::any TypeCheckVisitor::visitAdditiveExpr( xlangParser::AdditiveExprContext *ctx) { if (ctx->additiveOp().size()) { for (auto multiplicative_expr : ctx->multiplicativeExpr()) { auto type = std::any_cast(visitMultiplicativeExpr(multiplicative_expr)); if (type != Type::Integer) { errorlistener.typeMismatch(multiplicative_expr->getStart(), Type::Integer, type); } } return Type::Integer; } if (auto multiplicative_expr = ctx->multiplicativeExpr(0)) { return visitMultiplicativeExpr(multiplicative_expr); } // unreachable errorlistener.compilerError(__FILE__, __LINE__); } std::any TypeCheckVisitor::visitMultiplicativeExpr( xlangParser::MultiplicativeExprContext *ctx) { if (ctx->multiplicativeOp().size()) { for (auto factor : ctx->factor()) { auto type = std::any_cast(visitFactor(factor)); if (type != Type::Integer) { errorlistener.typeMismatch(factor->getStart(), Type::Integer, type); } } return Type::Integer; } if (auto factor = ctx->factor(0)) { return visitFactor(factor); } // unreachable errorlistener.compilerError(__FILE__, __LINE__); } std::any TypeCheckVisitor::visitFactor(xlangParser::FactorContext *ctx) { if (auto factor = ctx->factor()) { auto type = std::any_cast(visitFactor(factor)); if (type != Type::Integer) { errorlistener.typeMismatch(factor->getStart(), Type::Integer, type); } return Type::Integer; } if (ctx->Integer()) { return Type::Integer; } if (auto identifier = ctx->Identifier()) { auto token = identifier->getSymbol(); auto name = token->getText(); if (ctx->LeftParen()) { auto it = signatures.find(name); if (it == signatures.end()) { errorlistener.unknownFunction(token, name); return Type::Invalid; } else { auto signature = it->second; auto arity = signature.parametertypes.size(); if (auto arg_list = ctx->argumentList()) { auto arg_num = arg_list->expr().size(); if (arity != arg_num) { errorlistener.wrongArgumentNumber(token, name, arity, arg_num); } for (size_t i = 0; i < arg_num && i < arity; i++) { auto expr = arg_list->expr(i); auto type = std::any_cast(visitExpr(expr)); if (type != signature.parametertypes[i]) { errorlistener.typeMismatch(expr->getStart(), signature.parametertypes[i], type); } } } else { if (arity) { errorlistener.wrongArgumentNumber(token, name, arity, 0); } } return signature.returntype; } } if (auto type = scope.get(name)) { return *type; } else { errorlistener.unknownVariable(token, name); return Type::Invalid; } } if (auto expr = ctx->expr()) { return visitExpr(expr); } // unreachable errorlistener.compilerError(__FILE__, __LINE__); } } // namespace xlang