//$$ newmat6.cxx            Operators, element access, submatrices

// Copyright (C) 1991: R B Davies and DSIR

#include "include.hxx"

#include "newmat.hxx"
#include "newmatrc.hxx"


//#define REPORT { static ExeCounter ExeCount(__LINE__,6); ExeCount++; }

#define REPORT {}

/*************************** general utilities *************************/

static int tristore(int n)                      // els in triangular matrix
{ return (n*(n+1))/2; }


/****************************** operators *******************************/

real& Matrix::operator()(int m, int n)
{
   if (m<=0 || m>nrows || n<=0 || n>ncols) MatrixError("Index out of range");
   return store[(m-1)*ncols+n-1];
}

real& SymmetricMatrix::operator()(int m, int n)
{
   if (m<=0 || n<=0 || m>nrows || n>ncols) MatrixError("Index out of range");
   if (m>=n) return store[tristore(m-1)+n-1];
   else return store[tristore(n-1)+m-1];
}

real& UpperTriangularMatrix::operator()(int m, int n)
{
   if (m<=0 || n<m || n>ncols) MatrixError("Index out of range");
   return store[(m-1)*ncols+n-1-tristore(m-1)];
}

real& LowerTriangularMatrix::operator()(int m, int n)
{
   if (n<=0 || m<n || m>nrows) MatrixError("Index out of range");
   return store[tristore(m-1)+n-1];
}

real& DiagonalMatrix::operator()(int m, int n)
{
   if (n<=0 || m!=n || m>nrows || n>ncols) MatrixError("Index out of range");
   return store[n-1];
}

real& DiagonalMatrix::operator()(int m)
{
   if (m<=0 || m>nrows) MatrixError("Index out of range");
   return store[m-1];
}

real& ColumnVector::operator()(int m)
{
   if (m<=0 || m> nrows) MatrixError("Index out of range");
   return store[m-1];
}

real& RowVector::operator()(int n)
{
   if (n<=0 || n> ncols) MatrixError("Index out of range");
   return store[n-1];
}

#ifndef __ZTC__

real Matrix::operator()(int m, int n) const
{
   if (m<=0 || m>nrows || n<=0 || n>ncols) MatrixError("Index out of range");
   return store[(m-1)*ncols+n-1];
}

real SymmetricMatrix::operator()(int m, int n) const
{
   if (m<=0 || n<=0 || m>nrows || n>ncols) MatrixError("Index out of range");
   if (m>=n) return store[tristore(m-1)+n-1];
   else return store[tristore(n-1)+m-1];
}

real UpperTriangularMatrix::operator()(int m, int n) const
{
   if (m<=0 || n<m || n>ncols) MatrixError("Index out of range");
   return store[(m-1)*ncols+n-1-tristore(m-1)];
}

real LowerTriangularMatrix::operator()(int m, int n) const
{
   if (n<=0 || m<n || m>nrows) MatrixError("Index out of range");
   return store[tristore(m-1)+n-1];
}

real DiagonalMatrix::operator()(int m, int n) const
{
   if (n<=0 || m!=n || m>nrows || n>ncols) MatrixError("Index out of range");
   return store[n-1];
}

real DiagonalMatrix::operator()(int m) const
{
   if (m<=0 || m>nrows) MatrixError("Index out of range");
   return store[m-1];
}

real ColumnVector::operator()(int m) const
{
   if (m<=0 || m> nrows) MatrixError("Index out of range");
   return store[m-1];
}

real RowVector::operator()(int n) const
{
   if (n<=0 || n> ncols) MatrixError("Index out of range");
   return store[n-1];
}

#endif

BaseMatrix::operator real()
{
   REPORT
   GeneralMatrix* gm = Evaluate();
   if (gm->nrows!=1 || gm->ncols!=1)
      MatrixError("Attempt to convert non 1x1 matrix to scalar");
   real x = *(gm->store); gm->tDelete(); return x;
}

AddedMatrix BaseMatrix::operator+(BaseMatrix& bm)
{ REPORT return AddedMatrix(this, &bm); }

MultipliedMatrix BaseMatrix::operator*(BaseMatrix& bm)
{ REPORT return MultipliedMatrix(this, &bm); }

SolvedMatrix InvertedMatrix::operator*(BaseMatrix& bmx)
{ REPORT return SolvedMatrix(bm, &bmx); }

