diff --git a/bot.go b/bot.go index 8bb8905..11b83fd 100644 --- a/bot.go +++ b/bot.go @@ -24,7 +24,7 @@ func isFromAdmin(sender *tele.User) bool { return false } - _, ok := config.AdminUIDs[sender.ID] + _, ok := config.adminUidLookup[sender.ID] return ok } @@ -61,7 +61,7 @@ func initBot() (*tele.Bot, error) { adminGrp.Handle("/dig", handleDigCmd) - // adminGrp.Handle("/test", handleTestCmd) + // adminGrp.Handle("/test", testCmd) return b, nil } diff --git a/cfg.go b/cfg.go index 7cbf6d4..eec1f2b 100644 --- a/cfg.go +++ b/cfg.go @@ -1,66 +1,94 @@ package main import ( - "os" + "net" "strconv" - "strings" + "git.gensokyo.cafe/kkyy/tgbot_misaka_5882f7/utils" "github.com/go-errors/errors" - "github.com/samber/lo" + "github.com/ilyakaznacheev/cleanenv" ) -type Config struct { - AdminUIDs map[int64]struct{} - TGBotToken string - TGAnnounceCommands bool +type ConfigDef struct { + AdminUIDs []int64 `env:"TG_ADMIN_UIDS"` + TGBotToken string `env:"TG_TOKEN" env-required:""` + TGAnnounceCommands bool `env:"TG_ANNOUNCE_CMDS"` - WatchedInterface string - MonthlyTrafficLimitGiB int + // API + APIMasterKey string `env:"TG_API_MASTER_KEY" env-required:""` + APIListen string `env:"TG_API_LISTEN" env-default:"127.0.0.1:8080"` + + // Traffic usage + WatchedInterface string `env:"TG_WATCHED_INTERFACE"` + MonthlyTrafficLimitGiB int `env:"TG_MONTHLY_TRAFFIC_LIMIT_GIB" env-default:"1000"` + + // Parsed fields + adminUidLookup map[int64]struct{} + apiListenAddr net.Addr } -var config *Config - -func LoadCfg() error { - cfg := Config{} - - token := os.Getenv("TG_TOKEN") - if token == "" { - return errors.New("TG_TOKEN env not set") +func (c *ConfigDef) Parse() error { + c.adminUidLookup = make(map[int64]struct{}, len(c.AdminUIDs)) + for _, u := range c.AdminUIDs { + if u <= 0 { + return errors.New("invalid admin UID: " + strconv.FormatInt(u, 10)) + } + c.adminUidLookup[u] = struct{}{} } - cfg.TGBotToken = token - adminUIDsEnv := os.Getenv("TG_ADMIN_UIDS") - adminUIDs := lo.FilterMap(strings.Split(adminUIDsEnv, ","), func(s string, _ int) (string, bool) { - trimmed := strings.TrimSpace(s) - return trimmed, trimmed != "" - }) - cfg.AdminUIDs = make(map[int64]struct{}, len(adminUIDs)) - for _, uidStr := range adminUIDs { - uid, err := strconv.ParseInt(uidStr, 10, 64) + // validate API listen + ip, port, err := net.SplitHostPort(c.APIListen) + if err != nil { + return errors.WrapPrefix(err, "invalid API listen address", 0) + } + if tIP := net.ParseIP(ip); tIP == nil { + return errors.New("invalid API listen address: " + ip) + } + if tPort, err := strconv.Atoi(port); err != nil || tPort <= 0 || tPort > 65535 { + return errors.New("invalid API listen port: " + port) + } + + if c.WatchedInterface == "" { + tIf, err := utils.FindDefaultAdapter() if err != nil { - return errors.New("invalid admin UID: " + uidStr) + return errors.WrapPrefix(err, "Watched interface is not defined and failed to find a default one.", 0) } - cfg.AdminUIDs[uid] = struct{}{} + c.WatchedInterface = tIf.Name + logger.Infof("Watching default interface %s", c.WatchedInterface) } - announceCmdsEnv := os.Getenv("TG_ANNOUNCE_CMDS") - if !lo.Contains([]string{"", "no", "false", "0"}, strings.ToLower(announceCmdsEnv)) { - cfg.TGAnnounceCommands = true - } - - cfg.WatchedInterface = "eth0" - if iface := os.Getenv("TG_WATCHED_INTERFACE"); iface != "" { - cfg.WatchedInterface = iface - } - - cfg.MonthlyTrafficLimitGiB = 1000 - if trafficLimitStr := os.Getenv("TG_MONTHLY_TRAFFIC_LIMIT_GIB"); trafficLimitStr != "" { - var err error - if cfg.MonthlyTrafficLimitGiB, err = strconv.Atoi(trafficLimitStr); err != nil { - return errors.New("invalid traffic limit: " + trafficLimitStr) - } - } - - config = &cfg + return nil +} + +func (c *ConfigDef) IsAdmin(id int64) bool { + _, ok := c.adminUidLookup[id] + return ok +} + +func (c *ConfigDef) GetAPIListenAddr() net.Addr { + return c.apiListenAddr +} + +func configFromEnv() (*ConfigDef, error) { + cfg := ConfigDef{} + err := cleanenv.ReadEnv(&cfg) + if err != nil { + return nil, errors.Wrap(err, 0) + } + return &cfg, nil +} + +var config *ConfigDef + +func LoadCfg() error { + cfg, err := configFromEnv() + if err != nil { + return errors.WrapPrefix(err, "failed to load config from environment", 0) + } + if err = cfg.Parse(); err != nil { + return errors.WrapPrefix(err, "failed to parse config", 0) + } + + config = cfg return nil } diff --git a/utils/net.go b/utils/net.go new file mode 100644 index 0000000..71d4f1b --- /dev/null +++ b/utils/net.go @@ -0,0 +1,53 @@ +package utils + +import ( + "net" + + "github.com/go-errors/errors" +) + +func getLocalIP() (net.IP, error) { + conn, err := net.Dial("udp", "1.0.0.1:53") + if err != nil { + return nil, err + } + defer conn.Close() + + host, _, err := net.SplitHostPort(conn.LocalAddr().String()) + if err != nil { + return nil, errors.Wrap(err, 0) + } + ip := net.ParseIP(host) + if ip == nil { + return nil, errors.Errorf("failed to parse IP address %q", host) + } + return ip, nil +} + +func FindDefaultAdapter() (*net.Interface, error) { + localIP, err := getLocalIP() + if err != nil { + return nil, errors.WrapPrefix(err, "failed to get local address", 0) + } + + iflist, err := net.Interfaces() + if err != nil { + return nil, errors.WrapPrefix(err, "failed to get interface list", 0) + } + for _, iface := range iflist { + addrs, err := iface.Addrs() + if err != nil { + continue + } + for _, addr := range addrs { + ipnet, ok := addr.(*net.IPNet) + if !ok { + continue + } + if localIP.Equal(ipnet.IP) { + return &iface, nil + } + } + } + return nil, errors.New("failed to find default adapter") +}