/* $Id$ 
 *
 * Generate intermediate code, array specific parts.
 *
 * Copyright (C) 2008-2009 FAUmachine Team <info@faumachine.org>.
 * This program is free software. You can redistribute it and/or modify it
 * under the terms of the GNU General Public License, either version 2 of
 * the License, or (at your option) any later version. See COPYING.
 */


#include "frontend/visitor/GCArrays.hpp"
#include <cassert>
#include "frontend/visitor/GenCode.hpp"
#include "frontend/visitor/ResolveTypes.hpp"
#include "frontend/visitor/GCTypes.hpp"
#include "frontend/ast/DiscreteRange.hpp"
#include "frontend/ast/UnconstrainedArrayType.hpp"
#include "intermediate/operands/RegisterFactory.hpp"
#include "intermediate/operands/ImmediateOperand.hpp"
#include "intermediate/operands/IndirectOperand.hpp"
#include "intermediate/container/LabelFactory.hpp"
#include "intermediate/container/TypeFactory.hpp"
#include "intermediate/opcodes/Mov.hpp"
#include "intermediate/opcodes/Je.hpp"
#include "intermediate/opcodes/Jb.hpp"
#include "intermediate/opcodes/Jbe.hpp"
#include "intermediate/opcodes/Jmp.hpp"
#include "intermediate/opcodes/Sub.hpp"
#include "intermediate/opcodes/IMul.hpp"
#include "intermediate/opcodes/Add.hpp"
#include "intermediate/opcodes/AOffset.hpp"

