/* This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU General Public License
 * as published by the Free Software Foundation; either version 2
 * of the License, or (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA  02111-1307, USA.
 *
 * Copyright (C) 2014 Nanjing University, Nanjing, China
 */
 
package napping.policy;

import napping.core.Action;
import napping.core.Trajectory;
import napping.core.State;
import napping.core.Task;
import napping.core.Tuple;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import napping.core.Experiment;
import napping.core.Policy;
import napping.utills.Stats;
import weka.classifiers.Classifier;
import weka.classifiers.meta.AdditiveRegression;
import weka.classifiers.meta.Bagging;
import weka.classifiers.trees.REPTree;
import weka.core.Attribute;
import weka.core.FastVector;
import weka.core.Instance;
import weka.core.Instances;

/**
 * Implement the \pi_v policy in the paper:
 * Qing Da, Yang Yu, and Zhi-Hua Zhou. Napping for functional representation 
 * of policy. In: Proceedings of the 2014 International Conference on Autonomous
 * Agents and Multi-Agent Systems (AAMAS'14), Paris, France, 2014.
 *
 * @author Qing Da <daq@lamda.nju.edu.cn>
 */
public class NPPGWithNappingV extends Policy {

    private RandomPolicy rp;
    private List<Double> alphas;
    private List<Classifier> potentialFunctions;
    private Classifier base;
    private double stepsize;
    public Instances dataHead;
    private double stationaryRate;
    private double epsionGreedy;
    private double epsionGreedyDecay;
    private int nappingInterval;
    private int numTrain;

    public NPPGWithNappingV(int nappingInterval, int numTrain, Random rand) {
        this.nappingInterval = nappingInterval;
        this.numTrain = numTrain;
        rp = new RandomPolicy(new Random(rand.nextInt()));
        numIteration = 0;
        alphas = new ArrayList<Double>();
        potentialFunctions = new ArrayList<Classifier>();
        random = rand;

        REPTree reptree = new REPTree();
        reptree.setMaxDepth(100);
        reptree.setSeed(rand.nextInt());

        Bagging bag = new Bagging();
        bag.setClassifier(reptree);
        bag.setNumIterations(10);
        bag.setSeed(rand.nextInt());

        base = bag;

        stepsize = 1;
        stationaryRate = 0.8;
        epsionGreedy = 0.1;
        epsionGreedyDecay = 1;
    }

    @Override
    public long[] train(Task task, int iteration, int trialsPerIter, State initialState, int maxStep, boolean isPara, Random random) {
        throw new UnsupportedOperationException("Not supported yet.");
    }

    class ParallelExecute implements Runnable {

        private Trajectory rollout;
        private Task task;
        private Policy policy;
        private State initialState;
        private int maxStep;
        private Random random;

        public ParallelExecute(Task task, Policy policy, State initialState, int maxStep, int seed) {
            this.task = task;
            this.policy = policy;
            this.initialState = initialState;
            this.maxStep = maxStep;
            this.random = new Random(seed);
        }

        public void run() {
            List<Tuple> samples = Experiment.runTask(task, initialState, policy, maxStep, true, random);
            rollout = new Trajectory(task, samples);
        }

        public Trajectory getRollout() {
            return rollout;
        }
    }

    private double[] test(List<Task> tasks, boolean isPara, int maxStep, int iter) {
        double[] avaRewards = new double[tasks.size()];
        double[] avaSteps = new double[tasks.size()];

        List<ParallelExecute> list = new ArrayList<ParallelExecute>();
        ExecutorService exec = Executors.newFixedThreadPool(23);
        for (int i = 0; i < tasks.size(); i++) {
            ParallelExecute run = new ParallelExecute(tasks.get(i), this, tasks.get(i).getInitialState(), maxStep, random.nextInt());
            list.add(run);
            if (isPara) {
                exec.execute(run);
            } else {
                run.run();
            }
        }
        if (isPara) {
            exec.shutdown();
            try {
                while (!exec.awaitTermination(10, TimeUnit.SECONDS)) {
                }
            } catch (InterruptedException ex) {
                ex.printStackTrace();
            }
        }


        for (int i = 0; i < list.size(); i++) {
            Trajectory rollout = list.get(i).getRollout();
            avaRewards[i] = 0;
            for (Tuple t : rollout.samples) {
                avaRewards[i] += t.reward;
            }
            avaRewards[i] /= rollout.samples.size();
            avaSteps[i] = rollout.samples.size();
        }

        double[] mean_std_reward = Stats.mean_std(avaRewards);
        double[] mean_std_step = Stats.mean_std(avaSteps);
        double[] records = new double[7];
        records[0] = iter;
        records[1] = mean_std_reward[0];
        records[2] = mean_std_reward[1];
        records[3] = mean_std_step[0];
        records[4] = mean_std_step[1];

        return records;
    }

    public double[][] trainAndTest(List<Task> tasks, Task task, int iteration,
            int trialsPerIter, State initialState, int maxStep, boolean isPara, Random random) {
        List<double[]> resultList = new ArrayList<double[]>();
        long check01 = System.currentTimeMillis();
        double[] records0 = test(tasks, isPara, maxStep, 0);
        long check02 = System.currentTimeMillis();
        records0[5] = 0;
        records0[6] = check02 - check01;
        resultList.add(records0);

        for (int iter = 0; iter < iteration; iter++) {
            System.out.println("iter=" + iter);
            System.out.println("collecting samples...");

            long check1 = System.currentTimeMillis();
            ParallelExecute[] runs = new ParallelExecute[trialsPerIter];
            ExecutorService exec = Executors.newFixedThreadPool(23);
            for (int i = 0; i < trialsPerIter; i++) {
                runs[i] = new ParallelExecute(task, this, task.getInitialState(), maxStep, random.nextInt());
                if (isPara) {
                    exec.execute(runs[i]);
                } else {
                    runs[i].run();
                }
            }
            if (isPara) {
                exec.shutdown();
                try {
                    while (!exec.awaitTermination(10, TimeUnit.SECONDS)) {
                    }
                } catch (InterruptedException ex) {
                    ex.printStackTrace();
                }
            }
            long check2 = System.currentTimeMillis();
            System.out.println("collecting samples is done! Updating policy...");

            List<Trajectory> rollouts = new ArrayList<Trajectory>();

            for (int i = 0; i < trialsPerIter; i++) {
                rollouts.add(runs[i].getRollout());
            }
            
            update(rollouts, maxStep);

            if (iter > 0 && (iter+1) % nappingInterval == 0) {
                napping();
            }
                       
            long check3 = System.currentTimeMillis();
            double[] records = test(tasks, isPara, maxStep, iter + 1);
            System.out.println("averaged reward is "+records[1]+"\n");
            long check4 = System.currentTimeMillis();

            records[5] = 23 * (check2 - check1) + (check3 - check2);
            records[6] = 23 * (check4 - check3);
            resultList.add(records);
        }

        double[][] results = new double[resultList.size()][];
        for (int i = 0; i < resultList.size(); i++) {
            results[i] = resultList.get(i);
        }
        return results;
    }

    private void napping() {
        int NS = Math.min(numTrain, dataHead.numInstances());
        dataHead.randomize(random);
        Instances sampledData = new Instances(dataHead, NS);
        for (int i = 0; i < NS; i++) {
            Instance ins = dataHead.instance(i);
            double value = 0;
            for (int j = 0; j < potentialFunctions.size(); j++) {
                try {
                    value += alphas.get(j) * potentialFunctions.get(j).classifyInstance(ins);
                } catch (Exception ex) {
                    ex.printStackTrace();
                }
            }
            ins.setClassValue(value);
            sampledData.add(ins);
        }

        REPTree reptree = new REPTree();
        reptree.setMaxDepth(100);

        AdditiveRegression ar = new AdditiveRegression();
        ar.setClassifier(reptree);
        ar.setNumIterations(10);

        Classifier c = ar;
        try {
            c.buildClassifier(sampledData);
        } catch (Exception ex) {
            ex.printStackTrace();
        }
        alphas.clear();
        potentialFunctions.clear();

        alphas.add(1.0);
        potentialFunctions.add(c);

        int na = 3;
        for (int i = 0; i < NS; i++) {
            Instance ins = sampledData.instance(i);
            double maxUtility = Double.NEGATIVE_INFINITY;
            int maxInd = -1;
            double[] utilities = new double[na];

            for (int a = 0; a < na; a++) {
                ins.setValue(sampledData.numAttributes() - 2, a);
                double value = 0;
                for (int j = 0; j < potentialFunctions.size(); j++) {
                    try {
                        value += alphas.get(j) * potentialFunctions.get(j).classifyInstance(ins);
                    } catch (Exception ex) {
                        ex.printStackTrace();
                    }
                }
                utilities[a] = value;
                if (utilities[a] > maxUtility) {
                    maxUtility = utilities[a];
                    maxInd = a;
                }
            }

            ins.setValue(sampledData.numAttributes() - 2, maxInd);
        }
    }

    @Override
    protected void update(List<Trajectory> rollouts,
            int maxStep) {
        double gamma = 1;
        List<double[]> features = new ArrayList<double[]>();
        List<Double> weights = new ArrayList<Double>();
        List<Double> QHat = new ArrayList<Double>();

        int LAST = (int) (stationaryRate * maxStep);

        for (Trajectory rollout : rollouts) {
            Task task = rollout.task;
            List<Tuple> samples = rollout.samples;

            double E = 0;
            for (int step = samples.size() - 1; step >= Math.max(0, samples.size() - LAST); step--) {
                Tuple sample = samples.get(step);
                E = gamma * E + (sample.reward);

                features.add(task.getSAFeature(sample.s, sample.a));
                weights.add(sample.a.pi_s_a);
                QHat.add(E);
            }
        }

        if (null == dataHead) {
            int na = rollouts.get(0).task.actionSet.length;
            dataHead = constructDataHead(features.get(0).length, na);
        }
        Instances data = new Instances(dataHead, features.size());
        for (int i = 0; i < features.size(); i++) {
            double pi = weights.get(i);
            double Q = QHat.get(i);
            Instance ins = contructInstance(features.get(i), Q * pi * (1 - pi), 1.0);
            data.add(ins);
            // add data for sampling
            dataHead.add(ins);
        }

        Classifier c = getBaseLearner();
        try {
            c.buildClassifier(data);
        } catch (Exception ex) {
            ex.printStackTrace();
        }

        int t = alphas.size() + 1;
        alphas.add(stepsize / Math.sqrt(t));
        potentialFunctions.add(c);

        epsionGreedy = epsionGreedy * epsionGreedyDecay;
        numIteration++;
    }

