

/***********************************************************************

   This software is for research and educational purposes only.

************************************************************************/



#include "ZZ.h"
#include "tools.h"
#include <stdlib.h>


ZZ ZZ::_zero;



long ZZ::size() const
{
   if (!rep || (rep[0] == 1 && rep[1] == 0))
      return 0;
   else if (rep[0] < 0)
      return -rep[0];
   else 
      return rep[0];
}

long digit(const ZZ& a, long i)
{
   verylong rep = a.rep;

   if (i < 0 || !rep) return 0;

   long sa = rep[0];
   if (sa < 0) sa = -sa;
   if (i >= sa) return 0;
   return rep[i+1];
}


long IsOne(const ZZ& a)
{
   return a.rep != 0 && a.rep[0] == 1 && a.rep[1] == 1;
}


void AddMod(ZZ& x, const ZZ& a, long b, const ZZ& n)
{
   static ZZ B;
   B << b;
   AddMod(x, a, B, n);
}

void AddMod(ZZ& x, long a, const ZZ& b, const ZZ& n)
{
   static ZZ A;
   A << a;
   AddMod(x, A, b, n);
}

void SubMod(ZZ& x, const ZZ& a, long b, const ZZ& n)
{
   static ZZ B;
   B << b;
   SubMod(x, a, B, n);
}

void SubMod(ZZ& x, long a, const ZZ& b, const ZZ& n)
{
   static ZZ A;
   A << a;
   SubMod(x, A, b, n);
}

void RandomBnd(ZZ& x, long n)
{
   static ZZ N;
   N << n;
   RandomBnd(x, N);
}

// ****** input and output

static long iodigits = 0;
static long ioradix = 0;

// iodigits is the greatest integer such that 10^{iodigits} < ZZ_RADIX
// ioradix = 10^{iodigits}

static void InitZZIO()
{
   long x;

   x = (ZZ_RADIX-1)/10;
   iodigits = 0;
   ioradix = 1;

   while (x) {
      x = x / 10;
      iodigits++;
      ioradix = ioradix * 10;
   }

   if (iodigits <= 0) Error("problem with I/O");
}

istream& operator>>(istream& s, ZZ& x)
{
   long c;
   long sign;
   long ndigits;
   long acc;
   static ZZ a;

   if (!s) Error("bad ZZ input");

   if (!iodigits) InitZZIO();

   a << 0;

   c = s.peek();
   while (c == ' ' || c == '\n' || c == '\t') {
      s.get();
      c = s.peek();
   }

   if (c == '-') {
      sign = -1;
      
      s.get();
      c = s.peek();
      while (c == ' ' || c == '\n' || c == '\t') {
         s.get();
         c = s.peek();
      }
   }
   else
      sign = 1;

   if (c < '0' || c > '9') Error("bad ZZ input");

   ndigits = 0;
   acc = 0;
   while (c >= '0' && c <= '9') {
      acc = acc*10 + c - '0';
      ndigits++;

      if (ndigits == iodigits) {
         mul(a, a, ioradix);
         add(a, a, acc);
         ndigits = 0;
         acc = 0;
      }

      s.get();
      c = s.peek();
   }

   if (ndigits != 0) {
      long mpy = 1;
      while (ndigits > 0) {
         mpy = mpy * 10;
         ndigits--;
      }

      mul(a, a, mpy);
      add(a, a, acc);
   }

   if (sign == -1)
      negate(a, a);

   x = a;

   return s;
}

struct lstack {
   long top;
   long alloc;
   long *elts;

   lstack() { top = -1; alloc = 0; elts = 0; }
   ~lstack() { }

   long pop() { return elts[top--]; }
   long empty() { return (top == -1); }
   void push(long x);
};

void lstack::push(long x)
{
   if (alloc == 0) {
      alloc = 100;
      elts = (long *) malloc(alloc * sizeof(long));
   }

   top++;

   if (top + 1 > alloc) {
      alloc = 2*alloc;
      elts = (long *) realloc(elts, alloc * sizeof(long));
   }

   if (!elts) {
      Error("out of space in ZZ output");
   }

   elts[top] = x;
}


static
void PrintDigits(ostream& s, long d, long justify)
{
   static char *buf = 0;

   if (!buf) {
      buf = (char *) malloc(iodigits);
      if (!buf) Error("out of memory");
   }

   long i = 0;

   while (d) {
      buf[i] = (d % 10) + '0';
      d = d / 10;
      i++;
   }

   if (justify) {
      long j = iodigits - i;
      while (j > 0) {
         s << "0";
         j--;
      }
   }

   while (i > 0) {
      i--;
      s << buf[i];
   }
}
      

   

