


/***********************************************************************

   This software is for research and educational purposes only.

************************************************************************/



#include "ZZX.h"
#include "zz_pX.h"


ZZX ZZX::_zero;


ZZX::ZZX(INIT_VAL_TYPE, long a)
{
   if (a != 0) {
      rep.SetLength(1);
      rep[0] << a;
   }
}
ZZX::ZZX(INIT_VAL_TYPE, const ZZ& a)
{
   if (!IsZero(a)) {
      rep.SetLength(1);
      rep[0] = a;
   }
}



void operator<<(zz_pX& x, const ZZX& a)
{
   x.rep << a.rep;
   x.normalize();
}

void operator<<(ZZ_pX& x, const ZZX& a)
{
   x.rep << a.rep;
   x.normalize();
}

void operator<<(ZZX& x, const ZZ_pX& a)
{
   x.rep << a.rep;
   x.normalize();
}

void operator<<(ZZX& x, const zz_pX& a)
{
   x.rep << a.rep;
   x.normalize();
}



istream& operator>>(istream& s, ZZX& x)
{
   s >> x.rep;
   x.normalize();
   return s;
}

ostream& operator<<(ostream& s, const ZZX& a)
{
   return s << a.rep;
}


void ZZX::normalize()
{
   long n;
   const ZZ* p;

   n = rep.length();
   if (n == 0) return;
   p = rep.elts() + (n-1);
   while (n > 0 && IsZero(*p)) {
      p--; 
      n--;
   }
   rep.SetLength(n);
}


long IsZero(const ZZX& a)
{
   return a.rep.length() == 0;
}


long IsOne(const ZZX& a)
{
    return a.rep.length() == 1 && IsOne(a.rep[0]);
}

long operator==(const ZZX& a, const ZZX& b)
{
   long i, n;
   const ZZ *ap, *bp;

   n = a.rep.length();
   if (n != b.rep.length()) return 0;

   ap = a.rep.elts();
   bp = b.rep.elts();

   for (i = 0; i < n; i++)
      if (ap[i] != bp[i]) return 0;

   return 1;
}


long operator!=(const ZZX& a, const ZZX& b)
{
   return !(a == b);
}

void GetCoeff(ZZ& x, const ZZX& a, long i)
{
   if (i < 0 || i > deg(a))
      clear(x);
   else
      x = a.rep[i];
}

void SetCoeff(ZZX& x, long i, const ZZ& a)
{
   long j, m;

   if (i < 0) 
      Error("SetCoeff: negative index");

   m = deg(x);

   if (i > m) {
      x.rep.SetLength(i+1);
      for (j = m+1; j < i; j++)
         clear(x.rep[j]);
   }
   x.rep[i] = a;
   x.normalize();
}

void SetCoeff(ZZX& x, long i)
{
   long j, m;

   if (i < 0) 
      Error("coefficient index out of range");

   m = deg(x);

   if (i > m) {
      x.rep.SetLength(i+1);
      for (j = m+1; j < i; j++)
         clear(x.rep[j]);
   }
   set(x.rep[i]);
   x.normalize();
}


void SetX(ZZX& x)
{
   clear(x);
   SetCoeff(x, 1);
}


long IsX(const ZZX& a)
{
   return deg(a) == 1 && IsOne(LeadCoeff(a)) && IsZero(ConstTerm(a));
}
      
      

const ZZ& coeff(const ZZX& a, long i)
{
   if (i < 0 || i > deg(a))
      return ZZ::zero();
   else
      return a.rep[i];
}


const ZZ& LeadCoeff(const ZZX& a)
{
   if (IsZero(a))
      return ZZ::zero();
   else
      return a.rep[deg(a)];
}

const ZZ& ConstTerm(const ZZX& a)
{
   if (IsZero(a))
      return ZZ::zero();
   else
      return a.rep[0];
}



void operator<<(ZZX& x, const ZZ& a)
{
   if (IsZero(a))
      x.rep.SetLength(0);
   else {
      x.rep.SetLength(1);
      x.rep[0] = a;
   }
}

void operator<<(ZZX& x, long a)
{
   ZZ t;

   t << a;
   x << t;
}


void operator<<(ZZX& x, const vector(ZZ)& a)
{
   x.rep = a;
   x.normalize();
}


void add(ZZX& x, const ZZX& a, const ZZX& b)
{
   long da = deg(a);
   long db = deg(b);
   long minab = min(da, db);
   long maxab = max(da, db);
   x.rep.SetLength(maxab+1);

   long i;
   const ZZ *ap, *bp; 
   ZZ* xp;

   for (i = minab+1, ap = a.rep.elts(), bp = b.rep.elts(), xp = x.rep.elts();
        i; i--, ap++, bp++, xp++)
      add(*xp, (*ap), (*bp));

   if (da > minab && &x != &a)
      for (i = da-minab; i; i--, xp++, ap++)
         *xp = *ap;
   else if (db > minab && &x != &b)
      for (i = db-minab; i; i--, xp++, bp++)
         *xp = *bp;
   else
      x.normalize();
}

void add(ZZX& x, const ZZX& a, const ZZ& b)
{
   if (a.rep.length() == 0) {
      x << b;
   }
   else {
      x = a;
      add(x.rep[0], x.rep[0], b);
      x.normalize();
   }
}

void add(ZZX& x, const ZZ& a, const ZZX& b)
{
   add(x, b, a);
}

void add(ZZX& x, const ZZX& a, long b)
{
   static ZZ B;
   B << b;
   add(x, a, B);
}

void add(ZZX& x, long a, const ZZX& b)
{
   add(x, b, a);
}


void sub(ZZX& x, const ZZX& a, const ZZX& b)
{
   long da = deg(a);
   long db = deg(b);
   long minab = min(da, db);
   long maxab = max(da, db);
   x.rep.SetLength(maxab+1);

   long i;
   const ZZ *ap, *bp; 
   ZZ* xp;

   for (i = minab+1, ap = a.rep.elts(), bp = b.rep.elts(), xp = x.rep.elts();
        i; i--, ap++, bp++, xp++)
      sub(*xp, (*ap), (*bp));

   if (da > minab && &x != &a)
      for (i = da-minab; i; i--, xp++, ap++)
         *xp = *ap;
   else if (db > minab)
      for (i = db-minab; i; i--, xp++, bp++)
         negate(*xp, *bp);
   else
      x.normalize();

}

void sub(ZZX& x, const ZZX& a, const ZZ& b)
{
   if (a.rep.length() == 0) {
      x.rep.SetLength(1);
      negate(x.rep[0], b);
   }
   else {
      x = a;
      sub(x.rep[0], x.rep[0], b);
   }
   x.normalize();
}

void sub(ZZX& x, const ZZX& a, long b)
{
   ZZX t;
   t << b;
   sub(x, a, t);
}

void sub(ZZX& x, const ZZ& a, const ZZX& b)
{
   ZZX t;
   t << a;
   sub(x, t, b);
}

void sub(ZZX& x, long a, const ZZX& b)
{
   ZZX t;
   t << a;
   sub(x, t, b);
}




void negate(ZZX& x, const ZZX& a)
{
   long n = a.rep.length();
   x.rep.SetLength(n);

   const ZZ* ap = a.rep.elts();
   ZZ* xp = x.rep.elts();
   long i;

   for (i = n; i; i--, ap++, xp++)
      negate((*xp), (*ap));

}

long MaxBits(const ZZX& f)
{
   long i, m;
   m = 0;

   for (i = 0; i <= deg(f); i++) {
      m = max(m, NumBits(f.rep[i]));
   }

   return m;
}


void PlainMul(ZZX& x, const ZZX& a, const ZZX& b)
{
   if (&a == &b) {
      PlainSqr(x, a);
      return;
   }

   long da = deg(a);
   long db = deg(b);

   if (da < 0 || db < 0) {
      clear(x);
      return;
   }

   long d = da+db;



   const ZZ *ap, *bp;
   ZZ *xp;
   
   ZZX la, lb;

   if (&x == &a) {
      la = a;
      ap = la.rep.elts();
   }
   else
      ap = a.rep.elts();

   if (&x == &b) {
      lb = b;
      bp = lb.rep.elts();
   }
   else
      bp = b.rep.elts();

   x.rep.SetLength(d+1);

   xp = x.rep.elts();

   long i, j, jmin, jmax;
   ZZ t, accum;

   for (i = 0; i <= d; i++) {
      jmin = max(0, i-db);
      jmax = min(da, i);
      clear(accum);
      for (j = jmin; j <= jmax; j++) {
	 mul(t, ap[j], bp[i-j]);
	 add(accum, accum, t);
      }
      xp[i] = accum;
   }
   x.normalize();
}

void PlainSqr(ZZX& x, const ZZX& a)
{
   long da = deg(a);

   if (da < 0) {
      clear(x);
      return;
   }

   long d = 2*da;

   const ZZ *ap;
   ZZ *xp;

   ZZX la;

   if (&x == &a) {
      la = a;
      ap = la.rep.elts();
   }
   else
      ap = a.rep.elts();


   x.rep.SetLength(d+1);

   xp = x.rep.elts();

   long i, j, jmin, jmax;
   long m, m2;
   ZZ t, accum;

   for (i = 0; i <= d; i++) {
      jmin = max(0, i-da);
      jmax = min(da, i);
      m = jmax - jmin + 1;
      m2 = m >> 1;
      jmax = jmin + m2 - 1;
      clear(accum);
      for (j = jmin; j <= jmax; j++) {
	 mul(t, ap[j], ap[i-j]);
	 add(accum, accum, t);
      }
      add(accum, accum, accum);
      if (m & 1) {
	 sqr(t, ap[jmax + 1]);
	 add(accum, accum, t);
      }

      xp[i] = accum;
   }

   x.normalize();
}



long CRT(ZZX& g, ZZ& a, const zz_pX& G)
{
   long p = zz_p::modulus();
   zz_p a_inv;

   a_inv << a;
   inv(a_inv, a_inv);
 
   ZZ aa, new_a, new_a1;
   ZZ v_a, v_p, t;

   mul(aa, a, rep(a_inv));
   mul(new_a, a, p);

   RightShift(new_a1, new_a, 1);

   long n = g.rep.length();
   g.rep.SetLength(max(n, G.rep.length()));

   long modified = 0;
   long i;

   for (i = 0; i < g.rep.length(); i++) {
      if (i < n) {
         zz_p k;
         k << g.rep[i];
         mul(k, k, a_inv);
         negate(k, k);
         mul(v_a, a, rep(k));
         add(v_a, v_a, g.rep[i]);
      }
      else
         clear(v_a);

      if (i < G.rep.length()) 
         mul(v_p, aa, rep(G.rep[i]));
      else
         clear(v_p);

      add(t, v_a, v_p);

      rem(t, t, new_a);
      if (t > new_a1)
         sub(t, t, new_a);

      if (i >= n || t != g.rep[i]) {
         g.rep[i] = t;
         modified = 1;
      }
   }

   a = new_a;

   return modified;
}


static
void PlainMul(ZZ *xp, const ZZ *ap, long sa, const ZZ *bp, long sb)
{
   if (sa == 0 || sb == 0) return;

   long sx = sa+sb-1;

   static ZZ tmp;

   long i, j, jmin, jmax;
   static ZZ t, accum;

   for (i = 0; i < sx; i++) {
      jmin = max(0, i-sb+1);
      jmax = min(sa-1, i);
      clear(accum);
      for (j = jmin; j <= jmax; j++) {
         mul(t, ap[j], bp[i-j]);
         add(accum, accum, t);
      }
      xp[i] = accum;
   }
}


static
void KarFold(ZZ *T, const ZZ *b, long sb, long hsa)
{
   long m = sb - hsa;
   long i;

   for (i = 0; i < m; i++)
      add(T[i], b[i], b[hsa+i]);

   for (i = m; i < hsa; i++)
      T[i] = b[i];
}

static
void KarSub(ZZ *T, const ZZ *b, long sb)
{
   long i;

   for (i = 0; i < sb; i++)
      sub(T[i], T[i], b[i]);
}

static
void KarAdd(ZZ *T, const ZZ *b, long sb)
{
   long i;

   for (i = 0; i < sb; i++)
      add(T[i], T[i], b[i]);
}

static
void KarFix(ZZ *c, const ZZ *b, long sb, long hsa)
{
   long i;

   for (i = 0; i < hsa; i++)
      c[i] = b[i];

   for (i = hsa; i < sb; i++)
      add(c[i], c[i], b[i]);
}


static
void KarMul(ZZ *c, const ZZ *a, 
            long sa, const ZZ *b, long sb, ZZ *stk, long xover)
{
   if (sa < sb) {
      { long t = sa; sa = sb; sb = t; }
      { const ZZ *t = a; a = b; b = t; }
   }

   if (sb < xover) {
      PlainMul(c, a, sa, b, sb);
      return;
   }

   long hsa = (sa + 1) >> 1;

   if (hsa < sb) {
      /* normal case */

      long hsa2 = hsa << 1;

      ZZ *T1, *T2, *T3;

      T1 = stk; stk += hsa;
      T2 = stk; stk += hsa;
      T3 = stk; stk += hsa2 - 1;

      /* compute T1 = a_lo + a_hi */

      KarFold(T1, a, sa, hsa);

      /* compute T2 = b_lo + b_hi */

      KarFold(T2, b, sb, hsa);

      /* recursively compute T3 = T1 * T2 */

      KarMul(T3, T1, hsa, T2, hsa, stk, xover);

      /* recursively compute a_hi * b_hi into high part of c */
      /* and subtract from T3 */

      KarMul(c + hsa2, a+hsa, sa-hsa, b+hsa, sb-hsa, stk, xover);
      KarSub(T3, c + hsa2, sa + sb - hsa2 - 1);


      /* recursively compute a_lo*b_lo into low part of c */
      /* and subtract from T3 */

      KarMul(c, a, hsa, b, hsa, stk, xover);
      KarSub(T3, c, hsa2 - 1);

      clear(c[hsa2 - 1]);

      /* finally, add T3 * X^{hsa} to c */

      KarAdd(c+hsa, T3, hsa2-1);
   }
   else {
      /* degenerate case */

      ZZ *T;

      T = stk; stk += hsa + sb - 1;

      /* recursively compute b*a_hi into high part of c */

      KarMul(c + hsa, a + hsa, sa - hsa, b, sb, stk, xover);

      /* recursively compute b*a_lo into T */

      KarMul(T, a, hsa, b, sb, stk, xover);

      KarFix(c, T, hsa + sb - 1, hsa);
   }
}

void KarMul(ZZX& c, const ZZX& a, const ZZX& b)
{
   if (IsZero(a) || IsZero(b)) {
      clear(c);
      return;
   }

   if (&a == &b) {
      KarSqr(c, a);
      return;
   }

   vector(ZZ) mem;

   const ZZ *ap, *bp;
   ZZ *cp;

   long sa = a.rep.length();
   long sb = b.rep.length();

   if (&a == &c) {
      mem = a.rep;
      ap = mem.elts();
   }
   else
      ap = a.rep.elts();

   if (&b == &c) {
      mem = b.rep;
      bp = mem.elts();
   }
   else
      bp = b.rep.elts();

   c.rep.SetLength(sa+sb-1);
   cp = c.rep.elts();

   long maxa, maxb, xover;
   double k;

   maxa = MaxBits(a);
   maxb = MaxBits(b);

   k = double((maxa + ZZ_NBITS - 1)/ZZ_NBITS) *
       double((maxb + ZZ_NBITS - 1)/ZZ_NBITS);
 
   if (k <= 36)
      xover = 6;
   else if (k <= 196)
      xover = 4;
   else
      xover = 2;

   if (sa < xover || sb < xover)
      PlainMul(cp, ap, sa, bp, sb);
   else {
      /* karatsuba */

      long n, hn, sp, depth;

      n = max(sa, sb);
      sp = 0;
      depth = 0;
      do {
         hn = (n+1) >> 1;
         sp += (hn << 2) - 1;
         n = hn;
         depth++;
      } while (n >= xover);

      ZZVec stk;
      stk.SetSize(sp, 
         ((maxa + maxb + NumBits(min(sa, sb)) + 2*depth + 10) 
          + ZZ_NBITS-1)/ZZ_NBITS);

      KarMul(cp, ap, sa, bp, sb, stk.elts(), xover);
   }

   c.normalize();
}








/* Compute a = b * 2^l mod p, where p = 2^n+1. 0<=l<=n and 0<b<p are
   assumed. */
static void LeftRotate(ZZ& a, const ZZ& b, long l, const ZZ& p, long n)
{
  if (l == 0) {
    if (&a != &b) {
      a = b;
    }
    return;
  }

  /* tmp := upper l bits of b */
  static ZZ tmp;
  RightShift(tmp, b, n - l);
  /* a := 2^l * lower n - l bits of b */
  LowBits(a, b, n - l);
  LeftShift(a, a, l);
  /* a -= tmp */
  sub(a, a, tmp);
  if (sign(a) < 0) {
    add(a, a, p);
  }
}


/* Compute a = b * 2^l mod p, where p = 2^n+1. 0<=p<b is assumed. */
static void Rotate(ZZ& a, const ZZ& b, long l, const ZZ& p, long n)
{
  if (IsZero(b)) {
    clear(a);
    return;
  }

  /* l %= 2n */
  if (l >= 0) {
    l %= (n << 1);
  } else {
    l = (n << 1) - 1 - (-(l + 1) % (n << 1));
  }

  /* a = b * 2^l mod p */
  if (l < n) {
    LeftRotate(a, b, l, p, n);
  } else {
    LeftRotate(a, b, l - n, p, n);
    SubPos(a, p, a);
  }
}



/* Fast Fourier Transform. a is a vector of length 2^l, 2^l divides 2n,
   p = 2^n+1, w = 2^r mod p is a primitive (2^l)th root of
   unity. Returns a(1),a(w),...,a(w^{2^l-1}) mod p in bit-reverse
   order. */
static void fft(vector(ZZ)& a, long r, long l, const ZZ& p, long n)
{
  long round;
  long off, i, j, e;
  long halfsize;
  ZZ tmp, tmp1;

  for (round = 0; round < l; round++, r <<= 1) {
    halfsize =  1L << (l - 1 - round);
    for (i = (1L << round) - 1, off = 0; i >= 0; i--, off += halfsize) {
      for (j = 0, e = 0; j < halfsize; j++, off++, e+=r) {
	/* One butterfly : 
	 ( a[off], a[off+halfsize] ) *= ( 1  w^{j2^round} )
	                                ( 1 -w^{j2^round} ) */
	/* tmp = a[off] - a[off + halfsize] mod p */
	sub(tmp, a[off], a[off + halfsize]);
	if (sign(tmp) < 0) {
	  add(tmp, tmp, p);
	}
	/* a[off] += a[off + halfsize] mod p */
	add(a[off], a[off], a[off + halfsize]);
	sub(tmp1, a[off], p);
	if (sign(tmp1) >= 0) {
	  a[off] = tmp1;
	}
	/* a[off + halfsize] = tmp * w^{j2^round} mod p */
	Rotate(a[off + halfsize], tmp, e, p, n);
      }
    }
  }
}

/* Inverse FFT. r must be the same as in the call to FFT. Result is
   by 2^l too large. */
static void ifft(vector(ZZ)& a, long r, long l, const ZZ& p, long n)
{
  long round;
  long off, i, j, e;
  long halfsize;
  ZZ tmp, tmp1;

  for (round = l - 1, r <<= l - 1; round >= 0; round--, r >>= 1) {
    halfsize = 1L << (l - 1 - round);
    for (i = (1L << round) - 1, off = 0; i >= 0; i--, off += halfsize) {
      for (j = 0, e = 0; j < halfsize; j++, off++, e+=r) {
	/* One inverse butterfly : 
	 ( a[off], a[off+halfsize] ) *= ( 1               1             )
	                                ( w^{-j2^round}  -w^{-j2^round} ) */
	/* a[off + halfsize] *= w^{-j2^round} mod p */
	Rotate(a[off + halfsize], a[off + halfsize], -e, p, n);
	/* tmp = a[off] - a[off + halfsize] */
	sub(tmp, a[off], a[off + halfsize]);

	/* a[off] += a[off + halfsize] mod p */
	add(a[off], a[off], a[off + halfsize]);
	sub(tmp1, a[off], p);
	if (sign(tmp1) >= 0) {
	  a[off] = tmp1;
	}
	/* a[off+halfsize] = tmp mod p */
	if (sign(tmp) < 0) {
	  add(a[off+halfsize], tmp, p);
	} else {
	  a[off+halfsize] = tmp;
	}
      }
    }
  }
}



/* Multiplication a la Schoenhage & Strassen, modulo a "Fermat" number
   p = 2^{mr}+1, where m is a power of two and r is odd. Then w = 2^r
   is a primitive 2mth root of unity, i.e., polynomials whose product
   has degree less than 2m can be multiplied, provided that the
   coefficients of the product polynomial are at most 2^{mr-1} in
   absolute value. The algorithm is not called recursively;
   coefficient arithmetic is done directly.*/

void SSMul(ZZX& c, const ZZX& a, const ZZX& b)
{
  long na = deg(a);
  long nb = deg(b);

  if (na <= 0 || nb <= 0) {
    PlainMul(c, a, b);
    return;
  }

  long n = na + nb; /* degree of the product */


  /* Choose m and r suitably */
  long l = NextPowerOfTwo(n + 1) - 1; /* 2^l <= n < 2^{l+1} */
  long m2 = 1L << (l + 1); /* m2 = 2m = 2^{l+1} */
  /* Bitlength of the product: if the coefficients of a are absolutely less
     than 2^ka and the coefficients of b are absolutely less than 2^kb, then
     the coefficients of ab are absolutely less than
     (min(na,nb)+1)2^{ka+kb} <= 2^bound. */
  long bound = 2 + NumBits(min(na, nb)) + MaxBits(a) + MaxBits(b);
  /* Let r be minimal so that mr > bound */
  long r = (bound >> l) + 1;
  long mr = r << l;

  /* p := 2^{mr}+1 */
  ZZ p;
  set(p);
  LeftShift(p, p, mr);
  add(p, p, 1);

  /* Make coefficients of a and b positive */
  vector(ZZ) aa, bb;
  aa.SetLength(m2);
  bb.SetLength(m2);

  long i;
  for (i = 0; i <= deg(a); i++) {
    if (sign(a.rep[i]) >= 0) {
      aa[i] = a.rep[i];
    } else {
      add(aa[i], a.rep[i], p);
    }
  }

  for (i = 0; i <= deg(b); i++) {
    if (sign(b.rep[i]) >= 0) {
      bb[i] = b.rep[i];
    } else {
      add(bb[i], b.rep[i], p);
    }
  }

  /* 2m-point FFT's mod p */
  fft(aa, r, l + 1, p, mr);
  fft(bb, r, l + 1, p, mr);

  /* Pointwise multiplication aa := aa * bb mod p */
  ZZ tmp, ai;
  for (i = 0; i < m2; i++) {
    mul(ai, aa[i], bb[i]);
    if (NumBits(ai) > mr) {
      RightShift(tmp, ai, mr);
      LowBits(ai, ai, mr);
      sub(ai, ai, tmp);
      if (sign(ai) < 0) {
	add(ai, ai, p);
      }
    }
    aa[i] = ai;
  }
  
  ifft(aa, r, l + 1, p, mr);

  /* Retrieve c, dividing by 2m, and subtracting p where necessary */
  c.rep.SetLength(n + 1);
  for (i = 0; i <= n; i++) {
    ai = aa[i];
    ZZ& ci = c.rep[i];
    if (!IsZero(ai)) {
      /* ci = -ai * 2^{mr-l-1} = ai * 2^{-l-1} = ai / 2m mod p */
      LeftRotate(ai, ai, mr - l - 1, p, mr);
      sub(tmp, p, ai);
      if (NumBits(tmp) >= mr) { /* ci >= (p-1)/2 */
	negate(ci, ai); /* ci = -ai = ci - p */
      }
      else
        ci = tmp;
    } 
    else
       clear(ci);
  }
}

