/*

sshsocksfilter.c

  Author: Sami J. Lehtinen <sjl@ssh.com>

  Created: Wed Dec 12 16:12:28 2001.

  Copyright (C) 2001-2002 SSH Communications Security Corp, Helsinki, Finland
  All rights reserved.

*/

#include "ssh2includes.h"
#include "sshsocksfilter.h"
#include "sshfsm.h"
#include "sshchtcpfwd.h"
#include "sshtimeouts.h"
#include "sshsocks.h"

#define SSH_DEBUG_MODULE "SshSocksFilter"

typedef struct SshSocksFilterRec
{
  SshFSM fsm;
  SshStream stream;
  SshCommon common;
  SshFSMThread main;
  SshFSMThread writer;
  SshFSMThread reader;

  Boolean flushing_data;

  SshBuffer in_buf;
  SshBuffer out_buf;
  SshFSMCondition read_more;
  SshFSMCondition written_some;
  SshFSMCondition in_buf_shrunk;
  SshFSMCondition out_buf_grown;

  SshOperationHandle upper_handle;
  SshOperationHandle handle;

  unsigned int error;
  unsigned int socks_version;
  SshTcpForwardActive fwd;
  
  SshSocksFilterCompletionCB completion;
  void *completion_context;
} SshSocksFilterStruct, *SshSocksFilter;

/* Forward declarations. */
/* State functions. */
SSH_FSM_STEP(finish_fsm);
SSH_FSM_STEP(recv_methods);
SSH_FSM_STEP(send_methods);
SSH_FSM_STEP(recv_request);
SSH_FSM_STEP(connect_to_dst);
SSH_FSM_STEP(send_reply);
SSH_FSM_STEP(send_error);
/* Writer thread step functions. */
SSH_FSM_STEP(write_out);
SSH_FSM_STEP(read_in);

static void iocb(SshStreamNotification notification, void *context);

static void socks_filter_abort(void *context);

/*** Public functions **************************************************/
SshOperationHandle
ssh_socks_filter_create(SshStream stream, SshCommon common,
                        SshTcpForwardActive active_forward,
                        SshSocksFilterCompletionCB completion,
                        void *completion_context)
{
  SshSocksFilter filter = ssh_xcalloc(1, sizeof(*filter));
  SshOperationHandle handle;
  
  filter->fsm = ssh_fsm_create(filter);

  filter->stream = stream;
  filter->common = common;

  filter->in_buf = ssh_xbuffer_allocate();
  filter->out_buf = ssh_xbuffer_allocate();
  filter->read_more = ssh_fsm_condition_create(filter->fsm);
  filter->in_buf_shrunk = ssh_fsm_condition_create(filter->fsm);
  filter->written_some = ssh_fsm_condition_create(filter->fsm);
  filter->out_buf_grown = ssh_fsm_condition_create(filter->fsm);
  filter->fwd = active_forward;
  filter->completion = completion;
  filter->completion_context = completion_context;
  
  SSH_ASSERT(filter->read_more != NULL);
  SSH_ASSERT(filter->in_buf_shrunk != NULL);
  SSH_ASSERT(filter->written_some != NULL);
  SSH_ASSERT(filter->out_buf_grown != NULL);

  filter->main = ssh_fsm_thread_create(filter->fsm, recv_methods, NULL_FNPTR,
                                       NULL_FNPTR, NULL);

  /* We cannot use filterstream, because we can't set the callbacks. We
     can't use the FSM streamstub, because it can't be told to let go of
     the stream. We can't use packetstream, because we are not dealing
     with a binary packet protocol. What to do? We implement our own
     (again!).*/
  filter->writer = ssh_fsm_thread_create(filter->fsm, write_out, NULL_FNPTR,
                                         NULL_FNPTR, NULL);

  filter->reader = ssh_fsm_thread_create(filter->fsm, read_in, NULL_FNPTR,
                                         NULL_FNPTR, NULL);
  SSH_ASSERT(filter->writer != NULL);
  SSH_ASSERT(filter->reader != NULL);

  ssh_stream_set_callback(filter->stream, iocb, filter);

  /* add active channel to the common object's active channel list */
  ssh_common_add_active_forward(filter->fwd);

  handle = ssh_operation_register(socks_filter_abort, filter);
  SSH_ASSERT(handle != NULL);
  filter->upper_handle = handle;
  return handle;
}

/*** Private functions *************************************************/

static void error_timeout_cb(void *ctx)
{
  SshSocksFilter filter = (SshSocksFilter)ctx;

  if (filter->main)
    ssh_fsm_continue(filter->main);
}

static void socks_filter_abort(void *context)
{
  SshSocksFilter filter = (SshSocksFilter) context;

  SSH_DEBUG(2, ("Aborting SocksFilter..."));
  if (filter->handle)
    {
      ssh_operation_abort(filter->handle);
      filter->handle = NULL;
    }
  if (filter->stream)
    ssh_stream_set_callback(filter->stream, NULL_FNPTR, NULL);
  
  filter->completion = NULL_FNPTR;
  ssh_cancel_timeouts(error_timeout_cb, filter);

  filter->upper_handle = NULL;
  ssh_fsm_set_next(filter->main, finish_fsm);
  ssh_fsm_continue(filter->main);
}

/* Destroying the threads and the FSM. */

SSH_FSM_STEP(finish_fsm)
{
  SshSocksFilter filter = (SshSocksFilter)fsm_context;

  ssh_fsm_destroy(filter->fsm);
  if (filter->reader)
    ssh_fsm_kill_thread(filter->reader);
  if (filter->writer)
    ssh_fsm_kill_thread(filter->writer);

  ssh_buffer_free(filter->in_buf);
  ssh_buffer_free(filter->out_buf);
  ssh_fsm_condition_destroy(filter->read_more);
  ssh_fsm_condition_destroy(filter->written_some);
  ssh_fsm_condition_destroy(filter->in_buf_shrunk);
  ssh_fsm_condition_destroy(filter->out_buf_grown);

  if (filter->stream)
    ssh_stream_destroy(filter->stream);
  
  if (filter->completion != NULL_FNPTR)
    (*filter->completion)(filter->completion_context);
  
  ssh_cancel_timeouts(error_timeout_cb, filter);
  if (filter->upper_handle)
    ssh_operation_unregister(filter->upper_handle);
  
  ssh_xfree(filter);
  SSH_TRACE(2, ("SocksFilter finished."));
  return SSH_FSM_FINISH;
}

SSH_FSM_STEP(recv_methods)
{
  SshSocksFilter filter = (SshSocksFilter)fsm_context;
  SocksError err;
  SocksInfo socksinfo;
  
  err = ssh_socks_server_parse_methods(filter->in_buf, &socksinfo);
  if (err == SSH_SOCKS_TRY_AGAIN)
    {
      SSH_FSM_CONDITION_SIGNAL(filter->in_buf_shrunk);
      SSH_FSM_CONDITION_WAIT(filter->read_more);
    }
  else if (err != SSH_SOCKS_SUCCESS)
    {
      SSH_DEBUG(2, ("Got err %d from SshSocks.", err));
      SSH_FSM_SET_NEXT(finish_fsm);
      return SSH_FSM_CONTINUE;
    }
  filter->socks_version = socksinfo->socks_version_number;
  ssh_socks_free(&socksinfo);
  SSH_FSM_SET_NEXT(send_methods);
  return SSH_FSM_CONTINUE;
}

SSH_FSM_STEP(send_methods)
{
  SshSocksFilter filter = (SshSocksFilter)fsm_context;
  struct SocksInfoRec socksinfo;
  SocksError err;
  
  memset(&socksinfo, 0, sizeof(socksinfo));
  socksinfo.socks_version_number = filter->socks_version;

  /* Spec says a compliant implementation MUST support at least GSSAPI
     and SHOULD support USERNAME/PASSWORD. Because we only intend to
     support connections from localhost, GSSAPI is aiming a bit high. */

  /* XXX implement username/password auth. */

  /* If no authentication is required, send "no authentication required"
     packet. */

  err = ssh_socks_server_generate_method(filter->out_buf, &socksinfo);

  if (err != SSH_SOCKS_SUCCESS)
    {
      SSH_DEBUG(2, ("Got err %d from SshSocks.", err));
      SSH_FSM_SET_NEXT(finish_fsm);
      return SSH_FSM_CONTINUE;
    }
  SSH_FSM_CONDITION_SIGNAL(filter->out_buf_grown);

  SSH_FSM_SET_NEXT(recv_request);
  return SSH_FSM_YIELD;
}

SSH_FSM_STEP(recv_request)
{
  SshSocksFilter filter = (SshSocksFilter)fsm_context;
  SocksInfo socksinfo;
  SocksError err;
  
  err = ssh_socks_server_parse_open(filter->in_buf, &socksinfo);
  if (err == SSH_SOCKS_TRY_AGAIN)
    {
      SSH_FSM_CONDITION_SIGNAL(filter->in_buf_shrunk);
      SSH_FSM_CONDITION_WAIT(filter->read_more);
    }
  else if (err != SSH_SOCKS_SUCCESS)
    {
      SSH_DEBUG(2, ("Got err %d from SshSocks.", err));
      SSH_FSM_SET_NEXT(finish_fsm);
      return SSH_FSM_CONTINUE;
    }

  if (filter->socks_version != socksinfo->socks_version_number)
    {
      SSH_DEBUG(2, ("Version changed in the middle of the transaction "
                    "(using ver %d, client gave %d).",
                    filter->socks_version, socksinfo->socks_version_number));
      filter->error = SSH_SOCKS5_REPLY_FAILURE;
      SSH_FSM_SET_NEXT(send_error);
      return SSH_FSM_CONTINUE;
    }

  if (((filter->socks_version == 5) &&
       (socksinfo->command_code != SSH_SOCKS5_COMMAND_CODE_CONNECT)) ||
      ((filter->socks_version == 4) &&
       (socksinfo->command_code != SSH_SOCKS4_COMMAND_CODE_CONNECT)))
    {
      SSH_DEBUG(2, ("Invalid command %d.", socksinfo->command_code));
      /* send failure reply. */
      if (filter->socks_version == 5)
        filter->error = SSH_SOCKS5_REPLY_CMD_NOT_SUPPORTED;
      else
        filter->error = SSH_SOCKS4_REPLY_FAILED_REQUEST;
      
      SSH_FSM_SET_NEXT(send_error);
      return SSH_FSM_CONTINUE;
    }
  ssh_xfree(filter->fwd->connect_to_host);
  ssh_xfree(filter->fwd->connect_to_port);

  SSH_ASSERT(socksinfo->ip != NULL);
  SSH_ASSERT(socksinfo->port != NULL);
  
  filter->fwd->connect_to_host = ssh_xstrdup(socksinfo->ip);
  filter->fwd->connect_to_port = ssh_xstrdup(socksinfo->port);
  ssh_socks_free(&socksinfo);
  filter->flushing_data = TRUE;
  SSH_FSM_SET_NEXT(send_reply);
  return SSH_FSM_YIELD;
}

SSH_FSM_STEP(send_reply)
{
  SshSocksFilter filter = (SshSocksFilter)fsm_context;
  SocksInfoStruct socksinfo;
  SocksError err;
  
  SSH_DEBUG(2, ("Sending SUCCESS reply."));

  memset(&socksinfo, 0, sizeof(socksinfo));
  /* At this point, in_buffer should be empty. */
  SSH_ASSERT(ssh_buffer_len(filter->in_buf) == 0);

  if (filter->socks_version == 5)
    socksinfo.command_code = SSH_SOCKS5_REPLY_SUCCESS;
  else
    socksinfo.command_code = SSH_SOCKS4_REPLY_GRANTED;
  socksinfo.socks_version_number = filter->socks_version;
  socksinfo.ip = filter->fwd->connect_from_host;
  socksinfo.port = filter->fwd->port;

  err = ssh_socks_server_generate_reply(filter->out_buf, &socksinfo);
  if (err != SSH_SOCKS_SUCCESS)
    {
      /* XXX  */
      SSH_NOTREACHED;
    }
  SSH_FSM_CONDITION_SIGNAL(filter->out_buf_grown);
  SSH_FSM_SET_NEXT(connect_to_dst);
  return SSH_FSM_YIELD;
}

SSH_FSM_STEP(connect_to_dst)
{
  SshSocksFilter filter = (SshSocksFilter)fsm_context;

  if (ssh_buffer_len(filter->out_buf) > 0)
    {
      SSH_DEBUG(4, ("At least part of reply packet still lingering in "
                    "buffers. Waiting more."));

      /* Because at least mozilla doesn't seem to read replies at all,*/
      /* Should not be necessary, but let's wake the writer, if it has
         for some gone waiting. */
      SSH_FSM_CONDITION_SIGNAL(filter->out_buf_grown);
      SSH_FSM_CONDITION_WAIT(filter->written_some);
    }

  SSH_DEBUG(0, ("Opening a connection to \"%s\", port %s.",
                filter->fwd->connect_to_host, filter->fwd->connect_to_port));

  /* This makes certain, that the iocb is not called again. */
  ssh_stream_set_callback(filter->stream, NULL_FNPTR, NULL);

  /* Send a request to open a channel and connect it to the given port. */
  ssh_channel_dtcp_open_to_remote(filter->fwd->common, filter->stream,
                                  filter->fwd->connect_to_host,
                                  filter->fwd->connect_to_port,
                                  filter->fwd->connect_from_host,
                                  filter->fwd->port);

  filter->fwd = NULL;
  /* The stream is now owned by SshConn. */
  filter->stream = NULL;

  SSH_FSM_SET_NEXT(finish_fsm);
  return SSH_FSM_YIELD;
}

SSH_FSM_STEP(send_error)
{
  SshSocksFilter filter = (SshSocksFilter)fsm_context;
  SocksInfoStruct socksinfo;
  SocksError err;

  memset(&socksinfo, 0, sizeof(socksinfo));

  socksinfo.socks_version_number = filter->socks_version;
  socksinfo.command_code = filter->error;
  socksinfo.ip = "0.0.0.0";
  socksinfo.port = "1";

  err = ssh_socks_server_generate_reply(filter->out_buf, &socksinfo);
  if (err != SSH_SOCKS_SUCCESS)
    {
      SSH_DEBUG(2, ("Failed to generate reply (err = %d).", err));
      SSH_FSM_SET_NEXT(finish_fsm);
      return SSH_FSM_CONTINUE;
    }
  SSH_FSM_CONDITION_SIGNAL(filter->out_buf_grown);

  /*
    From the spec:

      When a reply (REP value other than X'00') indicates a failure, the
      SOCKS server MUST terminate the TCP connection shortly after
      sending the reply.  This must be no more than 10 seconds after
      detecting the condition that caused a failure.
  */

  ssh_register_timeout(3L, 0L, error_timeout_cb, filter);

  SSH_FSM_SET_NEXT(finish_fsm);

  return SSH_FSM_SUSPENDED;
}

/* Reading and writing to the stream. **********************************/

static void iocb(SshStreamNotification notification, void *context)
{
  SshSocksFilter filter = (SshSocksFilter)context;

  switch (notification)
    {
    case SSH_STREAM_INPUT_AVAILABLE:
      ssh_fsm_continue(filter->reader);
      break;
    case SSH_STREAM_CAN_OUTPUT:
      ssh_fsm_continue(filter->writer);
      break;
    case SSH_STREAM_DISCONNECTED:
      /* We don't want any more of these. */
      ssh_stream_set_callback(filter->stream, NULL_FNPTR, NULL);
      ssh_fsm_set_next(filter->main, finish_fsm);
      break;
    }
}

SSH_FSM_STEP(write_out)
{
  SshSocksFilter filter = (SshSocksFilter)fsm_context;
  size_t ret_val;

  if (ssh_buffer_len(filter->out_buf) == 0)
    {
      SSH_DEBUG(4, ("Waiting for more data."));
      SSH_FSM_CONDITION_WAIT(filter->out_buf_grown);
    }

  SSH_DEBUG_HEXDUMP(0, ("Writing data:"),
                    ssh_buffer_ptr(filter->out_buf),
                    ssh_buffer_len(filter->out_buf));

  ret_val = ssh_stream_write(filter->stream, ssh_buffer_ptr(filter->out_buf),
                             ssh_buffer_len(filter->out_buf));
  if (ret_val < 0)
    {
      SSH_DEBUG(5, ("Write failed, waiting for I/O notification from "
                    "stream layer."));
      return SSH_FSM_SUSPENDED;
    }
  else if (ret_val == 0)
    {
      SSH_DEBUG(2, ("EOF received."));
      /* Wake main thread to stop and free things. */
      ssh_stream_destroy(filter->stream);
      filter->stream = NULL;
      ssh_fsm_set_next(filter->main, finish_fsm);
      filter->writer = NULL;
      return SSH_FSM_FINISH;
    }

  SSH_DEBUG(1, ("Consuming some data (data left in buffer = %ld).",
                ssh_buffer_len(filter->out_buf) - ret_val));
  ssh_buffer_consume(filter->out_buf, ret_val);

  SSH_FSM_CONDITION_SIGNAL(filter->written_some);
  SSH_FSM_CONDITION_WAIT(filter->out_buf_grown);
}

SSH_FSM_STEP(read_in)
{
  SshSocksFilter filter = (SshSocksFilter)fsm_context;
  int ret_val;
  unsigned char buf[2];

  if (filter->flushing_data)
    SSH_FSM_CONDITION_WAIT(filter->in_buf_shrunk);

  ret_val = ssh_stream_read(filter->stream, buf, 1);
  if (ret_val > 0)
    {
      ssh_xbuffer_append(filter->in_buf, buf, 1);
    }

  if (ret_val < 0)
    {
      SSH_DEBUG(2, ("Read failed, waiting for I/O notification from "
                    "stream layer."));
      return SSH_FSM_SUSPENDED;
    }
  else if (ret_val == 0)
    {
      SSH_DEBUG(2, ("EOF received."));
      /* Wake main thread to stop and free things. */
      ssh_stream_destroy(filter->stream);
      filter->stream = NULL;
      ssh_fsm_set_next(filter->main, finish_fsm);
      filter->reader = NULL;
      return SSH_FSM_FINISH;
    }

  SSH_DEBUG_HEXDUMP(0, ("Received data:"),
                    ssh_buffer_ptr(filter->in_buf), 1);

  SSH_FSM_CONDITION_SIGNAL(filter->read_more);
  SSH_FSM_CONDITION_WAIT(filter->in_buf_shrunk);
}
