Skip to content

Commit dbdb8c9

Browse files
authored
feat: support configuring init container image (#443)
Signed-off-by: rudeigerc <[email protected]>
1 parent d47e44f commit dbdb8c9

File tree

11 files changed

+137
-55
lines changed

11 files changed

+137
-55
lines changed

config/default/configmap.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,4 @@ metadata:
55
data:
66
config.data: |
77
scheduler-name: default-scheduler
8-
# init-container-image: inftyai/model-loader:v0.0.10
8+
init-container-image: inftyai/model-loader:v0.0.10

pkg/controller/inference/service_controller.go

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -116,10 +116,11 @@ func (r *ServiceReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ct
116116
return ctrl.Result{}, err
117117
}
118118

119-
workloadApplyConfiguration, err := buildWorkloadApplyConfiguration(service, models)
119+
workloadApplyConfiguration, err := buildWorkloadApplyConfiguration(service, models, configs)
120120
if err != nil {
121121
return ctrl.Result{}, err
122122
}
123+
123124
if err := setControllerReferenceForWorkload(service, workloadApplyConfiguration, r.Scheme); err != nil {
124125
return ctrl.Result{}, err
125126
}
@@ -162,7 +163,7 @@ func (r *ServiceReconciler) SetupWithManager(mgr ctrl.Manager) error {
162163
Complete(r)
163164
}
164165

165-
func buildWorkloadApplyConfiguration(service *inferenceapi.Service, models []*coreapi.OpenModel) (*applyconfigurationv1.LeaderWorkerSetApplyConfiguration, error) {
166+
func buildWorkloadApplyConfiguration(service *inferenceapi.Service, models []*coreapi.OpenModel, configs *helper.GlobalConfigs) (*applyconfigurationv1.LeaderWorkerSetApplyConfiguration, error) {
166167
workload := applyconfigurationv1.LeaderWorkerSet(service.Name, service.Namespace)
167168

168169
leaderWorkerTemplate := applyconfigurationv1.LeaderWorkerTemplate()
@@ -193,7 +194,7 @@ func buildWorkloadApplyConfiguration(service *inferenceapi.Service, models []*co
193194
leaderWorkerTemplate.WithWorkerTemplate(&podTemplateSpecApplyConfiguration)
194195

195196
// The core logic to inject additional configurations.
196-
injectModelProperties(leaderWorkerTemplate, models, service)
197+
injectModelProperties(leaderWorkerTemplate, models, service, configs)
197198

198199
spec := applyconfigurationv1.LeaderWorkerSetSpec()
199200
spec.WithLeaderWorkerTemplate(leaderWorkerTemplate)
@@ -215,17 +216,17 @@ func buildWorkloadApplyConfiguration(service *inferenceapi.Service, models []*co
215216
return workload, nil
216217
}
217218

218-
func injectModelProperties(template *applyconfigurationv1.LeaderWorkerTemplateApplyConfiguration, models []*coreapi.OpenModel, service *inferenceapi.Service) {
219+
func injectModelProperties(template *applyconfigurationv1.LeaderWorkerTemplateApplyConfiguration, models []*coreapi.OpenModel, service *inferenceapi.Service, configs *helper.GlobalConfigs) {
219220
isMultiNodesInference := template.LeaderTemplate != nil
220221

221222
for i, model := range models {
222223
source := modelSource.NewModelSourceProvider(model)
223224
// Skip model-loader initContainer if llmaz.io/skip-model-loader annotation is set.
224225
if !helper.SkipModelLoader(service) {
225226
if isMultiNodesInference {
226-
source.InjectModelLoader(template.LeaderTemplate, i)
227+
source.InjectModelLoader(template.LeaderTemplate, i, configs.InitContainerImage)
227228
}
228-
source.InjectModelLoader(template.WorkerTemplate, i)
229+
source.InjectModelLoader(template.WorkerTemplate, i, configs.InitContainerImage)
229230
} else {
230231
if isMultiNodesInference {
231232
source.InjectModelEnvVars(template.LeaderTemplate)

pkg/controller_helper/configmap.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,5 +40,16 @@ func ParseGlobalConfigmap(cm *corev1.ConfigMap) (*GlobalConfigs, error) {
4040
return nil, fmt.Errorf("failed to unmarshal config.data: %v", err)
4141
}
4242

43+
if err := configs.validate(); err != nil {
44+
return nil, fmt.Errorf("invalid global config: %v", err)
45+
}
46+
4347
return &configs, nil
4448
}
49+
50+
func (c *GlobalConfigs) validate() error {
51+
if c.InitContainerImage == "" {
52+
return fmt.Errorf("init-container-image is required")
53+
}
54+
return nil
55+
}
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
/*
2+
Copyright 2025 The InftyAI Team.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
package helper
18+
19+
import (
20+
"testing"
21+
22+
"github.com/stretchr/testify/assert"
23+
"github.com/stretchr/testify/require"
24+
)
25+
26+
func TestGlobalConfigs_validate(t *testing.T) {
27+
tests := []struct {
28+
name string
29+
config *GlobalConfigs
30+
expectError bool
31+
errorMsg string
32+
}{
33+
{
34+
name: "valid config",
35+
config: &GlobalConfigs{
36+
SchedulerName: "custom-scheduler",
37+
InitContainerImage: "inftyai/model-loader:v0.0.10",
38+
},
39+
expectError: false,
40+
},
41+
{
42+
name: "empty init container image",
43+
config: &GlobalConfigs{
44+
SchedulerName: "custom-scheduler",
45+
InitContainerImage: "",
46+
},
47+
expectError: true,
48+
errorMsg: "init-container-image is required",
49+
},
50+
}
51+
52+
for _, tt := range tests {
53+
t.Run(tt.name, func(t *testing.T) {
54+
err := tt.config.validate()
55+
56+
if tt.expectError {
57+
require.Error(t, err)
58+
assert.Contains(t, err.Error(), tt.errorMsg)
59+
} else {
60+
require.NoError(t, err)
61+
}
62+
})
63+
}
64+
}

pkg/controller_helper/modelsource/modelhub.go

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,6 @@ import (
2121
"strings"
2222

2323
coreapplyv1 "k8s.io/client-go/applyconfigurations/core/v1"
24-
25-
"github.com/inftyai/llmaz/pkg"
2624
)
2725

2826
var _ ModelSourceProvider = &ModelHubProvider{}
@@ -62,7 +60,7 @@ func (p *ModelHubProvider) ModelPath(skipModelLoader bool) string {
6260
return CONTAINER_MODEL_PATH + "models--" + strings.ReplaceAll(p.modelID, "/", "--")
6361
}
6462

65-
func (p *ModelHubProvider) InjectModelLoader(template *coreapplyv1.PodTemplateSpecApplyConfiguration, index int) {
63+
func (p *ModelHubProvider) InjectModelLoader(template *coreapplyv1.PodTemplateSpecApplyConfiguration, index int, initContainerImage string) {
6664
initContainerName := MODEL_LOADER_CONTAINER_NAME
6765
if index != 0 {
6866
initContainerName += "-" + strconv.Itoa(index)
@@ -71,7 +69,7 @@ func (p *ModelHubProvider) InjectModelLoader(template *coreapplyv1.PodTemplateSp
7169
// Handle initContainer.
7270
initContainer := coreapplyv1.Container().
7371
WithName(initContainerName).
74-
WithImage(pkg.LOADER_IMAGE).
72+
WithImage(initContainerImage).
7573
WithVolumeMounts(coreapplyv1.VolumeMount().WithName(MODEL_VOLUME_NAME).WithMountPath(CONTAINER_MODEL_PATH))
7674

7775
// We have exactly one container in the template.Spec.Containers.

pkg/controller_helper/modelsource/modelhub_test.go

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,6 @@ import (
2727
"github.com/stretchr/testify/assert"
2828
corev1 "k8s.io/api/core/v1"
2929
coreapplyv1 "k8s.io/client-go/applyconfigurations/core/v1"
30-
31-
"github.com/inftyai/llmaz/pkg"
3230
)
3331

3432
func Test_ModelHubProvider_InjectModelLoader(t *testing.T) {
@@ -38,10 +36,11 @@ func Test_ModelHubProvider_InjectModelLoader(t *testing.T) {
3836
ignorePatterns := []string{"*.tmp"}
3937

4038
tests := []struct {
41-
name string
42-
provider *ModelHubProvider
43-
index int
44-
expectMainModel bool
39+
name string
40+
provider *ModelHubProvider
41+
index int
42+
expectMainModel bool
43+
initContainerImage string
4544
}{
4645
{
4746
name: "inject full modelhub with fileName, revision, allow/ignore",
@@ -54,8 +53,9 @@ func Test_ModelHubProvider_InjectModelLoader(t *testing.T) {
5453
modelAllowPatterns: allowPatterns,
5554
modelIgnorePatterns: ignorePatterns,
5655
},
57-
index: 0,
58-
expectMainModel: true,
56+
index: 0,
57+
expectMainModel: true,
58+
initContainerImage: "model-loader:latest",
5959
},
6060
{
6161
name: "inject with index > 0 skips volume/container mount",
@@ -64,8 +64,20 @@ func Test_ModelHubProvider_InjectModelLoader(t *testing.T) {
6464
modelID: "some/model",
6565
modelHub: "Huggingface",
6666
},
67-
index: 1,
68-
expectMainModel: false,
67+
index: 1,
68+
expectMainModel: false,
69+
initContainerImage: "model-loader:latest",
70+
},
71+
{
72+
name: "inject with custom initContainerImage",
73+
provider: &ModelHubProvider{
74+
modelName: "llama3",
75+
modelID: "meta/llama-3",
76+
modelHub: "Huggingface",
77+
},
78+
index: 0,
79+
expectMainModel: true,
80+
initContainerImage: "custom-model-loader:latest",
6981
},
7082
}
7183

@@ -83,7 +95,7 @@ func Test_ModelHubProvider_InjectModelLoader(t *testing.T) {
8395
),
8496
)
8597

86-
tt.provider.InjectModelLoader(template, tt.index)
98+
tt.provider.InjectModelLoader(template, tt.index, tt.initContainerImage)
8799

88100
assert.Len(t, template.Spec.InitContainers, 1)
89101
initContainer := template.Spec.InitContainers[0]
@@ -92,8 +104,9 @@ func Test_ModelHubProvider_InjectModelLoader(t *testing.T) {
92104
if tt.index != 0 {
93105
expectedName += "-" + strconv.Itoa(tt.index)
94106
}
107+
expectedImage := tt.initContainerImage
95108
assert.Equal(t, expectedName, *initContainer.Name)
96-
assert.Equal(t, pkg.LOADER_IMAGE, *initContainer.Image)
109+
assert.Equal(t, expectedImage, *initContainer.Image)
97110

98111
wantEnv := buildExpectedEnv(tt.provider)
99112
if diff := cmp.Diff(wantEnv, initContainer.Env, envSortOpt); diff != "" {

pkg/controller_helper/modelsource/modelsource.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ type ModelSourceProvider interface {
5959
ModelPath(skipModelLoader bool) string
6060
// InjectModelLoader will inject the model loader to the spec,
6161
// index refers to the suffix of the initContainer name, like model-loader, model-loader-1.
62-
InjectModelLoader(spec *coreapplyv1.PodTemplateSpecApplyConfiguration, index int)
62+
InjectModelLoader(spec *coreapplyv1.PodTemplateSpecApplyConfiguration, index int, initContainerImage string)
6363
// InjectModelEnvVars will inject the model credentials env to the model-runner container.
6464
// This is used when the model-loader initContainer is not injected, and the model loading is handled by the model-runner container.
6565
InjectModelEnvVars(spec *coreapplyv1.PodTemplateSpecApplyConfiguration)

pkg/controller_helper/modelsource/modelsource_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ func TestEnvInjectModelLoader(t *testing.T) {
130130

131131
for _, tt := range tests {
132132
t.Run(tt.name, func(t *testing.T) {
133-
tt.provider.InjectModelLoader(tt.template, 0)
133+
tt.provider.InjectModelLoader(tt.template, 0, "model-loader:latest")
134134
initContainer := tt.template.Spec.InitContainers[0]
135135
assert.Subset(t, initContainer.Env, tt.template.Spec.Containers[0].Env)
136136
})

pkg/controller_helper/modelsource/uri.go

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,6 @@ import (
2121
"strings"
2222

2323
coreapplyv1 "k8s.io/client-go/applyconfigurations/core/v1"
24-
25-
"github.com/inftyai/llmaz/pkg"
2624
)
2725

2826
var _ ModelSourceProvider = &URIProvider{}
@@ -80,7 +78,7 @@ func (p *URIProvider) ModelPath(skipModelLoader bool) string {
8078
return CONTAINER_MODEL_PATH + "models--" + splits[len(splits)-1]
8179
}
8280

83-
func (p *URIProvider) InjectModelLoader(template *coreapplyv1.PodTemplateSpecApplyConfiguration, index int) {
81+
func (p *URIProvider) InjectModelLoader(template *coreapplyv1.PodTemplateSpecApplyConfiguration, index int, initContainerImage string) {
8482
// We don't have additional operations for Ollama, just load in runtime.
8583
if p.protocol == Ollama {
8684
return
@@ -112,10 +110,11 @@ func (p *URIProvider) InjectModelLoader(template *coreapplyv1.PodTemplateSpecApp
112110
if index != 0 {
113111
initContainerName += "-" + strconv.Itoa(index)
114112
}
113+
115114
// Handle initContainer.
116115
initContainer := coreapplyv1.Container().
117116
WithName(initContainerName).
118-
WithImage(pkg.LOADER_IMAGE).
117+
WithImage(initContainerImage).
119118
WithVolumeMounts(
120119
coreapplyv1.VolumeMount().
121120
WithName(MODEL_VOLUME_NAME).

pkg/defaults.go

Lines changed: 0 additions & 21 deletions
This file was deleted.

0 commit comments

Comments
 (0)