feat(AI): streaming response

This commit is contained in:
Yiyang Kang 2023-03-19 14:45:20 +08:00
parent 7b2d3c31e5
commit bd5e8112a1
Signed by: kkyy
GPG Key ID: 80FD317ECAF06CC3
3 changed files with 168 additions and 14 deletions

View File

@ -3,6 +3,7 @@ package main
import (
"regexp"
"strings"
"time"
"github.com/samber/lo"
tele "gopkg.in/telebot.v3"
@ -71,10 +72,6 @@ func handleTranslateBtn(c tele.Context) error {
return nil
}
// change the temporary message
if err := c.Edit("Sure, please wait..."); err != nil {
logger.Warnf("failed to alter the temporary message: %v", err)
}
// pretend to be typing
if err := c.Bot().Notify(msg.Chat, tele.Typing); err != nil {
logger.Warnf("failed to send typing action: %v", err)
@ -98,23 +95,65 @@ func handleTranslateBtn(c tele.Context) error {
}
logger.Debugf("Openai chat request: %#+v", req)
resp, err := ai.ChatCompletion(req)
resp, err := ai.ChatCompletionStream(req)
if err != nil {
logger.Errorf("failed to translate: req: %#+v, err: %v", req, err)
_, err := c.Bot().Reply(origMsg, stickerFromID(stickerPanic), tele.Silent)
_, _ = c.Bot().Reply(origMsg, stickerFromID(stickerPanic), tele.Silent)
return err
}
respText := resp.Choices[0].Message.Content
respBuilder := strings.Builder{}
minWait := time.After(1 * time.Second)
for {
var (
nNewChunk int
finished bool
minWaitSatisfied bool
)
Drain:
for {
select {
case chunk, ok := <-resp.Stream:
if !ok {
finished = true
break Drain
}
nNewChunk += 1
respBuilder.WriteString(chunk)
default:
if minWaitSatisfied {
break Drain
}
<-minWait
minWaitSatisfied = true
}
}
if nNewChunk == 0 {
if chunk, ok := <-resp.Stream; !ok {
finished = true
} else {
respBuilder.WriteString(chunk)
}
}
if finished {
break
}
respoText := respBuilder.String() + "\n... (Writting)"
minWait = time.After(691 * time.Millisecond) // renew the timer
if msg, err = c.Bot().Edit(msg, respoText, tele.Silent); err != nil {
logger.Warnf("failed to edit the temporary message: %v", err)
break
}
logger.Debugf("... message edited")
}
respText := respBuilder.String()
retryBtn := translateBtnRetry
retryBtn.Data = targetLang
respMenu := &tele.ReplyMarkup{}
respMenu.Inline(respMenu.Row(retryBtn))
_, err = c.Bot().Reply(origMsg, respText, tele.Silent, respMenu)
// delete the temporary message
if err := c.Delete(); err != nil {
logger.Warnf("failed to delete the temporary message: %v", err)
}
_, err = c.Bot().Edit(msg, respText, tele.Silent, respMenu)
return err
}

View File

@ -44,3 +44,28 @@ type ChatResponse struct {
Usage map[string]int `json:"usage"`
Choices []ChatResponseChoice `json:"choices"`
}
type ChatResponseStream struct {
ID string
Object string
Created int
Model string
Stream chan string
Err error
}
type ChatResponseStreamChunk struct {
ID string
Object string
Created int
Model string
Choices []ChatResponseStreamChoice
}
type ChatResponseStreamChoice struct {
Index int `json:"index"`
FinishReason string `json:"finish_reason"`
Delta struct {
Content string `json:"content"`
} `json:"delta"`
}

View File

@ -1,6 +1,11 @@
package openai
import (
"bufio"
"encoding/json"
"io"
"net/http"
"strings"
"time"
"github.com/go-errors/errors"
@ -13,14 +18,21 @@ type Client struct {
func NewClient(apiKey string) *Client {
cli := resty.New().
SetTransport(&http.Transport{
Proxy: http.ProxyFromEnvironment,
ResponseHeaderTimeout: 10 * time.Second,
}).
SetBaseURL("https://api.openai.com").
SetHeader("Authorization", "Bearer "+apiKey).
SetTimeout(30 * time.Second)
SetTimeout(5 * time.Minute) // hard cap
return &Client{rest: cli}
}
func (c *Client) ChatCompletion(request ChatRequest) (*ChatResponse, error) {
// Note: this function might not work due to the header timeout set on the http client.
// We should probably not use this anyway.
resp, err := c.rest.R().
SetBody(request).
SetHeader("Content-Type", "application/json").
@ -37,3 +49,81 @@ func (c *Client) ChatCompletion(request ChatRequest) (*ChatResponse, error) {
return resp.Result().(*ChatResponse), nil
}
func (c *Client) ChatCompletionStream(request ChatRequest) (*ChatResponseStream, error) {
request.Stream = true
resp, err := c.rest.R().
SetBody(request).
SetHeader("Content-Type", "application/json").
SetDoNotParseResponse(true).
Post(ChatAPIPath)
if err != nil {
return nil, errors.Wrap(err, 0)
}
rbody := resp.RawBody()
if rbody == nil {
return nil, errors.New("no response body")
}
if resp.StatusCode() != 200 {
defer rbody.Close()
var respBodyStr string
if respBody, err := io.ReadAll(rbody); err == nil {
respBodyStr = string(respBody)
} else {
respBodyStr = "failed to read: " + err.Error()
}
return nil, errors.Errorf("status code: %d, body: %q", resp.StatusCode(), respBodyStr)
}
ret := &ChatResponseStream{Stream: make(chan string, 1024)}
go func() {
defer func() {
rbody.Close()
close(ret.Stream)
}()
var contentBegan bool
sc := bufio.NewScanner(rbody)
for sc.Scan() {
line := sc.Bytes()
if string(line[:6]) != "data: " {
continue
}
data := line[6:]
if string(data[:6]) == "[DONE]" {
return
}
var chunk ChatResponseStreamChunk
if err := json.Unmarshal(data, &chunk); err != nil {
ret.Err = errors.WrapPrefix(err, "failed to decode chunk:", 0)
return
}
if ret.ID == "" {
ret.ID = chunk.ID
ret.Object = chunk.Object
ret.Created = chunk.Created
ret.Model = chunk.Model
}
var delta string
if len(chunk.Choices) > 0 {
delta = chunk.Choices[0].Delta.Content
}
if !contentBegan {
if strings.TrimSpace(delta) == "" {
continue
}
contentBegan = true
}
ret.Stream <- delta
}
if sc.Err() != nil {
ret.Err = errors.WrapPrefix(sc.Err(), "read response:", 0)
}
}()
return ret, nil
}