
#include "mat_ZZ.h"

matrix_impl(ZZ)
matrix_io_impl(ZZ)
matrix_eq_impl(ZZ)



void add(matrix(ZZ)& X, const matrix(ZZ)& A, const matrix(ZZ)& B)  
{  
   long n = A.NumRows();  
   long m = A.NumCols();  
  
   if (B.NumRows() != n || B.NumCols() != m)   
      Error("matrix add: dimension mismatch");  
  
   X.SetDims(n, m);  
  
   long i, j;  
   for (i = 1; i <= n; i++)   
      for (j = 1; j <= m; j++)  
         add(X(i,j), A(i,j), B(i,j));  
}  
  
void sub(matrix(ZZ)& X, const matrix(ZZ)& A, const matrix(ZZ)& B)  
{  
   long n = A.NumRows();  
   long m = A.NumCols();  
  
   if (B.NumRows() != n || B.NumCols() != m)  
      Error("matrix sub: dimension mismatch");  
  
   X.SetDims(n, m);  
  
   long i, j;  
   for (i = 1; i <= n; i++)  
      for (j = 1; j <= m; j++)  
         sub(X(i,j), A(i,j), B(i,j));  
}  
  
void mul_aux(matrix(ZZ)& X, const matrix(ZZ)& A, const matrix(ZZ)& B)  
{  
   long n = A.NumRows();  
   long l = A.NumCols();  
   long m = B.NumCols();  
  
   if (l != B.NumRows())  
      Error("matrix mul: dimension mismatch");  
  
   X.SetDims(n, m);  
  
   long i, j, k;  
   ZZ acc, tmp;  
  
   for (i = 1; i <= n; i++) {  
      for (j = 1; j <= m; j++) {  
         clear(acc);  
         for(k = 1; k <= l; k++) {  
            mul(tmp, A(i,k), B(k,j));  
            add(acc, acc, tmp);  
         }  
         X(i,j) = acc;  
      }  
   }  
}  
  
  
void mul(matrix(ZZ)& X, const matrix(ZZ)& A, const matrix(ZZ)& B)  
{  
   if (&X == &A || &X == &B) {  
      matrix(ZZ) tmp;  
      mul_aux(tmp, A, B);  
      X = tmp;  
   }  
   else  
      mul_aux(X, A, B);  
}  
  
  
static
void mul_aux(vector(ZZ)& x, const matrix(ZZ)& A, const vector(ZZ)& b)  
{  
   long n = A.NumRows();  
   long l = A.NumCols();  
  
   if (l != b.length())  
      Error("matrix mul: dimension mismatch");  
  
   x.SetLength(n);  
  
   long i, k;  
   ZZ acc, tmp;  
  
   for (i = 1; i <= n; i++) {  
      clear(acc);  
      for (k = 1; k <= l; k++) {  
         mul(tmp, A(i,k), b(k));  
         add(acc, acc, tmp);  
      }  
      x(i) = acc;  
   }  
}  
  
  
void mul(vector(ZZ)& x, const matrix(ZZ)& A, const vector(ZZ)& b)  
{  
   vector(ZZ) tmp;
   mul_aux(tmp, A, b);
   x = tmp;
}  

static
void mul_aux(vector(ZZ)& x, const vector(ZZ)& a, const matrix(ZZ)& B)  
{  
   long n = B.NumRows();  
   long l = B.NumCols();  
  
   if (n != a.length())  
      Error("matrix mul: dimension mismatch");  
  
   x.SetLength(l);  
  
   long i, k;  
   ZZ acc, tmp;  
  
   for (i = 1; i <= l; i++) {  
      clear(acc);  
      for (k = 1; k <= n; k++) {  
         mul(tmp, a(k), B(k,i));
         add(acc, acc, tmp);  
      }  
      x(i) = acc;  
   }  
}  

void mul(vector(ZZ)& x, const vector(ZZ)& a, const matrix(ZZ)& B)
{
   vector(ZZ) tmp;
   mul_aux(tmp, a, B);
   x = tmp;
}

     
  
