refactor: rewrite cfg.go
This commit is contained in:
parent
b872b73ebf
commit
6971098753
4
bot.go
4
bot.go
|
@ -24,7 +24,7 @@ func isFromAdmin(sender *tele.User) bool {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
_, ok := config.AdminUIDs[sender.ID]
|
_, ok := config.adminUidLookup[sender.ID]
|
||||||
return ok
|
return ok
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -61,7 +61,7 @@ func initBot() (*tele.Bot, error) {
|
||||||
|
|
||||||
adminGrp.Handle("/dig", handleDigCmd)
|
adminGrp.Handle("/dig", handleDigCmd)
|
||||||
|
|
||||||
// adminGrp.Handle("/test", handleTestCmd)
|
// adminGrp.Handle("/test", testCmd)
|
||||||
|
|
||||||
return b, nil
|
return b, nil
|
||||||
}
|
}
|
||||||
|
|
122
cfg.go
122
cfg.go
|
@ -1,66 +1,94 @@
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"os"
|
"net"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
|
||||||
|
|
||||||
|
"git.gensokyo.cafe/kkyy/tgbot_misaka_5882f7/utils"
|
||||||
"github.com/go-errors/errors"
|
"github.com/go-errors/errors"
|
||||||
"github.com/samber/lo"
|
"github.com/ilyakaznacheev/cleanenv"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Config struct {
|
type ConfigDef struct {
|
||||||
AdminUIDs map[int64]struct{}
|
AdminUIDs []int64 `env:"TG_ADMIN_UIDS"`
|
||||||
TGBotToken string
|
TGBotToken string `env:"TG_TOKEN" env-required:""`
|
||||||
TGAnnounceCommands bool
|
TGAnnounceCommands bool `env:"TG_ANNOUNCE_CMDS"`
|
||||||
|
|
||||||
WatchedInterface string
|
// API
|
||||||
MonthlyTrafficLimitGiB int
|
APIMasterKey string `env:"TG_API_MASTER_KEY" env-required:""`
|
||||||
|
APIListen string `env:"TG_API_LISTEN" env-default:"127.0.0.1:8080"`
|
||||||
|
|
||||||
|
// Traffic usage
|
||||||
|
WatchedInterface string `env:"TG_WATCHED_INTERFACE"`
|
||||||
|
MonthlyTrafficLimitGiB int `env:"TG_MONTHLY_TRAFFIC_LIMIT_GIB" env-default:"1000"`
|
||||||
|
|
||||||
|
// Parsed fields
|
||||||
|
adminUidLookup map[int64]struct{}
|
||||||
|
apiListenAddr net.Addr
|
||||||
}
|
}
|
||||||
|
|
||||||
var config *Config
|
func (c *ConfigDef) Parse() error {
|
||||||
|
c.adminUidLookup = make(map[int64]struct{}, len(c.AdminUIDs))
|
||||||
func LoadCfg() error {
|
for _, u := range c.AdminUIDs {
|
||||||
cfg := Config{}
|
if u <= 0 {
|
||||||
|
return errors.New("invalid admin UID: " + strconv.FormatInt(u, 10))
|
||||||
token := os.Getenv("TG_TOKEN")
|
}
|
||||||
if token == "" {
|
c.adminUidLookup[u] = struct{}{}
|
||||||
return errors.New("TG_TOKEN env not set")
|
|
||||||
}
|
}
|
||||||
cfg.TGBotToken = token
|
|
||||||
|
|
||||||
adminUIDsEnv := os.Getenv("TG_ADMIN_UIDS")
|
// validate API listen
|
||||||
adminUIDs := lo.FilterMap(strings.Split(adminUIDsEnv, ","), func(s string, _ int) (string, bool) {
|
ip, port, err := net.SplitHostPort(c.APIListen)
|
||||||
trimmed := strings.TrimSpace(s)
|
if err != nil {
|
||||||
return trimmed, trimmed != ""
|
return errors.WrapPrefix(err, "invalid API listen address", 0)
|
||||||
})
|
}
|
||||||
cfg.AdminUIDs = make(map[int64]struct{}, len(adminUIDs))
|
if tIP := net.ParseIP(ip); tIP == nil {
|
||||||
for _, uidStr := range adminUIDs {
|
return errors.New("invalid API listen address: " + ip)
|
||||||
uid, err := strconv.ParseInt(uidStr, 10, 64)
|
}
|
||||||
|
if tPort, err := strconv.Atoi(port); err != nil || tPort <= 0 || tPort > 65535 {
|
||||||
|
return errors.New("invalid API listen port: " + port)
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.WatchedInterface == "" {
|
||||||
|
tIf, err := utils.FindDefaultAdapter()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.New("invalid admin UID: " + uidStr)
|
return errors.WrapPrefix(err, "Watched interface is not defined and failed to find a default one.", 0)
|
||||||
}
|
}
|
||||||
cfg.AdminUIDs[uid] = struct{}{}
|
c.WatchedInterface = tIf.Name
|
||||||
|
logger.Infof("Watching default interface %s", c.WatchedInterface)
|
||||||
}
|
}
|
||||||
|
|
||||||
announceCmdsEnv := os.Getenv("TG_ANNOUNCE_CMDS")
|
return nil
|
||||||
if !lo.Contains([]string{"", "no", "false", "0"}, strings.ToLower(announceCmdsEnv)) {
|
}
|
||||||
cfg.TGAnnounceCommands = true
|
|
||||||
}
|
func (c *ConfigDef) IsAdmin(id int64) bool {
|
||||||
|
_, ok := c.adminUidLookup[id]
|
||||||
cfg.WatchedInterface = "eth0"
|
return ok
|
||||||
if iface := os.Getenv("TG_WATCHED_INTERFACE"); iface != "" {
|
}
|
||||||
cfg.WatchedInterface = iface
|
|
||||||
}
|
func (c *ConfigDef) GetAPIListenAddr() net.Addr {
|
||||||
|
return c.apiListenAddr
|
||||||
cfg.MonthlyTrafficLimitGiB = 1000
|
}
|
||||||
if trafficLimitStr := os.Getenv("TG_MONTHLY_TRAFFIC_LIMIT_GIB"); trafficLimitStr != "" {
|
|
||||||
var err error
|
func configFromEnv() (*ConfigDef, error) {
|
||||||
if cfg.MonthlyTrafficLimitGiB, err = strconv.Atoi(trafficLimitStr); err != nil {
|
cfg := ConfigDef{}
|
||||||
return errors.New("invalid traffic limit: " + trafficLimitStr)
|
err := cleanenv.ReadEnv(&cfg)
|
||||||
}
|
if err != nil {
|
||||||
}
|
return nil, errors.Wrap(err, 0)
|
||||||
|
}
|
||||||
config = &cfg
|
return &cfg, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var config *ConfigDef
|
||||||
|
|
||||||
|
func LoadCfg() error {
|
||||||
|
cfg, err := configFromEnv()
|
||||||
|
if err != nil {
|
||||||
|
return errors.WrapPrefix(err, "failed to load config from environment", 0)
|
||||||
|
}
|
||||||
|
if err = cfg.Parse(); err != nil {
|
||||||
|
return errors.WrapPrefix(err, "failed to parse config", 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
config = cfg
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,53 @@
|
||||||
|
package utils
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
|
||||||
|
"github.com/go-errors/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
func getLocalIP() (net.IP, error) {
|
||||||
|
conn, err := net.Dial("udp", "1.0.0.1:53")
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
host, _, err := net.SplitHostPort(conn.LocalAddr().String())
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, 0)
|
||||||
|
}
|
||||||
|
ip := net.ParseIP(host)
|
||||||
|
if ip == nil {
|
||||||
|
return nil, errors.Errorf("failed to parse IP address %q", host)
|
||||||
|
}
|
||||||
|
return ip, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func FindDefaultAdapter() (*net.Interface, error) {
|
||||||
|
localIP, err := getLocalIP()
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.WrapPrefix(err, "failed to get local address", 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
iflist, err := net.Interfaces()
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.WrapPrefix(err, "failed to get interface list", 0)
|
||||||
|
}
|
||||||
|
for _, iface := range iflist {
|
||||||
|
addrs, err := iface.Addrs()
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for _, addr := range addrs {
|
||||||
|
ipnet, ok := addr.(*net.IPNet)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if localIP.Equal(ipnet.IP) {
|
||||||
|
return &iface, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, errors.New("failed to find default adapter")
|
||||||
|
}
|
Loading…
Reference in New Issue