Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: added support for streaming in API #247

Merged
merged 3 commits into from
Nov 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 91 additions & 18 deletions gptme/server/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import atexit
import io
import logging
from contextlib import redirect_stdout
from datetime import datetime
from importlib import resources
Expand All @@ -18,11 +19,14 @@

from ..commands import execute_cmd
from ..dirs import get_logs_dir
from ..llm import reply
from ..llm import _stream
from ..logmanager import LogManager, get_user_conversations, prepare_messages
from ..message import Message
from ..models import get_model
from ..tools import execute_msg
from ..tools.base import ToolUse

logger = logging.getLogger(__name__)

api = flask.Blueprint("api", __name__)

Expand Down Expand Up @@ -94,15 +98,14 @@ def confirm_func(msg: str) -> bool:
def api_conversation_generate(logfile: str):
# get model or use server default
req_json = flask.request.json or {}
stream = req_json.get("stream", False) # Default to no streaming (backward compat)
model = req_json.get("model", get_model().model)

# load conversation
manager = LogManager.load(logfile, branch=req_json.get("branch", "main"))

# if prompt is a user-command, execute it
if manager.log[-1].role == "user":
# TODO: capture output of command and return it

f = io.StringIO()
print("Begin capturing stdout, to pass along command output.")
with redirect_stdout(f):
Expand All @@ -118,21 +121,91 @@ def api_conversation_generate(logfile: str):
# performs reduction/context trimming, if necessary
msgs = prepare_messages(manager.log.messages)

# generate response
# TODO: add support for streaming
msg = reply(msgs, model=model, stream=True)
msg = msg.replace(quiet=True)

# log response and run tools
resp_msgs = []
manager.append(msg)
resp_msgs.append(msg)
for reply_msg in execute_msg(msg, confirm_func):
manager.append(reply_msg)
resp_msgs.append(reply_msg)

