2023-03-20 19:12:41 +09:00
package main
import (
2023-03-21 14:35:25 +09:00
"crypto/md5"
"encoding/base64"
2023-03-20 19:12:41 +09:00
"strconv"
"strings"
"time"
"unicode"
"git.gensokyo.cafe/kkyy/tgbot_misaka_5882f7/openai"
"git.gensokyo.cafe/kkyy/tgbot_misaka_5882f7/openai/prompts"
"github.com/samber/lo"
tele "gopkg.in/telebot.v3"
)
2023-03-20 22:45:59 +09:00
var assistantWritingSign = "\n... 📝"
2023-03-20 19:12:41 +09:00
func matchAssistantConversation ( botUsr * tele . User , msg * tele . Message ) [ ] * tele . Message {
// A thread must meet the following conditions to be considered a conversation with the assistant (for now):
// It has only two participants: the assistant and the user.
// Or, it has only one participant, which is the user. In this case, it has to be in a private chat, or the assistant must be mentioned if it's in a group chat.
// No bot commands, images, stickers, or other message types involved.
participants := map [ string ] struct { } { }
hasAssistant := false
hasUser := false
assistantMentioned := false
isPrivateChat := true
thread , err := getCachedThread ( msg )
if err != nil {
return nil
}
for _ , msg := range thread {
2023-03-20 20:49:59 +09:00
for _ , ent := range msg . Entities {
if ent . Type == tele . EntityCommand {
return nil
}
}
2023-03-20 19:12:41 +09:00
from := msg . Sender
chat := msg . Chat
if from == nil || chat == nil {
return nil
}
participant := strconv . FormatInt ( chat . ID , 10 ) + ":" + strconv . FormatInt ( from . ID , 10 )
participants [ participant ] = struct { } { }
if chat . Type != tele . ChatPrivate {
isPrivateChat = false
}
if ! lo . Contains ( [ ] tele . ChatType { tele . ChatPrivate , tele . ChatGroup , tele . ChatSuperGroup } , chat . Type ) {
return nil
}
if msg . Text == "" {
return nil
}
if from . ID == botUsr . ID {
hasAssistant = true
if strings . HasSuffix ( msg . Text , assistantWritingSign ) {
return nil
}
} else {
if from . IsBot {
return nil
}
hasUser = true
// Only @bot_user_name in the beginning will be counted as "mention"
mentionStr := "@" + botUsr . Username
if strings . HasPrefix ( strings . TrimSpace ( msg . Text ) , mentionStr ) {
assistantMentioned = true
}
}
}
if len ( participants ) > 2 {
return nil
}
if hasAssistant && hasUser {
return thread
}
if hasUser && ( isPrivateChat || assistantMentioned ) {
return thread
}
return nil
}
2023-03-20 22:19:53 +09:00
type assistantStreamedResponseCb func ( text string , finished bool ) error
2023-03-20 19:12:41 +09:00
2023-04-21 16:43:05 +09:00
// TODO interrupt response with context.
2023-03-20 19:12:41 +09:00
func assistantStreamedResponse ( request openai . ChatRequest , cb assistantStreamedResponseCb ) error {
2023-03-20 20:49:59 +09:00
logger . Debugw ( "Openai chat request" , "req" , request )
2023-03-20 19:12:41 +09:00
ai := openai . NewClient ( config . OpenAIApiKey )
2023-03-20 20:50:24 +09:00
var (
resp * openai . ChatResponseStream
err error
)
tries := 2
for {
if tries -- ; tries < 0 {
return err
}
if resp , err = ai . ChatCompletionStream ( request ) ; err == nil {
break
}
logger . Warnw ( "assistant: failed to get response" , "error" , err )
2023-03-20 19:12:41 +09:00
}
2023-03-21 03:03:02 +09:00
var nErrs , nUpdates int
2023-03-20 19:12:41 +09:00
go func ( ) {
respBuilder := strings . Builder { }
2023-03-20 22:19:53 +09:00
minWait := time . After ( time . Second )
2023-03-20 19:12:41 +09:00
for {
2023-03-21 03:03:02 +09:00
nUpdates += 1
2023-03-20 19:12:41 +09:00
var (
nNewChunk int
finished bool
minWaitSatisfied bool
)
Drain :
for {
select {
case chunk , ok := <- resp . Stream :
if ! ok {
finished = true
break Drain
}
nNewChunk += 1
respBuilder . WriteString ( chunk )
default :
if minWaitSatisfied {
break Drain
}
2023-03-21 13:39:49 +09:00
select {
case <- minWait :
case <- resp . Done :
}
2023-03-20 19:12:41 +09:00
minWaitSatisfied = true
}
}
if nNewChunk == 0 {
if chunk , ok := <- resp . Stream ; ! ok {
finished = true
} else {
respBuilder . WriteString ( chunk )
}
}
if finished {
break
}
2023-03-21 03:03:02 +09:00
// Renew the timer.
// The Telegram API rate limit for group messages is 20 per minute. So we cannot update messages too frequently.
minWaitDurSecs := lo . Min ( [ ] int { nUpdates , 4 } ) + nErrs * 3
minWait = time . After ( time . Duration ( minWaitDurSecs ) * time . Second )
2023-03-20 19:12:41 +09:00
2023-03-21 13:39:49 +09:00
// Send the partial message
2023-03-21 03:03:02 +09:00
respText := respBuilder . String ( ) + assistantWritingSign
if err := cb ( respText , false ) ; err != nil {
2023-03-20 19:12:41 +09:00
logger . Warnw ( "failed to send partial update" , "error" , err )
nErrs += 1
2023-03-21 03:03:02 +09:00
if nErrs > 5 {
2023-03-20 19:12:41 +09:00
logger . Warnw ( "too many errors, aborting" )
return
}
continue
}
logger . Debugf ( "... message edited" )
}
respText := respBuilder . String ( )
2023-03-20 22:19:53 +09:00
if err = cb ( respText , true ) ; err != nil {
2023-03-20 19:12:41 +09:00
logger . Warnw ( "assistant: failed to send message" , "error" , err )
}
} ( )
return nil
}
func handleAssistantConversation ( c tele . Context , thread [ ] * tele . Message ) error {
me := c . Bot ( ) . Me
lastMsg := thread [ len ( thread ) - 1 ]
if err := cacheMessage ( lastMsg ) ; err != nil {
logger . Warnw ( "failed to cache message" , "error" , err )
}
nBytes := 0 // Used to estimated number of tokens. For now we treat 3 bytes as 1 token.
nBytesMax := ( 4096 - 512 ) * 3 // Leave some space for the response
sysMsg := prompts . Assistant ( )
chatReqMsgs := [ ] openai . ChatMessage {
{
Role : openai . ChatRoleSystem ,
Content : sysMsg ,
} ,
}
nBytes += len ( sysMsg )
convMsgs := [ ] openai . ChatMessage { }
for l := len ( thread ) - 1 ; l >= 0 ; l -- {
text := assistantRemoveMention ( thread [ l ] . Text , me . Username )
textLen := len ( text )
if textLen + nBytes > nBytesMax {
break
}
nBytes += textLen
role := openai . ChatRoleUser
from := thread [ l ] . Sender
if from != nil && from . ID == me . ID {
2023-03-20 20:49:59 +09:00
role = openai . ChatRoleAssistant
2023-03-20 19:12:41 +09:00
}
convMsgs = append ( convMsgs , openai . ChatMessage {
Role : role ,
Content : text ,
} )
}
if len ( convMsgs ) == 0 {
2023-03-20 22:19:53 +09:00
// It turns out that this will never happen because Telegram splits messages when they exceed a certain length.
2023-03-20 19:12:41 +09:00
return c . Reply ( "Your message is too long (Sorry!)" )
}
for l := len ( convMsgs ) - 1 ; l >= 0 ; l -- {
chatReqMsgs = append ( chatReqMsgs , convMsgs [ l ] )
}
req := openai . ChatRequest {
Model : openai . ModelGpt0305Turbo ,
Messages : chatReqMsgs ,
Temperature : lo . ToPtr ( 0.42 ) ,
MaxTokens : 2048 ,
2023-03-21 14:35:25 +09:00
User : assistantHashUserId ( lastMsg . Sender . ID ) ,
2023-03-20 19:12:41 +09:00
}
2023-03-20 22:45:59 +09:00
typingNotifyCh := setTyping ( c )
2023-03-20 19:12:41 +09:00
var replyMsg * tele . Message
2023-03-20 22:19:53 +09:00
reqErr := assistantStreamedResponse ( req , func ( text string , finished bool ) error {
2023-03-20 19:12:41 +09:00
var err error
if replyMsg == nil {
<- typingNotifyCh
2023-03-20 20:49:59 +09:00
replyMsg , err = c . Bot ( ) . Reply ( lastMsg , text , tele . Silent )
2023-03-20 19:12:41 +09:00
} else {
replyMsg , err = c . Bot ( ) . Edit ( replyMsg , text )
}
if finished && err == nil {
2023-03-20 20:49:59 +09:00
replyMsg . ReplyTo = lastMsg // nasty bug
2023-03-20 19:12:41 +09:00
if err := cacheMessage ( replyMsg ) ; err != nil {
logger . Warnw ( "failed to cache message" , "error" , err )
}
}
2023-03-20 22:19:53 +09:00
return err
2023-03-20 19:12:41 +09:00
} )
if reqErr != nil {
logger . Errorw ( "assistant: failed to complete conversation" , "error" , reqErr )
2023-03-20 20:50:24 +09:00
return c . Reply ( "Sorry, there's a technical issue. 😵💫 Please try again later." , tele . Silent )
2023-03-20 19:12:41 +09:00
}
2023-03-20 20:50:24 +09:00
return nil
2023-03-20 19:12:41 +09:00
}
func assistantRemoveMention ( msg , name string ) string {
mentionStr := "@" + name
orig := msg
msg = strings . TrimLeftFunc ( msg , unicode . IsSpace )
if strings . HasPrefix ( msg , mentionStr ) {
return msg [ len ( mentionStr ) : ]
}
return orig
}
2023-03-21 14:35:25 +09:00
func assistantHashUserId ( uid int64 ) string {
seasoned := [ ] byte ( "RdnuRPqp66vtbc28QRO0ecKSLKXifz7G9UbXLoyCMpw" + strconv . FormatInt ( uid , 10 ) )
hashed := md5 . Sum ( seasoned ) // Don't judge me
return base64 . URLEncoding . EncodeToString ( hashed [ : ] ) [ : 22 ]
}