feat: integrate the AI assistant.
This commit is contained in:
		
							parent
							
								
									522d253410
								
							
						
					
					
						commit
						dcb251d2ad
					
				| 
						 | 
				
			
			@ -0,0 +1,250 @@
 | 
			
		|||
package main
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
	"unicode"
 | 
			
		||||
 | 
			
		||||
	"git.gensokyo.cafe/kkyy/tgbot_misaka_5882f7/openai"
 | 
			
		||||
	"git.gensokyo.cafe/kkyy/tgbot_misaka_5882f7/openai/prompts"
 | 
			
		||||
	"github.com/samber/lo"
 | 
			
		||||
	tele "gopkg.in/telebot.v3"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var assistantWritingSign = "\n...📝"
 | 
			
		||||
 | 
			
		||||
func matchAssistantConversation(botUsr *tele.User, msg *tele.Message) []*tele.Message {
 | 
			
		||||
	// A thread must meet the following conditions to be considered a conversation with the assistant (for now):
 | 
			
		||||
	// It has only two participants: the assistant and the user.
 | 
			
		||||
	// Or, it has only one participant, which is the user. In this case, it has to be in a private chat, or the assistant must be mentioned if it's in a group chat.
 | 
			
		||||
	// No bot commands, images, stickers, or other message types involved.
 | 
			
		||||
	participants := map[string]struct{}{}
 | 
			
		||||
	hasAssistant := false
 | 
			
		||||
	hasUser := false
 | 
			
		||||
	assistantMentioned := false
 | 
			
		||||
	isPrivateChat := true
 | 
			
		||||
 | 
			
		||||
	thread, err := getCachedThread(msg)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, msg := range thread {
 | 
			
		||||
		from := msg.Sender
 | 
			
		||||
		chat := msg.Chat
 | 
			
		||||
		if from == nil || chat == nil {
 | 
			
		||||
			return nil
 | 
			
		||||
		}
 | 
			
		||||
		participant := strconv.FormatInt(chat.ID, 10) + ":" + strconv.FormatInt(from.ID, 10)
 | 
			
		||||
		participants[participant] = struct{}{}
 | 
			
		||||
 | 
			
		||||
		if chat.Type != tele.ChatPrivate {
 | 
			
		||||
			isPrivateChat = false
 | 
			
		||||
		}
 | 
			
		||||
		if !lo.Contains([]tele.ChatType{tele.ChatPrivate, tele.ChatGroup, tele.ChatSuperGroup}, chat.Type) {
 | 
			
		||||
			return nil
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if msg.Text == "" {
 | 
			
		||||
			return nil
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if from.ID == botUsr.ID {
 | 
			
		||||
			hasAssistant = true
 | 
			
		||||
			if strings.HasSuffix(msg.Text, assistantWritingSign) {
 | 
			
		||||
				return nil
 | 
			
		||||
			}
 | 
			
		||||
		} else {
 | 
			
		||||
			if from.IsBot {
 | 
			
		||||
				return nil
 | 
			
		||||
			}
 | 
			
		||||
			hasUser = true
 | 
			
		||||
			// Only @bot_user_name in the beginning will be counted as "mention"
 | 
			
		||||
			mentionStr := "@" + botUsr.Username
 | 
			
		||||
			if strings.HasPrefix(strings.TrimSpace(msg.Text), mentionStr) {
 | 
			
		||||
				assistantMentioned = true
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if len(participants) > 2 {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
	if hasAssistant && hasUser {
 | 
			
		||||
		return thread
 | 
			
		||||
	}
 | 
			
		||||
	if hasUser && (isPrivateChat || assistantMentioned) {
 | 
			
		||||
		return thread
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type assistantStreamedResponseCb func(text string, finished bool) (*tele.Message, error)
 | 
			
		||||
 | 
			
		||||
func assistantStreamedResponse(request openai.ChatRequest, cb assistantStreamedResponseCb) error {
 | 
			
		||||
	logger.Debugf("Openai chat request: %#+v", request)
 | 
			
		||||
	ai := openai.NewClient(config.OpenAIApiKey)
 | 
			
		||||
 | 
			
		||||
	resp, err := ai.ChatCompletionStream(request)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	nErrs := 0
 | 
			
		||||
	go func() {
 | 
			
		||||
		respBuilder := strings.Builder{}
 | 
			
		||||
		minWait := time.After(1 * time.Second)
 | 
			
		||||
		for {
 | 
			
		||||
			var (
 | 
			
		||||
				nNewChunk        int
 | 
			
		||||
				finished         bool
 | 
			
		||||
				minWaitSatisfied bool
 | 
			
		||||
			)
 | 
			
		||||
		Drain:
 | 
			
		||||
			for {
 | 
			
		||||
				select {
 | 
			
		||||
				case chunk, ok := <-resp.Stream:
 | 
			
		||||
					if !ok {
 | 
			
		||||
						finished = true
 | 
			
		||||
						break Drain
 | 
			
		||||
					}
 | 
			
		||||
					nNewChunk += 1
 | 
			
		||||
					respBuilder.WriteString(chunk)
 | 
			
		||||
				default:
 | 
			
		||||
					if minWaitSatisfied {
 | 
			
		||||
						break Drain
 | 
			
		||||
					}
 | 
			
		||||
					<-minWait
 | 
			
		||||
					minWaitSatisfied = true
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			if nNewChunk == 0 {
 | 
			
		||||
				if chunk, ok := <-resp.Stream; !ok {
 | 
			
		||||
					finished = true
 | 
			
		||||
				} else {
 | 
			
		||||
					respBuilder.WriteString(chunk)
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
			if finished {
 | 
			
		||||
				break
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			respoText := respBuilder.String() + assistantWritingSign
 | 
			
		||||
			minWait = time.After(691 * time.Millisecond) // renew the timer
 | 
			
		||||
 | 
			
		||||
			if _, err := cb(respoText, false); err != nil {
 | 
			
		||||
				logger.Warnw("failed to send partial update", "error", err)
 | 
			
		||||
				nErrs += 1
 | 
			
		||||
				if nErrs > 3 {
 | 
			
		||||
					logger.Warnw("too many errors, aborting")
 | 
			
		||||
					return
 | 
			
		||||
				}
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			logger.Debugf("... message edited")
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		respText := respBuilder.String()
 | 
			
		||||
		if _, err = cb(respText, true); err != nil {
 | 
			
		||||
			logger.Warnw("assistant: failed to send message", "error", err)
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func handleAssistantConversation(c tele.Context, thread []*tele.Message) error {
 | 
			
		||||
	me := c.Bot().Me
 | 
			
		||||
	lastMsg := thread[len(thread)-1]
 | 
			
		||||
	if err := cacheMessage(lastMsg); err != nil {
 | 
			
		||||
		logger.Warnw("failed to cache message", "error", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	nBytes := 0                   // Used to estimated number of tokens. For now we treat 3 bytes as 1 token.
 | 
			
		||||
	nBytesMax := (4096 - 512) * 3 // Leave some space for the response
 | 
			
		||||
 | 
			
		||||
	sysMsg := prompts.Assistant()
 | 
			
		||||
	chatReqMsgs := []openai.ChatMessage{
 | 
			
		||||
		{
 | 
			
		||||
			Role:    openai.ChatRoleSystem,
 | 
			
		||||
			Content: sysMsg,
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
	nBytes += len(sysMsg)
 | 
			
		||||
 | 
			
		||||
	convMsgs := []openai.ChatMessage{}
 | 
			
		||||
	for l := len(thread) - 1; l >= 0; l-- {
 | 
			
		||||
		text := assistantRemoveMention(thread[l].Text, me.Username)
 | 
			
		||||
		textLen := len(text)
 | 
			
		||||
		if textLen+nBytes > nBytesMax {
 | 
			
		||||
			break
 | 
			
		||||
		}
 | 
			
		||||
		nBytes += textLen
 | 
			
		||||
 | 
			
		||||
		role := openai.ChatRoleUser
 | 
			
		||||
		from := thread[l].Sender
 | 
			
		||||
		if from != nil && from.ID == me.ID {
 | 
			
		||||
			role = openai.ChatRoleSystem
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		convMsgs = append(convMsgs, openai.ChatMessage{
 | 
			
		||||
			Role:    role,
 | 
			
		||||
			Content: text,
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
	if len(convMsgs) == 0 {
 | 
			
		||||
		return c.Reply("Your message is too long (Sorry!)")
 | 
			
		||||
	}
 | 
			
		||||
	for l := len(convMsgs) - 1; l >= 0; l-- {
 | 
			
		||||
		chatReqMsgs = append(chatReqMsgs, convMsgs[l])
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	req := openai.ChatRequest{
 | 
			
		||||
		Model:       openai.ModelGpt0305Turbo,
 | 
			
		||||
		Messages:    chatReqMsgs,
 | 
			
		||||
		Temperature: lo.ToPtr(0.42),
 | 
			
		||||
		MaxTokens:   2048,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	typingNotifyCh := make(chan struct{})
 | 
			
		||||
	go func() {
 | 
			
		||||
		defer close(typingNotifyCh)
 | 
			
		||||
		_ = c.Bot().Notify(lastMsg.Chat, tele.Typing)
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	var replyMsg *tele.Message
 | 
			
		||||
	reqErr := assistantStreamedResponse(req, func(text string, finished bool) (*tele.Message, error) {
 | 
			
		||||
		var err error
 | 
			
		||||
		if replyMsg == nil {
 | 
			
		||||
			<-typingNotifyCh
 | 
			
		||||
			replyMsg, err = c.Bot().Reply(c.Message(), text, tele.Silent)
 | 
			
		||||
		} else {
 | 
			
		||||
			replyMsg, err = c.Bot().Edit(replyMsg, text)
 | 
			
		||||
		}
 | 
			
		||||
		if finished && err == nil {
 | 
			
		||||
			if err := cacheMessage(replyMsg); err != nil {
 | 
			
		||||
				logger.Warnw("failed to cache message", "error", err)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		return replyMsg, err
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	if reqErr != nil {
 | 
			
		||||
		logger.Errorw("assistant: failed to complete conversation", "error", reqErr)
 | 
			
		||||
	}
 | 
			
		||||
	return reqErr
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func assistantRemoveMention(msg, name string) string {
 | 
			
		||||
	mentionStr := "@" + name
 | 
			
		||||
	orig := msg
 | 
			
		||||
	msg = strings.TrimLeftFunc(msg, unicode.IsSpace)
 | 
			
		||||
	if strings.HasPrefix(msg, mentionStr) {
 | 
			
		||||
		return msg[len(mentionStr):]
 | 
			
		||||
	}
 | 
			
		||||
	return orig
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										9
									
								
								bot.go
								
								
								
								
							
							
						
						
									
										9
									
								
								bot.go
								
								
								
								
							| 
						 | 
				
			
			@ -55,8 +55,6 @@ func initBot() (*tele.Bot, error) {
 | 
			
		|||
 | 
			
		||||
	adminGrp.Handle("/dig", handleDigCmd)
 | 
			
		||||
 | 
			
		||||
	// adminGrp.Handle("/test", testCmd)
 | 
			
		||||
 | 
			
		||||
	return b, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -104,8 +102,11 @@ func drawBar(progress float64, length int) string {
 | 
			
		|||
	return string(buf)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func handleGeneralMessage(_ tele.Context) error {
 | 
			
		||||
	// Do nothing for now
 | 
			
		||||
func handleGeneralMessage(c tele.Context) error {
 | 
			
		||||
	if thread := matchAssistantConversation(c.Bot().Me, c.Message()); thread != nil {
 | 
			
		||||
		return handleAssistantConversation(c, thread)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -140,7 +140,7 @@ func handleTranslateBtn(c tele.Context) error {
 | 
			
		|||
			break
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		respoText := respBuilder.String() + "\n... (Writting)"
 | 
			
		||||
		respoText := respBuilder.String() + assistantWritingSign
 | 
			
		||||
		minWait = time.After(691 * time.Millisecond) // renew the timer
 | 
			
		||||
		if msg, err = c.Bot().Edit(msg, respoText, tele.Silent); err != nil {
 | 
			
		||||
			logger.Warnf("failed to edit the temporary message: %v", err)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										3
									
								
								main.go
								
								
								
								
							
							
						
						
									
										3
									
								
								main.go
								
								
								
								
							| 
						 | 
				
			
			@ -81,6 +81,9 @@ func main() {
 | 
			
		|||
	} else {
 | 
			
		||||
		loglvl.SetLevel(parsedLvl)
 | 
			
		||||
	}
 | 
			
		||||
	if err := initMsgCache(); err != nil {
 | 
			
		||||
		logger.Fatalw("Failed to initialize message cache", "err", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	runBot()
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -0,0 +1,89 @@
 | 
			
		|||
package main
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"fmt"
 | 
			
		||||
 | 
			
		||||
	"git.gensokyo.cafe/kkyy/tgbot_misaka_5882f7/utils"
 | 
			
		||||
	"github.com/dgraph-io/ristretto"
 | 
			
		||||
	"github.com/eko/gocache/lib/v4/cache"
 | 
			
		||||
	gocache_lib "github.com/eko/gocache/lib/v4/store"
 | 
			
		||||
	ristretto_store "github.com/eko/gocache/store/ristretto/v4"
 | 
			
		||||
	"github.com/go-errors/errors"
 | 
			
		||||
	tele "gopkg.in/telebot.v3"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var (
 | 
			
		||||
	msgCache *cache.Cache[tele.Message]
 | 
			
		||||
	ctxVoid  = context.Background()
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func initMsgCache() error {
 | 
			
		||||
	ristrettoCache, err := ristretto.NewCache(&ristretto.Config{
 | 
			
		||||
		NumCounters: 100_000,
 | 
			
		||||
		MaxCost:     20 << 20, // 20 MiB
 | 
			
		||||
		BufferItems: 64,
 | 
			
		||||
	})
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	ristrettoStore := ristretto_store.NewRistretto(ristrettoCache)
 | 
			
		||||
 | 
			
		||||
	msgCache = cache.New[tele.Message](ristrettoStore)
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func cacheMessage(msg *tele.Message) error {
 | 
			
		||||
	if msg == nil || msg.Chat == nil {
 | 
			
		||||
		return errors.New("failed to get message ID")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	msgId := fmt.Sprintf("%d:%d", msg.Chat.ID, msg.ID)
 | 
			
		||||
	return msgCache.Set(ctxVoid, msgId, *msg)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func getCachedMessageById(id string) (*tele.Message, error) {
 | 
			
		||||
	msg, err := msgCache.Get(ctxVoid, id)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		if _, ok := err.(*gocache_lib.NotFound); ok {
 | 
			
		||||
			return nil, nil
 | 
			
		||||
		}
 | 
			
		||||
		return nil, errors.Wrap(err, 0)
 | 
			
		||||
	}
 | 
			
		||||
	return &msg, err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func getCachedMessage(msg *tele.Message) (*tele.Message, error) {
 | 
			
		||||
	if msg == nil || msg.Chat == nil {
 | 
			
		||||
		return nil, errors.New("failed to get message ID")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	msgId := fmt.Sprintf("%d:%d", msg.Chat.ID, msg.ID)
 | 
			
		||||
	return getCachedMessageById(msgId)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func getCachedThread(msg *tele.Message) ([]*tele.Message, error) {
 | 
			
		||||
	if msg == nil {
 | 
			
		||||
		return nil, errors.New("empty message given")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	threadR := []*tele.Message{msg}
 | 
			
		||||
	currentMsg := msg
 | 
			
		||||
	for {
 | 
			
		||||
		if currentMsg.ReplyTo == nil {
 | 
			
		||||
			break
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		parentMsg, err := getCachedMessage(currentMsg.ReplyTo)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		if parentMsg == nil {
 | 
			
		||||
			parentMsg = currentMsg.ReplyTo
 | 
			
		||||
		}
 | 
			
		||||
		threadR = append(threadR, parentMsg)
 | 
			
		||||
		currentMsg = parentMsg
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return utils.Reverse(threadR), nil
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -4,8 +4,15 @@ import (
 | 
			
		|||
	"strings"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func General() string {
 | 
			
		||||
	return "You are a helpful assistant."
 | 
			
		||||
func Assistant() string {
 | 
			
		||||
	return strings.Join([]string{
 | 
			
		||||
		"Misaka is a playful, energetic individual. She is annoyingly talkative.",
 | 
			
		||||
		"Misaka must answer questions as truthfully as possible. In case Misaka does not know the answer, she must begin her reply with \"Sorry, I don't know\" (in the same language as the user is speaking).",
 | 
			
		||||
		"Misaka must use a lot of different emojis in chat 😝🥹.",
 | 
			
		||||
		"Most importantly, Misaka is a helpful assistant.",
 | 
			
		||||
		"",
 | 
			
		||||
		"Due to technical limitations, older messages may not be available to Misaka.",
 | 
			
		||||
	}, "\n")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func Translate(targetLang string) string {
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -16,3 +16,12 @@ func ToLookupMap[T comparable](s []T) map[T]struct{} {
 | 
			
		|||
	}
 | 
			
		||||
	return m
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func Reverse[T any](s []T) []T {
 | 
			
		||||
	length := len(s)
 | 
			
		||||
	reversed := make([]T, length)
 | 
			
		||||
	for i, item := range s {
 | 
			
		||||
		reversed[length-i-1] = item
 | 
			
		||||
	}
 | 
			
		||||
	return reversed
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue