Skip to content

Commit aa385c9

Browse files
feat(server,engine): use nim as runtime (#872)
1 parent 16b51ca commit aa385c9

File tree

23 files changed

+668
-38
lines changed

23 files changed

+668
-38
lines changed

common/pkg/api/completions.go

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,18 +35,23 @@ func ConvertCreateChatCompletionRequestToProto(body []byte) ([]byte, error) {
3535
}
3636

3737
// ConvertCreateChatCompletionRequestToOpenAI converts the request to the OpenAI format.
38-
func ConvertCreateChatCompletionRequestToOpenAI(body []byte) ([]byte, error) {
38+
func ConvertCreateChatCompletionRequestToOpenAI(body []byte, needStringFormat bool) ([]byte, error) {
3939
fs := []convertF{
4040
// The order of the functions is the opposite of the ConvertCreateChatCompletionRequestToProto.
4141
//
4242
// We don't have a function that corresponds to convertContentStringToArray as the convertion
4343
// doesn't break the OpenAI API spec.
4444
convertEncodedTopP,
45+
convertEncodedTopP,
4546
convertEncodedTemperature,
4647
convertEncodedChatTemplateKwargs,
4748
convertEncodedFunctionParameters,
4849
convertToolChoiceObject,
4950
}
51+
if needStringFormat {
52+
// NIM expects the content field to be a string.
53+
fs = append([]convertF{convertContentArrayToString}, fs...)
54+
}
5055
return applyConvertFuncs(body, fs)
5156
}
5257

@@ -265,3 +270,46 @@ func convertContentStringToArray(r map[string]interface{}) error {
265270
}
266271
return nil
267272
}
273+
274+
// convertContentArrayToString converts the content array back to a string for OpenAI format compatibility.
275+
func convertContentArrayToString(r map[string]interface{}) error {
276+
msgs, ok := r["messages"]
277+
if !ok {
278+
return nil
279+
}
280+
281+
for _, msg := range msgs.([]interface{}) {
282+
m := msg.(map[string]interface{})
283+
content, ok := m["content"]
284+
if !ok {
285+
continue
286+
}
287+
288+
// If content is already a string, no conversion needed
289+
if _, ok := content.(string); ok {
290+
continue
291+
}
292+
293+
// If content is an array, convert it to a string format OpenAI expects
294+
if contentArr, ok := content.([]interface{}); ok && len(contentArr) > 0 {
295+
// For text-only content, extract just the text
296+
if len(contentArr) == 1 {
297+
if contentItem, ok := contentArr[0].(map[string]interface{}); ok {
298+
if contentType, ok := contentItem["type"].(string); ok && contentType == contentTypeText {
299+
if text, ok := contentItem["text"].(string); ok {
300+
m["content"] = text
301+
continue
302+
}
303+
} else {
304+
// TODO(guangrui): Handle non-text content.
305+
return fmt.Errorf("unsupported content type: %s", contentType)
306+
}
307+
}
308+
} else {
309+
// TODO(guangrui): Handle more complex content arrays.
310+
return fmt.Errorf("content array with multiple items is not supported")
311+
}
312+
}
313+
}
314+
return nil
315+
}

common/pkg/api/completions_test.go

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -544,3 +544,77 @@ func TestConvertContentStringToArray(t *testing.T) {
544544
})
545545
}
546546
}
547+
548+
func TestConvertContentArrayToString(t *testing.T) {
549+
tcs := []struct {
550+
name string
551+
body string
552+
want string
553+
wantErr bool
554+
}{
555+
{
556+
name: "no messages field",
557+
body: `{"model": "gpt-4"}`,
558+
want: `{"model": "gpt-4"}`,
559+
},
560+
{
561+
name: "content already string",
562+
body: `{"messages": [{"role": "user", "content": "Hello, world!"}]}`,
563+
want: `{"messages": [{"role": "user", "content": "Hello, world!"}]}`,
564+
},
565+
{
566+
name: "no content field",
567+
body: `{"messages": [{"role": "user"}]}`,
568+
want: `{"messages": [{"role": "user"}]}`,
569+
},
570+
{
571+
name: "single text content in array",
572+
body: `{"messages": [{"role": "user", "content": [{"type": "text", "text": "Hello, world!"}]}]}`,
573+
want: `{"messages": [{"role": "user", "content": "Hello, world!"}]}`,
574+
},
575+
{
576+
name: "multiple messages mixed formats",
577+
body: `{"messages": [
578+
{"role": "system", "content": "You are a helpful assistant."},
579+
{"role": "user", "content": [{"type": "text", "text": "What's the weather?"}]}
580+
]}`,
581+
want: `{"messages": [
582+
{"role": "system", "content": "You are a helpful assistant."},
583+
{"role": "user", "content": "What's the weather?"}
584+
]}`,
585+
},
586+
{
587+
name: "non-text content type",
588+
body: `{"messages": [{"role": "user", "content": [{"type": "image", "image_url": {"url": "https://example.com/image.jpg"}}]}]}`,
589+
wantErr: true,
590+
},
591+
{
592+
name: "multiple content items",
593+
body: `{"messages": [{"role": "user", "content": [{"type": "text", "text": "Hello"}, {"type": "text", "text": "world"}]}]}`,
594+
wantErr: true,
595+
},
596+
}
597+
598+
for _, tc := range tcs {
599+
t.Run(tc.name, func(t *testing.T) {
600+
got, err := applyConvertFuncs([]byte(tc.body), []convertF{convertContentArrayToString})
601+
602+
if tc.wantErr {
603+
assert.Error(t, err)
604+
return
605+
}
606+
607+
assert.NoError(t, err)
608+
609+
// Compare as parsed JSON to ignore formatting differences
610+
var gotJSON, wantJSON map[string]interface{}
611+
err = json.Unmarshal(got, &gotJSON)
612+
assert.NoError(t, err)
613+
614+
err = json.Unmarshal([]byte(tc.want), &wantJSON)
615+
assert.NoError(t, err)
616+
617+
assert.Equal(t, wantJSON, gotJSON)
618+
})
619+
}
620+
}

