diff --git a/assistant.go b/assistant.go index 6fd09c3..0b218e0 100644 --- a/assistant.go +++ b/assistant.go @@ -134,7 +134,10 @@ func assistantStreamedResponse(request openai.ChatRequest, cb assistantStreamedR if minWaitSatisfied { break Drain } - <-minWait + select { + case <-minWait: + case <-resp.Done: + } minWaitSatisfied = true } } @@ -155,7 +158,7 @@ func assistantStreamedResponse(request openai.ChatRequest, cb assistantStreamedR minWaitDurSecs := lo.Min([]int{nUpdates, 4}) + nErrs*3 minWait = time.After(time.Duration(minWaitDurSecs) * time.Second) - // send the partial message + // Send the partial message respText := respBuilder.String() + assistantWritingSign if err := cb(respText, false); err != nil { logger.Warnw("failed to send partial update", "error", err) diff --git a/openai/chat.go b/openai/chat.go index fc213a8..ac525df 100644 --- a/openai/chat.go +++ b/openai/chat.go @@ -51,6 +51,7 @@ type ChatResponseStream struct { Created int Model string Stream chan string + Done chan struct{} Err error } diff --git a/openai/client.go b/openai/client.go index e7abc7f..1800cf9 100644 --- a/openai/client.go +++ b/openai/client.go @@ -77,11 +77,15 @@ func (c *Client) ChatCompletionStream(request ChatRequest) (*ChatResponseStream, return nil, errors.Errorf("status code: %d, body: %q", resp.StatusCode(), respBodyStr) } - ret := &ChatResponseStream{Stream: make(chan string, 1024)} + ret := &ChatResponseStream{ + Stream: make(chan string, 1024), + Done: make(chan struct{}), + } go func() { defer func() { rbody.Close() close(ret.Stream) + close(ret.Done) }() var contentBegan bool