From 899491e8677b8557fb673263d737a0aa9a2267c9 Mon Sep 17 00:00:00 2001 From: Yiyang Kang Date: Wed, 8 Mar 2023 02:28:26 +0800 Subject: [PATCH] feat: add translate command --- bot.go | 12 ++--- botcmd_translate.go | 111 ++++++++++++++++++++++++++++++++++++++ cfg.go | 3 ++ main.go | 7 +-- openai/chat.go | 46 ++++++++++++++++ openai/client.go | 39 ++++++++++++++ openai/models.go | 7 +++ openai/prompts/prompts.go | 14 +++++ stickers.go | 6 +++ 9 files changed, 236 insertions(+), 9 deletions(-) create mode 100644 botcmd_translate.go create mode 100644 openai/chat.go create mode 100644 openai/client.go create mode 100644 openai/models.go create mode 100644 openai/prompts/prompts.go create mode 100644 stickers.go diff --git a/bot.go b/bot.go index 11b83fd..a808480 100644 --- a/bot.go +++ b/bot.go @@ -14,11 +14,6 @@ import ( "git.gensokyo.cafe/kkyy/tgbot_misaka_5882f7/cmds" ) -const ( - stickerPanic = "CAACAgUAAxkBAAMjY3zoraxZGB8Xejyw86bHLSWLjVcAArMIAAL7-nhXNK7dStmRUGsrBA" - stickerLoading = "CAACAgUAAxkBAAMmY3zp5UCMVRvy1isFCPHrx-UBWX8AApYHAALP8GhXEm9ZIBjn1v8rBA" -) - func isFromAdmin(sender *tele.User) bool { if sender == nil { return false @@ -43,12 +38,17 @@ func initBot() (*tele.Bot, error) { // command routing b.Handle("/start", handleStartCmd) - b.Handle("/me", handleUserInfoCmd) b.Handle("/chat", handleChatInfoCmd) + b.Handle("/year_progress", handleYearProgressCmd) b.Handle("/xr", handleExchangeRateCmd) + b.Handle("/tr", handleTranslateCmd) + for _, tbtn := range translateBtns { + b.Handle(tbtn, handleTranslateBtn) + } + b.Handle(tele.OnText, handleGeneralMessage) b.Handle(tele.OnSticker, handleGeneralMessage) diff --git a/botcmd_translate.go b/botcmd_translate.go new file mode 100644 index 0000000..82a38b2 --- /dev/null +++ b/botcmd_translate.go @@ -0,0 +1,111 @@ +package main + +import ( + "regexp" + "strings" + + tele "gopkg.in/telebot.v3" + + "git.gensokyo.cafe/kkyy/tgbot_misaka_5882f7/openai" + "git.gensokyo.cafe/kkyy/tgbot_misaka_5882f7/openai/prompts" +) + +var ( + translateMenu = &tele.ReplyMarkup{} + + translateBtnZhTw = translateMenu.Data("繁中", "btn_tr_zhtw", "Taiwanese Chinese") + translateBtnZhCn = translateMenu.Data("简中", "btn_tr_zhcn", "Mandarin Chinese") + translateBtnEn = translateMenu.Data("English", "btn_tr_en", "English") + translateBtnJa = translateMenu.Data("日本語", "btn_tr_ja", "Japanese") + + translateBtns = []*tele.Btn{ + &translateBtnZhTw, &translateBtnZhCn, &translateBtnEn, &translateBtnJa, + } + + translateCmdRe = regexp.MustCompile(`^\s*\/tr(anslate)?(@\S*)?\s*`) +) + +func init() { + translateMenu.Inline( + translateMenu.Row(translateBtnZhTw, translateBtnZhCn), + translateMenu.Row(translateBtnEn, translateBtnJa), + ) +} + +func handleTranslateCmd(c tele.Context) error { + msg := c.Message() + if msg == nil { + return nil + } + if msg.ReplyTo != nil { + msg = msg.ReplyTo + } + + payload := strings.TrimSpace(translateCmdRe.ReplaceAllString(msg.Text, "")) + if payload == "" { + return c.Reply("Usage: `/tr `", + &tele.SendOptions{ParseMode: tele.ModeMarkdown}, + tele.Silent, + ) + } + logger.Infof("trimmed message: %q", payload) + + _, err := c.Bot().Reply(msg, "Sure. To what language?", tele.Silent, translateMenu) + return err +} + +func handleTranslateBtn(c tele.Context) error { + msg := c.Message() + if msg == nil || msg.ReplyTo == nil { + return nil + } + origMsg := msg.ReplyTo + targetLang := c.Data() + txt := origMsg.Text + payload := strings.TrimSpace(translateCmdRe.ReplaceAllString(txt, "")) + + if targetLang == "" || payload == "" { + return nil + } + + // change the temporary message + if err := c.Edit("Sure, please wait..."); err != nil { + logger.Warnf("failed to alter the temporary message: %v", err) + } + // pretend to be typing + if err := c.Bot().Notify(msg.Chat, tele.Typing); err != nil { + logger.Warnf("failed to send typing action: %v", err) + } + + ai := openai.NewClient(config.OpenAIApiKey) + + req := openai.ChatRequest{ + Model: openai.ModelGpt0305Turbo, + Messages: []openai.ChatMessage{ + { + Role: openai.ChatRoleSystem, + Content: prompts.Translate(targetLang), + }, + { + Role: openai.ChatRoleUser, + Content: payload, + }, + }, + } + + resp, err := ai.ChatCompletion(req) + if err != nil { + logger.Errorf("failed to translate: req: %#+v, err: %v", req, err) + _, err := c.Bot().Reply(origMsg, stickerFromID(stickerPanic), tele.Silent) + return err + } + + respText := resp.Choices[0].Message.Content + _, err = c.Bot().Reply(origMsg, respText, tele.Silent) + + // delete the temporary message + if err := c.Delete(); err != nil { + logger.Warnf("failed to delete the temporary message: %v", err) + } + return err +} diff --git a/cfg.go b/cfg.go index eec1f2b..15a34e4 100644 --- a/cfg.go +++ b/cfg.go @@ -22,6 +22,9 @@ type ConfigDef struct { WatchedInterface string `env:"TG_WATCHED_INTERFACE"` MonthlyTrafficLimitGiB int `env:"TG_MONTHLY_TRAFFIC_LIMIT_GIB" env-default:"1000"` + // AI + OpenAIApiKey string `env:"TG_OPENAI_API_KEY"` + // Parsed fields adminUidLookup map[int64]struct{} apiListenAddr net.Addr diff --git a/main.go b/main.go index a3c87c3..d8fab53 100644 --- a/main.go +++ b/main.go @@ -6,10 +6,10 @@ import ( "syscall" "time" - "git.gensokyo.cafe/kkyy/mycurrencynet" "go.uber.org/zap" tele "gopkg.in/telebot.v3" + "git.gensokyo.cafe/kkyy/mycurrencynet" "git.gensokyo.cafe/kkyy/tgbot_misaka_5882f7/utils" ) @@ -55,10 +55,11 @@ func runBot() { logger.Info("Announcing commands...") if err = bot.SetCommands([]tele.Command{ + {Text: "tr", Description: "Translate text"}, + {Text: "xr", Description: "Currency exchange rates"}, + {Text: "year_progress", Description: "Time doesn't wait."}, {Text: "traffic", Description: "Show traffic usage."}, {Text: "dig", Description: "Diggy diggy dig."}, - {Text: "year_progress", Description: "Time doesn't wait."}, - {Text: "xr", Description: "Currency exchange rates"}, }); err != nil { logger.Fatalw("Failed to announce commands", "err", err) } diff --git a/openai/chat.go b/openai/chat.go new file mode 100644 index 0000000..882477a --- /dev/null +++ b/openai/chat.go @@ -0,0 +1,46 @@ +package openai + +const ChatAPIPath = "/v1/chat/completions" + +type ChatRole string + +const ( + ChatRoleSystem ChatRole = "system" + ChatRoleAssistant ChatRole = "assistant" + ChatRoleUser ChatRole = "user" +) + +type ChatMessage struct { + Role ChatRole `json:"role"` + Content string `json:"content"` +} + +type ChatRequest struct { + Model string `json:"model"` + Messages []ChatMessage `json:"messages"` + Temperature *float64 `json:"temperature,omitempty"` + TopP *float64 `json:"top_p,omitempty"` + N int `json:"n,omitempty"` + Stream bool `json:"stream,omitempty"` + Stop []string `json:"stop,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + PresencePenalty *float64 `json:"presence_penalty,omitempty"` + FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` + LogitBias map[string]float64 `json:"logit_bias,omitempty"` + User string `json:"user,omitempty"` +} + +type ChatResponseChoice struct { + Message ChatMessage `json:"message"` + FinishReason string `json:"finish_reason"` + Index int `json:"index"` +} + +type ChatResponse struct { + ID string `json:"id"` + Object string `json:"object"` + Created int `json:"created"` + Model string `json:"model"` + Usage map[string]int `json:"usage"` + Choices []ChatResponseChoice `json:"choices"` +} diff --git a/openai/client.go b/openai/client.go new file mode 100644 index 0000000..6b34807 --- /dev/null +++ b/openai/client.go @@ -0,0 +1,39 @@ +package openai + +import ( + "time" + + "github.com/go-errors/errors" + "github.com/go-resty/resty/v2" +) + +type Client struct { + rest *resty.Client +} + +func NewClient(apiKey string) *Client { + cli := resty.New(). + SetBaseURL("https://api.openai.com"). + SetHeader("Authorization", "Bearer "+apiKey). + SetTimeout(30 * time.Second) + + return &Client{rest: cli} +} + +func (c *Client) ChatCompletion(request ChatRequest) (*ChatResponse, error) { + resp, err := c.rest.R(). + SetBody(request). + SetHeader("Content-Type", "application/json"). + SetResult(&ChatResponse{}). + Post(ChatAPIPath) + + if err != nil { + return nil, errors.Wrap(err, 0) + } + + if resp.StatusCode() != 200 { + return nil, errors.Errorf("unexpected status code: %d", resp.StatusCode()) + } + + return resp.Result().(*ChatResponse), nil +} diff --git a/openai/models.go b/openai/models.go new file mode 100644 index 0000000..957190e --- /dev/null +++ b/openai/models.go @@ -0,0 +1,7 @@ +package openai + +const ( + ModelTextDavinciEdit001 = "text-davinci-edit-001" + ModelGpt0305Turbo = "gpt-3.5-turbo" + ModelGpt0305Turbo0301 = "gpt-3.5-turbo-0301" +) diff --git a/openai/prompts/prompts.go b/openai/prompts/prompts.go new file mode 100644 index 0000000..43f5ef5 --- /dev/null +++ b/openai/prompts/prompts.go @@ -0,0 +1,14 @@ +package prompts + +import "fmt" + +func General() string { + return "You are a helpful assistant." +} + +func Translate(targetLang string) string { + return fmt.Sprintf( + "You are a helpful assistant. Your task is to help translate the following text to %s. You should not interpret the text. You should structure the translated text to look natural in native %s, while keeping the meaning unchanged.", + targetLang, targetLang, + ) +} diff --git a/stickers.go b/stickers.go new file mode 100644 index 0000000..a0d849f --- /dev/null +++ b/stickers.go @@ -0,0 +1,6 @@ +package main + +const ( + stickerPanic = "CAACAgUAAxkBAAMjY3zoraxZGB8Xejyw86bHLSWLjVcAArMIAAL7-nhXNK7dStmRUGsrBA" + stickerLoading = "CAACAgUAAxkBAAMmY3zp5UCMVRvy1isFCPHrx-UBWX8AApYHAALP8GhXEm9ZIBjn1v8rBA" +)