feat: integrate the AI assistant.
This commit is contained in:
parent
522d253410
commit
dcb251d2ad
7 changed files with 366 additions and 7 deletions
250
assistant.go
Normal file
250
assistant.go
Normal file
|
|
@ -0,0 +1,250 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode"
|
||||
|
||||
"git.gensokyo.cafe/kkyy/tgbot_misaka_5882f7/openai"
|
||||
"git.gensokyo.cafe/kkyy/tgbot_misaka_5882f7/openai/prompts"
|
||||
"github.com/samber/lo"
|
||||
tele "gopkg.in/telebot.v3"
|
||||
)
|
||||
|
||||
var assistantWritingSign = "\n...📝"
|
||||
|
||||
func matchAssistantConversation(botUsr *tele.User, msg *tele.Message) []*tele.Message {
|
||||
// A thread must meet the following conditions to be considered a conversation with the assistant (for now):
|
||||
// It has only two participants: the assistant and the user.
|
||||
// Or, it has only one participant, which is the user. In this case, it has to be in a private chat, or the assistant must be mentioned if it's in a group chat.
|
||||
// No bot commands, images, stickers, or other message types involved.
|
||||
participants := map[string]struct{}{}
|
||||
hasAssistant := false
|
||||
hasUser := false
|
||||
assistantMentioned := false
|
||||
isPrivateChat := true
|
||||
|
||||
thread, err := getCachedThread(msg)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, msg := range thread {
|
||||
from := msg.Sender
|
||||
chat := msg.Chat
|
||||
if from == nil || chat == nil {
|
||||
return nil
|
||||
}
|
||||
participant := strconv.FormatInt(chat.ID, 10) + ":" + strconv.FormatInt(from.ID, 10)
|
||||
participants[participant] = struct{}{}
|
||||
|
||||
if chat.Type != tele.ChatPrivate {
|
||||
isPrivateChat = false
|
||||
}
|
||||
if !lo.Contains([]tele.ChatType{tele.ChatPrivate, tele.ChatGroup, tele.ChatSuperGroup}, chat.Type) {
|
||||
return nil
|
||||
}
|
||||
|
||||
if msg.Text == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
if from.ID == botUsr.ID {
|
||||
hasAssistant = true
|
||||
if strings.HasSuffix(msg.Text, assistantWritingSign) {
|
||||
return nil
|
||||
}
|
||||
} else {
|
||||
if from.IsBot {
|
||||
return nil
|
||||
}
|
||||
hasUser = true
|
||||
// Only @bot_user_name in the beginning will be counted as "mention"
|
||||
mentionStr := "@" + botUsr.Username
|
||||
if strings.HasPrefix(strings.TrimSpace(msg.Text), mentionStr) {
|
||||
assistantMentioned = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(participants) > 2 {
|
||||
return nil
|
||||
}
|
||||
if hasAssistant && hasUser {
|
||||
return thread
|
||||
}
|
||||
if hasUser && (isPrivateChat || assistantMentioned) {
|
||||
return thread
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type assistantStreamedResponseCb func(text string, finished bool) (*tele.Message, error)
|
||||
|
||||
func assistantStreamedResponse(request openai.ChatRequest, cb assistantStreamedResponseCb) error {
|
||||
logger.Debugf("Openai chat request: %#+v", request)
|
||||
ai := openai.NewClient(config.OpenAIApiKey)
|
||||
|
||||
resp, err := ai.ChatCompletionStream(request)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
nErrs := 0
|
||||
go func() {
|
||||
respBuilder := strings.Builder{}
|
||||
minWait := time.After(1 * time.Second)
|
||||
for {
|
||||
var (
|
||||
nNewChunk int
|
||||
finished bool
|
||||
minWaitSatisfied bool
|
||||
)
|
||||
Drain:
|
||||
for {
|
||||
select {
|
||||
case chunk, ok := <-resp.Stream:
|
||||
if !ok {
|
||||
finished = true
|
||||
break Drain
|
||||
}
|
||||
nNewChunk += 1
|
||||
respBuilder.WriteString(chunk)
|
||||
default:
|
||||
if minWaitSatisfied {
|
||||
break Drain
|
||||
}
|
||||
<-minWait
|
||||
minWaitSatisfied = true
|
||||
}
|
||||
}
|
||||
|
||||
if nNewChunk == 0 {
|
||||
if chunk, ok := <-resp.Stream; !ok {
|
||||
finished = true
|
||||
} else {
|
||||
respBuilder.WriteString(chunk)
|
||||
}
|
||||
}
|
||||
if finished {
|
||||
break
|
||||
}
|
||||
|
||||
respoText := respBuilder.String() + assistantWritingSign
|
||||
minWait = time.After(691 * time.Millisecond) // renew the timer
|
||||
|
||||
if _, err := cb(respoText, false); err != nil {
|
||||
logger.Warnw("failed to send partial update", "error", err)
|
||||
nErrs += 1
|
||||
if nErrs > 3 {
|
||||
logger.Warnw("too many errors, aborting")
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
logger.Debugf("... message edited")
|
||||
}
|
||||
|
||||
respText := respBuilder.String()
|
||||
if _, err = cb(respText, true); err != nil {
|
||||
logger.Warnw("assistant: failed to send message", "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func handleAssistantConversation(c tele.Context, thread []*tele.Message) error {
|
||||
me := c.Bot().Me
|
||||
lastMsg := thread[len(thread)-1]
|
||||
if err := cacheMessage(lastMsg); err != nil {
|
||||
logger.Warnw("failed to cache message", "error", err)
|
||||
}
|
||||
|
||||
nBytes := 0 // Used to estimated number of tokens. For now we treat 3 bytes as 1 token.
|
||||
nBytesMax := (4096 - 512) * 3 // Leave some space for the response
|
||||
|
||||
sysMsg := prompts.Assistant()
|
||||
chatReqMsgs := []openai.ChatMessage{
|
||||
{
|
||||
Role: openai.ChatRoleSystem,
|
||||
Content: sysMsg,
|
||||
},
|
||||
}
|
||||
nBytes += len(sysMsg)
|
||||
|
||||
convMsgs := []openai.ChatMessage{}
|
||||
for l := len(thread) - 1; l >= 0; l-- {
|
||||
text := assistantRemoveMention(thread[l].Text, me.Username)
|
||||
textLen := len(text)
|
||||
if textLen+nBytes > nBytesMax {
|
||||
break
|
||||
}
|
||||
nBytes += textLen
|
||||
|
||||
role := openai.ChatRoleUser
|
||||
from := thread[l].Sender
|
||||
if from != nil && from.ID == me.ID {
|
||||
role = openai.ChatRoleSystem
|
||||
}
|
||||
|
||||
convMsgs = append(convMsgs, openai.ChatMessage{
|
||||
Role: role,
|
||||
Content: text,
|
||||
})
|
||||
}
|
||||
if len(convMsgs) == 0 {
|
||||
return c.Reply("Your message is too long (Sorry!)")
|
||||
}
|
||||
for l := len(convMsgs) - 1; l >= 0; l-- {
|
||||
chatReqMsgs = append(chatReqMsgs, convMsgs[l])
|
||||
}
|
||||
|
||||
req := openai.ChatRequest{
|
||||
Model: openai.ModelGpt0305Turbo,
|
||||
Messages: chatReqMsgs,
|
||||
Temperature: lo.ToPtr(0.42),
|
||||
MaxTokens: 2048,
|
||||
}
|
||||
|
||||
typingNotifyCh := make(chan struct{})
|
||||
go func() {
|
||||
defer close(typingNotifyCh)
|
||||
_ = c.Bot().Notify(lastMsg.Chat, tele.Typing)
|
||||
}()
|
||||
|
||||
var replyMsg *tele.Message
|
||||
reqErr := assistantStreamedResponse(req, func(text string, finished bool) (*tele.Message, error) {
|
||||
var err error
|
||||
if replyMsg == nil {
|
||||
<-typingNotifyCh
|
||||
replyMsg, err = c.Bot().Reply(c.Message(), text, tele.Silent)
|
||||
} else {
|
||||
replyMsg, err = c.Bot().Edit(replyMsg, text)
|
||||
}
|
||||
if finished && err == nil {
|
||||
if err := cacheMessage(replyMsg); err != nil {
|
||||
logger.Warnw("failed to cache message", "error", err)
|
||||
}
|
||||
}
|
||||
return replyMsg, err
|
||||
})
|
||||
|
||||
if reqErr != nil {
|
||||
logger.Errorw("assistant: failed to complete conversation", "error", reqErr)
|
||||
}
|
||||
return reqErr
|
||||
}
|
||||
|
||||
func assistantRemoveMention(msg, name string) string {
|
||||
mentionStr := "@" + name
|
||||
orig := msg
|
||||
msg = strings.TrimLeftFunc(msg, unicode.IsSpace)
|
||||
if strings.HasPrefix(msg, mentionStr) {
|
||||
return msg[len(mentionStr):]
|
||||
}
|
||||
return orig
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue