#include <error.hh>
#include <typecheck.hh>

namespace xlang {

std::string typeToString(Type type) {
  switch (type) {
    case Type::Integer:
      return "int";
    case Type::Boolean:
      return "bool";
    default:
    case Type::Invalid:
      return "<invalid>";
  }
}

TypeCheckVisitor::TypeCheckVisitor(ErrorListener &errorlistener)
    : errorlistener{errorlistener}, loopcount{0} {}

std::any TypeCheckVisitor::visitFile(xlangParser::FileContext *ctx) {
  for (auto function : ctx->function()) {
    auto token = function->Identifier()->getSymbol();
    auto name = token->getText();
    Signature signature{std::any_cast<Type>(visitType(function->type())), {}};

    if (signatures.find(name) != signatures.end()) {
      errorlistener.duplicateFunction(token, name);
      continue;
    }
    if (auto param_list = function->parameterList()) {
      for (auto type : param_list->type()) {
        signature.parametertypes.push_back(
            std::any_cast<Type>(visitType(type)));
      }
    }
    signatures.emplace(name, signature);
  }
  visitChildren(ctx);
  return {};
}

std::any TypeCheckVisitor::visitFunction(xlangParser::FunctionContext *ctx) {
  scope.enter();
  if (auto param_list = ctx->parameterList()) {
    for (size_t i = 0, n = param_list->Identifier().size(); i < n; i++) {
      auto name = param_list->Identifier(i)->getSymbol()->getText();
      auto type = std::any_cast<Type>(visitType(param_list->type(i)));

      scope.add(name, type);
    }
  }
  returntype = std::any_cast<Type>(visitType(ctx->type()));
  visitBlock(ctx->block());
  scope.leave();
  return {};
}

std::any TypeCheckVisitor::visitType(xlangParser::TypeContext *ctx) {
  if (ctx->TypeInteger()) {
    return Type::Integer;
  }
  if (ctx->TypeBoolean()) {
    return Type::Boolean;
  }
  // unreachable
  errorlistener.compilerError(__FILE__, __LINE__);
}

std::any TypeCheckVisitor::visitBlock(xlangParser::BlockContext *ctx) {
  scope.enter();
  visitChildren(ctx);
  scope.leave();
  return {};
}

std::any TypeCheckVisitor::visitStatement(xlangParser::StatementContext *ctx) {
  if (ctx->If()) {
    scope.enter();
    auto condition = ctx->condition();
    auto type = std::any_cast<Type>(visitCondition(condition));
    if (type != Type::Boolean) {
      errorlistener.typeMismatch(condition->getStart(), Type::Boolean, type);
    }
    visitBlock(ctx->block(0));
    if (ctx->Else()) {
      visitBlock(ctx->block(1));
    }
    scope.leave();
    return {};
  }
  if (ctx->While() || ctx->For()) {
    scope.enter();
    if (ctx->For()) {
      visitValue(ctx->value(0));
    }
    auto condition = ctx->condition();
    auto type = std::any_cast<Type>(visitCondition(condition));
    if (type != Type::Boolean) {
      errorlistener.typeMismatch(condition->getStart(), Type::Boolean, type);
    }
    if (ctx->For()) {
      visitValue(ctx->value(1));
    }
    loopcount++;
    visitBlock(ctx->block(0));
    loopcount--;
    scope.leave();
    return {};
  }
  if (ctx->Break()) {
    if (auto integer = ctx->Integer()) {
      auto num = stoul(integer->getSymbol()->getText());
      if (!num) {
        errorlistener.breakZero(ctx->Break()->getSymbol());
      }
      if (num > loopcount) {
        errorlistener.breakTooMany(ctx->Break()->getSymbol(), num, loopcount);
      }
    } else {
      if (!loopcount) {
        errorlistener.loopControlWithoutLoop(ctx->Break()->getSymbol());
      }
    }
    return {};
  }
  if (ctx->Continue()) {
    if (auto integer = ctx->Integer()) {
      auto num = stoul(integer->getSymbol()->getText());
      if (!num) {
        errorlistener.continueZero(ctx->Break()->getSymbol());
      }
      if (num > loopcount) {
        errorlistener.continueTooMany(ctx->Break()->getSymbol(), num,
                                      loopcount);
      }
    } else {
      if (!loopcount) {
        errorlistener.loopControlWithoutLoop(ctx->Continue()->getSymbol());
      }
    }
    return {};
  }
  if (ctx->Return()) {
    auto type = std::any_cast<Type>(visitValue(ctx->value(0)));

    if (type != returntype) {
      errorlistener.typeMismatch(ctx->value(0)->getStart(), returntype, type);
    }
  }
  visitChildren(ctx);
  return {};
}

std::any TypeCheckVisitor::visitValue(xlangParser::ValueContext *ctx) {
  if (auto expr = ctx->expr()) {
    return visitExpr(expr);
  }
  if (auto condition = ctx->condition()) {
    return visitCondition(condition);
  }
  // unreachable
  errorlistener.compilerError(__FILE__, __LINE__);
}

#define CHECKOPERATOR(visitLeft, left, visitRight, right, expected)     \
  auto lefttype = std::any_cast<Type>(visitLeft(left));                 \
  auto righttype = std::any_cast<Type>(visitRight(right));              \
                                                                        \
  if (lefttype != expected) {                                           \
    errorlistener.typeMismatch(left->getStart(), expected, lefttype);   \
  }                                                                     \
  if (righttype != expected) {                                          \
    errorlistener.typeMismatch(right->getStart(), expected, righttype); \
  }

std::any TypeCheckVisitor::visitCondition(xlangParser::ConditionContext *ctx) {
  if (auto condition = ctx->condition()) {
    CHECKOPERATOR(visitCondition, condition, visitBoolean, ctx->boolean(),
                  Type::Boolean);
    return Type::Boolean;
  }
  return visitBoolean(ctx->boolean());
}

std::any TypeCheckVisitor::visitBoolean(xlangParser::BooleanContext *ctx) {
  if (auto boolean = ctx->boolean()) {
    auto type = std::any_cast<Type>(visitBoolean(boolean));

    if (type != Type::Boolean) {
      errorlistener.typeMismatch(boolean->getStart(), Type::Boolean, type);
    }
    return Type::Boolean;
  }
  if (ctx->True() || ctx->False()) {
    return Type::Boolean;
  }
  if (ctx->expr().size()) {
    CHECKOPERATOR(visitExpr, ctx->expr(0), visitExpr, ctx->expr(1),
                  Type::Integer);
    return Type::Boolean;
  }
  if (auto variable = ctx->variable()) {
    return visitVariable(variable);
  }
  return visitCondition(ctx->condition());
}

std::any TypeCheckVisitor::visitExpr(xlangParser::ExprContext *ctx) {
  if (auto expr = ctx->expr()) {
    CHECKOPERATOR(visitExpr, expr, visitTerm, ctx->term(), Type::Integer);
    return Type::Integer;
  }
  return visitTerm(ctx->term());
}

std::any TypeCheckVisitor::visitTerm(xlangParser::TermContext *ctx) {
  if (auto term = ctx->term()) {
    CHECKOPERATOR(visitTerm, term, visitFactor, ctx->factor(), Type::Integer);
    return Type::Integer;
  }
  return visitFactor(ctx->factor());
}

std::any TypeCheckVisitor::visitFactor(xlangParser::FactorContext *ctx) {
  if (auto factor = ctx->factor()) {
    auto type = std::any_cast<Type>(visitFactor(factor));

    if (type != Type::Integer) {
      errorlistener.typeMismatch(factor->getStart(), Type::Integer, type);
    }
    return Type::Integer;
  }
  if (ctx->Integer()) {
    return Type::Integer;
  }
  if (auto variable = ctx->variable()) {
    return visitVariable(variable);
  }
  return visitExpr(ctx->expr());
}

std::any TypeCheckVisitor::visitVariable(xlangParser::VariableContext *ctx) {
  auto token = ctx->Identifier()->getSymbol();
  auto name = token->getText();

  if (ctx->Define()) {
    auto type = std::any_cast<Type>(visitValue(ctx->value()));

    if (scope.get(name)) {
      errorlistener.shadowedVariable(token, name);
    }
    scope.add(name, type);
    return type;
  }
  if (ctx->Assign()) {
    auto type = std::any_cast<Type>(visitValue(ctx->value()));

    if (auto expected = scope.get(name)) {
      if (type != *expected) {
        errorlistener.typeMismatch(token, *expected, type);
      }
      return *expected;
    } else {
      errorlistener.unknownVariable(token, name);
      return type;
    }
  }
  if (auto expr = ctx->expr()) {
    auto righttype = std::any_cast<Type>(visitExpr(expr));

    if (auto lefttype = scope.get(name)) {
      if (*lefttype != Type::Integer) {
        errorlistener.typeMismatch(token, Type::Integer, *lefttype);
      }
    } else {
      errorlistener.unknownVariable(token, name);
    }
    if (righttype != Type::Integer) {
      errorlistener.typeMismatch(token, Type::Integer, righttype);
    }
    return Type::Integer;
  }
  if (ctx->LeftParen()) {
    auto it = signatures.find(name);

    if (it == signatures.end()) {
      errorlistener.unknownFunction(token, name);
      return Type::Invalid;
    } else {
      auto signature = it->second;
      auto arity = signature.parametertypes.size();

      if (auto arg_list = ctx->argumentList()) {
        auto arg_num = arg_list->value().size();

        if (arity != arg_num) {
          errorlistener.wrongArgumentNumber(token, name, arity, arg_num);
        }
        for (size_t i = 0; i < arg_num && i < arity; i++) {
          auto type = std::any_cast<Type>(visitValue(arg_list->value(i)));
          auto token = arg_list->value(i)->getStart();

          if (type != signature.parametertypes[i]) {
            errorlistener.typeMismatch(token, signature.parametertypes[i],
                                       type);
          }
        }
      } else {
        if (arity) {
          errorlistener.wrongArgumentNumber(token, name, arity, 0);
        }
      }
      return signature.returntype;
    }
  }
  if (auto type = scope.get(name)) {
    return *type;
  } else {
    errorlistener.unknownVariable(token, name);
    return Type::Invalid;
  }
}

}  // namespace xlang