/**
******************************************************************************
* File Name          : vega_esecure.c
* Description        : 
*
*
******************************************************************************
*/

/* Includes ------------------------------------------------------------------*/
#include <string.h>
#include <stdlib.h>
#include "serv_protocols_lib/serv_protocol_handler.h"

#if defined(VEGA_ESECURE_PRESENT)
/* C code---------------------------------------------------------------------*/

#if defined(VEGA_SUCURE_DUEBUG_LOG_OUT)
#warning VEGA_SUCURE_DUEBUG_LOG_OUT
#define __DEBUG_PRINTF          LOG
#define __DEBUG_BIN_BUFF_PRINTF BIN_BUFF_LOG
#else //
#define __DEBUG_PRINTF          __PRINTF
#define __DEBUG_BIN_BUFF_PRINTF __BIN_BUFF_PRINTF
#endif //defined(DUEBUG_LOG_OUT)

/*
"A"-, "B"-.
    ""  ""    uinixtime, 4 .
*/

typedef __packed struct
{
  //->encrypted->
  uint16_t msg_len;
  uint8_t* msg;//msg_len bytes
  //<-encrypted<-
  uint8_t sign[16];
}encrypted_packet_t;

static const char* get_pheader(const uint8_t conn_id)
{
  switch (conn_id)
  {
  case 0:  return "esecure_ctx1:";
  case 1:  return "esecure_ctx2:";
  case 2:  return "esecure_ctx3:";
  case 3:  return "esecure_ctx4:";
  case 4:  return "esecure_ctx5:";
  case 5:  return "esecure_ctx6:";
  default: return "esecure_ctxX:";
  }
}

void vega_esecure_init(conn_ctx_t* ctx)
{
  if(ctx->esecure==NULL ||
     ctx->protocol_type!=VEGA_PROTOCOL_TYPE) return;
  
  ctx->esecure->step=0;
}

//#pragma optimize=none
int16_t vega_esecure_rx(conn_ctx_t* ctx, uint16_t parse_len)
{
  static const uint8_t MAX_TIME_DIFF_S=1*60;
  
  if(ctx->esecure==NULL ||
     ctx->protocol_type!=VEGA_PROTOCOL_TYPE) return 0;
  
  if(ctx->esecure->step>2) return PROTOCOL_CTX_PARSE_DATA_FMT_ERR;
  
  if(!parse_len) return 0;
  
  uint16_t processed_len=0;
  
  //  unixtime
  if(ctx->esecure->step==0 && parse_len>=sizeof(uint32_t))
  {
    const uint32_t curr_utime=ctx->common_const->get_unix_time();
    
    //   
    if(curr_utime<1582019105UL)
    {
      if(VEGA_SECURE_DEBUG_LEVEL>=LDEBUG_L)
      {
        LOG("%s rtc are not synchronized\n", get_pheader(ctx->conn_ctx_id));
      }
      
      return PROTOCOL_CTX_PARSE_DATA_FMT_ERR;
    }
    
    if(memcmp(ctx->esecure->chachapoly_key, "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", 32) == 0)
    {
      if(VEGA_SECURE_DEBUG_LEVEL>=LDEBUG_L)
      {
        LOG("%s session key is null\n", get_pheader(ctx->conn_ctx_id));
      }
      
      return PROTOCOL_CTX_CRYPT_ERR;
    }
    
    const uint32_t* rx_utime = (const uint32_t*)ctx->rx_buff;
    
    if(VEGA_SECURE_DEBUG_LEVEL>=LDEBUG_L)
    {
      __DEBUG_PRINTF("%s system time: %u, recv time: %u\n", get_pheader(ctx->conn_ctx_id), curr_utime, rx_utime[0]);
    }
    
    if((curr_utime>=rx_utime[0] && (curr_utime-rx_utime[0])>MAX_TIME_DIFF_S) ||
       (curr_utime<rx_utime[0] && (rx_utime[0]-curr_utime)>MAX_TIME_DIFF_S))
    {
      if(VEGA_SECURE_DEBUG_LEVEL>=LDEBUG_L)
      {
        LOG("%s out of time delta (%li)\n", get_pheader(ctx->conn_ctx_id), curr_utime-rx_utime[0]);
      }
      
      //    nonce  unixtime   
      memset(&ctx->esecure->tx.nonce, 0, sizeof(ctx->esecure->tx.nonce));
      ctx->esecure->tx.nonce.utime=curr_utime;
      ctx->esecure->tx.nonce.direction=1;
      ctx->esecure->tx.nonce.counter=0;
      
      if(0>ctx->common_const->socket_write(ctx->conn_ctx_id, (uint8_t*)&ctx->esecure->tx.nonce, sizeof(ctx->esecure->tx.nonce))) return PROTOCOL_CTX_CONN_ERR;
      
      //todo: ,  
      vTaskDelay(1000);
      
      return PROTOCOL_CTX_PARSE_DATA_FMT_ERR;
    }
    
    memset(&ctx->esecure->rx.nonce, 0, sizeof(ctx->esecure->rx.nonce));
    ctx->esecure->rx.nonce.utime=rx_utime[0];
    ctx->esecure->rx.nonce.direction=0;
    ctx->esecure->rx.nonce.counter=0;
    
    mbedtls_chachapoly_free(&ctx->esecure->rx.chachapoly);
    mbedtls_chachapoly_init(&ctx->esecure->rx.chachapoly);
    if(mbedtls_chachapoly_setkey(&ctx->esecure->rx.chachapoly, ctx->esecure->chachapoly_key) != 0) return PROTOCOL_CTX_CRYPT_ERR;
    
    memset(&ctx->esecure->tx.nonce, 0, sizeof(ctx->esecure->tx.nonce));
    ctx->esecure->tx.nonce.utime=curr_utime;
    ctx->esecure->tx.nonce.direction=1;
    ctx->esecure->tx.nonce.counter=0;
    
    mbedtls_chachapoly_free(&ctx->esecure->tx.chachapoly);
    mbedtls_chachapoly_init(&ctx->esecure->tx.chachapoly);
    if(mbedtls_chachapoly_setkey(&ctx->esecure->tx.chachapoly, ctx->esecure->chachapoly_key) != 0) return PROTOCOL_CTX_CRYPT_ERR;
    
    memcpy(&ctx->esecure->tx.buff[0], &curr_utime, sizeof(curr_utime));
    
    if(0>ctx->common_const->socket_write(ctx->conn_ctx_id, &ctx->esecure->tx.buff[0], sizeof(curr_utime))) return PROTOCOL_CTX_CONN_ERR;
    
    if(VEGA_SECURE_DEBUG_LEVEL>=LDEBUG_L)
    {
      __DEBUG_PRINTF("%s sent utime...\n", get_pheader(ctx->conn_ctx_id));
    }
    
    ctx->esecure->step=1;
    
    if(VEGA_SECURE_DEBUG_LEVEL>=LDEBUG_L)
    {
      __DEBUG_PRINTF("%s enter secure exchange...\n", get_pheader(ctx->conn_ctx_id));
    }
    
    processed_len+=sizeof(uint32_t);
    parse_len-=sizeof(uint32_t);
  }
  
  for(;;)
  {
    //  
    if(ctx->esecure->step==1 && parse_len>=sizeof(esecure_packet_len_t))
    {
      if(mbedtls_chachapoly_starts(&ctx->esecure->rx.chachapoly, (const uint8_t*)&ctx->esecure->rx.nonce, MBEDTLS_CHACHAPOLY_DECRYPT) != 0) return PROTOCOL_CTX_CRYPT_ERR;
      
      if(mbedtls_chachapoly_update(&ctx->esecure->rx.chachapoly, sizeof(ctx->esecure->rx.msg_len), &ctx->rx_buff[processed_len], (uint8_t*)&ctx->esecure->rx.msg_len) != 0) return PROTOCOL_CTX_CRYPT_ERR;
      
      if(ctx->esecure->rx.msg_len>sizeof(ctx->esecure->rx.buff) ||
         ctx->esecure->rx.msg_len<1) return PROTOCOL_CTX_MEM_ERR;
      
      if(VEGA_SECURE_DEBUG_LEVEL>=LDEBUG_L)
      {
        __PRINTF("%s got crypted len, we waiting %u bytes...\n", get_pheader(ctx->conn_ctx_id), ctx->esecure->rx.msg_len);
      }
      
      ctx->esecure->step=2;
      
      processed_len+=sizeof(esecure_packet_len_t);
      parse_len-=sizeof(esecure_packet_len_t);
    }
    
    
    int16_t vega_rx_processed_len=0;
    
    //   
    if(ctx->esecure->step==2 && parse_len>=ctx->esecure->rx.msg_len+sizeof(esecure_mac_t))
    {
      if(mbedtls_chachapoly_update(&ctx->esecure->rx.chachapoly, ctx->esecure->rx.msg_len, &ctx->rx_buff[processed_len], ctx->esecure->rx.buff) != 0) return PROTOCOL_CTX_CRYPT_ERR;
      
      esecure_mac_t check_tag;
      
      if(mbedtls_chachapoly_finish(&ctx->esecure->rx.chachapoly, check_tag) != 0 ) return PROTOCOL_CTX_CRYPT_ERR;
      
      const uint8_t* const tag = &ctx->rx_buff[processed_len+ctx->esecure->rx.msg_len];
      
      if(memcmp(check_tag, tag, sizeof(check_tag))!=0)
      {
        __DEBUG_PRINTF("%s wrong packet sign...\n", get_pheader(ctx->conn_ctx_id));
        
        return PROTOCOL_CTX_CRYPT_ERR;
      }
      
      if(VEGA_SECURE_DEBUG_LEVEL>=LDEBUG_L)
      {
        __PRINTF("%s parse %u decrypted bytes...\n", get_pheader(ctx->conn_ctx_id), ctx->esecure->rx.msg_len);
      }
      
      //        
      ctx->protocol.vega.parse_buff=ctx->esecure->rx.buff;
      ctx->protocol.vega.parse_len=ctx->esecure->rx.msg_len;
      vega_rx_processed_len=VegaRxProcessing(&ctx->protocol.vega);
      
      if(vega_rx_processed_len<0) return vega_rx_processed_len;
      
      if(vega_rx_processed_len!=ctx->esecure->rx.msg_len)
      {
        if(VEGA_SECURE_DEBUG_LEVEL>=LDEBUG_L)
        {
          __PRINTF("%s warning, decrypted message len not not consistent with parsed len...\n", get_pheader(ctx->conn_ctx_id));
        }
      }
      
      ctx->esecure->rx.nonce.counter++;
      
      if(ctx->esecure->rx.nonce.counter==0)
      {
        return PROTOCOL_CTX_CRYPT_COUNTER_OVF_ERR; //   
      }
      
      processed_len+=ctx->esecure->rx.msg_len+sizeof(esecure_mac_t);
      parse_len-=ctx->esecure->rx.msg_len+sizeof(esecure_mac_t);
      
      ctx->esecure->step=1;
    }
    
    if(parse_len>0 && vega_rx_processed_len>0) continue;
    else                                       break;
  }
  
  return (int16_t)processed_len;
}

int16_t vega_esecure_tx_prepare(vega_esecure_ctx_t* sctx, const uint8_t* tx, const uint16_t tx_len)
{
  if(sctx==NULL || sctx->step<1) return PROTOCOL_CTX_PARSE_DATA_FMT_ERR;
  
  const esecure_packet_len_t packet_len=tx_len;
  
  if(sizeof(sctx->tx.buff)<(sizeof(esecure_packet_len_t)+packet_len+sizeof(esecure_mac_t))) return PROTOCOL_CTX_MEM_ERR;
  
  if(mbedtls_chachapoly_starts(&sctx->tx.chachapoly, (const uint8_t*)&sctx->tx.nonce, MBEDTLS_CHACHAPOLY_ENCRYPT) != 0) return PROTOCOL_CTX_CRYPT_ERR;
  
  if(mbedtls_chachapoly_update(&sctx->tx.chachapoly, sizeof(packet_len), (const uint8_t*)&packet_len, &sctx->tx.buff[0]) != 0) return PROTOCOL_CTX_CRYPT_ERR;
  
  if(mbedtls_chachapoly_update(&sctx->tx.chachapoly, packet_len, tx, &sctx->tx.buff[sizeof(esecure_packet_len_t)]) != 0) return PROTOCOL_CTX_CRYPT_ERR;
  
  if(mbedtls_chachapoly_finish(&sctx->tx.chachapoly, &sctx->tx.buff[sizeof(esecure_packet_len_t)+packet_len]) != 0) return PROTOCOL_CTX_CRYPT_ERR;
    
  sctx->tx.nonce.counter++;
  
  if(sctx->tx.nonce.counter==0)
  {
    return PROTOCOL_CTX_CRYPT_COUNTER_OVF_ERR; //   
  }
  
  return (int16_t)(sizeof(esecure_packet_len_t)+tx_len+sizeof(esecure_mac_t));
}

#endif //defined(VEGA_ESECURE_PRESENT)