/*--------------------------------------------------------------------------
 FFT routines.
--------------------------------------------------------------------------*/
#include <stdio.h>
#include <math.h>

#include "imath.h"
#include "bfp.h"

static char sccsid[] = "@(#)fft.c	1.2 7/17/91";

/*--------------------------------------------------------------------------
 The symbol BFP_DELAY_DIVIDES only applies to inverse fft's.
 If false, we divide by 2 at each column of the inverse fft (as normal).
 If true, we simply decrement bfp_shift_ct to flag the caller that the
 result should be divided by 2.
 In an informal test on a single radar echo, this reduces peak error
 by a factor of four, and root mean squared error by a factor of 2.
 (It also caused the median error to jump from zero to 1 or 2, but
 we won't talk about that!)
--------------------------------------------------------------------------*/
#define BFP_DELAY_DIVIDES 1

#define TAB_BITS FRAC_BITS
#define TAB_UNITY (1<<TAB_BITS)
#define TAB_HALF  (1<<(TAB_BITS-1))

/*--------------------------------------------------------------------------
 Routine to convert floating point number to integer format.
 Result is a fixed-point number with TAB_BITS of fraction.
--------------------------------------------------------------------------*/
#define ftoi(f) ((short)(f * TAB_UNITY))

/*--------------------------------------------------------------------------
 Fill a lookup table with the values of sin and cos.
 n is size of table.
 Table covers range [0..pi).
--------------------------------------------------------------------------*/
void
isincos_table(arc, n, isintab, icostab)
    double arc;
    int n;
    short *isintab;
    short *icostab;
{
    int i;
    for (i=0; i<n; i++) {
	icostab[i] = ftoi(cos((arc * i) / (float) n));
	isintab[i] = ftoi(sin((arc * i) / (float) n));
    }
}

void
fsincos_table(n, fsintab, fcostab)
    int n;
    float *fsintab;
    float *fcostab;
{
    int i;
    for (i=0; i<n; i++) {
	fcostab[i] = cos((M_PI * i) / (float) n);
	fsintab[i] = sin((M_PI * i) / (float) n);
    }
}

/*--------------------------------------------------------------------------
 For now, keep a single static lookup table.
 Rewrite it whenever the fft size changes.
--------------------------------------------------------------------------*/
static int tab_fn = -1;
static float *tab_fcos, *tab_fsin;
static int tab_in = -1;
static short *tab_icos, *tab_isin;

static void
tab_fsetup(n)
    int n;
{
    if (tab_fn != n) {
	if (tab_fn != -1) {
	    free(tab_fcos);
	    free(tab_fsin);
	}
	tab_fcos = (float *)malloc(sizeof(float) * n);
	tab_fsin = (float *)malloc(sizeof(float) * n);

	fsincos_table(n, tab_fsin, tab_fcos);
    }
    tab_fn = n;
}

static void
tab_isetup(arc, n)
    double arc;
    int n;
{
    if (tab_in != n) {
	if (tab_in != -1) {
	    free(tab_icos);
	    free(tab_isin);
	}
	tab_icos = (short *)malloc(sizeof(short) * n);
	tab_isin = (short *)malloc(sizeof(short) * n);

	isincos_table(arc, n, tab_isin, tab_icos);
    }
    tab_in = n;
}

