package main import ( "strconv" "strings" "time" "unicode" "git.gensokyo.cafe/kkyy/tgbot_misaka_5882f7/openai" "git.gensokyo.cafe/kkyy/tgbot_misaka_5882f7/openai/prompts" "github.com/samber/lo" tele "gopkg.in/telebot.v3" ) var assistantWritingSign = "\n...📝" func matchAssistantConversation(botUsr *tele.User, msg *tele.Message) []*tele.Message { // A thread must meet the following conditions to be considered a conversation with the assistant (for now): // It has only two participants: the assistant and the user. // Or, it has only one participant, which is the user. In this case, it has to be in a private chat, or the assistant must be mentioned if it's in a group chat. // No bot commands, images, stickers, or other message types involved. participants := map[string]struct{}{} hasAssistant := false hasUser := false assistantMentioned := false isPrivateChat := true thread, err := getCachedThread(msg) if err != nil { return nil } for _, msg := range thread { from := msg.Sender chat := msg.Chat if from == nil || chat == nil { return nil } participant := strconv.FormatInt(chat.ID, 10) + ":" + strconv.FormatInt(from.ID, 10) participants[participant] = struct{}{} if chat.Type != tele.ChatPrivate { isPrivateChat = false } if !lo.Contains([]tele.ChatType{tele.ChatPrivate, tele.ChatGroup, tele.ChatSuperGroup}, chat.Type) { return nil } if msg.Text == "" { return nil } if from.ID == botUsr.ID { hasAssistant = true if strings.HasSuffix(msg.Text, assistantWritingSign) { return nil } } else { if from.IsBot { return nil } hasUser = true // Only @bot_user_name in the beginning will be counted as "mention" mentionStr := "@" + botUsr.Username if strings.HasPrefix(strings.TrimSpace(msg.Text), mentionStr) { assistantMentioned = true } } } if len(participants) > 2 { return nil } if hasAssistant && hasUser { return thread } if hasUser && (isPrivateChat || assistantMentioned) { return thread } return nil } type assistantStreamedResponseCb func(text string, finished bool) (*tele.Message, error) func assistantStreamedResponse(request openai.ChatRequest, cb assistantStreamedResponseCb) error { logger.Debugf("Openai chat request: %#+v", request) ai := openai.NewClient(config.OpenAIApiKey) resp, err := ai.ChatCompletionStream(request) if err != nil { return err } nErrs := 0 go func() { respBuilder := strings.Builder{} minWait := time.After(1 * time.Second) for { var ( nNewChunk int finished bool minWaitSatisfied bool ) Drain: for { select { case chunk, ok := <-resp.Stream: if !ok { finished = true break Drain } nNewChunk += 1 respBuilder.WriteString(chunk) default: if minWaitSatisfied { break Drain } <-minWait minWaitSatisfied = true } } if nNewChunk == 0 { if chunk, ok := <-resp.Stream; !ok { finished = true } else { respBuilder.WriteString(chunk) } } if finished { break } respoText := respBuilder.String() + assistantWritingSign minWait = time.After(691 * time.Millisecond) // renew the timer if _, err := cb(respoText, false); err != nil { logger.Warnw("failed to send partial update", "error", err) nErrs += 1 if nErrs > 3 { logger.Warnw("too many errors, aborting") return } continue } logger.Debugf("... message edited") } respText := respBuilder.String() if _, err = cb(respText, true); err != nil { logger.Warnw("assistant: failed to send message", "error", err) } }() return nil } func handleAssistantConversation(c tele.Context, thread []*tele.Message) error { me := c.Bot().Me lastMsg := thread[len(thread)-1] if err := cacheMessage(lastMsg); err != nil { logger.Warnw("failed to cache message", "error", err) } nBytes := 0 // Used to estimated number of tokens. For now we treat 3 bytes as 1 token. nBytesMax := (4096 - 512) * 3 // Leave some space for the response sysMsg := prompts.Assistant() chatReqMsgs := []openai.ChatMessage{ { Role: openai.ChatRoleSystem, Content: sysMsg, }, } nBytes += len(sysMsg) convMsgs := []openai.ChatMessage{} for l := len(thread) - 1; l >= 0; l-- { text := assistantRemoveMention(thread[l].Text, me.Username) textLen := len(text) if textLen+nBytes > nBytesMax { break } nBytes += textLen role := openai.ChatRoleUser from := thread[l].Sender if from != nil && from.ID == me.ID { role = openai.ChatRoleSystem } convMsgs = append(convMsgs, openai.ChatMessage{ Role: role, Content: text, }) } if len(convMsgs) == 0 { return c.Reply("Your message is too long (Sorry!)") } for l := len(convMsgs) - 1; l >= 0; l-- { chatReqMsgs = append(chatReqMsgs, convMsgs[l]) } req := openai.ChatRequest{ Model: openai.ModelGpt0305Turbo, Messages: chatReqMsgs, Temperature: lo.ToPtr(0.42), MaxTokens: 2048, } typingNotifyCh := make(chan struct{}) go func() { defer close(typingNotifyCh) _ = c.Bot().Notify(lastMsg.Chat, tele.Typing) }() var replyMsg *tele.Message reqErr := assistantStreamedResponse(req, func(text string, finished bool) (*tele.Message, error) { var err error if replyMsg == nil { <-typingNotifyCh replyMsg, err = c.Bot().Reply(c.Message(), text, tele.Silent) } else { replyMsg, err = c.Bot().Edit(replyMsg, text) } if finished && err == nil { if err := cacheMessage(replyMsg); err != nil { logger.Warnw("failed to cache message", "error", err) } } return replyMsg, err }) if reqErr != nil { logger.Errorw("assistant: failed to complete conversation", "error", reqErr) } return reqErr } func assistantRemoveMention(msg, name string) string { mentionStr := "@" + name orig := msg msg = strings.TrimLeftFunc(msg, unicode.IsSpace) if strings.HasPrefix(msg, mentionStr) { return msg[len(mentionStr):] } return orig }