package tree;

import static tree.Codes.EQEQ_OP;
import static tree.Codes.GT_OP;
import static tree.Codes.ICONV_OP;
import static tree.Codes.INT_TYPE;
import static tree.Codes.LEN_OP;
import static tree.Codes.NEQ_OP;
import static tree.Codes.PLUS_OP;
import static tree.Codes.SADD_OP;
import static tree.Codes.SGT_OP;
import static tree.Codes.STRING_TYPE;
import static tree.Codes.VOID_TYPE;
import vartab.ArrayType;
import vartab.Errors;
import vartab.FunctionEntry;
import vartab.Functions;
import vartab.Type;

/**
 * Performs all semantic type checking.
 * In some cases the AST is modified in order to provide more specific
 * information. If necessary, automatic type coersion is inserted.
 */
public final class TypeCheckVisitor implements IVisitor {
    private Type _RESULT_;
    private Functions listOfFunctions;
    private FunctionEntry thisContext;

    /*
     * Helper functions to enhance the readability of type checking and
     * error reporting.
     */

    /**
     * Tests a given condition and reports an error in case of failure.
     * @param cond Condition that is required to be true.
     * @param before leading error-message text.
     * @param node source code that shall be included in the error message.
     */
    private void require(boolean cond, String before, INode node) {
        require(cond, before, node, "");
    }

    /**
     * Tests a given condition and reports an error in case of failure.
     * @param cond Condition that is required to be true.
     * @param before leading text of error message.
     * @param node source code that shall be including in the error message.
     * @param after trailing text of error message.
     */
    private void require(boolean cond, String before, INode node, String after) {
        if (!cond) fail(before, node, after);
    }

    /**
     * Reports an error message.
     * @param before leading error-message text.
     * @param node source code that shall be including in the error message.
     */
    private void fail(String before, INode node) {
        fail(before, node, "");
    }

    /**
     * Reports an error message.
     * @param before leading text of error message.
     * @param node source code that shall be including in the error message.
     * @param after trailing text of error message.
     */
    private void fail(String before, INode node, String after) {
        Errors.error(before);
        node.accept(PRINTER);
        System.err.println(after);
    }
    
    /**
     * Traverses subtree n and returns type of subtree
     * @param n subtree
     * @return type of subtree
     */
    private Type typeOf(INode n) {
        n.accept(this);
        return _RESULT_;
    }

    private static final PrintVisitor PRINTER = new PrintVisitor(System.err);

    public void visit(Program p) {
        listOfFunctions = p.listOfFunctions;
        for (FunctionEntry fkt : listOfFunctions) {
            thisContext = fkt;
            fkt.getBlock().accept(this);
        }
        thisContext = null;
        p.stmts.accept(this);
    }

    public void visit(StatementSequence b) {
        for (INode n : b.stmts) n.accept(this);
    }

    public void visit(IfStmt s) {
        require(typeOf(s.condition) == INT_TYPE,
            "needs int: if(", s.condition, ")...");
        s.thenPart.accept(this);
        if (s.elsePart != null) s.elsePart.accept(this);
    }

    public void visit(WhileStmt s) {
        require(typeOf(s.condition) == INT_TYPE,
            "needs int: while(", s.condition, ")...");
        s.body.accept(this);
    }

    public void visit(DoWhileStmt s) {
        require(typeOf(s.condition) == INT_TYPE, "needs int: ", s.condition, ")");
        s.body.accept(this);
    }

    public void visit(PrintStmt s) {
        if (s.format != null)
            require(typeOf(s.format) == INT_TYPE,
                "print-format must be int: ", s);
        int i = 0;
        for (INode n : s.expressions) s.types[i++] = typeOf(n);
    }

    public void visit(ReadNode n) {
        require(n.type == INT_TYPE || n.type == STRING_TYPE,
                "read works for int and string only: ", n);
        if (n.prompt != null)
            require(typeOf(n.prompt) == STRING_TYPE,
                "illegal read-prompt: ", n);
        _RESULT_ = n.type;
    }

    public void visit(ReturnStmt s) {
        if (thisContext == null) {
            Errors.errorln("return is not allowed in global code");
            return;
        }
        Type returnType = VOID_TYPE; 
        if (s.expression != null)
            returnType = typeOf(s.expression);
        if (!thisContext.getResultType().equals(returnType))
            Errors.errorln("illegal return in function: "
                    + thisContext.getName());
        s.numberOfArguments = thisContext.getNumberOfParameters();
        s.resultType = thisContext.getResultType();
    }

    public void visit(AssignStmt s) {
        require(typeOf(s.leftHandSide).equals(typeOf(s.rightHandSide)),
            "type error in assign: ", s);
    }

    @Override
    public void visit(VarExpr varExpr) {
        varExpr.reference.accept(this);
        _RESULT_ = varExpr.refType();
    }

    public void visit(VarRef v) {
        require(v.type(0) != Codes.VOID_TYPE, "undefinedVariable: ", v);
    }

    public void visit(IndexExpr n) {
        require(typeOf(n.index) == INT_TYPE,
            "array index: ", n, " must be int");
    }

    public void visit(NewOp n) {
        require(typeOf(n.size) == INT_TYPE,"illegal type for new-size in: ", n);
        _RESULT_ = new ArrayType(n.type);
    }

    public void visit(CallStmt c) {
        require(typeOf(c.functionCall) == VOID_TYPE,
            "call to non void function: ", c.functionCall);
    }

    public void visit(FunctionCall n) {
        FunctionEntry function = listOfFunctions.get(n.name);
        if (function == null) {
            fail("Function ", n, " not found");
            _RESULT_ = VOID_TYPE;
        } else {
            n.function = function;
            Type[] argTypes = function.getTypesOfParameters();
            if (argTypes.length == n.arguments.size()) {
                int i = 0;
                for (INode arg : n.arguments) {
                    require(typeOf(arg).equals(argTypes[i]),
                        "type error in arg #" + i++ + ": ", n);
                }
            } else
                fail("function ", n, " is called with wrong argument number");
            _RESULT_ = function.getResultType();
        }
    }

    public void visit(BinOp n) {
        Type t1 = typeOf(n.left);
        Type t2 = typeOf(n.right);
        if (t1.equals(t2))
            doBinOpSameType(n, t1);
        else if (n.operator == PLUS_OP)
            doBinOpPlusIntString(n, t1, t2);
        else {
            fail("type-error in: ", n);
            _RESULT_ = VOID_TYPE;
        }
    }

    private void doBinOpPlusIntString(BinOp op, Type t1, Type t2) {
        if (t1 == INT_TYPE && t2 == STRING_TYPE)
            op.left = new UnOp(ICONV_OP, op.left);
        else if (t2 == INT_TYPE && t1 == STRING_TYPE)
            op.right = new UnOp(ICONV_OP, op.right);
        else {
            fail("type-error in: ", op);
        }
        op.operator = SADD_OP;
        _RESULT_ = STRING_TYPE;
    }

    private void doBinOpSameType(BinOp op, Type t1) {
        if (t1 == STRING_TYPE)
            doBinOpTwoStrings(op);
        else if (t1 == INT_TYPE)
            _RESULT_ = INT_TYPE;
        else if (op.operator == EQEQ_OP || op.operator == NEQ_OP) {
            op.arrayOp = true;
            _RESULT_ = INT_TYPE;
        }
        else {
            fail("type-error in: ", op);
            _RESULT_ = INT_TYPE;
        }
    }

    private void doBinOpTwoStrings(BinOp op) {
        if (op.operator == PLUS_OP) {
            op.operator = SADD_OP;
            _RESULT_ = STRING_TYPE;
        }
        else {
            require(op.operator >= GT_OP && op.operator <= NEQ_OP,
                    "illegal string operation: ", op);
            op.operator += (SGT_OP - GT_OP);
            _RESULT_ = INT_TYPE;
        }
    }

    public void visit(UnOp n) {
        Type opType = typeOf(n.operand);
        if (n.operator == LEN_OP && opType.isArray()) {
            _RESULT_ = INT_TYPE;
        } else if (opType != INT_TYPE) {
            fail("type-error in: ", n);
            _RESULT_ = VOID_TYPE;
        }
    }

    public void visit(IntLiteral n) {
        _RESULT_ = INT_TYPE;
    }

    public void visit(StringLiteral n) {
        _RESULT_ = STRING_TYPE;
    }
}