ostream& operator<<(ostream& s, const ZZ& a)
{
   static ZZ b;
   static lstack S;
   long r;
   long k;

   if (!iodigits) InitZZIO();

   b = a;

   k = sign(b);

   if (k == 0) {
      s << "0";
      return s;
   }

   if (k < 0) {
      s << "-";
      negate(b, b);
   }

   do {
      r = DivRem(b, b, ioradix);
      S.push(r);
   } while (!IsZero(b));

   r = S.pop();
   PrintDigits(s, r, 0);

   while (!S.empty()) {
      r = S.pop();
      PrintDigits(s, r, 1);
   }
      
   return s;
}


// ******  MultiMul 


void MultiMul(ZZ& x, long n, const ZZ* a, const long* b, long size)
{
   verylong xx, yy;
   long i, sx;

   sx = size+1;

   zsetlength(&x.rep, sx);
   xx = x.rep;

   for (i = 1; i <= sx; i++)
      xx[i] = 0;

   xx++;

   for (i = 0; i < n; i++) {
      yy = a[i].rep;

      if (!yy || !b[i]) continue;

      zaddmul(b[i], xx, yy);
      yy = xx + yy[0];
   
      if ((*yy) >= ZZ_RADIX) {
         (*yy) -= ZZ_RADIX;
         yy++;
         while ((*yy) == ZZ_RADIX-1) {
            *yy = 0;
            yy++;
         }
         (*yy)++;
      }
   }

   xx--;
   while (sx > 1 && xx[sx] == 0) sx--;
   xx[0] = sx;
}


long GCD(long a, long b)
{
   long u, v, t, x;

   if (a < 0)
      a = -a;

   if (b < 0)
      b = -b;

   if (a < 0 || b < 0)
      Error("GCD: integer overflow");

   if (b==0)
      x = a;
   else {
      u = a;
      v = b;
      do {
         t = u % v;
         u = v; 
         v = t;
      } while (v != 0);

      x = u;
   }

   return x;
}

         

void XGCD(long& d, long& s, long& t, long a, long b)
{
   long  u, v, u0, v0, u1, v1, u2, v2, q, r;

   long aneg = 0, bneg = 0;

   if (a < 0) {
      a = -a;
      aneg = 1;
   }

   if (b < 0) {
      b = -b;
      bneg = 1;
   }

   if (a < 0 || b < 0)
      Error("XGCD: integer overflow");

   u1=1; v1=0;
   u2=0; v2=1;
   u = a; v = b;

   while (v != 0) {
      q = u / v;
      r = u % v;
      u = v;
      v = r;
      u0 = u2;
      v0 = v2;
      u2 =  u1 - q*u2;
      v2 = v1- q*v2;
      u1 = u0;
      v1 = v0;
   }

   if (aneg)
      u1 = -u1;

   if (bneg)
      v1 = -v1;

   d = u;
   s = u1;
   t = v1;
}
   

long InvMod(long a, long n)
{
   long d, s, t;

   XGCD(d, s, t, a, n);
   if (d != 1) Error("InvMod: inverse undefined");
   if (s < 0)
      return s + n;
   else
      return s;
}


long PowerMod(long a, long e, long n)
{
   long x, y;

   if (e < 0) {
      Error("negative exponent in PowerMod");
   }

   x = 1;
   y = a;
   while (e) {
      if (e & 1) x = MulMod(x, y, n);
      y = MulMod(y, y, n);
      e = e >> 1;
   }

   return x;
}

long ProbPrime(long n, long NumTests)
{
   static long  m, x, y, z;
   long i, j, k;

   if (n <= 1) return 0;


   if (n == 2) return 1;
   if (n % 2 == 0) return 0;

   if (n == 3) return 1;
   if (n % 3 == 0) return 0;

   if (n == 5) return 1;
   if (n % 5 == 0) return 0;

   if (n == 7) return 1;
   if (n % 7 == 0) return 0;

   if (n >= ZZ_RADIX) {
      ZZ nn;
      nn << n;
      return ProbPrime(nn, NumTests);
   }

   m = n - 1;
   k = 0;
   while((m & 1) == 0) {
      m = m >> 1;
      k++;
   }

   // n - 1 == 2^k * m, m odd

   for (i = 0; i < NumTests; i++) {
      x = RandomBnd(n);


      if (x == 0) continue;
      z = PowerMod(x, m, n);
      if (z == 1) continue;
   
      j = 0;
      do {
         y = z;
         z = MulMod(y, y, n);
         j++;
      } while (!(j == k || z == 1));

      if (z != 1 || y !=  n-1) return 0;
   }

   return 1;
}


