#include "api.hh"
#include <iostream>
#include <stdio.h>
#include <vector>
#include <algorithm>
#include <map>
#include <queue>
using namespace std;

#define Map vector<vector<etat_case>> 


struct EPosition{  // Enhanced position
    int colonne;
    int ligne;
    int niveau;
    position mama_pos;

    bool operator<(const EPosition& other){
        int diffx = (colonne - mama_pos.colonne), diffy = (ligne - mama_pos.ligne);
        int dist_sq = diffx * diffx + diffy * diffy;

        int other_diffx = (other.colonne - other.mama_pos.colonne), other_diffy = (other.ligne - other.mama_pos.ligne);
        int other_dist_sq = other_diffx * other_diffx + other_diffy * other_diffy;
        return dist_sq < other_dist_sq;
    }
};

struct as_path{
    vector<position> cur;
    int cost;

    bool operator<(as_path& other){
        return cost < other.cost;
    }
};






#pragma region Constantes
const int CONQUETE_NIDS = 0;
const int RECHERCHE_PAIN = 1;
const vector<pair<int, int>> DIR = {{1, 0}, {0, 1}, {-1, 0}, {0, -1}};
#pragma endregion

#pragma region Variables
int player_id;
// Map map(HAUTEUR, vector<etat_case>(LARGEUR));  // do we really need this?
vector<position> nids(0);
vector<position> nids_obtenus(0);
int current_phase;
map<int, vector<direction>> paths;
map<direction, pair<int, int>> MDIR;
#pragma endregion



int calc_cost(position target, vector<position>& path){
    int diffx = path.at(path.size() - 1).colonne - target.colonne;
    int diffy = path.at(path.size() - 1).ligne - target.ligne;
    int distance_cost = diffx * diffx + diffy * diffy;
    return distance_cost + path.size();
}



vector<position> find_path2(position cur, position target, int iter_lim=40){
    auto stat = info_case(target);
    if(stat.contenu == BUISSON || (stat.contenu == BARRIERE && info_barriere(target) == FERMEE)){
        return {};
    }

    int count_it = 0;
    auto cmp = [](as_path left, as_path right) { return (left.cost) > (right.cost);};
    priority_queue<as_path, std::vector<as_path>, decltype(cmp)> block_path(cmp);
    vector<position> init(1, cur);
    block_path.push({init, calc_cost(target, init)});
    while (count_it < iter_lim){
        as_path best = block_path.top();
        block_path.pop();
        // cout << "ittt\n";
        for(pair<int, int> p : DIR){
            position tmp_pos = {best.cur.at(best.cur.size() - 1).colonne + p.first, best.cur.at(best.cur.size() - 1).ligne + p.second, best.cur.at(best.cur.size() - 1).niveau};
            // printf("%d %d ===", p.first, p.second);
            // printf("%d %d -- ", best.cur.at(best.cur.size() - 1).colonne, best.cur.at(best.cur.size() - 1).ligne);
            // printf("aaaa %d, %d\n", tmp_pos.colonne, tmp_pos.ligne);
            etat_case etat = info_case(tmp_pos);
            type_case stat_tmp = etat.contenu;
            if(stat_tmp == BUISSON || (stat_tmp == BARRIERE && info_barriere(target) == FERMEE)){
                continue;
            }

            vector<position> new_path = best.cur;
            new_path.push_back(tmp_pos);
            if(tmp_pos.colonne == target.colonne && tmp_pos.ligne == target.ligne){
                return new_path;
            }
            block_path.push({new_path, calc_cost(target, new_path)});
        }
        count_it++;
    }
    
    // printf("cost: %d\n", block_path.top().cost);
    return block_path.top().cur;
}


vector<direction> convert2dirs(const vector<position>& pos){
    vector<direction> ret(pos.size() - 1);
    for(int i = 0; i < pos.size() - 1; i++){
        int ax = pos.at(i).colonne, ay = pos.at(i).ligne;
        int bx = pos.at(i + 1).colonne, by = pos.at(i + 1).ligne;
        int diffx = bx - ax, diffy = by - ay;
        // printf("diff %d, %d ", diffx, diffy);

        if(diffx == 1){
            // printf("e\n");
            ret.at(i) = EST;
        }else if(diffx == -1){
            // printf("o\n");
            ret.at(i) = OUEST;
        }else if(diffy == 1){
            // printf("n\n");
            ret.at(i) = NORD;
        }else if(diffy == -1){
            // printf("s\n");
            ret.at(i) = SUD;
        }
    }
    return ret;
}


