//===========================================================================
// @(#) $Name:$
// @(#) $Id: DwmDnsResolver.hh 10142 2018-01-28 19:11:40Z dwm $
//===========================================================================
//  Copyright (c) Daniel W. McRobb 2018
//  All rights reserved.
//
//  Redistribution and use in source and binary forms, with or without
//  modification, are permitted provided that the following conditions
//  are met:
//
//  1. Redistributions of source code must retain the above copyright
//     notice, this list of conditions and the following disclaimer.
//  2. Redistributions in binary form must reproduce the above copyright
//     notice, this list of conditions and the following disclaimer in the
//     documentation and/or other materials provided with the distribution.
//  3. The names of the authors and copyright holders may not be used to
//     endorse or promote products derived from this software without
//     specific prior written permission.
//
//  IN NO EVENT SHALL DANIEL W. MCROBB BE LIABLE TO ANY PARTY FOR
//  DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES,
//  INCLUDING LOST PROFITS, ARISING OUT OF THE USE OF THIS SOFTWARE,
//  EVEN IF DANIEL W. MCROBB HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH
//  DAMAGE.
//
//  THE SOFTWARE PROVIDED HEREIN IS ON AN "AS IS" BASIS, AND
//  DANIEL W. MCROBB HAS NO OBLIGATION TO PROVIDE MAINTENANCE, SUPPORT,
//  UPDATES, ENHANCEMENTS, OR MODIFICATIONS. DANIEL W. MCROBB MAKES NO
//  REPRESENTATIONS AND EXTENDS NO WARRANTIES OF ANY KIND, EITHER
//  IMPLIED OR EXPRESS, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
//  WARRANTIES OF MERCHANTABILITY OR FITNESS FOR A PARTICULAR PURPOSE,
//  OR THAT THE USE OF THIS SOFTWARE WILL NOT INFRINGE ANY PATENT,
//  TRADEMARK OR OTHER RIGHTS.
//===========================================================================

//---------------------------------------------------------------------------
//!  \file DwmDnsResolver.hh
//!  \brief Dwm::Dns::Resolver class definition
//---------------------------------------------------------------------------

#ifndef _DWMDNSRESOLVER_HH_
#define _DWMDNSRESOLVER_HH_

#include <atomic>
#include <type_traits>
#include <vector>

#include "DwmDnsEtcHosts.hh"
#include "DwmDnsNameServer.hh"

namespace Dwm {

  namespace Dns {

    //------------------------------------------------------------------------
    //!  @defgroup resolvergroup Resolver classes
    //------------------------------------------------------------------------
    
    //------------------------------------------------------------------------
    //!  @ingroup resolvergroup
    //!  Simple DNS resolver class.
    //------------------------------------------------------------------------
    class Resolver
    {
    public:
      //----------------------------------------------------------------------
      //!  Construct with the given resolver configuration file and the
      //!  preferred order of resolver sources (ala /etc/nssswitch.conf).
      //----------------------------------------------------------------------
      Resolver(const std::string & filename = "/etc/resolv.conf",
               const std::string & order = "files,dns");

      //----------------------------------------------------------------------
      //!  Fetches the IPv6 and IPv4 addresses for the given host @c name.
      //!  Stores the results in @c in6Addrs and @c inAddrs.  Returns true
      //!  if any addresses were found, false if none were found.
      //----------------------------------------------------------------------
      bool GetHostByName(const std::string & name,
                         std::vector<in6_addr> & in6Addrs,
                         std::vector<in_addr> & inAddrs);

      //----------------------------------------------------------------------
      //!  Fetches the host names for the given IPv6 address @c addr.
      //!  Stores the results in @c names and returns true on success.
      //!  Returns false if no host names were found.
      //----------------------------------------------------------------------
      bool GetHostByAddr(const in6_addr & addr,
                         std::vector<std::string> & names);

      //----------------------------------------------------------------------
      //!  Fetches the host names for the given IPv4 address @c addr.
      //!  Stores the results in @c names and returns true on success.
      //!  Returns false if no host names were found.
      //----------------------------------------------------------------------
      bool GetHostByAddr(const in_addr & addr,
                         std::vector<std::string> & names);
      
      //----------------------------------------------------------------------
      //!  Fetches resource records of type @c rrtype for the given domain
      //!  name @c name.  Stores the records that were found in @c results
      //!  and returns true on success.  Returns false if no records were
      //!  found.  If @c doEDNS is @c true, we will include an OPT resource
      //!  record in our query messages to the DNS server and use a packet
      //!  size of 4K (4096 bytes) for UDP.  If @c doDNSSEC is @c true, we
      //!  will also ask the server to do DNSSEC processing.  If
      //!  @c tcpFallback is @c true, we will use TCP for our query if UDP
      //!  fails.
      //----------------------------------------------------------------------
      bool Get(const std::string & name, uint16_t rrtype,
               std::vector<ResourceRecord> & results,
               bool doEDNS = false, bool doDNSSEC = false,
               bool tcpFallback = true);
      
      //----------------------------------------------------------------------
      //!  Fetches resource records whose data type is RRT for the given
      //!  domain name @c name.  Stores the records that were found in
      //!  @c results and returns true on success.  Returns false if no
      //!  records were found.  If @c doEDNS is @c true, we will include an
      //!  OPT resource record in our query messages to the DNS server and
      //!  use a packet size of 4K (4096 bytes) for UDP.  If @c doDNSSEC is
      //!  @c true, we will also ask the server to do DNSSEC processing.  If
      //!  @c tcpFallback is @c true, we will use TCP for our query if UDP
      //!  fails.
      //----------------------------------------------------------------------
      template <typename RRT>
      bool Get(const std::string & name,
               std::vector<ResourceRecord> & results,
               bool doEDNS = false, bool doDNSSEC = false,
               bool tcpFallback = true)
      {
        results.clear();
        if (std::is_same<RRT,RRDataPTR>::value
            || std::is_same<RRT,RRDataA>::value
            || std::is_same<RRT,RRDataAAAA>::value) {
          if (_order.find("files") == 0) {
            EtcHosts  etcHosts;
            etcHosts.Get(name, RRT::k_rrtype, results);
          }
        }
        
        if (results.empty() && (_order.find("dns") != std::string::npos)) {
          std::vector<std::string>  namesToTry;
          GetNamesToTry(name, RRT::k_rrtype, namesToTry);
          for (auto & n : namesToTry) {
            Message  message;
            message.Header().Id(_msgid++);
            message.Header().RecursionDesired(true);
            if (doEDNS) {
              message.EnableEDNS(4096, doDNSSEC);
            }
            Dns::MessageQuestion  question(n, RRT::k_rrtype,
                                           MessageQuestion::k_classIN);
            message.Questions().push_back(question);
            for (auto & ns : _nameservers) {
              if (GetViaUDP<RRT>(ns, message, results)) {
                break;
              }
              else if (tcpFallback) {
                if (GetViaTCP<RRT>(ns, message, results)) {
                  break;
                }
              }
            }
            if (! results.empty()) {
              break;
            }
          }
        }

        if (std::is_same<RRT,RRDataPTR>::value
            || std::is_same<RRT,RRDataA>::value
            || std::is_same<RRT,RRDataAAAA>::value) {
          std::string::size_type  fidx = _order.find("files");
          if (results.empty() && (fidx != std::string::npos) && (fidx > 0)) {
            EtcHosts  etcHosts;
            etcHosts.Get(name, RRT::k_rrtype, results);
          }
        }
        
        return (! results.empty());
      }
      
      //----------------------------------------------------------------------
      //!  Gets resource record data of type RRT for the given domain name
      //!  @c name.  Stores the results in @c results and returns true on
      //!  success.  Returns false if no records were found.  If @c doEDNS
      //!  is @c true, we will include an OPT resource record in our query
      //!  messages to the DNS server and use a packet size of 4K (4096
      //!  bytes) for UDP.  If @c doDNSSEC is @c true, we will also ask the
      //!  server to do DNSSEC processing.  If @c tcpFallback is @c true, we
      //!  will use TCP for our query if UDP fails.
      //----------------------------------------------------------------------
      template <typename RRT>
      bool Get(const std::string & name, std::vector<RRT> & results,
               bool doEDNS = false, bool doDNSSEC = false,
               bool tcpFallback = true)
      {
        results.clear();
        std::vector<ResourceRecord>  rrs;
        bool  rc = Get<RRT>(name, rrs, doEDNS, doDNSSEC, tcpFallback);
        for (auto rr : rrs) {
          RRT  *data = rr.Data<RRT>();
          if (data != nullptr) {
            results.push_back(*data);
          }
        }
        return (! results.empty());
      }

      //----------------------------------------------------------------------
      //!  Returns a const reference to the contained nameservers.
      //----------------------------------------------------------------------
      const std::vector<NameServer> &	NameServers() const
      {
        return _nameservers;
      }
      
      //----------------------------------------------------------------------
      //!  Returns a mutable reference to the contained nameservers.
      //----------------------------------------------------------------------
      std::vector<NameServer> & NameServers()
      {
        return _nameservers;
      }
      
    private:
      std::vector<NameServer>   _nameservers;
      std::vector<std::string>  _searchList;
      std::string               _domain;
      std::string               _order;
      std::atomic<uint16_t>     _msgid;

      //----------------------------------------------------------------------
      //!  
      //----------------------------------------------------------------------
      template <typename RRT>
      bool GetViaUDP(NameServer & ns, Message & querymsg,
                     std::vector<ResourceRecord> & results)
      {
        if (ns.SendMessage(querymsg)) {
          Message  rmsg;
          if (ns.ReceiveMessage(rmsg)) {
            for (auto & answer : rmsg.Answers()) {
              RRT  *rrdata = answer.Data<RRT>();
              if (rrdata != nullptr) {
                results.push_back(answer);
              }
            }
          }
        }
        return (! results.empty());
      }

      //----------------------------------------------------------------------
      //!  
      //----------------------------------------------------------------------
      template <typename RRT>
      bool GetViaTCP(NameServer & ns, Message & querymsg,
                     std::vector<ResourceRecord> & results)
      {
        if (ns.WriteMessage(querymsg)) {
          Message  rmsg;
          if (ns.ReadMessage(rmsg)) {
            for (auto & answer : rmsg.Answers()) {
              RRT  *rrdata = answer.Data<RRT>();
              if (rrdata != nullptr) {
                results.push_back(answer);
              }
            }
          }
        }
        return (! results.empty());
      }
      
      //----------------------------------------------------------------------
      //!  
      //----------------------------------------------------------------------
      void GetNamesToTry(const std::string & name, uint16_t rrtype,
                         std::vector<std::string> & names) const;
    };
    
  }  // namespace Dns

}  // namespace Dwm

#endif  // _DWMDNSRESOLVER_HH_
