130 lines
2.8 KiB
Go
130 lines
2.8 KiB
Go
package openai
|
|
|
|
import (
|
|
"bufio"
|
|
"encoding/json"
|
|
"io"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/go-errors/errors"
|
|
"github.com/go-resty/resty/v2"
|
|
)
|
|
|
|
type Client struct {
|
|
rest *resty.Client
|
|
}
|
|
|
|
func NewClient(apiKey string) *Client {
|
|
cli := resty.New().
|
|
SetTransport(&http.Transport{
|
|
Proxy: http.ProxyFromEnvironment,
|
|
ResponseHeaderTimeout: 10 * time.Second,
|
|
}).
|
|
SetBaseURL("https://api.openai.com").
|
|
SetHeader("Authorization", "Bearer "+apiKey).
|
|
SetTimeout(5 * time.Minute) // hard cap
|
|
|
|
return &Client{rest: cli}
|
|
}
|
|
|
|
func (c *Client) ChatCompletion(request ChatRequest) (*ChatResponse, error) {
|
|
// Note: this function might not work due to the header timeout set on the http client.
|
|
// We should probably not use this anyway.
|
|
|
|
resp, err := c.rest.R().
|
|
SetBody(request).
|
|
SetHeader("Content-Type", "application/json").
|
|
SetResult(&ChatResponse{}).
|
|
Post(ChatAPIPath)
|
|
|
|
if err != nil {
|
|
return nil, errors.Wrap(err, 0)
|
|
}
|
|
|
|
if resp.StatusCode() != 200 {
|
|
return nil, errors.Errorf("unexpected status code: %d", resp.StatusCode())
|
|
}
|
|
|
|
return resp.Result().(*ChatResponse), nil
|
|
}
|
|
|
|
func (c *Client) ChatCompletionStream(request ChatRequest) (*ChatResponseStream, error) {
|
|
request.Stream = true
|
|
resp, err := c.rest.R().
|
|
SetBody(request).
|
|
SetHeader("Content-Type", "application/json").
|
|
SetDoNotParseResponse(true).
|
|
Post(ChatAPIPath)
|
|
|
|
if err != nil {
|
|
return nil, errors.Wrap(err, 0)
|
|
}
|
|
rbody := resp.RawBody()
|
|
if rbody == nil {
|
|
return nil, errors.New("no response body")
|
|
}
|
|
|
|
if resp.StatusCode() != 200 {
|
|
defer rbody.Close()
|
|
var respBodyStr string
|
|
if respBody, err := io.ReadAll(rbody); err == nil {
|
|
respBodyStr = string(respBody)
|
|
} else {
|
|
respBodyStr = "failed to read: " + err.Error()
|
|
}
|
|
return nil, errors.Errorf("status code: %d, body: %q", resp.StatusCode(), respBodyStr)
|
|
}
|
|
|
|
ret := &ChatResponseStream{Stream: make(chan string, 1024)}
|
|
go func() {
|
|
defer func() {
|
|
rbody.Close()
|
|
close(ret.Stream)
|
|
}()
|
|
|
|
var contentBegan bool
|
|
sc := bufio.NewScanner(rbody)
|
|
for sc.Scan() {
|
|
line := sc.Bytes()
|
|
if string(line[:6]) != "data: " {
|
|
continue
|
|
}
|
|
data := line[6:]
|
|
if string(data[:6]) == "[DONE]" {
|
|
return
|
|
}
|
|
|
|
var chunk ChatResponseStreamChunk
|
|
if err := json.Unmarshal(data, &chunk); err != nil {
|
|
ret.Err = errors.WrapPrefix(err, "failed to decode chunk:", 0)
|
|
return
|
|
}
|
|
if ret.ID == "" {
|
|
ret.ID = chunk.ID
|
|
ret.Object = chunk.Object
|
|
ret.Created = chunk.Created
|
|
ret.Model = chunk.Model
|
|
}
|
|
|
|
var delta string
|
|
if len(chunk.Choices) > 0 {
|
|
delta = chunk.Choices[0].Delta.Content
|
|
}
|
|
if !contentBegan {
|
|
if strings.TrimSpace(delta) == "" {
|
|
continue
|
|
}
|
|
contentBegan = true
|
|
}
|
|
ret.Stream <- delta
|
|
}
|
|
if sc.Err() != nil {
|
|
ret.Err = errors.WrapPrefix(sc.Err(), "read response:", 0)
|
|
}
|
|
}()
|
|
|
|
return ret, nil
|
|
}
|