refactor: rewrite cfg.go
This commit is contained in:
		
							parent
							
								
									b872b73ebf
								
							
						
					
					
						commit
						6971098753
					
				
							
								
								
									
										4
									
								
								bot.go
								
								
								
								
							
							
						
						
									
										4
									
								
								bot.go
								
								
								
								
							| 
						 | 
					@ -24,7 +24,7 @@ func isFromAdmin(sender *tele.User) bool {
 | 
				
			||||||
		return false
 | 
							return false
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	_, ok := config.AdminUIDs[sender.ID]
 | 
						_, ok := config.adminUidLookup[sender.ID]
 | 
				
			||||||
	return ok
 | 
						return ok
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -61,7 +61,7 @@ func initBot() (*tele.Bot, error) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	adminGrp.Handle("/dig", handleDigCmd)
 | 
						adminGrp.Handle("/dig", handleDigCmd)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// adminGrp.Handle("/test", handleTestCmd)
 | 
						// adminGrp.Handle("/test", testCmd)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return b, nil
 | 
						return b, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
							
								
								
									
										122
									
								
								cfg.go
								
								
								
								
							
							
						
						
									
										122
									
								
								cfg.go
								
								
								
								
							| 
						 | 
					@ -1,66 +1,94 @@
 | 
				
			||||||
package main
 | 
					package main
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"os"
 | 
						"net"
 | 
				
			||||||
	"strconv"
 | 
						"strconv"
 | 
				
			||||||
	"strings"
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"git.gensokyo.cafe/kkyy/tgbot_misaka_5882f7/utils"
 | 
				
			||||||
	"github.com/go-errors/errors"
 | 
						"github.com/go-errors/errors"
 | 
				
			||||||
	"github.com/samber/lo"
 | 
						"github.com/ilyakaznacheev/cleanenv"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type Config struct {
 | 
					type ConfigDef struct {
 | 
				
			||||||
	AdminUIDs          map[int64]struct{}
 | 
						AdminUIDs          []int64 `env:"TG_ADMIN_UIDS"`
 | 
				
			||||||
	TGBotToken         string
 | 
						TGBotToken         string  `env:"TG_TOKEN" env-required:""`
 | 
				
			||||||
	TGAnnounceCommands bool
 | 
						TGAnnounceCommands bool    `env:"TG_ANNOUNCE_CMDS"`
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	WatchedInterface       string
 | 
						// API
 | 
				
			||||||
	MonthlyTrafficLimitGiB int
 | 
						APIMasterKey string `env:"TG_API_MASTER_KEY" env-required:""`
 | 
				
			||||||
 | 
						APIListen    string `env:"TG_API_LISTEN" env-default:"127.0.0.1:8080"`
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Traffic usage
 | 
				
			||||||
 | 
						WatchedInterface       string `env:"TG_WATCHED_INTERFACE"`
 | 
				
			||||||
 | 
						MonthlyTrafficLimitGiB int    `env:"TG_MONTHLY_TRAFFIC_LIMIT_GIB" env-default:"1000"`
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Parsed fields
 | 
				
			||||||
 | 
						adminUidLookup map[int64]struct{}
 | 
				
			||||||
 | 
						apiListenAddr  net.Addr
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
var config *Config
 | 
					func (c *ConfigDef) Parse() error {
 | 
				
			||||||
 | 
						c.adminUidLookup = make(map[int64]struct{}, len(c.AdminUIDs))
 | 
				
			||||||
func LoadCfg() error {
 | 
						for _, u := range c.AdminUIDs {
 | 
				
			||||||
	cfg := Config{}
 | 
							if u <= 0 {
 | 
				
			||||||
 | 
								return errors.New("invalid admin UID: " + strconv.FormatInt(u, 10))
 | 
				
			||||||
	token := os.Getenv("TG_TOKEN")
 | 
							}
 | 
				
			||||||
	if token == "" {
 | 
							c.adminUidLookup[u] = struct{}{}
 | 
				
			||||||
		return errors.New("TG_TOKEN env not set")
 | 
					 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	cfg.TGBotToken = token
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
	adminUIDsEnv := os.Getenv("TG_ADMIN_UIDS")
 | 
						// validate API listen
 | 
				
			||||||
	adminUIDs := lo.FilterMap(strings.Split(adminUIDsEnv, ","), func(s string, _ int) (string, bool) {
 | 
						ip, port, err := net.SplitHostPort(c.APIListen)
 | 
				
			||||||
		trimmed := strings.TrimSpace(s)
 | 
						if err != nil {
 | 
				
			||||||
		return trimmed, trimmed != ""
 | 
							return errors.WrapPrefix(err, "invalid API listen address", 0)
 | 
				
			||||||
	})
 | 
						}
 | 
				
			||||||
	cfg.AdminUIDs = make(map[int64]struct{}, len(adminUIDs))
 | 
						if tIP := net.ParseIP(ip); tIP == nil {
 | 
				
			||||||
	for _, uidStr := range adminUIDs {
 | 
							return errors.New("invalid API listen address: " + ip)
 | 
				
			||||||
		uid, err := strconv.ParseInt(uidStr, 10, 64)
 | 
						}
 | 
				
			||||||
 | 
						if tPort, err := strconv.Atoi(port); err != nil || tPort <= 0 || tPort > 65535 {
 | 
				
			||||||
 | 
							return errors.New("invalid API listen port: " + port)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if c.WatchedInterface == "" {
 | 
				
			||||||
 | 
							tIf, err := utils.FindDefaultAdapter()
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			return errors.New("invalid admin UID: " + uidStr)
 | 
								return errors.WrapPrefix(err, "Watched interface is not defined and failed to find a default one.", 0)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		cfg.AdminUIDs[uid] = struct{}{}
 | 
							c.WatchedInterface = tIf.Name
 | 
				
			||||||
 | 
							logger.Infof("Watching default interface %s", c.WatchedInterface)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	announceCmdsEnv := os.Getenv("TG_ANNOUNCE_CMDS")
 | 
						return nil
 | 
				
			||||||
	if !lo.Contains([]string{"", "no", "false", "0"}, strings.ToLower(announceCmdsEnv)) {
 | 
					}
 | 
				
			||||||
		cfg.TGAnnounceCommands = true
 | 
					
 | 
				
			||||||
	}
 | 
					func (c *ConfigDef) IsAdmin(id int64) bool {
 | 
				
			||||||
 | 
						_, ok := c.adminUidLookup[id]
 | 
				
			||||||
	cfg.WatchedInterface = "eth0"
 | 
						return ok
 | 
				
			||||||
	if iface := os.Getenv("TG_WATCHED_INTERFACE"); iface != "" {
 | 
					}
 | 
				
			||||||
		cfg.WatchedInterface = iface
 | 
					
 | 
				
			||||||
	}
 | 
					func (c *ConfigDef) GetAPIListenAddr() net.Addr {
 | 
				
			||||||
 | 
						return c.apiListenAddr
 | 
				
			||||||
	cfg.MonthlyTrafficLimitGiB = 1000
 | 
					}
 | 
				
			||||||
	if trafficLimitStr := os.Getenv("TG_MONTHLY_TRAFFIC_LIMIT_GIB"); trafficLimitStr != "" {
 | 
					
 | 
				
			||||||
		var err error
 | 
					func configFromEnv() (*ConfigDef, error) {
 | 
				
			||||||
		if cfg.MonthlyTrafficLimitGiB, err = strconv.Atoi(trafficLimitStr); err != nil {
 | 
						cfg := ConfigDef{}
 | 
				
			||||||
			return errors.New("invalid traffic limit: " + trafficLimitStr)
 | 
						err := cleanenv.ReadEnv(&cfg)
 | 
				
			||||||
		}
 | 
						if err != nil {
 | 
				
			||||||
	}
 | 
							return nil, errors.Wrap(err, 0)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
	config = &cfg
 | 
						return &cfg, nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					var config *ConfigDef
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func LoadCfg() error {
 | 
				
			||||||
 | 
						cfg, err := configFromEnv()
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return errors.WrapPrefix(err, "failed to load config from environment", 0)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if err = cfg.Parse(); err != nil {
 | 
				
			||||||
 | 
							return errors.WrapPrefix(err, "failed to parse config", 0)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						config = cfg
 | 
				
			||||||
	return nil
 | 
						return nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,53 @@
 | 
				
			||||||
 | 
					package utils
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"net"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"github.com/go-errors/errors"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func getLocalIP() (net.IP, error) {
 | 
				
			||||||
 | 
						conn, err := net.Dial("udp", "1.0.0.1:53")
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						defer conn.Close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						host, _, err := net.SplitHostPort(conn.LocalAddr().String())
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, errors.Wrap(err, 0)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						ip := net.ParseIP(host)
 | 
				
			||||||
 | 
						if ip == nil {
 | 
				
			||||||
 | 
							return nil, errors.Errorf("failed to parse IP address %q", host)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return ip, nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func FindDefaultAdapter() (*net.Interface, error) {
 | 
				
			||||||
 | 
						localIP, err := getLocalIP()
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, errors.WrapPrefix(err, "failed to get local address", 0)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						iflist, err := net.Interfaces()
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, errors.WrapPrefix(err, "failed to get interface list", 0)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						for _, iface := range iflist {
 | 
				
			||||||
 | 
							addrs, err := iface.Addrs()
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								continue
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							for _, addr := range addrs {
 | 
				
			||||||
 | 
								ipnet, ok := addr.(*net.IPNet)
 | 
				
			||||||
 | 
								if !ok {
 | 
				
			||||||
 | 
									continue
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								if localIP.Equal(ipnet.IP) {
 | 
				
			||||||
 | 
									return &iface, nil
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return nil, errors.New("failed to find default adapter")
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
		Loading…
	
		Reference in New Issue