type checking

development
Thomas Lindner 2023-01-07 08:54:53 +01:00
parent 4f3d1d2f2e
commit f169088f36
8 changed files with 346 additions and 70 deletions

View File

@ -1,8 +1,16 @@
#include <error.hh> #include <error.hh>
#include <exception>
#include <iostream> #include <iostream>
namespace xlang { namespace xlang {
void ErrorListener::printError(size_t line, size_t charPositionInLine,
std::string_view msg) {
std::cerr << file << ":" << line << ":" << charPositionInLine + 1 << ": "
<< msg << std::endl;
has_error = true;
}
ErrorListener::ErrorListener(std::string_view inputfile) ErrorListener::ErrorListener(std::string_view inputfile)
: file{inputfile}, has_error{false} {} : file{inputfile}, has_error{false} {}
@ -10,21 +18,57 @@ bool ErrorListener::hasError() {
return has_error; return has_error;
} }
void ErrorListener::compilerError(const std::string &file, size_t line) {
std::cerr << "compiler bug in " << file << ":" << line << std::endl;
std::terminate();
}
void ErrorListener::duplicateFunction(antlr4::Token *token,
const std::string &name) {
printError(token->getLine(), token->getCharPositionInLine(),
"duplicate function '" + name + "'");
}
void ErrorListener::shadowedVariable(antlr4::Token *token,
const std::string &name) {
printError(token->getLine(), token->getCharPositionInLine(),
"definition of variable '" + name +
"' shadows previously defined variable");
}
void ErrorListener::syntaxError([[maybe_unused]] antlr4::Recognizer *recognizer, void ErrorListener::syntaxError([[maybe_unused]] antlr4::Recognizer *recognizer,
[[maybe_unused]] antlr4::Token *offendingSymbol, [[maybe_unused]] antlr4::Token *offendingSymbol,
size_t line, size_t charPositionInLine, size_t line, size_t charPositionInLine,
const std::string &msg, const std::string &msg,
[[maybe_unused]] std::exception_ptr e) { [[maybe_unused]] std::exception_ptr e) {
std::cerr << file << ":" << line << ":" << charPositionInLine + 1 << ": " printError(line, charPositionInLine, msg);
<< msg << std::endl;
has_error = true;
} }
void ErrorListener::typeError(size_t line, size_t charPositionInLine, void ErrorListener::typeMismatch(antlr4::Token *token, Type expected,
std::string_view msg) { Type actual) {
std::cerr << file << ":" << line << ":" << charPositionInLine + 1 << ": " printError(token->getLine(), token->getCharPositionInLine(),
<< msg << std::endl; "expected type '" + typeToString(expected) + "', but got '" +
has_error = true; typeToString(actual) + "'");
}
void ErrorListener::unknownFunction(antlr4::Token *token,
const std::string &name) {
printError(token->getLine(), token->getCharPositionInLine(),
"unknown function '" + name + "'");
}
void ErrorListener::unknownVariable(antlr4::Token *token,
const std::string &name) {
printError(token->getLine(), token->getCharPositionInLine(),
"unknown variable '" + name + "'");
}
void ErrorListener::wrongArgumentNumber(antlr4::Token *token,
const std::string &name,
size_t expected, size_t actual) {
printError(token->getLine(), token->getCharPositionInLine(),
"function '" + name + "' expects " + std::to_string(expected) +
" arguments, but got " + std::to_string(actual));
} }
} // namespace xlang } // namespace xlang

View File

