Skip to content

enable deep <think> mode for local ollama #286

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 27 additions & 9 deletions api/websocket_wiki.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,10 @@ async def handle_websocket_chat(websocket: WebSocket):
conversation_history += f"<turn>\n<user>{turn.user_query.query_str}</user>\n<assistant>{turn.assistant_response.response_str}</assistant>\n</turn>\n"

# Create the prompt with context
prompt = f"/no_think {system_prompt}\n\n"
if request.provider == "ollama":
prompt = f"/think {system_prompt}\n\n"
else:
prompt = f"/no_think {system_prompt}\n\n"
Comment on lines +406 to +409
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To improve maintainability and reduce code duplication, define the think/no_think command in a variable and reuse it. The same logic is repeated for simplified_prompt in the exception handler on lines 627-630.

think_mode_command = "/think" if request.provider == "ollama" else "/no_think"
prompt = f"{think_mode_command} {system_prompt}\n\n"


if conversation_history:
prompt += f"<conversation_history>\n{conversation_history}</conversation_history>\n\n"
Expand All @@ -428,7 +431,7 @@ async def handle_websocket_chat(websocket: WebSocket):
model_config = get_model_config(request.provider, request.model)["model_kwargs"]

if request.provider == "ollama":
prompt += " /no_think"
prompt += " /think"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The prompt is being appended with /think here, but it's already added at the beginning of the prompt on line 407. This results in a prompt for ollama that both starts and ends with a think command, which is likely not the intended behavior and could confuse the model. A similar issue exists for the simplified_prompt on line 642. Remove this line.


model = OllamaClient()
model_kwargs = {
Expand Down Expand Up @@ -527,11 +530,18 @@ async def handle_websocket_chat(websocket: WebSocket):
# Get the response and handle it properly using the previously created api_kwargs
response = await model.acall(api_kwargs=api_kwargs, model_type=ModelType.LLM)
# Handle streaming response from Ollama
think = False
async for chunk in response:
text = getattr(chunk, 'response', None) or getattr(chunk, 'text', None) or str(chunk)
if text and not text.startswith('model=') and not text.startswith('created_at='):
text = text.replace('<think>', '').replace('</think>', '')
await websocket.send_text(text)
if text == '<think>':
think = True
logger.info("think enabled")
elif text == '</think>':
think = False
# skip <think>.....</think> in output
if not think:
await websocket.send_text(text)
Comment on lines +533 to +544
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current logic for handling <think> and </think> tags is brittle because it assumes that these tags will always arrive in their own separate chunks. If a chunk contains a tag along with other text (e.g., </think>Here is the answer.), the logic will fail to process it correctly. This could lead to missing or incomplete model responses.

A more robust implementation using a state machine would handle various chunking scenarios correctly. This logic is also duplicated in the fallback handler (lines 655-664). Consider extracting it into a helper function.

in_think_block = False
async for chunk in response:
    text = getattr(chunk, 'response', None) or getattr(chunk, 'text', None) or str(chunk)
    if not (text and not text.startswith('model=') and not text.startswith('created_at=')):
        continue

    while text:
        if in_think_block:
            end_tag_pos = text.find('</think>')
            if end_tag_pos != -1:
                in_think_block = False
                text = text[end_tag_pos + len('</think>'):]
            else:
                # The rest of the chunk is inside the think block, so we discard it.
                text = ''
        else:  # Not in a think block
            start_tag_pos = text.find('<think>')
            if start_tag_pos != -1:
                # Send content before the <think> tag
                if start_tag_pos > 0:
                    await websocket.send_text(text[:start_tag_pos])
                in_think_block = True
                text = text[start_tag_pos + len('<think>'):]
            else:
                # No <think> tag in the chunk, send it all
                await websocket.send_text(text)
                text = ''

# Explicitly close the WebSocket connection after the response is complete
await websocket.close()
elif request.provider == "openrouter":
Expand Down Expand Up @@ -614,7 +624,10 @@ async def handle_websocket_chat(websocket: WebSocket):
logger.warning("Token limit exceeded, retrying without context")
try:
# Create a simplified prompt without context
simplified_prompt = f"/no_think {system_prompt}\n\n"
if request.provider == "ollama":
simplified_prompt = f"/think {system_prompt}\n\n"
else:
simplified_prompt = f"/no_think {system_prompt}\n\n"
if conversation_history:
simplified_prompt += f"<conversation_history>\n{conversation_history}</conversation_history>\n\n"

Expand All @@ -626,7 +639,7 @@ async def handle_websocket_chat(websocket: WebSocket):
simplified_prompt += f"<query>\n{query}\n</query>\n\nAssistant: "

if request.provider == "ollama":
simplified_prompt += " /no_think"
simplified_prompt += " /think"

# Create new api_kwargs with the simplified prompt
fallback_api_kwargs = model.convert_inputs_to_api_kwargs(
Expand All @@ -639,11 +652,16 @@ async def handle_websocket_chat(websocket: WebSocket):
fallback_response = await model.acall(api_kwargs=fallback_api_kwargs, model_type=ModelType.LLM)

# Handle streaming fallback_response from Ollama
async for chunk in fallback_response:
think = False
async for chunk in response:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There is a critical bug in this fallback logic. The code iterates over response instead of fallback_response. The fallback_response is the new response that should be streamed, while response is the original one that caused the token limit error. This will likely cause incorrect behavior or an infinite loop.

Additionally, the streaming logic here is a duplicate of the one at lines 533-544. Please consider refactoring the streaming logic into a single helper function to be used in both places.

async for chunk in fallback_response:

text = getattr(chunk, 'response', None) or getattr(chunk, 'text', None) or str(chunk)
if text and not text.startswith('model=') and not text.startswith('created_at='):
text = text.replace('<think>', '').replace('</think>', '')
await websocket.send_text(text)
if text == '<think>':
think = True
elif text == '</think>':
think = False
if not think:
await websocket.send_text(text)
elif request.provider == "openrouter":
try:
# Create new api_kwargs with the simplified prompt
Expand Down