void ident(matrix(ZZ)& X, long n)  
{  
   X.SetDims(n, n);  
   long i, j;  
  
   for (i = 1; i <= n; i++)  
      for (j = 1; j <= n; j++)  
         if (i == j)  
            set(X(i, j));  
         else  
            clear(X(i, j));  
} 

static
long DetBound(const matrix(ZZ)& a)
{
   long n = a.NumRows();
   long i;
   ZZ res, t1;

   set(res);

   for (i = 0; i < n; i++) {
      InnerProduct(t1, a[i], a[i]);
      if (t1 > 1) {
         SqrRoot(t1, t1);
         add(t1, t1, 1);
      }
      mul(res, res, t1);
   }

   return NumBits(res);
}



   

void determinant(ZZ& rres, const matrix(ZZ)& a, long deterministic)
{
   long n = a.NumRows();
   if (a.NumCols() != n)
      Error("determinant: nonsquare matrix");

   if (n == 0) {
      set(rres);
      return;
   }

   zz_pBak zbak;
   zbak.save();

   ZZ_pBak Zbak;
   Zbak.save();

   long instable = 1;

   long bound = 2+DetBound(a);

   ZZ res, prod;

   clear(res);
   set(prod);


   long i;
   for (i = 0; ; i++) {
      if (NumBits(prod) > bound)
         break;

      if (!deterministic &&
          !instable && bound > 1000 && NumBits(prod) < 0.25*bound) {
         ZZ P;


         long plen = 90 + NumBits(max(bound, NumBits(res)));
         RandomPrime(P, plen, 40);

         ZZ_pInit(P);

         matrix(ZZ_p) A;
         A << a;

         ZZ_p t;
         determinant(t, A);

         if (CRT(res, prod, rep(t), P))
            instable = 1;
         else
            break;
      }


      zz_pFFTInit(i);
      long p = zz_p::modulus();

      matrix(zz_p) A;
      A << a;

      zz_p t;
      determinant(t, A);

      instable = CRT(res, prod, rep(t), p);
   }

   rres = res;

   zbak.restore();
   Zbak.restore();
}




void operator<<(matrix(zz_p)& x, const matrix(ZZ)& a)
{
   long n = a.NumRows();
   long m = a.NumCols();
   long i;

   x.SetDims(n, m);
   for (i = 0; i < n; i++)
      x[i] << a[i];
}

void operator<<(matrix(ZZ_p)& x, const matrix(ZZ)& a)
{
   long n = a.NumRows();
   long m = a.NumCols();
   long i;

   x.SetDims(n, m);
   for (i = 0; i < n; i++)
      x[i] << a[i];
}

long IsIdent(const matrix(ZZ)& A, long n)
{
   if (A.NumRows() != n || A.NumCols() != n)
      return 0;

   long i, j;

   for (i = 1; i <= n; i++)
      for (j = 1; j <= n; j++)
         if (i != j) {
            if (!IsZero(A(i, j))) return 0;
         }
         else {
            if (!IsOne(A(i, j))) return 0;
         }

   return 1;
}


void transpose(matrix(ZZ)& X, const matrix(ZZ)& A)
{
   long n = A.NumRows();
   long m = A.NumCols();

   long i, j;

   if (&X == & A) {
      if (n == m)
         for (i = 1; i <= n; i++)
            for (j = i+1; j <= n; j++)
               swap(X(i, j), X(j, i));
      else {
         matrix(ZZ) tmp;
         tmp.SetDims(m, n);
         for (i = 1; i <= n; i++)
            for (j = 1; j <= m; j++)
               tmp(j, i) = A(i, j);
         X.kill();
         X = tmp;
      }
   }
   else {
      X.SetDims(m, n);
      for (i = 1; i <= n; i++)
         for (j = 1; j <= m; j++)
            X(j, i) = A(i, j);
   }
}


long CRT(matrix(ZZ)& g, ZZ& a, const matrix(zz_p)& G)
{
   long n = g.NumRows();
   long m = g.NumCols();

   if (G.NumRows() != n || G.NumCols() != m) 
      Error("CRT: dimension mismatch");

   long p = zz_p::modulus();
   zz_p a_inv;

   a_inv << a;
   inv(a_inv, a_inv);
 
   ZZ aa, new_a, new_a1;
   ZZ v_a, v_p, t;

   mul(aa, a, rep(a_inv));
   mul(new_a, a, p);

   RightShift(new_a1, new_a, 1);

   long modified = 0;
   long i, j;

   for (i = 0; i < n; i++) 
      for (j = 0; j < m; j++) {
         zz_p k;
         k << g[i][j];
         mul(k, k, a_inv);
         negate(k, k);
         mul(v_a, a, rep(k));
         add(v_a, v_a, g[i][j]);
   
         mul(v_p, aa, rep(G[i][j]));
         add(t, v_a, v_p);
   
         rem(t, t, new_a);
         if (t > new_a1)
            sub(t, t, new_a);
   
         if (t != g[i][j]) {
            g[i][j] = t;
            modified = 1;
         }
      }

   a = new_a;

   return modified;
}

void mul(matrix(ZZ)& X, const matrix(ZZ)& A, const ZZ& b_in)
{
   ZZ b = b_in;
   long n = A.NumRows();
   long m = A.NumCols();

   X.SetDims(n, m);

   long i, j;
   for (i = 0; i < n; i++)
      for (j = 0; j < m; j++)
         mul(X[i][j], A[i][j], b);
}


static
void ExactDiv(vector(ZZ)& x, const ZZ& d)
{
   long n = x.length();
   long i;

   for (i = 0; i < n; i++)
      if (!divide(x[i], x[i], d))
         Error("inexact division");
}

static
void ExactDiv(matrix(ZZ)& x, const ZZ& d)
{
   long n = x.NumRows();
   long m = x.NumCols();
   
   long i, j;

   for (i = 0; i < n; i++)
      for (j = 0; j < m; j++)
         if (!divide(x[i][j], x[i][j], d))
            Error("inexact division");
}

void diag(matrix(ZZ)& X, long n, const ZZ& d_in)  
{  
   ZZ d = d_in;
   X.SetDims(n, n);  
   long i, j;  
  
   for (i = 1; i <= n; i++)  
      for (j = 1; j <= n; j++)  
         if (i == j)  
            X(i, j) = d;  
         else  
            clear(X(i, j));  
} 

long IsDiag(const matrix(ZZ)& A, long n, const ZZ& d)
{
   if (A.NumRows() != n || A.NumCols() != n)
      return 0;

   long i, j;

   for (i = 1; i <= n; i++)
      for (j = 1; j <= n; j++)
         if (i != j) {
            if (!IsZero(A(i, j))) return 0;
         }
         else {
            if (A(i, j) != d) return 0;
         }

   return 1;
}




void solve(ZZ& d_out, vector(ZZ)& x_out,
           const matrix(ZZ)& A, const vector(ZZ)& b,
           long deterministic)
{
   long n = A.NumRows();
   
   if (A.NumCols() != n)
      Error("solve: nonsquare matrix");

   if (b.length() != n)
      Error("solve: dimension mismatch");

   if (n == 0) {
      set(d_out);
      x_out.SetLength(0);
      return;
   }

   zz_pBak zbak;
   zbak.save();

   ZZ_pBak Zbak;
   Zbak.save();

   vector(ZZ) x(INIT_SIZE, n);
   ZZ d, d1;

   ZZ d_prod, x_prod;
   set(d_prod);
   set(x_prod);

   long d_instable = 1;
   long x_instable = 1;

   long check = 0;


   vector(ZZ) y, b1;

   long i;
   long bound = 2+DetBound(A);

   for (i = 0; ; i++) {
      if ((check || IsZero(d)) && !d_instable) {
         if (NumBits(d_prod) > bound) {
            break;
         }
         else if (!deterministic &&
                  bound > 1000 && NumBits(d_prod) < 0.25*bound) {

            ZZ P;
   
            long plen = 90 + NumBits(max(bound, NumBits(d)));
            RandomPrime(P, plen, 40);
   
            ZZ_pInit(P);
   
            matrix(ZZ_p) AA;
            AA << A;
   
            ZZ_p dd;
            determinant(dd, AA);
   
            if (CRT(d, d_prod, rep(dd), P))
               d_instable = 1;
            else 
               break;
         }
      }


      zz_pFFTInit(i);
      long p = zz_p::modulus();

      matrix(zz_p) AA;
      AA << A;

      if (!check) {
         vector(zz_p) bb, xx;
         bb << b;

         zz_p dd; 

         solve(dd, xx, AA, bb);

         d_instable = CRT(d, d_prod, rep(dd), p);
         if (!IsZero(dd)) {
            mul(xx, xx, dd);
            x_instable = CRT(x, x_prod, xx);
         }
         else
            x_instable = 1;

         if (!d_instable && !x_instable) {
            mul(y, x, A);
            mul(b1, b, d);
            if (y == b1) {
               d1 = d;
               check = 1;
            }
         }
      }
      else {
         zz_p dd;
         determinant(dd, AA);
         d_instable = CRT(d, d_prod, rep(dd), p);
      }
   }

   if (check && d1 != d) {
      mul(x, x, d);
      ExactDiv(x, d1);
   }

   d_out = d;
   if (check) x_out = x;

   zbak.restore();
   Zbak.restore();
}

void inv(ZZ& d_out, matrix(ZZ)& x_out, const matrix(ZZ)& A, long deterministic)
{
   long n = A.NumRows();
   
   if (A.NumCols() != n)
      Error("solve: nonsquare matrix");

   if (n == 0) {
      set(d_out);
      x_out.SetDims(0, 0);
      return;
   }

   zz_pBak zbak;
   zbak.save();

   ZZ_pBak Zbak;
   Zbak.save();

   matrix(ZZ) x(INIT_SIZE, n, n);
   ZZ d, d1;

   ZZ d_prod, x_prod;
   set(d_prod);
   set(x_prod);

   long d_instable = 1;
   long x_instable = 1;

   long check = 0;


   matrix(ZZ) y;

   long i;
   long bound = 2+DetBound(A);

   for (i = 0; ; i++) {
      if ((check || IsZero(d)) && !d_instable) {
         if (NumBits(d_prod) > bound) {
            break;
         }
         else if (!deterministic &&
                  bound > 1000 && NumBits(d_prod) < 0.25*bound) {

            ZZ P;
   
            long plen = 90 + NumBits(max(bound, NumBits(d)));
            RandomPrime(P, plen, 40);
   
            ZZ_pInit(P);
   
            matrix(ZZ_p) AA;
            AA << A;
   
            ZZ_p dd;
            determinant(dd, AA);
   
            if (CRT(d, d_prod, rep(dd), P))
               d_instable = 1;
            else 
               break;
         }
      }


      zz_pFFTInit(i);
      long p = zz_p::modulus();

      matrix(zz_p) AA;
      AA << A;

      if (!check) {
         matrix(zz_p) xx;

         zz_p dd; 

         inv(dd, xx, AA);

         d_instable = CRT(d, d_prod, rep(dd), p);
         if (!IsZero(dd)) {
            mul(xx, xx, dd);
            x_instable = CRT(x, x_prod, xx);
         }
         else
            x_instable = 1;

         if (!d_instable && !x_instable) {
            mul(y, x, A);
            if (IsDiag(y, n, d)) {
               d1 = d;
               check = 1;
            }
         }
      }
      else {
         zz_p dd;
         determinant(dd, AA);
         d_instable = CRT(d, d_prod, rep(dd), p);
      }
   }

   if (check && d1 != d) {
      mul(x, x, d);
      ExactDiv(x, d1);
   }

   d_out = d;
   if (check) x_out = x;

   zbak.restore();
   Zbak.restore();
}
