package openai import ( "bufio" "encoding/json" "io" "net/http" "strings" "time" "github.com/go-errors/errors" "github.com/go-resty/resty/v2" ) type Client struct { rest *resty.Client } func NewClient(apiKey string) *Client { cli := resty.New(). SetTransport(&http.Transport{ Proxy: http.ProxyFromEnvironment, ResponseHeaderTimeout: 10 * time.Second, }). SetBaseURL("https://api.openai.com"). SetHeader("Authorization", "Bearer "+apiKey). SetTimeout(5 * time.Minute) // hard cap return &Client{rest: cli} } func (c *Client) ChatCompletion(request ChatRequest) (*ChatResponse, error) { // Note: this function might not work due to the header timeout set on the http client. // We should probably not use this anyway. resp, err := c.rest.R(). SetBody(request). SetHeader("Content-Type", "application/json"). SetResult(&ChatResponse{}). Post(ChatAPIPath) if err != nil { return nil, errors.Wrap(err, 0) } if resp.StatusCode() != 200 { return nil, errors.Errorf("unexpected status code: %d", resp.StatusCode()) } return resp.Result().(*ChatResponse), nil } func (c *Client) ChatCompletionStream(request ChatRequest) (*ChatResponseStream, error) { request.Stream = true resp, err := c.rest.R(). SetBody(request). SetHeader("Content-Type", "application/json"). SetDoNotParseResponse(true). Post(ChatAPIPath) if err != nil { return nil, errors.Wrap(err, 0) } rbody := resp.RawBody() if rbody == nil { return nil, errors.New("no response body") } if resp.StatusCode() != 200 { defer rbody.Close() var respBodyStr string if respBody, err := io.ReadAll(rbody); err == nil { respBodyStr = string(respBody) } else { respBodyStr = "failed to read: " + err.Error() } return nil, errors.Errorf("status code: %d, body: %q", resp.StatusCode(), respBodyStr) } ret := &ChatResponseStream{ Stream: make(chan string, 1024), Done: make(chan struct{}), } go func() { defer func() { rbody.Close() close(ret.Stream) close(ret.Done) }() var contentBegan bool sc := bufio.NewScanner(rbody) for sc.Scan() { line := sc.Bytes() if string(line[:6]) != "data: " { continue } data := line[6:] if string(data[:6]) == "[DONE]" { return } var chunk ChatResponseStreamChunk if err := json.Unmarshal(data, &chunk); err != nil { ret.Err = errors.WrapPrefix(err, "failed to decode chunk:", 0) return } if ret.ID == "" { ret.ID = chunk.ID ret.Object = chunk.Object ret.Created = chunk.Created ret.Model = chunk.Model } var delta string if len(chunk.Choices) > 0 { delta = chunk.Choices[0].Delta.Content } if !contentBegan { if strings.TrimSpace(delta) == "" { continue } contentBegan = true } ret.Stream <- delta } if sc.Err() != nil { ret.Err = errors.WrapPrefix(sc.Err(), "read response:", 0) } }() return ret, nil }