/*
 *
 *			   IPSEC for Linux
 *		         Preliminary Release
 * 
 *	 Copyright (C) 1996, 1997, John Ioannidis <ji@hol.gr>
 * 
 * This program is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation; either version 2 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307
 * USA
 *
 */

/*
 * $Id: ipsec_xform.c,v 0.5 1997/06/03 04:24:48 ji Rel $
 *
 * $Log: ipsec_xform.c,v $
 * Revision 0.5  1997/06/03 04:24:48  ji
 * Added ESP-3DES-MD5-96
 *
 * Revision 0.4  1997/01/15 01:28:15  ji
 * Added new transforms.
 *
 * Revision 0.3  1996/11/20 14:39:04  ji
 * Minor cleanups.
 * Rationalized debugging code.
 *
 * Revision 0.2  1996/11/02 00:18:33  ji
 * First limited release.
 *
 *
 */

#include <linux/config.h>
#include <asm/segment.h>
#include <asm/system.h>
#include <linux/types.h>
#include <linux/kernel.h>
#include <linux/sched.h>
#include <linux/string.h>
#include <linux/errno.h>
#include <linux/config.h>
#include <linux/random.h>

#include <linux/socket.h>
#include <linux/sockios.h>
#include <linux/in.h>
#include <linux/inet.h>
#include <linux/netdevice.h>
#include <linux/etherdevice.h>
#include <linux/icmp.h>
#include <linux/udp.h>
#include <net/ip.h>
#include <net/protocol.h>
#include <net/route.h>
#include <net/tcp.h>
#include <net/udp.h>
#include <net/sock.h>
#include <net/icmp.h>

#include <net/checksum.h>

#include <linux/fs.h>
#include <linux/mm.h>
#include <linux/miscdevice.h>

#include <linux/skbuff.h>
#include <linux/proc_fs.h>
#include <linux/stat.h>

#include <net/netlink.h>
#include <unistd.h>
#include "radij.h"
#include "ipsec_encap.h"
#include "ipsec_radij.h"
#include "ipsec_netlink.h"
#include "ipsec_xform.h"

#include "ipsec_ipe4.h"

#ifdef CONFIG_IPSEC_AH
#include "ipsec_ah.h"
#endif

#ifdef CONFIG_IPSEC_ESP
#include "ipsec_esp.h"
#endif

#ifdef DEBUG_IPSEC_XFORM
int debug_xform = 0;
#endif

/*
 * Common routines for IPSEC transformations.
 */

struct xformsw xformsw[] = {
{ XF_IP4,		0,		"IPv4 Simple Encapsulation",
  ipe4_attach,		ipe4_init,	ipe4_zeroize,
  ipe4_print,		ipe4_room,
  ipe4_input,		ipe4_output, },
{ XF_OLD_AH,		XFT_AUTH,	"Keyed Authentication, RFC 1828/1852",
  ah_old_attach,	ah_old_init,	ah_old_zeroize,	
  ah_old_print,		ah_old_room,
  ah_old_input,		ah_old_output, },
{ XF_OLD_ESP,		XFT_CONF,	"Simple Encryption, RFC 1829/1851",
  esp_old_attach,	esp_old_init,	esp_old_zeroize,
  esp_old_print,	esp_old_room,
  esp_old_input,	esp_old_output, },
{ XF_NEW_AH,		XFT_AUTH,	"HMAC Authentication",
  ah_new_attach,	ah_new_init,	ah_new_zeroize,	
  ah_new_print,		ah_new_room,
  ah_new_input,		ah_new_output, },
{ XF_NEW_ESP,		XFT_CONF|XFT_AUTH,
  "Encryption + Authentication + Replay Protection",
  esp_new_attach,	esp_new_init,	esp_new_zeroize,
  esp_new_print,	esp_new_room,
  esp_new_input,	esp_new_output, },
};

struct xformsw *xformswNXFORMSW = &xformsw[sizeof(xformsw)/sizeof(xformsw[0])];

unsigned char ipseczeroes[IPSEC_ZEROES_SIZE];	/* zeroes! */

/*
 * Reserve an SPI; the SA is not valid yet though. Zero is reserved as
 * an error return value. If tspi is not zero, we try to allocate that
 * SPI. SPIs less than 255 are reserved, so we check for those too.
 */

__u32
reserve_spi(__u32 tspi, struct in_addr src, __u8 proto, int *errval)
{
    struct tdb *tdbp;
    __u32 spi = tspi;		/* Don't change */
    
    while (1)
    {
	while (spi <= 255)		/* Get a new SPI */
	  get_random_bytes((void *) &spi, sizeof(spi));
	
	/* Check whether we're using this SPI already */
	if (gettdb(spi, src, proto) != (struct tdb *) NULL)
	{
	    if (tspi != 0)		/* If one was proposed, report error */
	    {
		(*errval) = EEXIST;
	      	return 0;
	    }

	    spi = 0;
	    continue;
	}
	
	tdbp = (struct tdb *)kmalloc(sizeof(*tdbp), GFP_ATOMIC);
	if (tdbp == NULL)
	{
	    (*errval) = ENOBUFS;
	    return 0;
	} 

	memset((caddr_t) tdbp, 0, sizeof(*tdbp));
	
	tdbp->tdb_spi = spi;
	tdbp->tdb_dst = src;
	tdbp->tdb_sproto = proto;
	tdbp->tdb_flags |= TDBF_INVALID;
	
	puttdb(tdbp);
	
	return spi;
    }
}

/*
 * An IPSP SAID is really the concatenation of the SPI found in the
 * packet, the destination address of the packet and the IPsec protocol.
 * When we receive an IPSP packet, we need to look up its tunnel descriptor
 * block, based on the SPI in the packet and the destination address (which
 * is really one of our addresses if we received the packet!
 */

struct tdb *
gettdb(u_long spi, struct in_addr dst, __u8 proto)
{
	int hashval;
	struct tdb *tdbp;
	
	hashval = (spi+dst.s_addr + proto) % TDB_HASHMOD;
	
	for (tdbp = tdbh[hashval]; tdbp; tdbp = tdbp->tdb_hnext)
	  if ((tdbp->tdb_spi == spi) && (tdbp->tdb_dst.s_addr == dst.s_addr)
		&& (tdbp->tdb_sproto == proto))
	    break;
	
	return tdbp;
}

void
puttdb(struct tdb *tdbp)
{
	int hashval;
	hashval = ((tdbp->tdb_spi + tdbp->tdb_dst.s_addr + tdbp->tdb_sproto) % TDB_HASHMOD);
	tdbp->tdb_hnext = tdbh[hashval];
	tdbh[hashval] = tdbp;
}

int
tdb_delete(struct tdb *tdbp, int delchain)
{
    struct tdb *tdbpp;
    int hashval;

    hashval = ((tdbp->tdb_spi + tdbp->tdb_dst.s_addr + tdbp->tdb_sproto)
	       % TDB_HASHMOD);

    if (tdbh[hashval] == tdbp)
    {
	tdbpp = tdbp;
	tdbh[hashval] = tdbp->tdb_hnext;
    }
    else
      for (tdbpp = tdbh[hashval]; tdbpp != NULL; tdbpp = tdbpp->tdb_hnext)
	if (tdbpp->tdb_hnext == tdbp)
	{
	    tdbpp->tdb_hnext = tdbp->tdb_hnext;
	    tdbpp = tdbp;
	}

    if (tdbp != tdbpp)
      return EINVAL;		/* Should never happen */

    /* If there was something before us in the chain, make it point nowhere */
    if (tdbp->tdb_inext)
      tdbp->tdb_inext->tdb_onext = NULL;

    tdbpp = tdbp->tdb_onext;

    if (tdbp->tdb_xform)
      (*(tdbp->tdb_xform->xf_zeroize))(tdbp);

    kfree(tdbp);

    if (delchain && tdbpp)
      return tdb_delete(tdbpp, delchain);
    else
      return 0;
}

struct flow *
get_flow(void)
{
    struct flow *flow;

    flow = (struct flow *)kmalloc(sizeof(struct flow), GFP_ATOMIC);
    if (flow == (struct flow *) NULL)
      return (struct flow *) NULL;

    memset(flow, 0, sizeof(struct flow));

    return flow;
}