long MillerWitness(const ZZ& n, const ZZ& x)
{
   static ZZ  m, y, z;
   long j, k;

   if (x == 0) return 0;

   add(m, n, -1);
   k = MakeOdd(m);
   // n - 1 == 2^k * m, m odd

   PowerMod(z, x, m, n);
   if (z == 1) return 0;

   j = 0;
   do {
      y = z;
      SqrMod(z, y, n);
      j++;
   } while (!(j == k || z == 1));

   if (z != 1) return 1;
   add(y, y, 1);
   if (y != n) return 1;
   return 0;
}

long MillerWitness(const ZZ& n, long x)
{
   static ZZ  m, y, z;
   long j, k;

   add(m, n, -1);
   k = MakeOdd(m);
   // n - 1 == 2^k * m, m odd

   if (x == 0) return 0;
   PowerMod(z, x, m, n);
   if (z == 1) return 0;

   j = 0;
   do {
      y = z;
      SqrMod(z, y, n);
      j++;
   } while (!(j == k || z == 1));

   if (z != 1) return 1;
   add(y, y, 1);
   if (y != n) return 1;
   return 0;
}

long ProbPrime(const ZZ& n, long NumTrials)
{
   if (n <= 1) return 0;

   if (n.size() == 1) {
      long nn;
      nn << n;
      return ProbPrime(nn, NumTrials);
   }

   PrimeSeq s;
   long p;

   p = s.next();
   while (p && p < 1000) {
      if (rem(n, p) == 0)
         return 0;

      p = s.next();
   }

   long w;

   w = RandomLen(ZZ_NBITS);

   if (MillerWitness(n, w))
      return 0;

   ZZ W;
   long i;

   for (i = 0; i < NumTrials; i++) {
      RandomBnd(W, n);
      if (MillerWitness(n, W)) 
         return 0;
   }

   return 1;
}


void RandomPrime(ZZ& n, long l, long NumTrials)
{
   if (l <= 1)
      Error("RandomPrime: l out of range");

   do {
      RandomLen(n, l);
   } while (!ProbPrime(n, NumTrials));
}

void NextPrime(ZZ& n, const ZZ& m, long NumTrials)
{
   ZZ x;

   if (m <= 2) {
      n << 2;
      return;
   }

   x = m;

   while (!ProbPrime(x, NumTrials))
      add(x, x, 1);

   n = x;
}

long NextPrime(long m, long NumTrials)
{
   long x;

   if (m <= 2) 
      return 2;

   x = m;

   while (x < ZZ_RADIX && !ProbPrime(x, NumTrials))
      x++;

   if (x >= ZZ_RADIX)
      Error("NextPrime: no more primes");

   return x;
}

   



long NextPowerOfTwo(long m)
{
   long k; 
   long n;
   n = 1;
   k = 0;
   while (n < m) {
      n = n << 1;
      k++;
   }
   return k;
}



long NumBits(long a)
{
   unsigned long aa;
   if (a < 0) 
      aa = -a;
   else
      aa = a;

   long k = 0;
   while (aa) {
      k++;
      aa = aa >> 1;
   }

   return k;
}


long bit(long a, long k)
{
   unsigned long aa;
   if (a < 0)
      aa = -a;
   else
      aa = a;

   if (k < 0 || k >= ZZ_BITS_PER_LONG) 
      return 0;
   else
      return long((aa >> k) & 1);
}



long divide(ZZ& q, const ZZ& a, const ZZ& b)
{
   static ZZ qq, r;

   if (IsZero(b)) {
      if (IsZero(a)) {
         clear(q);
         return 1;
      }
      else
         return 0;
   }


   if (IsOne(b)) {
      q = a;
      return 1;
   }

   DivRem(qq, r, a, b);
   if (!IsZero(r)) return 0;
   q = qq;
   return 1;
}

long divide(const ZZ& a, const ZZ& b)
{
   static ZZ r;

   if (IsZero(b)) return IsZero(a);
   if (IsOne(b)) return 1;

   rem(r, a, b);
   return IsZero(r);
}