void HomMul(ZZX& x, const ZZX& a, const ZZX& b)
{
   if (&a == &b) {
      HomSqr(x, a);
      return;
   }

   long da = deg(a);
   long db = deg(b);

   if (da < 0 || db < 0) {
      clear(x);
      return;
   }

   long bound = 2 + NumBits(min(da, db)+1) + MaxBits(a) + MaxBits(b);


   ZZ prod;
   set(prod);

   long i, nprimes;

   zz_pBak bak;
   bak.save();

   for (nprimes = 0; NumBits(prod) <= bound; nprimes++) {
      if (nprimes >= NumFFTPrimes)
         zz_pFFTInit(nprimes);
      mul(prod, prod, FFTPrime[nprimes]);
   }


   ZZ coeff;
   ZZ t1;
   long tt;

   vector(ZZ) c;

   c.SetLength(da+db+1);

   long j;

   for (i = 0; i < nprimes; i++) {
      zz_pFFTInit(i);
      long p = zz_p::modulus();

      div(t1, prod, p);
      tt = rem(t1, p);
      tt = InvMod(tt, p);
      mul(coeff, t1, tt);

      zz_pX A, B, C;

      A << a;
      B << b;
      mul(C, A, B);

      long m = deg(C);

      for (j = 0; j <= m; j++) {
         /* c[j] += coeff*rep(C.rep[j]) */
         mul(t1, coeff, rep(C.rep[j]));
         add(c[j], c[j], t1); 
      }
   }

   x.rep.SetLength(da+db+1);

   ZZ prod2;
   RightShift(prod2, prod, 1);

   for (j = 0; j <= da+db; j++) {
      rem(t1, c[j], prod);

      if (t1 > prod2)
         sub(x.rep[j], t1, prod);
      else
         x.rep[j] = t1;
   }

   x.normalize();

   bak.restore();
}

void mul(ZZX& c, const ZZX& a, const ZZX& b)
{
   if (IsZero(a) || IsZero(b)) {
      clear(c);
      return;
   }

   if (&a == &b) {
      sqr(c, a);
      return;
   }

   long maxa = MaxBits(a);
   long maxb = MaxBits(b);

   long k = (min(maxa, maxb) + ZZ_NBITS - 1)/ZZ_NBITS;
   long s = min(deg(a), deg(b)) + 1;

   if (s < 1100/k)  {
      KarMul(c, a, b);
      return;
   }


   s = deg(a) + deg(b) + 1;
   k = (maxa + maxb + ZZ_NBITS - 1)/ZZ_NBITS;
   
   if (k >= 34 && s/4 < maxa + maxb) 
      SSMul(c, a, b);
   else
      HomMul(c, a, b);
}


void SSSqr(ZZX& c, const ZZX& a)
{
  long na = deg(a);
  if (na <= 0) {
    PlainSqr(c, a);
    return;
  }

  long n = na + na; /* degree of the product */


  long l = NextPowerOfTwo(n + 1) - 1; /* 2^l <= n < 2^{l+1} */
  long m2 = 1L << (l + 1); /* m2 = 2m = 2^{l+1} */
  long bound = 2 + NumBits(na) + 2*MaxBits(a);
  long r = (bound >> l) + 1;
  long mr = r << l;

  /* p := 2^{mr}+1 */
  ZZ p;
  set(p);
  LeftShift(p, p, mr);
  add(p, p, 1);

  vector(ZZ) aa;
  aa.SetLength(m2);

  long i;
  for (i = 0; i <= deg(a); i++) {
    if (sign(a.rep[i]) >= 0) {
      aa[i] = a.rep[i];
    } else {
      add(aa[i], a.rep[i], p);
    }
  }


  /* 2m-point FFT's mod p */
  fft(aa, r, l + 1, p, mr);

  /* Pointwise multiplication aa := aa * aa mod p */
  ZZ tmp, ai;
  for (i = 0; i < m2; i++) {
    sqr(ai, aa[i]);
    if (NumBits(ai) > mr) {
      RightShift(tmp, ai, mr);
      LowBits(ai, ai, mr);
      sub(ai, ai, tmp);
      if (sign(ai) < 0) {
	add(ai, ai, p);
      }
    }
    aa[i] = ai;
  }
  
  ifft(aa, r, l + 1, p, mr);

  ZZ ci;

  /* Retrieve c, dividing by 2m, and subtracting p where necessary */
  c.rep.SetLength(n + 1);

  for (i = 0; i <= n; i++) {
    ai = aa[i];
    ZZ& ci = c.rep[i];
    if (!IsZero(ai)) {
      /* ci = -ai * 2^{mr-l-1} = ai * 2^{-l-1} = ai / 2m mod p */
      LeftRotate(ai, ai, mr - l - 1, p, mr);
      sub(tmp, p, ai);
      if (NumBits(tmp) >= mr) { /* ci >= (p-1)/2 */
	negate(ci, ai); /* ci = -ai = ci - p */
      }
      else
        ci = tmp;
    } 
    else
       clear(ci);
  }
}

void HomSqr(ZZX& x, const ZZX& a)
{

   long da = deg(a);

   if (da < 0) {
      clear(x);
      return;
   }

   long bound = 2 + NumBits(da+1) + 2*MaxBits(a);


   ZZ prod;
   set(prod);

   long i, nprimes;

   zz_pBak bak;
   bak.save();

   for (nprimes = 0; NumBits(prod) <= bound; nprimes++) {
      if (nprimes >= NumFFTPrimes)
         zz_pFFTInit(nprimes);
      mul(prod, prod, FFTPrime[nprimes]);
   }


   ZZ coeff;
   ZZ t1;
   long tt;

   vector(ZZ) c;

   c.SetLength(da+da+1);

   long j;

   for (i = 0; i < nprimes; i++) {
      zz_pFFTInit(i);
      long p = zz_p::modulus();

      div(t1, prod, p);
      tt = rem(t1, p);
      tt = InvMod(tt, p);
      mul(coeff, t1, tt);

      zz_pX A, C;

      A << a;
      sqr(C, A);

      long m = deg(C);

      for (j = 0; j <= m; j++) {
         /* c[j] += coeff*rep(C.rep[j]) */
         mul(t1, coeff, rep(C.rep[j]));
         add(c[j], c[j], t1); 
      }
   }

   x.rep.SetLength(da+da+1);

   ZZ prod2;
   RightShift(prod2, prod, 1);

   for (j = 0; j <= da+da; j++) {
      rem(t1, c[j], prod);

      if (t1 > prod2)
         sub(x.rep[j], t1, prod);
      else
         x.rep[j] = t1;
   }

   x.normalize();

   bak.restore();
}


void PlainSqr(ZZ* xp, const ZZ* ap, long sa)
{
   if (sa == 0) return;

   long da = sa-1;
   long d = 2*da;

   long i, j, jmin, jmax;
   long m, m2;
   static ZZ t, accum;

   for (i = 0; i <= d; i++) {
      jmin = max(0, i-da);
      jmax = min(da, i);
      m = jmax - jmin + 1;
      m2 = m >> 1;
      jmax = jmin + m2 - 1;
      clear(accum);
      for (j = jmin; j <= jmax; j++) {
	 mul(t, ap[j], ap[i-j]);
	 add(accum, accum, t);
      }
      add(accum, accum, accum);
      if (m & 1) {
	 sqr(t, ap[jmax + 1]);
	 add(accum, accum, t);
      }

      xp[i] = accum;
   }
}


void KarSqr(ZZ *c, const ZZ *a, long sa, ZZ *stk, long xover)
{
   if (sa < xover) {
      PlainSqr(c, a, sa);
      return;
   }

   long hsa = (sa + 1) >> 1;
   long hsa2 = hsa << 1;

   ZZ *T1, *T2;

   T1 = stk; stk += hsa;
   T2 = stk; stk += hsa2-1;

   KarFold(T1, a, sa, hsa);
   KarSqr(T2, T1, hsa, stk, xover);


   KarSqr(c + hsa2, a+hsa, sa-hsa, stk, xover);
   KarSub(T2, c + hsa2, sa + sa - hsa2 - 1);


   KarSqr(c, a, hsa, stk, xover);
   KarSub(T2, c, hsa2 - 1);

   clear(c[hsa2 - 1]);

   KarAdd(c+hsa, T2, hsa2-1);
}
      
void KarSqr(ZZX& c, const ZZX& a)
{
   if (IsZero(a)) {
      clear(c);
      return;
   }

   vector(ZZ) mem;

   const ZZ *ap;
   ZZ *cp;

   long sa = a.rep.length();

   if (&a == &c) {
      mem = a.rep;
      ap = mem.elts();
   }
   else
      ap = a.rep.elts();

   c.rep.SetLength(sa+sa-1);
   cp = c.rep.elts();

   long maxa, k, xover;

   maxa = MaxBits(a);
   k = (maxa + ZZ_NBITS - 1)/ZZ_NBITS;

   if (k <= 6)
      xover = 10;
   else if (k <= 10)
      xover = 6;
   else if (k <= 14)
      xover = 5;
   else if (k <= 18)
      xover = 4;
   else
      xover = 2;


   if (sa < xover)
      PlainSqr(cp, ap, sa);
   else {
      /* karatsuba */

      long n, hn, sp, depth;

      n = sa;
      sp = 0;
      depth = 0;
      do {
         hn = (n+1) >> 1;
         sp += hn+hn+hn - 1;
         n = hn;
         depth++;
      } while (n >= xover);

      ZZVec stk;
      stk.SetSize(sp, 
         ((2*maxa + NumBits(sa) + 2*depth + 10) 
          + ZZ_NBITS-1)/ZZ_NBITS);

      KarSqr(cp, ap, sa, stk.elts(), xover);
   }

   c.normalize();
}

void sqr(ZZX& c, const ZZX& a)
{
   if (IsZero(a)) {
      clear(c);
      return;
   }

   long maxa = MaxBits(a);

   long k = (maxa + ZZ_NBITS - 1)/ZZ_NBITS;
   long s = deg(a) + 1;

   if (s < 1100/k)  {
      KarSqr(c, a);
      return;
   }


   s = 2*deg(a) + 1;
   k = (2*maxa + ZZ_NBITS - 1)/ZZ_NBITS;
   
   if (k >= 34 && s/4 < 2*maxa) 
      SSSqr(c, a);
   else
      HomSqr(c, a);
}


void mul(ZZX& x, const ZZX& a, const ZZ& b)
{
   ZZ t;
   long i, da;

   const ZZ *ap;
   ZZ* xp;

   if (IsZero(b)) {
      clear(x);
      return;
   }

   t = b;
   da = deg(a);
   x.rep.SetLength(da+1);
   ap = a.rep.elts();
   xp = x.rep.elts();

   for (i = 0; i <= da; i++) 
      mul(xp[i], ap[i], t);
}


void mul(ZZX& x, const ZZX& a, long b)
{
   static ZZ B;
   B << b;
   mul(x, a, B);
}


void diff(ZZX& x, const ZZX& a)
{
   long n = deg(a);
   long i;

   if (n <= 0) {
      clear(x);
      return;
   }

   if (&x != &a)
      x.rep.SetLength(n);

   for (i = 0; i <= n-1; i++) {
      mul(x.rep[i], a.rep[i+1], i+1);
   }

   if (&x == &a)
      x.rep.SetLength(n);

   x.normalize();
}

void HomDivRem(ZZX& q, ZZX& r, const ZZX& a, const ZZX& b)
{
   if (IsZero(b)) Error("division by zero");

   long da = deg(a);
   long db = deg(b);

   if (da < db) {
      r = b;
      clear(q);
      return;
   }

   ZZ LC;
   LC = LeadCoeff(b);

   ZZ LC1;

   power(LC1, LC, da-db+1);

   long a_bound = NumBits(LC1) + MaxBits(a);

   LC1.kill();

   long b_bound = MaxBits(b);

   zz_pBak bak;
   bak.save();

   ZZX qq, rr;

   ZZ prod, t;
   set(prod);

   clear(qq);
   clear(rr);

   long i;
   long Qinstable, Rinstable;

   Qinstable = 1;
   Rinstable = 1;

   for (i = 0; ; i++) {
      zz_pFFTInit(i);
      long p = zz_p::modulus();


      if (divide(LC, p)) continue;

      zz_pX A, B, Q, R;

      A << a;
      B << b;
      
      if (!IsOne(LC)) {
         zz_p y;
         y << LC;
         power(y, y, da-db+1);
         mul(A, A, y);
      }

      if (!Qinstable) {
         Q << qq;
         mul(R, B, Q);
         sub(R, A, R);

         if (deg(R) >= db)
            Qinstable = 1;
         else
            Rinstable = CRT(rr, prod, R);
      }

      if (Qinstable) {
         DivRem(Q, R, A, B);
         t = prod;
         Qinstable = CRT(qq, t, Q);
         Rinstable =  CRT(rr, prod, R);
      }

      if (!Qinstable && !Rinstable) {
         // stabilized...check if prod is big enough

         long bound1 = b_bound + MaxBits(qq) + NumBits(min(db, da-db)+1);
         long bound2 = MaxBits(rr);
         long bound = max(bound1, bound2);

         if (a_bound > bound)
            bound = a_bound;

         bound += 4;

         if (NumBits(prod) > bound)
            break;
      }
   }

   bak.restore();

   q = qq;
   r = rr;
}




void HomDiv(ZZX& q, const ZZX& a, const ZZX& b)
{
   ZZX r;
   HomDivRem(q, r, a, b);
}

void HomRem(ZZX& r, const ZZX& a, const ZZX& b)
{
   ZZX q;
   HomDivRem(q, r, a, b);
}

void PlainDivRem(ZZX& q, ZZX& r, const ZZX& a, const ZZX& b)
{
   long da, db, dq, i, j, LCIsOne;
   const ZZ *bp;
   ZZ *qp;
   ZZ *xp;


   ZZ  s, t;

   da = deg(a);
   db = deg(b);

   if (db < 0) Error("ZZX: division by zero");

   if (da < db) {
      r = a;
      clear(q);
      return;
   }

   ZZX lb;

   if (&q == &b) {
      lb = b;
      bp = lb.rep.elts();
   }
   else
      bp = b.rep.elts();

   ZZ LC = bp[db];
   LCIsOne = IsOne(LC);


   vector(ZZ) x;

   x = a.rep;
   xp = x.elts();

   dq = da - db;
   q.rep.SetLength(dq+1);
   qp = q.rep.elts();

   if (!LCIsOne) {
      t = LC;
      for (i = dq-1; i >= 0; i--) {
         mul(xp[i], xp[i], t);
         if (i > 0) mul(t, t, LC);
      }
   }

   for (i = dq; i >= 0; i--) {
      t = xp[i+db];
      qp[i] = t;

      for (j = db-1; j >= 0; j--) {
	 mul(s, t, bp[j]);
         if (!LCIsOne) mul(xp[i+j], xp[i+j], LC);
	 sub(xp[i+j], xp[i+j], s);
      }
   }

   if (!LCIsOne) {
      t = LC;
      for (i = 1; i <= dq; i++) {
         mul(qp[i], qp[i], t);
         if (i < dq) mul(t, t, LC);
      }
   }
      

   r.rep.SetLength(db);
   for (i = 0; i < db; i++)
      r.rep[i] = xp[i];
   r.normalize();
}


void PlainDiv(ZZX& q, const ZZX& a, const ZZX& b)
{
   ZZX r;
   PlainDivRem(q, r, a, b);
}

void PlainRem(ZZX& r, const ZZX& a, const ZZX& b)
{
   ZZX q;
   PlainDivRem(q, r, a, b);
}


long HomDivide(ZZX& q, const ZZX& a, const ZZX& b)
{
   if (IsZero(b)) {
      if (IsZero(a)) {
         clear(q);
         return 1;
      }
      else
         return 0;
   }

   if (IsZero(a)) {
      clear(q);
      return 1;
   }

   if (deg(a) < deg(b)) return 0;

   ZZ ca, cb, cq;

   content(ca, a);
   content(cb, b);

   if (!divide(cq, ca, cb)) return 0;

   ZZX aa, bb;

   divide(aa, a, ca);
   divide(bb, b, cb);

   if (!divide(LeadCoeff(aa), LeadCoeff(bb)))
      return 0;

   if (!divide(ConstTerm(aa), ConstTerm(bb)))
      return 0;

   zz_pBak bak;
   bak.save();

   ZZX qq;

   ZZ prod;
   set(prod);

   clear(qq);
   long res = 1;
   long Qinstable = 1;


   long a_bound = MaxBits(aa);
   long b_bound = MaxBits(bb);


   long i;
   for (i = 0; ; i++) {
      zz_pFFTInit(i);
      long p = zz_p::modulus();

      if (divide(LeadCoeff(bb), p)) continue;

      zz_pX A, B, Q, R;

      A << aa;
      B << bb;

      if (!Qinstable) {
         Q << qq;
         mul(R, B, Q);
         sub(R, A, R);

         if (deg(R) >= deg(B))
            Qinstable = 1;
         else if (!IsZero(R)) {
            res = 0;
            break;
         }
         else
            mul(prod, prod, p);
      }

      if (Qinstable) {
         if (!divide(Q, A, B)) {
            res = 0;
            break;
         }

         Qinstable = CRT(qq, prod, Q);
      }

      if (!Qinstable) {
         // stabilized...check if prod is big enough

         long bound = b_bound + MaxBits(qq) + 
                     NumBits(min(deg(bb), deg(qq)) + 1);

         if (a_bound > bound)
            bound = a_bound;

         bound += 3;

         if (NumBits(prod) > bound) 
            break;
      }
   }

   bak.restore();

   if (res) mul(q, qq, cq);
   return res;

}


long HomDivide(const ZZX& a, const ZZX& b)
{
   ZZX q;
   return HomDivide(q, a, b);
}

long PlainDivide(ZZX& qq, const ZZX& aa, const ZZX& bb)
{
   if (IsZero(bb)) {
      if (IsZero(aa)) {
         clear(qq);
         return 1;
      }
      else
         return 0;
   }

   long da, db, dq, i, j, LCIsOne;
   const ZZ *bp;
   ZZ *qp;
   ZZ *xp;


   ZZ  s, t;

   da = deg(aa);
   db = deg(bb);

   if (da < db) {
      return 0;
   }

   ZZ ca, cb, cq;

   content(ca, aa);
   content(cb, bb);

   if (!divide(cq, ca, cb)) {
      return 0;
   } 


   ZZX a, b, q;

   divide(a, aa, ca);
   divide(b, bb, cb);

   if (!divide(LeadCoeff(a), LeadCoeff(b)))
      return 0;

   if (!divide(ConstTerm(a), ConstTerm(b)))
      return 0;


   bp = b.rep.elts();

   ZZ LC;
   LC = bp[db];

   LCIsOne = IsOne(LC);

   xp = a.rep.elts();

   dq = da - db;
   q.rep.SetLength(dq+1);
   qp = q.rep.elts();

   for (i = dq; i >= 0; i--) {
      if (!LCIsOne) {
         if (!divide(t, xp[i+db], LC))
            return 0;
      }
      else
         t = xp[i+db];

      qp[i] = t;

      for (j = db-1; j >= 0; j--) {
	 mul(s, t, bp[j]);
	 sub(xp[i+j], xp[i+j], s);
      }
   }

   for (i = 0; i < db; i++)
      if (!IsZero(xp[i]))
         return 0;

   mul(qq, q, cq);
   return 1;
}

long PlainDivide(const ZZX& a, const ZZX& b)
{
   ZZX q;
   return PlainDivide(q, a, b);
}


long divide(ZZX& q, const ZZX& a, const ZZ& b)
{
   if (IsZero(b)) {
      if (IsZero(a)) {
         clear(q);
         return 1;
      }
      else
         return 0;
   }

   if (IsOne(b)) {
      q = a;
      return 1;
   }

   if (b == -1) {
      negate(q, a);
      return 1;
   }

   long n = a.rep.length();
   vector(ZZ) res(INIT_SIZE, n);
   long i;

   for (i = 0; i < n; i++) {
      if (!divide(res[i], a.rep[i], b))
         return 0;
   }

   q.rep = res;
   return 1;
}

long divide(const ZZX& a, const ZZ& b)
{
   if (IsZero(b)) return IsZero(a);

   if (IsOne(b) || b == -1) {
      return 1;
   }

   long n = a.rep.length();
   long i;

   for (i = 0; i < n; i++) {
      if (!divide(a.rep[i], b))
         return 0;
   }

   return 1;
}

long divide(ZZX& q, const ZZX& a, long b)
{
   if (b == 0) {
      if (IsZero(a)) {
         clear(q);
         return 1;
      }
      else
         return 0;
   }

   if (b == 1) {
      q = a;
      return 1;
   }

   if (b == -1) {
      negate(q, a);
      return 1;
   }

   long n = a.rep.length();
   vector(ZZ) res(INIT_SIZE, n);
   long i;

   for (i = 0; i < n; i++) {
      if (!divide(res[i], a.rep[i], b))
         return 0;
   }

   q.rep = res;
   return 1;
}

long divide(const ZZX& a, long b)
{
   if (b == 0) return IsZero(a);
   if (b == 1 || b == -1) {
      return 1;
   }

   long n = a.rep.length();
   long i;

   for (i = 0; i < n; i++) {
      if (!divide(a.rep[i], b))
         return 0;
   }

   return 1;
}

   

void content(ZZ& d, const ZZX& f)
{
   ZZ res;
   long i;

   clear(res);
   for (i = 0; i <= deg(f); i++) {
      GCD(res, res, f.rep[i]);
      if (IsOne(res)) break;
   }

   if (sign(LeadCoeff(f)) < 0) negate(res, res);
   d = res;
}

void PrimitivePart(ZZX& pp, const ZZX& f)
{
   if (IsZero(f)) {
      clear(pp);
      return;
   }
 
   ZZ d;

   content(d, f);
   divide(pp, f, d);
}


