package hostcmds

import (
	"bufio"
	"bytes"
	"fmt"
	"net"
	"os/exec"
	"regexp"
	"strconv"
	"strings"
	"time"

	"github.com/go-errors/errors"
	"golang.org/x/net/idna"

	"git.gensokyo.cafe/kkyy/tgbot_misaka_5882f7/utils"
)

var (
	// According to wikipedia
	digValidDnsTypes = utils.ToLookupMap([]string{
		"A", "AAAA", "AFSDB", "APL", "CAA", "CDNSKEY", "CDS", "CERT", "CNAME", "CSYNC", "DHCID", "DLV", "DNAME",
		"DNSKEY", "DS", "EUI48", "EUI64", "HINFO", "HIP", "HTTPS", "IPSECKEY", "KEY", "KX", "LOC", "MX", "NAPTR", "NS",
		"NSEC", "NSEC3", "NSEC3PARAM", "OPENPGPKEY", "PTR", "RRSIG", "RP", "SIG", "SMIMEA", "SOA", "SRV", "SSHFP",
		"SVCB", "TA", "TKEY", "TLSA", "TSIG", "TXT", "URI", "ZONEMD",
	})

	digErrInvalidArgs = fmt.Errorf("invalid request")

	digIdnaMapper = idna.New(idna.MapForLookup(), idna.StrictDomainName(false))
)

type DigRequest struct {
	Name    string
	Type    string
	Reverse bool
}

func NewDigRequest(req string) (*DigRequest, error) {
	ret := &DigRequest{}

	args := strings.Fields(req)
	nArgs := len(args)
	if nArgs == 0 || nArgs > 2 {
		return nil, digErrInvalidArgs
	}

	if nArgs > 1 {
		typ := strings.ToUpper(args[1])
		if _, ok := digValidDnsTypes[typ]; !ok {
			return nil, digErrInvalidArgs
		}
		ret.Type = typ
	}

	ip := net.ParseIP(args[0])
	if ip != nil {
		ret.Name = ip.String()
		ret.Reverse = true
		ret.Type = ""
		return ret, nil
	}

	name, err := digIdnaMapper.ToASCII(args[0])
	if err != nil {
		return nil, digErrInvalidArgs
	}
	ret.Name = name
	return ret, nil
}

func (r *DigRequest) ToCmdArgs() []string {
	args := make([]string, 0, 4)
	args = append(args, "-u")

	if r.Reverse {
		args = append(args, "-x")
	}
	args = append(args, r.Name)

	if r.Type != "" {
		args = append(args, r.Type)
	}
	return args
}

type DigDnsRecord struct {
	Name  string
	TTL   int
	Class string
	Type  string
	Data  string
}

type DigResponse struct {
	Status    string
	QueryTime time.Duration
	Records   []DigDnsRecord
}

var (
	digRespDnsRecordLineRe = regexp.MustCompile(`^([^;\s]\S*)\s+(\d+)\s+([A-Z]+)\s+([A-Z]+)\s+(.*)$`)
	digRespHeaderLineRe    = regexp.MustCompile(`^;;.*HEADER.*status: ([A-Z]+),.*$`)
	digRespQueryTimeLineRe = regexp.MustCompile(`^;; Query time: (\d+) usec$`)
)

func buildDigResponse(buf []byte) (*DigResponse, error) {
	sc := bufio.NewScanner(bytes.NewReader(buf))
	ret := &DigResponse{}

	for sc.Scan() {
		line := sc.Text()
		if line == "" {
			continue
		}
		if line[0] == ';' {
			if ret.Status == "" {
				m := digRespHeaderLineRe.FindStringSubmatch(line)
				if len(m) == 2 {
					ret.Status = m[1]
					continue
				}
			}

			m := digRespQueryTimeLineRe.FindStringSubmatch(line)
			if len(m) == 2 {
				usec, err := strconv.ParseInt(m[1], 10, 64)
				if err != nil {
					return nil, errors.WrapPrefix(err, "failed to parse query time", 0)
				}
				ret.QueryTime = time.Microsecond * time.Duration(usec)
				continue
			}
		}

		m := digRespDnsRecordLineRe.FindStringSubmatch(line)
		if len(m) == 6 {
			ttl, err := strconv.Atoi(m[2])
			if err != nil {
				return nil, errors.WrapPrefix(err, "failed to parse ttl", 0)
			}
			ret.Records = append(ret.Records, DigDnsRecord{
				Name:  m[1],
				TTL:   ttl,
				Class: m[3],
				Type:  m[4],
				Data:  m[5],
			})
		}
	}
	if ret.Status == "" {
		return nil, errors.New("failed to parse response: \"status\" is unknown")
	}

	return ret, nil
}

func Dig(req *DigRequest) (*DigResponse, error) {
	cmd := exec.Command("/usr/bin/dig", req.ToCmdArgs()...)
	cmd.Stdin = nil

	buf, err := cmd.Output()
	if err != nil {
		return nil, errors.WrapPrefix(err, "failed to run dig command", 0)
	}

	ret, err := buildDigResponse(buf)
	if err != nil {
		return nil, errors.WrapPrefix(err, "failed to parse dig response", 0)
	}
	return ret, nil
}