long divide(ZZ& q, const ZZ& a, long b)
{
   static ZZ qq;

   if (!b) {
      if (IsZero(a)) {
         clear(q);
         return 1;
      }
      else
         return 0;
   }

   if (b == 1) {
      q = a;
      return 1;
   }

   long r = DivRem(qq, a, b);
   if (r) return 0;
   q = qq;
   return 1;
}

long divide(const ZZ& a, long b)
{
   if (!b) return IsZero(a);
   if (b == 1) {
      return 1;
   }

   long r = rem(a,  b);
   return (r == 0);
}
   

long RandomLen(long l)
{
   if (l <= 0) return 0;
   if (l == 1) return 1;
   if (l >= ZZ_BITS_PER_LONG) 
      Error("RandomLen: l out out of range");
   return RandomBnd(1L << (l-1)) + (1L << (l-1)); 
}



long RandomPrime(long l, long NumTrials)
{
   if (l <= 1 || l >= ZZ_BITS_PER_LONG)
      Error("RandomPrime: l out of range");

   long n;
   do {
      n = RandomLen(l);
   } while (!ProbPrime(n, NumTrials));

   return n;
}


PrimeSeq::PrimeSeq()
{
   movesieve = 0;
   movesieve_mem = 0;
   pshift = -1;
   pindex = -1;
   exhausted = 0;
}

PrimeSeq::~PrimeSeq()
{
   if (movesieve_mem)
      free(movesieve_mem);
}

long PrimeSeq::next()
{
   if (exhausted) {
      return 0;
   }

   if (pshift < 0) {
      shift(0);
      return 2;
   }

   for (;;) {
      char *p = movesieve;
      long i = pindex;

      while ((++i) < PRIME_BND) {
         if (p[i]) {
            pindex = i;
            return pshift + 2 * i + 3;
         }
      }

      long newshift = pshift + 2*PRIME_BND;

      if (newshift > 2 * PRIME_BND * (2 * PRIME_BND + 1)) {
         /* end of the road */
         exhausted = 1;
         return 0;
      }

      shift(newshift);
   }
}

static char *lowsieve = 0;

void PrimeSeq::shift(long newshift)
{
   long i;
   long j;
   long jstep;
   long jstart;
   long ibound;
   char *p;

   if (!lowsieve)
      start();

   pindex = -1;
   exhausted = 0;

   if (newshift < 0) {
      pshift = -1;
      return;
   }

   if (newshift == pshift) return;

   pshift = newshift;

   if (pshift == 0) {
      movesieve = lowsieve;
   } 
   else {
      if (!movesieve_mem) {
         movesieve_mem = (char *) malloc(PRIME_BND);
         if (!movesieve_mem) 
            Error("out of memory in PrimeSeq");
      }

      p = movesieve = movesieve_mem;
      for (i = 0; i < PRIME_BND; i++)
         p[i] = 1;

      jstep = 3;
      ibound = pshift + 2 * PRIME_BND + 1;
      for (i = 0; jstep * jstep <= ibound; i++) {
         if (lowsieve[i]) {
            if (!((jstart = (pshift + 2) / jstep + 1) & 1))
               jstart++;
            if (jstart <= jstep)
               jstart = jstep;
            jstart = (jstart * jstep - pshift - 3) / 2;
            for (j = jstart; j < PRIME_BND; j += jstep)
               p[j] = 0;
         }
         jstep += 2;
      }
   }
}


void PrimeSeq::start()
{
   long i;
   long j;
   long jstep;
   long jstart;
   long ibnd;
   char *p;

   p = lowsieve = (char *) malloc(PRIME_BND);
   if (!p)
      Error("out of memory in PrimeSeq");

   for (i = 0; i < PRIME_BND; i++)
      p[i] = 1;
      
   jstep = 1;
   jstart = -1;
   ibnd = (SqrRoot(2 * PRIME_BND + 1) - 3) / 2;
   for (i = 0; i <= ibnd; i++) {
      jstart += 2 * ((jstep += 2) - 1);
      if (p[i])
         for (j = jstart; j < PRIME_BND; j += jstep)
            p[j] = 0;
   }
}

