feat(AI): streaming response

This commit is contained in:
Yiyang Kang 2023-03-19 14:45:20 +08:00
parent 7b2d3c31e5
commit bd5e8112a1
Signed by: kkyy
GPG key ID: 80FD317ECAF06CC3
3 changed files with 168 additions and 14 deletions

View file

@ -44,3 +44,28 @@ type ChatResponse struct {
Usage map[string]int `json:"usage"`
Choices []ChatResponseChoice `json:"choices"`
}
type ChatResponseStream struct {
ID string
Object string
Created int
Model string
Stream chan string
Err error
}
type ChatResponseStreamChunk struct {
ID string
Object string
Created int
Model string
Choices []ChatResponseStreamChoice
}
type ChatResponseStreamChoice struct {
Index int `json:"index"`
FinishReason string `json:"finish_reason"`
Delta struct {
Content string `json:"content"`
} `json:"delta"`
}

View file

@ -1,6 +1,11 @@
package openai
import (
"bufio"
"encoding/json"
"io"
"net/http"
"strings"
"time"
"github.com/go-errors/errors"
@ -13,14 +18,21 @@ type Client struct {
func NewClient(apiKey string) *Client {
cli := resty.New().
SetTransport(&http.Transport{
Proxy: http.ProxyFromEnvironment,
ResponseHeaderTimeout: 10 * time.Second,
}).
SetBaseURL("https://api.openai.com").
SetHeader("Authorization", "Bearer "+apiKey).
SetTimeout(30 * time.Second)
SetTimeout(5 * time.Minute) // hard cap
return &Client{rest: cli}
}
func (c *Client) ChatCompletion(request ChatRequest) (*ChatResponse, error) {
// Note: this function might not work due to the header timeout set on the http client.
// We should probably not use this anyway.
resp, err := c.rest.R().
SetBody(request).
SetHeader("Content-Type", "application/json").
@ -37,3 +49,81 @@ func (c *Client) ChatCompletion(request ChatRequest) (*ChatResponse, error) {
return resp.Result().(*ChatResponse), nil
}
func (c *Client) ChatCompletionStream(request ChatRequest) (*ChatResponseStream, error) {
request.Stream = true
resp, err := c.rest.R().
SetBody(request).
SetHeader("Content-Type", "application/json").
SetDoNotParseResponse(true).
Post(ChatAPIPath)
if err != nil {
return nil, errors.Wrap(err, 0)
}
rbody := resp.RawBody()
if rbody == nil {
return nil, errors.New("no response body")
}
if resp.StatusCode() != 200 {
defer rbody.Close()
var respBodyStr string
if respBody, err := io.ReadAll(rbody); err == nil {
respBodyStr = string(respBody)
} else {
respBodyStr = "failed to read: " + err.Error()
}
return nil, errors.Errorf("status code: %d, body: %q", resp.StatusCode(), respBodyStr)
}
ret := &ChatResponseStream{Stream: make(chan string, 1024)}
go func() {
defer func() {
rbody.Close()
close(ret.Stream)
}()
var contentBegan bool
sc := bufio.NewScanner(rbody)
for sc.Scan() {
line := sc.Bytes()
if string(line[:6]) != "data: " {
continue
}
data := line[6:]
if string(data[:6]) == "[DONE]" {
return
}
var chunk ChatResponseStreamChunk
if err := json.Unmarshal(data, &chunk); err != nil {
ret.Err = errors.WrapPrefix(err, "failed to decode chunk:", 0)
return
}
if ret.ID == "" {
ret.ID = chunk.ID
ret.Object = chunk.Object
ret.Created = chunk.Created
ret.Model = chunk.Model
}
var delta string
if len(chunk.Choices) > 0 {
delta = chunk.Choices[0].Delta.Content
}
if !contentBegan {
if strings.TrimSpace(delta) == "" {
continue
}
contentBegan = true
}
ret.Stream <- delta
}
if sc.Err() != nil {
ret.Err = errors.WrapPrefix(sc.Err(), "read response:", 0)
}
}()
return ret, nil
}