namespace ast {

using namespace intermediate;

/*
 * ===================== ARRAY HANDLING =======================
 */

ArrayHandling::ArrayHandling(
	TypeDeclaration *at,
	Operand *b,
	CodeContainer &container,
	std::list<Operand *> lbounds,
	std::list<Operand *> rbounds,
	std::list<Operand *> directs
) :		arrayType(at),
		base(b),
		cc(container),
		leftBounds(lbounds),
		rightBounds(rbounds),
		directions(directs)
{
	assert(at != NULL);
	assert(at->baseType == BASE_TYPE_ARRAY);

	GCTypes::GenTypeElements gte = 
		GCTypes::GenTypeElements(false, NULL, *this->arrayType, NULL);
	this->arrayType->accept(gte);

	this->indices = gte.getIndices();
	for (std::list<DiscreteRange*>::const_iterator i = 
		this->indices.begin();
		i != this->indices.end(); i++) {

		ImmediateOperand *lb = 
			new ImmediateOperand((*i)->getLeftBound());
		ImmediateOperand *rb = 
			new ImmediateOperand((*i)->getRightBound());
		ImmediateOperand *direct = 
			new ImmediateOperand((*i)->getDirection());
		this->leftBounds.push_back(lb);
		this->rightBounds.push_back(rb);
		this->directions.push_back(direct);
	}

	assert(gte.composite.size() == 1);
	assert(gte.referredTypes.size() == 1);

	TypeElement *elem = gte.composite.front();
	std::string name = elem->name;

	this->itype = TypeFactory::lookupType(name);
	assert(this->itype != NULL);

	this->elementType = gte.referredTypes.front();
}

ArrayHandling::~ArrayHandling()
{
}

void
ArrayHandling::factorize(void)
{
	// reverse dimension sizes
	// (..., d6, d5, d4, d3, d2)
	std::list<Operand *> rsizes;

	assert(this->leftBounds.size() == this->rightBounds.size());
	assert(this->leftBounds.size() > 0);

	std::list<Operand *>::const_reverse_iterator lbit = 
		this->leftBounds.rbegin();
	std::list<Operand *>::const_reverse_iterator rbit = 
		this->rightBounds.rbegin();
	std::list<Operand *>::const_reverse_iterator dirit = 
		this->directions.rbegin();

	std::list<Operand *>::const_reverse_iterator llast = 
		this->leftBounds.rend();
	llast--;
	// iterate over all but *first* entry (first is the actual dimension
	// that gets subscribed to with the first subscription element,
	// so the factor is always constant 1)
	for (; lbit != llast; lbit++, rbit++, dirit++) {
		// size := [(right - left) * direction] + 1
		
		Register *r1 = this->cc.createRegister(OP_TYPE_INTEGER);
		Sub *diff = new Sub(*rbit, *lbit, r1);
		this->cc.addCode(diff);

		Register *r2 = this->cc.createRegister(OP_TYPE_INTEGER);
		Register *r3 = this->cc.createRegister(OP_TYPE_INTEGER);
		
		IMul *m = new IMul(r1, *dirit, r2);

		Add *inc = new Add(r2, ImmediateOperand::getOne(), r3);
		this->cc.addCode(m);
		this->cc.addCode(inc);
		rsizes.push_back(r3);
	}

	// factors should be
	// (d2 * d3 * d4 * d5 * d6) 
	// (d3 * d4 * d5 * d6)
	// (d4 * d5 * d6)
	// (d5 * d6)
	// (d6)
	intermediate::Operand *factorial = ImmediateOperand::getOne();
	for (std::list<Operand *>::const_iterator i = rsizes.begin();
		i != rsizes.end(); i++) {

		Register *r1 = this->cc.createRegister(OP_TYPE_INTEGER);
		IMul *m = new IMul(factorial, *i, r1);
		this->cc.addCode(m);
		factorial = r1;
		this->dimensionFactors.push_front(factorial);
	}
	
	// also store 1 as factor for first dimension
	this->dimensionFactors.push_front(ImmediateOperand::getOne());
}

Register *
ArrayHandling::subscribe(std::list<Operand *> relativeIndices)
{
	// array(x, y, z) 
	// corresponds to (d1, d2, d3, d4, d5, d6)
	//
	// -->   (x - left(d1)) * direction1 * factor(d2, d3, d4, d5, d6)
	//     + (y - left(d2)) * direction2 * factor(d3, d4, d5, d6)
	//     + (y - left(d3)) * direction3 * factor(d4, d5, d6)
	//
	// the corresponding factors should be in dimensionFactors
	// or can be created with factorize().
	//
	if (this->dimensionFactors.empty()) {
		this->factorize();
	}

	assert(! this->dimensionFactors.empty());
	assert(this->dimensionFactors.size() >= relativeIndices.size());
	assert(this->itype != NULL);

	Operand *offset = ImmediateOperand::getZero();
	std::list<Operand *>::const_iterator f = 
		this->dimensionFactors.begin();
	std::list<Operand *>::const_iterator lb = 
		this->leftBounds.begin();
	std::list<Operand *>::const_iterator dir = 
		this->directions.begin();
	std::list<Operand *>::const_iterator ri = relativeIndices.begin();

	while (ri != relativeIndices.end()) {
		Register *r1 = this->cc.createRegister(OP_TYPE_INTEGER);
		Register *r2 = this->cc.createRegister(OP_TYPE_INTEGER);
		Register *r3 = this->cc.createRegister(OP_TYPE_INTEGER);
		Register *r4 = this->cc.createRegister(OP_TYPE_INTEGER);

		Sub *s1 = new Sub(*ri, *lb, r1);
		IMul *m1 = new IMul(r1, *f, r2);
		IMul *m2 = new IMul(r2, *dir, r3);
		Add *a1 = new Add(r3, offset, r4);
		offset = r4;

		this->cc.addCode(s1);
		this->cc.addCode(m1);
		this->cc.addCode(m2);
		this->cc.addCode(a1);
		f++; lb++; ri++; dir++;
	}

	Register *result = this->cc.createRegister(OP_TYPE_POINTER);
	AOffset *ao = new AOffset(this->base, offset, this->itype, result);
	this->cc.addCode(ao);

	return result;
}

universal_integer
ArrayHandling::transform_idx(
	const TypeDeclaration *arrayType,
	universal_integer idx
)
{
	std::list<DiscreteRange *> idcs;
	ResolveTypes::pickupIndexConstraint(arrayType, idcs);

	assert(idcs.size() == 1);
	DiscreteRange *dr = idcs.front();

	universal_integer ret = 
		(idx - dr->getLeftBound()) * dr->getDirection();
	return ret;
}

/*
 * =================== STATIC ARRAY ITERATE ===================
 */
void
StaticArrayIterate::iterate(void)
{
	std::list<universal_integer> lbounds = std::list<universal_integer>();
	std::list<universal_integer> rbounds = std::list<universal_integer>();
	std::list<universal_integer> ds = std::list<universal_integer>();
	std::list<universal_integer> i = std::list<universal_integer>();
	std::list<Operand*> offsetL = std::list<Operand*>();

	for (std::list<DiscreteRange*>::const_iterator d = 
		this->indices.begin();
		d != this->indices.end(); d++) {

		universal_integer lb = (*d)->getLeftBound();
		universal_integer rb = (*d)->getRightBound();
		universal_integer dir = (*d)->getDirection();

		lbounds.push_back(lb);
		i.push_back(lb);
		rbounds.push_back(rb);
		ds.push_back(dir);
	}

	while (StaticArrayIterate::checkLoop(i, rbounds, ds)) {

		for (std::list<universal_integer>::const_iterator i1 = 
			i.begin(); i1 != i.end(); i1++) {

			offsetL.push_back(new ImmediateOperand(*i1));
		}

		Register *element = this->subscribe(offsetL);
		this->iterateBody(element, i);

		StaticArrayIterate::incCounters(i, lbounds, rbounds, ds);
		offsetL.clear();
	}
}

bool
StaticArrayIterate::checkLoop(
	const std::list<universal_integer> &counters,
	const std::list<universal_integer> &rbounds,
	const std::list<universal_integer> &directions
)
{
	std::list<universal_integer>::const_iterator i = counters.begin();
	std::list<universal_integer>::const_iterator j = rbounds.begin();
	std::list<universal_integer>::const_iterator d = directions.begin();

	while (i != counters.end()) {
		bool ret = (((*i) * (*d)) <= ((*j) * (*d)));
		if (! ret) {
			return false;
		}

		i++;
		j++;
	}

	return true;
}

void
StaticArrayIterate::incCounters(
	std::list<universal_integer> &counters,
	const std::list<universal_integer> &lbounds,
	const std::list<universal_integer> &rbounds,
	const std::list<universal_integer> &directions
)
{
	bool carry = false;
	std::list<universal_integer>::iterator i = counters.begin();
	std::list<universal_integer>::const_iterator l = lbounds.begin();
	std::list<universal_integer>::const_iterator r = rbounds.begin();
	std::list<universal_integer>::const_iterator d = directions.begin();

	(*i) += (*d);

	while (i != counters.end()) {
		if (carry) {
			(*i) += (*d);
			carry = false;
		}

		if (((*r) * (*d)) < ((*i) * (*d))) {
			carry = true;
			(*i) = (*l);
		}

		i++;
		r++;
		l++;
		d++;
	}

	// make sure, that at least one index overflows if all 
	// values have been handled, otherwise checkLoop would
	// never yield false.
	if (carry) {
		i = counters.begin();
		r = rbounds.begin();
		d = directions.begin();
		(*i) = (*r) + (*d);
	}
}

/*
 * =================== ARRAY ITERATE ===================
 */

void
ArrayIterate::initCounters(void)
{
	for (std::list<Operand *>::const_iterator i = 
		this->leftBounds.begin();
		i != this->leftBounds.end();
		i++) {

		Register *b = this->cc.createRegister(OP_TYPE_INTEGER);
		Mov *m = new Mov(*i, b);
		this->cc.addCode(m);
		this->counters.push_back(b);
	}
}

void
ArrayIterate::incCounters(void)
{
	std::list<Operand*>::const_reverse_iterator rbi = 
		this->rightBounds.rbegin();
	std::list<Operand*>::const_reverse_iterator lbi = 
		this->leftBounds.rbegin();
	std::list<Operand*>::const_reverse_iterator di = 
		this->directions.rbegin();

	for (std::list<Register *>::reverse_iterator i = 
		this->counters.rbegin(); i != this->counters.rend();
		i++) {

		// c = c + direction
		Add *a = new Add(*di, *i, *i);
		this->cc.addCode(a);

		Register *r1 = this->cc.createRegister(OP_TYPE_INTEGER);
		Register *r2 = this->cc.createRegister(OP_TYPE_INTEGER);

		// r1 = c * direction
		// r2 = rightBound * direction
		IMul *m1 = new IMul(*i, *di, r1);
		IMul *m2 = new IMul(*rbi, *di, r2);

		this->cc.addCode(m1);
		this->cc.addCode(m2);

		// if r1 <= r2 goto loop
		Jbe *jbe = new Jbe(r1, r2, this->loop);
		this->cc.addCode(jbe);

		// c = lowerBound.
		Mov *m = new Mov(*lbi, *i);
		this->cc.addCode(m);

		rbi++;
		lbi++;
		di++;
	}

	// we're done with the iteration.
}

void
ArrayIterate::iterate(void)
{
	this->initCounters();
	this->cc.addCode(this->loop);
	std::list<Operand *> idx;

	for (std::list<Register *>::iterator i = this->counters.begin();
		i != this->counters.end(); i++) {

		idx.push_back(*i);
	}

	Register *elem = this->subscribe(idx);
	this->iterateBody(elem, this->counters);

	this->incCounters();
}

}; /* namespace ast */