@ -2,6 +2,7 @@
#include <BaseErrorListener.h> #include <BaseErrorListener.h>
#include <string_view> #include <string_view>
#include <typecheck.hh>
namespace xlang { namespace xlang {
@ -9,15 +10,25 @@ class ErrorListener : public antlr4::BaseErrorListener {
std::string_view file; std::string_view file;
bool has_error; bool has_error;
void printError(size_t line, size_t charPositionInLine, std::string_view msg);
public: public:
ErrorListener(std::string_view inputfile); ErrorListener(std::string_view inputfile);
bool hasError(); bool hasError();
void compilerError(const std::string &file, size_t line);
void duplicateFunction(antlr4::Token *token, const std::string &name);
void shadowedVariable(antlr4::Token *token, const std::string &name);
void syntaxError(antlr4::Recognizer *recognizer, void syntaxError(antlr4::Recognizer *recognizer,
antlr4::Token *offendingSymbol, size_t line, antlr4::Token *offendingSymbol, size_t line,
size_t charPositionInLine, const std::string &msg, size_t charPositionInLine, const std::string &msg,
std::exception_ptr e) override; std::exception_ptr e) override;
void typeError(size_t line, size_t charPositionInLine, std::string_view msg); void typeMismatch(antlr4::Token *token, Type expected, Type actual);
void unknownFunction(antlr4::Token *token, const std::string &name);
void unknownVariable(antlr4::Token *token, const std::string &name);
void wrongArgumentNumber(antlr4::Token *token, const std::string &name,
size_t expected, size_t actual);
}; };
} // namespace xlang } // namespace xlang

42
bootstrap/scope.hh Normal file
View File

@ -0,0 +1,42 @@
#pragma once
#include <optional>
#include <string>
#include <unordered_map>
#include <vector>
namespace xlang {
template <typename T>
class Scope {
std::vector<std::unordered_map<std::string, T>> stack;
public:
void enter() {
stack.emplace_back();
}
void leave() {
stack.pop_back();
}
void add(const std::string &name, T v) {
stack.back().emplace(name, v);
}
std::optional<T> get(const std::string &name) {
for (auto scope = stack.end() - 1;; scope--) {
auto it = scope->find(name);
if (it != scope->end()) {
return it->second;
}
if (scope == stack.begin()) {
break;
}
}
return {};
}
};
} // namespace xlang

View File

@ -1,17 +1,16 @@
#include <error.hh>
#include <typecheck.hh> #include <typecheck.hh>
namespace xlang { namespace xlang {
bool TypeCheckVisitor::inScope(const std::string &name) { std::string typeToString(Type type) {
for (auto it = scope.end() - 1;; it--) { switch (type) {
if (it->find(name) != it->end()) { case Type::Integer:
return true; return "int";
} case Type::Boolean:
if (it == scope.begin()) { return "bool";
break;
}
} }
return false; return "<invalid>";
} }
TypeCheckVisitor::TypeCheckVisitor(ErrorListener &errorlistener) TypeCheckVisitor::TypeCheckVisitor(ErrorListener &errorlistener)
@ -21,91 +20,243 @@ std::any TypeCheckVisitor::visitFile(xlangParser::FileContext *ctx) {
for (auto function : ctx->function()) { for (auto function : ctx->function()) {
auto token = function->Identifier()->getSymbol(); auto token = function->Identifier()->getSymbol();
auto name = token->getText(); auto name = token->getText();
Signature signature{std::any_cast<Type>(visitType(function->type())), {}};
if (function_arity.find(name) != function_arity.end()) { if (signatures.find(name) != signatures.end()) {
errorlistener.typeError(token->getLine(), token->getCharPositionInLine(), errorlistener.duplicateFunction(token, name);
"duplicate function '" + name + "'");
continue; continue;
} }
if (auto param_list = function->parameterList()) { if (auto param_list = function->parameterList()) {
function_arity[name] = param_list->Identifier().size(); for (auto type : param_list->type()) {
} else { signature.parametertypes.push_back(
function_arity[name] = 0; std::any_cast<Type>(visitType(type)));
}
} }
signatures.emplace(name, signature);
} }
visitChildren(ctx); visitChildren(ctx);
return {}; return {};
} }
std::any TypeCheckVisitor::visitFunction(xlangParser::FunctionContext *ctx) { std::any TypeCheckVisitor::visitFunction(xlangParser::FunctionContext *ctx) {
scope.emplace_back(); scope.enter();
if (auto param_list = ctx->parameterList()) { if (auto param_list = ctx->parameterList()) {
for (const auto param : param_list->Identifier()) { for (size_t i = 0, n = param_list->Identifier().size(); i < n; i++) {
scope.back().emplace(param->getSymbol()->getText()); auto name = param_list->Identifier(i)->getSymbol()->getText();
auto type = std::any_cast<Type>(visitType(param_list->type(i)));
scope.add(name, type);
} }
} }
returntype = std::any_cast<Type>(visitType(ctx->type()));
visitBlock(ctx->block()); visitBlock(ctx->block());
scope.pop_back(); scope.leave();
return {};
}
std::any TypeCheckVisitor::visitType(xlangParser::TypeContext *ctx) {
if (ctx->TypeInteger()) {
return Type::Integer;
}
if (ctx->TypeBoolean()) {
return Type::Boolean;
}
// unreachable
errorlistener.compilerError(__FILE__, __LINE__);
return {}; return {};
} }
std::any TypeCheckVisitor::visitBlock(xlangParser::BlockContext *ctx) { std::any TypeCheckVisitor::visitBlock(xlangParser::BlockContext *ctx) {
scope.emplace_back(); scope.enter();
visitChildren(ctx); visitChildren(ctx);
scope.pop_back(); scope.leave();
return {}; return {};
} }
std::any TypeCheckVisitor::visitStatement(xlangParser::StatementContext *ctx) { std::any TypeCheckVisitor::visitStatement(xlangParser::StatementContext *ctx) {
if (auto identifier = ctx->Identifier()) { if (auto identifier = ctx->Identifier()) {
visitValue(ctx->value()); auto name = identifier->getSymbol()->getText();
scope.back().emplace(identifier->getSymbol()->getText()); auto type = std::any_cast<Type>(visitValue(ctx->value()));
if (ctx->Define()) {
if (scope.get(name)) {
errorlistener.shadowedVariable(identifier->getSymbol(), name);
}
scope.add(name, type);
return {};
}
if (ctx->Assign()) {
if (auto expected = scope.get(name)) {
if (type != *expected) {
errorlistener.typeMismatch(identifier->getSymbol(), *expected, type);
}
} else {
errorlistener.unknownVariable(identifier->getSymbol(), name);
}
return {};
}
// unreachable
errorlistener.compilerError(__FILE__, __LINE__);
return {}; return {};
} }
if (auto condition = ctx->condition()) {
auto type = std::any_cast<Type>(visitCondition(condition));
if (type != Type::Boolean) {
errorlistener.typeMismatch(condition->getStart(), Type::Boolean, type);
}
visitBlock(ctx->block(0));
if (ctx->Else()) {
visitBlock(ctx->block(1));
}
return {};
}
if (ctx->Return()) {
auto type = std::any_cast<Type>(visitValue(ctx->value()));
if (type != returntype) {
errorlistener.typeMismatch(ctx->value()->getStart(), returntype, type);
}
}
visitChildren(ctx); visitChildren(ctx);
return {}; return {};
} }
std::any TypeCheckVisitor::visitValue(xlangParser::ValueContext *ctx) {
if (auto expr = ctx->expr()) {
return visitExpr(expr);
}
if (auto condition = ctx->condition()) {
return visitCondition(condition);
}
// unreachable
errorlistener.compilerError(__FILE__, __LINE__);
return {};
}
#define CHECKOPERATOR(visitLeft, left, visitRight, right, expected) \
auto lefttype = std::any_cast<Type>(visitLeft(left)); \
auto righttype = std::any_cast<Type>(visitRight(right)); \
\
if (lefttype != expected) { \
errorlistener.typeMismatch(left->getStart(), expected, lefttype); \
} \
if (righttype != expected) { \
errorlistener.typeMismatch(right->getStart(), expected, righttype); \
}
std::any TypeCheckVisitor::visitCondition(xlangParser::ConditionContext *ctx) {
if (auto condition = ctx->condition()) {
CHECKOPERATOR(visitCondition, condition, visitBoolean, ctx->boolean(),
Type::Boolean);
return Type::Boolean;
}
return visitBoolean(ctx->boolean());
}
std::any TypeCheckVisitor::visitBoolean(xlangParser::BooleanContext *ctx) {
if (auto boolean = ctx->boolean()) {
auto type = std::any_cast<Type>(visitBoolean(boolean));
if (type != Type::Boolean) {
errorlistener.typeMismatch(boolean->getStart(), Type::Boolean, type);
}
return Type::Boolean;
}
if (ctx->True() || ctx->False()) {
return Type::Boolean;
}
if (ctx->expr().size()) {
CHECKOPERATOR(visitExpr, ctx->expr(0), visitExpr, ctx->expr(1),
Type::Integer);
return Type::Boolean;
}
if (auto variable = ctx->variable()) {
return visitVariable(variable);
}
return visitCondition(ctx->condition());
}
std::any TypeCheckVisitor::visitExpr(xlangParser::ExprContext *ctx) {
if (auto expr = ctx->expr()) {
CHECKOPERATOR(visitExpr, expr, visitTerm, ctx->term(), Type::Integer);
return Type::Integer;
}
return visitTerm(ctx->term());
}
std::any TypeCheckVisitor::visitTerm(xlangParser::TermContext *ctx) {
if (auto term = ctx->term()) {
CHECKOPERATOR(visitTerm, term, visitFactor, ctx->factor(), Type::Integer);
return Type::Integer;
}
return visitFactor(ctx->factor());
}
std::any TypeCheckVisitor::visitFactor(xlangParser::FactorContext *ctx) {
if (auto factor = ctx->factor()) {
auto type = std::any_cast<Type>(visitFactor(factor));
if (type != Type::Integer) {
errorlistener.typeMismatch(factor->getStart(), Type::Integer, type);
}
return Type::Integer;
}
if (ctx->Integer()) {
return Type::Integer;
}
if (auto variable = ctx->variable()) {
return visitVariable(variable);
}
return visitExpr(ctx->expr());
}
std::any TypeCheckVisitor::visitVariable(xlangParser::VariableContext *ctx) { std::any TypeCheckVisitor::visitVariable(xlangParser::VariableContext *ctx) {
auto token = ctx->Identifier()->getSymbol(); auto token = ctx->Identifier()->getSymbol();
auto name = token->getText(); auto name = token->getText();
auto line = token->getLine();
auto charPositionInLine = token->getCharPositionInLine();
if (ctx->LeftParen()) { if (ctx->LeftParen()) {
auto it = function_arity.find(name); auto it = signatures.find(name);
if (it == function_arity.end()) { if (it == signatures.end()) {
errorlistener.typeError(line, charPositionInLine, errorlistener.unknownFunction(token, name);
"unknown function '" + name + "'"); return Type::Integer;
} else { } else {
auto arity = it->second; auto signature = it->second;
auto arity = signature.parametertypes.size();
if (auto arg_list = ctx->argumentList()) { if (auto arg_list = ctx->argumentList()) {
auto arg_num = arg_list->value().size(); auto arg_num = arg_list->value().size();
if (arity != arg_num) { if (arity != arg_num) {
errorlistener.typeError( errorlistener.wrongArgumentNumber(token, name, arity, arg_num);
line, charPositionInLine, }
"function '" + name + "' expects " + std::to_string(arity) + for (size_t i = 0; i < arg_num && i < arity; i++) {
" arguments, but got " + std::to_string(arg_num)); auto type = std::any_cast<Type>(visitValue(arg_list->value(i)));
auto token = arg_list->value(i)->getStart();
if (type != signature.parametertypes[i]) {
errorlistener.typeMismatch(token, signature.parametertypes[i],
type);
}
} }
} else { } else {
if (arity) { if (arity) {
errorlistener.typeError(line, charPositionInLine, errorlistener.wrongArgumentNumber(token, name, arity, 0);
"function '" + name + "' expects " +
std::to_string(arity) +
" arguments, but got none");
} }
} }
return signature.returntype;
} }
visitChildren(ctx);
} else { } else {
if (!inScope(name)) { auto type = scope.get(name);
errorlistener.typeError(line, charPositionInLine, if (!type) {
"variable '" + name + "' is not in scope"); errorlistener.unknownVariable(token, name);
return Type::Integer;
} }
return *type;
} }
// unreachable
errorlistener.compilerError(__FILE__, __LINE__);
return {}; return {};
} }

View File

@ -1,28 +1,48 @@
#pragma once #pragma once
#include <error.hh> #include <scope.hh>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <unordered_set>
#include <vector> #include <vector>
#include <xlangBaseVisitor.h> #include <xlangBaseVisitor.h>
namespace xlang { namespace xlang {
class TypeCheckVisitor : public xlangBaseVisitor { class ErrorListener;
ErrorListener &errorlistener;
std::unordered_map<std::string, unsigned int> function_arity;
std::vector<std::unordered_set<std::string>> scope;
bool inScope(const std::string &name); enum class Type
{
Integer,
Boolean,
};
std::string typeToString(Type type);
class TypeCheckVisitor : public xlangBaseVisitor {
struct Signature {
Type returntype;
std::vector<Type> parametertypes;
};
ErrorListener &errorlistener;
std::unordered_map<std::string, Signature> signatures;
Type returntype;
Scope<Type> scope;
public: public:
TypeCheckVisitor(ErrorListener &errorlistener); TypeCheckVisitor(ErrorListener &errorlistener);
std::any visitFile(xlangParser::FileContext *ctx) override; std::any visitFile(xlangParser::FileContext *ctx) override;
std::any visitFunction(xlangParser::FunctionContext *ctx) override; std::any visitFunction(xlangParser::FunctionContext *ctx) override;
std::any visitType(xlangParser::TypeContext *ctx) override;
std::any visitBlock(xlangParser::BlockContext *ctx) override; std::any visitBlock(xlangParser::BlockContext *ctx) override;
std::any visitStatement(xlangParser::StatementContext *ctx) override; std::any visitStatement(xlangParser::StatementContext *ctx) override;
std::any visitValue(xlangParser::ValueContext *ctx) override;
std::any visitCondition(xlangParser::ConditionContext *ctx) override;
std::any visitBoolean(xlangParser::BooleanContext *ctx) override;
std::any visitExpr(xlangParser::ExprContext *ctx) override;
std::any visitTerm(xlangParser::TermContext *ctx) override;
std::any visitFactor(xlangParser::FactorContext *ctx) override;
std::any visitVariable(xlangParser::VariableContext *ctx) override; std::any visitVariable(xlangParser::VariableContext *ctx) override;
}; };

View File

@ -1,10 +1,14 @@
grammar xlang; grammar xlang;
file : function+ EOF; file : function+ EOF;
function : Identifier LeftParen parameterList? RightParen block; function : Identifier LeftParen parameterList? RightParen Colon type block;
parameterList : Identifier (Comma Identifier)*; parameterList : Identifier Colon type (Comma Identifier Colon type)*;
type : TypeInteger
| TypeBoolean
;
block : LeftBrace statement* RightBrace; block : LeftBrace statement* RightBrace;
statement : Identifier Assign value Semicolon statement : Identifier Define value Semicolon
| Identifier Assign value Semicolon
| If condition block (Else block)? | If condition block (Else block)?
| While condition block | While condition block
| Return value Semicolon | Return value Semicolon
@ -40,6 +44,8 @@ variable : Identifier
; ;
argumentList : value (Comma value)*; argumentList : value (Comma value)*;
TypeInteger : 'int';
TypeBoolean : 'bool';
If : 'if'; If : 'if';
Else : 'else'; Else : 'else';
While : 'while'; While : 'while';
@ -54,8 +60,10 @@ False : 'false';
LeftParen : '('; LeftParen : '(';
RightParen : ')'; RightParen : ')';
Colon : ':';
LeftBrace : '{'; LeftBrace : '{';
RightBrace : '}'; RightBrace : '}';
Define : ':=';
Assign : '='; Assign : '=';
Less : '<'; Less : '<';
LessEqual : '<='; LessEqual : '<=';

View File

@ -1,3 +1,3 @@
main() { main() : int {
print 42; print 42;
} }

View File

@ -1,5 +1,5 @@
main() { main() : int {
i = 0; i := 0;
while i < 10 { while i < 10 {
print fib_rec(i); print fib_rec(i);
print fib_iter(i); print fib_iter(i);
@ -7,19 +7,19 @@ main() {
} }
} }
fib_rec(n) { fib_rec(n : int) : int {
if n < 2 { if n < 2 {
return 1; return 1;
} }
return fib_rec(n - 1) + fib_rec(n - 2); return fib_rec(n - 1) + fib_rec(n - 2);
} }
fib_iter(n) { fib_iter(n : int) : int {
x0 = 1; x0 := 1;
x1 = 1; x1 := 1;
i = 0; i := 0;
while i < n { while i < n {
t = x0 + x1; t := x0 + x1;
x0 = x1; x0 = x1;
x1 = t; x1 = t;
i = i + 1; i = i + 1;