diff --git a/bot.go b/bot.go index dfc0c56..319eca1 100644 --- a/bot.go +++ b/bot.go @@ -2,17 +2,25 @@ package main import ( "fmt" + "html" "math" "strings" + "text/tabwriter" "time" "github.com/go-errors/errors" "github.com/samber/lo" tele "gopkg.in/telebot.v3" + "git.gensokyo.cafe/kkyy/tgbot_misaka_5882f7/cmds" "git.gensokyo.cafe/kkyy/tgbot_misaka_5882f7/stats" ) +const ( + stickerPanic = "CAACAgUAAxkBAAMjY3zoraxZGB8Xejyw86bHLSWLjVcAArMIAAL7-nhXNK7dStmRUGsrBA" + stickerLoading = "CAACAgUAAxkBAAMmY3zp5UCMVRvy1isFCPHrx-UBWX8AApYHAALP8GhXEm9ZIBjn1v8rBA" +) + func isFromAdmin(sender *tele.User) bool { if sender == nil { return false @@ -33,17 +41,49 @@ func initBot() (*tele.Bot, error) { return nil, errors.Wrap(err, 0) } + b.Use(logMiddleware) + // command routing b.Handle("/start", handleStartCmd) - b.Handle("/traffic", handleTrafficCmd) b.Handle("/me", handleUserInfoCmd) b.Handle("/chat", handleChatInfoCmd) b.Handle("/year_progress", handleYearProgressCmd) + b.Handle(tele.OnText, handleGeneralMessage) + b.Handle(tele.OnSticker, handleGeneralMessage) + + // admin required + adminGrp := b.Group() + adminGrp.Use(adminMiddleware) + adminGrp.Handle("/traffic", handleTrafficCmd) + adminGrp.Handle("/dig", handleDigCmd) + + // adminGrp.Handle("/test", handleTestCmd) + return b, nil } +func adminMiddleware(next tele.HandlerFunc) tele.HandlerFunc { + return func(c tele.Context) error { + if !isFromAdmin(c.Sender()) { + return nil + } + return next(c) + } +} + +func logMiddleware(next tele.HandlerFunc) tele.HandlerFunc { + return func(c tele.Context) error { + upd := c.Update() + defer func() { + logger.Infow("Log middleware", "update", upd) + }() + + return next(c) + } +} + func handleStartCmd(c tele.Context) error { if !isFromAdmin(c.Sender()) { return c.Send("Hello, stranger :)") @@ -53,18 +93,14 @@ func handleStartCmd(c tele.Context) error { } func handleTrafficCmd(c tele.Context) error { - if !isFromAdmin(c.Sender()) { - return nil - } - dailyTraffic, err := stats.VnstatDailyTraffic(config.WatchedInterface) if err != nil { - _ = c.Reply("im die, thank you forever") + _ = c.Reply(stickerFromID(stickerPanic)) return err } monthlyTraffic, err := stats.VnstatMonthlyTraffic(config.WatchedInterface) if err != nil { - _ = c.Reply("im die, thank you forever") + _ = c.Reply(stickerFromID(stickerPanic)) return err } @@ -200,3 +236,55 @@ func handleYearProgressCmd(c tele.Context) error { ) return c.Reply(replyText, &tele.SendOptions{ParseMode: tele.ModeHTML}) } + +func handleGeneralMessage(_ tele.Context) error { + // Do nothing for now + return nil +} + +func stickerFromID(id string) *tele.Sticker { + return &tele.Sticker{ + File: tele.File{ + FileID: id, + }, + } +} + +func handleDigCmd(c tele.Context) error { + msg := c.Message() + if msg == nil { + return nil + } + + req, err := cmds.NewDigRequest(msg.Payload) + if err != nil { + return c.Reply("Invalid arguments.\nUsage: `/dig [type]`", &tele.SendOptions{ParseMode: tele.ModeMarkdown}) + } + + resp, err := cmds.Dig(req) + if err != nil { + _ = c.Reply(stickerFromID(stickerPanic)) + return err + } + + replyBuf := &strings.Builder{} + tw := tabwriter.NewWriter(replyBuf, 0, 0, 2, ' ', 0) + // Write header + if len(resp.Records) > 0 { + _, _ = tw.Write([]byte("Name\tTTL\tType\tData\n")) + } + // Write data + for _, r := range resp.Records { + _, _ = fmt.Fprintf(tw, "%s\t%d\t%s\t%s\n", r.Name, r.TTL, r.Type, r.Data) + } + _ = tw.Flush() + + replyText := []string{ + fmt.Sprintf("Status: %s\n", resp.Status), + fmt.Sprintf("Query Time: %s\n\n", resp.QueryTime), + "
",
+		html.EscapeString(replyBuf.String()),
+		"
", + } + return c.Reply(strings.Join(replyText, ""), &tele.SendOptions{ParseMode: tele.ModeHTML}) +} diff --git a/cmds/dig.go b/cmds/dig.go new file mode 100644 index 0000000..e61c1ab --- /dev/null +++ b/cmds/dig.go @@ -0,0 +1,172 @@ +package cmds + +import ( + "bufio" + "bytes" + "fmt" + "net" + "os/exec" + "regexp" + "strconv" + "strings" + "time" + + "github.com/go-errors/errors" + + "git.gensokyo.cafe/kkyy/tgbot_misaka_5882f7/utils" +) + +var ( + // According to wikipedia + digValidDnsTypes = utils.ToLookupMap([]string{ + "A", "AAAA", "AFSDB", "APL", "CAA", "CDNSKEY", "CDS", "CERT", "CNAME", "CSYNC", "DHCID", "DLV", "DNAME", + "DNSKEY", "DS", "EUI48", "EUI64", "HINFO", "HIP", "HTTPS", "IPSECKEY", "KEY", "KX", "LOC", "MX", "NAPTR", "NS", + "NSEC", "NSEC3", "NSEC3PARAM", "OPENPGPKEY", "PTR", "RRSIG", "RP", "SIG", "SMIMEA", "SOA", "SRV", "SSHFP", + "SVCB", "TA", "TKEY", "TLSA", "TSIG", "TXT", "URI", "ZONEMD", + }) + + digErrInvalidArgs = fmt.Errorf("invalid request") + + digDnsNameRe = regexp.MustCompile(`^([a-z0-9_-]+\.?)+|\.$`) +) + +type DigRequest struct { + Name string + Type string + Reverse bool +} + +func NewDigRequest(req string) (*DigRequest, error) { + ret := &DigRequest{} + + args := strings.Fields(req) + nArgs := len(args) + if nArgs == 0 || nArgs > 2 { + return nil, digErrInvalidArgs + } + + if nArgs > 1 { + typ := strings.ToUpper(args[1]) + if _, ok := digValidDnsTypes[typ]; !ok { + return nil, digErrInvalidArgs + } + ret.Type = typ + } + + ip := net.ParseIP(args[0]) + if ip != nil { + ret.Name = ip.String() + ret.Reverse = true + ret.Type = "" + return ret, nil + } + + name := strings.ToLower(args[0]) + if !digDnsNameRe.Match([]byte(name)) { + return nil, digErrInvalidArgs + } + ret.Name = name + return ret, nil +} + +func (r *DigRequest) ToCmdArgs() []string { + args := make([]string, 0, 4) + args = append(args, "-u") + + if r.Reverse { + args = append(args, "-x") + } + args = append(args, r.Name) + + if r.Type != "" { + args = append(args, r.Type) + } + return args +} + +type DigDnsRecord struct { + Name string + TTL int + Class string + Type string + Data string +} + +type DigResponse struct { + Status string + QueryTime time.Duration + Records []DigDnsRecord +} + +var ( + digRespDnsRecordLineRe = regexp.MustCompile(`^([^;\s]\S*)\s+(\d+)\s+([A-Z]+)\s+([A-Z]+)\s+(.*)$`) + digRespHeaderLineRe = regexp.MustCompile(`^;;.*HEADER.*status: ([A-Z]+),.*$`) + digResqQueryTimeLineRe = regexp.MustCompile(`^;; Query time: (\d+) usec$`) +) + +func buildDigResponse(buf []byte) (*DigResponse, error) { + sc := bufio.NewScanner(bytes.NewReader(buf)) + ret := &DigResponse{} + + for sc.Scan() { + line := sc.Text() + if line == "" { + continue + } + if line[0] == ';' { + if ret.Status == "" { + m := digRespHeaderLineRe.FindStringSubmatch(line) + if len(m) == 2 { + ret.Status = m[1] + continue + } + } + + m := digResqQueryTimeLineRe.FindStringSubmatch(line) + if len(m) == 2 { + usec, err := strconv.ParseInt(m[1], 10, 64) + if err != nil { + return nil, errors.WrapPrefix(err, "failed to parse query time", 0) + } + ret.QueryTime = time.Microsecond * time.Duration(usec) + continue + } + } + + m := digRespDnsRecordLineRe.FindStringSubmatch(line) + if len(m) == 6 { + ttl, err := strconv.Atoi(m[2]) + if err != nil { + return nil, errors.WrapPrefix(err, "failed to parse ttl", 0) + } + ret.Records = append(ret.Records, DigDnsRecord{ + Name: m[1], + TTL: ttl, + Class: m[3], + Type: m[4], + Data: m[5], + }) + } + } + if ret.Status == "" { + return nil, errors.New("failed to parse response: \"status\" is unknown") + } + + return ret, nil +} + +func Dig(req *DigRequest) (*DigResponse, error) { + cmd := exec.Command("/usr/bin/dig", req.ToCmdArgs()...) + cmd.Stdin = nil + + buf, err := cmd.Output() + if err != nil { + return nil, errors.WrapPrefix(err, "failed to run dig command", 0) + } + + ret, err := buildDigResponse(buf) + if err != nil { + return nil, errors.WrapPrefix(err, "failed to parse dig response", 0) + } + return ret, nil +} diff --git a/utils/utils.go b/utils/utils.go new file mode 100644 index 0000000..fff069f --- /dev/null +++ b/utils/utils.go @@ -0,0 +1,18 @@ +package utils + +func WaitFor(fn func()) <-chan struct{} { + ch := make(chan struct{}) + go func() { + defer close(ch) + fn() + }() + return ch +} + +func ToLookupMap[T comparable](s []T) map[T]struct{} { + m := make(map[T]struct{}, len(s)) + for _, item := range s { + m[item] = struct{}{} + } + return m +}