#include <stdio.h>
#include "comm/examples/angst/angst.h"

/*
    This is a simple program to test the performance of some collective
    operations.
 */     
extern void gsetopL(), gsetopT(), gsetopTP();
extern void gscattersetR();

double sync();
double dsum();
double isum();
double scat();
double col();

int worker();

static int use_native = 0;
static ProcSet *pset  = ALLPROCS;

/* Stub for general parallel call */
main(argc,argv)
int argc;
char *argv[];
{
return PICall( worker, argc, argv );
}

int worker(argc,argv)
int argc;
char *argv[];
{
int c;
double (* f)();
long   reps, len, error_flag;
double t;
long   first,last,incr, svals[3];
long   pvals[3], p;
int    rrsize;
char   psetname[50];
int    psetlen = 50;

SYusc_init();

if (SYArgHasName( &argc, argv, 1, "-help" )) {
  if (MYPROCID == 0) PrintHelp( argv );
  return 0;
  }
  
reps     = DEFAULT_REPS;
f        = sync;
svals[0] = 0;
svals[1] = 1024;
svals[2] = 32;
pvals[0] = 1024*1024;
pvals[1] = pvals[0];
pvals[2] = pvals[0];
use_native = 0;

if (SYArgGetString( &argc, argv, 1, "-pset", psetname, 50 )) {
    int n1, m1, i;
    pset = PSCreate( 1 );
    sscanf( psetname, "%d-%d", &n1, &m1 );
    for (i=n1; i<=m1; i++) {
	PSAddMember( pset, &i, 1 );
	}
    PSCompile( pset );
    }
if (SYArgGetInt( &argc, argv, 1, "-rr", &rrsize )) {
    PISetRRSize( rrsize );
    }
if (SYArgHasName( &argc, argv, 1, "-ring" )) {
    if (MYPROCID == 0)
	fprintf( stderr, "Ring not yet supported\n" );
    return 0;
    /*    PISetCombFunc( gsetopR ); */
    }
if (SYArgHasName( &argc, argv, 1, "-line" )) {
    PISetCombFunc( gsetopL );
    PISetScatterFunc( gscattersetR, (void (*)())0 );
    }
if (SYArgHasName( &argc, argv, 1, "-fast" )) {
    PISetNbrRoutines( PIDefFastTree, (void (*)())0, (void (*)())0 );
    PISetCollectiveTree( PISetTreeNodesFast );
    PISetupCollectiveTree();
    }
if (SYArgHasName( &argc, argv, 1, "-tree" ))
    PISetCombFunc( gsetopT );
if (SYArgHasName( &argc, argv, 1, "-dsum" ))
    f = dsum;
if (SYArgHasName( &argc, argv, 1, "-isum" ))
    f = isum;
if (SYArgHasName( &argc, argv, 1, "-sync" ))
    f = sync;
if (SYArgHasName( &argc, argv, 1, "-scatter" ))
    f = scat;
if (SYArgHasName( &argc, argv, 1, "-col" ))
    f = col;    
if (SYArgHasName( &argc, argv, 1, "-native" )) {
    use_native = 1;
#if !defined(GDSUMGLOB) || !defined(GSYNCGLOB) || !defined(GSCATTERGLOB) || \
    !defined(GCOLGLOB)
    if (MYPROCID == 0) 
	fprintf( stderr, "-native not supported\n" );
    return 0;
#endif    
    }
if (SYArgGetIntVec( &argc, argv, 1, "-pkt", 3, pvals )) {
    /* Only pipelined version for now */
    PISetCombFunc( gsetopTP );
    /* Native version does not have packet size control */
    if (use_native) pvals[1] = pvals[0];
    }
SYArgGetIntVec( &argc, argv, 1, "-size", 3, svals );
SYArgGetInt(    &argc, argv, 1, "-reps", &reps );

/* sync only makes sense for a single size */
if (f == sync) svals[0] = svals[1] = 0;

first = svals[0];
last  = svals[1];
incr  = svals[2];

/* Generate header */
if (MYPROCID == 0)
    (*f)(0,-1);

for (p=pvals[0]; p<=pvals[1]; p += pvals[2]) {
    PISetPacketSize( p );	
    time_function(reps,first,last,incr,f);
    }

GSYNC(ALLPROCS);
if (MYPROCID == 0)
    printf( "Done\n" );
return 0;    
}

time_function(reps,first,last,incr,f)
long reps,first,last,incr;
double (* f)();
{
long len,myproc;
double mean_time,rate;
double t;

myproc = MYPROCID;
if (myproc == 0) 
  printf( "\n#len\ttime\t\tave time (us)\trate\n"); 
for(len=first;len<=last;len+=incr){
    t = (* f) ( reps,len );
    mean_time = ( t / reps ) * 1.0e6;  /* average over trials, convert to us */
    rate      = (double)(len)/(mean_time*(1e-6));
    if(myproc==0) printf("%d\t%f\t%f\t%f\n", len, t, mean_time, rate);
    }
}

double dsum(reps,len)
long reps,len;
{
double elapsed_time;
long i,msg_id;
SYusc_time_t t0, t1;
double *val, *work;
int    n = sizeof(double)*len;

if (len < 0) {
    printf( "Dsum\n" );
    return 0.0;
    }

val  = (double *)MALLOC(n);   CHKPTRV(val,0);
work = (double *)MALLOC(n);   CHKPTRV(work,0);
for (i=0; i<len; i++) val[i] = 0.0;

GSYNC(ALLPROCS);
elapsed_time = 0;
#ifdef GDSUMGLOB
if (use_native) {
    SYusc_clock(&t0);
    for(i=0;i<reps;i++){
        GDSUMGLOB(val,len,work);
        }
    SYusc_clock(&t1);
    } else
#endif
    {    
    SYusc_clock(&t0);
    for(i=0;i<reps;i++){
        GDSUM(val,len,work,pset);
        }
    SYusc_clock(&t1);
    }
elapsed_time = SYuscDiff( &t0, &t1 );
FREE( val );
FREE( work );
return(elapsed_time);
}

double isum(reps,len)
     long reps,len;
{
double elapsed_time;
long i,msg_id;
SYusc_time_t t0, t1;
long *val, *work;
int    n = sizeof(long)*len;

if (len < 0) {
    printf( "Isum\n" );
    return 0.0;
    }

val  = (long *)MALLOC(n);   CHKPTRV(val,0);
work = (long *)MALLOC(n);   CHKPTRV(work,0);
for (i=0; i<len; i++) val[i] = 0;

GSYNC(ALLPROCS);
elapsed_time = 0;
#ifdef GISUMGLOB
if (use_native) {
    SYusc_clock(&t0);
    for(i=0;i<reps;i++){
	GISUMGLOB(val,len,work);
	}
    SYusc_clock(&t1);
    } else
#endif
	{    
	SYusc_clock(&t0);
	for(i=0;i<reps;i++){
	    GISUM(val,len,work,pset);
	    }
	SYusc_clock(&t1);
	}
elapsed_time = SYuscDiff( &t0, &t1 );
FREE( val );
FREE( work );
return(elapsed_time);
}

double scat(reps,len)
long reps,len;
{
double elapsed_time;
long i,msg_id,myproc;
SYusc_time_t t0, t1;
char *buf;

if (len < 0) {
    printf( "Scatter\n" );
    return 0.0;
    }
buf = (char *)MALLOC(len);          CHKPTRV(buf,0);
for (i=0; i<len; i++) buf[i] = ' ';
myproc = MYPROCID;
GSYNC(ALLPROCS);
elapsed_time = 0;
#ifdef GSCATTERGLOB
if (use_native) {
    SYusc_clock(&t0);
    for(i=0;i<reps;i++){
       GSCATTERGLOB(buf,len,myproc==0,MSG_OTHER);
       }
    SYusc_clock(&t1);
    } else
#endif
    {    
    SYusc_clock(&t0);
    for(i=0;i<reps;i++){
       GSCATTER(buf,len,myproc==0,pset,MSG_OTHER);
       }
    SYusc_clock(&t1);
    }
elapsed_time = SYuscDiff( &t0, &t1 );
FREE( buf );
return(elapsed_time);
}