void BalCopy(ZZX& g, const zz_pX& G)
{
   long p = zz_p::modulus();
   long p2 = p >> 1;
   long n = G.rep.length();
   long i;
   long t;

   g.rep.SetLength(n);
   for (i = 0; i < n; i++) {
      t = rep(G.rep[i]);
      if (t > p2) t = t - p;
      g.rep[i] << t;
   }
}


   
void GCD(ZZX& d, const ZZX& a, const ZZX& b)
{
   if (IsZero(a)) {
      d = b;
      if (sign(LeadCoeff(d)) < 0) negate(d, d);
      return;
   }

   if (IsZero(b)) {
      d = a;
      if (sign(LeadCoeff(d)) < 0) negate(d, d);
      return;
   }

   ZZ c1, c2, c;
   ZZX f1, f2;

   content(c1, a);
   divide(f1, a, c1);

   content(c2, b);
   divide(f2, b, c2);

   GCD(c, c1, c2);

   ZZ ld;
   GCD(ld, LeadCoeff(f1), LeadCoeff(f2));

   ZZX g, h, res;

   ZZ prod;
   set(prod);

   zz_pBak bak;
   bak.save();


   long FirstTime = 1;
   long match;

   long i;
   for (i = 0; ;i++) {
      zz_pFFTInit(i);
      long p = zz_p::modulus();

      if (divide(LeadCoeff(f1), p) || divide(LeadCoeff(f2), p)) continue;

      zz_pX G, F1, F2;
      zz_p  LD;

      F1 << f1;
      F2 << f2;
      LD << ld;

      GCD(G, F1, F2);
      mul(G, G, LD);


      if (deg(G) == 0) { 
         set(res);
         break;
      }

      if (FirstTime || deg(G) < deg(g)) {
         FirstTime = 0;
         prod << p;
         BalCopy(g, G);
         match = 0;
      }
      else if (deg(G) > deg(g)) 
         continue;
      else if (!CRT(g, prod, G)) {
         PrimitivePart(res, g);
         if (divide(f1, res) && divide(f2, res))
            break;
      }

   }

   bak.restore();

   mul(d, res, c);
   if (sign(LeadCoeff(d)) < 0) negate(d, d);
}

void trunc(ZZX& x, const ZZX& a, long m)

// x = a % X^m, output may alias input

{
   if (&x == &a) {
      if (x.rep.length() > m) {
         x.rep.SetLength(m);
         x.normalize();
      }
   }
   else {
      long n;
      long i;
      ZZ* xp;
      const ZZ* ap;

      n = min(a.rep.length(), m);
      x.rep.SetLength(n);

      xp = x.rep.elts();
      ap = a.rep.elts();

      for (i = 0; i < n; i++) xp[i] = ap[i];

      x.normalize();
   }
}



void LeftShift(ZZX& x, const ZZX& a, long n)
{
   if (IsZero(a)) {
      clear(x);
      return;
   }

   long m = a.rep.length();

   x.rep.SetLength(m+n);

   long i;
   for (i = m-1; i >= 0; i--)
      x.rep[i+n] = a.rep[i];

   for (i = 0; i < n; i++)
      clear(x.rep[i]);
}


void RightShift(ZZX& x, const ZZX& a, long n)
{
   long da = deg(a);
   long i;

   if (da < n) {
      clear(x);
      return;
   }

   if (&x != &a)
      x.rep.SetLength(da-n+1);

   for (i = 0; i <= da-n; i++)
      x.rep[i] = a.rep[i+n];

   if (&x == &a)
      x.rep.SetLength(da-n+1);

   x.normalize();
}


void TraceVec(vector(ZZ)& S, const ZZX& ff)
{
   if (!IsOne(LeadCoeff(ff)))
      Error("TraceVec: bad args");

   ZZX f;
   f = ff;

   long n = deg(f);

   S.SetLength(n);

   if (n == 0)
      return;

   long k, i;
   ZZ acc, t;

   S[0] << n;

   for (k = 1; k < n; k++) {
      mul(acc, f.rep[n-k], k);

      for (i = 1; i < k; i++) {
         mul(t, f.rep[n-i], S[k-i]);
         add(acc, acc, t);
      }

      negate(S[k], acc);
   }

}

static
void EuclLength(ZZ& l, const ZZX& a)
{
   long n = a.rep.length();
   long i;
 
   ZZ sum, t;

   clear(sum);
   for (i = 0; i < n; i++) {
      sqr(t, a.rep[i]);
      add(sum, sum, t);
   }

   if (sum > 1) {
      SqrRoot(l, sum);
      add(l, l, 1);
   }
   else
      l = sum;
}



static
long ResBound(const ZZX& a, const ZZX& b)
{
   if (IsZero(a) || IsZero(b)) 
      return 0;

   ZZ t1, t2, t;
   EuclLength(t1, a);
   EuclLength(t2, b);
   power(t1, t1, deg(b));
   power(t2, t2, deg(a));
   mul(t, t1, t2);
   return NumBits(t);
}



void resultant(ZZ& rres, const ZZX& a, const ZZX& b, long deterministic)
{
   if (IsZero(a) || IsZero(b)) {
      clear(rres);
      return;
   }

   zz_pBak zbak;
   zbak.save();

   ZZ_pBak Zbak;
   Zbak.save();

   long instable = 1;

   long bound = 2+ResBound(a, b);

   ZZ res, prod;

   clear(res);
   set(prod);


   long i;
   for (i = 0; ; i++) {
      if (NumBits(prod) > bound)
         break;

      if (!deterministic &&
          !instable && bound > 1000 && NumBits(prod) < 0.25*bound) {

         ZZ P;


         long plen = 90 + NumBits(max(bound, NumBits(res)));

         do {
            RandomPrime(P, plen, 40);
         }
         while (divide(LeadCoeff(a), P) || divide(LeadCoeff(b), P));

         ZZ_pInit(P);

         ZZ_pX A, B;
         A << a;
         B << b;

         ZZ_p t;
         resultant(t, A, B);

         if (CRT(res, prod, rep(t), P))
            instable = 1;
         else
            break;
      }


      zz_pFFTInit(i);
      long p = zz_p::modulus();
      if (divide(LeadCoeff(a), p) || divide(LeadCoeff(b), p))
         continue;

      zz_pX A, B;
      A << a;
      B << b;

      zz_p t;
      resultant(t, A, B);

      instable = CRT(res, prod, rep(t), p);
   }

   rres = res;

   zbak.restore();
   Zbak.restore();
}



long CRT(ZZX& g, ZZ& a, const ZZ_pX& G)
{
   ZZ p;
   p = ZZ_p::modulus();
   ZZ_p a_inv;

   a_inv << a;
   inv(a_inv, a_inv);
 
   ZZ aa, new_a, new_a1;
   ZZ v_a, v_p, t;

   mul(aa, a, rep(a_inv));
   mul(new_a, a, p);

   RightShift(new_a1, new_a, 1);

   long n = g.rep.length();
   g.rep.SetLength(max(n, G.rep.length()));

   long modified = 0;
   long i;

   for (i = 0; i < g.rep.length(); i++) {
      if (i < n) {
         ZZ_p k;
         k << g.rep[i];
         mul(k, k, a_inv);
         negate(k, k);
         mul(v_a, a, rep(k));
         add(v_a, v_a, g.rep[i]);
      }
      else
         clear(v_a);

      if (i < G.rep.length()) 
         mul(v_p, aa, rep(G.rep[i]));
      else
         clear(v_p);

      add(t, v_a, v_p);

      rem(t, t, new_a);
      if (t > new_a1)
         sub(t, t, new_a);

      if (i >= n || t != g.rep[i]) {
         g.rep[i] = t;
         modified = 1;
      }
   }

   a = new_a;

   return modified;
}



void MinPoly(ZZX& gg, const ZZX& a, const ZZX& f)

