#include "computer.h"
#include "sbr.h"
#include "alphabeta.h"
#include "genrand.h"
#include "board.h"
#include "movegen.h"
#include "game.h"
#include "inline/game.h"
#include "names.h"
#include "timer.h"

#ifdef DEBUG
#define abtrace printf
#else
#define abtrace(s,...)
#endif

#undef PRINT_MOVE_EVALUATIONS
#undef CLEAR_KILLER_TABLE
#undef CLEAR_TRANSPOSITION_TABLE

#define print_iter if (game->output_iteration) game->output_iteration
#define uci if (game->uci_output) game->uci_output

#define move_string(m, null) short_move_string(m, NULL, null)
/* Retrieve principle variation from the hash table */
void retrieve_principle_variation(gamestate_t *game, move_t move)
{
   static int n = 30;
   hash_table_entry_t *hash;

   if (n == 0) return;
   n--;
   
   print_iter("%s ", move_string(move, NULL));
   playmove(game, move);

   /* Look up th resulting position in the hash table */
   hash = query_table_entry(game->transposition_table, game->board->hash);

   /* Recursive call to get next move from the table */
   if (hash)
      retrieve_principle_variation(game, hash->best_move);

   takeback(game);
}

void insert_pv(gamestate_t *game, alphabeta_state_t *state, move_t move, int score)
{
   int depth;
   int n;

   return;

   depth = state->size_of_variation+1;

   store_table_entry(game->transposition_table, game->board->hash, 
                           depth, score, HASH_TYPE_EXACT, move);

   playmove(game, move);
   for (n=0; n<state->size_of_variation-1; n++) {
      store_table_entry(game->transposition_table, game->board->hash, 
            depth, score, HASH_TYPE_EXACT, state->mainline[n]);
      playmove(game, state->mainline[n]);
   }

   /* return the board to its previous state */
   for (n=0; n<=state->size_of_variation-1; n++)
      takeback(game);
   takeback(game);
}

static int movelist_sorter(void *score_ptr, const void *mi1, const void *mi2)
{
   int *score = (int*)score_ptr;
   int i1 = *(int*)mi1;
   int i2 = *(int*)mi2;

   return score[i2] - score[i1];
}

#undef move_string(m, null)
#define move_string(m, null) short_move_string(m, &movelist, null)
#define move_stringv(m, null) short_move_string(m, NULL, null)
void computer_play(gamestate_t *game, int max_depth)
{
   static const char *output_lines[] = {
      "<%2d>   ---   % 4.2f %7d   %+7.2f -> %+7.2f  %2d /%2d %-5s",
      "<%2d>         % 4.2f %7d   %+7.2f -> %+7.2f  %2d /%2d %-5s",
      "<%2d>   !!!   % 4.2f %7d   %+7.2f -> %+7.2f  %2d /%2d %-5s",
      "[%2d] %+7.2f % 4.2f %7d (<BR>=%.2f, EBR=%.2f) %2d /%2d %-5s"
   };
   uint64_t start_time, time, prev_time, prev_prev_time;
   alphabeta_state_t principle_variation;
   movelist_t movelist;
   int move_perm[MAX_MOVES];
   int score[MAX_MOVES];
   int alpha, beta;
   int previous_positions_evaluated;
   int best_score;
   int best_move;
   int legal_moves;
   int depth;
   int sdepth, sqdepth;
   int me;
   int c, n;
   int best_mate_score;
   hash_table_entry_t *hash;

   game->extra_time = 0;
   game->running_clock = game->player>>7;
   game->root_moves_played = game->moves_played;
   start_time = get_timer();
   start_clock(game);

   /* Setup data structures */
#ifdef CLEAR_KILLER_TABLE
   memset(game->killer, 0, sizeof game->killer);
   //memset(game->mate_killer, 0, sizeof game->mate_killer);
#endif

#ifdef CLEAR_TRANSPOSITION_TABLE
   if (game->transposition_table) {
      memset(game->transposition_table->data, 0, game->transposition_table->number_of_elements*sizeof(hash_table_entry_t));
      memset(game->transposition_table->depth_data, 0, game->transposition_table->number_of_elements*sizeof(hash_table_entry_t));
   }
#endif

   /* Age entries in the transposition table: this to ensure that obsolete
    * entries will be replaced eventually.
    */
   if (game->transposition_table) {
      for (n=0; n<game->transposition_table->number_of_elements; n++) {
         if (game->transposition_table->data[n].age)
            game->transposition_table->data[n].age--;
         if (game->transposition_table->depth_data[n].age)
            game->transposition_table->depth_data[n].age--;
      }
   }

   memset(score, 0, sizeof score);
   memset(game->history, 0, sizeof game->history);
   game->max_history = 0;
#ifdef PRINCIPLE_VARIATION_IN_GAME_STRUCT
   memset(game->principle_variation, 0, sizeof game->principle_variation);
   game->length_of_variation[0] = 0;
#endif

   generate_moves(&movelist, game->board, game->player);

   if (max_depth>MAX_SEARCH_DEPTH) max_depth = MAX_SEARCH_DEPTH;
   best_move = 0;
   best_score = -3*CHECKMATE;
   alpha = -3*CHECKMATE;
   beta = 3*CHECKMATE;
   principle_variation.size_of_variation = 0;
   sdepth = sqdepth = 0;

   me = game->player;
   legal_moves = movelist.num_moves;

   abort_search = false;
   previous_positions_evaluated = 0;
   positions_evaluated = 0;
   moves_searched = 0;
   positions_in_hashtable = 0;
   branches_pruned = 0;
   time = prev_time = prev_prev_time = 0;

   game->score[game->moves_played+1] = -game->score[game->moves_played];

   /* If we found a mate in the position before, we want to make sure we're
    * at least going to find a mate that is as good ion this iteration/
    */
   best_mate_score = CHECKMATE-1000;
   if (game->moves_played>1 &&
       game->score[game->moves_played-1] > best_mate_score) {
      best_mate_score = game->score[game->moves_played-1]+1;
   }
   if (game->score[game->moves_played] < -best_mate_score) {
      best_mate_score = -game->score[game->moves_played]+1;
   }

   /* First step: de a 2-ply initial full-breath search (to establish move
    * ordering)
    */
   prev_time = time = get_timer();
   prev_prev_time = prev_time;
   for (n=0; n<movelist.num_moves; n++) {
      move_perm[n] = n;
      alphabeta_state_t state;

#ifdef PRINT_MOVE_EVALUATIONS
       print_iter("     %2d/%d (%+3.2f) %-5s %+3.2f",
             n+1, legal_moves,
             best_score/100.0, 
             move_string(movelist.move[move_perm[n]], NULL),
             score[move_perm[n]]/100.0);
       fflush(stdout);
#endif

      playmove(game, movelist.move[n]);
      if (!player_in_check(game, me)) {   /* We found a legal move! */
         state = alphabeta(game, 1*PLY, 1*PLY, alpha, beta, false);
         score[n] = -state.score;
         if (score[n] > best_score) {
            best_score = score[n];
            best_move = move_perm[n];
            memcpy(&principle_variation, &state, sizeof state);
            sdepth = state.depth;
            sqdepth = state.qdepth;
         }
      } else {
         legal_moves--;
         score[n] = -3*CHECKMATE;
#ifdef PRINT_MOVE_EVALUATIONS
         print_iter(" (illegal)\n");
#endif
      }
      takeback(game);
#ifdef PRINT_MOVE_EVALUATIONS
         print_iter(" -> %+3.2f %2d/%d\n",
                     state.score/100.0, state.depth, state.qdepth);
#endif
   }
   prev_prev_time = prev_time;
   prev_time = time;
   time = get_timer();

   /* Hmm... no legal moves? */
   if (legal_moves == 0)
      return;

   if (game->fifty_counter[game->moves_played] >= 100) {
      print_iter("Draw (50 moves without pawn move or capture)\n");
      return;
   }
   if (game->repetition_hash_table[game->board->hash&0xFFFF]>2
         && draw_by_repetition(game)) {
      print_iter("Draw (position repeated 3 times)\n");
      return;
   }

   /* Look up the root position in the hash table, re-search the move from
    * the hash table first
    */
   hash = query_table_entry(game->transposition_table, game->board->hash);

   /* Detect problems with the hash table: the current position is scored
    * as a mate-in-N, but its descendant is scored as a mate-in-N+1
    */
   if (hash && hash->best_move.m && (hash->score) > (CHECKMATE - 1000)) {
      best_mate_score = hash->score;
   }

   /* Work around an annoying problem: if the current position is scored as
    * a mate, but it has been repeated, then there is a problem somewhere
    * and the score is unreliable...
    */
   if (best_mate_score > (CHECKMATE-1000) && count_repetition(game)>=1) {
      memset(game->transposition_table->data, 0, game->transposition_table->number_of_elements*sizeof(hash_table_entry_t));
      memset(game->transposition_table->depth_data, 0, game->transposition_table->number_of_elements*sizeof(hash_table_entry_t));
   }

   /* Sort the move list based on the score */
   qsort_r(move_perm, movelist.num_moves, sizeof *move_perm, score,
                                                            movelist_sorter);
   if (hash && !moves_are_equal(hash->best_move,movelist.move[move_perm[0]])) {
      for (n=0; n<legal_moves; n++) {
         if (moves_are_equal(hash->best_move, movelist.move[move_perm[n]])) {
            int k = move_perm[n];
            move_perm[n] = move_perm[0];
            move_perm[0] = k;
            break;
         }
      }
   }
   best_move = move_perm[0];
   print_iter(output_lines[3], 2,
         score[best_move]/100.0, (get_timer() - start_time)/1000000.0,
         positions_evaluated,
         (float)moves_searched/positions_evaluated,
         (float)(time - prev_time)/(start_time - prev_prev_time),
         sdepth, sqdepth,
         move_string(movelist.move[best_move], NULL));
   print_iter("\n");

#ifdef PRINCIPLE_VARIATION_IN_GAME_STRUCT
   /* Store principle variation (such as it is at the moment) */
   game->principle_variation[0][0] = movelist.move[best_move];
   for (c=1; c < game->length_of_variation[1]; c++) {
      game->principle_variation[0][c] =
         game->principle_variation[1][c];
   }
   game->length_of_variation[0] = game->length_of_variation[1];
#endif

   /* Uh oh... */
   if (legal_moves == 1) {
      playmove(game, movelist.move[best_move]);
      return;
   }

   uci("info depth 1 score cp %d nodes %d time %d\n", score[best_move], positions_evaluated, peek_timer(game));

   for (depth=2; depth<max_depth; depth++) {
      int best_score = score[best_move];
      int new_best_move;

#ifdef USE_HISTORY_HEURISTIC
      /* Scale down history scores */
      while (game->max_history > 200000) {
         int k;
         for (n=0; n<64; n++) {
            for (k=0; k<64; k++) {
               game->history[0][n][k] /= 2;
               game->history[1][n][k] /= 2;
            }
         }
         game->max_history /= 2;
      }
#endif

      new_best_move = best_move;

      previous_positions_evaluated = positions_evaluated;

      for (n=0; n<legal_moves; n++) {
         alphabeta_state_t state;
         bool open_window_wide = false;
         bool in_pv = (n == 0);
         int open_alpha_window = 1;
         int open_beta_window = 1;

         uci("info currmove %s%s\n",
               square_str[movelist.move[move_perm[n]].from],
               square_str[movelist.move[move_perm[n]].to]);

         /* Don't waste time re-searching moves for which we already know
          * they lead to a forced mate.
          */
         if (score[move_perm[n]] < -(CHECKMATE-1000)) continue;

         /* If we've found a forced check mate, there's no point in trying to
          * improve our score here...
          */
         if (best_score >= best_mate_score) {
            break;
         }

         /* Aspiration search */
         if (move_perm[n] == new_best_move) {
            alpha = best_score - 50;
            beta = best_score + 50;
         } else {
            alpha = best_score;
            beta = best_score + 1;
         }

#ifdef PRINT_MOVE_EVALUATIONS
       print_iter("     %2d/%d (%+3.2f) %-5s %+3.2f",
             n+1, legal_moves,
             best_score/100.0, 
             move_string(movelist.move[move_perm[n]], NULL),
             score[move_perm[n]]/100.0);
      fflush(stdout);
#endif

         playmove(game, movelist.move[move_perm[n]]);
#ifdef PRINT_MOVE_EVALUATIONS
         hash = query_table_entry(game->transposition_table, game->board->hash);
#endif
         while (true) {
            state = alphabeta(game, depth*PLY, depth*PLY, -beta, -alpha, in_pv);
            state.score = -state.score;

            /* Fail high/low, adjust window.
             * NB: this assumes fail soft alpha-beta, which may return a
             * score outside of [alpha, beta]!
             */
            /* Non-PV moves are expected to fail low, so only research fail
             * low for PV-nodes.
             */
            if (state.score < alpha && n==0) {
               abtrace("alphabeta [%d %d] returned %d\n",alpha,beta,state.score);
               alpha = state.score - open_alpha_window;
               open_alpha_window *= 100;
               abtrace("\n");
               abtrace("Fail low %d, set alpha to %d\n", state.score, alpha);
               if (n == 0) {  /* Best move failed low, allocate extra time */
                  game->extra_time =
                     peek_timer(game) / 2 + game->time_inc[game->player>>7];
               }
               continue;
            }
            if (state.score >= beta) {
               abtrace("alphabeta [%d %d] returned %d\n",alpha,beta,state.score);
               beta = state.score + open_beta_window;
               open_beta_window *= 100;
               abtrace("\n");
               abtrace("Fail high %d, set beta to %d\n", state.score, beta);
               open_window_wide = true;
               in_pv = true;
               continue;
            }
            break;
         } 
         takeback(game);

#ifdef PRINT_MOVE_EVALUATIONS
         if (abort_search) {
            print_iter("(interrupted)\n");
            break;
         }
         print_iter(" -> %+3.2f %2d/%d\n",
                     state.score/100.0, state.depth, state.qdepth);
         /*
         if (state.size_of_variation) {
            int n;
            print_iter("         ");
            for (n=0; n<state.size_of_variation; n++) {
               print_iter("%s ",
                     move_string(state.mainline[n], NULL));
            }
            print_iter("\n");
         }
         */
         if (hash) {
#if 0
            print_iter("           Transposition table: %.2f %02x %d ... %s\n",
                  -hash->score/100.0, hash->flags, hash->depth+1,
                  move_stringv(hash->best_move, NULL));
            print_iter("           < ");

            if (state.size_of_variation) {
               int n;
               for (n=0; n<state.size_of_variation; n++) {
                  print_iter("%s ", move_stringv(state.mainline[n], NULL));
               }
            } else {
               retrieve_principle_variation(game, movelist.move[move_perm[n]]);
            }
            print_iter(">\n");
#endif
         } else {
            print_iter("           (no hash)\n");
         }
#endif
         if (abort_search) break;

         /* store principle variation */
         if (n == 0) {
            memcpy(&principle_variation, &state, sizeof state);
            sdepth = state.depth;
            sqdepth = state.qdepth;
         }

         /* Update score for this move */
         score[move_perm[n]] = state.score;

         /* Best move score decreased? */
         if (new_best_move == move_perm[n] && state.score<(best_score-50)) {
            print_iter(output_lines[0],
                  depth+1, (get_timer() - start_time)/1000000.0,
                  positions_evaluated,
                  best_score/100.0, state.score/100.0,
                  state.depth, state.qdepth,
                  move_string(movelist.move[move_perm[n]], NULL));
            if (state.size_of_variation) {
               int n;
               for (n=0; n<state.size_of_variation; n++) {
                  print_iter("%-5s ", move_string(state.mainline[n], NULL));
               }
            }
            print_iter("\n");
         } else if (state.score>(best_score+50)) {
            /* New best move, or score improved substantially */
            print_iter(output_lines[2],
               depth+1, (get_timer() - start_time)/1000000.0,
               positions_evaluated,
               best_score/100.0, state.score/100.0,
               state.depth, state.qdepth,
               move_string(movelist.move[move_perm[n]], NULL));
            if (state.size_of_variation) {
               int n;
               for (n=0; n<state.size_of_variation; n++) {
                  print_iter("%-5s ", move_string(state.mainline[n], NULL));
               }
            }
            print_iter("\n");
         }

         /* New best move, or score increased substantially */
         if ( (state.score > best_score) || (new_best_move == move_perm[n])) {
            if (new_best_move != move_perm[n] && state.score<=(best_score+50)) {
               print_iter(output_lines[1],
                     depth+1, (get_timer() - start_time)/1000000.0,
                     positions_evaluated,
                     best_score/100.0, state.score/100.0,
                     state.depth, state.qdepth,
                     move_string(movelist.move[move_perm[n]], NULL));
               if (state.size_of_variation) {
                  int n;
                  for (n=0; n<state.size_of_variation; n++) {
                     print_iter("%-5s ", move_string(state.mainline[n], NULL));
                  }

                  /* Insert the principle variation back into the hash
                   * table
                   */
                  insert_pv(game, &state,
                        movelist.move[move_perm[n]], state.score);
               }
               print_iter("\n");
            }

            /* Copy principle variation */
            if (new_best_move != move_perm[n]) {
#ifdef PRINCIPLE_VARIATION_IN_GAME_STRUCT
               game->principle_variation[0][0] = movelist.move[move_perm[n]];
               for (c=1; c < game->length_of_variation[1]; c++) {
                  game->principle_variation[0][c] =
                     game->principle_variation[1][c];
               }
               game->length_of_variation[0] = game->length_of_variation[1];
#endif
               memcpy(&principle_variation, &state, sizeof state);
            }

            best_score = state.score;
            new_best_move = move_perm[n];

#ifdef PRINCIPLE_VARIATION_IN_GAME_STRUCT
            int c2;
            for (c2 = 0; c2<1+0*depth; c2++) {
               printf(" (%2d)                                                   ", game->length_of_variation[c2]);
               for (c=0; c<game->length_of_variation[c2]; c++) {
                  move_t move = game->principle_variation[c2][c];
                  printf("%-5s ", move_string(move, NULL));
                  //printf("(%s%s) ", square_str[move.from], square_str[move.to]);
               }
               printf("\n");
            }
#endif

            sdepth = state.depth;
            sqdepth = state.qdepth;
            memcpy(&principle_variation, &state, sizeof state);
            store_table_entry(game->transposition_table, game->board->hash, 
                  depth, best_score, HASH_TYPE_EXACT,
                  movelist.move[new_best_move]);
         }
      }  /* end of loop over all moves */

      if (abort_search) break;

      /* Sort the move list based on the score */
      qsort_r(move_perm, movelist.num_moves, sizeof *move_perm, score,
            movelist_sorter);
      //new_best_move = move_perm[0];
      best_move = new_best_move;

      prev_prev_time = prev_time;
      prev_time = time;
      time = get_timer();

      print_iter(output_lines[3], depth+1,
            score[best_move]/100.0, (get_timer() - start_time)/1000000.0,
            positions_evaluated,
            (float)moves_searched/positions_evaluated,
            (float)(time - prev_time)/(prev_time - prev_prev_time),
            sdepth, sqdepth,
            move_string(movelist.move[best_move], NULL));
      for (n=0; n<principle_variation.size_of_variation; n++) {
         print_iter("%-5s ", move_stringv(principle_variation.mainline[n], NULL));
      }
      print_iter("\n");

#if 0
      if (principle_variation.size_of_variation) {
         print_iter("         ");
         for (n=0; n<principle_variation.size_of_variation; n++) {
            print_iter("%s ", move_stringv(principle_variation.mainline[n], NULL));
         }
         print_iter("\n");
      }
#endif

      uci("info depth %d score cp %d nodes %d time %d\n", depth+1,
         score[best_move], positions_evaluated, peek_timer(game));
      uci("info nps %lld\n", positions_evaluated * 1000 / (start_time - get_timer()));

      /* Break out early if we've found a checkmate */
      if (abs(best_score) >= best_mate_score) {
         int ply = CHECKMATE - abs(best_score);
         print_iter("Mate in %d (%d ply)\n", ply/2 + 1, ply);
         break;
      }
   }

   print_iter(" --> %s %.2f\n", move_string(movelist.move[best_move], NULL), score[best_move]/100.0);
   print_iter("     ");
   retrieve_principle_variation(game, movelist.move[best_move]);
   print_iter("\n");
   print_iter("%d nodes visited (%d [%.2f%%] in transposition table), "
              "%g nodes/s\n"
              "%d branches pruned\n"
              "%d moves searched (average branching ratio %.2f)\n",
         positions_evaluated,
         positions_in_hashtable,
         100.0*(float)positions_in_hashtable/(positions_in_hashtable + positions_evaluated),
         1.0e6*(positions_evaluated+positions_in_hashtable) / (time - start_time),
         branches_pruned,
         moves_searched,
         (float)moves_searched/positions_evaluated);

   playmove(game, movelist.move[best_move]);
   if (game->repetition_hash_table[game->board->hash&0xFFFF]>2
         && draw_by_repetition(game)) {
      print_iter("Draw (position will repeat for the third time after %s)\n",
                  move_string(movelist.move[best_move], NULL));
      return;
   }

   game->score[game->moves_played] = score[best_move];

   return;
}
