#include <yadns/resolver_options.h>
#include <ares.h>
#include <ares_dns.h>
#include <arpa/nameser.h>
#include <arpa/nameser_compat.h>
#undef DELETE
#include <boost/array.hpp>
#include <boost/unordered_set.hpp>
#include <boost/range/iterator_range.hpp>
#include <boost/iterator/iterator_facade.hpp>
#include <boost/enable_shared_from_this.hpp>
#include <boost/make_shared.hpp>
#include <boost/bind.hpp>
#include <boost/unordered_map.hpp>


namespace detail {

struct delete_op
{
  template <class V>
  void operator()(V* v) const
  {
    if (v)
      delete v;
  }
};

struct ares_free_string_op
{
  void operator()(char* v) const
  {
    if (v)
      ares_free_string(v);
  }
};

template <class P, class D = delete_op>
struct guard : public boost::noncopyable

{
  P ptr_;

  template <class W>
  explicit guard(W q)
    : ptr_(q)
  {
  }

  ~guard()
  {
    D d;
    d(ptr_);
  }
};

typedef guard<char*, ares_free_string_op> guard_ares_name_t;

} // namespace detail {

inline boost::system::error_code translate_ares_status(int error)
{
  namespace baerr = boost::asio::error;
  switch (error)
  {
    case ARES_SUCCESS:
      return boost::system::error_code();
    case ARES_ECANCELLED:
      return baerr::make_error_code(baerr::operation_aborted);
    case ARES_ENOMEM:
      return baerr::make_error_code(baerr::no_memory);
    case ARES_ENOTIMP:
    case ARES_EFORMERR:
    case ARES_EBADQUERY:
    case ARES_EBADNAME:
    case ARES_EBADFAMILY:
      return baerr::make_error_code(baerr::no_recovery);
    case ARES_ENODATA:
    case ARES_ENOTFOUND:
      return baerr::make_error_code(baerr::host_not_found);
    case ARES_EFILE:
      return baerr::make_error_code(baerr::not_found);
    case ARES_ENOTINITIALIZED:
      return baerr::make_error_code(baerr::invalid_argument);
    default:
      return baerr::make_error_code(baerr::host_not_found_try_again);
  }
}

struct ares_resolver_service::basic_response
{
  typedef std::vector<unsigned char> buf_t;
  typedef boost::shared_ptr<buf_t> buf_ptr_t;
  typedef boost::iterator_range<const unsigned char*> range_t;

  int n_of_rr_;
  buf_ptr_t buf_;
  range_t rr_;

  basic_response(int ancount, const range_t& reply,
      size_t rdata_offset, int rrtype)
    : n_of_rr_(0)
    , buf_(new buf_t(reply.begin(), reply.end()))
    , rr_(find_rr(buf_->data()+rdata_offset, buf_->data(),
            buf_->size(), rrtype))
  {
    if (rr_)
      n_of_rr_ = ancount;
  }

  basic_response()
    : n_of_rr_(0)
    , buf_()
    , rr_()
  {
  }

  inline bool advance(int rrtype)
  {
    if (--n_of_rr_ > 0
        && (rr_ = find_rr(rr_.end(), buf_->data(), buf_->size(), rrtype)))
      return true;

    n_of_rr_ = 0;
    return false;
  }

  inline bool operator==(const basic_response& lh) const
  {
    return n_of_rr_ == lh.n_of_rr_ &&
      rr_ == lh.rr_;
  }

  // Search for the first record of type 'atype' in dns response buffer abuf
  //   beginning the search from aprt
  static range_t find_rr(const unsigned char* aptr,
      const unsigned char* abuf, int alen, int atype)
  {
    const unsigned char* end = abuf + alen;
    for (;;)
    {
      // Partly borrowed from c-ares- adig
      long len = 0;
      char* name = 0;

      // Parse the RR name.
      int status = ares_expand_name(aptr, abuf, alen, &name, &len);
      if (status != ARES_SUCCESS)
        return boost::make_iterator_range(end, end);
      aptr += len;

      detail::guard_ares_name_t name_guard(name);

      // Make sure there is enough data after the RR name for the fixed
      //   part of the RR.
      if (aptr + NS_RRFIXEDSZ > end)
        return boost::make_iterator_range(end, end);

      // Parse the fixed part of the RR, and advance to the RR data field.
      int type = DNS_RR_TYPE(aptr);
      int dlen = DNS_RR_LEN(aptr);
      aptr += RRFIXEDSZ;

      if (aptr + dlen > end)
        return boost::make_iterator_range(end, end);
      if (type == atype)
        return range_t(aptr, aptr + dlen);

      aptr += dlen;
    }
  }

};

template <int T>
struct ares_resolver_service::response : public basic_response
{
  response(int ancount, const range_t& reply, size_t rdata_offset)
    : basic_response(ancount, reply, rdata_offset, T)
  {
  }

  response()
    : basic_response()
  {
  }

  response(const basic_response& a)
    : basic_response(a)
  {
  }

  bool advance()
  {
    return basic_response::advance(T);
  }
};

template <class D, class V, int T>
class ares_resolver_service::basic_iterator
  : public boost::iterator_facade<
    D,
    V,
    boost::forward_traversal_tag,
    V>
{
  typedef std::vector<char>::iterator iterator_t;

public:
  enum {
    rr_type = T
  };

  typedef response<rr_type> response_t;
  typedef typename response_t::range_t range_t;

  basic_iterator()
    : a_()
  {
  }

  static D create(response_t a)
  {
    D it;
    it.a_ = a;
    return it;
  }

private:
  friend class boost::iterator_core_access;

  void increment()
  {
    if (!a_.advance())
      *this = basic_iterator<D, V, T>();
  }

  bool equal(const basic_iterator<D, V, T>& other) const
  {
    return a_ == other.a_;
  }

  V dereference() const
  {
    const D* that = static_cast<const D*>(this);
    return that->do_dereference(a_);
  }

  response_t a_;
};

class ares_resolver_service::iterator_a
  : public ares_resolver_service::basic_iterator<
    ares_resolver_service::iterator_a, std::string, T_A>
{
public:
  std::string do_dereference(const response_t& a) const
  {
    if (a.rr_.size() != 4)
      return std::string();
    char ipv4[16];
    const char* rv = inet_ntop(AF_INET, a.rr_.begin(), ipv4, sizeof(ipv4));
    return std::string(rv ? rv : "");
  }
};

class ares_resolver_service::iterator_aaaa
  : public ares_resolver_service::basic_iterator<
    ares_resolver_service::iterator_aaaa, std::string, T_AAAA>
{
public:
  std::string do_dereference(const response_t& a) const
  {
    if (a.rr_.size() != 16)
      return std::string();
    char ipv6[40];
    const char* rv = inet_ntop(AF_INET6, a.rr_.begin(), ipv6, sizeof(ipv6));
    return std::string(rv ? rv : "");
  }
};

class ares_resolver_service::iterator_ptr
  : public ares_resolver_service::basic_iterator<
    ares_resolver_service::iterator_ptr, std::string, T_PTR>
{
public:
  std::string do_dereference(const response_t& a) const
  {
    char* name = 0;
    long len = 0;
    int status = ares_expand_name(a.rr_.begin(), a.buf_->data(),
        a.buf_->size(), &name, &len);
    detail::guard_ares_name_t name_guard(name);

    if (status != ARES_SUCCESS)
      return std::string();
    return std::string(name);
  }
};

class ares_resolver_service::iterator_mx
  : public ares_resolver_service::basic_iterator<
    ares_resolver_service::iterator_mx, std::pair<unsigned short, std::string>, T_MX>
{
public:
    std::pair<unsigned short, std::string> do_dereference(const response_t& a) const
  {
    if (a.rr_.size() < 2)
      return {0, ""};

    char* name = 0;
    long len = 0;
    unsigned short priority = DNS__16BIT(a.rr_.begin());
    int status = ares_expand_name(a.rr_.begin() + sizeof(unsigned short),
        a.buf_->data(), a.buf_->size(), &name, &len);
    detail::guard_ares_name_t name_guard(name);

    if (status != ARES_SUCCESS)
      return {0, ""};
    return {priority, name};
  }
};

class ares_resolver_service::iterator_txt
  : public ares_resolver_service::basic_iterator<
    ares_resolver_service::iterator_txt, std::string, T_TXT>
{
public:
  std::string do_dereference(const response_t& a) const
  {
    const unsigned char* p = a.rr_.begin();
    std::string rv;
    while (p < a.rr_.end())
    {
      long len = *p;
      if (p + len + 1 > a.rr_.end())
        return rv;
      union
      {
        char* as_char;
        unsigned char* as_uchar;
      } name;
      int status = ares_expand_string(p, a.buf_->data(),
          a.buf_->size(), &name.as_uchar, &len);
      detail::guard_ares_name_t name_guard(name.as_char);

      if (status != ARES_SUCCESS)
        return rv;
      if (rv.empty())
        rv = std::string(name.as_char);
      else
      {
        rv += std::string(name.as_char);
      }
      p += len;
    }
    return rv;
  }
};

struct ares_resolver_service::basic_query
  : public boost::enable_shared_from_this<ares_resolver_service::basic_query>
{
  implementation_type d_;

  explicit basic_query(implementation_type& d);
  virtual ~basic_query();

  void start(const std::string& name, int dnsclass,
      int type, ares_callback cb);
  void start_helper(const std::string& name, int dnsclass,
      int type, ares_callback cb);

  virtual void invoke_handler(const boost::system::error_code& ec,
      basic_response) = 0;
  inline void invoke_handler_helper(const boost::system::error_code& ec,
      basic_response = basic_response());
};

struct ares_resolver_service::helper
  : public boost::enable_shared_from_this<ares_resolver_service::helper>
{
  typedef boost::asio::posix::stream_descriptor stream_descriptor_t;
  typedef boost::shared_ptr<stream_descriptor_t> stream_descriptor_ptr_t;

  struct socket_state
  {
    stream_descriptor_ptr_t stream_;
    bool updated_ = false;
  };

  typedef boost::unordered_map<ares_socket_t, socket_state> fds_set_t;

  enum class process_state_t
  {
      read,
      write,
      timeout
  };

  int q_;
  ares_options options_;
  int optmask_;
  ares_channel channel_;
  boost::asio::deadline_timer t_;
  boost::asio::io_service::strand strand_;
  fds_set_t fds_;
  helper(boost::asio::io_service& ios)
    : q_(0)
    , options_()
    , optmask_(0)
    , channel_(nullptr)
    , t_(ios)
    , strand_(ios)
    , fds_()
  {
  }

  void init()
  {
    if (channel_)
      return;

    optmask_ |= ARES_OPT_SOCK_STATE_CB;

    options_.sock_state_cb = &helper::socket_change_state_cb;
    options_.sock_state_cb_data = this;

    int err = ares_init_options(&channel_, &options_, optmask_);
    if (err != ARES_SUCCESS)
      boost::throw_exception(
        boost::system::system_error(
          translate_ares_status(err), "ares_init_options"));
  }

  ~helper()
  {
    assert (q_ == 0);
    ares_destroy(channel_);
  }

  void set_options(const resolver_options& options)
  {
    options_.timeout = options.timeout;
    optmask_ |= ARES_OPT_TIMEOUTMS;

    options_.tries = options.tries;
    optmask_ |= ARES_OPT_TRIES;

    if (options.use_edns)
    {
      options_.flags |= ARES_FLAG_EDNS;
      optmask_ |= ARES_OPT_FLAGS;
    }
  }

  void start()
  {
    init();
    set_timer();
  }

  void cancel()
  {
    // Cancel all outstanding async operations
    t_.cancel();
    fds_.clear();

    if (channel_)
      ares_cancel(channel_);
  }

  void process_fd(const boost::system::error_code& ec,
      ares_socket_t socket_fd, process_state_t state)
  {
    if (ec == boost::asio::error::operation_aborted || q_ == 0)
      return;

    if (state == process_state_t::timeout)
    {
        set_timer();
        ares_process_fd(channel_, ARES_SOCKET_BAD, ARES_SOCKET_BAD);

        return;
    }

    // Reset the relevance flag
    // The socket state can be changed inside ares_process_fd by socket_change_state_cb
    // And after the call we may not need any action
    auto it = fds_.find(socket_fd);
    if (it != fds_.end())
        it->second.updated_ = false;

    if (state == process_state_t::read)
        ares_process_fd(channel_, socket_fd, ARES_SOCKET_BAD);
    else
        ares_process_fd(channel_, ARES_SOCKET_BAD, socket_fd);

    // The connection could have already been closed
    it = fds_.find(socket_fd);
    if (it == fds_.end())
        return;

    auto& socket_state = it->second;
    if (socket_state.updated_)
        return;

    // If the state of the socket has not changed
    // repeat the request for the same action that brought us here
    if (state == process_state_t::read)
        socket_state.stream_->async_read_some(boost::asio::null_buffers(),
            strand_.wrap(
              boost::bind(&helper::process_fd,
                  shared_from_this(), _1, socket_fd, process_state_t::read)));
    else
        socket_state.stream_->async_write_some(boost::asio::null_buffers(),
            strand_.wrap(
              boost::bind(&helper::process_fd,
                  shared_from_this(), _1, socket_fd, process_state_t::write)));
  }

  struct timeval get_ares_process_timeout() const
  {
    struct timeval tv;
    tv.tv_sec = 5;  // default c-ares timeout
    tv.tv_usec = 0;

    if (optmask_ & ARES_OPT_TIMEOUTMS)
    {
      tv.tv_sec = options_.timeout / 1000;
      tv.tv_usec = (options_.timeout % 1000) * 1000;
    }

    if (q_ == 1)
      ares_timeout(channel_, &tv, &tv);

    return tv;
  }

  void set_timer()
  {
    struct timeval tv = get_ares_process_timeout();
    t_.expires_from_now(boost::posix_time::seconds(tv.tv_sec) + boost::posix_time::microseconds(tv.tv_usec));
    t_.async_wait(strand_.wrap(
          boost::bind(&helper::process_fd, shared_from_this(), _1, ARES_SOCKET_BAD, process_state_t::timeout)));
  }

  static void socket_change_state_cb(void* data, ares_socket_t socket_fd, int readable, int writable)
  {
      auto helper_ptr = reinterpret_cast<helper*>(data)->shared_from_this();
      auto it = helper_ptr->fds_.find(socket_fd);

      // If readable = 0 and writable = 0 - socket was closed
      if (!readable && !writable)
      {
          if (it != helper_ptr->fds_.end())
              helper_ptr->fds_.erase(it);

          return;
      }

      // Socket was opened
      if (it == helper_ptr->fds_.end())
      {
          auto elem = helper_ptr->fds_.insert({socket_fd, {}});
          it = elem.first;
      }

      auto& state = it->second;
      state.updated_ = true;

      int dupfd = dup(socket_fd);

      if (dupfd == -1)
      {
        boost::system::error_code ec(errno,
            boost::asio::error::get_system_category());
        boost::throw_exception(
          boost::system::system_error(ec, "dup"));
      }

      state.stream_ = boost::make_shared<stream_descriptor_t>(
                  helper_ptr->strand_.get_io_service(), dupfd);

      // Waiting for incoming data
      if (readable)
          state.stream_->async_read_some(boost::asio::null_buffers(),
              helper_ptr->strand_.wrap(
                boost::bind(&helper::process_fd,
                    helper_ptr, _1, socket_fd, process_state_t::read)));

      // We have data ready to send (TCP-only)
      if (writable)
          state.stream_->async_write_some(boost::asio::null_buffers(),
              helper_ptr->strand_.wrap(
                boost::bind(&helper::process_fd,
                    helper_ptr, _1, socket_fd, process_state_t::write)));
  }
};

template <class Handler, class Iterator>
struct ares_resolver_service::query
  : public ares_resolver_service::basic_query
{
  Handler handler_;

  query(implementation_type& impl, Handler handler)
    : basic_query(impl)
    , handler_(handler)
  {}

  void invoke_handler(const boost::system::error_code& ec, basic_response a)
  {
    boost::asio::io_service& ios = d_->strand_.get_io_service();
    ios.post(boost::bind(
          handler_, ec,
          Iterator::create(typename Iterator::response_t(a))));
  }

  inline void start(const std::string& name)
  {
    basic_query::start(name, C_IN, Iterator::rr_type,
        &callback<Iterator::rr_type>::op);
  }
};

inline ares_resolver_service::ares_resolver_service(
  boost::asio::io_service& owner)
  : service_base<ares_resolver_service>(owner)
{
  if (ares_library_init(ARES_LIB_INIT_ALL) != ARES_SUCCESS)
  {
    boost::system::error_code ec(errno,
        boost::asio::error::get_system_category());
    boost::throw_exception(
      boost::system::system_error(ec, "ares_library_init"));
  }
}

inline void ares_resolver_service::shutdown_service()
{
  ares_library_cleanup();
}

inline void ares_resolver_service::cancel(implementation_type& impl)
{
  if (impl)
    impl->strand_.dispatch(boost::bind(&helper::cancel, impl));
}

inline void ares_resolver_service::destroy(implementation_type& impl)
{
  impl.reset();
}

inline void ares_resolver_service::construct(implementation_type& impl)
{
  impl.reset(new helper(get_io_service()));
}

inline void ares_resolver_service::set_options(
    implementation_type& impl,
    const resolver_options& settings)
{
  if (impl)
      impl->set_options(settings);
}

template <int T>
struct ares_resolver_service::callback
{
  static void op(void* arg, int status, int /*timeouts*/, unsigned char *abuf,
      int alen)
  {
    basic_query_ptr_t* q = static_cast<basic_query_ptr_t*>(arg);
    detail::guard<basic_query_ptr_t*> g(q);

    if (status != ARES_SUCCESS)
      return (**q).invoke_handler_helper(translate_ares_status(status));

    if (alen < NS_HFIXEDSZ)
      return (**q).invoke_handler_helper(
        boost::asio::error::host_not_found_try_again);

    int qdcount = DNS_HEADER_QDCOUNT(abuf);
    int ancount = DNS_HEADER_ANCOUNT(abuf);

    // Skip the questions section.
    unsigned char* aptr = abuf + NS_HFIXEDSZ;
    for (int i=0; i<qdcount; ++i)
    {
      if (aptr == 0)
        return (**q).invoke_handler_helper(
          boost::asio::error::host_not_found_try_again);

      // Partly borrowed from c-ares- adig.
      char *name = 0;
      long len = 0;

      // Parse the question name.
      int status = ares_expand_name(aptr, abuf, alen, &name, &len);
      detail::guard_ares_name_t name_guard(name);

      if (status != ARES_SUCCESS)
        return (**q).invoke_handler_helper(
          boost::asio::error::host_not_found_try_again);
      aptr += len;

      // Make sure there's enough data after the name for the fixed part
      //   of the question.
      if (aptr + NS_QFIXEDSZ > abuf + alen)
      {
        return (**q).invoke_handler_helper(
          boost::asio::error::host_not_found_try_again);
      }
      aptr += NS_QFIXEDSZ;
    }
    typename response<T>::range_t buf(abuf, abuf+alen);
    response<T> res(ancount, buf, aptr-abuf);
    return (**q).invoke_handler_helper(
      (res == response<T>()
          ? boost::asio::error::host_not_found
          : boost::system::error_code())
      , res);
  }
};

inline ares_resolver_service::basic_query::basic_query(implementation_type& d)
  : d_(d)
{
}

inline ares_resolver_service::basic_query::~basic_query()
{
}

inline void ares_resolver_service::basic_query::start(const std::string& name,
    int dnsclass, int type, ares_callback cb)
{
  d_->strand_.dispatch(
    boost::bind(&basic_query::start_helper, shared_from_this(),
        name, dnsclass, type, cb));
}

inline void ares_resolver_service::basic_query::start_helper(
  const std::string& name, int /*dnsclass*/, int rrtype, ares_callback cb)
{
  basic_query_ptr_t* q = new basic_query_ptr_t(shared_from_this());

  int saved_q = ++d_->q_;
  d_->init();
  ares_query(d_->channel_, name.c_str(), C_IN, rrtype, cb, q);

  if (saved_q == d_->q_)
    d_->start();
}

inline void ares_resolver_service::basic_query::invoke_handler_helper(
  const boost::system::error_code& ec, basic_response a)
{
  if (0 == --d_->q_)
  {
    // Cancel all outstanding async operations
    d_->t_.cancel();
    d_->fds_.clear();
  }

  invoke_handler(ec, a);
}

template<typename Handler>
void ares_resolver_service::async_resolve_a(implementation_type& impl,
    const std::string& domain, Handler handler)
{
  typedef query<Handler, iterator_a> query_t;
  boost::shared_ptr<query_t>(new query_t(impl, handler))->start(domain);
}

template<typename Handler>
void ares_resolver_service::async_resolve_aaaa(implementation_type& impl,
    const std::string& domain, Handler handler)
{
  typedef query<Handler, iterator_aaaa> query_t;
  boost::shared_ptr<query_t>(new query_t(impl, handler))->start(domain);
}

template<typename Handler>
void ares_resolver_service::async_resolve_ptr(implementation_type& impl,
    const std::string& domain, Handler handler)
{
  typedef query<Handler, iterator_ptr> query_t;
  boost::shared_ptr<query_t>(new query_t(impl, handler))->start(domain);
}

template<typename Handler>
void ares_resolver_service::async_resolve_mx(implementation_type& impl,
    const std::string& domain, Handler handler)
{
  typedef query<Handler, iterator_mx> query_t;
  boost::shared_ptr<query_t>(new query_t(impl, handler))->start(domain);
}

template<typename Handler>
void ares_resolver_service::async_resolve_txt(implementation_type& impl,
    const std::string& domain, Handler handler)
{
  typedef query<Handler, iterator_txt> query_t;
  boost::shared_ptr<query_t>(new query_t(impl, handler))->start(domain);
}
