package main import ( "crypto/md5" "encoding/base64" "strconv" "strings" "time" "unicode" "git.gensokyo.cafe/kkyy/tgbot_misaka_5882f7/openai" "git.gensokyo.cafe/kkyy/tgbot_misaka_5882f7/openai/prompts" "github.com/samber/lo" tele "gopkg.in/telebot.v3" ) var assistantWritingSign = "\n... 📝" func matchAssistantConversation(botUsr *tele.User, msg *tele.Message) []*tele.Message { // A thread must meet the following conditions to be considered a conversation with the assistant (for now): // It has only two participants: the assistant and the user. // Or, it has only one participant, which is the user. In this case, it has to be in a private chat, or the assistant must be mentioned if it's in a group chat. // No bot commands, images, stickers, or other message types involved. participants := map[string]struct{}{} hasAssistant := false hasUser := false assistantMentioned := false isPrivateChat := true thread, err := getCachedThread(msg) if err != nil { return nil } for _, msg := range thread { for _, ent := range msg.Entities { if ent.Type == tele.EntityCommand { return nil } } from := msg.Sender chat := msg.Chat if from == nil || chat == nil { return nil } participant := strconv.FormatInt(chat.ID, 10) + ":" + strconv.FormatInt(from.ID, 10) participants[participant] = struct{}{} if chat.Type != tele.ChatPrivate { isPrivateChat = false } if !lo.Contains([]tele.ChatType{tele.ChatPrivate, tele.ChatGroup, tele.ChatSuperGroup}, chat.Type) { return nil } if msg.Text == "" { return nil } if from.ID == botUsr.ID { hasAssistant = true if strings.HasSuffix(msg.Text, assistantWritingSign) { return nil } } else { if from.IsBot { return nil } hasUser = true // Only @bot_user_name in the beginning will be counted as "mention" mentionStr := "@" + botUsr.Username if strings.HasPrefix(strings.TrimSpace(msg.Text), mentionStr) { assistantMentioned = true } } } if len(participants) > 2 { return nil } if hasAssistant && hasUser { return thread } if hasUser && (isPrivateChat || assistantMentioned) { return thread } return nil } type assistantStreamedResponseCb func(text string, finished bool) error // TODO interrupt response with context. func assistantStreamedResponse(request openai.ChatRequest, cb assistantStreamedResponseCb) error { logger.Debugw("Openai chat request", "req", request) ai := openai.NewClient(config.OpenAIApiKey) var ( resp *openai.ChatResponseStream err error ) tries := 2 for { if tries--; tries < 0 { return err } if resp, err = ai.ChatCompletionStream(request); err == nil { break } logger.Warnw("assistant: failed to get response", "error", err) } var nErrs, nUpdates int go func() { respBuilder := strings.Builder{} minWait := time.After(time.Second) for { nUpdates += 1 var ( nNewChunk int finished bool minWaitSatisfied bool ) Drain: for { select { case chunk, ok := <-resp.Stream: if !ok { finished = true break Drain } nNewChunk += 1 respBuilder.WriteString(chunk) default: if minWaitSatisfied { break Drain } select { case <-minWait: case <-resp.Done: } minWaitSatisfied = true } } if nNewChunk == 0 { if chunk, ok := <-resp.Stream; !ok { finished = true } else { respBuilder.WriteString(chunk) } } if finished { break } // Renew the timer. // The Telegram API rate limit for group messages is 20 per minute. So we cannot update messages too frequently. minWaitDurSecs := lo.Min([]int{nUpdates, 4}) + nErrs*3 minWait = time.After(time.Duration(minWaitDurSecs) * time.Second) // Send the partial message respText := respBuilder.String() + assistantWritingSign if err := cb(respText, false); err != nil { logger.Warnw("failed to send partial update", "error", err) nErrs += 1 if nErrs > 5 { logger.Warnw("too many errors, aborting") return } continue } logger.Debugf("... message edited") } respText := respBuilder.String() if err = cb(respText, true); err != nil { logger.Warnw("assistant: failed to send message", "error", err) } }() return nil } func handleAssistantConversation(c tele.Context, thread []*tele.Message) error { me := c.Bot().Me lastMsg := thread[len(thread)-1] if err := cacheMessage(lastMsg); err != nil { logger.Warnw("failed to cache message", "error", err) } nBytes := 0 // Used to estimated number of tokens. For now we treat 3 bytes as 1 token. nBytesMax := (4096 - 512) * 3 // Leave some space for the response sysMsg := prompts.Assistant() chatReqMsgs := []openai.ChatMessage{ { Role: openai.ChatRoleSystem, Content: sysMsg, }, } nBytes += len(sysMsg) convMsgs := []openai.ChatMessage{} for l := len(thread) - 1; l >= 0; l-- { text := assistantRemoveMention(thread[l].Text, me.Username) textLen := len(text) if textLen+nBytes > nBytesMax { break } nBytes += textLen role := openai.ChatRoleUser from := thread[l].Sender if from != nil && from.ID == me.ID { role = openai.ChatRoleAssistant } convMsgs = append(convMsgs, openai.ChatMessage{ Role: role, Content: text, }) } if len(convMsgs) == 0 { // It turns out that this will never happen because Telegram splits messages when they exceed a certain length. return c.Reply("Your message is too long (Sorry!)") } for l := len(convMsgs) - 1; l >= 0; l-- { chatReqMsgs = append(chatReqMsgs, convMsgs[l]) } req := openai.ChatRequest{ Model: openai.ModelGpt4O, Messages: chatReqMsgs, Temperature: lo.ToPtr(0.42), MaxTokens: 2048, User: assistantHashUserId(lastMsg.Sender.ID), } typingNotifyCh := setTyping(c) var replyMsg *tele.Message reqErr := assistantStreamedResponse(req, func(text string, finished bool) error { var err error if replyMsg == nil { <-typingNotifyCh replyMsg, err = c.Bot().Reply(lastMsg, text, tele.Silent) } else { replyMsg, err = c.Bot().Edit(replyMsg, text) } if finished && err == nil { replyMsg.ReplyTo = lastMsg // nasty bug if err := cacheMessage(replyMsg); err != nil { logger.Warnw("failed to cache message", "error", err) } } return err }) if reqErr != nil { logger.Errorw("assistant: failed to complete conversation", "error", reqErr) return c.Reply("Sorry, there's a technical issue. 😵💫 Please try again later.", tele.Silent) } return nil } func assistantRemoveMention(msg, name string) string { mentionStr := "@" + name orig := msg msg = strings.TrimLeftFunc(msg, unicode.IsSpace) if strings.HasPrefix(msg, mentionStr) { return msg[len(mentionStr):] } return orig } func assistantHashUserId(uid int64) string { seasoned := []byte("RdnuRPqp66vtbc28QRO0ecKSLKXifz7G9UbXLoyCMpw" + strconv.FormatInt(uid, 10)) hashed := md5.Sum(seasoned) // Don't judge me return base64.URLEncoding.EncodeToString(hashed[:])[:22] }