;;; TENSOR.SCM -- Tensor-like support functions for JACAL
;;; Copyright (C) 1993 Jerry D. Hedden
;;; See the file `COPYING' for terms applying to this program.

; Does not support the notion of contra-/covariant indices.  Users must
;  keep track of this information themselves.

; Assumes that all matrices are "proper" (i.e., that all "dimensions" of
;  the matrix are the same length, e.g., 4x4x4) and "compatible" (e.g.,
;  a 3x3 matrix is not compatible with a 4x4x4 matrix).

(definfo 'indexshift
  "Shifts an index within a tensor")

(definfo 'indexswap
  "Swaps two indices within a tensor")

(definfo 'contract
  "Tensor contraction")

(definfo 'tmult
  "Tensor multiplication")


(define (tnsr:rank m)
  (let loop ((rank 0) (mm m))
    (if (bunch? mm) (loop (+ rank 1) (car mm))
		    rank)))

(define (idxright m n a b)
  (let l1 ((ma m) (nn n))
  (if (< nn a)
      (map (lambda (mm) (l1 mm (+ nn 1))) ma)
      (let l2 ((mb ma) (aa (+ a 1)))
      (if (= aa b)
	  (apply map list mb)
	  (map (lambda (mm) (l2 mm (+ aa 1))) (apply map list mb)))))))

(define (idxleft m n a b)
  (let l1 ((ma m) (nn n))
  (if (< nn a)
      (map (lambda (mm) (l1 mm (+ nn 1))) ma)
      (let l2 ((mb ma) (aa (+ a 1)))
      (if (= aa b)
	  (apply map list mb)
	  (apply map list (map (lambda (mm) (l2 mm (+ aa 1))) mb)))))))

(define (idxswap m n a b)
  (let l1 ((ma m) (nn n))
  (if (< nn a)
      (map (lambda (mm) (l1 mm (+ nn 1))) ma)
      (let l2 ((mb ma) (aa (+ a 1)))
      (if (= aa b)
	  (apply map list mb)
	  (apply map list (map (lambda (mm) (l2 mm (+ aa 1)))
			       (apply map list mb))))))))

(defbltn 'indexshift
  (lambda (m . args)
    (let ((rank (tnsr:rank m)))
      (cond ((= rank 0) m)
	    ((= rank 1) (map list m))
	    ((or (= rank 2) (null? args)) (apply map list m))
	    (else
	       (let* ((a (car args))
		      (b (if (null? (cdr args)) (+ a 1) (cadr args))))
		 (if (< a 1) (set! a 1) (if (> a rank) (set! a rank)))
		 (if (< b 1) (set! b 1) (if (> b rank) (set! b rank)))
		 (if (= a b) (if (= a rank) (set! a (- b 1))
					    (set! b (+ a 1))))
		 (if (< a b) (idxright m 1 a b) (idxleft m 1 b a))))))))

(defbltn 'indexswap
  (lambda (m . args)
    (let ((rank (tnsr:rank m)))
      (cond ((= rank 0) m)
	    ((= rank 1) (map list m))
	    ((or (= rank 2) (null? args)) (apply map list m))
	    (else
	       (let* ((a (car args))
		      (b (if (null? (cdr args)) (+ a 1) (cadr args))))
		 (if (< b a) (let ((c a)) (begin (set! a b) (set! b c))))
		 (if (< a 1)
		     (begin (set! a 1)
			    (if (<= b a) (set! b 2)))
		     (if (> b rank)
			 (begin (set! b rank)
				(if (<= b a) (set! a (- b 1))))
			 (if (= a b) (if (= a rank) (set! a (- b 1))
						     (set! b (+ a 1))))))
		 (idxswap m 1 a b)))))))

(define (tnsr:xpose m)
  (apply map list m))

(define (tnsr:contract m)
  (let loop ((mm (map cdr (cdr m))) (ss (car (car m))))
    (if (null? mm) ss
		   (loop (map cdr (cdr mm)) (app* $1+$2 ss (car (car mm)))))))

(define (contract m n a b d)
  (let l1 ((ma m) (nn n))
  (if (< nn a)
      (map (lambda (mm) (l1 mm (+ nn 1))) ma)
      (let l2 ((mb ma) (aa (+ a 1)))
      (if (< aa b)
	  (map (lambda (mm) (l2 mm (+ aa 1))) (apply map list mb))
	  (let l3 ((mc mb) (bb b))
	  (if (= bb d)
	      (tnsr:contract mc)
	      (map (lambda (mm) (l3 mm (+ bb 1)))
		   (apply map list (map tnsr:xpose mc))))))))))

(defbltn 'contract
  (lambda (m . args)
    (let ((rank (tnsr:rank m)))
      (cond ((= rank 0) m)
	    ((= rank 1) (reduce (lambda (x y) (app* $1+$2 x y)) args))
	    ((= rank 2) (tnsr:contract m))
	    (else
	       (let* ((a (car args))
		      (b (if (null? (cdr args)) (+ a 1) (cadr args))))
		 (if (< b a) (let ((c a)) (begin (set! a b) (set! b c))))
		 (if (< a 1)
		     (begin (set! a 1)
			    (if (<= b a) (set! b 2)))
		     (if (> b rank)
			 (begin (set! b rank)
				(if (<= b a) (set! a (- b 1))))
			 (if (= a b) (if (= a rank) (set! a (- b 1))
						     (set! b (+ a 1))))))
		 (contract m 1 a b rank)))))))


(define (tmult m1 n1 a1 d1 m2 n2 a2 d2)
  (let l1 ((ma1 m1) (nn1 n1))
  (if (< nn1 a1)
      (map (lambda (mm) (l1 mm (+ nn1 1))) ma1)
      (let l2 ((mb1 ma1) (aa1 a1))
      (if (< aa1 d1)
	  (map (lambda (mm) (l2 mm (+ aa1 1))) (apply map list mb1))
	  (let l3 ((ma2 m2) (nn2 n2))
	  (if (< nn2 a2)
	      (map (lambda (mm) (l3 mm (+ nn2 1))) ma2)
	      (let l4 ((mb2 ma2) (aa2 a2))
	      (if (< aa2 d2)
		  (map (lambda (mm) (l4 mm (+ aa2 1))) (apply map list mb2))
		  (reduce (lambda (x y) (app* $1+$2 x y))
			  (map (lambda (x y) (app* $1*$2 x y)) mb1 mb2)))))))))))

(define (outerproduct m1 a1 d1 m2 a2 d2)
  (let l1 ((ma1 m1) (aa1 a1))
  (if (< aa1 d1)
      (map (lambda (mm) (l1 mm (+ aa1 1))) ma1)
      (let l2 ((ma2 m2) (aa2 a2))
      (if (< aa2 d2)
	  (map (lambda (mm) (l2 mm (+ aa2 1))) ma2)
	  (map (lambda (x) (app* $1*$2 x ma2)) ma1))))))

(defbltn 'tmult
  (lambda (m1 m2 . args)
    (let ((r1 (tnsr:rank m1)) (r2 (tnsr:rank m2)))
      (cond ((or (= r1 0) (= r2 0)) (app* $1*$2 m1 m2))
	    ((null? args) (outerproduct m1 1 r1 m2 1 r2))
	    (else
	       (let* ((a (car args))
		      (b (if (null? (cdr args)) a (cadr args))))
		 (if (< a 1) (set! a 1) (if (> a r1) (set! a r1)))
		 (if (< b 1) (set! b 1) (if (> b r2) (set! b r2)))
		 (tmult m1 1 a r1 m2 1 b r2)))))))
