/*

  pkcs1.c

  Author(s):

    Jukka Aittokallio <jai@ssh.com>
    Mika Kojo <mkojo@ssh.fi>

  Copyright (c) 2001 SSH Communications Security, Finland
  All rights reserved.

*/
#include "sshincludes.h"
#include "sshasn1.h"
#include "sshcrypt.h"
#include "pkcs1.h"

#define SSH_DEBUG_MODULE "Pkcs1"

Boolean
ssh_pkcs1_wrap(const unsigned char *oid,
               const unsigned char *data,
               size_t data_len,
               unsigned char **ber_ret,
               size_t *ber_len_ret)
{
  SshAsn1Context asn1_context;
  SshAsn1Status  status;
  SshAsn1Node node;

  if ((asn1_context = ssh_asn1_init()) == NULL)
    {
      SSH_DEBUG(SSH_D_FAIL, ("Couldn't init ASN.1 context."));
      return FALSE;
    }

  status = ssh_asn1_create_node(asn1_context, &node,
                                "(sequence ()"
                                "  (sequence ()"
                                "    (object-identifier ())"
                                "    (null ()))"
                                "  (octet-string ()))",
                                oid,
                                data,
                                data_len);

  if (status != SSH_ASN1_STATUS_OK)
    {
      SSH_DEBUG(SSH_D_FAIL, ("ASN.1 create node failed."));
      return FALSE;
    }

  status = ssh_asn1_encode_node(asn1_context, node);

  if (status != SSH_ASN1_STATUS_OK)
    {
      SSH_DEBUG(SSH_D_FAIL, ("ASN.1 encode failed."));
      return FALSE;
    }
  ssh_asn1_node_get_data(node, ber_ret, ber_len_ret);
  ssh_asn1_free(asn1_context);
  return TRUE;
}

Boolean
ssh_pkcs1_unwrap(const unsigned char *ber,
                 size_t ber_len,
                 unsigned char **oid_ret,
                 unsigned char **data_ret,
                 size_t *data_len_ret)
{
  SshAsn1Context asn1_context;
  SshAsn1Status  status;
  SshAsn1Node node;

  if ((asn1_context = ssh_asn1_init()) == NULL)
    {
      SSH_DEBUG(SSH_D_FAIL, ("Couldn't init ASN.1 context."));
      return FALSE;
    }

  status = ssh_asn1_decode_node(asn1_context,
                                ber, ber_len,
                                &node);
  if (status != SSH_ASN1_STATUS_OK)
    {
      SSH_DEBUG(SSH_D_FAIL, ("Couldn't ASN.1 decode the PKCS#1 blob.."));
      ssh_asn1_free(asn1_context);
      return FALSE;
    }

  status = ssh_asn1_read_node(asn1_context,
                              node,
                              "(sequence ()"
                              "  (sequence ()"
                              "    (object-identifier ())"
                              "    (null ()))"
                              "  (octet-string ()))",
                              oid_ret,
                              data_ret, data_len_ret);
  if (status != SSH_ASN1_STATUS_OK)
    {
      SSH_DEBUG(SSH_D_FAIL, ("Invalid PKCS#1 BER structure."));
      ssh_asn1_free(asn1_context);
      return FALSE;
    }
  ssh_asn1_free(asn1_context);
  return TRUE;
}

Boolean
ssh_pkcs1_pad(SshMPInt output,
              SshMPInt input,
              unsigned int input_len,
              unsigned int tag_number,
              unsigned int len)
{
  unsigned int i;

  if (len < input_len + 1)
    {
      SSH_DEBUG(SSH_D_FAIL, ("input len is too long: input_len = %u, len = %u", 
                             input_len, len));
      return FALSE;
    }

  /* Set tag number. */
  ssh_mp_set_ui(output, tag_number);

  /* Check the block type. */
  switch (tag_number)
    {
      /* Block type 0. */
    case 0x0:
      ssh_mp_mul_2exp(output, output, 8 * (len - 2));
      ssh_mp_add(output, output, input);
      break;
      /* Block type 1 (used with signatures). */
    case 0x1:
      for (i = 2; i < len - input_len - 1; i++)
        {
          ssh_mp_mul_2exp(output, output, 8);
          ssh_mp_add_ui(output, output, 0xff);
        }
      ssh_mp_mul_2exp(output, output, 8 * (input_len + 1));
      ssh_mp_add(output, output, input);
      break;
      /* Block type 2 (used with encryption). */
    case 0x2:
      for (i = 2; i < len - input_len - 1; i++)
        {
          unsigned int byte;
          do
            byte = ssh_random_get_byte();
          while (byte == 0);
          ssh_mp_mul_2exp(output, output, 8);
          ssh_mp_add_ui(output, output, byte);
        }
      ssh_mp_mul_2exp(output, output, 8 * (input_len + 1));
      ssh_mp_add(output, output, input);
      break;
    default:
      SSH_DEBUG(SSH_D_FAIL, ("block type unknown %d.", tag_number));
      return FALSE; 
    }
  
  return TRUE;
}


Boolean
ssh_pkcs1_unpad(unsigned int tag_number,
                SshMPInt input,
                unsigned char *output_buffer,
                size_t output_buffer_length,
                size_t *return_length)
{
  unsigned int i;
  unsigned char *input_buffer;

  /* Linearize to output buffer, which must be long enough. */
  if ((input_buffer = ssh_malloc(output_buffer_length)) == NULL)
    return FALSE;

  if (ssh_mp_to_buf(input_buffer, output_buffer_length, input) == 0)
    {
      ssh_free(input_buffer);
      return FALSE;
    }

  /* Check for valid block. */
  if (input_buffer[0] != 0 || input_buffer[1] != tag_number)
    {
      ssh_free(input_buffer);
      return FALSE;
    }

  /* Check the block type. */
  switch (tag_number)
    {
      /* Block type 0. */
    case 0x0:
      /* This block type needs further handling at later time... We just
         get it out as is. */
      memcpy(output_buffer, input_buffer + 2, output_buffer_length - 2);
      *return_length = output_buffer_length - 2;
      break;
      /* Block type 1 (used with signatures). */
    case 0x1:
      for (i = 2; i < output_buffer_length; i++)
        {
          if (input_buffer[i] != 0xff)
            break;
        }
      /* Copy. */
      memcpy(output_buffer, input_buffer + i + 1,
             output_buffer_length - i - 1);
      *return_length = output_buffer_length - i - 1;
      break;
      /* Block type 2 (used with encryption). */
    case 0x2:
      for (i = 2; i < output_buffer_length; i++)
        {
          if (input_buffer[i] == 0x0)
            break;
        }

      /* Copy. */
      memcpy(output_buffer, input_buffer + i + 1,
             output_buffer_length - i - 1);
      *return_length = output_buffer_length - i - 1;

      break;
    default:
      SSH_DEBUG(SSH_D_NETGARB, ("block type unknown %d.", tag_number));
      ssh_free(input_buffer);
      return FALSE;
    }

  /* Free temporary buffer. */
  ssh_free(input_buffer);

  return TRUE;
}
