


/***********************************************************************

   This software is for research and educational purposes only.

************************************************************************/



#include "ZZ_p.h"
#include "tools.h"
#include "FFT.h"


ZZ_pInfoT::ZZ_pInfoT(const ZZ& NewP)
{
   p = NewP;
   size = p.size();
   ExtendedModulusSize = 2*size + (ZZ_BITS_PER_LONG + ZZ_NBITS - 1)/ZZ_NBITS;
   initialized = 0;
   x = 0;
   u = 0;
   tbl = 0;
   tbl1 = 0;
}



void ZZ_pInfoT::init()
{
   ZZ B, M, M1, MinusM;
   long n, i;
   long q, t;

   initialized = 1;

   sqr(B, p);

   LeftShift(B, B, FFTMaxRoot+FFTFudge);

   set(M);
   n = 0;
   while (M <= B) {
      UseFFTPrime(n);
      q = FFTPrime[n];
      n++;
      mul(M, M, q);
   }

   NumPrimes = n;
   MaxRoot = CalcMaxRoot(q);

   if (NumBits(n) >= ZZ_DOUBLE_PRECISION - 10)
      Error("modulus too big");
   QuickCRT = (NumBits(n) < ZZ_DOUBLE_PRECISION - ZZ_NBITS - 10);
   // These tests ensure the Kahan addition is accurate enough

   negate(MinusM, M);
   rem(MinusMModP, MinusM, p);

   CoeffModP.SetSize(n, p.size());

   if (!(x = (double *) malloc(n * (sizeof (double)))))
      Error("out of space");

   if (!(u = (long *) malloc(n * (sizeof (long)))))
      Error("out of space");

   for (i = 0; i < n; i++) {
      q = FFTPrime[i];

      div(M1, M, q);
      t = rem(M1, q);
      t = InvMod(t, q);
      mul(M1, M1, t);
      rem(CoeffModP[i], M1, p);
      x[i] = ((double) t)/((double) q);
      u[i] = t;
   }

   B = p;
   LeftShift(B, B, ZZ_NBITS);
   mul(B, B, NumPrimes);
   mmsize = B.size();

#if (defined(SINGLE_MUL))
   tbl = (double **) malloc(NumPrimes * sizeof(double *));
   if (!tbl) Error("out of space");
   for (i = 0; i < NumPrimes; i++) {
      tbl[i] = (double *) malloc(p.size() * sizeof(double));
      if (!tbl[i]) Error ("out of space");
   }

   long t1;
   long j;

   for (i = 0; i < NumPrimes; i++) {
      q = FFTPrime[i];
      t = (((long)1) << ZZ_NBITS) % q;
      t1 = 1;
      tbl[i][0] = (double) 1;
      for (j = 1; j < p.size(); j++) {
         t1 = MulMod(t1, t, q);
         tbl[i][j] = (double) t1;
      }
   }
#else
   tbl = 0;
#endif

#if (defined(TBL_REM) && !defined(SINGLE_MUL))
   tbl1 = (long **) malloc(NumPrimes * sizeof(long *));
   if (!tbl1) Error("out of space");
   for (i = 0; i < NumPrimes; i++) {
      tbl1[i] = (long *) malloc(p.size() * sizeof(long));
      if (!tbl1[i]) Error ("out of space");
   }

   long t1;
   long j;

   for (i = 0; i < NumPrimes; i++) {
      q = FFTPrime[i];
      t = (((long)1) << ZZ_NBITS) % q;
      t1 = 1;
      tbl1[i][0] = 1;
      for (j = 1; j < p.size(); j++) {
         t1 = MulMod(t1, t, q);
         tbl1[i][j] = t1;
      }
   }
#else
   tbl1 = 0;
#endif
}

ZZ_pInfoT::~ZZ_pInfoT()
{
   long i;

   if (initialized) {
      free(x);
      free(u);
#if (defined(SINGLE_MUL))
      for (i = 0; i < NumPrimes; i++)
         free(tbl[i]);

      free(tbl);

#elif (defined(TBL_REM))
      for (i = 0; i < NumPrimes; i++)
         free(tbl1[i]);

      free(tbl1);
#endif
   }
}



ZZ_pInfoT *ZZ_pInfo = 0; 

void ZZ_pInit(const ZZ& NewP)
{
   if (NewP <= 1) Error("ZZ_pInit: modulus must be > 1");

   delete ZZ_pInfo;
   ZZ_pInfo = new ZZ_pInfoT(NewP);
}

void ZZ_pBak::save()
{
   delete ptr;
   MustRestore = 1;
   ptr = ZZ_pInfo;
   ZZ_pInfo = 0;
}

void ZZ_pBak::move()
{
   delete ptr;
   MustRestore = 0;
   ptr = ZZ_pInfo;
   ZZ_pInfo = 0;
}

void ZZ_pBak::restore()
{
   delete ZZ_pInfo;
   ZZ_pInfo = ptr;
   MustRestore = 0;
   ptr = 0;
}





ZZ_p ZZ_p::_zero(ZZ_p_NoAlloc);

ZZ_p::DivHandlerPtr ZZ_p::DivHandler = 0;

ZZ_p::ZZ_p(INIT_VAL_TYPE, const ZZ& a) : rep(INIT_SIZE, ModulusSize())
{
   *this << a;
} 

ZZ_p::ZZ_p(INIT_VAL_TYPE, long a) : rep(INIT_SIZE, ModulusSize())
{
   *this << a;
}


void operator<<(ZZ_p& x, long a)
{
   static ZZ y;

   y << a;
   x << y;
}

istream& operator>>(istream& s, ZZ_p& x)
{
   static ZZ y;

   s >> y;
   x << y;

   return s;
}

void div(ZZ_p& x, const ZZ_p& a, const ZZ_p& b)
{
   ZZ_p t;

   inv(t, b);
   mul(x, a, t);
}

void div(ZZ_p& x, const ZZ_p& a, long b)
{
   ZZ_p B;
   B << b;
   div(x, a, B);
}

void div(ZZ_p& x, long a, const ZZ_p& b)
{
   ZZ_p A;
   A << a;
   div(x, A, b);
}


void inv(ZZ_p& x, const ZZ_p& a)
{
   if (InvModStatus(x.rep, a.rep, ZZ_p::modulus())) {
      if (IsZero(a.rep))
         Error("ZZ_p: division by zero");
      else if (ZZ_p::DivHandler)
         (*ZZ_p::DivHandler)(a);
      else
         Error("ZZ_p: division by non-invertible element");
   }
}