#if 0
struct expiration *
get_expiration(void)
{
    struct expiration *exp;
    
    exp = (struct expiration *)kmalloc(sizeof(struct expiration), GFP_ATOMIC);
    if (exp == (struct expiration *) NULL)
      return (struct expiration *) NULL;

    memset(exp, 0, sizeof(struct expiration));
    
    return exp;
}

void
cleanup_expirations(struct in_addr dst, __u32 spi, __u8 sproto)
{
    struct expiration *exp, *nexp;
    
    for (exp = explist; exp; exp = exp->exp_next)
      if ((exp->exp_dst.s_addr == dst.s_addr) &&
	  (exp->exp_spi == spi) && (exp->exp_sproto == sproto))
      {
	  /* Link previous to next */
	  if (exp->exp_prev == (struct expiration *) NULL)
	    explist = exp->exp_next;
	  else
	    exp->exp_prev->exp_next = exp->exp_next;
	  
	  /* Link next (if it exists) to previous */
	  if (exp->exp_next != (struct expiration *) NULL)
	    exp->exp_next->exp_prev = exp->exp_prev;
	 
	  nexp = exp;
	  exp = exp->exp_prev;
	  kfree(nexp);
      }
}

void 
handle_expirations(void *arg)
{
    struct expiration *exp;
    struct tdb *tdb;
    
    if (explist == (struct expiration *) NULL)
      return;
    
    while (1)
    {
	exp = explist;

	if (exp == (struct expiration *) NULL)
	  return;
	else
	  if (exp->exp_timeout > time.tv_sec)
	    break;
	
	/* Advance pointer */
	explist = explist->exp_next;
	if (explist)
	  explist->exp_prev = NULL;
	
	tdb = gettdb(exp->exp_spi, exp->exp_dst, exp->exp_sproto);
	if (tdb == (struct tdb *) NULL)
	{
	    kfree(exp);
	    continue;			/* TDB is gone, ignore this */
	}
	
	/* Soft expirations */
	if (tdb->tdb_flags & TDBF_SOFT_TIMER)
	{
	  if (tdb->tdb_soft_timeout <= time.tv_sec)
	  {
	      encap_sendnotify(NOTIFY_SOFT_EXPIRE, tdb);
	      tdb->tdb_flags &= ~TDBF_SOFT_TIMER;
	  }
	  else
	    if (tdb->tdb_flags & TDBF_SOFT_FIRSTUSE)
	      if (tdb->tdb_first_use + tdb->tdb_soft_first_use <=
		  time.tv_sec)
	      {
		  encap_sendnotify(NOTIFY_SOFT_EXPIRE, tdb);
		  tdb->tdb_flags &= ~TDBF_SOFT_FIRSTUSE;
	      }
	}

	/* Hard expirations */
	if (tdb->tdb_flags & TDBF_TIMER)
	{
	  if (tdb->tdb_exp_timeout <= time.tv_sec)
	  {
	      encap_sendnotify(NOTIFY_HARD_EXPIRE, tdb);
	      tdb_delete(tdb, 0);
	  }
	  else
	    if (tdb->tdb_flags & TDBF_FIRSTUSE)
	      if (tdb->tdb_first_use + tdb->tdb_exp_first_use <=
		  time.tv_sec)
	      {
		  encap_sendnotify(NOTIFY_HARD_EXPIRE, tdb);
		  tdb_delete(tdb, 0);
	      }
	}

	kfree(exp);
    }

    if (explist)
      timeout(handle_expirations, (void *) NULL, 
	      hz * (explist->exp_timeout - time.tv_sec));
}

