package tree;

/******************************************************************
 *  requires:
 *        asm = Java Bytecode Engineering Library
 *        http://forge.ow2.org/projects/asm/
 ******************************************************************/

import static tree.Codes.AND_OP;
import static tree.Codes.EQEQ_OP;
import static tree.Codes.ICONV_OP;
import static tree.Codes.INT_TYPE;
import static tree.Codes.MOD_OP;
import static tree.Codes.NEQ_OP;
import static tree.Codes.NOT_OP;
import static tree.Codes.OR_OP;
import static tree.Codes.SADD_OP;
import static tree.Codes.STRING_TYPE;
import static tree.Codes.javaCodes;

import java.io.FileOutputStream;

import org.objectweb.asm.ClassWriter;
import org.objectweb.asm.Label;
import org.objectweb.asm.MethodVisitor;
import org.objectweb.asm.Opcodes;

import vartab.FunctionEntry;
import vartab.Type;
import vartab.VariableEntry;

/**
 * Generates a java classfile that is semantically equivalent to the AST.
 */
public final class JavaCodeGenerationVisitor implements IVisitor, Opcodes {

    private static final String IN_VAR = "$$in$$";
    private final String className;
    private MethodVisitor mv;
    private boolean deleteReturn = false;
    private Label beginLabel;

    private static final String VOID = "V";
    private static final String INT = "I";
    private static final String PRINTER = "java/io/PrintStream";
    private static final String STRING = "java/lang/String";
    private static final String SCANNER = "java/util/Scanner";
    private static final String SYSTEM = "java/lang/System";
    private static final String OBJECT = "java/lang/Object";
    private static final String INPUT = "java/io/InputStream";
    private static final String BUILDER = "java/lang/StringBuilder";
    private static final String INTEGER = "java/lang/Integer";

    /**
     * Converts Java classname (separator is /) to type format for bytecode.
     * Names of elementary types are not changed.
     * 
     * @param className
     *            full Java classname with / as package name separator
     * @return type name as required by bytecode.
     */
    private static String type(String className) {
        return className.length() == 1 ? className : "L" + className + ";";
    }

    /**
     * Builds signature String. Java type-names are to be provided as Strings.
     * They are converted into the canonical type format.
     * 
     * @param result
     *            return-type
     * @param params
     *            type-names for function parameters
     * @return signature String
     */
    private static String sig(String result, String... params) {
        final StringBuilder b = new StringBuilder().append('(');
        for (String p : params) b.append(type(p));
        return b.append(")").append(type(result)).toString();
    }

    /**
     * Create the Visitor.
     * 
     * @param class name.
     */
    public JavaCodeGenerationVisitor(String name) {
        this.className = name;
    }

    public void visit(Program p) {
        ClassWriter cw = new ClassWriter(ClassWriter.COMPUTE_MAXS);
        writeHeader(p, cw);
        writeMainFunction(p, cw);
        for (FunctionEntry fkt : p.listOfFunctions)
            writeFunction(fkt, cw);
        writeToClassfile(cw);
    }

