package puppetmaster.strats;

import java.text.DecimalFormat;
import java.text.SimpleDateFormat;
import java.util.ArrayList;
import java.util.Calendar;
import java.util.Collections;
import java.util.List;

import org.apache.log4j.Logger;

import puppetmaster.model.CCY;
import puppetmaster.model.Contract;
import puppetmaster.model.ContractExtent;
import puppetmaster.model.INSTR;
import puppetmaster.model.Order;
import puppetmaster.model.Position;
import puppetmaster.model.PositionChange;
import puppetmaster.model.PositionRecord;
import puppetmaster.model.mktData.Collector;
import puppetmaster.model.mktData.MemCube;
import puppetmaster.model.mktData.OHLCV;
import puppetmaster.model.mktData.Quote;
import puppetmaster.model.portfolio.PortfolioComposer;
import puppetmaster.model.portfolio.PotentialTrade;
import puppetmaster.model.portfolio.PortfolioComposer.Bias;
import puppetmaster.model.portfolio.PortfolioComposer.Style;
import puppetmaster.model.strategy.BasicStratPart;
import puppetmaster.model.strategy.Descriptor;
import puppetmaster.model.strategy.Param;
import puppetmaster.model.strategy.Strategy;
import puppetmaster.model.strategy.StrategyEvent;

/**
 * Pairs trading strategy inspired by:
 * 
 * http://www.palantirfinance.com/analysis-blog/?p=194
 * 
 * The strategy looks at the universe of equities, identifying those with
 * correlation greater than some threshold over some longish time horizon. This
 * results in a set of equity pairs. Each pair's z-score is calculated and used
 * to signal long-short trades within a dynamic portfolio.
 * 
 * @author Tito Ingargiola
 */
public class Pairs extends BasicStratPart {

	static Logger _Log = Logger.getLogger(Pairs.class);

	final static String _Desc = 
		"Pairs trading strategy inspired by: \n\n"
		+ "    http://www.palantirfinance.com/analysis-blog/?p=194 \n\n"
		+ "The strategy looks at the universe of equities, identifying "
		+ "those with correlation greater than some threshold over some "
		+ "longish time horizon.  This results in a set of equity pairs.  "
		+ "Each pair's short-term z-score is calculated and used to signal "
		+ "long-short trades within a dynamic portfolio.";

	public final static String N = "sizeOfPortfolio";
	public final static String TrailDays = "trailDays";
	public final static String MinCorrelation = "minCorrelation";
	public final static String MinZScore = "minZscore";
	public final static String MaxZScore = "maxZscore";
	public final static String ZScoreDays = "zScoreDays";
	public final static String EvenlyWeight = "evenlyWeight?";

	public static Descriptor Descriptor() {
		ArrayList<Param> params = new ArrayList<Param>();

		String cd = "Unleveraged capital to apply";
		params.add(new Param(InitialBalance, cd, false, 1000000.0));

		String n = "# of instruments to hold in portfolio";
		params.add(new Param(N, n, true, 20));

		String tc = "Trailing days over which to calculate correlations";
		params.add(new Param(TrailDays, tc, true, 252));

		String mc = "the minimum correlation of pairs to consider";
		params.add(new Param(MinCorrelation, mc, true, .75));

		String zd = "trailing days over which to calc z-score";
		params.add(new Param(ZScoreDays, zd, true, 21));

		String mizs = "minimum z-score";
		params.add(new Param(MinZScore, mizs, true, 1.5));

		String mazs = "maximum z-score";
		params.add(new Param(MaxZScore, mazs, true, 3.0));

		String ew = "evenly-weight portfolio? (or weight by z-score)";
		params.add(new Param(EvenlyWeight, ew, true, true));

		Descriptor _desc = new Descriptor
			("puppetmaster.strats.Pairs", _Desc,params);

		return _desc;
	}

	public Pairs(Strategy strat, Descriptor d) {
		super(strat, d);
	}

	/** listen for select strategyEvents */
	public void strategyEvent(StrategyEvent event) {
		super.strategyEvent(event);
		switch (event.type) {
		case Activated:				_init();				break;
		case DescriptorChanged:		_readDesc();			break;
		case PositionChanged:		_posnChange(event); 	break;
		}
	}

	/**
	 * mkt data goes here, but we only listen for the BOD ("Beginning of day")
	 * event
	 */
	public void quote(Quote q) {
		if (q.type() != Quote.Type.BOD) return; // daily strat...

		try {
			List<_Pair> pairs = _getPairs();
			List<_Pair> sel = _selectPairs(pairs);
			List<Position> dp = (List<Position>) _desiredPfolio(pairs, sel);
			_Log.debug("PORTFOLIO: " + dp);
			_trade(dp);
		} catch (Exception e) {
			_Log.error(e.getMessage(), e);
		}
	}

	// /////// IMPL -------

	/** react to position changed events */
	void _posnChange(StrategyEvent event) {
		PositionChange pc = (PositionChange) event.obj;
		if (pc == null || pc.after == null || pc.after.contract() == null)
			return;
		Contract c = pc.after.contract();
		if (pc.after.qty() == 0 && _findContaining(_pairs, c).size() == 0) {
			_unsubscribe(c);
		}
	}

