/*
 *
 *			   IPSEC for Linux
 *			 Preliminary Release
 * 
 *	 Copyright (C) 1996, 1997, John Ioannidis <ji@hol.gr>
 *       Copyright (C) 1996-1997 Robert Muchsel <muchsel@acm.org>
 *
 * Interface and firewall code ported from ENskip-0.67 by
 *  Petr Novak <pn@i.cz>
 *
 * This program is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation; either version 2 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307
 * USA
 *
 */

#define __NO_VERSION__
#include <linux/module.h>
#include <linux/config.h>
#include <linux/types.h>
#include <netinet/ip.h>
#include <netinet/udp.h>
#include <netinet/ip_icmp.h>
#include <linux/ip.h>
#include <linux/in.h>
#include <linux/skbuff.h>
#include <linux/netdevice.h>
#include <linux/firewall.h>
#include <linux/string.h>
#include <linux/ipsec.h>
#include <net/route.h>
#include <net/ip.h>
#include <net/icmp.h>
#include <net/netlink.h>
#include <unistd.h>
#include "radij.h"
#include "ipsec_encap.h"
#include "ipsec_radij.h"
#include "ipsec_netlink.h"
#include "ipsec_xform.h"
#include "ipsec_esp.h"
#include "ipsec_ah.h"
#include "ipsec_fw.h"

#ifdef DEBUG_IPSEC_TUNNEL
int debug_tunnel = -1;
#endif

static inline void interface_ship_out(struct sk_buff *);
static int dev_ipsec(struct device *, char *);
static int ipsec_output_packet(struct sk_buff **nskb);

static int maxheadergrowth = 0;

#ifdef DEBUG_IPSEC_TUNNEL
static void print_ip(struct iphdr *ip)
{
	unsigned char *ipaddr;
#if 0
	int i;
#endif

	printk("header len = %d, ", ip->ihl*4);
	printk("ip version: %d, ", ip->version);
	printk("ip protocol: %d\n", ip->protocol);
	ipaddr=(unsigned char *)&ip->saddr;
	printk("ipsec_print_ip: src addr: %u.%u.%u.%u, ", 
			*ipaddr, *(ipaddr+1), *(ipaddr+2), *(ipaddr+3));
	ipaddr=(unsigned char *)&ip->daddr;
	printk("dst addr: %u.%u.%u.%u, ", 
			*ipaddr, *(ipaddr+1), *(ipaddr+2), *(ipaddr+3));
	printk("pkt len: %d\n", ntohs(ip->tot_len));
#if 0
	printk("payload:");
	for (i=0; i < ntohs(ip->tot_len)-sizeof(struct iphdr); i++)
	{
		printk(" %02x", *((unsigned char *)(ip+1)+i));
	}
	printk("\n");
#endif
}
#endif DEBUG_IPSEC_TUNNEL

/* This is the output interceptor.  The feedback function is a kludge
   (we use ip_forward in lack of something better). There's also a new 
   return code, FW_QUEUE, which works like FW_BLOCK but does not cause a 
   "permission denied" message on local packets (kernel function 
   ip_build_xmit). 
   Note: The dummies are all 0, arg is for debugging only. */

int output_packet(struct firewall_ops *dummy1, int pf, struct device *dev,
		  void *dummy3, void *arg, struct sk_buff **pskb)
{
  struct iphdr *ipp = (*pskb)->ip_hdr;
  struct sk_buff *newm, *qskb;
  int result;
  int retval = FW_BLOCK;

  /* ignore loopback */
  if (!dev_ipsec(dev, "output")) {
#ifdef	DEBUG_IPSEC_TUNNEL
	printk("ipsec_fw:output_packet: returning - device %s not IPSEC enabled\n", dev->name);
#endif
    	return FW_SKIP;
  }

  /* Recursion happens if the datagram has been fragmented:
     ip_queue_xmit -> ipsec -> ip_fragment -> ip_queue_xmit -> ipsec
  */
  
#ifdef	DEBUG_IPSEC_TUNNEL
  printk("ipsec_fw:output_packet >>> called with arg=%d\n", (int)(arg));
#endif

  if ((*pskb)->proto_priv[15] & SND_SEC)
    return FW_ACCEPT;

  /* this happens, too */
  if (ipp->protocol == IPPROTO_AH || ipp->protocol == IPPROTO_ESP)
    return FW_ACCEPT;

  if ((*pskb)->sk != NULL) {
    if ((*pskb)->sk->authentication == IPSEC_LEVEL_NONE &&
        (*pskb)->sk->encryption     == IPSEC_LEVEL_NONE) {
#ifdef	DEBUG_IPSEC_TUNNEL
	printk("ipsec_fw:output_packet: returning - socket has disabled IPSEC\n");
#endif
	return FW_SKIP;
    }
  }

  /* Clone the input skb; locks its data. Never free it while it is queued! */
  /* Only the sk_buff part is cloned, the data is left in the old skb */
  qskb = newm = skb_clone(*pskb, GFP_ATOMIC);

  result = ipsec_output_packet(&newm);

  if (result == IPSEC_PROCESSED)
  {
    /* nothing happened - if not transformed, we just say "OK" */ 

    kfree_skb(newm, FREE_WRITE);
    if (newm != qskb)
      kfree_skb(qskb, FREE_WRITE);

    /* check user level policy */
    if (!((*pskb)->sk &&
         ((*pskb)->sk->authentication >= IPSEC_LEVEL_USE ||
          (*pskb)->sk->encryption     >= IPSEC_LEVEL_USE))) {

#ifdef	DEBUG_IPSEC_TUNNEL
	printk("ipsec_fw:output_packet: ipsec void & socket OK\n");
#endif
       /* Set marker in skb + accept packet */
      (*pskb)->proto_priv[15] |= SND_SEC;
      retval = FW_SKIP;
    }
  }
  else if (result > IPSEC_PROCESSED)
  {
    /* tunneled/encrypted/authenticated */

    /* check user level policy */
    if ((*pskb)->sk &&
        ((!(result & IPSEC_P_AUTH) && 
          (*pskb)->sk->authentication >= IPSEC_LEVEL_USE) ||
         (!(result & IPSEC_P_ENCRYPT) &&
          (*pskb)->sk->encryption >= IPSEC_LEVEL_USE))) {

#ifdef	DEBUG_IPSEC_TUNNEL
		printk("ipsec_fw:output_packet: droping because of socket auth level\n");
#endif

      kfree_skb(newm, FREE_WRITE);
      if (newm != qskb)
        kfree_skb(qskb, FREE_WRITE);
    }
    else
    {
      newm->h.iph     = newm->ip_hdr;
      newm->ip_summed = 0;	/* ??? XXX */

      if (newm != qskb)
        kfree_skb(qskb, FREE_WRITE);

      /* Set marker in skb */
      newm->proto_priv[15] |= SND_SEC;

      interface_ship_out(newm);

      retval = FW_QUEUE;
    }
  }
  else if (result == IPSEC_QUEUED)
  {
    /* queued, will be fed back to us */

    if (newm != qskb)
      kfree_skb(newm, FREE_WRITE);
    
    qskb->dev = dev;

    retval = FW_QUEUE;
  }
  else
  {
    /* bad packet/policy/etc. */
    kfree_skb(newm, FREE_WRITE);
    if (newm != qskb)
      kfree_skb(qskb, FREE_WRITE);
  }

  return retval;
}


/* The forward function is used to get the true device MTU only.
   This is a hack... Real packets are forwarded via a pair of
   input_packet/output_packet. */

int forward_packet(struct firewall_ops *dummy1, int dummy2, struct device *dev,
		  void *dummy3, void *arg, struct sk_buff **dummy4)
{
  /* hack to return true device MTU; arg is pointer to MTU */

  if (dev_ipsec(dev, "forward") && arg)
    *((unsigned short *) arg) += maxheadergrowth;

  return FW_ACCEPT;
}


/* and the null op */

int packet_nop(struct firewall_ops *this, int pf, struct device *dev,
	       void *phdr, void *arg, struct sk_buff **pskb)
{
  return FW_SKIP;
}

/* and one for keeping away standard firewalls -- on forward, a standard
   forward firewall call would be made. This call would also be made
   with local packets AND it would get the encrypted packet. */

int packet_accept(struct firewall_ops *this, int pf, struct device *dev,
	       void *phdr, void *arg, struct sk_buff **pskb)
{
  if ((*pskb)->proto_priv[15] & SND_SEC)
    return FW_ACCEPT;
  else
    return FW_SKIP;
}

/*
 *	This function assumes it is being called from output_packet()
 *	and that skb is filled properly by that function.
 */

static int
ipsec_output_packet(struct sk_buff **nskb)
{
	struct sk_buff *skb = *nskb;
	struct iphdr  *iph;		/* Our new IP header */
	__u32          target;		/* The other host's IP address */
	int	iphlen;
	int	pyldsz;
	int     max_headroom;		/* The extra header space needed */
	int	max_tailroom;		/*  The extra stuffing needed */

	struct udphdr *udpp;
	struct tcphdr *tcpp;
	

	struct sockaddr_encap matcher, *gw;
	struct eroute *er;
	struct tdb *tdbp, *tdbq;
	int oerror;
	
	/*
	 *	Return if there is nothing to do.  (Does this ever happen?)
	 */
	if (skb == NULL)
        {
#ifdef DEBUG_IPSEC_TUNNEL
		if (debug_tunnel & DB_TN_XMIT)
			printk ( KERN_INFO "ipsec_fw:ipsec_output_packet: Nothing to do!\n" );
#endif DEBUG_IPSEC_TUNNEL
		return IPSEC_DISCARD;
	}

	iph = skb->ip_hdr;
	iphlen = iph->ihl << 2;
	pyldsz = ntohs(iph->tot_len) - iphlen;

#ifdef DEBUG_IPSEC_TUNNEL
	if (debug_tunnel & DB_TN_XMIT)
	{
		printk("ipsec_fw:ipsec_output_packet:skb-> len=%ld data=%p iph=%p ip_hdr=%p head=%p tail=%p end=%p saddr=%x daddr=%x\n", skb->len, skb->data, skb->h.iph, skb->ip_hdr, skb->head, skb->tail, skb->end, skb->saddr, skb->daddr); 
		printk("ipsec_fw:ipsec_output_packet: packet contents:");
		print_ip(iph);
	}
#endif DEBUG_IPSEC_TUNNEL

	/*
	 * First things first -- look us up in the erouting tables.
	 */

	matcher.sen_len = SENT_IP4_LEN;
	matcher.sen_family = AF_ENCAP;
	matcher.sen_type = SENT_IP4;
	
	matcher.sen_ip_src.s_addr = iph->saddr;
	matcher.sen_ip_dst.s_addr = iph->daddr;
	matcher.sen_proto = iph->protocol;

	switch (iph->protocol)
	{
	  case IPPROTO_UDP:
		udpp = (struct udphdr *)((char *)skb->ip_hdr + iphlen);
		matcher.sen_sport = ntohs(udpp->source);
		matcher.sen_dport = ntohs(udpp->dest);
		break;

	  case IPPROTO_TCP:
		tcpp = (struct tcphdr *)((char *)skb->ip_hdr + iphlen);
		matcher.sen_sport = ntohs(tcpp->source);
		matcher.sen_dport = ntohs(tcpp->dest);
		break;

	  default:
		matcher.sen_sport = 0;
		matcher.sen_dport = 0;
	}
	
	er = ipsec_findroute(&matcher);
	if (er == NULL)
	{
#ifdef DEBUG_IPSEC_TUNNEL
		printk("ipsec_fw:ipsec_output_packet: Packet [%lx->%lx] with no e-route\n", 
		       ntohl(iph->saddr), ntohl(iph->daddr));
#endif DEBUG_IPSEC_TUNNEL
		return IPSEC_PROCESSED;
	}

	gw = (struct sockaddr_encap *)&(er->er_dst);
	target = gw->sen_ipsp_dst.s_addr;

	if ((gw != NULL) && (target == 0) &&
		(gw->sen_ipsp_sproto == 0) && (gw->sen_ipsp_spi == 0))
	{
		/* The eroute tells us not to do any IPSEC processing */
		return IPSEC_PROCESSED;
	}

	if (gw == NULL || gw->sen_type != SENT_IPSP)
	{
		printk("ipsec_fw:ipsec_output_packet: no gw or gw data not IPSP\n");
		return IPSEC_PROCESSED;
	}

	tdbp = gettdb(gw->sen_ipsp_spi, gw->sen_ipsp_dst, gw->sen_ipsp_sproto);
	if (tdbp == NULL)
	{
		printk( KERN_INFO "ipsec_fw:ipsec_output_packet: no tdb for spi=%x dst=%08x\n", 
		       (u_int)ntohl(gw->sen_ipsp_spi),
			(u_int)ntohl(target));
		return IPSEC_PROCESSED;
	}

	skb_pull(skb, (char *)skb->ip_hdr - (char *)skb->data);
#ifdef	DEBUG_IPSEC_TUNNEL
	printk("ipsec_fw:ipsec_output_packet: skb after pull: len=%ld data=%p iph=%p ip_hdr=%p head=%p tail=%p end=%p\n", skb->len, skb->data, skb->h.iph, skb->ip_hdr, skb->head, skb->tail, skb->end); 
#endif
	max_headroom = max_tailroom = 0;
	
	tdbq = tdbp;			/* save */
	while (tdbp && tdbp->tdb_xform)
	{
		/* Check for tunneling */
		if (tdbp->tdb_flags & TDBF_TUNNELING)
		{
#ifdef DEBUG_IPSEC_TUNNEL
			if (debug_tunnel & DB_TN_CROUT)
				printk("ipsec_fw:ipsec_output_packet: calling room for IPE4\n");
#endif DEBUG_IPSEC_TUNNEL
			oerror = ipe4_room(tdbp, iphlen, pyldsz + max_headroom + max_tailroom, &max_headroom, &max_tailroom);
		}

#ifdef DEBUG_IPSEC_TUNNEL
		if (debug_tunnel & DB_TN_CROUT)
			printk("ipsec_fw:ipsec_output_packet: calling room for <%s>\n", 
			       tdbp->tdb_xform->xf_name);
#endif DEBUG_IPSEC_TUNNEL
		oerror = (*(tdbp->tdb_xform->xf_room))
		  (tdbp, iphlen, pyldsz + max_headroom + max_tailroom, &max_headroom, &max_tailroom);
		
		tdbp = tdbp->tdb_onext;
	}
	tdbp = tdbq;			/* restore */

	/*
	 * Okay, now see if we can stuff it in the buffer as-is.
	 * (the test assumes that the packet goes out via the same dev
	 * as before. This may be not true while tunneling. XXX
	 */
	max_headroom += ((skb->dev->hard_header_len+15)&~15);

#ifdef DEBUG_IPSEC_TUNNEL
	if (debug_tunnel & DB_TN_CROUT)
	{
		printk("ipsec_fw:ipsec_output_packet: Room left at head, tail: %d,%d\n", skb_headroom(skb), skb_tailroom(skb));
		printk("ipsec_fw:ipsec_output_packet: Required room: %d,%d\n", 
		       max_headroom, max_tailroom);
	}
#endif DEBUG_IPSEC_TUNNEL

	if ((skb_headroom(skb) >= max_headroom) && 
	    (skb_tailroom(skb) >= max_tailroom) && skb->free)
	{
#ifdef DEBUG_IPSEC_TUNNEL
		if (debug_tunnel & DB_TN_CROUT)
			printk("ipsec_fw:ipsec_output_packet: data fits in existing skb\n");
#else /*  DEBUG_IPSEC_TUNNEL */
		;
#endif DEBUG_IPSEC_TUNNEL

	}
	else
	{
		struct sk_buff *new_skb;

#ifdef DEBUG_IPSEC_TUNNEL
		if (debug_tunnel & DB_TN_CROUT)
			printk("ipsec_fw:ipsec_output_packet: allocating new skb\n");
#endif DEBUG_IPSEC_TUNNEL
		if ( !(new_skb = alloc_skb(skb->len+max_headroom+max_tailroom, GFP_ATOMIC)) ) 
		{
			printk( KERN_INFO "ipsec_fw:ipsec_output_packet: Out of memory, dropped packet\n");
			return IPSEC_DISCARD;
		}
		new_skb->free = 1;

		/*
		 * Reserve space for our header and the lower device header
		 */
		skb_reserve(new_skb, max_headroom);

		/*
		 * Copy the old packet to the new buffer.
		 * Note that new_skb->h.iph will be our (tunnel driver's) header
		 * and new_skb->ip_hdr is the IP header of the old packet.
		 */
		new_skb->ip_hdr = (struct iphdr *) skb_put(new_skb, skb->len);
		new_skb->dev = skb->dev;
		memcpy(new_skb->ip_hdr, skb->data, skb->len);
		memset(new_skb->proto_priv, 0, sizeof(skb->proto_priv));

		/* Tack on our header */
		new_skb->h.iph = (struct iphdr *)(new_skb->data);
		
		/* Free the old packet, we no longer need it */
		kfree_skb(skb, FREE_WRITE);
		skb = new_skb;
		*nskb = skb;
	}

	while (tdbp && tdbp->tdb_xform)
	{
		if (tdbp->tdb_flags & TDBF_TUNNELING) {
#ifdef DEBUG_IPSEC_TUNNEL
			if (debug_tunnel & DB_TN_OXFS)
				printk("ipsec_fw:ipsec_output_packet: calling output for IPE4...");
#endif DEBUG_IPSEC_TUNNEL
			oerror = ipe4_output(skb, tdbp);
#ifdef DEBUG_IPSEC_TUNNEL
			if (debug_tunnel & DB_TN_OXFS)
				printk("ipsec_fw:ipsec_output_packet: returns %d\n", oerror);
#endif DEBUG_IPSEC_TUNNEL
		}
		
#ifdef DEBUG_IPSEC_TUNNEL
		if (debug_tunnel & DB_TN_OXFS)
			printk("ipsec_fw:ipsec_output_packet: calling output for <%s>...", 
			       tdbp->tdb_xform->xf_name);
#endif DEBUG_IPSEC_TUNNEL
		oerror = (*(tdbp->tdb_xform->xf_output))(skb, tdbp);
#ifdef DEBUG_IPSEC_TUNNEL
		if (debug_tunnel & DB_TN_OXFS)
			printk("ipsec_fw:ipsec_output_packet: returns %d\n", oerror);
#endif DEBUG_IPSEC_TUNNEL
		tdbp = tdbp->tdb_onext;
	}

#ifdef DEBUG_IPSEC_TUNNEL
	if (debug_tunnel & DB_TN_XMIT)
	{
		printk("ipsec_fw:ipsec_output_packet: packet contents after xforms:");
		print_ip(skb->h.iph);
	}
#endif DEBUG_IPSEC_TUNNEL

	return IPSEC_PROCESSED | IPSEC_P_AUTH | IPSEC_P_ENCRYPT;
}

/* Put skb back to kernel output processing. Caution! Causes
   the output interceptor to be called recursively. Since it
   is done via the timer, there should be no way to crash the
   stack!? */

static inline void interface_ship_out(struct sk_buff *skb)
{
  IS_SKB(skb);

  if (ip_forward(skb, skb->dev, IPFWD_NOTTLDEC, skb->h.iph->daddr))
    kfree_skb(skb, FREE_WRITE);
}


/* feed back queued packet */
int interface_feed_out(struct sk_buff *skb)
{
  IS_SKB(skb);


  /* fix the packet for ip_forward (because ip_build_xmit might not have) */
  skb->h.iph = skb->ip_hdr;

  /* check it */
  if (output_packet(NULL, PF_IPSEC, skb->dev, NULL, NULL, &skb) < FW_ACCEPT)
    kfree_skb(skb, FREE_WRITE);
  else {
    /* ...and ship it */
    interface_ship_out(skb);
  }

  return 0;
}


/* feed back queued packet */
int interface_feed_in(struct sk_buff *skb)
{
  skb->protocol  = htons(ETH_P_IP);
  skb->ip_summed = 0;

  netif_rx(skb); /* frees skb for us */

  return 0;
}

struct devlist {
  char *name;
  struct devlist *next;
};

static struct devlist *ipsec_devs = NULL;

