tgbot_misaka_5882f7/openai/client.go

134 lines
2.9 KiB
Go
Raw Permalink Normal View History

2023-03-08 03:28:26 +09:00
package openai
import (
2023-03-19 15:45:20 +09:00
"bufio"
"encoding/json"
"io"
"net/http"
"strings"
2023-03-08 03:28:26 +09:00
"time"
"github.com/go-errors/errors"
"github.com/go-resty/resty/v2"
)
type Client struct {
rest *resty.Client
}
func NewClient(apiKey string) *Client {
cli := resty.New().
2023-03-19 15:45:20 +09:00
SetTransport(&http.Transport{
Proxy: http.ProxyFromEnvironment,
ResponseHeaderTimeout: 10 * time.Second,
}).
2023-03-08 03:28:26 +09:00
SetBaseURL("https://api.openai.com").
SetHeader("Authorization", "Bearer "+apiKey).
2023-03-19 15:45:20 +09:00
SetTimeout(5 * time.Minute) // hard cap
2023-03-08 03:28:26 +09:00
return &Client{rest: cli}
}
func (c *Client) ChatCompletion(request ChatRequest) (*ChatResponse, error) {
2023-03-19 15:45:20 +09:00
// Note: this function might not work due to the header timeout set on the http client.
// We should probably not use this anyway.
2023-03-08 03:28:26 +09:00
resp, err := c.rest.R().
SetBody(request).
SetHeader("Content-Type", "application/json").
SetResult(&ChatResponse{}).
Post(ChatAPIPath)
if err != nil {
return nil, errors.Wrap(err, 0)
}
if resp.StatusCode() != 200 {
return nil, errors.Errorf("unexpected status code: %d", resp.StatusCode())
}
return resp.Result().(*ChatResponse), nil
}
2023-03-19 15:45:20 +09:00
func (c *Client) ChatCompletionStream(request ChatRequest) (*ChatResponseStream, error) {
request.Stream = true
resp, err := c.rest.R().
SetBody(request).
SetHeader("Content-Type", "application/json").
SetDoNotParseResponse(true).
Post(ChatAPIPath)
if err != nil {
return nil, errors.Wrap(err, 0)
}
rbody := resp.RawBody()
if rbody == nil {
return nil, errors.New("no response body")
}
if resp.StatusCode() != 200 {
defer rbody.Close()
var respBodyStr string
if respBody, err := io.ReadAll(rbody); err == nil {
respBodyStr = string(respBody)
} else {
respBodyStr = "failed to read: " + err.Error()
}
return nil, errors.Errorf("status code: %d, body: %q", resp.StatusCode(), respBodyStr)
}
ret := &ChatResponseStream{
Stream: make(chan string, 1024),
Done: make(chan struct{}),
}
2023-03-19 15:45:20 +09:00
go func() {
defer func() {
rbody.Close()
close(ret.Stream)
close(ret.Done)
2023-03-19 15:45:20 +09:00
}()
var contentBegan bool
sc := bufio.NewScanner(rbody)
for sc.Scan() {
line := sc.Bytes()
if string(line[:6]) != "data: " {
continue
}
data := line[6:]
if string(data[:6]) == "[DONE]" {
return
}
var chunk ChatResponseStreamChunk
if err := json.Unmarshal(data, &chunk); err != nil {
ret.Err = errors.WrapPrefix(err, "failed to decode chunk:", 0)
return
}
if ret.ID == "" {
ret.ID = chunk.ID
ret.Object = chunk.Object
ret.Created = chunk.Created
ret.Model = chunk.Model
}
var delta string
if len(chunk.Choices) > 0 {
delta = chunk.Choices[0].Delta.Content
}
if !contentBegan {
if strings.TrimSpace(delta) == "" {
continue
}
contentBegan = true
}
ret.Stream <- delta
}
if sc.Err() != nil {
ret.Err = errors.WrapPrefix(sc.Err(), "read response:", 0)
}
}()
return ret, nil
}