
#include "RR.h"
#include "IsFinite.h"
#include <math.h>
#include <float.h>


long RR::prec = 90;

void RR::SetPrecision(long p)
{
   if (p < 53)
      p = 53;

   if (p >= (1L << (ZZ_BITS_PER_LONG-3)))
      Error("RR: precision too high");

   prec = p;
}

long RR::oprec = 10;

void RR::SetOutputPrecision(long p)
{
   if (p < 1)
      p = 1;

   oprec = p;
}



void normalize(RR& z, const RR& y, long residual = 0)
{
   long len = NumBits(y.x);

   if (len > RR::prec) {
      const long *a = y.x.rep;

      long sgn;

      long direction;
      long p, wh, bl;

      if (a[0] > 0)
         sgn = 1;
      else
         sgn = -1;

      p = len - RR::prec - 1;
      bl = (p/ZZ_NBITS);
      wh = 1L << (p - ZZ_NBITS*bl);
      bl++;

      if (a[bl] & wh) {
         // bit is 1...we have to see if lower bits are all 0
         // in order to implement "round to even"

         if (a[bl] & (wh - 1)) 
            direction = 1;
         else {
            long i = bl - 1;
            while (i > 0 && a[i] == 0) i--;
            if (i > 0)
               direction = 1;
            else
               direction = 0;
         }

         // use residual to break ties

         if (direction == 0 && residual != 0) {
            if (residual == sgn)
               direction = 1;
            else 
               direction = -1;
         }

         if (direction == 0) {
            // round to even

            wh = wh << 1;
            if (wh == ZZ_RADIX) {
               wh = 1;
               bl++;
            }

            if (a[bl] & wh)
               direction = 1;
            else
               direction = -1;
         }
      }
      else
         direction = -1;

      RightShift(z.x, y.x, len - RR::prec);
      if (direction == 1) 
         add(z.x, z.x, sgn);

      z.e = y.e + len - RR::prec;
   }
   else if (len == 0) {
      clear(z.x);
      z.e = 0;
   }
   else {
      z.x = y.x;
      z.e = y.e;
   }

   if (!IsOdd(z.x))
      z.e += MakeOdd(z.x);

   if (z.e >= (1L << (ZZ_BITS_PER_LONG-3)))
      Error("RR: overflow");

   if (z.e <= -(1L << (ZZ_BITS_PER_LONG-3)))
      Error("RR: underflow");
}


void RR::operator=(const RR& a)
{
   normalize(*this, a);
}

RR::RR(const RR& a)
{
   normalize(*this, a);
}

long IsZero(const RR& a)
{
   return IsZero(a.x);
}

long IsOne(const RR& a)
{
   return a.e == 0 && IsOne(a.x);
}

long sign(const RR& a)
{
   return sign(a.x);
}

void clear(RR& z)
{
   z.e = 0;
   clear(z.x);
}

void set(RR& z)
{
   z.e = 0;
   set(z.x);
}


void add(RR& z, const RR& a, const RR& b)
{
   static RR t;

   if (IsZero(a.x)) {
      z = b;
      return;
   }

   if (IsZero(b.x)) {
      z = a;
      return;
   }

   if (a.e > b.e) {
      if (a.e-b.e - max(RR::prec-NumBits(a.x),0) >= NumBits(b.x) + 2)
         normalize(z, a, sign(b));
      else {
         LeftShift(t.x, a.x, a.e-b.e);
         add(t.x, t.x, b.x);
         t.e = b.e;
         normalize(z, t);
      }
   }
   else if (a.e < b.e) {
      if (b.e-a.e - max(RR::prec-NumBits(b.x),0) >= NumBits(a.x) + 2)
         normalize(z, b, sign(a));
      else {
         LeftShift(t.x, b.x, b.e-a.e);
         add(t.x, t.x, a.x);
         t.e = a.e;
         normalize(z, t);
      }
   }
   else {
      add(t.x, a.x, b.x);
      t.e = a.e;
      normalize(z, t);
   }
}

void sub(RR& z, const RR& a, const RR& b)
{
   static RR t;

   if (IsZero(a.x)) {
      negate(z, b);
      return;
   }

   if (IsZero(b.x)) {
      z = a;
      return;
   }

   if (a.e > b.e) {
      if (a.e-b.e - max(RR::prec-NumBits(a.x),0) >= NumBits(b.x) + 2)
         normalize(z, a, -sign(b));
      else {
         LeftShift(t.x, a.x, a.e-b.e);
         sub(t.x, t.x, b.x);
         t.e = b.e;
         z = t;
      }
   }
   else if (a.e < b.e) {
      if (b.e-a.e - max(RR::prec-NumBits(b.x),0) >= NumBits(a.x) + 2) {
         normalize(z, b, -sign(a));
         negate(z.x, z.x);
      }
      else {
         LeftShift(t.x, b.x, b.e-a.e);
         sub(t.x, a.x, t.x);
         t.e = a.e;
         z = t;
      }
   }
   else {
      sub(t.x, a.x, b.x);
      t.e = a.e;
      normalize(z, t);
   }
}

void negate(RR& z, const RR& a)
{
   z = a;
   negate(z.x, z.x);
}

void abs(RR& z, const RR& a)
{
   z = a;
   abs(z.x, z.x);
}


void mul(RR& z, const RR& a, const RR& b)
{
   static RR t;

   mul(t.x, a.x, b.x);
   t.e = a.e + b.e;
   z = t;
}

void sqr(RR& z, const RR& a)
{
   static RR t;

   sqr(t.x, a.x);
   t.e = a.e + a.e;
   z = t;
}

void div(RR& z, const RR& a, const RR& b)
{
   if (IsZero(b))
      Error("RR: division by zero");

   if (IsZero(a)) {
      clear(z);
      return;
   }

   long la = NumBits(a.x);
   long lb = NumBits(b.x);

   long neg = (sign(a) != sign(b));

   long k = RR::prec - la + lb + 1;
   if (k < 0) k = 0;

   static RR t;
   static ZZ A, B, R;

   abs(A, a.x);
   LeftShift(A, A, k);

   abs(B, b.x);
   DivRem(t.x, R, A, B);

   t.e = a.e - b.e - k;

   normalize(z, t, !IsZero(R));

   if (neg)
      negate(z.x, z.x);
}

void SqrRoot(RR& z, const RR& a)
{
   if (sign(a) < 0)
      Error("RR: attempt to take square root of negative number");

   if (IsZero(a)) {
      clear(z);
      return;
   }

   RR t;
   ZZ T1, T2;
   long k;

   k = 2*RR::prec - NumBits(a.x) + 1;

   if (k < 0) k = 0;

   if ((a.e - k) & 1) k++;

   LeftShift(T1, a.x, k);
   SqrRoot(t.x, T1);
   t.e = (a.e - k)/2;
   sqr(T2, T1);

   normalize(z, t, T2 < T1);
}
   




void swap(RR& a, RR& b)
{
   swap(a.x, b.x);
   swap(a.e, b.e);
}

long compare(const RR& a, const RR& b)
{
   static RR t;

   sub(t, a, b);
   return sign(t);
}



long operator==(const RR& a, const RR& b) 
{
   return a.e == b.e && a.x == b.x;
}


void trunc(RR& z, const RR& a)
{
   static RR t;

   if (a.e >= 0) 
      z = a;
   else {
      RightShift(t.x, a.x, -a.e);
      t.e = 0;
      z = t;
   }
}

void floor(RR& z, const RR& a)
{
   static RR t;

   if (a.e >= 0) 
      z = a;
   else {
      RightShift(t.x, a.x, -a.e);
      if (sign(a.x) < 0)
         add(t.x, t.x, -1);
      t.e = 0;
      z = t;
   }
}

void ceil(RR& z, const RR& a)
{
   static RR t;

   if (a.e >= 0)
      z = a;
   else {
      RightShift(t.x, a.x, -a.e);
      if (sign(a.x) > 0)
         add(t.x, t.x, 1);
      t.e = 0;
      z = t;
   }
}
   

void operator<<(RR& z, const ZZ& a)
{
   static RR t;

   t.x = a;
   t.e = 0;

   z = t;
}

void operator<<(RR& z, long a)
{
   static ZZ t;
   t << a;
   z << t;
}


void operator<<(RR& z, double a)
{
   if (!IsFinite(&a))
      Error("RR: conversion of a non-finite double");

   if (a == 0) {
      clear(z);
      return;
   }

   int e;
   double f;
   static RR t;

   f = frexp(a, &e);

   f = f * ZZ_FDOUBLE_PRECISION;
   f = f * 4;

   t.x << f;
   t.e = e - (ZZ_DOUBLE_PRECISION + 1);

   z = t;
}


void operator<<(ZZ& z, const RR& a)
{
   static RR t;

   floor(t, a);

   LeftShift(z, t.x, t.e);
}

void operator<<(long& z, const RR& a)
{
   static ZZ t;
   t << a;
   z << t;
}