/* Check interface */
static int dev_ipsec(struct device *dev, char *s)
{
  struct devlist *tmp = ipsec_devs;

  /* XXX */
  if (!tmp) {
    printk("Warning: dev_ipsec %s: dev==NULL\n", s);
    return 0;
  }
  /* XXX */

  while (tmp && strcmp(tmp->name, dev->name))
    tmp = tmp->next;

  return (tmp != NULL);
}

/* Add interface to list */
static void dev_addlist(struct device *dev)
{
  struct devlist *tmp;

  if (ipsec_devs == NULL) {
    ipsec_devs = kmalloc(sizeof(*ipsec_devs), GFP_ATOMIC);
    ipsec_devs->name = kmalloc(strlen(dev->name) + 1, GFP_ATOMIC);
    strcpy(ipsec_devs->name, dev->name);
    ipsec_devs->next = NULL;
  }
  else {
    tmp = ipsec_devs;
    while (tmp->next != NULL)
      tmp = tmp->next;
    tmp->next = kmalloc(sizeof(*tmp), GFP_ATOMIC);
    tmp = tmp->next;
    tmp->name = kmalloc(strlen(dev->name) + 1, GFP_ATOMIC);
    strcpy(tmp->name, dev->name);
    tmp->next = NULL;
  }
}