vector<position> sort_by_distance(vector<position>& pos, position& mama_pos){
    
    vector<EPosition> ep(pos.size());
    for(int i = 0; i < pos.size(); i++){
        ep.at(i) = {pos.at(i).colonne, pos.at(i).ligne, pos.at(i).niveau, mama_pos};
    }

    sort(ep.begin(), ep.end());

    vector<position> ret(pos.size());
    for(int i = 0; i < pos.size(); i++){
        ret.at(i) = {ep.at(i).colonne, ep.at(i).ligne, ep.at(i).niveau};
    }
    return ret;
}


bool can_exit_1st_phase(){
    // 1. Tous les nids sont occupes
    bool occupied = true;
    for(position nid_p : nids){
        etat_nid etat = info_nid(nid_p);
        occupied = occupied & !(etat == JOUEUR_0 || etat == JOUEUR_1);
    }
    if(occupied){
        return true;
    }

    
    // 2. Atteint la 20eme tour
    if(tour_actuel() > 20){
        return true;
    }


    return false;
}



// Fonction appelée au début de la partie.
void partie_init(void)
{
    player_id = moi();
    current_phase = CONQUETE_NIDS;
    MDIR[NORD] = make_pair(0, 1);
    MDIR[SUD] = make_pair(0, -1);
    MDIR[EST] = make_pair(1, 0);
    MDIR[OUEST] = make_pair(-1, 0);


    for(troupe tr : troupes_joueur(player_id)){
        paths[tr.id] = vector<direction>(0);
    }

    #pragma region Charge la carte
    for(int y = 0; y < HAUTEUR; y++){
        for(int x = 0; x < LARGEUR; x++){
            position pos = {y, x, 0};
            // map.at(y).at(x) = info_case(pos);
            // ajoute au tableau si c'est un nid
            if(info_case(pos).contenu == NID){
                nids.push_back(pos);
            }
        }
    }
    #pragma endregion
}


// Fonction appelée à chaque tour.
void jouer_tour(void)
{
    debug_poser_pigeon({0, 0, 0}, PIGEON_ROUGE);
    debug_poser_pigeon({39, 39, 0}, PIGEON_BLEU);
    debug_poser_pigeon({0, 39, 0}, PIGEON_JAUNE);
    // position pos = {12, 12, 0}, pos2 = {17, 17, 0};
    // vector<direction> path2br = trouver_chemin(pos, pos2);
    // printf("len: %d\n", path2br.size());
    // for(direction d : path2br){
    //     cout << d << endl;
    // }
    // vector<position> path2br = find_path2(pos, pos2);
    // printf("len: %d\n", path2br.size());
    // for(position p : path2br){
    //     printf("x: %d, y: %d\n", p.colonne, p.ligne);
    // }
    position pos = {20, 15, 0};
    debug_poser_pigeon(pos, PIGEON_ROUGE);

    vector<troupe> troupes = troupes_joueur(player_id);
    for(troupe tr : troupes){
        // printf("POSX: %d, POSY: %d\n", tr.maman.colonne, tr.maman.ligne);
        debug_poser_pigeon(tr.maman, PIGEON_JAUNE);

        if(paths.at(tr.id).size() > 0){
            for(int i = 0; i < min(5, (int)paths.at(tr.id).size()); i++){
                avancer(tr.id, paths.at(tr.id).at(0));
                paths.at(tr.id).erase(paths.at(tr.id).begin());
            }
            printf("x: %d, y: %d\n", tr.maman.colonne, tr.maman.ligne);
        }else{
            
            // vector<direction> dirs = convert2dirs(find_path2(tr.maman, pos));
            direction cur_dir = tr.dir;
            pair<int, int> opposite = MDIR.at(cur_dir);
            opposite.first *= -1; opposite.second *= -1;

            for(pair<int, int> p : DIR){
                position new_case = {tr.maman.colonne + p.first, tr.maman.ligne + p.second, tr.maman.niveau};
                if(p != opposite){
                    vector<direction> dirs = trouver_chemin(new_case, pos);
                    if(dirs.size() > 0){
                        paths.at(tr.id) = dirs;
                        break;
                    }
                }
            }

            
            // paths.at(tr.id) = dirs;
            
            // for(direction d : dirs){
            //     string s = "Nope";
            //     switch(d){
            //     case NORD:
            //         s = "nord";
            //         break;
            //     case SUD:
            //         s = "sud";
            //         break;
            //     case EST:
            //         s = "est";
            //         break;
            //     case OUEST:
            //         s = "ouest";
            //         break;
            //     default:
            //         break;
            // }
            // cout << s << " ";
            // }
        }
    }  
}

// Fonction appelée à la fin de la partie.
void partie_fin(void)
{
    // TODO
    
    
}


  // printf("%d %d ===", p.first, p.second);
            // printf("%d %d -- ", best.cur.at(best.cur.size() - 1).colonne, best.cur.at(best.cur.size() - 1).ligne);
            // printf("aaaa %d, %d\n", tmp_pos.colonne, tmp_pos.ligne)