	/** given a desired portfolio, trade to make it so */
	void _trade(List<? extends Position> desiredFolio) {
		PositionRecord[] currFolio = posns();
		// Given our current portfolio (posns) and our desired state (dfolio),
		// we generate the set of orders which will transform from the former
		// to the latter
		//
		List<Order> orders = PortfolioComposer.TransformPortfolio(currFolio,
				desiredFolio, _orderF(), _strat, _strat.account());

		// place the orders for execution
		for (Order order : orders) {
			try {
				order.setType(Order.Type.MOC); // we trade at the close only
				_subscribe(order.contract());
				_execP().placeOrder(order);
			} catch (Exception e) {
				_Log.error(e.getMessage(), e);
			}
		}
	}

	/**
	 * Given the list of all acceptable pairs as well those selected, determine
	 * and return what our new desired portfolio should be
	 */
	List<? extends Position> _desiredPfolio(List<_Pair> pairs, List<_Pair> sel) {

		List<PotentialTrade> pts = new ArrayList<PotentialTrade>();
		// first check if we want to remove any existing positions as they no
		// no longer meet the requirements for being in a pair (ie, they're
		// insufficiently correlated or their z-scores are out of range)
		//  
		List<_Pair> toremove = new ArrayList<_Pair>();
		for (_Pair p : _pairs) {
			if (!_contains(pairs, p)) { toremove.add(p); }
		}
		_Log.debug("Removing #" + toremove.size() + "/" + _pairs.size());
		for (_Pair p : toremove) { _pairs.remove(p); }

		// calculate how many pairs to add and then select them from the
		// sorted collection of selectable pairs
		int pairsToAdd = (_n / 2) - _pairs.size();
		_Log.debug("We have #" + _pairs.size() + " and we will add "
				+ pairsToAdd);
		List<_Pair> toadd = new ArrayList<_Pair>();
		for (int selIdx = 0; selIdx < sel.size(); selIdx++) {
			_Pair p = sel.get(selIdx);
			if (_findContaining(_pairs, p.first).size() == 0
					&& _findContaining(_pairs, p.first).size() == 0) {
				toadd.add(p);
			}
			if (toadd.size() >= pairsToAdd) break;
		}

		for (_Pair p : toadd) { _pairs.add(p); }
		for (_Pair p : _pairs) {
			Order.Action firstAct = (p.zScore > 0) ? Order.Action.SELL
					: Order.Action.BUY;
			double firstZ = (p.zScore > 0) ? -p.zScore : p.zScore;
			pts.add(new PotentialTrade
				(p.first, _lastPx(p.first), 0, firstZ, 0, firstAct));
			pts.add(new PotentialTrade
				(p.second, _lastPx(p.second), 0, -firstZ,0, firstAct.invert()));
		}
		// Now that we have a raw collection of "pairs" trades, let's
		// assemble them into an evenly weighted long-short portfolio
		pts = _evenlyWeight 
			? PortfolioComposer.CreateEvenlyWeightedPortfolio
				(pts, Bias.LongShort, Style.Trending, _n, _alloc / 1000)
			: PortfolioComposer.CreateWeightedPortfolio
				(pts, Bias.LongShort, Style.Trending, _n, _alloc / 1000);

		_Log.debug("From #" + _pairs.size() + " pairs, we've created "
				+ pts.size() + " desired posns");

		return pts;
	}

	/** utility to get approx price for position sizing */
	double _lastPx(Contract c) {
		double px = 0;
		OHLCV o = _qube.getTime(c, _now());
		// we get today's open - note that this is only OK because
		// we're trading on the close
		//
		px = (o == null) ? 0 : o.open();
		if (px == 0) { // or yesterday's close
			Collector col = _collector(c);
			px = (col == null) ? 0 : col.ohlcv(0).close(); 
		}
		return px;
	}

	/**
	 * get _trailDays correlation matrix across all contracts, and for each pair
	 * with corr > _minCorr check |zscore| and if it's in range [_minzscore,
	 * maxzscore] add to list of pairs which is returned. */
	ArrayList<_Pair> _getPairs() throws Exception {
		ArrayList<_Pair> pairs = new ArrayList<_Pair>();
		_updateCorrs();
		int ccount = 0;
		int cpairs = 0;
		for (int i = 0; i < _corrs.length; i++) {
			for (int j = i + 1; j < _corrs.length; j++) {
				cpairs++;
				if (_corrs[i][j] >= _minCorr) { // if correlation
					ccount++;
					double zscore = // get zscore
					_qube.zscore(i, j, _now(), _zScoreDays);
					if (Math.abs(zscore) >= _minZscore
							&& Math.abs(zscore) <= _maxZscore) {
						// if zscore create a pair and add to list
						pairs.add(new _Pair(_qube.atIndex(i), _qube.atIndex(j),
							_corrs[i][j], zscore));
					}
				}
			}
			Collections.sort(pairs);// sort by natural ordering
		}
		if (_Log.isDebugEnabled()) {
			SimpleDateFormat sdf = new SimpleDateFormat("yy/MM/dd");
			_Log.debug(sdf.format(_now()) + ": Out of " + _qube.data.length
					+ " contracts and " + cpairs + " pairs, " + ccount
					+ " met correlation " + "req and " + pairs.size()
					+ "met z-score req.");
		}
		return pairs;
	}