/* This version collects contributions of the same size */
double col(reps,len)
long reps,len;
{
double elapsed_time;
long i,msg_id,myproc;
SYusc_time_t t0, t1;
char *buf, *gbuf;
int  glen, gsize;

if (len < 0) {
    printf( "Collect\n" );
    return 0.0;
    }
buf   = (char *)MALLOC(len);          CHKPTRV(buf,0);
gbuf  = (char *)MALLOC(len*NUMNODES);CHKPTRV(gbuf,0);
gsize = len * NUMNODES;

for (i=0; i<len; i++) buf[i] = ' ';
myproc = MYPROCID;
GSYNC(ALLPROCS);
elapsed_time = 0;
#ifdef GCOLGLOB
if (use_native) {
    SYusc_clock(&t0);
    for(i=0;i<reps;i++){
       GCOLGLOB(buf,len,gbuf,gsize,&glen,MSG_OTHER);
       }
    SYusc_clock(&t1);
    } else
#endif
    {    
    SYusc_clock(&t0);
    for(i=0;i<reps;i++){
       GCOL(buf,len,gbuf,gsize,&glen,pset,MSG_OTHER);
       }
    SYusc_clock(&t1);
    }
elapsed_time = SYuscDiff( &t0, &t1 );
FREE( buf ); FREE( gbuf );
return(elapsed_time);
}

double sync(reps,len) long reps,len;
{
  double elapsed_time;
  long i,msg_id,myproc;
  SYusc_time_t t0, t1;

if (len < 0) {
    printf( "Sync\n" );
    return 0.0;
    }
myproc = MYPROCID;
GSYNC(ALLPROCS);
elapsed_time = 0;
#ifdef GSYNCGLOB
if (use_native) {
    SYusc_clock(&t0);
    for(i=0;i<reps;i++){
       GSYNCGLOB();
       }
    SYusc_clock(&t1);
    } else
#endif
    {    
    SYusc_clock(&t0);
    for(i=0;i<reps;i++){
        GSYNC(pset);
        }
    SYusc_clock(&t1);
    }
elapsed_time = SYuscDiff( &t0, &t1 );
return(elapsed_time);
}

PrintHelp( argv )
char **argv;
{
  fprintf( stderr, "%s - test collection operations for speed\n", argv[0] );
  fprintf( stderr, 
 "-rr n -ring -line -tree -dsum -isum -sync -scatter -pkt len stride cnt\n" );
  fprintf( stderr, "-native -size len stride cnt -reps n -fast\n" );

  fprintf( stderr, "\nTests:\n" );
  fprintf( stderr, "-dsum     : reduction (double precision)\n" );
  fprintf( stderr, "-isum     : reduction (integer)\n" );
  fprintf( stderr, "-sync     : synchronization\n" );
  fprintf( stderr, "-scatter  : scatter\n" );

  fprintf( stderr, "\nTest control:\n" );
  fprintf( stderr, "-size len stride cnt : size of messages\n" );
  fprintf( stderr, "-reps n              : number of times to repeat test\n" );
  fprintf( stderr, 
	  "-pset n-m            : processor set consisting of nodes n to m" );

  fprintf( stderr, "\nCollective communication algorithms:\n" );
  fprintf( stderr, 
   "-rr n   : use a ready-receiver version for messages n bytes or longer\n" );
  fprintf( stderr, 
   "          (many operations do not yet have rr versions)\n" );
  fprintf( stderr, "-ring   : use a ring algorithm\n" );
  fprintf( stderr, "-tree   : use a tree algorithm\n" );
  fprintf( stderr, "-line   : use a linear algorithm\n" );
  fprintf( stderr, "-native : use the native routines (if available)\n" );

  fprintf( stderr, "\nOptions for algorithms\n" );
  fprintf( stderr, "-pkt len stride cnt : use packets of length len\n" );
  fprintf( stderr, 
         "-fast               : use a (possibly) contention-free tree\n" );
}