/* Remove interface from list */
static void dev_rmlist(struct device *dev)
{
  struct devlist *tmp, *tmp2;

  if (ipsec_devs == NULL)
    return;

  if (strcmp(ipsec_devs->name, dev->name) == 0) {
    tmp = ipsec_devs->next;
    kfree(ipsec_devs->name);
    kfree(ipsec_devs);
    ipsec_devs = tmp; 
  }
  else {
    tmp = ipsec_devs;
    while (tmp->next && strcmp(tmp->next->name, dev->name))
      tmp = tmp->next;

    if (tmp->next == NULL)
      return;

    tmp2 = tmp->next->next;
    kfree(tmp->next->name);
    kfree(tmp->next);
    tmp->next = tmp2;
  }
}

/* Attach/detach to/from interfaces */
int interface_attach(void *dummy, u_char *ipaddr)
{
  struct device *dev;
  int result = -1;

  for (dev = dev_base; dev != NULL; dev = dev->next) { 
    if ((dev->family == AF_INET) && (dev->pa_addr == *((__u32 *) ipaddr))
      && !dev_ipsec(dev, "attach")) {
      if (dev->mtu < 68 + maxheadergrowth)
        printk("ipsec_fw:interface_attach: %s: interface mtu of %d is too small\n",
                dev->name, dev->mtu);
      else {
        dev->mtu -= maxheadergrowth;
	dev_addlist(dev);
        result = 0;
      }
    }
  }

  return result;
}

int interface_detach(void *dummy, u_char *ipaddr)
{
  struct device *dev;
  int result = -1;

  for (dev = dev_base; dev != NULL; dev = dev->next) { 
    if ((dev->family == AF_INET) && (dev->pa_addr == *((__u32 *) ipaddr)) 
       && dev_ipsec(dev, "detach")) {
      dev_rmlist(dev);
      dev->mtu += maxheadergrowth;
      result = 0;
    }
  }

  return result;
}


/* Decrease the MTUs of all IPSEC interfaces */

int interface_init(void)
{
  struct device *dev;

  maxheadergrowth = 96;	/* !!! XXX */

  for (dev = dev_base; dev != NULL; dev = dev->next) { 
    if (dev->family == AF_INET) {
      if (dev->mtu < 68 + maxheadergrowth)
        printk("ipsec_fw:interface_init: %s: interface mtu of %d is too small\n",
                dev->name, dev->mtu);
      else {
        if ((dev->flags & IFF_LOOPBACK) == 0) {
          dev->mtu -= maxheadergrowth;
          dev_addlist(dev);
        }
      }
    }
  }

  return 0;
}


/* Restore the MTUs */

int interface_exit(void)
{
  struct device *dev;

  for (dev = dev_base; dev != NULL; dev = dev->next) { 
    if ((dev->family == AF_INET) && dev_ipsec(dev, "exit")) {
      dev_rmlist(dev);
      dev->mtu += maxheadergrowth;
    }
  }

  return 0;
}