return flask.jsonify(
[{"role": msg.role, "content": msg.content} for msg in resp_msgs]
if not msgs:
logger.error("No messages to process")
return flask.jsonify({"error": "No messages to process"})

if not stream:
# Non-streaming response
try:
# Get complete response
output = "".join(_stream(msgs, model))

# Store the message
msg = Message("assistant", output)
msg = msg.replace(quiet=True)
manager.append(msg)

# Execute any tools
reply_msgs = list(execute_msg(msg, confirm_func))
for reply_msg in reply_msgs:
manager.append(reply_msg)

# Return all messages
response = [{"role": "assistant", "content": output, "stored": True}]
response.extend(
{"role": msg.role, "content": msg.content, "stored": True}
for msg in reply_msgs
)
return flask.jsonify(response)

except Exception as e:
logger.exception("Error during generation")
return flask.jsonify({"error": str(e)})

# Streaming response
def generate():
# Start with an empty message
output = ""
try:
logger.info(f"Starting generation for conversation {logfile}")

# Prepare messages for the model
if not msgs:
logger.error("No messages to process")
yield f"data: {flask.json.dumps({'error': 'No messages to process'})}\n\n"
return

# Stream tokens from the model
logger.debug(f"Starting token stream with model {model}")
for char in (char for chunk in _stream(msgs, model) for char in chunk):
output += char
# Send each token as a JSON event
yield f"data: {flask.json.dumps({'role': 'assistant', 'content': char, 'stored': False})}\n\n"

# Check for complete tool uses
tooluses = list(ToolUse.iter_from_content(output))
if tooluses and any(tooluse.is_runnable for tooluse in tooluses):
logger.debug("Found runnable tool use, breaking stream")
break

# Store the complete message
logger.debug(f"Storing complete message: {output[:100]}...")
msg = Message("assistant", output)
msg = msg.replace(quiet=True)
manager.append(msg)

# Execute any tools and stream their output
for reply_msg in execute_msg(msg, confirm_func):
logger.debug(
f"Tool output: {reply_msg.role} - {reply_msg.content[:100]}..."
)
manager.append(reply_msg)
yield f"data: {flask.json.dumps({'role': reply_msg.role, 'content': reply_msg.content, 'stored': True})}\n\n"

except Exception as e:
logger.exception("Error during generation")
yield f"data: {flask.json.dumps({'error': str(e)})}\n\n"
finally:
logger.info("Generation complete")

return flask.Response(
generate(),
mimetype="text/event-stream",
headers={
"Cache-Control": "no-cache",
"X-Accel-Buffering": "no", # Disable buffering in nginx
},
)


Expand Down
96 changes: 74 additions & 22 deletions gptme/server/static/main.js
Original file line number Diff line number Diff line change
Expand Up @@ -239,30 +239,82 @@ new Vue({
},
async generate() {
this.generating = true;
const req = await fetch(
`${apiRoot}/${this.selectedConversation}/generate`,
{
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({ branch: this.branch }),
let currentMessage = {
role: "assistant",
content: "",
timestamp: new Date().toISOString(),
};
this.chatLog.push(currentMessage);

try {
// Create EventSource with POST method using fetch
const response = await fetch(
`${apiRoot}/${this.selectedConversation}/generate`,
{
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({ branch: this.branch }),
}
);

if (!response.ok) {
throw new Error(`HTTP error! status: ${response.status}`);
}
);
this.generating = false;
if (!req.ok) {
this.error = req.statusText;
return;
}
// req.json() can contain (not stored) responses to /commands,
// or the result of the generation.
// if it's unsaved results of a command, we need to display it
const data = await req.json();
if (data.length == 1 && data[0].stored === false) {
this.cmdout = data[0].content;

const reader = response.body.getReader();
const decoder = new TextDecoder();

while (true) {
const {value, done} = await reader.read();
if (done) break;

const chunk = decoder.decode(value);
// Parse SSE data
const lines = chunk.split('\n');
for (const line of lines) {
if (line.startsWith('data: ')) {
const data = JSON.parse(line.slice(6));

if (data.error) {
this.error = data.error;
break;
}

if (data.stored === false) {
// Streaming token from assistant
currentMessage.content += data.content;
currentMessage.html = this.mdToHtml(currentMessage.content);
this.scrollToBottom();
} else {
// Tool output or stored message
if (data.role === "system") {
this.cmdout = data.content;
} else {
// Add as a new message
const newMsg = {
role: data.role,
content: data.content,
timestamp: new Date().toISOString(),
html: this.mdToHtml(data.content),
};
this.chatLog.push(newMsg);
}
}
}
}
}

// After streaming is complete, reload to ensure we have the server's state
this.generating = false;
await this.selectConversation(this.selectedConversation, this.branch);
} catch (error) {
this.error = error.toString();
this.generating = false;
// Remove the temporary message on error
this.chatLog.pop();
}
// reload conversation
await this.selectConversation(this.selectedConversation, this.branch);
},
changeBranch(branch) {
this.branch = branch;
Expand Down
51 changes: 44 additions & 7 deletions tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,14 +71,51 @@ def test_api_conversation_generate(conv: str, client: FlaskClient):
)
assert response.status_code == 200

# Test regular (non-streaming) response
response = client.post(
f"/api/conversations/{conv}/generate",
json={"model": get_model().model},
json={"model": get_model().model, "stream": False},
)
assert response.status_code == 200
msgs = response.get_json()
assert len(msgs) >= 1
assert len(msgs) <= 2
assert msgs[0]["role"] == "assistant"
if len(msgs) == 2:
assert msgs[1]["role"] == "system"
data = response.get_data(as_text=True)
assert data # Ensure we got some response
msgs_resps = response.get_json()
assert msgs_resps is not None # Ensure we got valid JSON
# Assistant message + possible tool output
assert len(msgs_resps) >= 1

# First message should be the assistant's response
assert msgs_resps[0]["role"] == "assistant"


@pytest.mark.slow
def test_api_conversation_generate_stream(conv: str, client: FlaskClient):
# Ask the assistant to generate a test response
response = client.post(
f"/api/conversations/{conv}",
json={"role": "user", "content": "hello, just testing"},
)
assert response.status_code == 200

# Test streaming response
response = client.post(
f"/api/conversations/{conv}/generate",
json={"model": get_model().model, "stream": True},
headers={"Accept": "text/event-stream"},
)
assert response.status_code == 200
assert "text/event-stream" in response.headers["Content-Type"]

# Read and validate the streamed response
chunks = list(response.iter_encoded())
assert len(chunks) > 0

# Each chunk should be a Server-Sent Event
for chunk in chunks:
chunk_str = chunk.decode("utf-8")
assert chunk_str.startswith("data: ")
# Skip empty chunks (heartbeats)
if chunk_str.strip() == "data: ":
continue
data = chunk_str.replace("data: ", "").strip()
assert data # Non-empty data
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test only checks for non-empty data but does not validate the content of the streamed data. Consider checking for expected roles and content in the streamed data.

Loading