SubtractedMatrix BaseMatrix::operator-(BaseMatrix& bm)
{ REPORT return SubtractedMatrix(this, &bm); }

ShiftedMatrix BaseMatrix::operator+(real f)
{ REPORT return ShiftedMatrix(this, f); }

ScaledMatrix BaseMatrix::operator*(real f)
{ REPORT return ScaledMatrix(this, f); }

ScaledMatrix BaseMatrix::operator/(real f)
{ REPORT return ScaledMatrix(this, 1.0/f); }

ShiftedMatrix BaseMatrix::operator-(real f)
{ REPORT return ShiftedMatrix(this, -f); }

TransposedMatrix BaseMatrix::t() { REPORT return TransposedMatrix(this); }

NegatedMatrix BaseMatrix::operator-() { REPORT return NegatedMatrix(this); }

InvertedMatrix BaseMatrix::i() { REPORT return InvertedMatrix(this); }

ConstMatrix GeneralMatrix::c() const
{
   if (tag != -1) MatrixError(".c() applied to temporary matrix");
   REPORT return ConstMatrix(this);
}

RowedMatrix BaseMatrix::CopyToRow() { REPORT return RowedMatrix(this); }

ColedMatrix BaseMatrix::CopyToColumn() { REPORT return ColedMatrix(this); }

DiagedMatrix BaseMatrix::CopyToDiagonal() { REPORT return DiagedMatrix(this); }

MatedMatrix BaseMatrix::CopyToMatrix(int nrx, int ncx)
{ REPORT return MatedMatrix(this,nrx,ncx); }

void GeneralMatrix::operator=(real f)
{ REPORT int i=storage; real* s=store; while (i--) { *s++ = f; } }

void Matrix::operator=(BaseMatrix& X)
{ REPORT CheckConversion(X); Eq(X,MatrixType::Rect); } 

void RowVector::operator=(BaseMatrix& X)
{
   REPORT CheckConversion(X); Eq(X,MatrixType::RowV);
   if (nrows!=1) MatrixError("Illegal conversion to row vector");
}

void ColumnVector::operator=(BaseMatrix& X)
{
   REPORT CheckConversion(X); Eq(X,MatrixType::ColV);
   if (ncols!=1) MatrixError("Illegal conversion to column vector");
}

void SymmetricMatrix::operator=(BaseMatrix& X)
{ REPORT CheckConversion(X); Eq(X,MatrixType::Sym); }
 
void UpperTriangularMatrix::operator=(BaseMatrix& X)
{ REPORT CheckConversion(X); Eq(X,MatrixType::UT); }

void LowerTriangularMatrix::operator=(BaseMatrix& X)
{ REPORT CheckConversion(X); Eq(X,MatrixType::LT); }

void DiagonalMatrix::operator=(BaseMatrix& X)
{ REPORT CheckConversion(X); Eq(X,MatrixType::Diag); }

void GeneralMatrix::operator<<(const real* r)
{
   REPORT
   int i = storage; real* s=store;
   while(i--) *s++ = *r++;
}


/************************* element access *********************************/

real& Matrix::element(int m, int n)
{
   if (m<0 || m>= nrows || n<0 || n>= ncols) MatrixError("Index out of range");
   return store[m*ncols+n];
}

real& SymmetricMatrix::element(int m, int n)
{
   if (m<0 || n<0 || m >= nrows || n>=ncols) MatrixError("Index out of range");
   if (m>=n) return store[tristore(m)+n];
   else return store[tristore(n)+m];
}

real& UpperTriangularMatrix::element(int m, int n)
{
   if (m<0 || n<m || n>=ncols) MatrixError("Index out of range");
   return store[m*ncols+n-tristore(m)];
}

real& LowerTriangularMatrix::element(int m, int n)
{
   if (n<0 || m<n || m>=nrows) MatrixError("Index out of range");
   return store[tristore(m)+n];
}

real& DiagonalMatrix::element(int m, int n)
{
   if (n<0 || m!=n || m>=nrows || n>=ncols) MatrixError("Index out of range");
   return store[n];
}

real& DiagonalMatrix::element(int m)
{
   if (m<0 || m>=nrows) MatrixError("Index out of range");
   return store[m];
}

real& ColumnVector::element(int m)
{
   if (m<0 || m>= nrows) MatrixError("Index out of range");
   return store[m];
}

real& RowVector::element(int n)
{
   if (n<0 || n>= ncols) MatrixError("Index out of range");
   return store[n];
}