/*--------------------------------------------------------------------------
 Complex FFT, decimation-in-time, in place.
 Each of rbuf and ibuf is a (2 ^ log_n)-long array of 16 bit signed fractions.
 Inputs are in bit-reversed order, outputs are in order.

 See Oppenheim & Schafer, "Digital Signal Processing", 1975, p. 318; 
 the particular dataflow graph implemented is that of figure 6.10, 
 which uses bit-reversed input, and accesses coefficients in normal order.
--------------------------------------------------------------------------*/
void
cfft_dit(log_n, xr, xi)
    int log_n;
    short *xr;
    short *xi;
{
    int n;
    int j, k;
    int stage;

    n = 1<<log_n;
    tab_isetup(M_PI, n/2);

    /* For each of the log_n stages of the FFT */
    for (stage=1; stage<=log_n; stage++) {
	int n_dft;
	int n_butterfly;
	int twiddle;
	short ar, br, ai, bi;

	/* If this macro defined, scale x by 1/2 if |x[j]|>1 for any j */
	BFP_SCALE(n, xr, xi, j);

	n_dft = 1 << (log_n - stage);		/* # of dft's */
	n_butterfly = 1 << (stage-1);		/* # of butterflies per dft */

	/* For each (dft_size)-point dft making up this stage: */
	for (j=0; j < n_dft; j++) {

	    /* For each butterfly making up this dft: */
	    for (k=twiddle=0; k<n_butterfly; k++, twiddle += n_dft) {
		int index1, index2;
		short wr, wi;

		if (twiddle >= n/2) twiddle = 0;
		wr = tab_icos[twiddle];
		wi = tab_isin[twiddle];

		/* Calculate the indices of the two locations this butterfly
		 * operates upon.
		 */
		index1 = j * 2 * n_butterfly + k;
		index2 = index1 + n_butterfly;

		/* Grab old values & apply twiddle factor to one input value. */
		ar = xr[index1];
		ai = xi[index1];
		br =(TAB_HALF+(int)wr*xr[index2] - wi*xi[index2]) >> TAB_BITS;
		bi =(TAB_HALF+(int)wr*xi[index2] + wi*xr[index2]) >> TAB_BITS;

		/* Do butterfly. */
		xr[index1] = ar + br;
		xi[index1] = ai + bi;
		xr[index2] = ar - br;
		xi[index2] = ai - bi;
	    }
	}
    }
}

/*--------------------------------------------------------------------------
 Complex FFT, decimation-in-frequency, in place.
 Each of rbuf and ibuf is a (2 ^ log_n)-long array of 16 bit signed fractions.
 Inputs are in order, outputs are in bit-reversed order.

 Routine will divide by two in the first butterfly if divide_stages[0] 
 is TRUE, etc.
--------------------------------------------------------------------------*/
void
cfft_dif(log_n, xr, xi, divide_stages)
    int log_n;
    short *xr;
    short *xi;
    short *divide_stages;
{
    int n;
    int stage;
    int j;

    n = 1<<log_n;
    tab_isetup(M_PI, n/2);

    /* For each of the log_n stages of the FFT */
    for (stage=log_n; stage>=1; stage--) {
	int k;
	int n_dft;
	int n_butterfly;
	int twiddle;

	/* If this macro defined, scale x by 1/2 if |x[j]|>1 for any j */
	BFP_SCALE(n, xr, xi, j);

	n_dft = 1 << (log_n - stage);		/* # of dft's */
	n_butterfly = 1 << (stage-1);		/* # of butterflies per dft */

#define cfft_dif_stage(SHIFT) \
	/* For each (dft_size)-point dft making up this stage: */	\
	for (j=0; j < n_dft; j++) {					\
	    /* For each butterfly making up this dft: */		\
	     for (k=twiddle=0; k<n_butterfly; k++, twiddle += n_dft) {	\
		int index1, index2;					\
		short xr1, xi1, xr2, xi2;				\
		short wr, wi;						\
		if (twiddle >= n/2) twiddle = 0;			\
		wr = tab_icos[twiddle];					\
		wi = tab_isin[twiddle];					\
		/* Calculate adr of locations this butterfly operates upon. */ \
		index1 = j * 2 * n_butterfly + k;			\
		index2 = index1 + n_butterfly;				\
		xr1 = xr[index1];					\
		xi1 = xi[index1];					\
		xr2 = xr[index2];					\
		xi2 = xi[index2];					\
		/* Do butterfly.  Different from a DIT bfly. */		\
		xr[index1]= (xr1 + xr2)>>SHIFT;				\
		xi[index1]= (xi1 + xi2)>>SHIFT;				\
		xr[index2]=((1<<(TAB_BITS+SHIFT-1))			\
		    +(int)(xr1-xr2)*wr - (xi1-xi2)*wi)>>(TAB_BITS+SHIFT);\
		xi[index2]=((1<<(TAB_BITS+SHIFT-1))			\
		    +(int)(xr1-xr2)*wi + (xi1-xi2)*wr)>>(TAB_BITS+SHIFT);\
	      }								\
	}

	/* Do inner two loops of fft; divide by 2 if desired. */
	if (divide_stages && *divide_stages++) {
	    cfft_dif_stage(1);
	} else {
	    cfft_dif_stage(0);
	}
    }
}

