feat(AI): streaming response
This commit is contained in:
parent
7b2d3c31e5
commit
bd5e8112a1
|
@ -3,6 +3,7 @@ package main
|
|||
import (
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/samber/lo"
|
||||
tele "gopkg.in/telebot.v3"
|
||||
|
@ -71,10 +72,6 @@ func handleTranslateBtn(c tele.Context) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// change the temporary message
|
||||
if err := c.Edit("Sure, please wait..."); err != nil {
|
||||
logger.Warnf("failed to alter the temporary message: %v", err)
|
||||
}
|
||||
// pretend to be typing
|
||||
if err := c.Bot().Notify(msg.Chat, tele.Typing); err != nil {
|
||||
logger.Warnf("failed to send typing action: %v", err)
|
||||
|
@ -98,23 +95,65 @@ func handleTranslateBtn(c tele.Context) error {
|
|||
}
|
||||
logger.Debugf("Openai chat request: %#+v", req)
|
||||
|
||||
resp, err := ai.ChatCompletion(req)
|
||||
resp, err := ai.ChatCompletionStream(req)
|
||||
if err != nil {
|
||||
logger.Errorf("failed to translate: req: %#+v, err: %v", req, err)
|
||||
_, err := c.Bot().Reply(origMsg, stickerFromID(stickerPanic), tele.Silent)
|
||||
_, _ = c.Bot().Reply(origMsg, stickerFromID(stickerPanic), tele.Silent)
|
||||
return err
|
||||
}
|
||||
|
||||
respText := resp.Choices[0].Message.Content
|
||||
respBuilder := strings.Builder{}
|
||||
minWait := time.After(1 * time.Second)
|
||||
for {
|
||||
var (
|
||||
nNewChunk int
|
||||
finished bool
|
||||
minWaitSatisfied bool
|
||||
)
|
||||
Drain:
|
||||
for {
|
||||
select {
|
||||
case chunk, ok := <-resp.Stream:
|
||||
if !ok {
|
||||
finished = true
|
||||
break Drain
|
||||
}
|
||||
nNewChunk += 1
|
||||
respBuilder.WriteString(chunk)
|
||||
default:
|
||||
if minWaitSatisfied {
|
||||
break Drain
|
||||
}
|
||||
<-minWait
|
||||
minWaitSatisfied = true
|
||||
}
|
||||
}
|
||||
|
||||
if nNewChunk == 0 {
|
||||
if chunk, ok := <-resp.Stream; !ok {
|
||||
finished = true
|
||||
} else {
|
||||
respBuilder.WriteString(chunk)
|
||||
}
|
||||
}
|
||||
if finished {
|
||||
break
|
||||
}
|
||||
|
||||
respoText := respBuilder.String() + "\n... (Writting)"
|
||||
minWait = time.After(691 * time.Millisecond) // renew the timer
|
||||
if msg, err = c.Bot().Edit(msg, respoText, tele.Silent); err != nil {
|
||||
logger.Warnf("failed to edit the temporary message: %v", err)
|
||||
break
|
||||
}
|
||||
logger.Debugf("... message edited")
|
||||
}
|
||||
|
||||
respText := respBuilder.String()
|
||||
retryBtn := translateBtnRetry
|
||||
retryBtn.Data = targetLang
|
||||
respMenu := &tele.ReplyMarkup{}
|
||||
respMenu.Inline(respMenu.Row(retryBtn))
|
||||
_, err = c.Bot().Reply(origMsg, respText, tele.Silent, respMenu)
|
||||
|
||||
// delete the temporary message
|
||||
if err := c.Delete(); err != nil {
|
||||
logger.Warnf("failed to delete the temporary message: %v", err)
|
||||
}
|
||||
_, err = c.Bot().Edit(msg, respText, tele.Silent, respMenu)
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -44,3 +44,28 @@ type ChatResponse struct {
|
|||
Usage map[string]int `json:"usage"`
|
||||
Choices []ChatResponseChoice `json:"choices"`
|
||||
}
|
||||
|
||||
type ChatResponseStream struct {
|
||||
ID string
|
||||
Object string
|
||||
Created int
|
||||
Model string
|
||||
Stream chan string
|
||||
Err error
|
||||
}
|
||||
|
||||
type ChatResponseStreamChunk struct {
|
||||
ID string
|
||||
Object string
|
||||
Created int
|
||||
Model string
|
||||
Choices []ChatResponseStreamChoice
|
||||
}
|
||||
|
||||
type ChatResponseStreamChoice struct {
|
||||
Index int `json:"index"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
Delta struct {
|
||||
Content string `json:"content"`
|
||||
} `json:"delta"`
|
||||
}
|
||||
|
|
|
@ -1,6 +1,11 @@
|
|||
package openai
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-errors/errors"
|
||||
|
@ -13,14 +18,21 @@ type Client struct {
|
|||
|
||||
func NewClient(apiKey string) *Client {
|
||||
cli := resty.New().
|
||||
SetTransport(&http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
ResponseHeaderTimeout: 10 * time.Second,
|
||||
}).
|
||||
SetBaseURL("https://api.openai.com").
|
||||
SetHeader("Authorization", "Bearer "+apiKey).
|
||||
SetTimeout(30 * time.Second)
|
||||
SetTimeout(5 * time.Minute) // hard cap
|
||||
|
||||
return &Client{rest: cli}
|
||||
}
|
||||
|
||||
func (c *Client) ChatCompletion(request ChatRequest) (*ChatResponse, error) {
|
||||
// Note: this function might not work due to the header timeout set on the http client.
|
||||
// We should probably not use this anyway.
|
||||
|
||||
resp, err := c.rest.R().
|
||||
SetBody(request).
|
||||
SetHeader("Content-Type", "application/json").
|
||||
|
@ -37,3 +49,81 @@ func (c *Client) ChatCompletion(request ChatRequest) (*ChatResponse, error) {
|
|||
|
||||
return resp.Result().(*ChatResponse), nil
|
||||
}
|
||||
|
||||
func (c *Client) ChatCompletionStream(request ChatRequest) (*ChatResponseStream, error) {
|
||||
request.Stream = true
|
||||
resp, err := c.rest.R().
|
||||
SetBody(request).
|
||||
SetHeader("Content-Type", "application/json").
|
||||
SetDoNotParseResponse(true).
|
||||
Post(ChatAPIPath)
|
||||
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, 0)
|
||||
}
|
||||
rbody := resp.RawBody()
|
||||
if rbody == nil {
|
||||
return nil, errors.New("no response body")
|
||||
}
|
||||
|
||||
if resp.StatusCode() != 200 {
|
||||
defer rbody.Close()
|
||||
var respBodyStr string
|
||||
if respBody, err := io.ReadAll(rbody); err == nil {
|
||||
respBodyStr = string(respBody)
|
||||
} else {
|
||||
respBodyStr = "failed to read: " + err.Error()
|
||||
}
|
||||
return nil, errors.Errorf("status code: %d, body: %q", resp.StatusCode(), respBodyStr)
|
||||
}
|
||||
|
||||
ret := &ChatResponseStream{Stream: make(chan string, 1024)}
|
||||
go func() {
|
||||
defer func() {
|
||||
rbody.Close()
|
||||
close(ret.Stream)
|
||||
}()
|
||||
|
||||
var contentBegan bool
|
||||
sc := bufio.NewScanner(rbody)
|
||||
for sc.Scan() {
|
||||
line := sc.Bytes()
|
||||
if string(line[:6]) != "data: " {
|
||||
continue
|
||||
}
|
||||
data := line[6:]
|
||||
if string(data[:6]) == "[DONE]" {
|
||||
return
|
||||
}
|
||||
|
||||
var chunk ChatResponseStreamChunk
|
||||
if err := json.Unmarshal(data, &chunk); err != nil {
|
||||
ret.Err = errors.WrapPrefix(err, "failed to decode chunk:", 0)
|
||||
return
|
||||
}
|
||||
if ret.ID == "" {
|
||||
ret.ID = chunk.ID
|
||||
ret.Object = chunk.Object
|
||||
ret.Created = chunk.Created
|
||||
ret.Model = chunk.Model
|
||||
}
|
||||
|
||||
var delta string
|
||||
if len(chunk.Choices) > 0 {
|
||||
delta = chunk.Choices[0].Delta.Content
|
||||
}
|
||||
if !contentBegan {
|
||||
if strings.TrimSpace(delta) == "" {
|
||||
continue
|
||||
}
|
||||
contentBegan = true
|
||||
}
|
||||
ret.Stream <- delta
|
||||
}
|
||||
if sc.Err() != nil {
|
||||
ret.Err = errors.WrapPrefix(sc.Err(), "read response:", 0)
|
||||
}
|
||||
}()
|
||||
|
||||
return ret, nil
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue