feat: add dig command
Other changes: - use middleware for permission checking - able to respond with emoji - log incoming requests
This commit is contained in:
		
							parent
							
								
									412725c6b9
								
							
						
					
					
						commit
						37dffe56ce
					
				
							
								
								
									
										102
									
								
								bot.go
								
								
								
								
							
							
						
						
									
										102
									
								
								bot.go
								
								
								
								
							| 
						 | 
				
			
			@ -2,17 +2,25 @@ package main
 | 
			
		|||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"html"
 | 
			
		||||
	"math"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"text/tabwriter"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/go-errors/errors"
 | 
			
		||||
	"github.com/samber/lo"
 | 
			
		||||
	tele "gopkg.in/telebot.v3"
 | 
			
		||||
 | 
			
		||||
	"git.gensokyo.cafe/kkyy/tgbot_misaka_5882f7/cmds"
 | 
			
		||||
	"git.gensokyo.cafe/kkyy/tgbot_misaka_5882f7/stats"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	stickerPanic   = "CAACAgUAAxkBAAMjY3zoraxZGB8Xejyw86bHLSWLjVcAArMIAAL7-nhXNK7dStmRUGsrBA"
 | 
			
		||||
	stickerLoading = "CAACAgUAAxkBAAMmY3zp5UCMVRvy1isFCPHrx-UBWX8AApYHAALP8GhXEm9ZIBjn1v8rBA"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func isFromAdmin(sender *tele.User) bool {
 | 
			
		||||
	if sender == nil {
 | 
			
		||||
		return false
 | 
			
		||||
| 
						 | 
				
			
			@ -33,17 +41,49 @@ func initBot() (*tele.Bot, error) {
 | 
			
		|||
		return nil, errors.Wrap(err, 0)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	b.Use(logMiddleware)
 | 
			
		||||
 | 
			
		||||
	// command routing
 | 
			
		||||
	b.Handle("/start", handleStartCmd)
 | 
			
		||||
	b.Handle("/traffic", handleTrafficCmd)
 | 
			
		||||
 | 
			
		||||
	b.Handle("/me", handleUserInfoCmd)
 | 
			
		||||
	b.Handle("/chat", handleChatInfoCmd)
 | 
			
		||||
	b.Handle("/year_progress", handleYearProgressCmd)
 | 
			
		||||
 | 
			
		||||
	b.Handle(tele.OnText, handleGeneralMessage)
 | 
			
		||||
	b.Handle(tele.OnSticker, handleGeneralMessage)
 | 
			
		||||
 | 
			
		||||
	// admin required
 | 
			
		||||
	adminGrp := b.Group()
 | 
			
		||||
	adminGrp.Use(adminMiddleware)
 | 
			
		||||
	adminGrp.Handle("/traffic", handleTrafficCmd)
 | 
			
		||||
	adminGrp.Handle("/dig", handleDigCmd)
 | 
			
		||||
 | 
			
		||||
	// adminGrp.Handle("/test", handleTestCmd)
 | 
			
		||||
 | 
			
		||||
	return b, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func adminMiddleware(next tele.HandlerFunc) tele.HandlerFunc {
 | 
			
		||||
	return func(c tele.Context) error {
 | 
			
		||||
		if !isFromAdmin(c.Sender()) {
 | 
			
		||||
			return nil
 | 
			
		||||
		}
 | 
			
		||||
		return next(c)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func logMiddleware(next tele.HandlerFunc) tele.HandlerFunc {
 | 
			
		||||
	return func(c tele.Context) error {
 | 
			
		||||
		upd := c.Update()
 | 
			
		||||
		defer func() {
 | 
			
		||||
			logger.Infow("Log middleware", "update", upd)
 | 
			
		||||
		}()
 | 
			
		||||
 | 
			
		||||
		return next(c)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func handleStartCmd(c tele.Context) error {
 | 
			
		||||
	if !isFromAdmin(c.Sender()) {
 | 
			
		||||
		return c.Send("Hello, stranger :)")
 | 
			
		||||
| 
						 | 
				
			
			@ -53,18 +93,14 @@ func handleStartCmd(c tele.Context) error {
 | 
			
		|||
}
 | 
			
		||||
 | 
			
		||||
func handleTrafficCmd(c tele.Context) error {
 | 
			
		||||
	if !isFromAdmin(c.Sender()) {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	dailyTraffic, err := stats.VnstatDailyTraffic(config.WatchedInterface)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		_ = c.Reply("im die, thank you forever")
 | 
			
		||||
		_ = c.Reply(stickerFromID(stickerPanic))
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	monthlyTraffic, err := stats.VnstatMonthlyTraffic(config.WatchedInterface)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		_ = c.Reply("im die, thank you forever")
 | 
			
		||||
		_ = c.Reply(stickerFromID(stickerPanic))
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -200,3 +236,55 @@ func handleYearProgressCmd(c tele.Context) error {
 | 
			
		|||
	)
 | 
			
		||||
	return c.Reply(replyText, &tele.SendOptions{ParseMode: tele.ModeHTML})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func handleGeneralMessage(_ tele.Context) error {
 | 
			
		||||
	// Do nothing for now
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func stickerFromID(id string) *tele.Sticker {
 | 
			
		||||
	return &tele.Sticker{
 | 
			
		||||
		File: tele.File{
 | 
			
		||||
			FileID: id,
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func handleDigCmd(c tele.Context) error {
 | 
			
		||||
	msg := c.Message()
 | 
			
		||||
	if msg == nil {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	req, err := cmds.NewDigRequest(msg.Payload)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return c.Reply("Invalid arguments.\nUsage: `/dig <name> [type]`", &tele.SendOptions{ParseMode: tele.ModeMarkdown})
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	resp, err := cmds.Dig(req)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		_ = c.Reply(stickerFromID(stickerPanic))
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	replyBuf := &strings.Builder{}
 | 
			
		||||
	tw := tabwriter.NewWriter(replyBuf, 0, 0, 2, ' ', 0)
 | 
			
		||||
	// Write header
 | 
			
		||||
	if len(resp.Records) > 0 {
 | 
			
		||||
		_, _ = tw.Write([]byte("Name\tTTL\tType\tData\n"))
 | 
			
		||||
	}
 | 
			
		||||
	// Write data
 | 
			
		||||
	for _, r := range resp.Records {
 | 
			
		||||
		_, _ = fmt.Fprintf(tw, "%s\t%d\t%s\t%s\n", r.Name, r.TTL, r.Type, r.Data)
 | 
			
		||||
	}
 | 
			
		||||
	_ = tw.Flush()
 | 
			
		||||
 | 
			
		||||
	replyText := []string{
 | 
			
		||||
		fmt.Sprintf("<i>Status: <b>%s</b></i>\n", resp.Status),
 | 
			
		||||
		fmt.Sprintf("<i>Query Time: <b>%s</b></i>\n\n", resp.QueryTime),
 | 
			
		||||
		"<pre>",
 | 
			
		||||
		html.EscapeString(replyBuf.String()),
 | 
			
		||||
		"</pre>",
 | 
			
		||||
	}
 | 
			
		||||
	return c.Reply(strings.Join(replyText, ""), &tele.SendOptions{ParseMode: tele.ModeHTML})
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -0,0 +1,172 @@
 | 
			
		|||
package cmds
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bufio"
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"net"
 | 
			
		||||
	"os/exec"
 | 
			
		||||
	"regexp"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/go-errors/errors"
 | 
			
		||||
 | 
			
		||||
	"git.gensokyo.cafe/kkyy/tgbot_misaka_5882f7/utils"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var (
 | 
			
		||||
	// According to wikipedia
 | 
			
		||||
	digValidDnsTypes = utils.ToLookupMap([]string{
 | 
			
		||||
		"A", "AAAA", "AFSDB", "APL", "CAA", "CDNSKEY", "CDS", "CERT", "CNAME", "CSYNC", "DHCID", "DLV", "DNAME",
 | 
			
		||||
		"DNSKEY", "DS", "EUI48", "EUI64", "HINFO", "HIP", "HTTPS", "IPSECKEY", "KEY", "KX", "LOC", "MX", "NAPTR", "NS",
 | 
			
		||||
		"NSEC", "NSEC3", "NSEC3PARAM", "OPENPGPKEY", "PTR", "RRSIG", "RP", "SIG", "SMIMEA", "SOA", "SRV", "SSHFP",
 | 
			
		||||
		"SVCB", "TA", "TKEY", "TLSA", "TSIG", "TXT", "URI", "ZONEMD",
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	digErrInvalidArgs = fmt.Errorf("invalid request")
 | 
			
		||||
 | 
			
		||||
	digDnsNameRe = regexp.MustCompile(`^([a-z0-9_-]+\.?)+|\.$`)
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type DigRequest struct {
 | 
			
		||||
	Name    string
 | 
			
		||||
	Type    string
 | 
			
		||||
	Reverse bool
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewDigRequest(req string) (*DigRequest, error) {
 | 
			
		||||
	ret := &DigRequest{}
 | 
			
		||||
 | 
			
		||||
	args := strings.Fields(req)
 | 
			
		||||
	nArgs := len(args)
 | 
			
		||||
	if nArgs == 0 || nArgs > 2 {
 | 
			
		||||
		return nil, digErrInvalidArgs
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if nArgs > 1 {
 | 
			
		||||
		typ := strings.ToUpper(args[1])
 | 
			
		||||
		if _, ok := digValidDnsTypes[typ]; !ok {
 | 
			
		||||
			return nil, digErrInvalidArgs
 | 
			
		||||
		}
 | 
			
		||||
		ret.Type = typ
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	ip := net.ParseIP(args[0])
 | 
			
		||||
	if ip != nil {
 | 
			
		||||
		ret.Name = ip.String()
 | 
			
		||||
		ret.Reverse = true
 | 
			
		||||
		ret.Type = ""
 | 
			
		||||
		return ret, nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	name := strings.ToLower(args[0])
 | 
			
		||||
	if !digDnsNameRe.Match([]byte(name)) {
 | 
			
		||||
		return nil, digErrInvalidArgs
 | 
			
		||||
	}
 | 
			
		||||
	ret.Name = name
 | 
			
		||||
	return ret, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *DigRequest) ToCmdArgs() []string {
 | 
			
		||||
	args := make([]string, 0, 4)
 | 
			
		||||
	args = append(args, "-u")
 | 
			
		||||
 | 
			
		||||
	if r.Reverse {
 | 
			
		||||
		args = append(args, "-x")
 | 
			
		||||
	}
 | 
			
		||||
	args = append(args, r.Name)
 | 
			
		||||
 | 
			
		||||
	if r.Type != "" {
 | 
			
		||||
		args = append(args, r.Type)
 | 
			
		||||
	}
 | 
			
		||||
	return args
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type DigDnsRecord struct {
 | 
			
		||||
	Name  string
 | 
			
		||||
	TTL   int
 | 
			
		||||
	Class string
 | 
			
		||||
	Type  string
 | 
			
		||||
	Data  string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type DigResponse struct {
 | 
			
		||||
	Status    string
 | 
			
		||||
	QueryTime time.Duration
 | 
			
		||||
	Records   []DigDnsRecord
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var (
 | 
			
		||||
	digRespDnsRecordLineRe = regexp.MustCompile(`^([^;\s]\S*)\s+(\d+)\s+([A-Z]+)\s+([A-Z]+)\s+(.*)$`)
 | 
			
		||||
	digRespHeaderLineRe    = regexp.MustCompile(`^;;.*HEADER.*status: ([A-Z]+),.*$`)
 | 
			
		||||
	digResqQueryTimeLineRe = regexp.MustCompile(`^;; Query time: (\d+) usec$`)
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func buildDigResponse(buf []byte) (*DigResponse, error) {
 | 
			
		||||
	sc := bufio.NewScanner(bytes.NewReader(buf))
 | 
			
		||||
	ret := &DigResponse{}
 | 
			
		||||
 | 
			
		||||
	for sc.Scan() {
 | 
			
		||||
		line := sc.Text()
 | 
			
		||||
		if line == "" {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		if line[0] == ';' {
 | 
			
		||||
			if ret.Status == "" {
 | 
			
		||||
				m := digRespHeaderLineRe.FindStringSubmatch(line)
 | 
			
		||||
				if len(m) == 2 {
 | 
			
		||||
					ret.Status = m[1]
 | 
			
		||||
					continue
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			m := digResqQueryTimeLineRe.FindStringSubmatch(line)
 | 
			
		||||
			if len(m) == 2 {
 | 
			
		||||
				usec, err := strconv.ParseInt(m[1], 10, 64)
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					return nil, errors.WrapPrefix(err, "failed to parse query time", 0)
 | 
			
		||||
				}
 | 
			
		||||
				ret.QueryTime = time.Microsecond * time.Duration(usec)
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		m := digRespDnsRecordLineRe.FindStringSubmatch(line)
 | 
			
		||||
		if len(m) == 6 {
 | 
			
		||||
			ttl, err := strconv.Atoi(m[2])
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return nil, errors.WrapPrefix(err, "failed to parse ttl", 0)
 | 
			
		||||
			}
 | 
			
		||||
			ret.Records = append(ret.Records, DigDnsRecord{
 | 
			
		||||
				Name:  m[1],
 | 
			
		||||
				TTL:   ttl,
 | 
			
		||||
				Class: m[3],
 | 
			
		||||
				Type:  m[4],
 | 
			
		||||
				Data:  m[5],
 | 
			
		||||
			})
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	if ret.Status == "" {
 | 
			
		||||
		return nil, errors.New("failed to parse response: \"status\" is unknown")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return ret, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func Dig(req *DigRequest) (*DigResponse, error) {
 | 
			
		||||
	cmd := exec.Command("/usr/bin/dig", req.ToCmdArgs()...)
 | 
			
		||||
	cmd.Stdin = nil
 | 
			
		||||
 | 
			
		||||
	buf, err := cmd.Output()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, errors.WrapPrefix(err, "failed to run dig command", 0)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	ret, err := buildDigResponse(buf)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, errors.WrapPrefix(err, "failed to parse dig response", 0)
 | 
			
		||||
	}
 | 
			
		||||
	return ret, nil
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,18 @@
 | 
			
		|||
package utils
 | 
			
		||||
 | 
			
		||||
func WaitFor(fn func()) <-chan struct{} {
 | 
			
		||||
	ch := make(chan struct{})
 | 
			
		||||
	go func() {
 | 
			
		||||
		defer close(ch)
 | 
			
		||||
		fn()
 | 
			
		||||
	}()
 | 
			
		||||
	return ch
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func ToLookupMap[T comparable](s []T) map[T]struct{} {
 | 
			
		||||
	m := make(map[T]struct{}, len(s))
 | 
			
		||||
	for _, item := range s {
 | 
			
		||||
		m[item] = struct{}{}
 | 
			
		||||
	}
 | 
			
		||||
	return m
 | 
			
		||||
}
 | 
			
		||||
		Loading…
	
		Reference in New Issue