tgbot_misaka_5882f7/assistant.go

286 lines
6.9 KiB
Go

package main
import (
"crypto/md5"
"encoding/base64"
"strconv"
"strings"
"time"
"unicode"
"git.gensokyo.cafe/kkyy/tgbot_misaka_5882f7/openai"
"git.gensokyo.cafe/kkyy/tgbot_misaka_5882f7/openai/prompts"
"github.com/samber/lo"
tele "gopkg.in/telebot.v3"
)
var assistantWritingSign = "\n... 📝"
func matchAssistantConversation(botUsr *tele.User, msg *tele.Message) []*tele.Message {
// A thread must meet the following conditions to be considered a conversation with the assistant (for now):
// It has only two participants: the assistant and the user.
// Or, it has only one participant, which is the user. In this case, it has to be in a private chat, or the assistant must be mentioned if it's in a group chat.
// No bot commands, images, stickers, or other message types involved.
participants := map[string]struct{}{}
hasAssistant := false
hasUser := false
assistantMentioned := false
isPrivateChat := true
thread, err := getCachedThread(msg)
if err != nil {
return nil
}
for _, msg := range thread {
for _, ent := range msg.Entities {
if ent.Type == tele.EntityCommand {
return nil
}
}
from := msg.Sender
chat := msg.Chat
if from == nil || chat == nil {
return nil
}
participant := strconv.FormatInt(chat.ID, 10) + ":" + strconv.FormatInt(from.ID, 10)
participants[participant] = struct{}{}
if chat.Type != tele.ChatPrivate {
isPrivateChat = false
}
if !lo.Contains([]tele.ChatType{tele.ChatPrivate, tele.ChatGroup, tele.ChatSuperGroup}, chat.Type) {
return nil
}
if msg.Text == "" {
return nil
}
if from.ID == botUsr.ID {
hasAssistant = true
if strings.HasSuffix(msg.Text, assistantWritingSign) {
return nil
}
} else {
if from.IsBot {
return nil
}
hasUser = true
// Only @bot_user_name in the beginning will be counted as "mention"
mentionStr := "@" + botUsr.Username
if strings.HasPrefix(strings.TrimSpace(msg.Text), mentionStr) {
assistantMentioned = true
}
}
}
if len(participants) > 2 {
return nil
}
if hasAssistant && hasUser {
return thread
}
if hasUser && (isPrivateChat || assistantMentioned) {
return thread
}
return nil
}
type assistantStreamedResponseCb func(text string, finished bool) error
// TODO interrupt response with context.
func assistantStreamedResponse(request openai.ChatRequest, cb assistantStreamedResponseCb) error {
logger.Debugw("Openai chat request", "req", request)
ai := openai.NewClient(config.OpenAIApiKey)
var (
resp *openai.ChatResponseStream
err error
)
tries := 2
for {
if tries--; tries < 0 {
return err
}
if resp, err = ai.ChatCompletionStream(request); err == nil {
break
}
logger.Warnw("assistant: failed to get response", "error", err)
}
var nErrs, nUpdates int
go func() {
respBuilder := strings.Builder{}
minWait := time.After(time.Second)
for {
nUpdates += 1
var (
nNewChunk int
finished bool
minWaitSatisfied bool
)
Drain:
for {
select {
case chunk, ok := <-resp.Stream:
if !ok {
finished = true
break Drain
}
nNewChunk += 1
respBuilder.WriteString(chunk)
default:
if minWaitSatisfied {
break Drain
}
select {
case <-minWait:
case <-resp.Done:
}
minWaitSatisfied = true
}
}
if nNewChunk == 0 {
if chunk, ok := <-resp.Stream; !ok {
finished = true
} else {
respBuilder.WriteString(chunk)
}
}
if finished {
break
}
// Renew the timer.
// The Telegram API rate limit for group messages is 20 per minute. So we cannot update messages too frequently.
minWaitDurSecs := lo.Min([]int{nUpdates, 4}) + nErrs*3
minWait = time.After(time.Duration(minWaitDurSecs) * time.Second)
// Send the partial message
respText := respBuilder.String() + assistantWritingSign
if err := cb(respText, false); err != nil {
logger.Warnw("failed to send partial update", "error", err)
nErrs += 1
if nErrs > 5 {
logger.Warnw("too many errors, aborting")
return
}
continue
}
logger.Debugf("... message edited")
}
respText := respBuilder.String()
if err = cb(respText, true); err != nil {
logger.Warnw("assistant: failed to send message", "error", err)
}
}()
return nil
}
func handleAssistantConversation(c tele.Context, thread []*tele.Message) error {
me := c.Bot().Me
lastMsg := thread[len(thread)-1]
if err := cacheMessage(lastMsg); err != nil {
logger.Warnw("failed to cache message", "error", err)
}
nBytes := 0 // Used to estimated number of tokens. For now we treat 3 bytes as 1 token.
nBytesMax := (4096 - 512) * 3 // Leave some space for the response
sysMsg := prompts.Assistant()
chatReqMsgs := []openai.ChatMessage{
{
Role: openai.ChatRoleSystem,
Content: sysMsg,
},
}
nBytes += len(sysMsg)
convMsgs := []openai.ChatMessage{}
for l := len(thread) - 1; l >= 0; l-- {
text := assistantRemoveMention(thread[l].Text, me.Username)
textLen := len(text)
if textLen+nBytes > nBytesMax {
break
}
nBytes += textLen
role := openai.ChatRoleUser
from := thread[l].Sender
if from != nil && from.ID == me.ID {
role = openai.ChatRoleAssistant
}
convMsgs = append(convMsgs, openai.ChatMessage{
Role: role,
Content: text,
})
}
if len(convMsgs) == 0 {
// It turns out that this will never happen because Telegram splits messages when they exceed a certain length.
return c.Reply("Your message is too long (Sorry!)")
}
for l := len(convMsgs) - 1; l >= 0; l-- {
chatReqMsgs = append(chatReqMsgs, convMsgs[l])
}
req := openai.ChatRequest{
Model: openai.ModelGpt04,
Messages: chatReqMsgs,
Temperature: lo.ToPtr(0.42),
MaxTokens: 2048,
User: assistantHashUserId(lastMsg.Sender.ID),
}
typingNotifyCh := setTyping(c)
var replyMsg *tele.Message
reqErr := assistantStreamedResponse(req, func(text string, finished bool) error {
var err error
if replyMsg == nil {
<-typingNotifyCh
replyMsg, err = c.Bot().Reply(lastMsg, text, tele.Silent)
} else {
replyMsg, err = c.Bot().Edit(replyMsg, text)
}
if finished && err == nil {
replyMsg.ReplyTo = lastMsg // nasty bug
if err := cacheMessage(replyMsg); err != nil {
logger.Warnw("failed to cache message", "error", err)
}
}
return err
})
if reqErr != nil {
logger.Errorw("assistant: failed to complete conversation", "error", reqErr)
return c.Reply("Sorry, there's a technical issue. 😵💫 Please try again later.", tele.Silent)
}
return nil
}
func assistantRemoveMention(msg, name string) string {
mentionStr := "@" + name
orig := msg
msg = strings.TrimLeftFunc(msg, unicode.IsSpace)
if strings.HasPrefix(msg, mentionStr) {
return msg[len(mentionStr):]
}
return orig
}
func assistantHashUserId(uid int64) string {
seasoned := []byte("RdnuRPqp66vtbc28QRO0ecKSLKXifz7G9UbXLoyCMpw" + strconv.FormatInt(uid, 10))
hashed := md5.Sum(seasoned) // Don't judge me
return base64.URLEncoding.EncodeToString(hashed[:])[:22]
}