/*
 * SIS - Source IP Selector
 *
 * (c) 2002 Peter Palfrader <peter@palfrader.org>
 *
 *   This program is free software; you can redistribute it and/or modify
 *   it under the terms of the GNU General Public License as published by
 *   the Free Software Foundation; either version 2 of the License, or
 *   (at your option) any later version.
 *   
 *   This program is distributed in the hope that it will be useful,
 *   but WITHOUT ANY WARRANTY; without even the implied warranty of
 *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *   GNU General Public License for more details.
 *   
 *   You should have received a copy of the GNU General Public License
 *   along with this program; if not, write to the Free Software
 *   Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
 *   
 *
 * $Id: sis.c,v 1.1 2002/05/22 22:34:53 weasel Exp $
 *
 * Release: 0.1.1
 */

// FIXME: check bind() return value

#define _GNU_SOURCE

#include <stdio.h>
#include <stdlib.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <dlfcn.h>
#include <arpa/inet.h>

#define CONNECT_SIGNATURE int __fd, const struct sockaddr * __addr, socklen_t __len
#define SENDTO_SIGNATURE int __fd, const void *__msg, size_t __len, int __flags, const struct sockaddr *__to, socklen_t __tolen

static int (*realconnect)(CONNECT_SIGNATURE);
static int (*realsendto)(SENDTO_SIGNATURE);

/* Exported Function Prototypes */
void _init(void);
int connect(CONNECT_SIGNATURE);
int sendto(SENDTO_SIGNATURE);


void _init(void) {
	realconnect = dlsym(RTLD_NEXT, "connect");
	realsendto = dlsym(RTLD_NEXT, "sendto");
}

int connect(CONNECT_SIGNATURE) {
	struct sockaddr_in *connaddr;
	// struct sockaddr_in peer_address;
	// struct sockaddr_in server_address;
	struct sockaddr_in my_address;
	int sock_type = -1;
	int sock_type_len = sizeof(sock_type);
	
	
	
	if (realconnect == NULL) {
		fprintf(stderr, "Unresolved symbol: connect\n");
		return(-1);
	}
	if (getenv ("SIS_SOURCE") == NULL) {
		fprintf(stderr, "SIS_SOURCE environment is not set, Ignoring.\n");
		return(realconnect(__fd, __addr, __len));
	}


#ifdef DEBUG
   	fprintf(stderr, "Got connection request\n");
#endif

	connaddr = (struct sockaddr_in *) __addr;
	getsockopt(__fd, SOL_SOCKET, SO_TYPE,
			(void *) &sock_type, &sock_type_len);
	

	if (connaddr->sin_family != AF_INET) {
   		fprintf(stderr, "Connection isn't in family AF_INET. Ignoring.\n");
		return(realconnect(__fd, __addr, __len));
	}
	
	// Now we bind - this should probably only be done if we did not already
	{
		memset (&my_address, 0, sizeof (my_address));
		my_address.sin_family = connaddr->sin_family;
		// my_address.sin_addr.s_addr = htonl(INADDR_ANY);
		inet_pton(connaddr->sin_family, getenv ("SIS_SOURCE"), &my_address.sin_addr);

#ifdef DEBUG
   		fprintf(stderr, " Binding.\n");
#endif
		bind( __fd, (struct sockaddr *) &my_address, sizeof (my_address));
	}
	
	return(realconnect(__fd, __addr, __len));
}

int sendto(SENDTO_SIGNATURE) {
	struct sockaddr_in *connaddr;
	// struct sockaddr_in peer_address;
	// struct sockaddr_in server_address;
	struct sockaddr_in my_address;
	int sock_type = -1;
	int sock_type_len = sizeof(sock_type);
	
// #define SENDTO_SIGNATURE int __fd, const void *__msg, size_t __len, int __flags, const struct sockaddr *__to, socklen_t __tolen
	
	
	if (realsendto == NULL) {
		fprintf(stderr, "Unresolved symbol: sendto\n");
		return(-1);
	}
	if (getenv ("SIS_SOURCE") == NULL) {
		fprintf(stderr, "SIS_SOURCE environment is not set, Ignoring.\n");
		return(realsendto(__fd, __msg, __len, __flags, __to, __tolen));
	}

#ifdef DEBUG
   	fprintf(stderr, "Got sendto request\n");
#endif

	connaddr = (struct sockaddr_in *) __to;
	getsockopt(__fd, SOL_SOCKET, SO_TYPE,
			(void *) &sock_type, &sock_type_len);
	

	if (connaddr->sin_family != AF_INET) {
   		fprintf(stderr, "Connection isn't in family AF_INET. Ignoring.\n");
		return(realsendto(__fd, __msg, __len, __flags, __to, __tolen));
	}
	
	// Now we bind - this should probably only be done if we did not already
	{
		memset (&my_address, 0, sizeof (my_address));
		my_address.sin_family = connaddr->sin_family;
		// my_address.sin_addr.s_addr = htonl(INADDR_ANY);
		inet_pton(connaddr->sin_family, getenv ("SIS_SOURCE"), &my_address.sin_addr);

#ifdef DEBUG
   		fprintf(stderr, " Binding.\n");
#endif
		bind( __fd, (struct sockaddr *) &my_address, sizeof (my_address));
	}
	
	return(realsendto(__fd, __msg, __len, __flags, __to, __tolen));
}