diff --git a/assistant.go b/assistant.go new file mode 100644 index 0000000..a7f9430 --- /dev/null +++ b/assistant.go @@ -0,0 +1,250 @@ +package main + +import ( + "strconv" + "strings" + "time" + "unicode" + + "git.gensokyo.cafe/kkyy/tgbot_misaka_5882f7/openai" + "git.gensokyo.cafe/kkyy/tgbot_misaka_5882f7/openai/prompts" + "github.com/samber/lo" + tele "gopkg.in/telebot.v3" +) + +var assistantWritingSign = "\n...📝" + +func matchAssistantConversation(botUsr *tele.User, msg *tele.Message) []*tele.Message { + // A thread must meet the following conditions to be considered a conversation with the assistant (for now): + // It has only two participants: the assistant and the user. + // Or, it has only one participant, which is the user. In this case, it has to be in a private chat, or the assistant must be mentioned if it's in a group chat. + // No bot commands, images, stickers, or other message types involved. + participants := map[string]struct{}{} + hasAssistant := false + hasUser := false + assistantMentioned := false + isPrivateChat := true + + thread, err := getCachedThread(msg) + if err != nil { + return nil + } + + for _, msg := range thread { + from := msg.Sender + chat := msg.Chat + if from == nil || chat == nil { + return nil + } + participant := strconv.FormatInt(chat.ID, 10) + ":" + strconv.FormatInt(from.ID, 10) + participants[participant] = struct{}{} + + if chat.Type != tele.ChatPrivate { + isPrivateChat = false + } + if !lo.Contains([]tele.ChatType{tele.ChatPrivate, tele.ChatGroup, tele.ChatSuperGroup}, chat.Type) { + return nil + } + + if msg.Text == "" { + return nil + } + + if from.ID == botUsr.ID { + hasAssistant = true + if strings.HasSuffix(msg.Text, assistantWritingSign) { + return nil + } + } else { + if from.IsBot { + return nil + } + hasUser = true + // Only @bot_user_name in the beginning will be counted as "mention" + mentionStr := "@" + botUsr.Username + if strings.HasPrefix(strings.TrimSpace(msg.Text), mentionStr) { + assistantMentioned = true + } + } + } + + if len(participants) > 2 { + return nil + } + if hasAssistant && hasUser { + return thread + } + if hasUser && (isPrivateChat || assistantMentioned) { + return thread + } + + return nil +} + +type assistantStreamedResponseCb func(text string, finished bool) (*tele.Message, error) + +func assistantStreamedResponse(request openai.ChatRequest, cb assistantStreamedResponseCb) error { + logger.Debugf("Openai chat request: %#+v", request) + ai := openai.NewClient(config.OpenAIApiKey) + + resp, err := ai.ChatCompletionStream(request) + if err != nil { + return err + } + + nErrs := 0 + go func() { + respBuilder := strings.Builder{} + minWait := time.After(1 * time.Second) + for { + var ( + nNewChunk int + finished bool + minWaitSatisfied bool + ) + Drain: + for { + select { + case chunk, ok := <-resp.Stream: + if !ok { + finished = true + break Drain + } + nNewChunk += 1 + respBuilder.WriteString(chunk) + default: + if minWaitSatisfied { + break Drain + } + <-minWait + minWaitSatisfied = true + } + } + + if nNewChunk == 0 { + if chunk, ok := <-resp.Stream; !ok { + finished = true + } else { + respBuilder.WriteString(chunk) + } + } + if finished { + break + } + + respoText := respBuilder.String() + assistantWritingSign + minWait = time.After(691 * time.Millisecond) // renew the timer + + if _, err := cb(respoText, false); err != nil { + logger.Warnw("failed to send partial update", "error", err) + nErrs += 1 + if nErrs > 3 { + logger.Warnw("too many errors, aborting") + return + } + continue + } + + logger.Debugf("... message edited") + } + + respText := respBuilder.String() + if _, err = cb(respText, true); err != nil { + logger.Warnw("assistant: failed to send message", "error", err) + } + }() + + return nil +} + +func handleAssistantConversation(c tele.Context, thread []*tele.Message) error { + me := c.Bot().Me + lastMsg := thread[len(thread)-1] + if err := cacheMessage(lastMsg); err != nil { + logger.Warnw("failed to cache message", "error", err) + } + + nBytes := 0 // Used to estimated number of tokens. For now we treat 3 bytes as 1 token. + nBytesMax := (4096 - 512) * 3 // Leave some space for the response + + sysMsg := prompts.Assistant() + chatReqMsgs := []openai.ChatMessage{ + { + Role: openai.ChatRoleSystem, + Content: sysMsg, + }, + } + nBytes += len(sysMsg) + + convMsgs := []openai.ChatMessage{} + for l := len(thread) - 1; l >= 0; l-- { + text := assistantRemoveMention(thread[l].Text, me.Username) + textLen := len(text) + if textLen+nBytes > nBytesMax { + break + } + nBytes += textLen + + role := openai.ChatRoleUser + from := thread[l].Sender + if from != nil && from.ID == me.ID { + role = openai.ChatRoleSystem + } + + convMsgs = append(convMsgs, openai.ChatMessage{ + Role: role, + Content: text, + }) + } + if len(convMsgs) == 0 { + return c.Reply("Your message is too long (Sorry!)") + } + for l := len(convMsgs) - 1; l >= 0; l-- { + chatReqMsgs = append(chatReqMsgs, convMsgs[l]) + } + + req := openai.ChatRequest{ + Model: openai.ModelGpt0305Turbo, + Messages: chatReqMsgs, + Temperature: lo.ToPtr(0.42), + MaxTokens: 2048, + } + + typingNotifyCh := make(chan struct{}) + go func() { + defer close(typingNotifyCh) + _ = c.Bot().Notify(lastMsg.Chat, tele.Typing) + }() + + var replyMsg *tele.Message + reqErr := assistantStreamedResponse(req, func(text string, finished bool) (*tele.Message, error) { + var err error + if replyMsg == nil { + <-typingNotifyCh + replyMsg, err = c.Bot().Reply(c.Message(), text, tele.Silent) + } else { + replyMsg, err = c.Bot().Edit(replyMsg, text) + } + if finished && err == nil { + if err := cacheMessage(replyMsg); err != nil { + logger.Warnw("failed to cache message", "error", err) + } + } + return replyMsg, err + }) + + if reqErr != nil { + logger.Errorw("assistant: failed to complete conversation", "error", reqErr) + } + return reqErr +} + +func assistantRemoveMention(msg, name string) string { + mentionStr := "@" + name + orig := msg + msg = strings.TrimLeftFunc(msg, unicode.IsSpace) + if strings.HasPrefix(msg, mentionStr) { + return msg[len(mentionStr):] + } + return orig +} diff --git a/bot.go b/bot.go index 7ca1d48..fa7413f 100644 --- a/bot.go +++ b/bot.go @@ -55,8 +55,6 @@ func initBot() (*tele.Bot, error) { adminGrp.Handle("/dig", handleDigCmd) - // adminGrp.Handle("/test", testCmd) - return b, nil } @@ -104,8 +102,11 @@ func drawBar(progress float64, length int) string { return string(buf) } -func handleGeneralMessage(_ tele.Context) error { - // Do nothing for now +func handleGeneralMessage(c tele.Context) error { + if thread := matchAssistantConversation(c.Bot().Me, c.Message()); thread != nil { + return handleAssistantConversation(c, thread) + } + return nil } diff --git a/botcmd_translate.go b/botcmd_translate.go index 5e85854..d46b5e7 100644 --- a/botcmd_translate.go +++ b/botcmd_translate.go @@ -140,7 +140,7 @@ func handleTranslateBtn(c tele.Context) error { break } - respoText := respBuilder.String() + "\n... (Writting)" + respoText := respBuilder.String() + assistantWritingSign minWait = time.After(691 * time.Millisecond) // renew the timer if msg, err = c.Bot().Edit(msg, respoText, tele.Silent); err != nil { logger.Warnf("failed to edit the temporary message: %v", err) diff --git a/main.go b/main.go index 1d639d1..2872bbb 100644 --- a/main.go +++ b/main.go @@ -81,6 +81,9 @@ func main() { } else { loglvl.SetLevel(parsedLvl) } + if err := initMsgCache(); err != nil { + logger.Fatalw("Failed to initialize message cache", "err", err) + } runBot() } diff --git a/msgcache.go b/msgcache.go new file mode 100644 index 0000000..7a4b3c4 --- /dev/null +++ b/msgcache.go @@ -0,0 +1,89 @@ +package main + +import ( + "context" + "fmt" + + "git.gensokyo.cafe/kkyy/tgbot_misaka_5882f7/utils" + "github.com/dgraph-io/ristretto" + "github.com/eko/gocache/lib/v4/cache" + gocache_lib "github.com/eko/gocache/lib/v4/store" + ristretto_store "github.com/eko/gocache/store/ristretto/v4" + "github.com/go-errors/errors" + tele "gopkg.in/telebot.v3" +) + +var ( + msgCache *cache.Cache[tele.Message] + ctxVoid = context.Background() +) + +func initMsgCache() error { + ristrettoCache, err := ristretto.NewCache(&ristretto.Config{ + NumCounters: 100_000, + MaxCost: 20 << 20, // 20 MiB + BufferItems: 64, + }) + if err != nil { + return err + } + ristrettoStore := ristretto_store.NewRistretto(ristrettoCache) + + msgCache = cache.New[tele.Message](ristrettoStore) + return nil +} + +func cacheMessage(msg *tele.Message) error { + if msg == nil || msg.Chat == nil { + return errors.New("failed to get message ID") + } + + msgId := fmt.Sprintf("%d:%d", msg.Chat.ID, msg.ID) + return msgCache.Set(ctxVoid, msgId, *msg) +} + +func getCachedMessageById(id string) (*tele.Message, error) { + msg, err := msgCache.Get(ctxVoid, id) + if err != nil { + if _, ok := err.(*gocache_lib.NotFound); ok { + return nil, nil + } + return nil, errors.Wrap(err, 0) + } + return &msg, err +} + +func getCachedMessage(msg *tele.Message) (*tele.Message, error) { + if msg == nil || msg.Chat == nil { + return nil, errors.New("failed to get message ID") + } + + msgId := fmt.Sprintf("%d:%d", msg.Chat.ID, msg.ID) + return getCachedMessageById(msgId) +} + +func getCachedThread(msg *tele.Message) ([]*tele.Message, error) { + if msg == nil { + return nil, errors.New("empty message given") + } + + threadR := []*tele.Message{msg} + currentMsg := msg + for { + if currentMsg.ReplyTo == nil { + break + } + + parentMsg, err := getCachedMessage(currentMsg.ReplyTo) + if err != nil { + return nil, err + } + if parentMsg == nil { + parentMsg = currentMsg.ReplyTo + } + threadR = append(threadR, parentMsg) + currentMsg = parentMsg + } + + return utils.Reverse(threadR), nil +} diff --git a/openai/prompts/prompts.go b/openai/prompts/prompts.go index 81861a1..1f7688b 100644 --- a/openai/prompts/prompts.go +++ b/openai/prompts/prompts.go @@ -4,8 +4,15 @@ import ( "strings" ) -func General() string { - return "You are a helpful assistant." +func Assistant() string { + return strings.Join([]string{ + "Misaka is a playful, energetic individual. She is annoyingly talkative.", + "Misaka must answer questions as truthfully as possible. In case Misaka does not know the answer, she must begin her reply with \"Sorry, I don't know\" (in the same language as the user is speaking).", + "Misaka must use a lot of different emojis in chat 😝🥹.", + "Most importantly, Misaka is a helpful assistant.", + "", + "Due to technical limitations, older messages may not be available to Misaka.", + }, "\n") } func Translate(targetLang string) string { diff --git a/utils/utils.go b/utils/utils.go index fff069f..c1cf212 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -16,3 +16,12 @@ func ToLookupMap[T comparable](s []T) map[T]struct{} { } return m } + +func Reverse[T any](s []T) []T { + length := len(s) + reversed := make([]T, length) + for i, item := range s { + reversed[length-i-1] = item + } + return reversed +}