From bd5e8112a1368a522a1947361f9be98e87f9f9d8 Mon Sep 17 00:00:00 2001 From: Yiyang Kang Date: Sun, 19 Mar 2023 14:45:20 +0800 Subject: [PATCH] feat(AI): streaming response --- botcmd_translate.go | 65 +++++++++++++++++++++++++------- openai/chat.go | 25 ++++++++++++ openai/client.go | 92 ++++++++++++++++++++++++++++++++++++++++++++- 3 files changed, 168 insertions(+), 14 deletions(-) diff --git a/botcmd_translate.go b/botcmd_translate.go index 3f8ab65..5e85854 100644 --- a/botcmd_translate.go +++ b/botcmd_translate.go @@ -3,6 +3,7 @@ package main import ( "regexp" "strings" + "time" "github.com/samber/lo" tele "gopkg.in/telebot.v3" @@ -71,10 +72,6 @@ func handleTranslateBtn(c tele.Context) error { return nil } - // change the temporary message - if err := c.Edit("Sure, please wait..."); err != nil { - logger.Warnf("failed to alter the temporary message: %v", err) - } // pretend to be typing if err := c.Bot().Notify(msg.Chat, tele.Typing); err != nil { logger.Warnf("failed to send typing action: %v", err) @@ -98,23 +95,65 @@ func handleTranslateBtn(c tele.Context) error { } logger.Debugf("Openai chat request: %#+v", req) - resp, err := ai.ChatCompletion(req) + resp, err := ai.ChatCompletionStream(req) if err != nil { logger.Errorf("failed to translate: req: %#+v, err: %v", req, err) - _, err := c.Bot().Reply(origMsg, stickerFromID(stickerPanic), tele.Silent) + _, _ = c.Bot().Reply(origMsg, stickerFromID(stickerPanic), tele.Silent) return err } - respText := resp.Choices[0].Message.Content + respBuilder := strings.Builder{} + minWait := time.After(1 * time.Second) + for { + var ( + nNewChunk int + finished bool + minWaitSatisfied bool + ) + Drain: + for { + select { + case chunk, ok := <-resp.Stream: + if !ok { + finished = true + break Drain + } + nNewChunk += 1 + respBuilder.WriteString(chunk) + default: + if minWaitSatisfied { + break Drain + } + <-minWait + minWaitSatisfied = true + } + } + + if nNewChunk == 0 { + if chunk, ok := <-resp.Stream; !ok { + finished = true + } else { + respBuilder.WriteString(chunk) + } + } + if finished { + break + } + + respoText := respBuilder.String() + "\n... (Writting)" + minWait = time.After(691 * time.Millisecond) // renew the timer + if msg, err = c.Bot().Edit(msg, respoText, tele.Silent); err != nil { + logger.Warnf("failed to edit the temporary message: %v", err) + break + } + logger.Debugf("... message edited") + } + + respText := respBuilder.String() retryBtn := translateBtnRetry retryBtn.Data = targetLang respMenu := &tele.ReplyMarkup{} respMenu.Inline(respMenu.Row(retryBtn)) - _, err = c.Bot().Reply(origMsg, respText, tele.Silent, respMenu) - - // delete the temporary message - if err := c.Delete(); err != nil { - logger.Warnf("failed to delete the temporary message: %v", err) - } + _, err = c.Bot().Edit(msg, respText, tele.Silent, respMenu) return err } diff --git a/openai/chat.go b/openai/chat.go index 882477a..fc213a8 100644 --- a/openai/chat.go +++ b/openai/chat.go @@ -44,3 +44,28 @@ type ChatResponse struct { Usage map[string]int `json:"usage"` Choices []ChatResponseChoice `json:"choices"` } + +type ChatResponseStream struct { + ID string + Object string + Created int + Model string + Stream chan string + Err error +} + +type ChatResponseStreamChunk struct { + ID string + Object string + Created int + Model string + Choices []ChatResponseStreamChoice +} + +type ChatResponseStreamChoice struct { + Index int `json:"index"` + FinishReason string `json:"finish_reason"` + Delta struct { + Content string `json:"content"` + } `json:"delta"` +} diff --git a/openai/client.go b/openai/client.go index 6b34807..e7abc7f 100644 --- a/openai/client.go +++ b/openai/client.go @@ -1,6 +1,11 @@ package openai import ( + "bufio" + "encoding/json" + "io" + "net/http" + "strings" "time" "github.com/go-errors/errors" @@ -13,14 +18,21 @@ type Client struct { func NewClient(apiKey string) *Client { cli := resty.New(). + SetTransport(&http.Transport{ + Proxy: http.ProxyFromEnvironment, + ResponseHeaderTimeout: 10 * time.Second, + }). SetBaseURL("https://api.openai.com"). SetHeader("Authorization", "Bearer "+apiKey). - SetTimeout(30 * time.Second) + SetTimeout(5 * time.Minute) // hard cap return &Client{rest: cli} } func (c *Client) ChatCompletion(request ChatRequest) (*ChatResponse, error) { + // Note: this function might not work due to the header timeout set on the http client. + // We should probably not use this anyway. + resp, err := c.rest.R(). SetBody(request). SetHeader("Content-Type", "application/json"). @@ -37,3 +49,81 @@ func (c *Client) ChatCompletion(request ChatRequest) (*ChatResponse, error) { return resp.Result().(*ChatResponse), nil } + +func (c *Client) ChatCompletionStream(request ChatRequest) (*ChatResponseStream, error) { + request.Stream = true + resp, err := c.rest.R(). + SetBody(request). + SetHeader("Content-Type", "application/json"). + SetDoNotParseResponse(true). + Post(ChatAPIPath) + + if err != nil { + return nil, errors.Wrap(err, 0) + } + rbody := resp.RawBody() + if rbody == nil { + return nil, errors.New("no response body") + } + + if resp.StatusCode() != 200 { + defer rbody.Close() + var respBodyStr string + if respBody, err := io.ReadAll(rbody); err == nil { + respBodyStr = string(respBody) + } else { + respBodyStr = "failed to read: " + err.Error() + } + return nil, errors.Errorf("status code: %d, body: %q", resp.StatusCode(), respBodyStr) + } + + ret := &ChatResponseStream{Stream: make(chan string, 1024)} + go func() { + defer func() { + rbody.Close() + close(ret.Stream) + }() + + var contentBegan bool + sc := bufio.NewScanner(rbody) + for sc.Scan() { + line := sc.Bytes() + if string(line[:6]) != "data: " { + continue + } + data := line[6:] + if string(data[:6]) == "[DONE]" { + return + } + + var chunk ChatResponseStreamChunk + if err := json.Unmarshal(data, &chunk); err != nil { + ret.Err = errors.WrapPrefix(err, "failed to decode chunk:", 0) + return + } + if ret.ID == "" { + ret.ID = chunk.ID + ret.Object = chunk.Object + ret.Created = chunk.Created + ret.Model = chunk.Model + } + + var delta string + if len(chunk.Choices) > 0 { + delta = chunk.Choices[0].Delta.Content + } + if !contentBegan { + if strings.TrimSpace(delta) == "" { + continue + } + contentBegan = true + } + ret.Stream <- delta + } + if sc.Err() != nil { + ret.Err = errors.WrapPrefix(sc.Err(), "read response:", 0) + } + }() + + return ret, nil +}