feat: integrate the AI assistant.
This commit is contained in:
parent
522d253410
commit
dcb251d2ad
|
@ -0,0 +1,250 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
"unicode"
|
||||||
|
|
||||||
|
"git.gensokyo.cafe/kkyy/tgbot_misaka_5882f7/openai"
|
||||||
|
"git.gensokyo.cafe/kkyy/tgbot_misaka_5882f7/openai/prompts"
|
||||||
|
"github.com/samber/lo"
|
||||||
|
tele "gopkg.in/telebot.v3"
|
||||||
|
)
|
||||||
|
|
||||||
|
var assistantWritingSign = "\n...📝"
|
||||||
|
|
||||||
|
func matchAssistantConversation(botUsr *tele.User, msg *tele.Message) []*tele.Message {
|
||||||
|
// A thread must meet the following conditions to be considered a conversation with the assistant (for now):
|
||||||
|
// It has only two participants: the assistant and the user.
|
||||||
|
// Or, it has only one participant, which is the user. In this case, it has to be in a private chat, or the assistant must be mentioned if it's in a group chat.
|
||||||
|
// No bot commands, images, stickers, or other message types involved.
|
||||||
|
participants := map[string]struct{}{}
|
||||||
|
hasAssistant := false
|
||||||
|
hasUser := false
|
||||||
|
assistantMentioned := false
|
||||||
|
isPrivateChat := true
|
||||||
|
|
||||||
|
thread, err := getCachedThread(msg)
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, msg := range thread {
|
||||||
|
from := msg.Sender
|
||||||
|
chat := msg.Chat
|
||||||
|
if from == nil || chat == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
participant := strconv.FormatInt(chat.ID, 10) + ":" + strconv.FormatInt(from.ID, 10)
|
||||||
|
participants[participant] = struct{}{}
|
||||||
|
|
||||||
|
if chat.Type != tele.ChatPrivate {
|
||||||
|
isPrivateChat = false
|
||||||
|
}
|
||||||
|
if !lo.Contains([]tele.ChatType{tele.ChatPrivate, tele.ChatGroup, tele.ChatSuperGroup}, chat.Type) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if msg.Text == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if from.ID == botUsr.ID {
|
||||||
|
hasAssistant = true
|
||||||
|
if strings.HasSuffix(msg.Text, assistantWritingSign) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if from.IsBot {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
hasUser = true
|
||||||
|
// Only @bot_user_name in the beginning will be counted as "mention"
|
||||||
|
mentionStr := "@" + botUsr.Username
|
||||||
|
if strings.HasPrefix(strings.TrimSpace(msg.Text), mentionStr) {
|
||||||
|
assistantMentioned = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(participants) > 2 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if hasAssistant && hasUser {
|
||||||
|
return thread
|
||||||
|
}
|
||||||
|
if hasUser && (isPrivateChat || assistantMentioned) {
|
||||||
|
return thread
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type assistantStreamedResponseCb func(text string, finished bool) (*tele.Message, error)
|
||||||
|
|
||||||
|
func assistantStreamedResponse(request openai.ChatRequest, cb assistantStreamedResponseCb) error {
|
||||||
|
logger.Debugf("Openai chat request: %#+v", request)
|
||||||
|
ai := openai.NewClient(config.OpenAIApiKey)
|
||||||
|
|
||||||
|
resp, err := ai.ChatCompletionStream(request)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
nErrs := 0
|
||||||
|
go func() {
|
||||||
|
respBuilder := strings.Builder{}
|
||||||
|
minWait := time.After(1 * time.Second)
|
||||||
|
for {
|
||||||
|
var (
|
||||||
|
nNewChunk int
|
||||||
|
finished bool
|
||||||
|
minWaitSatisfied bool
|
||||||
|
)
|
||||||
|
Drain:
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case chunk, ok := <-resp.Stream:
|
||||||
|
if !ok {
|
||||||
|
finished = true
|
||||||
|
break Drain
|
||||||
|
}
|
||||||
|
nNewChunk += 1
|
||||||
|
respBuilder.WriteString(chunk)
|
||||||
|
default:
|
||||||
|
if minWaitSatisfied {
|
||||||
|
break Drain
|
||||||
|
}
|
||||||
|
<-minWait
|
||||||
|
minWaitSatisfied = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if nNewChunk == 0 {
|
||||||
|
if chunk, ok := <-resp.Stream; !ok {
|
||||||
|
finished = true
|
||||||
|
} else {
|
||||||
|
respBuilder.WriteString(chunk)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if finished {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
respoText := respBuilder.String() + assistantWritingSign
|
||||||
|
minWait = time.After(691 * time.Millisecond) // renew the timer
|
||||||
|
|
||||||
|
if _, err := cb(respoText, false); err != nil {
|
||||||
|
logger.Warnw("failed to send partial update", "error", err)
|
||||||
|
nErrs += 1
|
||||||
|
if nErrs > 3 {
|
||||||
|
logger.Warnw("too many errors, aborting")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Debugf("... message edited")
|
||||||
|
}
|
||||||
|
|
||||||
|
respText := respBuilder.String()
|
||||||
|
if _, err = cb(respText, true); err != nil {
|
||||||
|
logger.Warnw("assistant: failed to send message", "error", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func handleAssistantConversation(c tele.Context, thread []*tele.Message) error {
|
||||||
|
me := c.Bot().Me
|
||||||
|
lastMsg := thread[len(thread)-1]
|
||||||
|
if err := cacheMessage(lastMsg); err != nil {
|
||||||
|
logger.Warnw("failed to cache message", "error", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
nBytes := 0 // Used to estimated number of tokens. For now we treat 3 bytes as 1 token.
|
||||||
|
nBytesMax := (4096 - 512) * 3 // Leave some space for the response
|
||||||
|
|
||||||
|
sysMsg := prompts.Assistant()
|
||||||
|
chatReqMsgs := []openai.ChatMessage{
|
||||||
|
{
|
||||||
|
Role: openai.ChatRoleSystem,
|
||||||
|
Content: sysMsg,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
nBytes += len(sysMsg)
|
||||||
|
|
||||||
|
convMsgs := []openai.ChatMessage{}
|
||||||
|
for l := len(thread) - 1; l >= 0; l-- {
|
||||||
|
text := assistantRemoveMention(thread[l].Text, me.Username)
|
||||||
|
textLen := len(text)
|
||||||
|
if textLen+nBytes > nBytesMax {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
nBytes += textLen
|
||||||
|
|
||||||
|
role := openai.ChatRoleUser
|
||||||
|
from := thread[l].Sender
|
||||||
|
if from != nil && from.ID == me.ID {
|
||||||
|
role = openai.ChatRoleSystem
|
||||||
|
}
|
||||||
|
|
||||||
|
convMsgs = append(convMsgs, openai.ChatMessage{
|
||||||
|
Role: role,
|
||||||
|
Content: text,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
if len(convMsgs) == 0 {
|
||||||
|
return c.Reply("Your message is too long (Sorry!)")
|
||||||
|
}
|
||||||
|
for l := len(convMsgs) - 1; l >= 0; l-- {
|
||||||
|
chatReqMsgs = append(chatReqMsgs, convMsgs[l])
|
||||||
|
}
|
||||||
|
|
||||||
|
req := openai.ChatRequest{
|
||||||
|
Model: openai.ModelGpt0305Turbo,
|
||||||
|
Messages: chatReqMsgs,
|
||||||
|
Temperature: lo.ToPtr(0.42),
|
||||||
|
MaxTokens: 2048,
|
||||||
|
}
|
||||||
|
|
||||||
|
typingNotifyCh := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
defer close(typingNotifyCh)
|
||||||
|
_ = c.Bot().Notify(lastMsg.Chat, tele.Typing)
|
||||||
|
}()
|
||||||
|
|
||||||
|
var replyMsg *tele.Message
|
||||||
|
reqErr := assistantStreamedResponse(req, func(text string, finished bool) (*tele.Message, error) {
|
||||||
|
var err error
|
||||||
|
if replyMsg == nil {
|
||||||
|
<-typingNotifyCh
|
||||||
|
replyMsg, err = c.Bot().Reply(c.Message(), text, tele.Silent)
|
||||||
|
} else {
|
||||||
|
replyMsg, err = c.Bot().Edit(replyMsg, text)
|
||||||
|
}
|
||||||
|
if finished && err == nil {
|
||||||
|
if err := cacheMessage(replyMsg); err != nil {
|
||||||
|
logger.Warnw("failed to cache message", "error", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return replyMsg, err
|
||||||
|
})
|
||||||
|
|
||||||
|
if reqErr != nil {
|
||||||
|
logger.Errorw("assistant: failed to complete conversation", "error", reqErr)
|
||||||
|
}
|
||||||
|
return reqErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func assistantRemoveMention(msg, name string) string {
|
||||||
|
mentionStr := "@" + name
|
||||||
|
orig := msg
|
||||||
|
msg = strings.TrimLeftFunc(msg, unicode.IsSpace)
|
||||||
|
if strings.HasPrefix(msg, mentionStr) {
|
||||||
|
return msg[len(mentionStr):]
|
||||||
|
}
|
||||||
|
return orig
|
||||||
|
}
|
9
bot.go
9
bot.go
|
@ -55,8 +55,6 @@ func initBot() (*tele.Bot, error) {
|
||||||
|
|
||||||
adminGrp.Handle("/dig", handleDigCmd)
|
adminGrp.Handle("/dig", handleDigCmd)
|
||||||
|
|
||||||
// adminGrp.Handle("/test", testCmd)
|
|
||||||
|
|
||||||
return b, nil
|
return b, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -104,8 +102,11 @@ func drawBar(progress float64, length int) string {
|
||||||
return string(buf)
|
return string(buf)
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleGeneralMessage(_ tele.Context) error {
|
func handleGeneralMessage(c tele.Context) error {
|
||||||
// Do nothing for now
|
if thread := matchAssistantConversation(c.Bot().Me, c.Message()); thread != nil {
|
||||||
|
return handleAssistantConversation(c, thread)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -140,7 +140,7 @@ func handleTranslateBtn(c tele.Context) error {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
respoText := respBuilder.String() + "\n... (Writting)"
|
respoText := respBuilder.String() + assistantWritingSign
|
||||||
minWait = time.After(691 * time.Millisecond) // renew the timer
|
minWait = time.After(691 * time.Millisecond) // renew the timer
|
||||||
if msg, err = c.Bot().Edit(msg, respoText, tele.Silent); err != nil {
|
if msg, err = c.Bot().Edit(msg, respoText, tele.Silent); err != nil {
|
||||||
logger.Warnf("failed to edit the temporary message: %v", err)
|
logger.Warnf("failed to edit the temporary message: %v", err)
|
||||||
|
|
3
main.go
3
main.go
|
@ -81,6 +81,9 @@ func main() {
|
||||||
} else {
|
} else {
|
||||||
loglvl.SetLevel(parsedLvl)
|
loglvl.SetLevel(parsedLvl)
|
||||||
}
|
}
|
||||||
|
if err := initMsgCache(); err != nil {
|
||||||
|
logger.Fatalw("Failed to initialize message cache", "err", err)
|
||||||
|
}
|
||||||
|
|
||||||
runBot()
|
runBot()
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,89 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"git.gensokyo.cafe/kkyy/tgbot_misaka_5882f7/utils"
|
||||||
|
"github.com/dgraph-io/ristretto"
|
||||||
|
"github.com/eko/gocache/lib/v4/cache"
|
||||||
|
gocache_lib "github.com/eko/gocache/lib/v4/store"
|
||||||
|
ristretto_store "github.com/eko/gocache/store/ristretto/v4"
|
||||||
|
"github.com/go-errors/errors"
|
||||||
|
tele "gopkg.in/telebot.v3"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
msgCache *cache.Cache[tele.Message]
|
||||||
|
ctxVoid = context.Background()
|
||||||
|
)
|
||||||
|
|
||||||
|
func initMsgCache() error {
|
||||||
|
ristrettoCache, err := ristretto.NewCache(&ristretto.Config{
|
||||||
|
NumCounters: 100_000,
|
||||||
|
MaxCost: 20 << 20, // 20 MiB
|
||||||
|
BufferItems: 64,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
ristrettoStore := ristretto_store.NewRistretto(ristrettoCache)
|
||||||
|
|
||||||
|
msgCache = cache.New[tele.Message](ristrettoStore)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func cacheMessage(msg *tele.Message) error {
|
||||||
|
if msg == nil || msg.Chat == nil {
|
||||||
|
return errors.New("failed to get message ID")
|
||||||
|
}
|
||||||
|
|
||||||
|
msgId := fmt.Sprintf("%d:%d", msg.Chat.ID, msg.ID)
|
||||||
|
return msgCache.Set(ctxVoid, msgId, *msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func getCachedMessageById(id string) (*tele.Message, error) {
|
||||||
|
msg, err := msgCache.Get(ctxVoid, id)
|
||||||
|
if err != nil {
|
||||||
|
if _, ok := err.(*gocache_lib.NotFound); ok {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return nil, errors.Wrap(err, 0)
|
||||||
|
}
|
||||||
|
return &msg, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func getCachedMessage(msg *tele.Message) (*tele.Message, error) {
|
||||||
|
if msg == nil || msg.Chat == nil {
|
||||||
|
return nil, errors.New("failed to get message ID")
|
||||||
|
}
|
||||||
|
|
||||||
|
msgId := fmt.Sprintf("%d:%d", msg.Chat.ID, msg.ID)
|
||||||
|
return getCachedMessageById(msgId)
|
||||||
|
}
|
||||||
|
|
||||||
|
func getCachedThread(msg *tele.Message) ([]*tele.Message, error) {
|
||||||
|
if msg == nil {
|
||||||
|
return nil, errors.New("empty message given")
|
||||||
|
}
|
||||||
|
|
||||||
|
threadR := []*tele.Message{msg}
|
||||||
|
currentMsg := msg
|
||||||
|
for {
|
||||||
|
if currentMsg.ReplyTo == nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
parentMsg, err := getCachedMessage(currentMsg.ReplyTo)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if parentMsg == nil {
|
||||||
|
parentMsg = currentMsg.ReplyTo
|
||||||
|
}
|
||||||
|
threadR = append(threadR, parentMsg)
|
||||||
|
currentMsg = parentMsg
|
||||||
|
}
|
||||||
|
|
||||||
|
return utils.Reverse(threadR), nil
|
||||||
|
}
|
|
@ -4,8 +4,15 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
func General() string {
|
func Assistant() string {
|
||||||
return "You are a helpful assistant."
|
return strings.Join([]string{
|
||||||
|
"Misaka is a playful, energetic individual. She is annoyingly talkative.",
|
||||||
|
"Misaka must answer questions as truthfully as possible. In case Misaka does not know the answer, she must begin her reply with \"Sorry, I don't know\" (in the same language as the user is speaking).",
|
||||||
|
"Misaka must use a lot of different emojis in chat 😝🥹.",
|
||||||
|
"Most importantly, Misaka is a helpful assistant.",
|
||||||
|
"",
|
||||||
|
"Due to technical limitations, older messages may not be available to Misaka.",
|
||||||
|
}, "\n")
|
||||||
}
|
}
|
||||||
|
|
||||||
func Translate(targetLang string) string {
|
func Translate(targetLang string) string {
|
||||||
|
|
|
@ -16,3 +16,12 @@ func ToLookupMap[T comparable](s []T) map[T]struct{} {
|
||||||
}
|
}
|
||||||
return m
|
return m
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Reverse[T any](s []T) []T {
|
||||||
|
length := len(s)
|
||||||
|
reversed := make([]T, length)
|
||||||
|
for i, item := range s {
|
||||||
|
reversed[length-i-1] = item
|
||||||
|
}
|
||||||
|
return reversed
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue