#include "composantes.hh"

position UHASH(int h) {
    position p;
    p.ligne = h / 6;
    p.colonne = h % 6;
    return p;
}

bool in_map(position p) {
    return p.ligne > 0 && p.ligne < 6 && p.colonne > 0 && p.colonne < 6;
}

vector<position> voisins(position p) {
    vector<position> ret;
    position vois = p;
    vois.ligne++;
    if(in_map(vois)) {
        ret.push_back(vois);
    }
    vois.ligne--;
    vois.ligne--;
    if(in_map(vois)) {
        ret.push_back(vois);
    }
    vois.ligne++;
    vois.colonne++;
    if(in_map(vois)) {
        ret.push_back(vois);
    }
    vois.colonne--;
    vois.colonne--;
    if(in_map(vois)) {
        ret.push_back(vois);
    }
    return ret;
}

int repr(vector<int> &ufind, int v) {
    if(ufind[v] == v || ufind[v] == -1)
        return ufind[v];
    else {
        int t = repr(ufind, ufind[v]);
        ufind[v] = t;
        return t;
    }
}

vector<int> make_ufind() {
    vector<int> ufind(36);
    for(int i=0 ; i < 36 ; i++) {
        ufind[i] = i;
    }

    position p;
    for(p.ligne=0 ; p.ligne < 6 ; p.ligne++) {
        for(p.colonne=0 ; p.colonne < 6 ; p.colonne++) {
            for(position vois : voisins(p)) {
                if(est_vide(p, moi()))
                    ufind[HASH(p)] = -1;
                if(type_case(p, moi()) == type_case(vois, moi())) {
                    ufind[HASH(vois)] = repr(ufind, HASH(p));
                }
            }
        }
    }
    return ufind;
}

int count_composantes() {
    auto ufind = make_ufind();
    int comp = 0;
    for(int i=0 ; i < 36 ; i++) {
        if(ufind[i] == i) {
            comp++;
        }
    }
    return comp;
}

vector<vector<int>> make_dist() {
    vector<vector<int>> dist(36, vector<int>(36, 1000));

    position p;
    for(p.ligne=0 ; p.ligne < 6 ; p.ligne++) {
        for(p.colonne=0 ; p.colonne < 6 ; p.colonne++) {
            dist[HASH(p)][HASH(p)] = 0;
            for(position vois : voisins(p)) {
                if(!est_vide(p, moi()) && !est_vide(vois, moi())) {
                    dist[HASH(vois)][HASH(p)] = dist[HASH(p)][HASH(vois)] = 1;
                }
            }
        }
    }

    for(int i=0 ; i < 36 ; i++) {
        for(int j=0 ; j < 36 ; j++) {
            for(int k=0 ; k < 36 ; k++) {
                dist[i][j] = min(dist[i][j], dist[i][k] + dist[k][j]);
            }
        }
    }
    return dist;
}

vector<position> connect(position a, position b, vector<vector<int>> &dist) {
    int mn = 1000;
    position am, bm;
    vector<position> ret;

    if(dist[HASH(a)][HASH(b)] > 500) {
        return ret;
    }

    for(position i : positions_region(a, moi())) {
        for(position j : positions_region(b, moi())) {
            if(dist[HASH(i)][HASH(j)] < mn) {
                mn = dist[HASH(i)][HASH(j)];
                am = i;
                bm = j;
            }
        }
    }

    int N = 0;
    ret.push_back(am);
    while(mn > 0) {
        for(position v : voisins(ret[ret.size()-1])) {
            if(dist[HASH(v)][HASH(bm)] < mn) {
                ret.push_back(v);
                mn = dist[HASH(v)][HASH(bm)];
                break;
            }
        }
        N++;
        if(N > 100) {
            printf("WTF\n");
            ret.empty();
            return ret;
        }
    }
    return ret;
}