void operator<<(double& z, const RR& a)
{
   long old_p;
   double x;
   int e;

   e = a.e;

   if (e != a.e) Error("RR: overflow in conversion to double");

   old_p = RR::prec;
   RR::prec = ZZ_DOUBLE_PRECISION;

   x << a.x;

   RR::prec = old_p;

   z = ldexp(x, e);
}

void add(RR& z, const RR& a, long b)
{
   static RR B;
   B << b;
   add(z, a, B);
}


void add(RR& z, const RR& a, double b)
{
   static RR B;
   B << b;
   add(z, a, B);
}


void add(RR& z, long a, const RR& b)
{
   static RR A;
   A << a;
   add(z, A, b);
}


void add(RR& z, double a, const RR& b)
{
   static RR A;
   A << a;
   add(z, A, b);
}


void sub(RR& z, const RR& a, long b)
{
   static RR B;
   B << b;
   sub(z, a, B);
}


void sub(RR& z, const RR& a, double b)
{
   static RR B;
   B << b;
   sub(z, a, B);
}


void sub(RR& z, long a, const RR& b)
{
   static RR A;
   A << a;
   sub(z, A, b);
}


void sub(RR& z, double a, const RR& b)
{
   static RR A;
   A << a;
   sub(z, A, b);
}



void mul(RR& z, const RR& a, long b)
{
   static RR B;
   B << b;
   mul(z, a, B);
}


void mul(RR& z, const RR& a, double b)
{
   static RR B;
   B << b;
   mul(z, a, B);
}


void mul(RR& z, long a, const RR& b)
{
   static RR A;
   A << a;
   mul(z, A, b);
}



void mul(RR& z, double a, const RR& b)
{
   static RR A;
   A << a;
   mul(z, A, b);
}





void div(RR& z, const RR& a, long b)
{
   static RR B;
   B << b;
   div(z, a, B);
}


void div(RR& z, const RR& a, double b)
{
   static RR B;
   B << b;
   div(z, a, B);
}


void div(RR& z, long a, const RR& b)
{
   static RR A;
   A << a;
   div(z, A, b);
}


void div(RR& z, double a, const RR& b)
{
   static RR A;
   A << a;
   div(z, A, b);
}



long compare(const RR& a, long b)
{
   static RR B;
   B << b;
   return compare(a, B);
}


long compare(const RR& a, double b)
{
   static RR B;
   B << b;
   return compare(a, B);
}


long operator==(const RR& a, long b) 
{
   static RR B;
   B << b;
   return a == B;
}


long operator==(const RR& a, double b) 
{
   static RR B;
   B << b;
   return a == B;
}


void power(RR& z, const RR& a, long e)
{
   RR b, res;
   long neg;

   b = a;
   if (e < 0) {
      e = -e;
      neg = 1;
   }
   else
      neg = 0;

   set(res);
   long n = NumBits(e);
   long i;

   for (i = n-1; i >= 0; i--) {
      sqr(res, res);
      if (bit(e, i))
         mul(res, res, b);
   }

   if (neg) 
      inv(z, res);
   else
      z = res;
}

void power(RR& z, long a, long e)
{
   RR A;
   A << a;
   power(z, A, e);
}

void power(RR& z, int a, long e)
{
   RR A;
   A << a;
   power(z, A, e);
}

void power(RR& z, double a, long e)
{
   RR A;
   A << a;
   power(z, A, e);
}
   

ostream& operator<<(ostream& s, const RR& a)
{
   if (IsZero(a)) {
      s << "0";
      return s;
   }

   long old_p = RR::precision();

   RR::SetPrecision(long(RR::OutputPrecision()*3.321928095) + 10);

   RR b;
   long neg;

   if (a < 0) {
      negate(b, a);
      neg = 1;
   }
   else {
      b = a;
      neg = 0;
   }

   long k;

   k = long((RR::OutputPrecision()*3.321928095-NumBits(b.mantissa())
            -b.exponent()) / 3.321928095);


   RR c;

   power(c, 10, k);
   mul(b, b, c);

   power(c, 10, RR::OutputPrecision());

   while (b < c) {
      mul(b, b, 10);
      k++;
   }

   while (b >= c) {
      div(b, b, 10);
      k--;
   }

   add(b, b, 0.5);
   k = -k;

   ZZ B;
   B << b;

   char *bp = new char[RR::OutputPrecision()+10];

   if (!bp) Error("RR output: out of memory");

   long len, i;

   len = 0;
   do {
      bp[len] = DivRem(B, B, 10) + '0';
      len++;
   } while (B > 0);

   for (i = 0; i < len/2; i++) {
      char tmp;
      tmp = bp[i];
      bp[i] = bp[len-1-i];
      bp[len-1-i] = tmp;
   }

   i = len-1;
   while (bp[i] == '0') i--;

   k += (len-1-i);
   len = i+1;

   bp[len] = '\0';

   if (k > 3 || k < -len - 3) {
      // use scientific notation

      if (neg) s << "-";
      s << "0." << bp << "e" << k + len;
   }
   else if (k >= 0) {
      if (neg) s << "-";
      s << bp;
      for (i = 0; i < k; i++) 
         s << "0";
   }
   else if (k <= -len) {
      if (neg) s << "-";
      s << "0.";
      for (i = 0; i < -len-k; i++)
         s << "0";
      s << bp;
   }
   else {
      if (neg) s << "-";
      for (i = 0; i < len+k; i++)
         s << bp[i];

      s << ".";

      for (i = len+k; i < len; i++)
         s << bp[i];
   }

   RR::SetPrecision(old_p);
   delete [] bp;
   return s;
}

istream& operator>>(istream& s, RR& x)
{
   long c;
   long sign;
   ZZ a, b;

   if (!s) Error("bad RR input");


   c = s.peek();
   while (c == ' ' || c == '\n' || c == '\t') {
      s.get();
      c = s.peek();
   }

   if (c == '-') {
      sign = -1;


      s.get();
      c = s.peek();
      while (c == ' ' || c == '\n' || c == '\t') {
         s.get();
         c = s.peek();
      }
   }
   else
      sign = 1;

   long got1 = 0;
   long got_dot = 0;
   long got2 = 0;

   a << 0;
   b << 1;

   if (c >= '0' && c <= '9') {
      got1 = 1;

      while (c >= '0' && c <= '9') {
         mul(a, a, 10);
         add(a, a, c-'0');
         s.get();
         c = s.peek();
      }
   }

   if (c == '.') {
      got_dot = 1;

      s.get();
      c = s.peek();

      if (c >= '0' && c <= '9') {
         got2 = 1;
   
         while (c >= '0' && c <= '9') {
            mul(a, a, 10);
            add(a, a, c-'0');
            mul(b, b, 10);
            s.get();
            c = s.peek();
         }
      }
   }

   if (got_dot && !got1 && !got2)  Error("bad RR input");

   ZZ e;

   long got_e = 0;
   long e_sign;

   if (c == 'e' || c == 'E') {
      got_e = 1;

      s.get();
      c = s.peek();

      if (c == '-') {
         e_sign = -1;
         s.get();
         c = s.peek();
      }
      else if (c == '+') {
         e_sign = 1;
         s.get();
         c = s.peek();
      }
      else
         e_sign = 1;

      if (c < '0' || c > '9') Error("bad RR input");

      e << 0;
      while (c >= '0' && c <= '9') {
         mul(e, e, 10);
         add(e, e, c-'0');
         s.get();
         c = s.peek();
      }
   }

   if (!got1 && !got2 && !got_e) Error("bad RR input");

   RR t1, t2, v;

   long old_p = RR::precision();

   if (got1 || got2) {
      RR::SetPrecision(max(NumBits(a), NumBits(b)));
      t1 << a;
      t2 << b;
      if (got_e)
         RR::SetPrecision(old_p + 10);
      else
         RR::SetPrecision(old_p);
      div(v, t1, t2);
   }
   else
      set(v);

   if (sign < 0)
      negate(v, v);

   if (got_e) {
      if (e >= 1L << (ZZ_BITS_PER_LONG-3)) Error("RR input overflow");
      long E;
      E << e;
      if (e_sign < 0) E = -E;
      RR::SetPrecision(old_p + 10);
      power(t1, 10, E);
      mul(v, v, t1);
      RR::SetPrecision(old_p);
   }

   x = v;
   return s;
}


void operator<<(RR& z, xdouble a)
{
   static RR t1, t2;


   t1 << a.mantissa();

   if (a.exponent() >= 0)
      power(t2, XD_BOUND, a.exponent());
   else
      power(t2, 1/XD_BOUND, -a.exponent());

   mul(z, t1, t2);
}


void operator<<(xdouble& z, const RR& a)
{
   long old_p;
   double x;
   xdouble y;
   int e;

   old_p = RR::prec;
   RR::prec = ZZ_DOUBLE_PRECISION;

   x << a.x;
   power(y, 2, a.e);
   z = xdouble(x)*y;

   RR::prec = old_p;
}
      