deployments/engine/templates/_helpers.tpl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,15 @@ Create the name of the service account to use
5858
{{ default (include "inference-manager-engine.fullname" .) .Values.serviceAccount.name }}
5959
{{- end -}}
6060

61+
{{/*
62+
For inline NGC key, create image pull secret
63+
*/}}
64+
{{- define "inference-manager-engine.generatedImagePullSecret" -}}
65+
{{- if .Values.nim.ngcApiKey }}
66+
{{- printf "{\"auths\":{\"nvcr.io\":{\"username\":\"$oauthtoken\",\"password\":\"%s\"}}}" .Values.nim.ngcApiKey | b64enc }}
67+
{{- end }}
68+
{{- end }}
69+
6170
{{/*
6271
Do nothing, just for validation.
6372
*/}}

deployments/engine/templates/configmap.yaml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,14 @@ data:
6060
vllm:
6161
dynamicLoRALoading: {{ .Values.vllm.dynamicLoRALoading }}
6262
loggingLevel: {{ .Values.vllm.loggingLevel }}
63+
{{- with .Values.nim }}
64+
nim:
65+
ngcApiKey: {{ .ngcApiKey | b64enc }}
66+
{{- with .models }}
67+
models:
68+
{{- toYaml . | nindent 8 }}
69+
{{- end }}
70+
{{- end }}
6371
model:
6472
default:
6573
runtimeName: {{ .Values.model.default.runtimeName }}
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
{{- if .Values.nim.ngcApiKey -}}
2+
apiVersion: v1
3+
kind: Secret
4+
metadata:
5+
name: ngc-secret
6+
labels:
7+
{{- include "inference-manager-engine.labels" . | nindent 4 }}
8+
type: kubernetes.io/dockerconfigjson
9+
data:
10+
.dockerconfigjson: {{ template "inference-manager-engine.generatedImagePullSecret" .}}
11+
12+
---
13+
14+
apiVersion: v1
15+
kind: Secret
16+
metadata:
17+
name: ngc-api
18+
labels:
19+
{{- include "inference-manager-engine.labels" . | nindent 4 }}
20+
type: Opaque
21+
data:
22+
NGC_API_KEY: {{ .Values.nim.ngcApiKey | b64enc }}
23+
{{- end -}}

deployments/engine/values.schema.json

Lines changed: 1 addition & 1 deletion
Large diffs are not rendered by default.

deployments/engine/values.yaml

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,36 @@ vllm:
277277
# Logging level of VLLM.
278278
loggingLevel: ERROR
279279

280+
# nim is settings for using NVIDIA NIM (NVIDIA Inference Manager) as serving engine.
281+
nim:
282+
# The NIM API key to use for accessing the NIM API.
283+
# +docs:type=string
284+
ngcApiKey: ""
285+
# The NIM models to use.
286+
# For example:
287+
# models:
288+
# meta/llama-3.1-8b-instruct:
289+
# image: nvcr.io/nim/meta/llama-3.1-8b-instruct:1.3.3
290+
# imagePullPolicy: IfNotPresent
291+
# modelName: meta/llama-3.1-8b-instruct
292+
# modelVersion: 1.3.3
293+
# openaiPort: 8000
294+
# logLevel: DEBUG
295+
# resources:
296+
# requests:
297+
# cpu: 0
298+
# memory: 0
299+
# limits:
300+
# cpu: 0
301+
# memory: 0
302+
# nvidia.com/gpu: 1
303+
# volume:
304+
# storageClassName: "standard"
305+
# size: "50Gi"
306+
# accessMode: "ReadWriteOnce"
307+
# +docs:type=property
308+
models: {}
309+
280310
autoscaler:
281311
# If set to true, the request base autoscaler will be enabled.
282312
# NOTE: In ollama dynamic-model-loading mode, volume sharing is required.

deployments/server/templates/configmap.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,3 +47,8 @@ data:
4747
gracefulShutdownTimeout: {{ .Values.gracefulShutdownTimeout }}
4848
serverPodLabelKey: app.kubernetes.io/name
4949
serverPodLabelValue: {{ include "inference-manager-server.name" . }}
50+
{{- with .Values.nimModels }}
51+
nimModels:
52+
{{- toYaml . | nindent 4 }}
53+
{{- end }}
54+

deployments/server/values.schema.json

Lines changed: 1 addition & 1 deletion
Large diffs are not rendered by default.

deployments/server/values.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,12 @@ vectorStoreManagerServerAddr: vector-store-manager-server-grpc:8081
8989
# The address of the vector-store-manager-server to call internal vector-store APIs.
9090
vectorStoreManagerInternalServerAddr: vector-store-manager-server-internal-grpc:8083
9191

92+
# The array of model names to be served by NIM backend.
93+
# For example:
94+
# nimModels:
95+
# - meta/llama-3.1-8b-instruct
96+
# +docs:type=property
97+
nimModels: []
9298

9399
engineHeartbeat:
94100
# Set to true to enable heartbeats from the server to engines.

engine/cmd/run.go

Lines changed: 47 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,7 @@ func run(ctx context.Context, c *config.Config, ns string, lv int) error {
173173
modelManager runtime.ModelManager
174174
)
175175

176+
nimModels := make(map[string]bool)
176177
errCh := make(chan error)
177178
if c.Ollama.DynamicModelLoading {
178179
pullerAddr := fmt.Sprintf("%s:%d", ollamaClient.GetName(""), c.Runtime.PullerPort)
@@ -185,27 +186,43 @@ func run(ctx context.Context, c *config.Config, ns string, lv int) error {
185186
modelManager = ollamaManager
186187

187188
} else {
189+
clients := map[string]runtime.Client{
190+
config.RuntimeNameOllama: ollamaClient,
191+
config.RuntimeNameVLLM: runtime.NewVLLMClient(
192+
mgr.GetClient(),
193+
ns,
194+
owner,
195+
&c.Runtime,
196+
processedConfig,
197+
modelClient,
198+
&c.VLLM,
199+
),
200+
config.RuntimeNameTriton: runtime.NewTritonClient(
201+
mgr.GetClient(),
202+
ns,
203+
owner,
204+
&c.Runtime,
205+
processedConfig,
206+
),
207+
}
208+
209+
nimClients := make(map[string]runtime.Client)
210+
for _, model := range c.NIM.Models {
211+
nimClients[model.ModelName] = runtime.NewNIMClient(
212+
mgr.GetClient(),
213+
ns,
214+
owner,
215+
&c.Runtime,
216+
&c.NIM,
217+
&model,
218+
)
219+
nimModels[model.ModelName] = true
220+
}
221+
188222
rtClientFactory := &clientFactory{
189-
config: c,
190-
clients: map[string]runtime.Client{
191-
config.RuntimeNameOllama: ollamaClient,
192-
config.RuntimeNameVLLM: runtime.NewVLLMClient(
193-
mgr.GetClient(),
194-
ns,
195-
owner,
196-
&c.Runtime,
197-
processedConfig,
198-
modelClient,
199-
&c.VLLM,
200-
),
201-
config.RuntimeNameTriton: runtime.NewTritonClient(
202-
mgr.GetClient(),
203-
ns,
204-
owner,
205-
&c.Runtime,
206-
processedConfig,
207-
),
208-
},
223+
config: c,
224+
clients: clients,
225+
nimClients: nimClients,
209226
}
210227

211228
rtManager := runtime.NewManager(
@@ -215,6 +232,7 @@ func run(ctx context.Context, c *config.Config, ns string, lv int) error {
215232
modelClient,
216233
c.VLLM.DynamicLoRALoading,
217234
c.Runtime.PullerPort,
235+
nimModels,
218236
)
219237
if err := rtManager.SetupWithManager(mgr, leaderElection); err != nil {
220238
return err
@@ -278,6 +296,7 @@ func run(ctx context.Context, c *config.Config, ns string, lv int) error {
278296
logger,
279297
collector,
280298
c.GracefulShutdownTimeout,
299+
nimModels,
281300
)
282301
if err := p.SetupWithManager(mgr, leaderElection); err != nil {
283302
return err
@@ -323,11 +342,17 @@ func (f *grpcClientFactory) Create() (processor.ProcessTasksClient, func(), erro
323342
}
324343

325344
type clientFactory struct {
326-
config *config.Config
327-
clients map[string]runtime.Client
345+
config *config.Config
346+
clients map[string]runtime.Client
347+
nimClients map[string]runtime.Client
328348
}
329349

330350
func (f *clientFactory) New(modelID string) (runtime.Client, error) {
351+
// skip processing model config if the model is served by NIM runtime.
352+
if _, ok := f.config.NIM.Models[modelID]; ok {
353+
return f.nimClients[modelID], nil
354+
}
355+
331356
mci := config.NewProcessedModelConfig(f.config).ModelConfigItem(modelID)
332357
c, ok := f.clients[mci.RuntimeName]
333358
if !ok {

0 commit comments

Comments
 (0)