void
cfft_dif_float32(log_n, xr, xi)
    int log_n;
    float *xr;
    float *xi;
{
    int n;
    int stage;
    int j;

    n = 1<<log_n;
    tab_fsetup(n/2);

    /* For each of the log_n stages of the FFT */
    for (stage=log_n; stage>=1; stage--) {
	int k;
	int n_dft;
	int n_butterfly;
	int twiddle;

	n_dft = 1 << (log_n - stage);		/* # of dft's */
	n_butterfly = 1 << (stage-1);		/* # of butterflies per dft */

	/* For each (dft_size)-point dft making up this stage: */
	for (j=0; j < n_dft; j++) {

	    /* For each butterfly making up this dft: */
	    for (k=twiddle=0; k<n_butterfly; k++, twiddle += n_dft) {
		int index1, index2;
		float xr1, xi1, xr2, xi2;
		float wr, wi;

		if (twiddle >= n/2) twiddle = 0;
		wr = tab_fcos[twiddle];
		wi = tab_fsin[twiddle];
		/* printf("stage=%d j=%d k=%d: wr=%d, wi=%d, twiddle=%d\n", stage, j, k, wr, wi, twiddle); */

		/* Calculate the indices of the two locations this butterfly
		 * operates upon.
		 */
		index1 = j * 2 * n_butterfly + k;
		index2 = index1 + n_butterfly;

		/* Grab old values. */
		xr1 = xr[index1];
		xi1 = xi[index1];
		xr2 = xr[index2];
		xi2 = xi[index2];

		/* Do butterfly.  Note this is different from a DIT bfly. */
		xr[index1] = xr1 + xr2;
		xi[index1] = xi1 + xi2;
		xr[index2] = (xr1-xr2)*wr - (xi1-xi2)*wi;
		xi[index2] = (xr1-xr2)*wi + (xi1-xi2)*wr;
	    }
	}
    }
}

