feat(AI): streaming response
This commit is contained in:
		
							parent
							
								
									7b2d3c31e5
								
							
						
					
					
						commit
						bd5e8112a1
					
				| 
						 | 
				
			
			@ -3,6 +3,7 @@ package main
 | 
			
		|||
import (
 | 
			
		||||
	"regexp"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/samber/lo"
 | 
			
		||||
	tele "gopkg.in/telebot.v3"
 | 
			
		||||
| 
						 | 
				
			
			@ -71,10 +72,6 @@ func handleTranslateBtn(c tele.Context) error {
 | 
			
		|||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// change the temporary message
 | 
			
		||||
	if err := c.Edit("Sure, please wait..."); err != nil {
 | 
			
		||||
		logger.Warnf("failed to alter the temporary message: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
	// pretend to be typing
 | 
			
		||||
	if err := c.Bot().Notify(msg.Chat, tele.Typing); err != nil {
 | 
			
		||||
		logger.Warnf("failed to send typing action: %v", err)
 | 
			
		||||
| 
						 | 
				
			
			@ -98,23 +95,65 @@ func handleTranslateBtn(c tele.Context) error {
 | 
			
		|||
	}
 | 
			
		||||
	logger.Debugf("Openai chat request: %#+v", req)
 | 
			
		||||
 | 
			
		||||
	resp, err := ai.ChatCompletion(req)
 | 
			
		||||
	resp, err := ai.ChatCompletionStream(req)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logger.Errorf("failed to translate: req: %#+v, err: %v", req, err)
 | 
			
		||||
		_, err := c.Bot().Reply(origMsg, stickerFromID(stickerPanic), tele.Silent)
 | 
			
		||||
		_, _ = c.Bot().Reply(origMsg, stickerFromID(stickerPanic), tele.Silent)
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	respText := resp.Choices[0].Message.Content
 | 
			
		||||
	respBuilder := strings.Builder{}
 | 
			
		||||
	minWait := time.After(1 * time.Second)
 | 
			
		||||
	for {
 | 
			
		||||
		var (
 | 
			
		||||
			nNewChunk        int
 | 
			
		||||
			finished         bool
 | 
			
		||||
			minWaitSatisfied bool
 | 
			
		||||
		)
 | 
			
		||||
	Drain:
 | 
			
		||||
		for {
 | 
			
		||||
			select {
 | 
			
		||||
			case chunk, ok := <-resp.Stream:
 | 
			
		||||
				if !ok {
 | 
			
		||||
					finished = true
 | 
			
		||||
					break Drain
 | 
			
		||||
				}
 | 
			
		||||
				nNewChunk += 1
 | 
			
		||||
				respBuilder.WriteString(chunk)
 | 
			
		||||
			default:
 | 
			
		||||
				if minWaitSatisfied {
 | 
			
		||||
					break Drain
 | 
			
		||||
				}
 | 
			
		||||
				<-minWait
 | 
			
		||||
				minWaitSatisfied = true
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if nNewChunk == 0 {
 | 
			
		||||
			if chunk, ok := <-resp.Stream; !ok {
 | 
			
		||||
				finished = true
 | 
			
		||||
			} else {
 | 
			
		||||
				respBuilder.WriteString(chunk)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		if finished {
 | 
			
		||||
			break
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		respoText := respBuilder.String() + "\n... (Writting)"
 | 
			
		||||
		minWait = time.After(691 * time.Millisecond) // renew the timer
 | 
			
		||||
		if msg, err = c.Bot().Edit(msg, respoText, tele.Silent); err != nil {
 | 
			
		||||
			logger.Warnf("failed to edit the temporary message: %v", err)
 | 
			
		||||
			break
 | 
			
		||||
		}
 | 
			
		||||
		logger.Debugf("... message edited")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	respText := respBuilder.String()
 | 
			
		||||
	retryBtn := translateBtnRetry
 | 
			
		||||
	retryBtn.Data = targetLang
 | 
			
		||||
	respMenu := &tele.ReplyMarkup{}
 | 
			
		||||
	respMenu.Inline(respMenu.Row(retryBtn))
 | 
			
		||||
	_, err = c.Bot().Reply(origMsg, respText, tele.Silent, respMenu)
 | 
			
		||||
 | 
			
		||||
	// delete the temporary message
 | 
			
		||||
	if err := c.Delete(); err != nil {
 | 
			
		||||
		logger.Warnf("failed to delete the temporary message: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
	_, err = c.Bot().Edit(msg, respText, tele.Silent, respMenu)
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -44,3 +44,28 @@ type ChatResponse struct {
 | 
			
		|||
	Usage   map[string]int       `json:"usage"`
 | 
			
		||||
	Choices []ChatResponseChoice `json:"choices"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type ChatResponseStream struct {
 | 
			
		||||
	ID      string
 | 
			
		||||
	Object  string
 | 
			
		||||
	Created int
 | 
			
		||||
	Model   string
 | 
			
		||||
	Stream  chan string
 | 
			
		||||
	Err     error
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type ChatResponseStreamChunk struct {
 | 
			
		||||
	ID      string
 | 
			
		||||
	Object  string
 | 
			
		||||
	Created int
 | 
			
		||||
	Model   string
 | 
			
		||||
	Choices []ChatResponseStreamChoice
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type ChatResponseStreamChoice struct {
 | 
			
		||||
	Index        int    `json:"index"`
 | 
			
		||||
	FinishReason string `json:"finish_reason"`
 | 
			
		||||
	Delta        struct {
 | 
			
		||||
		Content string `json:"content"`
 | 
			
		||||
	} `json:"delta"`
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,6 +1,11 @@
 | 
			
		|||
package openai
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bufio"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/go-errors/errors"
 | 
			
		||||
| 
						 | 
				
			
			@ -13,14 +18,21 @@ type Client struct {
 | 
			
		|||
 | 
			
		||||
func NewClient(apiKey string) *Client {
 | 
			
		||||
	cli := resty.New().
 | 
			
		||||
		SetTransport(&http.Transport{
 | 
			
		||||
			Proxy:                 http.ProxyFromEnvironment,
 | 
			
		||||
			ResponseHeaderTimeout: 10 * time.Second,
 | 
			
		||||
		}).
 | 
			
		||||
		SetBaseURL("https://api.openai.com").
 | 
			
		||||
		SetHeader("Authorization", "Bearer "+apiKey).
 | 
			
		||||
		SetTimeout(30 * time.Second)
 | 
			
		||||
		SetTimeout(5 * time.Minute) // hard cap
 | 
			
		||||
 | 
			
		||||
	return &Client{rest: cli}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *Client) ChatCompletion(request ChatRequest) (*ChatResponse, error) {
 | 
			
		||||
	// Note: this function might not work due to the header timeout set on the http client.
 | 
			
		||||
	// We should probably not use this anyway.
 | 
			
		||||
 | 
			
		||||
	resp, err := c.rest.R().
 | 
			
		||||
		SetBody(request).
 | 
			
		||||
		SetHeader("Content-Type", "application/json").
 | 
			
		||||
| 
						 | 
				
			
			@ -37,3 +49,81 @@ func (c *Client) ChatCompletion(request ChatRequest) (*ChatResponse, error) {
 | 
			
		|||
 | 
			
		||||
	return resp.Result().(*ChatResponse), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *Client) ChatCompletionStream(request ChatRequest) (*ChatResponseStream, error) {
 | 
			
		||||
	request.Stream = true
 | 
			
		||||
	resp, err := c.rest.R().
 | 
			
		||||
		SetBody(request).
 | 
			
		||||
		SetHeader("Content-Type", "application/json").
 | 
			
		||||
		SetDoNotParseResponse(true).
 | 
			
		||||
		Post(ChatAPIPath)
 | 
			
		||||
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, errors.Wrap(err, 0)
 | 
			
		||||
	}
 | 
			
		||||
	rbody := resp.RawBody()
 | 
			
		||||
	if rbody == nil {
 | 
			
		||||
		return nil, errors.New("no response body")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if resp.StatusCode() != 200 {
 | 
			
		||||
		defer rbody.Close()
 | 
			
		||||
		var respBodyStr string
 | 
			
		||||
		if respBody, err := io.ReadAll(rbody); err == nil {
 | 
			
		||||
			respBodyStr = string(respBody)
 | 
			
		||||
		} else {
 | 
			
		||||
			respBodyStr = "failed to read: " + err.Error()
 | 
			
		||||
		}
 | 
			
		||||
		return nil, errors.Errorf("status code: %d, body: %q", resp.StatusCode(), respBodyStr)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	ret := &ChatResponseStream{Stream: make(chan string, 1024)}
 | 
			
		||||
	go func() {
 | 
			
		||||
		defer func() {
 | 
			
		||||
			rbody.Close()
 | 
			
		||||
			close(ret.Stream)
 | 
			
		||||
		}()
 | 
			
		||||
 | 
			
		||||
		var contentBegan bool
 | 
			
		||||
		sc := bufio.NewScanner(rbody)
 | 
			
		||||
		for sc.Scan() {
 | 
			
		||||
			line := sc.Bytes()
 | 
			
		||||
			if string(line[:6]) != "data: " {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			data := line[6:]
 | 
			
		||||
			if string(data[:6]) == "[DONE]" {
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			var chunk ChatResponseStreamChunk
 | 
			
		||||
			if err := json.Unmarshal(data, &chunk); err != nil {
 | 
			
		||||
				ret.Err = errors.WrapPrefix(err, "failed to decode chunk:", 0)
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
			if ret.ID == "" {
 | 
			
		||||
				ret.ID = chunk.ID
 | 
			
		||||
				ret.Object = chunk.Object
 | 
			
		||||
				ret.Created = chunk.Created
 | 
			
		||||
				ret.Model = chunk.Model
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			var delta string
 | 
			
		||||
			if len(chunk.Choices) > 0 {
 | 
			
		||||
				delta = chunk.Choices[0].Delta.Content
 | 
			
		||||
			}
 | 
			
		||||
			if !contentBegan {
 | 
			
		||||
				if strings.TrimSpace(delta) == "" {
 | 
			
		||||
					continue
 | 
			
		||||
				}
 | 
			
		||||
				contentBegan = true
 | 
			
		||||
			}
 | 
			
		||||
			ret.Stream <- delta
 | 
			
		||||
		}
 | 
			
		||||
		if sc.Err() != nil {
 | 
			
		||||
			ret.Err = errors.WrapPrefix(sc.Err(), "read response:", 0)
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	return ret, nil
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue