#include "savantio.h"
#include "savutil.h"
#include "yenta-savant.h"

int init_savant_yenta(char *);

int save_dv(DenseDocVec *);
void cat_doc(int index, DenseDocVec *add);
void remove_doc(int index);

int setup_compare();
/* Must be called after a call to acquire and before a call to find_matches.
 * (Finds term frequencies, updates caches (word and document), etc)
 * Call it until it returns true. (Each call does a certain ammount of work,
 * which may not be all of the work remaining) */

int close_savant_yenta(void);

static int wvtree_insert(unsigned int *, unsigned int);

/* To do: fix init to include the possibility of other files than wordvec not
 * working. */

#define MAXFNAMELEN 9
/* "word_vecs" */

static int init_new(char *db_path, char *db_filename)
/* create new files for storing the database, so that we know we can save
 * information we generate, and so we fail now instead of later. :)
 * Returns non-zero on failure. */
{
  char *db_namepos = db_filename + strlen(db_path);
  FILE *WORDVEC_FILE;
  int i;

  if (!(WORDVEC_FILE = fopen(db_filename, "w")))
    {
      free(db_filename);
      /* Report error */
      return 1;
    }
  
  Global_Tree = NULL;
  Global_path = strdup(db_path);
  Current_IDnum = 0;

  strcpy(db_namepos, DOCVEC_FNAME);

  DocVecs = load_coll(db_filename);

  fclose(WORDVEC_FILE);
  DVM_max = 8;
  DVM_rng = 0;
  DocVec_Mags = NULL;
  DVszs = (int *)malloc(sizeof(int)*8);
  DVColls = (DDV_coll **)malloc(sizeof(DDV_coll *)*8);
  for (i = 0; i < 8; i++)
    DVColls[i] = NULL;

  free(db_filename);

  return 0;
}

#define _CHECK_DATABASE_MSG_
/* #define _CHECK_DATABASE_MSG_TELL_PASSED_ */

/* So far the check is for:
 * - Matching ideas of the size of the database
 * - Existence of all centroids that should exist (including ones that
 *   aren't valid)
 * - non-negative tfidf values for all words
 * - no wordvec entries for documents past the size of the database
 */

static int check_tree(WV_Tree *tree)
{
  WV_List *trav;
  if (tree)
    {
      if (tree->num_docs > Num_DocVecs)
	{
#ifdef _CHECK_DATABASE_MSG_
	  printf("(%p)->num_docs > Num_DocVecs", tree);
#endif
	  return 0;
	}
      trav = tree->wvlist;
      while (trav)
	{
	  if (WV_DVNUM(trav->docweight) >= DVM_rng)
	    {
#ifdef _CHECK_DATABASE_MSG_
	      printf("%p contributes to document #%d, out of range", 
		     trav, WV_DVNUM(trav->docweight));
#endif		  
	      return 0;
	    }
	  trav = trav->next;
	}
      return (check_tree(tree->left) && check_tree(tree->right));
    }
  return 1;
}

int check_database()
{
  int i;
  if (DVM_rng != DocVecs->length)
    {
#ifdef _CHECK_DATABASE_MSG_
      printf("DVM_rng != DocVec->length\n");
#endif
      return 0;
    }
  for (i = 0; i < DVM_rng; i++)
    {
      if (!DocVecs->items[i])
	{
#ifdef _CHECK_DATABASE_MSG_
	  printf("!DocVecs->items[%d]\n", i);
#endif
	  return 0;
	}
    }
  if (!check_tree(Global_Tree))
    return 0;
#ifdef _CHECK_DATABASE_MSG_
#ifdef _CHECK_DATABASE_MSG_TELL_PASSED_
  printf("Database passed check.\n", i);
#endif
#endif
  return 1;
}

int init_savant_yenta(char *db_path)
     /* Returns non-zero on failure, no data structures are created */
{
  char *db_filename = (char *)malloc(strlen(db_path)+MAXFNAMELEN+2);
  char *db_namepos = db_filename + strlen(db_path);
  FILE *WORDVEC_FILE, *DVMAG_FILE;
  int i;

  strcpy(db_filename, db_path);

  strcpy(db_namepos, WORDVEC_FNAME);
  if (!(WORDVEC_FILE = fopen(db_filename, "r")))
    {
      /* perhaps we've got a new database */
      return init_new(db_path, db_filename);
    }

  strcpy(db_namepos, DVMAG_FNAME);
  if (!(DVMAG_FILE = fopen(db_filename, "r")))
    {
      fclose(WORDVEC_FILE);
      return init_new(db_path, db_filename);
    }

  Global_path = strdup(db_path);
  
  read_sizes(DVMAG_FILE);

  DocVec_Mags = NULL;

  if (!DVM_max) /* There's nothing actually in the files... */
    {
      fclose(WORDVEC_FILE);
      fclose(DVMAG_FILE);
      return init_new(db_path, db_filename);
    }

  Global_Tree = read_wvtree(WORDVEC_FILE);

  strcpy(db_namepos, DOCVEC_FNAME);

  DocVecs = load_coll(db_filename);

/*  DVszs = load_dv_szs(DVMAG_FILE);*/

  DVColls = (DDV_coll **)malloc(sizeof(DDV_coll *) * DVM_max);
  for (i = 0; i < DVM_max; i++)
    DVColls[i] = NULL;

  Current_IDnum = 0;
  while ((Current_IDnum < DocVecs->length) && (DVszs[Current_IDnum]))
    Current_IDnum++;

  fclose(WORDVEC_FILE);
  fclose(DVMAG_FILE);

  free(db_filename);

  /* Perhaps we should do a consistancy check at this point. */
  if (!check_database())
    {
      printf("database inconsistent.\n");
#ifdef _CHECK_DATABASE_MSG_
      exit(1);
#endif
    }

  return 0;
}

DenseDocVec *ddv_dup(DenseDocVec *ddv)
{
  DenseDocVec *ret = (DenseDocVec *)malloc(sizeof(DenseDocVec));
  int num;

  ret->num_entries = num = ddv->num_entries;
  ret->wordcodes = memcpy(malloc(num * WORD_ENCODE_WIDTH * sizeof(int)),
			  ddv->wordcodes,
			  num * WORD_ENCODE_WIDTH * sizeof(int));
  ret->weights = memcpy(malloc(sizeof(float) * num),
			ddv->weights, sizeof(float) * num);
  return ret;
}

void ddv_mult(DenseDocVec *ddv, int mult)
{
  int num;
  for (num = 0; num < ddv->num_entries; num++)
    ddv->weights[num] *= mult;
}

void wvtree_delete(unsigned int *wordcode, int index)
{
  WV_Tree *trav = Global_Tree;
  WV_List **lst, *tmp;
  int cmp;
  index = index<<WEIGHT_WIDTH; /* move this over to simplify matters */
  while (1)
    {
      cmp = wordcode_cmp(wordcode, trav->wordcode);
      if (cmp < 0)
	trav = trav->left;
      else if (cmp > 0)
	trav = trav->right;
      else
	{
	  lst = &(trav->wvlist);
	  while (((*lst)->docweight & ~((1<<WEIGHT_WIDTH)-1)) != index)
	    {
	      lst = &((*lst)->next);
	    }
	  tmp = *lst;
	  *lst = tmp->next;
	  free(tmp);
	  trav->num_entries--;
	  return;
	}
    }
}

float find_idf(unsigned int *wordcode)
/* Just gets the info for the word. Sometimes you just need this. (Computing
 * document magnitudes */
{
  WV_Tree *trav = Global_Tree;
  int cmp;
  while (trav)
    {
      cmp = wordcode_cmp(wordcode, trav->wordcode);
      if (cmp < 0)
	trav = trav->left;
      else if (cmp > 0)
	trav = trav->right;
      else
	return log(((float) Num_DocVecs)/trav->num_docs);
    }
  return 0.0;
}

void tfidf_inc_word(unsigned int *wordcode)
{
  WV_Tree **trav = &Global_Tree;
  int cmp, i;
  while (*trav)
    {
      cmp = wordcode_cmp(wordcode, (*trav)->wordcode);
      if (cmp < 0)
	trav = &((*trav)->left);
      else if (cmp > 0)
	trav = &((*trav)->right);
      else
	{
	  (*trav)->num_docs++;
	  return;
	}
    }
  *trav = (WV_Tree *)malloc(sizeof(WV_Tree));
  (*trav)->num_entries = 0;
  (*trav)->num_docs = 1;
  (*trav)->wvlist = NULL;
  (*trav)->right = (*trav)->left = NULL;
  for(i = 0; i < WORD_ENCODE_WIDTH; i++)
    (*trav)->wordcode[i] = wordcode[i];
}

float wvtree_replace(unsigned int *wordcode, unsigned int newweight)
{
  /* NOTE: the replacement weight must be for the same document; thus the
   * docnum part of newweight must be right. Returns "idf" for this word: the
   * log of the number of documents we've acquired divided by the number that
   * use this word. Used by incr and found in the wvtree, so here for 
   * efficiency, not a logical connection. */
  WV_Tree *trav = Global_Tree;
  WV_List *lst;
  int cmp;
  unsigned int match = newweight & ~((1<<WEIGHT_WIDTH) - 1);

  char *test_buf;

  while (trav)
    {
      cmp = wordcode_cmp(wordcode, trav->wordcode);
      if (cmp < 0)
	trav = trav->left;
      else if (cmp > 0)
	trav = trav->right;
      else
	{
	  lst = trav->wvlist;
	  while ((lst->docweight & ~((1<<WEIGHT_WIDTH)-1)) != match)
	    {
	      if ((lst) && (!lst->next))
		{
#if 0
		  test_buf = ddv2string(DocVecs->items[match>>WEIGHT_WIDTH],
					20);
		  printf("Document %d{%s} is missing wordvec entry\n",
			 DocVecs->items[match>>WEIGHT_WIDTH]->num_entries,
			 test_buf);
		  decode_word(wordcode, test_buf);
		  printf("#%d, %s\n", match>>WEIGHT_WIDTH, test_buf);
		  
		  free(test_buf);
#endif
		  /* Report error */
		}
	      lst = lst->next;
	    }
	  lst->docweight = newweight;
	  return log(((float)Num_DocVecs)/trav->num_docs);
	}
    }
}

void remove_doc(int index)
{
  DenseDocVec *ddv = DocVecs->items[index];
  int i;
  if ((index>DocVecs->length) || (!DVszs[index]))
    {
#if 0
      printf("Error-- removing a non-existant document!\n");
#endif
      /* report error */
      return;
    }
  for (i = 0; i < ddv->num_entries; i++)
    wvtree_delete(&(ddv->wordcodes[WORD_ENCODE_WIDTH*i]), index);
  DVszs[index] = 0.0;

  if (DVColls[index])
    {
      close_coll(DVColls[index]);
      DVColls[index] = NULL;
    }

  if (Current_IDnum > index) /* reuse holes */
    Current_IDnum = index;
}

void incr_dv(DenseDocVec *ddv, int index)
{
  DenseDocVec *dup = ddv_dup(ddv), *old = DocVecs->items[index];
  int i = 0, num = dup->num_entries, j, oldnum = old->num_entries, k;
  int newnum = num;
  unsigned int *wordcode;
  float mag = DocVec_Mags[index * NUM_FIELD_TYPES] * 
    DocVec_Mags[index * NUM_FIELD_TYPES]; /* a^2 = b^2 */
  float idf;

  for (i = 0; i < num; i++)
    {
      wordcode = dup->wordcodes + i*WORD_ENCODE_WIDTH;
      idf = find_idf(wordcode);
      mag += dup->weights[i] * dup->weights[i] * idf; /* + c^2 */
      for (j = 0; j < oldnum; j++)
	if (!wordcode_cmp(old->wordcodes+j*WORD_ENCODE_WIDTH, wordcode))
	  {
	    mag += 2 * old->weights[j] * dup->weights[i] * idf;
	    /* - 2 b c cos(A) (Note: this does sum out right, really) */
	    old->weights[j] += dup->weights[i];
	    /* We should find that old[j] is dup[i] */
	    if (old->weights[j] >= (1<<WEIGHT_WIDTH))
	      { /* The document has overrun the size; shrink everything */
#if 0
		fprintf(stderr, 
			"Document %d got too big; shrinking...\n", index);
#endif
		for (k = 0; k < oldnum; k++)
		  {
		    old->weights[k] = old->weights[k] / 2;
		    if (k == j)
		      wvtree_replace(old->wordcodes+k*WORD_ENCODE_WIDTH,
				     old->weights[k] + 
				     (index<<WEIGHT_WIDTH));
		    else
		      wvtree_replace(old->wordcodes+k*WORD_ENCODE_WIDTH,
				     old->weights[k] + (index<<WEIGHT_WIDTH));
		  }
		mag /= 4; /* half mag, quarter mag^2 */
		for (k = 0; k < num; k++)
		  dup->weights[k] /= 2;
#if 0
		fprintf(stderr, 
			"...done\n");
#endif
	      }
	    else
	      wvtree_replace(wordcode,
			     old->weights[j] + (index<<WEIGHT_WIDTH));

	    dup->weights[i] = 0.0; /* signal that it's handled */
	    newnum--;
	    break;
	  }
    }

  if (DVColls[index])
    coll_add_doc(DVColls[index], ddv_dup(ddv));

  newnum = (old->num_entries += newnum);
  old->wordcodes = 
    (unsigned int *)realloc(old->wordcodes,
			    newnum*WORD_ENCODE_WIDTH*sizeof(int));
  old->weights =
    (int *)realloc(old->weights, newnum*sizeof(int));
  for (i = 0; i < num; i++)
    {
      if (dup->weights[i]) /* didn't handle before */
	{
	  old->weights[--newnum] = dup->weights[i];
	  for (j = 0; j<WORD_ENCODE_WIDTH; j++)
	    {
	      old->wordcodes[j + newnum*WORD_ENCODE_WIDTH] =
		dup->wordcodes[j + i*WORD_ENCODE_WIDTH];
	    }
	  wvtree_insert_one(old->wordcodes + newnum*WORD_ENCODE_WIDTH,
			    (index<<WEIGHT_WIDTH) + dup->weights[i]);
	}
    }
  if (DocVec_Mags)
    DocVec_Mags[index*NUM_FIELD_TYPES] = (float)sqrt(mag); /* new mag */
  DVszs[index]++;
  destroy_ddv(dup);

}

int save_dv(DenseDocVec *ddv)
{
  int i;

  float mag; /* really mag^2 */

  if (Current_IDnum == DocVecs->length)
    coll_add_doc(DocVecs, ddv_dup(ddv)); /* don't steal the caller's copy */
  else
    {
      destroy_ddv(DocVecs->items[Current_IDnum]);
      DocVecs->items[Current_IDnum] = ddv_dup(ddv);
    }
  if (Current_IDnum >= DVM_max)
    {
      if (DocVec_Mags)
	{
	  DocVec_Mags = (float *)realloc(DocVec_Mags, 
					 (DVM_max *= 2) * sizeof(float) *
					 NUM_FIELD_TYPES);
	}
      else
	DVM_max *= 2;
      DVszs = (int *)realloc(DVszs, DVM_max * sizeof(float));
      DVColls = (DDV_coll **)realloc(DVColls, DVM_max * sizeof(DDV_coll *));
      for (i = DVM_max / 2; i < DVM_max; i++)
	DVColls[i] = NULL;
    }
  mag = 0;

  for (i = 0; i < ddv->num_entries; i++)
    {
      mag += ddv->weights[i] * ddv->weights[i] *
	find_idf(ddv->wordcodes + i * WORD_ENCODE_WIDTH);
    }
  
  if (DocVec_Mags)
    DocVec_Mags[Current_IDnum*NUM_FIELD_TYPES] = (float)sqrt((double)mag);
  DVszs[Current_IDnum] = 1;
  DVColls[Current_IDnum] = temp_coll();
  coll_add_doc(DVColls[Current_IDnum], ddv_dup(ddv));

  /* for each wordvec entry in the ddv, do a tree insert */

  i = Current_IDnum;

  wvtree_insert_ddv(ddv);

  while ((Current_IDnum < DVM_rng) && (DVszs[Current_IDnum]))
    Current_IDnum++; /* find the next spot not in use */
  if (Current_IDnum > DVM_rng)
    DVM_rng = Current_IDnum;
  return i;
}


int wvtree_insert(unsigned int *wordcode,
		  unsigned int docweight)
{
  int cmp,j;
  WV_Tree **treeptr = &Global_Tree;
  WV_List *others;
  
  while (1) {
    
    if(*treeptr == NULL) {  /* new word node */
      *treeptr = (WV_Tree *)malloc(sizeof(WV_Tree));
      for(j=0; j<WORD_ENCODE_WIDTH; j++) {
	(*treeptr)->wordcode[j] = wordcode[j];
      }
      (*treeptr)->num_entries = 1;
      (*treeptr)->wvlist = (WV_List *)malloc(sizeof(WV_List));
      (*treeptr)->wvlist->docweight = docweight;
      (*treeptr)->wvlist->next = NULL;
      (*treeptr)->right = (*treeptr)->left = NULL;
      return 0;
    }
  
    cmp = wordcode_cmp((*treeptr)->wordcode, wordcode);
    if(cmp < 0) {
      treeptr = &((*treeptr)->right);
    }
    else if(cmp > 0) {
      treeptr = &((*treeptr)->left);
    }
    else { /* proper node found */
      /* note that this works for adding 1 to 0 */
      others = (*treeptr)->wvlist;
      (*treeptr)->wvlist = (WV_List *)malloc(sizeof(WV_List));
      (*treeptr)->wvlist->docweight = docweight;
      (*treeptr)->wvlist->next = others;
      return(++((*treeptr)->num_entries));
    }
  }

}

static int checkpoint_database()
{
  FILE *file;
  char *db_filename;
  char *db_filepos; 

  db_filename = (char *)malloc(sizeof(char)*
			       (strlen(Global_path)+MAXFNAMELEN+1));
  db_filepos = db_filename + strlen(Global_path);
  
  check_database();

  strcpy(db_filename, Global_path);
  
  strcpy(db_filepos, WORDVEC_FNAME);
  file = fopen(db_filename, "w");
  write_wvtree(file, Global_Tree);
  fclose(file);

  strcpy(db_filepos, DVMAG_FNAME);
  file = fopen(db_filename, "w");
/*  write_dv_mags(file, DVMags, DocVecs->length);*/

  write_sizes(file, DVszs, DocVecs->length);
  fclose(file);
  free(db_filename);
}

int close_savant_yenta()
{
  if (!Global_path)
    return 0; /* We never openned the database... */

  checkpoint_database();
  close_coll(DocVecs);

  free(Global_path);
  Global_path = NULL;

  return 0;
}

int checkpoint_savant_yenta()
{
  if (!Global_path)
    return 0; /* Nope, nothing here to save... */

  checkpoint_database();
  checkpoint_coll(DocVecs);

  return 0;
}

  

