package tree;

import static tree.Codes.ALOAD_OP;
import static tree.Codes.AND_OP;
import static tree.Codes.ASTO_OP;
import static tree.Codes.ENTR_OP;
import static tree.Codes.EXIT_OP;
import static tree.Codes.FME_OP;
import static tree.Codes.FMT_OP;
import static tree.Codes.GLOAD_OP;
import static tree.Codes.GSTO_OP;
import static tree.Codes.INT_LIT;
import static tree.Codes.INT_TYPE;
import static tree.Codes.JNZ_OP;
import static tree.Codes.JSR_OP;
import static tree.Codes.JUMP_OP;
import static tree.Codes.JZ_OP;
import static tree.Codes.LOAD_OP;
import static tree.Codes.NEW_OP;
import static tree.Codes.OR_OP;
import static tree.Codes.PRI_OP;
import static tree.Codes.PRL_OP;
import static tree.Codes.PRS_OP;
import static tree.Codes.RDI_OP;
import static tree.Codes.RDS_OP;
import static tree.Codes.RTN_OP;
import static tree.Codes.RTS_OP;
import static tree.Codes.STO_OP;
import static tree.Codes.STR_LIT;
import io.UncheckedRandomAccess;

import java.io.IOException;
import java.io.RandomAccessFile;

import vartab.FunctionEntry;
import vartab.VariableEntry;

/**
 * Generates byte code out of the AST and writes it to the file system.
 */
public final class CodeGenerationVisitor implements IVisitor {

    private static final int MAGIC = 0xcafe;
    private static final int VERSION = 0x0099;
    private static final int UNKNOWN = -1;

    private Program thisProgram;
    private final UncheckedRandomAccess out;

    private long progSizePos;
    private long startAddress;
    private boolean deleteReturn = false;
    private long beginLabel;

    /**
     * Create the Visitor.
     * 
     * @param file
     *            output file object.
     */
    public CodeGenerationVisitor(RandomAccessFile file) {
        out = new UncheckedRandomAccess(file);
    }

    /*
     * Helpers for the handling of forward and backward jumps.
     */

    /**
     * Returns address of next byte-code.
     * 
     * @return next address.
     * @throws IOException
     *             when something went wrong
     */
    private long currentAddress() {
        return out.getFilePointer();
    }

    /**
     * Sets the address pointer.
     * 
     * @param address
     *            new address.
     * @throws IOException
     *             when something went wrong:
     */
    private void setAddress(long address) {
        out.seek(address);
    }

    /**
     * Writes the goal part of a JMP, JZ or JSR instruction.
     * 
     * @param goalAddress
     *            address of goal.
     * @return address of goal part.
     * @throws Exception
     *             when something went wrong.
     */
    private long writeJumpTarget(long goalAddress) {
        long targetAddress = currentAddress();
        int target =
            (goalAddress == UNKNOWN) ? UNKNOWN
                : (int) (goalAddress - targetAddress);
        out.writeShort(target);
        return targetAddress;
    }

    /**
     * Fixes a forward jump to the current address.
     * 
     * @param addressOfJump
     *            address of goal part.
     * @throws IOException
     *             when something went wrong.
     */
    private void jumpLabelFor(long addressOfJump) {
        long currentAddress = currentAddress();
        setAddress(addressOfJump);
        out.writeShort((int) (currentAddress - addressOfJump));
        setAddress(currentAddress);
    }

    public void visit(Program p) {
        writeHeader(p);
        writeStringLiterals();
        writeMainFunction();
        for (FunctionEntry fkt : thisProgram.listOfFunctions)
            writeFunction(fkt);
        fixByteCount();
    }

    private void writeHeader(Program p) {
        thisProgram = p;
        out.writeShort(MAGIC);
        out.writeShort(VERSION);
        progSizePos = currentAddress();
        out.writeShort(UNKNOWN);
        out.writeShort(thisProgram.globalFrame.numberOfVariables());
    }

    private void writeStringLiterals() {
        out.writeShort(thisProgram.literals.size());
        int count = 0;
        for (String s : thisProgram.literals) {
            thisProgram.literals.put(s, count++);
            out.writeShort(s.length());
            out.writeBytes(s);
        }
    }

    private void writeMainFunction() {
        startAddress = currentAddress();
        thisProgram.stmts.accept(this);
        out.write(EXIT_OP);
    }

    private void writeFunction(FunctionEntry fkt) {
        fkt.setAddress((int) (currentAddress() - startAddress));
        out.write(ENTR_OP);
        out.writeShort(fkt.getNumberOfParameters());
        out.writeShort(fkt.geVarCount());
        beginLabel = currentAddress();
        StatementSequence b = fkt.getBlock();
        b.accept(this);
        if (!(b.lastStatement(0) instanceof ReturnStmt))
            out.write(RTS_OP);
    }

    private void fixByteCount() {
        long endAddress = currentAddress();
        setAddress(progSizePos);
        out.writeShort((int) (endAddress - startAddress));
    }

    public void visit(StatementSequence b) {
        for (INode n : b.stmts)
            n.accept(this);
    }

    public void visit(IfStmt s) {
        s.condition.accept(this);
        out.write(JZ_OP);
        long thenJump = writeJumpTarget(UNKNOWN);
        s.thenPart.accept(this);
        if (s.elsePart == null)
            jumpLabelFor(thenJump);
        else {
            out.write(JUMP_OP);
            long elseJump = writeJumpTarget(UNKNOWN);
            jumpLabelFor(thenJump);
            s.elsePart.accept(this);
            jumpLabelFor(elseJump);
        }
    }

    public void visit(WhileStmt s) {
        out.write(JUMP_OP);
        long loopCondition = writeJumpTarget(UNKNOWN);
        long loopStart = currentAddress();
        s.body.accept(this);
        jumpLabelFor(loopCondition);
        s.condition.accept(this);
        out.write(JNZ_OP);
        writeJumpTarget(loopStart);
    }

    public void visit(DoWhileStmt s) {
        long loopStart = currentAddress();
        s.body.accept(this);
        s.condition.accept(this);
        out.write(JNZ_OP);
        writeJumpTarget(loopStart);
    }

    public void visit(ReturnStmt s) {
        if (s.expression != null) {
            deleteReturn = false;
            s.expression.accept(this);
            if (!deleteReturn) out.write(RTN_OP);
        } else
            out.write(RTS_OP);
    }

    public void visit(AssignStmt s) {
        s.rightHandSide.accept(this);
        VarNode left = s.leftHandSide.reference;
        if (s.leftHandSide.isArray()) 
            storeArray((IndexExpr) left);
        else
            storeVariable((VarRef) left);
    }

    private void storeArray(IndexExpr index) {
        index.reference.accept(this);
        index.index.accept(this);
        out.write(ASTO_OP);
    }
    
    private void storeVariable(VarRef n) {
        VariableEntry e = n.descriptor;
        out.write(e.isGlobal() ? GSTO_OP : STO_OP);
        out.writeShort(e.getAddress());
    }
    
    public void visit(VarExpr e) {
        e.reference.accept(this);
    }
    
    public void visit(IndexExpr n) {
        n.reference.accept(this);
        n.index.accept(this);
        out.write(ALOAD_OP);
    }

    public void visit(VarRef v) {
        VariableEntry e = v.descriptor;
        out.write(e.isGlobal() ? GLOAD_OP : LOAD_OP);
        out.writeShort(e.getAddress());
    }

    public void visit(PrintStmt s) {
        if (s.format != null) {
            s.format.accept(this);
            out.write(FMT_OP);
        }
        int i = 0;
        for (INode n : s.expressions) {
            n.accept(this);
            out.write(s.types[i++] == INT_TYPE ? PRI_OP : PRS_OP);
        }
        if (s.format != null)
            out.write(FME_OP);
        if (s.printNewLine)
            out.write(PRL_OP);
    }

    public void visit(ReadNode s) {
        if (s.prompt != null) {
            s.prompt.accept(this);
            out.write(PRS_OP);
        }
        out.write(s.type == INT_TYPE ? RDI_OP : RDS_OP);
    }

    public void visit(NewOp n) {
        n.size.accept(this);
        out.write(NEW_OP);
    }

    public void visit(CallStmt f) {
        f.functionCall.accept(this);
    }

    public void visit(FunctionCall f) {
        for (INode n : f.arguments) n.accept(this);
        if (f.lastCall) {
            int n = f.function.getNumberOfParameters();
            for (int i = n - 1; i >= 0; i--) {
                out.write(STO_OP);
                out.writeShort(i);
            }
            out.write(JUMP_OP);
            writeJumpTarget(beginLabel);
            deleteReturn = true;
        }
        else {
            out.write(JSR_OP);
            // fixed by LinkVisitor:
            f.callAddress = writeJumpTarget(UNKNOWN);
        }
    }

    public void visit(BinOp op) {
        op.left.accept(this);
        if (op.operator != OR_OP && op.operator != AND_OP) {
            op.right.accept(this);
            out.write(op.operator);
        } else { // && or ||:
            out.write(op.operator);
            long skipRightOperand = writeJumpTarget(UNKNOWN);
            op.right.accept(this);
            jumpLabelFor(skipRightOperand);
            deleteReturn = false;
        }
    }

    public void visit(UnOp op) {
        op.operand.accept(this);
        out.write(op.operator);
    }

    public void visit(IntLiteral lit) {
        out.write(INT_LIT);
        out.writeInt(lit.number);
    }

    public void visit(StringLiteral lit) {
        out.write(STR_LIT);
        out.writeShort(thisProgram.literals.get(lit.string));
    }
}
