xlang/bootstrap/typecheck.cc

305 lines
8.6 KiB
C++

#include <error.hh>
#include <typecheck.hh>
namespace xlang {
std::string typeToString(Type type) {
switch (type) {
case Type::Integer:
return "int";
case Type::Boolean:
return "bool";
default:
case Type::Invalid:
return "<invalid>";
}
}
TypeCheckVisitor::TypeCheckVisitor(ErrorListener &errorlistener)
: errorlistener{errorlistener}, loopcount{0} {}
std::any TypeCheckVisitor::visitFile(xlangParser::FileContext *ctx) {
for (auto function : ctx->function()) {
auto token = function->Identifier()->getSymbol();
auto name = token->getText();
Signature signature{std::any_cast<Type>(visitType(function->type())), {}};
if (signatures.find(name) != signatures.end()) {
errorlistener.duplicateFunction(token, name);
continue;
}
if (auto param_list = function->parameterList()) {
for (auto type : param_list->type()) {
signature.parametertypes.push_back(
std::any_cast<Type>(visitType(type)));
}
}
signatures.emplace(name, signature);
}
visitChildren(ctx);
return {};
}
std::any TypeCheckVisitor::visitFunction(xlangParser::FunctionContext *ctx) {
scope.enter();
if (auto param_list = ctx->parameterList()) {
for (size_t i = 0, n = param_list->Identifier().size(); i < n; i++) {
auto name = param_list->Identifier(i)->getSymbol()->getText();
auto type = std::any_cast<Type>(visitType(param_list->type(i)));
scope.add(name, type);
}
}
returntype = std::any_cast<Type>(visitType(ctx->type()));
visitBlock(ctx->block());
scope.leave();
return {};
}
std::any TypeCheckVisitor::visitType(xlangParser::TypeContext *ctx) {
if (ctx->TypeInteger()) {
return Type::Integer;
}
if (ctx->TypeBoolean()) {
return Type::Boolean;
}
// unreachable
errorlistener.compilerError(__FILE__, __LINE__);
}
std::any TypeCheckVisitor::visitBlock(xlangParser::BlockContext *ctx) {
scope.enter();
visitChildren(ctx);
scope.leave();
return {};
}
std::any TypeCheckVisitor::visitStatement(xlangParser::StatementContext *ctx) {
if (ctx->If()) {
scope.enter();
auto condition = ctx->condition();
auto type = std::any_cast<Type>(visitCondition(condition));
if (type != Type::Boolean) {
errorlistener.typeMismatch(condition->getStart(), Type::Boolean, type);
}
visitBlock(ctx->block(0));
if (ctx->Else()) {
visitBlock(ctx->block(1));
}
scope.leave();
return {};
}
if (ctx->While() || ctx->For()) {
scope.enter();
if (ctx->For()) {
visitValue(ctx->value(0));
}
auto condition = ctx->condition();
auto type = std::any_cast<Type>(visitCondition(condition));
if (type != Type::Boolean) {
errorlistener.typeMismatch(condition->getStart(), Type::Boolean, type);
}
if (ctx->For()) {
visitValue(ctx->value(1));
}
loopcount++;
visitBlock(ctx->block(0));
loopcount--;
scope.leave();
return {};
}
if (ctx->Break()) {
if (!loopcount) {
errorlistener.loopControlWithoutLoop(ctx->Break()->getSymbol());
}
return {};
}
if (ctx->Continue()) {
if (!loopcount) {
errorlistener.loopControlWithoutLoop(ctx->Continue()->getSymbol());
}
return {};
}
if (ctx->Return()) {
auto type = std::any_cast<Type>(visitValue(ctx->value(0)));
if (type != returntype) {
errorlistener.typeMismatch(ctx->value(0)->getStart(), returntype, type);
}
}
visitChildren(ctx);
return {};
}
std::any TypeCheckVisitor::visitValue(xlangParser::ValueContext *ctx) {
if (auto expr = ctx->expr()) {
return visitExpr(expr);
}
if (auto condition = ctx->condition()) {
return visitCondition(condition);
}
// unreachable
errorlistener.compilerError(__FILE__, __LINE__);
}
#define CHECKOPERATOR(visitLeft, left, visitRight, right, expected) \
auto lefttype = std::any_cast<Type>(visitLeft(left)); \
auto righttype = std::any_cast<Type>(visitRight(right)); \
\
if (lefttype != expected) { \
errorlistener.typeMismatch(left->getStart(), expected, lefttype); \
} \
if (righttype != expected) { \
errorlistener.typeMismatch(right->getStart(), expected, righttype); \
}
std::any TypeCheckVisitor::visitCondition(xlangParser::ConditionContext *ctx) {
if (auto condition = ctx->condition()) {
CHECKOPERATOR(visitCondition, condition, visitBoolean, ctx->boolean(),
Type::Boolean);
return Type::Boolean;
}
return visitBoolean(ctx->boolean());
}
std::any TypeCheckVisitor::visitBoolean(xlangParser::BooleanContext *ctx) {
if (auto boolean = ctx->boolean()) {
auto type = std::any_cast<Type>(visitBoolean(boolean));
if (type != Type::Boolean) {
errorlistener.typeMismatch(boolean->getStart(), Type::Boolean, type);
}
return Type::Boolean;
}
if (ctx->True() || ctx->False()) {
return Type::Boolean;
}
if (ctx->expr().size()) {
CHECKOPERATOR(visitExpr, ctx->expr(0), visitExpr, ctx->expr(1),
Type::Integer);
return Type::Boolean;
}
if (auto variable = ctx->variable()) {
return visitVariable(variable);
}
return visitCondition(ctx->condition());
}
std::any TypeCheckVisitor::visitExpr(xlangParser::ExprContext *ctx) {
if (auto expr = ctx->expr()) {
CHECKOPERATOR(visitExpr, expr, visitTerm, ctx->term(), Type::Integer);
return Type::Integer;
}
return visitTerm(ctx->term());
}
std::any TypeCheckVisitor::visitTerm(xlangParser::TermContext *ctx) {
if (auto term = ctx->term()) {
CHECKOPERATOR(visitTerm, term, visitFactor, ctx->factor(), Type::Integer);
return Type::Integer;
}
return visitFactor(ctx->factor());
}
std::any TypeCheckVisitor::visitFactor(xlangParser::FactorContext *ctx) {
if (auto factor = ctx->factor()) {
auto type = std::any_cast<Type>(visitFactor(factor));
if (type != Type::Integer) {
errorlistener.typeMismatch(factor->getStart(), Type::Integer, type);
}
return Type::Integer;
}
if (ctx->Integer()) {
return Type::Integer;
}
if (auto variable = ctx->variable()) {
return visitVariable(variable);
}
return visitExpr(ctx->expr());
}
std::any TypeCheckVisitor::visitVariable(xlangParser::VariableContext *ctx) {
auto token = ctx->Identifier()->getSymbol();
auto name = token->getText();
if (ctx->Define()) {
auto type = std::any_cast<Type>(visitValue(ctx->value()));
if (scope.get(name)) {
errorlistener.shadowedVariable(token, name);
}
scope.add(name, type);
return type;
}
if (ctx->Assign()) {
auto type = std::any_cast<Type>(visitValue(ctx->value()));
if (auto expected = scope.get(name)) {
if (type != *expected) {
errorlistener.typeMismatch(token, *expected, type);
}
return *expected;
} else {
errorlistener.unknownVariable(token, name);
return type;
}
}
if (auto expr = ctx->expr()) {
auto righttype = std::any_cast<Type>(visitExpr(expr));
if (auto lefttype = scope.get(name)) {
if (*lefttype != Type::Integer) {
errorlistener.typeMismatch(token, Type::Integer, *lefttype);
}
} else {
errorlistener.unknownVariable(token, name);
}
if (righttype != Type::Integer) {
errorlistener.typeMismatch(token, Type::Integer, righttype);
}
return Type::Integer;
}
if (ctx->LeftParen()) {
auto it = signatures.find(name);
if (it == signatures.end()) {
errorlistener.unknownFunction(token, name);
return Type::Invalid;
} else {
auto signature = it->second;
auto arity = signature.parametertypes.size();
if (auto arg_list = ctx->argumentList()) {
auto arg_num = arg_list->value().size();
if (arity != arg_num) {
errorlistener.wrongArgumentNumber(token, name, arity, arg_num);
}
for (size_t i = 0; i < arg_num && i < arity; i++) {
auto type = std::any_cast<Type>(visitValue(arg_list->value(i)));
auto token = arg_list->value(i)->getStart();
if (type != signature.parametertypes[i]) {
errorlistener.typeMismatch(token, signature.parametertypes[i],
type);
}
}
} else {
if (arity) {
errorlistener.wrongArgumentNumber(token, name, arity, 0);
}
}
return signature.returntype;
}
}
if (auto type = scope.get(name)) {
return *type;
} else {
errorlistener.unknownVariable(token, name);
return Type::Invalid;
}
}
} // namespace xlang