void
put_expiration(struct expiration *exp)
{
    struct expiration *expt;
    int reschedflag = 0;
    
    if (exp == (struct expiration *) NULL)
    {
#ifdef ENCDEBUG
	if (encdebug)
	  log(LOG_WARNING, "put_expiration(): NULL argument\n");
#endif /* ENCDEBUG */	
	return;
    }
    
    if (explist == (struct expiration *) NULL)
    {
	explist = exp;
	reschedflag = 1;
    }
    else
      if (explist->exp_timeout > exp->exp_timeout)
      {
	  exp->exp_next = explist;
	  explist->exp_prev = exp;
	  explist = exp;
	  reschedflag = 2;
      }
      else
      {
	  for (expt = explist; expt->exp_next; expt = expt->exp_next)
	    if (expt->exp_next->exp_timeout > exp->exp_timeout)
	    {
		expt->exp_next->exp_prev = exp;
		exp->exp_next = expt->exp_next;
		expt->exp_next = exp;
		exp->exp_prev = expt;
		break;
	    }

	  if (expt->exp_next == (struct expiration *) NULL)
	  {
	      expt->exp_next = exp;
	      exp->exp_prev = expt;
	  }
      }

    switch (reschedflag)
    {
	case 1:
	    timeout(handle_expirations, (void *) NULL, 
		    hz * (explist->exp_timeout - time.tv_sec));
	    break;
	    
	case 2:
	    untimeout(handle_expirations, (void *) NULL);
	    timeout(handle_expirations, (void *) NULL,
		    hz * (explist->exp_timeout - time.tv_sec));
	    break;
	    
	default:
	    break;
    }
}
#endif

struct flow *
find_flow(struct in_addr src, struct in_addr srcmask, struct in_addr dst,
	  struct in_addr dstmask, __u8 proto, __u16 sport,
	  __u16 dport, struct tdb *tdb)
{
    struct flow *flow;

    for (flow = tdb->tdb_flow; flow; flow = flow->flow_next)
      if ((src.s_addr == flow->flow_src.s_addr) &&
	  (dst.s_addr == flow->flow_dst.s_addr) &&
	  (srcmask.s_addr == flow->flow_srcmask.s_addr) &&
	  (dstmask.s_addr == flow->flow_dstmask.s_addr) &&
	  (proto == flow->flow_proto) &&
	  (sport == flow->flow_sport) && (dport == flow->flow_dport))
	return flow;

    return (struct flow *) NULL;
}

struct flow *
find_global_flow(struct in_addr src, struct in_addr srcmask,
		 struct in_addr dst, struct in_addr dstmask,
		 __u8 proto, __u16 sport, __u16 dport)
{
    struct flow *flow;
    struct tdb *tdb;
    int i;

    for (i = 0; i < TDB_HASHMOD; i++)
      for (tdb = tdbh[i]; tdb; tdb = tdb->tdb_hnext)
	if ((flow = find_flow(src, srcmask, dst, dstmask, proto, sport,
			      dport, tdb)) != (struct flow *) NULL)
	  return flow;

    return (struct flow *) NULL;
}

void
put_flow(struct flow *flow, struct tdb *tdb)
{
    flow->flow_next = tdb->tdb_flow;
    flow->flow_prev = (struct flow *) NULL;

    tdb->tdb_flow = flow;

    flow->flow_sa = tdb;

    if (flow->flow_next)
      flow->flow_next->flow_prev = flow;
}

void
delete_flow(struct flow *flow, struct tdb *tdb)
{
    if (tdb->tdb_flow == flow)
    {
	tdb->tdb_flow = flow->flow_next;
	if (tdb->tdb_flow)
	  tdb->tdb_flow->flow_prev = (struct flow *) NULL;
    }
    else
    {
	flow->flow_prev->flow_next = flow->flow_next;
	if (flow->flow_next)
	  flow->flow_next->flow_prev = flow->flow_prev;
    }

    kfree(flow);
}

int
tdb_init(struct tdb *tdbp, struct encap_msghdr *em)
{
	int alg;
	struct xformsw *xsp;
	
	alg = em->em_alg;

	for (xsp = xformsw; xsp < xformswNXFORMSW; xsp++)
	  if (xsp->xf_type == alg)
	    return (*(xsp->xf_init))(tdbp, xsp, em);

#ifdef DEBUG_IPSEC_XFORM
	if (debug_xform & DB_XF_INIT)
	  printk("tdbinit: no alg %d for spi %x, addr %x, proto %d\n", alg, (u_int)tdbp->tdb_spi, (u_int)ntohl(tdbp->tdb_dst.s_addr), tdbp->tdb_sproto);
#endif
	return EINVAL;
}
