From f169088f36d49d5aba0dc09430689c056f0cd1e9 Mon Sep 17 00:00:00 2001 From: Thomas Lindner Date: Sat, 7 Jan 2023 08:54:53 +0100 Subject: [PATCH] type checking --- bootstrap/error.cc | 60 +++++++++-- bootstrap/error.hh | 13 ++- bootstrap/scope.hh | 42 ++++++++ bootstrap/typecheck.cc | 235 +++++++++++++++++++++++++++++++++-------- bootstrap/typecheck.hh | 34 ++++-- bootstrap/xlang.g4 | 14 ++- test/42.x | 2 +- test/fib.x | 16 +-- 8 files changed, 346 insertions(+), 70 deletions(-) create mode 100644 bootstrap/scope.hh diff --git a/bootstrap/error.cc b/bootstrap/error.cc index a594bbc..ebc6083 100644 --- a/bootstrap/error.cc +++ b/bootstrap/error.cc @@ -1,8 +1,16 @@ #include +#include #include namespace xlang { +void ErrorListener::printError(size_t line, size_t charPositionInLine, + std::string_view msg) { + std::cerr << file << ":" << line << ":" << charPositionInLine + 1 << ": " + << msg << std::endl; + has_error = true; +} + ErrorListener::ErrorListener(std::string_view inputfile) : file{inputfile}, has_error{false} {} @@ -10,21 +18,57 @@ bool ErrorListener::hasError() { return has_error; } +void ErrorListener::compilerError(const std::string &file, size_t line) { + std::cerr << "compiler bug in " << file << ":" << line << std::endl; + std::terminate(); +} + +void ErrorListener::duplicateFunction(antlr4::Token *token, + const std::string &name) { + printError(token->getLine(), token->getCharPositionInLine(), + "duplicate function '" + name + "'"); +} + +void ErrorListener::shadowedVariable(antlr4::Token *token, + const std::string &name) { + printError(token->getLine(), token->getCharPositionInLine(), + "definition of variable '" + name + + "' shadows previously defined variable"); +} + void ErrorListener::syntaxError([[maybe_unused]] antlr4::Recognizer *recognizer, [[maybe_unused]] antlr4::Token *offendingSymbol, size_t line, size_t charPositionInLine, const std::string &msg, [[maybe_unused]] std::exception_ptr e) { - std::cerr << file << ":" << line << ":" << charPositionInLine + 1 << ": " - << msg << std::endl; - has_error = true; + printError(line, charPositionInLine, msg); } -void ErrorListener::typeError(size_t line, size_t charPositionInLine, - std::string_view msg) { - std::cerr << file << ":" << line << ":" << charPositionInLine + 1 << ": " - << msg << std::endl; - has_error = true; +void ErrorListener::typeMismatch(antlr4::Token *token, Type expected, + Type actual) { + printError(token->getLine(), token->getCharPositionInLine(), + "expected type '" + typeToString(expected) + "', but got '" + + typeToString(actual) + "'"); +} + +void ErrorListener::unknownFunction(antlr4::Token *token, + const std::string &name) { + printError(token->getLine(), token->getCharPositionInLine(), + "unknown function '" + name + "'"); +} + +void ErrorListener::unknownVariable(antlr4::Token *token, + const std::string &name) { + printError(token->getLine(), token->getCharPositionInLine(), + "unknown variable '" + name + "'"); +} + +void ErrorListener::wrongArgumentNumber(antlr4::Token *token, + const std::string &name, + size_t expected, size_t actual) { + printError(token->getLine(), token->getCharPositionInLine(), + "function '" + name + "' expects " + std::to_string(expected) + + " arguments, but got " + std::to_string(actual)); } } // namespace xlang diff --git a/bootstrap/error.hh b/bootstrap/error.hh index 6208ef4..c72eac6 100644 --- a/bootstrap/error.hh +++ b/bootstrap/error.hh @@ -2,6 +2,7 @@ #include #include +#include namespace xlang { @@ -9,15 +10,25 @@ class ErrorListener : public antlr4::BaseErrorListener { std::string_view file; bool has_error; + void printError(size_t line, size_t charPositionInLine, std::string_view msg); + public: ErrorListener(std::string_view inputfile); bool hasError(); + + void compilerError(const std::string &file, size_t line); + void duplicateFunction(antlr4::Token *token, const std::string &name); + void shadowedVariable(antlr4::Token *token, const std::string &name); void syntaxError(antlr4::Recognizer *recognizer, antlr4::Token *offendingSymbol, size_t line, size_t charPositionInLine, const std::string &msg, std::exception_ptr e) override; - void typeError(size_t line, size_t charPositionInLine, std::string_view msg); + void typeMismatch(antlr4::Token *token, Type expected, Type actual); + void unknownFunction(antlr4::Token *token, const std::string &name); + void unknownVariable(antlr4::Token *token, const std::string &name); + void wrongArgumentNumber(antlr4::Token *token, const std::string &name, + size_t expected, size_t actual); }; } // namespace xlang diff --git a/bootstrap/scope.hh b/bootstrap/scope.hh new file mode 100644 index 0000000..9fbab0c --- /dev/null +++ b/bootstrap/scope.hh @@ -0,0 +1,42 @@ +#pragma once + +#include +#include +#include +#include + +namespace xlang { + +template +class Scope { + std::vector> stack; + + public: + void enter() { + stack.emplace_back(); + } + + void leave() { + stack.pop_back(); + } + + void add(const std::string &name, T v) { + stack.back().emplace(name, v); + } + + std::optional get(const std::string &name) { + for (auto scope = stack.end() - 1;; scope--) { + auto it = scope->find(name); + + if (it != scope->end()) { + return it->second; + } + if (scope == stack.begin()) { + break; + } + } + return {}; + } +}; + +} // namespace xlang diff --git a/bootstrap/typecheck.cc b/bootstrap/typecheck.cc index 9f352ba..8f10b08 100644 --- a/bootstrap/typecheck.cc +++ b/bootstrap/typecheck.cc @@ -1,17 +1,16 @@ +#include #include namespace xlang { -bool TypeCheckVisitor::inScope(const std::string &name) { - for (auto it = scope.end() - 1;; it--) { - if (it->find(name) != it->end()) { - return true; - } - if (it == scope.begin()) { - break; - } +std::string typeToString(Type type) { + switch (type) { + case Type::Integer: + return "int"; + case Type::Boolean: + return "bool"; } - return false; + return ""; } TypeCheckVisitor::TypeCheckVisitor(ErrorListener &errorlistener) @@ -21,91 +20,243 @@ std::any TypeCheckVisitor::visitFile(xlangParser::FileContext *ctx) { for (auto function : ctx->function()) { auto token = function->Identifier()->getSymbol(); auto name = token->getText(); + Signature signature{std::any_cast(visitType(function->type())), {}}; - if (function_arity.find(name) != function_arity.end()) { - errorlistener.typeError(token->getLine(), token->getCharPositionInLine(), - "duplicate function '" + name + "'"); + if (signatures.find(name) != signatures.end()) { + errorlistener.duplicateFunction(token, name); continue; } if (auto param_list = function->parameterList()) { - function_arity[name] = param_list->Identifier().size(); - } else { - function_arity[name] = 0; + for (auto type : param_list->type()) { + signature.parametertypes.push_back( + std::any_cast(visitType(type))); + } } + signatures.emplace(name, signature); } visitChildren(ctx); return {}; } std::any TypeCheckVisitor::visitFunction(xlangParser::FunctionContext *ctx) { - scope.emplace_back(); + scope.enter(); if (auto param_list = ctx->parameterList()) { - for (const auto param : param_list->Identifier()) { - scope.back().emplace(param->getSymbol()->getText()); + for (size_t i = 0, n = param_list->Identifier().size(); i < n; i++) { + auto name = param_list->Identifier(i)->getSymbol()->getText(); + auto type = std::any_cast(visitType(param_list->type(i))); + + scope.add(name, type); } } + returntype = std::any_cast(visitType(ctx->type())); visitBlock(ctx->block()); - scope.pop_back(); + scope.leave(); + return {}; +} + +std::any TypeCheckVisitor::visitType(xlangParser::TypeContext *ctx) { + if (ctx->TypeInteger()) { + return Type::Integer; + } + if (ctx->TypeBoolean()) { + return Type::Boolean; + } + // unreachable + errorlistener.compilerError(__FILE__, __LINE__); return {}; } std::any TypeCheckVisitor::visitBlock(xlangParser::BlockContext *ctx) { - scope.emplace_back(); + scope.enter(); visitChildren(ctx); - scope.pop_back(); + scope.leave(); return {}; } std::any TypeCheckVisitor::visitStatement(xlangParser::StatementContext *ctx) { if (auto identifier = ctx->Identifier()) { - visitValue(ctx->value()); - scope.back().emplace(identifier->getSymbol()->getText()); + auto name = identifier->getSymbol()->getText(); + auto type = std::any_cast(visitValue(ctx->value())); + + if (ctx->Define()) { + if (scope.get(name)) { + errorlistener.shadowedVariable(identifier->getSymbol(), name); + } + scope.add(name, type); + return {}; + } + if (ctx->Assign()) { + if (auto expected = scope.get(name)) { + if (type != *expected) { + errorlistener.typeMismatch(identifier->getSymbol(), *expected, type); + } + } else { + errorlistener.unknownVariable(identifier->getSymbol(), name); + } + return {}; + } + // unreachable + errorlistener.compilerError(__FILE__, __LINE__); return {}; } + if (auto condition = ctx->condition()) { + auto type = std::any_cast(visitCondition(condition)); + + if (type != Type::Boolean) { + errorlistener.typeMismatch(condition->getStart(), Type::Boolean, type); + } + visitBlock(ctx->block(0)); + if (ctx->Else()) { + visitBlock(ctx->block(1)); + } + return {}; + } + if (ctx->Return()) { + auto type = std::any_cast(visitValue(ctx->value())); + + if (type != returntype) { + errorlistener.typeMismatch(ctx->value()->getStart(), returntype, type); + } + } visitChildren(ctx); return {}; } +std::any TypeCheckVisitor::visitValue(xlangParser::ValueContext *ctx) { + if (auto expr = ctx->expr()) { + return visitExpr(expr); + } + if (auto condition = ctx->condition()) { + return visitCondition(condition); + } + // unreachable + errorlistener.compilerError(__FILE__, __LINE__); + return {}; +} + +#define CHECKOPERATOR(visitLeft, left, visitRight, right, expected) \ + auto lefttype = std::any_cast(visitLeft(left)); \ + auto righttype = std::any_cast(visitRight(right)); \ + \ + if (lefttype != expected) { \ + errorlistener.typeMismatch(left->getStart(), expected, lefttype); \ + } \ + if (righttype != expected) { \ + errorlistener.typeMismatch(right->getStart(), expected, righttype); \ + } + +std::any TypeCheckVisitor::visitCondition(xlangParser::ConditionContext *ctx) { + if (auto condition = ctx->condition()) { + CHECKOPERATOR(visitCondition, condition, visitBoolean, ctx->boolean(), + Type::Boolean); + return Type::Boolean; + } + return visitBoolean(ctx->boolean()); +} + +std::any TypeCheckVisitor::visitBoolean(xlangParser::BooleanContext *ctx) { + if (auto boolean = ctx->boolean()) { + auto type = std::any_cast(visitBoolean(boolean)); + + if (type != Type::Boolean) { + errorlistener.typeMismatch(boolean->getStart(), Type::Boolean, type); + } + return Type::Boolean; + } + if (ctx->True() || ctx->False()) { + return Type::Boolean; + } + if (ctx->expr().size()) { + CHECKOPERATOR(visitExpr, ctx->expr(0), visitExpr, ctx->expr(1), + Type::Integer); + return Type::Boolean; + } + if (auto variable = ctx->variable()) { + return visitVariable(variable); + } + return visitCondition(ctx->condition()); +} + +std::any TypeCheckVisitor::visitExpr(xlangParser::ExprContext *ctx) { + if (auto expr = ctx->expr()) { + CHECKOPERATOR(visitExpr, expr, visitTerm, ctx->term(), Type::Integer); + return Type::Integer; + } + return visitTerm(ctx->term()); +} + +std::any TypeCheckVisitor::visitTerm(xlangParser::TermContext *ctx) { + if (auto term = ctx->term()) { + CHECKOPERATOR(visitTerm, term, visitFactor, ctx->factor(), Type::Integer); + return Type::Integer; + } + return visitFactor(ctx->factor()); +} + +std::any TypeCheckVisitor::visitFactor(xlangParser::FactorContext *ctx) { + if (auto factor = ctx->factor()) { + auto type = std::any_cast(visitFactor(factor)); + + if (type != Type::Integer) { + errorlistener.typeMismatch(factor->getStart(), Type::Integer, type); + } + return Type::Integer; + } + if (ctx->Integer()) { + return Type::Integer; + } + if (auto variable = ctx->variable()) { + return visitVariable(variable); + } + return visitExpr(ctx->expr()); +} + std::any TypeCheckVisitor::visitVariable(xlangParser::VariableContext *ctx) { auto token = ctx->Identifier()->getSymbol(); auto name = token->getText(); - auto line = token->getLine(); - auto charPositionInLine = token->getCharPositionInLine(); if (ctx->LeftParen()) { - auto it = function_arity.find(name); + auto it = signatures.find(name); - if (it == function_arity.end()) { - errorlistener.typeError(line, charPositionInLine, - "unknown function '" + name + "'"); + if (it == signatures.end()) { + errorlistener.unknownFunction(token, name); + return Type::Integer; } else { - auto arity = it->second; + auto signature = it->second; + auto arity = signature.parametertypes.size(); if (auto arg_list = ctx->argumentList()) { auto arg_num = arg_list->value().size(); if (arity != arg_num) { - errorlistener.typeError( - line, charPositionInLine, - "function '" + name + "' expects " + std::to_string(arity) + - " arguments, but got " + std::to_string(arg_num)); + errorlistener.wrongArgumentNumber(token, name, arity, arg_num); + } + for (size_t i = 0; i < arg_num && i < arity; i++) { + auto type = std::any_cast(visitValue(arg_list->value(i))); + auto token = arg_list->value(i)->getStart(); + + if (type != signature.parametertypes[i]) { + errorlistener.typeMismatch(token, signature.parametertypes[i], + type); + } } } else { if (arity) { - errorlistener.typeError(line, charPositionInLine, - "function '" + name + "' expects " + - std::to_string(arity) + - " arguments, but got none"); + errorlistener.wrongArgumentNumber(token, name, arity, 0); } } + return signature.returntype; } - visitChildren(ctx); } else { - if (!inScope(name)) { - errorlistener.typeError(line, charPositionInLine, - "variable '" + name + "' is not in scope"); + auto type = scope.get(name); + if (!type) { + errorlistener.unknownVariable(token, name); + return Type::Integer; } + return *type; } + // unreachable + errorlistener.compilerError(__FILE__, __LINE__); return {}; } diff --git a/bootstrap/typecheck.hh b/bootstrap/typecheck.hh index cf84daf..9dfffee 100644 --- a/bootstrap/typecheck.hh +++ b/bootstrap/typecheck.hh @@ -1,28 +1,48 @@ #pragma once -#include +#include #include #include -#include #include #include namespace xlang { -class TypeCheckVisitor : public xlangBaseVisitor { - ErrorListener &errorlistener; - std::unordered_map function_arity; - std::vector> scope; +class ErrorListener; - bool inScope(const std::string &name); +enum class Type +{ + Integer, + Boolean, +}; + +std::string typeToString(Type type); + +class TypeCheckVisitor : public xlangBaseVisitor { + struct Signature { + Type returntype; + std::vector parametertypes; + }; + + ErrorListener &errorlistener; + std::unordered_map signatures; + Type returntype; + Scope scope; public: TypeCheckVisitor(ErrorListener &errorlistener); std::any visitFile(xlangParser::FileContext *ctx) override; std::any visitFunction(xlangParser::FunctionContext *ctx) override; + std::any visitType(xlangParser::TypeContext *ctx) override; std::any visitBlock(xlangParser::BlockContext *ctx) override; std::any visitStatement(xlangParser::StatementContext *ctx) override; + std::any visitValue(xlangParser::ValueContext *ctx) override; + std::any visitCondition(xlangParser::ConditionContext *ctx) override; + std::any visitBoolean(xlangParser::BooleanContext *ctx) override; + std::any visitExpr(xlangParser::ExprContext *ctx) override; + std::any visitTerm(xlangParser::TermContext *ctx) override; + std::any visitFactor(xlangParser::FactorContext *ctx) override; std::any visitVariable(xlangParser::VariableContext *ctx) override; }; diff --git a/bootstrap/xlang.g4 b/bootstrap/xlang.g4 index 2370060..50618ad 100644 --- a/bootstrap/xlang.g4 +++ b/bootstrap/xlang.g4 @@ -1,10 +1,14 @@ grammar xlang; file : function+ EOF; -function : Identifier LeftParen parameterList? RightParen block; -parameterList : Identifier (Comma Identifier)*; +function : Identifier LeftParen parameterList? RightParen Colon type block; +parameterList : Identifier Colon type (Comma Identifier Colon type)*; +type : TypeInteger + | TypeBoolean + ; block : LeftBrace statement* RightBrace; -statement : Identifier Assign value Semicolon +statement : Identifier Define value Semicolon + | Identifier Assign value Semicolon | If condition block (Else block)? | While condition block | Return value Semicolon @@ -40,6 +44,8 @@ variable : Identifier ; argumentList : value (Comma value)*; +TypeInteger : 'int'; +TypeBoolean : 'bool'; If : 'if'; Else : 'else'; While : 'while'; @@ -54,8 +60,10 @@ False : 'false'; LeftParen : '('; RightParen : ')'; +Colon : ':'; LeftBrace : '{'; RightBrace : '}'; +Define : ':='; Assign : '='; Less : '<'; LessEqual : '<='; diff --git a/test/42.x b/test/42.x index fd51ea0..cd4223d 100644 --- a/test/42.x +++ b/test/42.x @@ -1,3 +1,3 @@ -main() { +main() : int { print 42; } diff --git a/test/fib.x b/test/fib.x index fb2690b..e097db3 100644 --- a/test/fib.x +++ b/test/fib.x @@ -1,5 +1,5 @@ -main() { - i = 0; +main() : int { + i := 0; while i < 10 { print fib_rec(i); print fib_iter(i); @@ -7,19 +7,19 @@ main() { } } -fib_rec(n) { +fib_rec(n : int) : int { if n < 2 { return 1; } return fib_rec(n - 1) + fib_rec(n - 2); } -fib_iter(n) { - x0 = 1; - x1 = 1; - i = 0; +fib_iter(n : int) : int { + x0 := 1; + x1 := 1; + i := 0; while i < n { - t = x0 + x1; + t := x0 + x1; x0 = x1; x1 = t; i = i + 1;