	/** we update correlation matrix once a month */
	void _updateCorrs() {
		int month = _gcNow().get(Calendar.MONTH);
		if (_corrs == null || _month != month) {
			_corrs = _qube.correlation(_now(), _trailDays, true);
			_month = month;
		}
	}

	/** Selection is based on z-score; note that we don't allow any name to 
	 *  be in more than one pair - this is a rich area for optimization.. */
	List<_Pair> _selectPairs(List<_Pair> pairs) {
		ArrayList<Contract> cons = new ArrayList<Contract>();
		ArrayList<_Pair> ps = new ArrayList<_Pair>();
		for (_Pair p : pairs) {
			if (!cons.contains(p.first) && !cons.contains(p.second)) {
				cons.add(p.first);
				cons.add(p.second);
				ps.add(p);
			}
		}
		Collections.sort(ps);
		_Log.debug("We've (arbitrarily) selected " + ps.size() + "/"
				+ pairs.size() + " pairs");

		return ps;
	}

	/** util: find pair in list based on contracts only */
	boolean _contains(List<_Pair> pairs, _Pair sought) {

		for (_Pair p : pairs) {
			if (p.equals(sought))
				return true;
		}
		return false;
	}

	/** return all pairs containing the specified contract within the supplied
	 * list; returns an empty list if none are found. */
	List<_Pair> _findContaining(List<_Pair> pairs, Contract c) {
		ArrayList<_Pair> found = new ArrayList<_Pair>();
		for (_Pair p : pairs) {
			if (p.first == c || p.second == c)
				found.add(p);
		}
		return found;
	}

	/** initialize strategy */
	void _init() {
		_readDesc();
		try {
			_proxy = ContractExtent.Default().getContract(INSTR.STK, "SMART",
					"SPY", CCY.USD);
			_subscribe(_proxy);
			_qube = MemCube.DailyDefault();
			_corrs = null;
			_month = -1;
			_pairs = new ArrayList<_Pair>();
		} catch (Exception e) {
			_Log.error(e.getMessage(), e);
		}
	}

	/** read metadata descriptor and store values */
	void _readDesc() {
		_alloc = ((Number) _desc.valueOf(InitialBalance)).doubleValue();
		_n = ((Number) _desc.valueOf(N)).intValue();
		_trailDays = ((Number) _desc.valueOf(TrailDays)).intValue();
		_minCorr = ((Number) _desc.valueOf(MinCorrelation)).doubleValue();
		_zScoreDays = ((Number) _desc.valueOf(ZScoreDays)).intValue();
		_minZscore = ((Number) _desc.valueOf(MinZScore)).doubleValue();
		_maxZscore = ((Number) _desc.valueOf(MaxZScore)).doubleValue();
		_evenlyWeight = (Boolean) _desc.valueOf(EvenlyWeight);
		// _nuisancePct = ((Number)_desc.valueOf(NuisancePct)).doubleValue();
		// _tstyle = (Style)_desc.valueOf(TStyle);
	}

	/** necessary to play nice in stratbox */
	public Pairs clone() throws CloneNotSupportedException {
		Pairs clone = (Pairs) super.clone();
		clone._init();
		return clone;
	}

	double _alloc;
	int _n;
	int _trailDays;
	double _minCorr;
	int _zScoreDays;
	double _minZscore;
	double _maxZscore;
	Contract _proxy;
	MemCube<OHLCV> _qube;
	double[][] _corrs;
	int _month;
	boolean _evenlyWeight;
	ArrayList<_Pair> _pairs; // pairs we should have positions in

	/** struct to hold pair of contracts, correlations & z-score */
	static class _Pair implements Comparable<_Pair> {
		final Contract first;
		final Contract second;
		final double correlation;
		/** if zscore > 0 -> sell first, buy second else vice-versa. */
		final double zScore;

		/** ctor */
		_Pair(Contract one, Contract two, double corr, double z) {
			first = one;
			second = two;
			correlation = corr;
			zScore = z;
		}

		public String toString() {
			DecimalFormat df = new DecimalFormat("0.0#");
			return first + "-" + second + "(" + df.format(correlation) + "/"
					+ df.format(zScore) + ")";
		}

		/** compare by |z-score| */
		public int compareTo(_Pair o) {
			double tz = Math.abs(zScore);
			double oz = Math.abs(o.zScore);
			if (tz == oz)
				return 0;

			return (tz < oz) ? -1 : 1;
		}

		public boolean equals(_Pair o) {
			return first == o.first && second == o.second;
		}

	} // _Pair
}