refactor: rewrite cfg.go

This commit is contained in:
Yiyang Kang 2023-03-07 16:07:14 +08:00
parent b872b73ebf
commit 6971098753
3 changed files with 130 additions and 49 deletions

4
bot.go
View File

@ -24,7 +24,7 @@ func isFromAdmin(sender *tele.User) bool {
return false
}
_, ok := config.AdminUIDs[sender.ID]
_, ok := config.adminUidLookup[sender.ID]
return ok
}
@ -61,7 +61,7 @@ func initBot() (*tele.Bot, error) {
adminGrp.Handle("/dig", handleDigCmd)
// adminGrp.Handle("/test", handleTestCmd)
// adminGrp.Handle("/test", testCmd)
return b, nil
}

118
cfg.go
View File

@ -1,66 +1,94 @@
package main
import (
"os"
"net"
"strconv"
"strings"
"git.gensokyo.cafe/kkyy/tgbot_misaka_5882f7/utils"
"github.com/go-errors/errors"
"github.com/samber/lo"
"github.com/ilyakaznacheev/cleanenv"
)
type Config struct {
AdminUIDs map[int64]struct{}
TGBotToken string
TGAnnounceCommands bool
type ConfigDef struct {
AdminUIDs []int64 `env:"TG_ADMIN_UIDS"`
TGBotToken string `env:"TG_TOKEN" env-required:""`
TGAnnounceCommands bool `env:"TG_ANNOUNCE_CMDS"`
WatchedInterface string
MonthlyTrafficLimitGiB int
// API
APIMasterKey string `env:"TG_API_MASTER_KEY" env-required:""`
APIListen string `env:"TG_API_LISTEN" env-default:"127.0.0.1:8080"`
// Traffic usage
WatchedInterface string `env:"TG_WATCHED_INTERFACE"`
MonthlyTrafficLimitGiB int `env:"TG_MONTHLY_TRAFFIC_LIMIT_GIB" env-default:"1000"`
// Parsed fields
adminUidLookup map[int64]struct{}
apiListenAddr net.Addr
}
var config *Config
func LoadCfg() error {
cfg := Config{}
token := os.Getenv("TG_TOKEN")
if token == "" {
return errors.New("TG_TOKEN env not set")
func (c *ConfigDef) Parse() error {
c.adminUidLookup = make(map[int64]struct{}, len(c.AdminUIDs))
for _, u := range c.AdminUIDs {
if u <= 0 {
return errors.New("invalid admin UID: " + strconv.FormatInt(u, 10))
}
c.adminUidLookup[u] = struct{}{}
}
cfg.TGBotToken = token
adminUIDsEnv := os.Getenv("TG_ADMIN_UIDS")
adminUIDs := lo.FilterMap(strings.Split(adminUIDsEnv, ","), func(s string, _ int) (string, bool) {
trimmed := strings.TrimSpace(s)
return trimmed, trimmed != ""
})
cfg.AdminUIDs = make(map[int64]struct{}, len(adminUIDs))
for _, uidStr := range adminUIDs {
uid, err := strconv.ParseInt(uidStr, 10, 64)
// validate API listen
ip, port, err := net.SplitHostPort(c.APIListen)
if err != nil {
return errors.New("invalid admin UID: " + uidStr)
return errors.WrapPrefix(err, "invalid API listen address", 0)
}
cfg.AdminUIDs[uid] = struct{}{}
if tIP := net.ParseIP(ip); tIP == nil {
return errors.New("invalid API listen address: " + ip)
}
if tPort, err := strconv.Atoi(port); err != nil || tPort <= 0 || tPort > 65535 {
return errors.New("invalid API listen port: " + port)
}
announceCmdsEnv := os.Getenv("TG_ANNOUNCE_CMDS")
if !lo.Contains([]string{"", "no", "false", "0"}, strings.ToLower(announceCmdsEnv)) {
cfg.TGAnnounceCommands = true
if c.WatchedInterface == "" {
tIf, err := utils.FindDefaultAdapter()
if err != nil {
return errors.WrapPrefix(err, "Watched interface is not defined and failed to find a default one.", 0)
}
c.WatchedInterface = tIf.Name
logger.Infof("Watching default interface %s", c.WatchedInterface)
}
cfg.WatchedInterface = "eth0"
if iface := os.Getenv("TG_WATCHED_INTERFACE"); iface != "" {
cfg.WatchedInterface = iface
}
cfg.MonthlyTrafficLimitGiB = 1000
if trafficLimitStr := os.Getenv("TG_MONTHLY_TRAFFIC_LIMIT_GIB"); trafficLimitStr != "" {
var err error
if cfg.MonthlyTrafficLimitGiB, err = strconv.Atoi(trafficLimitStr); err != nil {
return errors.New("invalid traffic limit: " + trafficLimitStr)
}
}
config = &cfg
return nil
}
func (c *ConfigDef) IsAdmin(id int64) bool {
_, ok := c.adminUidLookup[id]
return ok
}
func (c *ConfigDef) GetAPIListenAddr() net.Addr {
return c.apiListenAddr
}
func configFromEnv() (*ConfigDef, error) {
cfg := ConfigDef{}
err := cleanenv.ReadEnv(&cfg)
if err != nil {
return nil, errors.Wrap(err, 0)
}
return &cfg, nil
}
var config *ConfigDef
func LoadCfg() error {
cfg, err := configFromEnv()
if err != nil {
return errors.WrapPrefix(err, "failed to load config from environment", 0)
}
if err = cfg.Parse(); err != nil {
return errors.WrapPrefix(err, "failed to parse config", 0)
}
config = cfg
return nil
}

53
utils/net.go Normal file
View File

@ -0,0 +1,53 @@
package utils
import (
"net"
"github.com/go-errors/errors"
)
func getLocalIP() (net.IP, error) {
conn, err := net.Dial("udp", "1.0.0.1:53")
if err != nil {
return nil, err
}
defer conn.Close()
host, _, err := net.SplitHostPort(conn.LocalAddr().String())
if err != nil {
return nil, errors.Wrap(err, 0)
}
ip := net.ParseIP(host)
if ip == nil {
return nil, errors.Errorf("failed to parse IP address %q", host)
}
return ip, nil
}
func FindDefaultAdapter() (*net.Interface, error) {
localIP, err := getLocalIP()
if err != nil {
return nil, errors.WrapPrefix(err, "failed to get local address", 0)
}
iflist, err := net.Interfaces()
if err != nil {
return nil, errors.WrapPrefix(err, "failed to get interface list", 0)
}
for _, iface := range iflist {
addrs, err := iface.Addrs()
if err != nil {
continue
}
for _, addr := range addrs {
ipnet, ok := addr.(*net.IPNet)
if !ok {
continue
}
if localIP.Equal(ipnet.IP) {
return &iface, nil
}
}
}
return nil, errors.New("failed to find default adapter")
}