(** (c) 2016 peio borthelle. Ce code est distribue sous license GPLv3.
 *
 * Cette IA effectue un Monte-Carlo Tree Search (avec la formule UCB1). Pour
 * aider un peu la recherche, il a des heuristiques de selection des coups
 * possibles assez violentes (cela reduit drastiquement le degre et la marge de
 * manoeuvre). *)


open Api

module M = Map.Make(struct type t = position let compare = compare end)


(**************************************)
(* Reimplementation du moteur de jeu. *)
(**************************************)

type game_state = {map: case_type M.t; plasma: (float * position) list;
                   aspi: int M.t; score_me: float; score_him: float;
                   turn: int}

type action =
| Construire of position
| Detruire of position
| Ameliorer of position
| Deplacer_aspiration of position * position

let none = max_int

let is_allowed pa action state = match action with
| Construire p ->
        M.find p state.map = Vide
| Detruire p ->
        M.find p state.map = Tuyau || M.find p state.map = Super_tuyau
| Ameliorer p ->
        M.find p state.map = Tuyau
| Deplacer_aspiration (p, p') ->
        M.find p state.aspi > 0 && M.find p' state.aspi < 5

let all_children (x, y) =
    List.filter
        (fun (x', y') ->
            0 <= x' && x' < taille_terrain &&
            0 <= y' && y' < taille_terrain)
        (List.map (fun (dx, dy) -> (x+dx, y+dy)) [(0,1);(1,0);(0,-1);(-1,0)])

let children state pos =
    List.filter
        (fun p -> (M.find p state.map = Tuyau || M.find p state.map = Super_tuyau))
        (all_children pos)

let mini def f l =
    snd (List.fold_left (fun (m, fm) x ->
                             let fx = f x in
                             if fx < fm then (x, fx) else (m, fm)) def l)

let argmax f l = match l with
| [] -> failwith "empty list"
| x :: xs -> 
    fst (List.fold_left (fun (m, fm) x ->
                             let fx = f x in
                             if fx < fm then (x, fx) else (m, fm)) (x, f x) xs)

(** Compute the aspiration power of position. *)
let get_aspiration state pos =
    let min_asp = mini ((0,0), none) (fun p -> M.find p state.aspi)
                            (children state pos) in
    if min_asp <> none then min_asp + 1 else none

(** Propagate an aspiration change on [pos] to neighbours. *)
let rec propagate_aspi state pos =
    List.fold_left
        (fun st p ->
            let asp = get_aspiration st p in
            if M.find p st.aspi <> asp then
                propagate_aspi {st with aspi=M.add p asp st.aspi} p
            else
                state)
        state (children state pos)

(** Execute an action on the game state. *)
let exec st action = match action with
| Construire p ->
        propagate_aspi {st with aspi=M.add p (get_aspiration st p) st.aspi;
                                   map=M.add p Tuyau st.map} p
| Detruire p ->
        propagate_aspi {st with map=M.add p Debris st.map;
                                aspi=M.add p none st.aspi} p
| Ameliorer p ->
        {st with map=M.add p Super_tuyau st.map}
| Deplacer_aspiration (p, p') ->
        let aspi = M.add p (M.find p st.aspi - 1)
                   (M.add p' (M.find p' st.aspi + 1) st.aspi) in
        propagate_aspi (propagate_aspi {st with aspi=aspi} p) p'

(** Execute an action on the server. *)
let exec_really action = match action with
| Construire p -> construire p
| Detruire p -> detruire p
| Ameliorer p -> ameliorer p
| Deplacer_aspiration (p, p') -> deplacer_aspiration p p'

let flat_map f l = List.flatten (List.map f l)

(** Make the plasma move (ow yeah). *)
let finish_turn state =
    let spawn_plasma pos =
        let info = info_pulsar pos in
        if info.pulsations_totales - state.turn / info.periode > 0 &&
               state.turn mod info.periode = 0 then
            List.map (fun c -> info.puissance /. 4., c) (all_children pos)
        else [] in

    let mv_1 (x, pos) =
        if M.find pos state.map = Base then [(x, pos)]  (* a bit of a hack *)
        else 
            let asp = get_aspiration state pos - 1 in
            let succ = List.filter (fun p -> M.find p state.aspi = asp)
                                   (children state pos) in
            let n = float (List.length succ) in
            if n = 0. then [] else List.map (fun s -> (x /. n, s)) succ in

    let mv_plasma (x, pos) = match M.find pos state.map with
    | Tuyau -> mv_1 (x, pos)
    | Super_tuyau -> flat_map mv_1 (mv_1 (x, pos))
    | _ -> [] in

    let l, moved = List.partition (fun (x, p) -> M.find p state.map = Base)
                                  (flat_map mv_plasma state.plasma) in
    let sum = List.fold_left (fun s pl -> s +. (fst pl)) 0. in
    let s_me, s_him =
        let a, b = List.partition (fun (x, p) -> proprietaire_base p = moi ()) l in
        sum a, sum b in

    {state with score_me = state.score_me +. s_me;
     score_him = state.score_him +. s_him;
     plasma = List.rev_append (flat_map spawn_plasma (Array.to_list (liste_pulsars ()))) moved}

let state_from_current () =
    let rec loop i m =
        let p = (i / taille_terrain, i mod taille_terrain) in
        if i >= 0 then loop (i-1) (M.add p (type_case p) m)
        else m in
    let map = loop (taille_terrain * taille_terrain - 1) M.empty in
    let state = {map; aspi = M.empty; score_me = float (score (moi ()));
                 score_him = float (score (adversaire ())); turn = tour_actuel ();
                 plasma = List.map (fun pos -> charges_presentes pos, pos)
                                   (Array.to_list (liste_plasmas ()))} in
    let aux st p =
        propagate_aspi {st with aspi = M.add p (-puissance_aspiration p)
        st.aspi} p in
    Array.fold_left aux (Array.fold_left aux state (base_ennemie ())) (ma_base ())


(***************)
(* MCTS search *)
(***************)

let heur_succ st = [] (*TODO*)

type node_content = {succ: (action * game_tree) list; player: int; w: int; n: int}
and game_tree =
| Leaf of game_state
| Node of node_content

let simulate state =
    true

let select_next nodes = match List.filter (fun n -> match n with Leaf _ -> true | _ -> false) nodes with
| [] ->
    let n = float (List.fold_left (fun acc (Node node) -> acc + node.n) 0 nodes) in
    argmax (fun (Node node) -> let n_n = float node.n in
                (float node.w) /. n_n +. sqrt (2. *. (log n) /. n_n)) nodes
| x :: _ -> x

let rec expand tree = match tree with
| Leaf state ->
        let res = simulate state in
        res, Node {succ = List.map (fun a -> a, Leaf (exec state a)) (heur_succ state);
                   w = if res then 1 else 0; n = 1}
| Node node ->
        let next = select_next (List.map snd node.children) in
        let res, tree = expand next (score, seq) in
        res, Node {node with n = node.n + 1; w = node.w + if res then 1 else 0}


(*************)
(* Callbacks *)
(*************)

(** Fonction appelée au début de la partie. *)
let partie_init () =
    flush stderr; flush stdout;;

(** Fonction appelée à chaque tour. *)
let jouer_tour () =
    let ts = Sys.time () in
    let tree = Leaf (state_from_current ()) in

    let rec loop tree =
        if Sys.time () - ts < 0.9 then loop (snd (expand tree)) else tree in

    match loop tree with
    | Node root ->
        List.iter exec_really (argmax (fun (_, Node node) ->
            (float node.w) /. (float node.n)) root.succ)
    | _ -> print_string "oups! couldn't compute anything to do...";

    flush stderr; flush stdout;;

(** Fonction appelée à la fin de la partie. *)
let partie_fin () =
    flush stderr; flush stdout;;

(* /!\ Ne touche pas a ce qui suit /!\ *)
Callback.register "ml_partie_init" partie_init;;
Callback.register "ml_jouer_tour" jouer_tour;;
Callback.register "ml_partie_fin" partie_fin;;