/*--------------------------------------------------------------------------
 Inverse of Complex FFT, decimation-in-frequency, in place.
 Each of rbuf and ibuf is a (2 ^ log_n)-long array of 16 bit signed fractions.
 Inputs are in bit-reversed order, outputs are in order.

 Divides data by N, unless divide_stages is non-NULL, in which case it
 only divides by 2 at stage i s.t. divide_stages[i]==1.

 Same as decimation-in-time forward transform, but goes thru stages in opposite 
 order, and uses 1/original twiddle factors (i.e. uses the complex conjugate
 of the original twiddle factors; you could equivalently take the complex
 conjugate of the input (and output?) data & use the original twiddle factors!)
--------------------------------------------------------------------------*/
void
cifft_dif(log_n, xr, xi, divide_stages)
    int log_n;
    short *xr;
    short *xi;
    short *divide_stages;
{
    int j, k;
    int stage;
    int n;

    n = 1<<log_n;
    tab_isetup(M_PI, n/2);

    /* For each of the log_n stages of the FFT */
    for (stage=1; stage<=log_n; stage++) {
	int n_dft;
	int n_butterfly;
	int twiddle;
	short ar, br, ai, bi;

	/* If this macro defined, scale x by 1/2 if |x[j]|>1 for any j */
	BFP_SCALE(n, xr, xi, j);

	n_dft = 1 << (log_n - stage);		/* # of dft's */
	n_butterfly = 1 << (stage-1);		/* # of butterflies per dft */

#define cifft_dif_stage(SHIFT)						\
	/* For each (dft_size)-point dft making up this stage: */	\
	for (j=0; j < n_dft; j++) {					\
	    /* For each butterfly making up this dft: */		\
	    for (k=twiddle=0; k<n_butterfly; k++, twiddle += n_dft) {	\
		int index1, index2;					\
		short wr, wi;						\
		if (twiddle >= n/2) twiddle = 0;			\
		wr = tab_icos[twiddle];					\
		wi = -tab_isin[twiddle];/* negative exp on twiddle */	\
		/* Calculate adr of locations this butterfly operates upon. */ \
		index1 = j * 2 * n_butterfly + k;			\
		index2 = index1 + n_butterfly;				\
		/* Grab old values, apply twiddle to 1 input value. */	\
		ar = xr[index1];					\
		ai = xi[index1];					\
		br =(TAB_HALF+ wr*xr[index2] - wi*xi[index2]) >> TAB_BITS; \
		bi =(TAB_HALF+ wr*xi[index2] + wi*xr[index2]) >> TAB_BITS; \
		/* Do butterfly already. */				\
		xr[index1] = (ar + br)>>SHIFT;				\
		xi[index1] = (ai + bi)>>SHIFT;				\
		xr[index2] = (ar - br)>>SHIFT;				\
		xi[index2] = (ai - bi)>>SHIFT;				\
	    }								\
	} /* end of macro cifft_dif_stage */

	{
	    int do_div = (divide_stages==NULL || *divide_stages++!=0);
	    int delay_any_div = (BFP_DELAY_DIVIDES && bfp_do);

	    /* If we want to divide by 2 while doing the fft stage, do so. */
	    if (do_div && !delay_any_div) {
		cifft_dif_stage(1);
	    } else {
		cifft_dif_stage(0);
	    }
	    /* If we need to divide by 2 later, say so. */
	    if (do_div && delay_any_div)
		bfp_scale_ct--;
	}
    }
}

void
cifft_dif_float32(log_n, xr, xi)
    int log_n;
    float *xr;
    float *xi;
{
    int j, k;
    int stage;
    int n;

    n = 1<<log_n;
    tab_fsetup(n/2);

    /* For each of the log_n stages of the FFT */
    for (stage=1; stage<=log_n; stage++) {
	int n_dft;
	int n_butterfly;
	int twiddle;
	float ar, br, ai, bi;

	n_dft = 1 << (log_n - stage);		/* # of dft's */
	n_butterfly = 1 << (stage-1);		/* # of butterflies per dft */

	/* For each (dft_size)-point dft making up this stage: */
	for (j=0; j < n_dft; j++) {

	    /* For each butterfly making up this dft: */
	    for (k=twiddle=0; k<n_butterfly; k++, twiddle += n_dft) {
		int index1, index2;

		float wr, wi;

		if (twiddle >= n/2) twiddle = 0;
		wr = tab_fcos[twiddle];
		wi = -tab_fsin[twiddle];	/* negative exp on twiddle */

		/* Calculate the indices of the two locations this butterfly
		 * operates upon.
		 */
		index1 = j * 2 * n_butterfly + k;
		index2 = index1 + n_butterfly;

		/* Grab old values & apply twiddle factor to one input value. */
		ar = xr[index1];
		ai = xi[index1];
		br = wr*xr[index2] - wi*xi[index2];
		bi = wr*xi[index2] + wi*xr[index2];

		/* Do butterfly.  Scale down by factor of 2. */
		xr[index1] = (ar + br)/2;
		xi[index1] = (ai + bi)/2;
		xr[index2] = (ar - br)/2;
		xi[index2] = (ai - bi)/2;
	    }
	}
    }
}


/*--------------------------------------------------------------------------
 Complex FFT, direct method.
 Each of rbuf and ibuf is a (2 ^ log_n)-long array of 16 bit signed fractions.
 Inputs are in order, outputs are in order.
--------------------------------------------------------------------------*/
void
cfft_direct_int16(log_n, xr, xi, yr, yi)
    int log_n;
    short *xr, *yr;
    short *xi, *yi;
{
    int j, k;
    int dft_size = 1 << log_n;

    tab_isetup(2.0*M_PI, dft_size);

    for (j=0; j<dft_size; j++) {
	float vr, vi;
	int twiddle;
	vr = 0;
	vi = 0;
	twiddle = 0;
	for (k=0; k<dft_size; k++, twiddle += j) {
	    if (twiddle >= dft_size) twiddle -= dft_size;
	    vr += xr[k]*tab_icos[twiddle]
		- xi[k]*tab_isin[twiddle];
	    vi += xi[k]*tab_icos[twiddle]
		+ xr[k]*tab_isin[twiddle];
	    /* printf("%d %d: tw %d: jk/d %f: vr += %d * %d - %d * %d, vr += %d * %d - %d * %d\n",
		j, k, twiddle, ( (float)j * k ) / dft_size,
		xr[k], ftoi(cos(2*M_PI*j*k/dft_size)),
		xi[k], ftoi(sin(2*M_PI*j*k/dft_size)),
		xr[k], tab_icos[twiddle],
		xi[k], tab_isin[twiddle]);
	    */
	}
	yr[j] = (short) ((TAB_HALF + vr) / TAB_UNITY);
	yi[j] = (short) ((TAB_HALF + vi) / TAB_UNITY);
    }
}

/*--------------------------------------------------------------------------
 Complex FFT, direct method.
 x and y are (2 ^ log_n) array of floats.
 Inputs and outputs are in normal order.
--------------------------------------------------------------------------*/
void
cfft_direct_float32(log_n, xr, xi, yr, yi)
    int log_n;
    float *xr, *yr;
    float *xi, *yi;
{
    int j, k;
    int dft_size = 1 << log_n;

    for (j=0; j<dft_size; j++) {
	float vr, vi;
	vr = 0;
	vi = 0;
	for (k=0; k<dft_size; k++) {
	    vr += xr[k]*cos(2*M_PI*j*k/dft_size) 
		- xi[k]*sin(2*M_PI*j*k/dft_size);
	    vi += xi[k]*cos(2*M_PI*j*k/dft_size) 
		+ xr[k]*sin(2*M_PI*j*k/dft_size);
	    /* printf("%d %d: vr += %g * %g - %g * %g, vi += %g * %g + %g * %g\n",
		j, k,
		xr[k], cos(2*M_PI*j*k/dft_size),
		xi[k], sin(2*M_PI*j*k/dft_size),
		xi[k], cos(2*M_PI*j*k/dft_size),
		xr[k], sin(2*M_PI*j*k/dft_size));
	    */

	}
	yr[j] = vr;
	yi[j] = vi;
    }
}

/*--------------------------------------------------------------------------
 Complex inverse FFT, direct method.
 x and y are (2 ^ log_n) array of floats.
 Inputs and outputs are in normal order.
--------------------------------------------------------------------------*/
void
cifft_direct_float32(log_n, xr, xi, yr, yi)
    int log_n;
    float *xr, *yr;
    float *xi, *yi;
{
    int j, k;
    int dft_size = 1 << log_n;

    for (j=0; j<dft_size; j++) {
	float vr, vi;
	vr = 0;
	vi = 0;
	for (k=0; k<dft_size; k++) {
	    vr += xr[k]*cos(2*M_PI*j*k/dft_size) 
		+ xi[k]*sin(2*M_PI*j*k/dft_size);
	    vi += xi[k]*cos(2*M_PI*j*k/dft_size) 
		- xr[k]*sin(2*M_PI*j*k/dft_size);
	}
	yr[j] = vr/dft_size;
	yi[j] = vi/dft_size;
    }
}
