refactor: rewrite cfg.go

This commit is contained in:
Yiyang Kang 2023-03-07 16:07:14 +08:00
parent b872b73ebf
commit 6971098753
3 changed files with 130 additions and 49 deletions

4
bot.go
View File

@ -24,7 +24,7 @@ func isFromAdmin(sender *tele.User) bool {
return false return false
} }
_, ok := config.AdminUIDs[sender.ID] _, ok := config.adminUidLookup[sender.ID]
return ok return ok
} }
@ -61,7 +61,7 @@ func initBot() (*tele.Bot, error) {
adminGrp.Handle("/dig", handleDigCmd) adminGrp.Handle("/dig", handleDigCmd)
// adminGrp.Handle("/test", handleTestCmd) // adminGrp.Handle("/test", testCmd)
return b, nil return b, nil
} }

122
cfg.go
View File

@ -1,66 +1,94 @@
package main package main
import ( import (
"os" "net"
"strconv" "strconv"
"strings"
"git.gensokyo.cafe/kkyy/tgbot_misaka_5882f7/utils"
"github.com/go-errors/errors" "github.com/go-errors/errors"
"github.com/samber/lo" "github.com/ilyakaznacheev/cleanenv"
) )
type Config struct { type ConfigDef struct {
AdminUIDs map[int64]struct{} AdminUIDs []int64 `env:"TG_ADMIN_UIDS"`
TGBotToken string TGBotToken string `env:"TG_TOKEN" env-required:""`
TGAnnounceCommands bool TGAnnounceCommands bool `env:"TG_ANNOUNCE_CMDS"`
WatchedInterface string // API
MonthlyTrafficLimitGiB int APIMasterKey string `env:"TG_API_MASTER_KEY" env-required:""`
APIListen string `env:"TG_API_LISTEN" env-default:"127.0.0.1:8080"`
// Traffic usage
WatchedInterface string `env:"TG_WATCHED_INTERFACE"`
MonthlyTrafficLimitGiB int `env:"TG_MONTHLY_TRAFFIC_LIMIT_GIB" env-default:"1000"`
// Parsed fields
adminUidLookup map[int64]struct{}
apiListenAddr net.Addr
} }
var config *Config func (c *ConfigDef) Parse() error {
c.adminUidLookup = make(map[int64]struct{}, len(c.AdminUIDs))
func LoadCfg() error { for _, u := range c.AdminUIDs {
cfg := Config{} if u <= 0 {
return errors.New("invalid admin UID: " + strconv.FormatInt(u, 10))
token := os.Getenv("TG_TOKEN") }
if token == "" { c.adminUidLookup[u] = struct{}{}
return errors.New("TG_TOKEN env not set")
} }
cfg.TGBotToken = token
adminUIDsEnv := os.Getenv("TG_ADMIN_UIDS") // validate API listen
adminUIDs := lo.FilterMap(strings.Split(adminUIDsEnv, ","), func(s string, _ int) (string, bool) { ip, port, err := net.SplitHostPort(c.APIListen)
trimmed := strings.TrimSpace(s) if err != nil {
return trimmed, trimmed != "" return errors.WrapPrefix(err, "invalid API listen address", 0)
}) }
cfg.AdminUIDs = make(map[int64]struct{}, len(adminUIDs)) if tIP := net.ParseIP(ip); tIP == nil {
for _, uidStr := range adminUIDs { return errors.New("invalid API listen address: " + ip)
uid, err := strconv.ParseInt(uidStr, 10, 64) }
if tPort, err := strconv.Atoi(port); err != nil || tPort <= 0 || tPort > 65535 {
return errors.New("invalid API listen port: " + port)
}
if c.WatchedInterface == "" {
tIf, err := utils.FindDefaultAdapter()
if err != nil { if err != nil {
return errors.New("invalid admin UID: " + uidStr) return errors.WrapPrefix(err, "Watched interface is not defined and failed to find a default one.", 0)
} }
cfg.AdminUIDs[uid] = struct{}{} c.WatchedInterface = tIf.Name
logger.Infof("Watching default interface %s", c.WatchedInterface)
} }
announceCmdsEnv := os.Getenv("TG_ANNOUNCE_CMDS") return nil
if !lo.Contains([]string{"", "no", "false", "0"}, strings.ToLower(announceCmdsEnv)) { }
cfg.TGAnnounceCommands = true
} func (c *ConfigDef) IsAdmin(id int64) bool {
_, ok := c.adminUidLookup[id]
cfg.WatchedInterface = "eth0" return ok
if iface := os.Getenv("TG_WATCHED_INTERFACE"); iface != "" { }
cfg.WatchedInterface = iface
} func (c *ConfigDef) GetAPIListenAddr() net.Addr {
return c.apiListenAddr
cfg.MonthlyTrafficLimitGiB = 1000 }
if trafficLimitStr := os.Getenv("TG_MONTHLY_TRAFFIC_LIMIT_GIB"); trafficLimitStr != "" {
var err error func configFromEnv() (*ConfigDef, error) {
if cfg.MonthlyTrafficLimitGiB, err = strconv.Atoi(trafficLimitStr); err != nil { cfg := ConfigDef{}
return errors.New("invalid traffic limit: " + trafficLimitStr) err := cleanenv.ReadEnv(&cfg)
} if err != nil {
} return nil, errors.Wrap(err, 0)
}
config = &cfg return &cfg, nil
}
var config *ConfigDef
func LoadCfg() error {
cfg, err := configFromEnv()
if err != nil {
return errors.WrapPrefix(err, "failed to load config from environment", 0)
}
if err = cfg.Parse(); err != nil {
return errors.WrapPrefix(err, "failed to parse config", 0)
}
config = cfg
return nil return nil
} }

53
utils/net.go Normal file
View File

@ -0,0 +1,53 @@
package utils
import (
"net"
"github.com/go-errors/errors"
)
func getLocalIP() (net.IP, error) {
conn, err := net.Dial("udp", "1.0.0.1:53")
if err != nil {
return nil, err
}
defer conn.Close()
host, _, err := net.SplitHostPort(conn.LocalAddr().String())
if err != nil {
return nil, errors.Wrap(err, 0)
}
ip := net.ParseIP(host)
if ip == nil {
return nil, errors.Errorf("failed to parse IP address %q", host)
}
return ip, nil
}
func FindDefaultAdapter() (*net.Interface, error) {
localIP, err := getLocalIP()
if err != nil {
return nil, errors.WrapPrefix(err, "failed to get local address", 0)
}
iflist, err := net.Interfaces()
if err != nil {
return nil, errors.WrapPrefix(err, "failed to get interface list", 0)
}
for _, iface := range iflist {
addrs, err := iface.Addrs()
if err != nil {
continue
}
for _, addr := range addrs {
ipnet, ok := addr.(*net.IPNet)
if !ok {
continue
}
if localIP.Equal(ipnet.IP) {
return &iface, nil
}
}
}
return nil, errors.New("failed to find default adapter")
}