/*
 * tcp forwarder
 * coded by Fabien Menemenlis 08/01/2002
 *
 * might be of no use i don't know, it just reconnects to a given ip/port
 * and transfers data from the client to the server and vice versa
 */

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <ctype.h>

#include <netdb.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <sys/param.h>
#include <netinet/in.h>
#include <arpa/inet.h>

#define DEBUG

#define NBLISTEN 10

#define TBUFFER 1024
#define TBUF_INET 2048


struct in_addr ip_dest;
u_short port_dest;

struct sockaddr_in server;

FILE *logfile;

extern int errno;


int createservsock(char *src, int noport)
{
  int servsock;
  struct in_addr srcip;


  if ((servsock = socket(AF_INET, SOCK_STREAM, 0)) == -1)
    return(-1);

  memset(&server, 0, sizeof(server));
  server.sin_family = AF_INET;
  server.sin_port = htons(noport);
  inet_aton(src, &srcip);
  memcpy(&server.sin_addr.s_addr, &srcip, sizeof(srcip));

  if (bind(servsock, (struct sockaddr *)&server, sizeof(server)) == -1)
    return(-1);

  if (listen(servsock, NBLISTEN) == -1)
    return(-1);

  return(servsock);
}


void redirect_traffic(int sock1, int sock2)
{
  char buf[TBUF_INET];
  int maxfd;
  fd_set iofds, r_iofds;
  int readbytes;
  int ret;


  maxfd = sock1 > sock2 ? sock1 : sock2;
  maxfd++;

  FD_ZERO(&iofds);
  FD_SET(sock1, &iofds);
  FD_SET(sock2, &iofds);

  for (;;) {
    memcpy(&r_iofds, &iofds, sizeof(iofds));

    ret = select(maxfd, &r_iofds, (fd_set *)NULL, (fd_set *)NULL, NULL);
    if (ret <= 0)
      break;

    /*
     * data has arrived from the client
     */
    if (FD_ISSET(sock1, &r_iofds)) {
      if ((readbytes = read(sock1, buf, TBUF_INET)) <= 0)
	break;
      write(sock2, buf, readbytes);
    }

    /*
     * coming from the server to the client
     */
    if (FD_ISSET(sock2, &r_iofds)) {
      if ((readbytes = read(sock2, buf, TBUF_INET)) <= 0)
	break;
      write(sock1, buf, readbytes);
    }
  }
}


int main(int argc, char *argv[])
{
  int i;
  int noport;
  int sockserv;

  int sockf;
  struct sockaddr_in servers;
  struct sockaddr srcip;
  struct hostent *host;
  char buffer[TBUFFER + 1];
  char *remotename;
  char remotehost[MAXHOSTNAMELEN + 1];

  FILE *pidfile;


  printf("tcp forwarder\n\n");

  if (argc != 5) {
    fprintf(stderr, "%s <listening IP> <listening port> <dest IP> <dest port>\n", argv[0]);
    return(1);
  }
  noport = atoi(argv[2]);

  inet_aton(argv[3], &ip_dest);
  port_dest = (u_short)atoi(argv[4]);

  if ((sockserv = createservsock(argv[1], noport)) == -1) {
    perror("createservsock");
    return(1);
  }

#ifdef DEBUG
  printf("server socket created on port %d\n", noport);
#endif

  snprintf(buffer, TBUFFER, "%s.log", argv[0]);
  if ((logfile = fopen(buffer, "a")) == NULL)
  {
    perror("can't create log file");
    return(1);
  }
#ifndef SILENT
  fprintf(logfile, "forwarder launched and listening on port %d\n", noport);
  fflush(logfile);
#endif

  snprintf(buffer, TBUFFER, "%s.pid", argv[0]);
  if ((pidfile = fopen(buffer, "w")) == NULL)
  {
    fprintf(logfile, "can't create pid file: %s\n", buffer, strerror(errno));
    return(1);
  }

  if (daemon(NULL, 0) == -1)
  {
    perror(argv[0]);
    return(1);
  }

  fprintf(pidfile, "%d\n", getpid());
  fclose(pidfile);

  i = 0;

  for (;;) {
    struct sockaddr_in client;
    int clisock;
    int clientsize = sizeof(client);
    int found;


    clisock = accept(sockserv, (struct sockaddr *)&client, &clientsize);
    if (clisock == -1) {
      fprintf(logfile, "accept: %s\n", strerror(errno));
      fflush(logfile);
      continue;
    }
#ifndef SILENT
    fprintf(logfile, "connect from %s\n", inet_ntoa(client.sin_addr));
    fflush(logfile);
#endif

    if (fork() == 0) {
      /*
       * attach child to init
       */
      if (fork())
	exit(0);

      if ((sockf = socket(AF_INET, SOCK_STREAM, 0)) == -1)
      {
	fprintf(logfile, "can't create socket: %s\n", strerror(errno));
	fflush(logfile);
      }

      memset(&servers, 0, sizeof(servers));
      servers.sin_family = AF_INET;
      servers.sin_port = htons(port_dest);
      memcpy(&servers.sin_addr, &ip_dest, sizeof(struct in_addr));

      if (connect(sockf, (struct sockaddr *)&servers,
		  sizeof(struct sockaddr_in)) == -1)
      {
	fprintf(logfile, "can't connect to server at %s (%s)\n",
		inet_ntoa(ip_dest), strerror(errno));
	fflush(logfile);
      }

      redirect_traffic(clisock, sockf);

      exit(0);
    }
    else {
      close(clisock);
      wait(NULL);
    }
  }

  return(0);
}
