feat(AI): streaming response
This commit is contained in:
parent
7b2d3c31e5
commit
bd5e8112a1
|
@ -3,6 +3,7 @@ package main
|
||||||
import (
|
import (
|
||||||
"regexp"
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/samber/lo"
|
"github.com/samber/lo"
|
||||||
tele "gopkg.in/telebot.v3"
|
tele "gopkg.in/telebot.v3"
|
||||||
|
@ -71,10 +72,6 @@ func handleTranslateBtn(c tele.Context) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// change the temporary message
|
|
||||||
if err := c.Edit("Sure, please wait..."); err != nil {
|
|
||||||
logger.Warnf("failed to alter the temporary message: %v", err)
|
|
||||||
}
|
|
||||||
// pretend to be typing
|
// pretend to be typing
|
||||||
if err := c.Bot().Notify(msg.Chat, tele.Typing); err != nil {
|
if err := c.Bot().Notify(msg.Chat, tele.Typing); err != nil {
|
||||||
logger.Warnf("failed to send typing action: %v", err)
|
logger.Warnf("failed to send typing action: %v", err)
|
||||||
|
@ -98,23 +95,65 @@ func handleTranslateBtn(c tele.Context) error {
|
||||||
}
|
}
|
||||||
logger.Debugf("Openai chat request: %#+v", req)
|
logger.Debugf("Openai chat request: %#+v", req)
|
||||||
|
|
||||||
resp, err := ai.ChatCompletion(req)
|
resp, err := ai.ChatCompletionStream(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Errorf("failed to translate: req: %#+v, err: %v", req, err)
|
logger.Errorf("failed to translate: req: %#+v, err: %v", req, err)
|
||||||
_, err := c.Bot().Reply(origMsg, stickerFromID(stickerPanic), tele.Silent)
|
_, _ = c.Bot().Reply(origMsg, stickerFromID(stickerPanic), tele.Silent)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
respText := resp.Choices[0].Message.Content
|
respBuilder := strings.Builder{}
|
||||||
|
minWait := time.After(1 * time.Second)
|
||||||
|
for {
|
||||||
|
var (
|
||||||
|
nNewChunk int
|
||||||
|
finished bool
|
||||||
|
minWaitSatisfied bool
|
||||||
|
)
|
||||||
|
Drain:
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case chunk, ok := <-resp.Stream:
|
||||||
|
if !ok {
|
||||||
|
finished = true
|
||||||
|
break Drain
|
||||||
|
}
|
||||||
|
nNewChunk += 1
|
||||||
|
respBuilder.WriteString(chunk)
|
||||||
|
default:
|
||||||
|
if minWaitSatisfied {
|
||||||
|
break Drain
|
||||||
|
}
|
||||||
|
<-minWait
|
||||||
|
minWaitSatisfied = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if nNewChunk == 0 {
|
||||||
|
if chunk, ok := <-resp.Stream; !ok {
|
||||||
|
finished = true
|
||||||
|
} else {
|
||||||
|
respBuilder.WriteString(chunk)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if finished {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
respoText := respBuilder.String() + "\n... (Writting)"
|
||||||
|
minWait = time.After(691 * time.Millisecond) // renew the timer
|
||||||
|
if msg, err = c.Bot().Edit(msg, respoText, tele.Silent); err != nil {
|
||||||
|
logger.Warnf("failed to edit the temporary message: %v", err)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
logger.Debugf("... message edited")
|
||||||
|
}
|
||||||
|
|
||||||
|
respText := respBuilder.String()
|
||||||
retryBtn := translateBtnRetry
|
retryBtn := translateBtnRetry
|
||||||
retryBtn.Data = targetLang
|
retryBtn.Data = targetLang
|
||||||
respMenu := &tele.ReplyMarkup{}
|
respMenu := &tele.ReplyMarkup{}
|
||||||
respMenu.Inline(respMenu.Row(retryBtn))
|
respMenu.Inline(respMenu.Row(retryBtn))
|
||||||
_, err = c.Bot().Reply(origMsg, respText, tele.Silent, respMenu)
|
_, err = c.Bot().Edit(msg, respText, tele.Silent, respMenu)
|
||||||
|
|
||||||
// delete the temporary message
|
|
||||||
if err := c.Delete(); err != nil {
|
|
||||||
logger.Warnf("failed to delete the temporary message: %v", err)
|
|
||||||
}
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -44,3 +44,28 @@ type ChatResponse struct {
|
||||||
Usage map[string]int `json:"usage"`
|
Usage map[string]int `json:"usage"`
|
||||||
Choices []ChatResponseChoice `json:"choices"`
|
Choices []ChatResponseChoice `json:"choices"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type ChatResponseStream struct {
|
||||||
|
ID string
|
||||||
|
Object string
|
||||||
|
Created int
|
||||||
|
Model string
|
||||||
|
Stream chan string
|
||||||
|
Err error
|
||||||
|
}
|
||||||
|
|
||||||
|
type ChatResponseStreamChunk struct {
|
||||||
|
ID string
|
||||||
|
Object string
|
||||||
|
Created int
|
||||||
|
Model string
|
||||||
|
Choices []ChatResponseStreamChoice
|
||||||
|
}
|
||||||
|
|
||||||
|
type ChatResponseStreamChoice struct {
|
||||||
|
Index int `json:"index"`
|
||||||
|
FinishReason string `json:"finish_reason"`
|
||||||
|
Delta struct {
|
||||||
|
Content string `json:"content"`
|
||||||
|
} `json:"delta"`
|
||||||
|
}
|
||||||
|
|
|
@ -1,6 +1,11 @@
|
||||||
package openai
|
package openai
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bufio"
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/go-errors/errors"
|
"github.com/go-errors/errors"
|
||||||
|
@ -13,14 +18,21 @@ type Client struct {
|
||||||
|
|
||||||
func NewClient(apiKey string) *Client {
|
func NewClient(apiKey string) *Client {
|
||||||
cli := resty.New().
|
cli := resty.New().
|
||||||
|
SetTransport(&http.Transport{
|
||||||
|
Proxy: http.ProxyFromEnvironment,
|
||||||
|
ResponseHeaderTimeout: 10 * time.Second,
|
||||||
|
}).
|
||||||
SetBaseURL("https://api.openai.com").
|
SetBaseURL("https://api.openai.com").
|
||||||
SetHeader("Authorization", "Bearer "+apiKey).
|
SetHeader("Authorization", "Bearer "+apiKey).
|
||||||
SetTimeout(30 * time.Second)
|
SetTimeout(5 * time.Minute) // hard cap
|
||||||
|
|
||||||
return &Client{rest: cli}
|
return &Client{rest: cli}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) ChatCompletion(request ChatRequest) (*ChatResponse, error) {
|
func (c *Client) ChatCompletion(request ChatRequest) (*ChatResponse, error) {
|
||||||
|
// Note: this function might not work due to the header timeout set on the http client.
|
||||||
|
// We should probably not use this anyway.
|
||||||
|
|
||||||
resp, err := c.rest.R().
|
resp, err := c.rest.R().
|
||||||
SetBody(request).
|
SetBody(request).
|
||||||
SetHeader("Content-Type", "application/json").
|
SetHeader("Content-Type", "application/json").
|
||||||
|
@ -37,3 +49,81 @@ func (c *Client) ChatCompletion(request ChatRequest) (*ChatResponse, error) {
|
||||||
|
|
||||||
return resp.Result().(*ChatResponse), nil
|
return resp.Result().(*ChatResponse), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Client) ChatCompletionStream(request ChatRequest) (*ChatResponseStream, error) {
|
||||||
|
request.Stream = true
|
||||||
|
resp, err := c.rest.R().
|
||||||
|
SetBody(request).
|
||||||
|
SetHeader("Content-Type", "application/json").
|
||||||
|
SetDoNotParseResponse(true).
|
||||||
|
Post(ChatAPIPath)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, 0)
|
||||||
|
}
|
||||||
|
rbody := resp.RawBody()
|
||||||
|
if rbody == nil {
|
||||||
|
return nil, errors.New("no response body")
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode() != 200 {
|
||||||
|
defer rbody.Close()
|
||||||
|
var respBodyStr string
|
||||||
|
if respBody, err := io.ReadAll(rbody); err == nil {
|
||||||
|
respBodyStr = string(respBody)
|
||||||
|
} else {
|
||||||
|
respBodyStr = "failed to read: " + err.Error()
|
||||||
|
}
|
||||||
|
return nil, errors.Errorf("status code: %d, body: %q", resp.StatusCode(), respBodyStr)
|
||||||
|
}
|
||||||
|
|
||||||
|
ret := &ChatResponseStream{Stream: make(chan string, 1024)}
|
||||||
|
go func() {
|
||||||
|
defer func() {
|
||||||
|
rbody.Close()
|
||||||
|
close(ret.Stream)
|
||||||
|
}()
|
||||||
|
|
||||||
|
var contentBegan bool
|
||||||
|
sc := bufio.NewScanner(rbody)
|
||||||
|
for sc.Scan() {
|
||||||
|
line := sc.Bytes()
|
||||||
|
if string(line[:6]) != "data: " {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
data := line[6:]
|
||||||
|
if string(data[:6]) == "[DONE]" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var chunk ChatResponseStreamChunk
|
||||||
|
if err := json.Unmarshal(data, &chunk); err != nil {
|
||||||
|
ret.Err = errors.WrapPrefix(err, "failed to decode chunk:", 0)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if ret.ID == "" {
|
||||||
|
ret.ID = chunk.ID
|
||||||
|
ret.Object = chunk.Object
|
||||||
|
ret.Created = chunk.Created
|
||||||
|
ret.Model = chunk.Model
|
||||||
|
}
|
||||||
|
|
||||||
|
var delta string
|
||||||
|
if len(chunk.Choices) > 0 {
|
||||||
|
delta = chunk.Choices[0].Delta.Content
|
||||||
|
}
|
||||||
|
if !contentBegan {
|
||||||
|
if strings.TrimSpace(delta) == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
contentBegan = true
|
||||||
|
}
|
||||||
|
ret.Stream <- delta
|
||||||
|
}
|
||||||
|
if sc.Err() != nil {
|
||||||
|
ret.Err = errors.WrapPrefix(sc.Err(), "read response:", 0)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return ret, nil
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue