#include "defs.h"
#include "mp.e"
#include "mp.h"

void
mp_bern		WITH_2_ARGS(
	mp_float_array,	x,
	mp_int,		n
)
/*
Computes the Bernoulli numbers b2 = 1/6, b4 = -1/30, b6 = 1/42, b8 =
-1/30, b10 = 5/66, b12 = -691/2730, etc., defined by the generating
function y/(exp(y) - 1).  The Bernoulli numbers b2, .., b(2n) are
returned in the array x, with b(2j) placed in x[j - 1].  Rounding
options are not implemented, and no guard digits are used; the relative
error in b(2j) is O(j^2 * b^(1 - t)).  If n is negative, int_abs(n)
Bernoulli numbers are returned, but the comment above about relative
error no longer applies - instead the precision decreases linearly from
the first to the last bernoulli number (this is usually sufficient if
the bernoulli numbers are to be used as coefficients in an
Euler-Maclaurin expansion).  The time taken is
O(t * int_min(n, t)^2 + n * M(t)).
*/
{
    mp_ptr_type         base;
    mp_round_type       save_round = round;
    mp_int              el_size, nl, n2, i, j;
    mp_base_type        b;
    mp_length           t;
    mp_acc_float        temp1, temp2;
    mp_ptr_type         temp1_ptr, temp2_ptr;

    if (!n)
	return;

    DEBUG_BEGIN(DEBUG_BERN);
    DEBUG_PRINTF_1("+bern {\n");
    DEBUG_PRINTF_2("n = %d\n", n);

    nl = int_abs(n);

    base = mp_array_element_ptr(x, 0);
    b = mp_b(base);
    t = mp_t(base);

    mp_acc_float_alloc_2(b, t, temp1, temp2);
    round = MP_TRUNC;

    /*
    Compute upper limit for recurrence relation method.
    */

    n2 = mp_times_log2_b(t, b) / 2;

    if (n2 > nl)
	n2 = nl;


    DEBUG_PRINTF_4("b = %d, t = %d, n2 = %d\n", b, t, n2);


    /*
    Set all results to zero.
    */


    for (i = 0; i < nl; i++)
	mp_set_digits_zero(mp_digit_ptr(mp_array_element_ptr(x, i), 0), t);

    mp_q_to_mp(1, 12, temp1);
    mp_copy(temp1, temp2);


#define x_t(i)		mp_t(mp_array_element_ptr(x, i))
#define fix_pointers()	if (mp_has_changed())				\
			{						\
			    temp1_ptr = mp_acc_float_ptr(temp1);	\
			    temp2_ptr = mp_acc_float_ptr(temp2);	\
			}

    temp1_ptr = mp_acc_float_ptr(temp1);
    temp2_ptr = mp_acc_float_ptr(temp2);


    for (j = 0; j < n2; j++)
    {
	register mp_int		j1 = j + 1;

	/*
	Decrease t if n is negative.
	*/

	if (n < 0)
	{
	    register mp_length	new_t = ((nl - j) * (t - 2)) / nl + 4;

	    if (new_t > t)
		new_t = t;

	    mp_t(temp1_ptr) = mp_t(temp2_ptr) = x_t(j) = new_t;
	}

	
	mp_copy(temp2, mp_array_element(x, j));

	if (j == n2 - 1)
	    break;

	mp_mul_double_q(temp1, 1, 1, 4 * j1, 4 * j1 + 6);
	mp_copy(temp1, temp2);

	fix_pointers();

	for (i = 0; i <= j; i++)
	{
	    /*
	    Change t if n is negative.
	    */

	    mp_float	xi;

	    if (n < 0)
	    {
		register mp_length	new_t = ((nl - i) * (t - 2)) /
								nl + 4;

		if (new_t > t)
		    new_t = t;

		mp_t(temp2_ptr) = x_t(i) = new_t;
	    }

	    xi = mp_array_element(x, i);

	    mp_mul_double_q(xi, 1, 1, 4 * (j - i) + 4, 4 * (j - i) + 6);
	    mp_sub_eq(temp2, xi);

	    fix_pointers();
	}
    }


    /*
    Now unscale results.
    */

    mp_t(temp1_ptr) = mp_t(temp2_ptr) = t;
    mp_int_to_mp(1, temp1);

    if (n2 > 1)
    {
	for (i = n2 - 2; i >= 0; i--)
	{
	    register mp_int	c = 4 * (n2 - 2 - i) + 4;

	    x_t(i) = t;

	    mp_mul_double_q(temp1, c, c + 2, 1, 1);
	    mp_mul_eq(mp_array_element(x, i), temp1);
	}


	/*
	We now have b(2j)/factorial(2j) in x.
	*/

	mp_int_to_mp(1, temp1);
    }

    for (i = 0; i < n2; i++)
    {
	register mp_int		c = 2 * i + 1;

	x_t(i) = t;

	mp_mul_double_q(temp1, c, c + 1, 1, 1);
	mp_mul_eq(mp_array_element(x, i), temp1);
    }

    if (nl > n2)
    {
	/*
	Compute remaining numbers.
	*/

	mp_pi(temp2);
	mp_int_power(temp2, -2, temp2);
	mp_div_int_eq(temp2, -4);

	mp_copy(mp_array_element(x, n2 - 1), temp1);

	for (i = n2; i < nl; i++)
	{
	    register mp_int	c = 2 * i + 1;

	    mp_mul_eq(temp1, temp2);
	    mp_mul_double_q(temp1, c, c + 1, 1, 1);
	    mp_copy(temp1, mp_array_element(x, i));
	}
    }

    mp_acc_float_delete(temp2);
    mp_acc_float_delete(temp1);
    round = save_round;

#ifndef NO_DEBUG
    DEBUG_PRINTF_1("\n\nFinal numbers:\n\n");
    for (i = 0; i < nl; i++)
	DEBUG_1("", mp_array_element_ptr(x, i));
#endif

    DEBUG_PRINTF_1("-}\n");
    DEBUG_END();
}