void PrimeSeq::reset(long b)
{
   if (b > (2*PRIME_BND+1)*(2*PRIME_BND+1)) {
      exhausted = 1;
      return;
   }

   if (b <= 2) {
      shift(-1);
      return;
   }

   if ((b & 1) == 0) b++;

   shift(((b-3) / (2*PRIME_BND))* (2*PRIME_BND));
   pindex = (b - pshift - 3)/2 - 1;
}
 
long Jacobi(const ZZ& aa, const ZZ& nn)
{
   ZZ a, n;
   long t, k;
   long d;

   a = aa;
   n = nn;
   t = 1;

   while (a != 0) {
      k = MakeOdd(a);
      d = LowBits(n, 3);
      if ((k & 1) && (d == 3 || d == 5)) t = -t;

      if (LowBits(a, 2) == 3 && (d & 3) == 3) t = -t;
      swap(a, n);
      QuickRem(a, n);
   }

   if (n == 1)
      return t;
   else
      return 0;
}


void SqrRootMod(ZZ& x, const ZZ& aa, const ZZ& nn)
{
   if (aa == 0) {
      x << 0;
      return;
   }

   long i, k;
   ZZ ma, n, t, u, v, e;
   ZZ t1, t2, t3;

   n = nn;
   NegateMod(ma, aa, n);

   // find t such that t^2 - 4*a is not a squre

   MulMod(t1, ma, 4, n);
   do {
      RandomBnd(t, n);
      SqrMod(t2, t, n);
      AddMod(t2, t2, t1, n);
   } while (Jacobi(t2, n) != -1);

   // compute u*X + v = X^{(n+1)/2} mod f, where f = X^2 - t*X + a

   add(e, n, 1);
   RightShift(e, e, 1);

   u << 0;
   v << 1;

   k = NumBits(e);

   for (i = k - 1; i >= 0; i--) {
      SqrMod(t1, u, n);
      SqrMod(t2, v, n);
      MulMod(t3, u, v, n);
      MulMod(t3, t3, 2, n);
      MulMod(u, t1, t, n);
      AddMod(u, u, t3, n);
      MulMod(v, t1, ma, n);
      AddMod(v, v, t2, n);

      if (bit(e, i)) {
         MulMod(t1, u, t, n);
         AddMod(t1, t1, v, n);
         MulMod(v, u, ma, n);
         u = t1;
      }

   }

   x = v;
}
long CRT(ZZ& g, ZZ& a, const ZZ& G, const ZZ& p)
{
   ZZ a_inv;

   rem(a_inv, a, p);
   InvMod(a_inv, a_inv, p);
 
   ZZ aa, new_a, new_a1;
   ZZ v_a, v_p, t;

   mul(aa, a, a_inv);
   mul(new_a, a, p);
   RightShift(new_a1, new_a, 1);

   long modified = 0;

   ZZ k;
   rem(k, g, p);
   MulMod(k, k, a_inv, p);
   NegateMod(k, k, p);
   mul(v_a, a, k);
   add(v_a, v_a, g);
   mul(v_p, aa, G);
   add(t, v_a, v_p);
   rem(t, t, new_a);
   if (t > new_a1)
      sub(t, t, new_a);

   if (t != g) {
      g = t;
      modified = 1;
   }

   a = new_a;

   return modified;
}



long CRT(ZZ& g, ZZ& a, long G, long p)
{
   if (p >= ZZ_RADIX) {
      ZZ GG, pp;
      GG << G;
      pp << p;
      return CRT(g, a, GG, pp);
   }

   long a_inv;

   a_inv = rem(a, p);
   a_inv = InvMod(a_inv, p);
 
   ZZ aa, new_a, new_a1;
   ZZ v_a, v_p, t;

   mul(aa, a, a_inv);
   mul(new_a, a, p);
   RightShift(new_a1, new_a, 1);

   long modified = 0;

   long k;
   k = rem(g, p);
   k = MulMod(k, a_inv, p);
   k = SubMod(0, k, p);
   mul(v_a, a, k);
   add(v_a, v_a, g);
   mul(v_p, aa, G);
   add(t, v_a, v_p);
   rem(t, t, new_a);
   if (t > new_a1)
      sub(t, t, new_a);

   if (t != g) {
      g = t;
      modified = 1;
   }

   a = new_a;

   return modified;
}

void sub(ZZ& x, const ZZ& a, long b)
{
   static ZZ B;
   B << b;
   sub(x, a, B);
}

void sub(ZZ& x, long a, const ZZ& b)
{
   static ZZ A;
   A << a;
   sub(x, A, b);
}


