Skip to content

Dynamic scaling + Other stuff (DONT MERGE - SPLIT INTO SMALLER PRS) #737

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/router_tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ jobs:
- name: Install Rust
uses: actions-rs/toolchain@v1
with:
toolchain: 1.79.0
toolchain: 1.83.0
override: true
components: rustfmt, clippy
- name: Install Protoc
Expand Down
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Rust builder
FROM lukemathwalker/cargo-chef:latest-rust-1.79 AS chef
FROM lukemathwalker/cargo-chef:latest-rust-1.83 AS chef
WORKDIR /usr/src

ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
Expand Down
35 changes: 25 additions & 10 deletions router/src/health.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use lorax_client::{
Batch, NextTokenChooserParameters, Request, ShardInfo, ShardedClient,
StoppingCriteriaParameters,
input_chunk, Batch, InputChunk, NextTokenChooserParameters, Request, ShardInfo, ShardedClient,
StoppingCriteriaParameters, TokenizedInputs,
};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
Expand Down Expand Up @@ -40,7 +40,12 @@ impl Health {
let generation_liveness_request = Request {
id: LIVENESS_ID,
inputs: "liveness".to_string(),
tokenized_inputs: None,
tokenized_inputs: Some(TokenizedInputs {
ids: vec![75],
input_chunks: vec![InputChunk {
chunk: Some(input_chunk::Chunk::Text("liveness".to_string())),
}],
}),
truncate: 10,
prefill_logprobs: false,
parameters: Some(NextTokenChooserParameters {
Expand All @@ -66,7 +71,7 @@ impl Health {
adapter_index: 0,
// Block 0 is reserved for health checks
blocks: vec![0],
slots: (0..16).collect(),
slots: (0..self.shard_info.block_size).collect(),
cache_len: 0,
chunk_len: None,
};
Expand All @@ -84,15 +89,20 @@ impl Health {
pub(crate) async fn check_classification(&mut self) -> bool {
let classify_request = Request {
id: LIVENESS_ID,
inputs: "San Francisco".to_string(),
tokenized_inputs: None,
inputs: "liveness".to_string(),
tokenized_inputs: Some(TokenizedInputs {
ids: vec![75],
input_chunks: vec![InputChunk {
chunk: Some(input_chunk::Chunk::Text("liveness".to_string())),
}],
}),
truncate: 10,
prefill_logprobs: false,
parameters: None,
stopping_parameters: None,
adapter_index: 0,
blocks: vec![0],
slots: (0..16).collect(),
slots: (0..self.shard_info.block_size).collect(),
cache_len: 0,
chunk_len: None,
};
Expand All @@ -109,15 +119,20 @@ impl Health {
pub(crate) async fn check_embeddings(&mut self) -> bool {
let embed_request = Request {
id: LIVENESS_ID,
inputs: "San Francisco".to_string(),
tokenized_inputs: None,
inputs: "liveness".to_string(),
tokenized_inputs: Some(TokenizedInputs {
ids: vec![75],
input_chunks: vec![InputChunk {
chunk: Some(input_chunk::Chunk::Text("liveness".to_string())),
}],
}),
truncate: 10,
prefill_logprobs: false,
parameters: None,
stopping_parameters: None,
adapter_index: 0,
blocks: vec![0],
slots: (0..16).collect(),
slots: (0..self.shard_info.block_size).collect(),
cache_len: 0,
chunk_len: None,
};
Expand Down
6 changes: 4 additions & 2 deletions router/src/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,12 +112,14 @@ impl ChatTemplateRenderer {
// if not, we need to append the tools to the last message
let text = if self.use_default_tool_template {
match serde_json::to_string(&tools) {
Ok(tools_str) => format!("\n---\n{}\n{}", tools_str, tool_prompt),
// Ok(tools_str) => format!("\n---\n{}\n{}", tools_str, tool_prompt),
Ok(tools_str) => format!("\n{}\n{}", tools_str, tool_prompt),
Err(e) => return Err(InferError::ToolError(e.to_string())),
}
} else {
// if the `tools` variable is used in the template, we just append the tool_prompt
format!("\n---\n{}", tool_prompt)
// format!("\n---\n{}", tool_prompt)
format!("\n{}", tool_prompt)
};
if let Some(last_message) = messages.last_mut() {
if let Some(content) = &mut last_message.content {
Expand Down
12 changes: 9 additions & 3 deletions router/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -581,7 +581,7 @@ pub struct Url {
}

#[derive(Clone, Deserialize, Serialize, ToSchema, Default, Debug, PartialEq)]
pub(crate) struct ToolCall {
pub struct ToolCall {
pub id: String,
pub r#type: String,
pub function: ReturnFunctionDefinition,
Expand All @@ -603,6 +603,8 @@ pub struct Message {
#[schema(example = "My name is David and I")]
pub content: Option<MessageContent>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ToolCall>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
#[schema(example = "\"David\"")]
name: Option<String>,
}
Expand Down Expand Up @@ -642,6 +644,8 @@ pub struct TextMessage {
pub role: String,
#[schema(example = "My name is David and I")]
pub content: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ToolCall>>,
}

impl From<Message> for TextMessage {
Expand All @@ -660,6 +664,7 @@ impl From<Message> for TextMessage {
.join(""),
None => String::new(),
},
tool_calls: value.tool_calls,
}
}
}
Expand Down Expand Up @@ -858,7 +863,8 @@ impl ChatCompletionRequest {
}

pub fn default_tool_prompt() -> String {
"\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.\n".to_string()
// "\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.\n".to_string()
"".to_string()
}

#[derive(Clone, Debug, Deserialize, PartialEq, Serialize, ToSchema)]
Expand Down Expand Up @@ -951,7 +957,7 @@ pub(crate) struct FunctionDefinition {
}

#[derive(Clone, Debug, Deserialize, Serialize, ToSchema, Default, PartialEq)]
pub(crate) struct ReturnFunctionDefinition {
pub struct ReturnFunctionDefinition {
#[serde(default)]
pub description: Option<String>,
pub name: String,
Expand Down
169 changes: 117 additions & 52 deletions router/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ use futures::Stream;
use lorax_client::{ShardInfo, ShardedClient};
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
use once_cell::sync::OnceCell;
use regex::Regex;
use reqwest_middleware::ClientBuilder;
use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -210,6 +211,112 @@ async fn completions_v1(
}
}

fn parse_json_tool_call(
gen_text_value: Value,
) -> Result<(Option<Vec<ToolCall>>, Option<String>), InferError> {
let function = gen_text_value.get("function").ok_or(InferError::ToolError(
"No function found in generated text".to_string(),
))?;

let name = function
.get("_name")
.and_then(Value::as_str)
.ok_or(InferError::ToolError(
"No _name found in generated text".to_string(),
))?
.to_string();

let mut arguments = function.clone();
if let Value::Object(ref mut props) = arguments {
props.remove("_name");
}
match name.as_str() {
"no_tool" => {
// parse the content message
let content_message = arguments
.get("content")
.and_then(Value::as_str)
.ok_or_else(|| {
InferError::ToolError("No `content` found in generated text".to_string())
})?
.to_string();
Ok((None, Some(content_message)))
}
_ => {
let arguments = serde_json::to_string(&arguments).map_err(|e| {
InferError::ToolError(format!("Failed to serialize arguments: {}", e))
})?;
let tool_calls = vec![ToolCall {
id: "0".to_string(),
r#type: "function".to_string(),
function: ReturnFunctionDefinition {
description: None,
name,
arguments,
},
}];
Ok((Some(tool_calls), None))
}
}
}

fn parse_xml_tool_call(gen: &str) -> Result<(Option<Vec<ToolCall>>, Option<String>), InferError> {
let tool_call_regex = Regex::new(r"(?s)<tool_call>(.*?)</tool_call>|<tool_call>(.*)")
.map_err(|e| InferError::ToolError(format!("Failed to create tool call regex: {}", e)))?;
// Check for tool call matches
if let Some(captures) = tool_call_regex.captures(gen) {
// Check for complete tool call (first capture group)
let json_content = if let Some(complete_match) = captures.get(1) {
complete_match.as_str()
}
// Check for incomplete tool call (second capture group)
else if let Some(incomplete_match) = captures.get(2) {
incomplete_match.as_str()
} else {
return Ok((None, Some(gen.to_string())));
};

// Parse the JSON content
let parsed_content: serde_json::Value =
serde_json::from_str(json_content.trim()).map_err(|e| {
InferError::ToolError(format!("Failed to parse tool call JSON content: {}", e))
})?;

// Extract name and arguments
let name = parsed_content["name"]
.as_str()
.ok_or_else(|| InferError::ToolError("Missing 'name' field in tool call".to_string()))?
.to_string();

// Parse the arguments field which may be a JSON string
let arguments = if let Some(args_str) = parsed_content["arguments"].as_str() {
// If arguments is a string, try to parse it as JSON
serde_json::from_str(args_str).unwrap_or(parsed_content["arguments"].clone())
} else {
// If not a string, use the raw value
parsed_content["arguments"].clone()
};

// Create tool call with the extracted content
let tool_calls = vec![ToolCall {
id: "0".to_string(),
r#type: "function".to_string(),
function: ReturnFunctionDefinition {
description: None,
name,
arguments: serde_json::to_string(&arguments).map_err(|e| {
InferError::ToolError(format!("Failed to serialize arguments: {}", e))
})?,
},
}];

Ok((Some(tool_calls), None))
} else {
// If no tool call tags are found, return the original text
Ok((None, Some(gen.to_string())))
}
}

/// OpenAI compatible chat completions endpoint
#[utoipa::path(
post,
Expand Down Expand Up @@ -319,57 +426,14 @@ async fn chat_completions_v1(
let mut choice_content = vec![];
for (_, gen) in generations.iter().enumerate() {
let (tool_calls, output) = if using_tools {
let gen_text_value: Value = serde_json::from_str(&gen).map_err(|e| {
InferError::ToolError(format!(
"Failed to parse generated text: {} {:?}",
e, gen
))
})?;
let function = gen_text_value.get("function").ok_or(InferError::ToolError(
"No function found in generated text".to_string(),
))?;

let name = function
.get("_name")
.and_then(Value::as_str)
.ok_or(InferError::ToolError(
"No _name found in generated text".to_string(),
))?
.to_string();

let mut arguments = function.clone();
if let Value::Object(ref mut props) = arguments {
props.remove("_name");
}
match name.as_str() {
"no_tool" => {
// parse the content message
let content_message = arguments
.get("content")
.and_then(Value::as_str)
.ok_or_else(|| {
InferError::ToolError(
"No `content` found in generated text".to_string(),
)
})?
.to_string();
(None, Some(content_message))
}
_ => {
let arguments = serde_json::to_string(&arguments).map_err(|e| {
InferError::ToolError(format!("Failed to serialize arguments: {}", e))
})?;
let tool_calls = vec![ToolCall {
id: "0".to_string(),
r#type: "function".to_string(),
function: ReturnFunctionDefinition {
description: None,
name,
arguments,
},
}];
(Some(tool_calls), None)
}
let tool_call_result = match serde_json::from_str::<Value>(gen) {
Ok(gen_text_value) => parse_json_tool_call(gen_text_value),
Err(_) => parse_xml_tool_call(gen),
};
match tool_call_result {
Ok((tool_calls, output)) => (tool_calls, output),
// TODO: (magdy) How should we tell the user that the tool call failed?
Err(_) => (None, Some(gen.clone())),
}
} else {
(None, Some(gen.clone()))
Expand Down Expand Up @@ -435,7 +499,8 @@ pub(crate) fn prepare_chat_input(
messages,
Some((updated_tools, tool_prompt.into())),
)?;
return Ok((inputs, grammar, tool_schema.is_some()));
// return Ok((inputs, grammar, tool_schema.is_some()));
return Ok((inputs, grammar, true));
}

// if no response_format or tools are set simply apply the chat template to generate inputs
Expand Down
Loading
Loading