{
   if (!IsOne(LeadCoeff(f)) || deg(f) < 1 || deg(a) >= deg(f))
      Error("MinPoly: bad args");

   if (IsZero(a)) {
      SetX(gg);
      return;
   }

   ZZ_pBak Zbak;
   Zbak.save();
   zz_pBak zbak;
   zbak.save();

   long n = deg(f);

   long instable = 1;

   ZZ prod;
   ZZX g;

   clear(g);
   set(prod);

   long bound = -1;

   long i;
   for (i = 0; ; i++) {
      if (deg(g) == n) {
         if (bound < 0)
            bound = 2+CharPolyBound(a, f);

         if (NumBits(prod) > bound)
            break;
      }

      if (!instable && 
         (deg(g) < n || 
         (deg(g) == n && bound > 1000 && NumBits(prod) < 0.75*bound))) {

         // guarantees 2^{-80} error probability
         long plen = 90 + max( 2*NumBits(n) + NumBits(MaxBits(f)),
                         max( NumBits(n) + NumBits(MaxBits(a)),
                              NumBits(MaxBits(g)) ));

         ZZ P;
         RandomPrime(P, plen, 40);
         ZZ_pInit(P);


         ZZ_pX A, F, G;
         A << a;
         F << f;
         G << g;

         ZZ_pXModulus FF;
         build(FF, F);

         ZZ_pX H;
         compose(H, G, A, FF);
         
         if (IsZero(H))
            break;

         instable = 1;
      } 
         
      zz_pFFTInit(i);

      zz_pX A, F;
      A << a;
      F << f;

      zz_pXModulus FF;
      build(FF, F);

      zz_pX G;
      MinPoly(G, A, FF);

      if (deg(G) < deg(g))
         continue;

      if (deg(G) > deg(g)) {
         clear(g);
         set(prod);
      }

      instable = CRT(g, prod, G);
   }

   gg = g;

   Zbak.restore();
   zbak.restore();
}


void XGCD(ZZ& rr, ZZX& ss, ZZX& tt, const ZZX& a, const ZZX& b, 
          long deterministic)
{
   ZZ r;

   resultant(r, a, b, deterministic);

   if (IsZero(r)) {
      clear(rr);
      return;
   }

   zz_pBak bak;
   bak.save();

   long i;
   long instable = 1;

   ZZ tmp;
   ZZ prod;
   ZZX s, t;

   set(prod);
   clear(s);
   clear(t);

   for (i = 0; ; i++) {
      zz_pFFTInit(i);
      long p = zz_p::modulus();

      if (divide(LeadCoeff(a), p) || divide(LeadCoeff(b), p) || divide(r, p))
         continue;

      zz_p R;
      R << r;

      zz_pX D, S, T, A, B;
      A << a;
      B << b;

      if (!instable) {
         S << s;
         T << t;
         zz_pX t1, t2;
         mul(t1, A, S); 
         mul(t2, B, T);
         add(t1, t1, t2);

         if (deg(t1) == 0 && ConstTerm(t1) == R)
            mul(prod, prod, p);
         else
            instable = 1;
      }

      if (instable) {
         XGCD(D, S, T, A, B);
   
         mul(S, S, R);
         mul(T, T, R);
   
         tmp = prod;
         long Sinstable = CRT(s, tmp, S);
         long Tinstable = CRT(t, prod, T);
   
         instable = Sinstable || Tinstable;
      }

      if (!instable) {
         long bound1 = NumBits(min(deg(a), deg(s)) + 1) 
                      + MaxBits(a) + MaxBits(s);
         long bound2 = NumBits(min(deg(b), deg(t)) + 1) 
                      + MaxBits(b) + MaxBits(t);

         long bound = 4 + max(NumBits(r), max(bound1, bound2));

         if (NumBits(prod) > bound)
            break;
      }
   }

   rr = r;
   ss = s;
   tt = t;

   bak.restore();
}

void norm(ZZ& x, const ZZX& a, const ZZX& f, long deterministic)
{
   if (!IsOne(LeadCoeff(f)) || deg(a) >= deg(f) || deg(f) <= 0)
      Error("norm: bad args");

   if (IsZero(a)) {
      clear(x);
      return;
   }

   resultant(x, f, a, deterministic);
}

void trace(ZZ& res, const ZZX& a, const ZZX& f)
{
   if (!IsOne(LeadCoeff(f)) || deg(a) >= deg(f) || deg(f) <= 0)
      Error("trace: bad args");

   vector(ZZ) S;

   TraceVec(S, f);

   InnerProduct(res, S, a.rep);
}


void discriminant(ZZ& d, const ZZX& a, long deterministic)
{
   long m = deg(a);

   if (m < 0) {
      clear(d);
      return;
   }

   ZZX a1;
   ZZ res;

   diff(a1, a);
   resultant(res, a, a1, deterministic);
   if (!divide(res, res, LeadCoeff(a)))
      Error("discriminant: inexact division");

   m = m & 3;
   if (m >= 2)
      negate(res, res);

   d = res;
}


void MulMod(ZZX& x, const ZZX& a, const ZZX& b, const ZZX& f)
{
   if (deg(a) >= deg(f) || deg(b) >= deg(f) || !IsOne(LeadCoeff(f)))
      Error("MulMod: bad args");

   ZZX t;
   mul(t, a, b);
   rem(x, t, f);
}

void SqrMod(ZZX& x, const ZZX& a, const ZZX& f)
{
   if (deg(a) >= deg(f) || !IsOne(LeadCoeff(f)))
      Error("MulMod: bad args");

   ZZX t;
   sqr(t, a);
   rem(x, t, f);
}



void MulByXMod(ZZX& h, const ZZX& a, const ZZX& f)
{
   long i, n, m;
   ZZ* hh;
   const ZZ *aa, *ff;

   ZZ t, z;


   n = deg(f);
   m = deg(a);

   if (m >= n || !IsOne(LeadCoeff(f)))
      Error("MulByXMod: bad args");

   if (m < 0) {
      clear(h);
      return;
   }

   if (m < n-1) {
      h.rep.SetLength(m+2);
      hh = h.rep.elts();
      aa = a.rep.elts();
      for (i = m+1; i >= 1; i--)
         hh[i] = aa[i-1];
      clear(hh[0]);
   }
   else {
      h.rep.SetLength(n);
      hh = h.rep.elts();
      aa = a.rep.elts();
      ff = f.rep.elts();
      negate(z, aa[n-1]);
      for (i = n-1; i >= 1; i--) {
         mul(t, z, ff[i]);
         add(hh[i], aa[i-1], t);
      }
      mul(hh[0], z, ff[0]);
      h.normalize();
   }
}

static
void EuclLength1(ZZ& l, const ZZX& a)
{
   long n = a.rep.length();
   long i;
 
   ZZ sum, t;

   clear(sum);
   for (i = 0; i < n; i++) {
      sqr(t, a.rep[i]);
      add(sum, sum, t);
   }

   abs(t, ConstTerm(a));
   mul(t, t, 2);
   add(t, t, 1);
   add(sum, sum, t);

   if (sum > 1) {
      SqrRoot(l, sum);
      add(l, l, 1);
   }
   else
      l = sum;
}


long CharPolyBound(const ZZX& a, const ZZX& f)
// This computes a bound on the size of the
// coefficients of the characterstic polynomial.
// It use the relation characterization of the char poly as
// resultant_y(f(y), x-a(y)), and then interpolates this
// through complex primimitive (deg(f)+1)-roots of unity.

{
   if (IsZero(a) || IsZero(f))
      Error("CharPolyBound: bad args");

   ZZ t1, t2, t;
   EuclLength1(t1, a);
   EuclLength(t2, f);
   power(t1, t1, deg(f));
   power(t2, t2, deg(a));
   mul(t, t1, t2);
   return NumBits(t);
}




// vectors

vector_impl(ZZX)

vector_eq_impl(ZZX)

vector_io_impl(ZZX)