vector<position> chemin(vector<int> &ufind, position a, position b, vector<vector<bool>> &visited, bool start = true) {
    vector<position> ret;
    if(repr(ufind, HASH(a)) == repr(ufind, HASH(b))) {
        ret.push_back(b);
        return ret;
    }
    else {
        for(position vois : voisins(a)) {
            if(!visited[HASH(a)][HASH(vois)]) {
                visited[HASH(a)][HASH(vois)] = true;
                bool n_start = start && repr(ufind, HASH(vois)) == repr(ufind, HASH(a));
                vector<position> c_vois = chemin(ufind, vois, b, visited, n_start);
                if(ret.size() == 0 || c_vois.size() < ret.size()) {
                    if(!n_start) {
                        c_vois.push_back(a);
                    }
                    ret = c_vois;
                }
            }
        }
        return ret;
    }
}

vector<position> get_catalyse_path() {
    auto ufind = make_ufind();

    vector<tuple<int, int>> composantes;
    for(int i=0 ; i < 36 ; i++) {
        if(ufind[i] == i) {
            composantes.push_back(make_tuple(taille_region(UHASH(i), moi()), i));
        }
    }
    sort(begin(composantes), end(composantes));

    vector<tuple<position, position>> pairs;
    for(int i=0 ; i < composantes.size() ; i++) {
        for(int j=0 ; j < i ; j++) {
            int a = get<1>(composantes[i]);
            int b = get<1>(composantes[j]);
            if(type_case(UHASH(a), moi()) == type_case(UHASH(b), moi())) {
                pairs.push_back(make_tuple(UHASH(a), UHASH(b)));
            }
        }
    }
    sort(begin(pairs), end(pairs), [](tuple<position, position> a, tuple<position, position> b) {
        position a1, a2, b1, b2;
        tie(a1, a2) = a;
        tie(b1, b2) = b;
        return taille_region(b1, moi()) + taille_region(b2, moi()) < taille_region(a1, moi()) + taille_region(a2, moi());
    });

    auto dist = make_dist();
    for(auto &pair : pairs) {
        vector<vector<bool>> visited(36, vector<bool>(false, 36));
        position a, b;
        tie(a, b) = pair;
        //auto path = chemin(ufind, a, b, visited);
        auto path = connect(a, b, dist);
        if(path.size()-2 <= nombre_catalyseurs()) {
            return path;
        }
    }

    return vector<position>();
}

/* ******************************** */

int sq(int x) {
    return x*x;
}

int eval_catalysis(position p, case_type m, auto &ufind) {
    int init = sq(quantite_transmutation_or(taille_region(p, adversaire())));
    for(auto vois : voisins(p)) {
        if(type_case(vois, adversaire()) != type_case(p, adversaire())) {
            init += sq(quantite_transmutation_or(taille_region(vois, adversaire())));
        }
    }

    if(catalyser(p, adversaire(), m) == 0) {
        int after = sq(quantite_transmutation_or(taille_region(p, adversaire())));
        for(auto vois : voisins(p)) {
            if(type_case(vois, adversaire()) != type_case(p, adversaire())) {
                after += sq(quantite_transmutation_or(taille_region(vois, adversaire())));
            }
        }
        annuler();
        return init-after;
    }
    else {
        return 0;
    }
}

vector<tuple<position, case_type>> get_attack_cells() {
    vector<case_type> materials({PLOMB, FER, CUIVRE, SOUFRE, MERCURE});
    vector<tuple<position, case_type>> ret;
    for(case_type mater : materials) {
        position p;
        for(p.ligne=0 ; p.ligne < 6 ; p.ligne++) {
            for(p.colonne=0 ; p.colonne < 6 ; p.colonne++) {
                if(!est_vide(p, adversaire())) {
                    ret.push_back(make_tuple(p, mater));
                }
            }
        }
    }

    auto ufind = make_ufind();
    sort(begin(ret), end(ret), [ufind](tuple<position, case_type> a, tuple<position, case_type> b) {
        position p1, p2;
        case_type m1, m2;
        tie(p1, m1) = a;
        tie(p2, m2) = b;
        return eval_catalysis(p1, m1, ufind) > eval_catalysis(p2, m2, ufind);
    });

    return ret;
}