    public Classifier getBaseLearner() {
        Classifier c = null;
        try {
            c = Classifier.makeCopy(base);
        } catch (Exception ex) {
            ex.printStackTrace();
        }
        return c;
    }

    @Override
    public Action makeDecisionD(State s, Task t, Random outRand) {
        Random thisRand = outRand == null ? random : outRand;
        int A = t.actionSet.length;

        if (numIteration == 0) {
            return rp.makeDecisionS(s, t, thisRand);
        } else {
            double[] utilities = getUtility(s, t);

            int bestAction = 0, num_ties = 1;
            for (int a = 1; a < A; a++) {
                double value = utilities[a];
                if (value >= utilities[bestAction]) {
                    if (value > utilities[bestAction] + Double.MIN_VALUE) {
                        bestAction = a;
                    } else {
                        num_ties++;
                        if (0 == thisRand.nextInt(num_ties)) {
                            bestAction = a;
                        }
                    }
                }
            }

            return new Action(bestAction, utilities[bestAction]);
        }
    }

    @Override
    public Action makeDecisionS(State s, Task t, Random outRand) {
        Random thisRand = outRand == null ? random : outRand;
        if (numIteration == 0 || thisRand.nextDouble() < epsionGreedy) {
            return rp.makeDecisionS(s, t, thisRand);
        } else {
            return makeDecisionD(s, t, outRand);
        }
    }

    public double[] getUtility(State s, Task t) {
        int A = t.actionSet.length;
        double[] utilities = new double[A];
        double maxUtility = Double.NEGATIVE_INFINITY;
        for (int a = 0; a < A; a++) {
            double[] stateActionFeature = t.getSAFeature(s, new Action(a));
            Instance ins = contructInstance(stateActionFeature, 0, 1.0);
            if (null == dataHead) {
                int na = t.actionSet.length;
                dataHead = constructDataHead(stateActionFeature.length, na);
            }
            ins.setDataset(dataHead);
            utilities[a] = 0;
            for (int j = 0; j < potentialFunctions.size(); j++) {
                try {
                    utilities[a] += alphas.get(j) * potentialFunctions.get(j).classifyInstance(ins);
                } catch (Exception ex) {
                    ex.printStackTrace();
                }
            }
            if (utilities[a] > maxUtility) {
                maxUtility = utilities[a];
            }
        }

        double norm = 0;
        for (int a = 0; a < A; a++) {
            utilities[a] = Math.exp(utilities[a]);
            norm += utilities[a];
        }

        for (int a = 0; a < A; a++) {
            utilities[a] /= norm;
        }

        return utilities;
    }

    public static Instance contructInstance(double[] stateActionTaskFeature, double label, double weight) {
        int D = stateActionTaskFeature.length;
        double[] values = new double[D + 1];
        values[D] = label;
        System.arraycopy(stateActionTaskFeature, 0, values, 0, D);
        Instance ins = new Instance(weight, values);
        return ins;
    }

    public static Instances constructDataHead(int D, int na) {
        FastVector attInfo_x = new FastVector();
        for (int i = 0; i < D - 1; i++) {
            attInfo_x.addElement(new Attribute("att_" + i, i));
        }

        FastVector att = new FastVector(na);
        for (int i = 0; i < na; i++) {
            att.addElement("" + i);
        }
        attInfo_x.addElement(new Attribute("action", att, D - 1));

        attInfo_x.addElement(new Attribute("class", D));
        Instances data = new Instances("dataHead", attInfo_x, 0);
        data.setClassIndex(data.numAttributes() - 1);
        return data;
    }

    public double potentialFunctionValue(Instance ins, int numIteration) {
        double value = 0;
        for (int j = 0; j < numIteration; j++) {
            try {
                value += alphas.get(j) * potentialFunctions.get(j).classifyInstance(ins);
            } catch (Exception ex) {
                ex.printStackTrace();
            }
        }
        return value;
    }

    public double getStepsize() {
        return stepsize;
    }

    public void setStepsize(double stepsize) {
        this.stepsize = stepsize;
    }

    @Override
    public void setNumIteration(int numIteration) {
        this.numIteration = Math.min(potentialFunctions.size(), numIteration);
    }

    public void setEpsionGreedy(double epsionGreedy) {
        this.epsionGreedy = epsionGreedy;
    }

    public void setEpsionGreedyDecay(double epsionGreedyDecay) {
        this.epsionGreedyDecay = epsionGreedyDecay;
    }

    public void setStationaryRate(double stationaryRate) {
        this.stationaryRate = stationaryRate;
    }
}