    /**
     * Converts asm-bytecode structure into byte array and writes it to the
     * output file. IOExceptions are wrapped into RuntimeExceptions
     * 
     * @param cw ClassWriter object
     * 
     * @throws RuntimeException
     *             if some exception occurs
     */
    private void writeToClassfile(ClassWriter cw) {
        cw.visitEnd();
        byte[] b = cw.toByteArray();
        try {
            FileOutputStream f = new FileOutputStream(className + ".class");
            f.write(b);
            f.close();
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    private void writeHeader(Program p, ClassWriter cw) {
        cw.visit(V1_5, ACC_PUBLIC + ACC_SUPER, className, null, OBJECT, null);
        cw.visitSource(className + ".tc", null);
        for (VariableEntry e : p.getGlobalFrame().topScope())
            cw.visitField(ACC_STATIC, e.getName(), e.getType().javaSignature(),
                null, null).visitEnd();
        cw.visitField(ACC_STATIC, IN_VAR, type(SCANNER), null, null).visitEnd();
    }

    private void writeMainFunction(Program p, ClassWriter cw) {
        mv = cw.visitMethod(ACC_PUBLIC + ACC_STATIC, "main", sig(VOID, "[",
                STRING), null, null);
        mv.visitCode();
        initializeScanner();
        p.stmts.accept(this);
        if (!(p.stmts.lastStatement(0) instanceof ReturnStmt))
            mv.visitInsn(RETURN);
        mv.visitMaxs(0, 0);
        mv.visitEnd();
    }

    private void initializeScanner() {
        mv.visitTypeInsn(NEW, SCANNER);
        mv.visitInsn(DUP);
        mv.visitFieldInsn(GETSTATIC, SYSTEM, "in", type(INPUT));
        mv.visitMethodInsn(INVOKESPECIAL, SCANNER, "<init>", sig(VOID, INPUT));
        mv.visitFieldInsn(PUTSTATIC, className, IN_VAR, type(SCANNER));
    }

    private void writeFunction(FunctionEntry fkt, ClassWriter cw) {
        mv = cw.visitMethod(ACC_PUBLIC + ACC_STATIC, fkt.getName(),
                fkt.getSignature(), null, null);
        mv.visitCode();
        StatementSequence b = fkt.getBlock();
        beginLabel = new Label();
        mv.visitLabel(beginLabel);
        b.accept(this);
        if (!(b.lastStatement(0) instanceof ReturnStmt))
            mv.visitInsn(RETURN);
        mv.visitMaxs(0, 0);
        mv.visitEnd();
    }

    public void visit(StatementSequence seq) {
        for (INode n : seq.stmts)
            n.accept(this);
    }

    /* IF (s.condition) s.thenStmt ELSE s.elseStmt
     *      s.condition
     *      IFEQ _endOfThen
     *      s.thenStmt
     *      GOTO _endOfElse
     * _endOfThen:
     *      s.elseStmt  
     * _endOfElse:      
     */
    public void visit(IfStmt s) {
        s.condition.accept(this);
        Label endOfThen = new Label();
        mv.visitJumpInsn(IFEQ, endOfThen);
        s.thenPart.accept(this);
        if (s.elsePart == null)
            mv.visitLabel(endOfThen);
        else {
            Label endOfElse = new Label();
            mv.visitJumpInsn(GOTO, endOfElse);
            mv.visitLabel(endOfThen);
            s.elsePart.accept(this);
            mv.visitLabel(endOfElse);
        }
    }

    /* WHILE (s.condition) s.stmt
     *      GOTO _condition
     * _startOfBody:
     *      s.stmt
     * _condition:
     *      s.condition
     *      IFNE _startOfBody
     */
    public void visit(WhileStmt s) {
        Label condition = new Label();
        mv.visitJumpInsn(GOTO, condition);
        Label startOfBody = new Label();
        mv.visitLabel(startOfBody);
        s.body.accept(this);
        mv.visitLabel(condition);
        s.condition.accept(this);
        mv.visitJumpInsn(IFNE, startOfBody);
    }

    /* DO { s-stmt} WHILE (s.condition)
     * _startOfBody:
     *      s.stmt
     *      s.condition
     *      IFNE _startOfBody
     */
    public void visit(DoWhileStmt s) {
        Label startOfBody = new Label();
        mv.visitLabel(startOfBody);
        s.body.accept(this);
        s.condition.accept(this);
        mv.visitJumpInsn(IFNE, startOfBody);
    }

    public void visit(ReturnStmt s) {
        if (s.expression != null) {
            deleteReturn = false;
            s.expression.accept(this);
            if (!deleteReturn)
                mv.visitInsn(s.resultType == INT_TYPE ? IRETURN : ARETURN);
        } else
            mv.visitInsn(RETURN);
    }

    public void visit(AssignStmt s) {
        VarExpr left = s.leftHandSide;
        if (left.isArray()) {
            left.reference.accept(this);
            s.rightHandSide.accept(this);
            createArrayStore(left.refType());
        }
        else {
            s.rightHandSide.accept(this);
            createVarStore((VarRef) left.reference);
        }
    }

    private void createArrayStore(Type type) {
        int code = (type == INT_TYPE) ? IASTORE : AASTORE;
        mv.visitInsn(code);
    }

    private void createVarStore(VarRef left) {
        VariableEntry e = left.descriptor;
        Type t = e.getType();
        if (e.isGlobal()) {
            mv.visitFieldInsn(PUTSTATIC, className, e.getName(),
                t.javaSignature());
        } else {
            int code = (t == INT_TYPE) ? ISTORE : ASTORE;
            mv.visitVarInsn(code, e.getAddress());
        }
    }

    public void visit(VarExpr e) {
        e.reference.accept(this);
        if (e.isArray()) {
            int code = e.refType() == INT_TYPE?
                IALOAD:
                AALOAD;
            mv.visitInsn(code);
        }
    }

    public void visit(VarRef v) {
        VariableEntry e = v.descriptor;
        Type t = e.getType();
        if (e.isGlobal())
            mv.visitFieldInsn(GETSTATIC, className, e.getName(),
                t.javaSignature());
        else {
            int code = (t == INT_TYPE) ? ILOAD : ALOAD;
            mv.visitVarInsn(code, e.getAddress());
        }
    }

    public void visit(IndexExpr n) {
        n.reference.accept(this);
        if (n.reference instanceof IndexExpr)
            mv.visitInsn(AALOAD);
        n.index.accept(this);
    }

    public void visit(PrintStmt s) {
        int i = 0;
        for (INode n : s.expressions) {
            mv.visitFieldInsn(GETSTATIC, SYSTEM, "out", type(PRINTER));
            Type type = s.types[i++];
            buildFormat(s.format, type);
            mv.visitInsn(ICONST_1);
            mv.visitTypeInsn(ANEWARRAY, OBJECT);
            mv.visitInsn(DUP);
            mv.visitInsn(ICONST_0);
            n.accept(this);
            if (type == INT_TYPE)
                mv.visitMethodInsn(INVOKESTATIC, INTEGER, "valueOf", sig(
                    INTEGER, INT));
            mv.visitInsn(AASTORE);
            mv.visitMethodInsn(INVOKEVIRTUAL, PRINTER, "printf", sig(PRINTER,
                STRING, "[", OBJECT));
            mv.visitInsn(POP);
        }
        if (s.printNewLine) {
            mv.visitFieldInsn(GETSTATIC, SYSTEM, "out", type(PRINTER));
            mv.visitMethodInsn(INVOKEVIRTUAL, PRINTER, "println", sig(VOID));
        }
    }

    private void buildFormat(INode format, Type type) {
        if (format != null) {
            mv.visitTypeInsn(NEW, BUILDER);
            mv.visitInsn(DUP);
            mv.visitLdcInsn("%");
            mv.visitMethodInsn(INVOKESPECIAL, BUILDER, "<init>", sig(VOID,
                STRING));
            format.accept(this);
            mv.visitMethodInsn(INVOKESTATIC, STRING, "valueOf",
                sig(STRING, INT));
            mv.visitMethodInsn(INVOKEVIRTUAL, BUILDER, "append", sig(BUILDER,
                STRING));
            String spec = (type == INT_TYPE) ? "d" : "s";
            mv.visitLdcInsn(spec);
            mv.visitMethodInsn(INVOKEVIRTUAL, BUILDER, "append", sig(BUILDER,
                STRING));
            mv.visitMethodInsn(INVOKEVIRTUAL, BUILDER, "toString", sig(STRING));
        } else
            mv.visitLdcInsn("%s");
    }

    public void visit(ReadNode s) {
        if (s.prompt != null) {
            mv.visitFieldInsn(GETSTATIC, SYSTEM, "out", type(PRINTER));
            s.prompt.accept(this);
            mv.visitMethodInsn(INVOKEVIRTUAL, PRINTER, "print", sig(VOID, STRING));
        }
        mv.visitFieldInsn(GETSTATIC, className, IN_VAR, type(SCANNER));
        if (s.type == INT_TYPE) {
//            mv.visitInsn(DUP);
            mv.visitMethodInsn(INVOKEVIRTUAL, SCANNER, "nextInt", sig(INT));
//            mv.visitInsn(SWAP);
//            mv.visitMethodInsn(INVOKEVIRTUAL, SCANNER, "nextLine", sig(STRING));
//            mv.visitInsn(POP);
        } else
            mv.visitMethodInsn(INVOKEVIRTUAL, SCANNER, "nextLine", sig(STRING));
    }

    public void visit(NewOp n) {
        n.size.accept(this);
        if (n.type == INT_TYPE)
            mv.visitIntInsn(NEWARRAY, T_INT);
        else if (n.type == STRING_TYPE)
            mv.visitTypeInsn(ANEWARRAY, STRING);
        else
            mv.visitTypeInsn(ANEWARRAY, n.type.javaSignature());
    }

    public void visit(CallStmt f) {
        f.functionCall.accept(this);
    }

    public void visit(FunctionCall f) {
        for (INode n : f.arguments) n.accept(this);
        if (f.lastCall) {
            Type[] parameterTypes = f.function.getTypesOfParameters();
            for (int addr = parameterTypes.length - 1; addr >= 0; addr--) {
                Type t = parameterTypes[addr];
                mv.visitVarInsn((t == INT_TYPE) ? ISTORE : ASTORE, addr);
            }
            mv.visitJumpInsn(GOTO, beginLabel);
            deleteReturn = true;
        }
        else
            mv.visitMethodInsn(INVOKESTATIC, className, f.name,
                f.function.getSignature());
    }

    public void visit(BinOp b) {
        if (b.operator == SADD_OP)
            generateConcat(b.left, b.right);
        else if (b.operator != OR_OP && b.operator != AND_OP) {
            b.left.accept(this);
            b.right.accept(this);
            if (b.operator <= MOD_OP)
                mv.visitInsn(javaCodes[b.operator]);
            else if (b.operator <= NEQ_OP)
                generateBoolOp(b.operator, b.arrayOp);
            else
                // if (op.op < SNEQ_OP)
                generateStringCmp(b);
        } else
            generateAndOr(b);
    }

    private void generateAndOr(BinOp op) {
        int compareCode = IFNE;
        int shortCutCode = ICONST_1;
        if (op.operator == AND_OP) {
            compareCode = IFEQ;
            shortCutCode = ICONST_0;
        }
        op.left.accept(this);
        Label shortCut = new Label();
        mv.visitJumpInsn(compareCode, shortCut);
        op.right.accept(this);
        Label finish = new Label();
        mv.visitJumpInsn(GOTO, finish);
        mv.visitLabel(shortCut);
        mv.visitInsn(shortCutCode);
        mv.visitLabel(finish);
        deleteReturn = false;
    }

    private void generateConcat(INode left, INode right) {
        mv.visitTypeInsn(NEW, BUILDER);
        mv.visitInsn(DUP);
        left.accept(this);
        mv.visitMethodInsn(INVOKESPECIAL, BUILDER, "<init>", sig(VOID, STRING));
        right.accept(this);
        mv.visitMethodInsn(INVOKEVIRTUAL, BUILDER, "append", sig(BUILDER,
            STRING));
        mv.visitMethodInsn(INVOKEVIRTUAL, BUILDER, "toString", sig(STRING));
    }

    private void generateStringCmp(BinOp op) {
        mv.visitMethodInsn(INVOKEVIRTUAL, STRING, "compareTo", sig(INT, OBJECT));
        generateBoolOp(op.operator, false);
    }

    private void generateBoolOp(int op, boolean isArray) {
        Label L1 = new Label();
        Label L2 = new Label();
        int code = javaCodes[op];
        if (isArray)
            code = (op == EQEQ_OP) ? IF_ACMPEQ : IF_ACMPNE;
        mv.visitJumpInsn(code, L1);
        mv.visitInsn(ICONST_0);
        mv.visitJumpInsn(GOTO, L2);
        mv.visitLabel(L1);
        mv.visitInsn(ICONST_1);
        mv.visitLabel(L2);

    }

    public void visit(UnOp u) {
        u.operand.accept(this);
        if (u.operator == NOT_OP)
            generateBoolOp(u.operator, false);
        else if (u.operator == ICONV_OP)
            mv.visitMethodInsn(INVOKESTATIC, STRING, "valueOf",
                sig(STRING, INT));
        else
            mv.visitInsn(javaCodes[u.operator]);
    }

    public void visit(IntLiteral lit) {
        int nr = lit.number;
        if (nr >= 0 && nr <= 5)
            mv.visitInsn(ICONST_0 + nr);
        else if (nr >= Byte.MIN_VALUE && nr <= Byte.MAX_VALUE)
            mv.visitIntInsn(BIPUSH, nr);
        else if (nr >= Short.MIN_VALUE && nr <= Short.MAX_VALUE)
            mv.visitIntInsn(SIPUSH, nr);
        else
            mv.visitLdcInsn(Integer.valueOf(nr));
    }

    public void visit(StringLiteral lit) {
        mv.visitLdcInsn(lit.string);
    }
}
