Skip to content

Commit 215c35e

Browse files
committed
Update the lora scaling to be based on the replica count set as part of the ModelAdapter
Signed-off-by: dittops <[email protected]>
1 parent 4adb2fb commit 215c35e

File tree

3 files changed

+190
-40
lines changed

3 files changed

+190
-40
lines changed

pkg/controller/modeladapter/modeladapter_controller.go

Lines changed: 167 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -349,41 +349,9 @@ func (r *ModelAdapterReconciler) DoReconcile(ctx context.Context, req ctrl.Reque
349349

350350
oldInstance := instance.DeepCopy()
351351

352-
// Step 1: Sync Pod instances for ModelAdapter
353-
activePods, err := r.getActivePodsForModelAdapter(ctx, instance)
354-
if err != nil {
355-
return ctrl.Result{}, err
356-
}
357-
358-
activeMap := make(map[string]corev1.Pod, len(activePods))
359-
for _, p := range activePods {
360-
activeMap[p.Name] = p
361-
}
362-
363-
var updatedInstances []string
364-
for _, name := range instance.Status.Instances {
365-
if _, ok := activeMap[name]; ok {
366-
updatedInstances = append(updatedInstances, name)
367-
}
368-
}
369-
instance.Status.Instances = updatedInstances
370-
371-
added := false
372-
for name := range activeMap {
373-
if !StringInSlice(instance.Status.Instances, name) {
374-
instance.Status.Instances = append(instance.Status.Instances, name)
375-
added = true
376-
}
377-
}
378-
379-
if added {
380-
instance.Status.Phase = modelv1alpha1.ModelAdapterScheduled
381-
condition := NewCondition(string(modelv1alpha1.ModelAdapterConditionTypeScheduled), metav1.ConditionTrue,
382-
"Scheduled", fmt.Sprintf("ModelAdapter %s has been allocated to pods %v", klog.KObj(instance), instance.Status.Instances))
383-
if err := r.updateStatus(ctx, instance, condition); err != nil {
384-
return ctrl.Result{}, err
385-
}
386-
return ctrl.Result{Requeue: true}, nil
352+
// Step 1: Reconcile Pod instances for ModelAdapter based on desired replicas
353+
if ctrlResult, err := r.reconcileReplicas(ctx, instance); err != nil || ctrlResult.Requeue || ctrlResult.RequeueAfter > 0 {
354+
return ctrlResult, err
387355
}
388356

389357
// Step 2: Reconcile Loading
@@ -521,11 +489,122 @@ func (r *ModelAdapterReconciler) getActivePodsForModelAdapter(ctx context.Contex
521489
return activePods, nil
522490
}
523491

524-
// schedulePod picks a valid pod to schedule the model adapter
525-
func (r *ModelAdapterReconciler) schedulePod(ctx context.Context, instance *modelv1alpha1.ModelAdapter, activePods []corev1.Pod) (*corev1.Pod, error) {
526-
// Implement your scheduling logic here to select a Pod based on the instance.Spec.PodSelector
527-
// For the sake of example, we will just list the Pods matching the selector and pick the first one
528-
return r.scheduler.SelectPod(ctx, instance.Name, activePods)
492+
// reconcileReplicas ensures the desired number of replicas are scheduled
493+
func (r *ModelAdapterReconciler) reconcileReplicas(ctx context.Context, instance *modelv1alpha1.ModelAdapter) (ctrl.Result, error) {
494+
// Get all active pods matching the selector
495+
activePods, err := r.getActivePodsForModelAdapter(ctx, instance)
496+
if err != nil {
497+
return ctrl.Result{}, err
498+
}
499+
500+
// Create a map of active pods for quick lookup
501+
activeMap := make(map[string]corev1.Pod, len(activePods))
502+
for _, p := range activePods {
503+
activeMap[p.Name] = p
504+
}
505+
506+
// Remove instances that are no longer active
507+
var validInstances []string
508+
for _, name := range instance.Status.Instances {
509+
if _, ok := activeMap[name]; ok {
510+
validInstances = append(validInstances, name)
511+
}
512+
}
513+
instance.Status.Instances = validInstances
514+
515+
// Get desired replicas (default to 1 if not specified)
516+
desiredReplicas := int32(1)
517+
if instance.Spec.Replicas != nil {
518+
desiredReplicas = *instance.Spec.Replicas
519+
}
520+
521+
currentReplicas := int32(len(instance.Status.Instances))
522+
523+
// Scale up if needed
524+
if currentReplicas < desiredReplicas {
525+
// Get pods that are not yet scheduled
526+
unscheduledPods := []corev1.Pod{}
527+
for _, pod := range activePods {
528+
if !StringInSlice(instance.Status.Instances, pod.Name) {
529+
unscheduledPods = append(unscheduledPods, pod)
530+
}
531+
}
532+
533+
// Schedule additional pods
534+
neededReplicas := int(desiredReplicas - currentReplicas)
535+
if len(unscheduledPods) >= neededReplicas {
536+
newPods, err := r.schedulePods(ctx, instance, unscheduledPods, neededReplicas)
537+
if err != nil {
538+
return ctrl.Result{}, err
539+
}
540+
541+
for _, pod := range newPods {
542+
instance.Status.Instances = append(instance.Status.Instances, pod.Name)
543+
}
544+
545+
instance.Status.Phase = modelv1alpha1.ModelAdapterScheduled
546+
condition := NewCondition(string(modelv1alpha1.ModelAdapterConditionTypeScheduled), metav1.ConditionTrue,
547+
"Scheduled", fmt.Sprintf("ModelAdapter %s has been allocated to %d pods: %v", klog.KObj(instance), len(instance.Status.Instances), instance.Status.Instances))
548+
if err := r.updateStatus(ctx, instance, condition); err != nil {
549+
return ctrl.Result{}, err
550+
}
551+
return ctrl.Result{Requeue: true}, nil
552+
} else if len(unscheduledPods) > 0 {
553+
// Not enough pods available, schedule what we can
554+
klog.Warningf("Only %d pods available for model adapter %s, need %d more", len(unscheduledPods), klog.KObj(instance), neededReplicas)
555+
}
556+
} else if currentReplicas > desiredReplicas {
557+
// Scale down - remove excess instances
558+
excessCount := int(currentReplicas - desiredReplicas)
559+
removedInstances := instance.Status.Instances[len(instance.Status.Instances)-excessCount:]
560+
instance.Status.Instances = instance.Status.Instances[:len(instance.Status.Instances)-excessCount]
561+
562+
// Unload adapters from removed instances
563+
for _, podName := range removedInstances {
564+
if err := r.unloadModelAdapterFromPod(ctx, instance, podName); err != nil {
565+
klog.Warningf("Failed to unload adapter from pod %s: %v", podName, err)
566+
}
567+
}
568+
569+
instance.Status.Phase = modelv1alpha1.ModelAdapterScaled
570+
condition := NewCondition(string(modelv1alpha1.ModelAdapterConditionTypeScheduled), metav1.ConditionTrue,
571+
"Scaled", fmt.Sprintf("ModelAdapter %s scaled to %d replicas", klog.KObj(instance), desiredReplicas))
572+
if err := r.updateStatus(ctx, instance, condition); err != nil {
573+
return ctrl.Result{}, err
574+
}
575+
return ctrl.Result{Requeue: true}, nil
576+
}
577+
578+
return ctrl.Result{}, nil
579+
}
580+
581+
// schedulePods selects multiple pods to schedule the model adapter based on the configured scheduler policy
582+
func (r *ModelAdapterReconciler) schedulePods(ctx context.Context, instance *modelv1alpha1.ModelAdapter, availablePods []corev1.Pod, count int) ([]corev1.Pod, error) {
583+
if count <= 0 || len(availablePods) == 0 {
584+
return nil, nil
585+
}
586+
587+
selectedPods := []corev1.Pod{}
588+
remainingPods := append([]corev1.Pod{}, availablePods...)
589+
590+
for i := 0; i < count && len(remainingPods) > 0; i++ {
591+
pod, err := r.scheduler.SelectPod(ctx, instance.Name, remainingPods)
592+
if err != nil {
593+
return nil, err
594+
}
595+
596+
selectedPods = append(selectedPods, *pod)
597+
598+
// Remove selected pod from remaining pods to avoid selecting it again
599+
for j, p := range remainingPods {
600+
if p.Name == pod.Name {
601+
remainingPods = append(remainingPods[:j], remainingPods[j+1:]...)
602+
break
603+
}
604+
}
605+
}
606+
607+
return selectedPods, nil
529608
}
530609

531610
func (r *ModelAdapterReconciler) reconcileLoading(ctx context.Context, instance *modelv1alpha1.ModelAdapter) error {
@@ -726,6 +805,54 @@ func (r *ModelAdapterReconciler) unloadModelAdapter(ctx context.Context, instanc
726805
return nil
727806
}
728807

808+
// unloadModelAdapterFromPod unloads the adapter from a specific pod
809+
func (r *ModelAdapterReconciler) unloadModelAdapterFromPod(ctx context.Context, instance *modelv1alpha1.ModelAdapter, podName string) error {
810+
targetPod := &corev1.Pod{}
811+
if err := r.Get(ctx, types.NamespacedName{Namespace: instance.Namespace, Name: podName}, targetPod); err != nil {
812+
if apierrors.IsNotFound(err) {
813+
klog.Warningf("Failed to find lora Pod instance %s/%s from apiserver, skip unloading", instance.GetNamespace(), podName)
814+
return nil
815+
}
816+
return err
817+
}
818+
819+
payload := map[string]string{
820+
"lora_name": instance.Name,
821+
}
822+
payloadBytes, err := json.Marshal(payload)
823+
if err != nil {
824+
return err
825+
}
826+
827+
urls := BuildURLs(targetPod.Status.PodIP, r.RuntimeConfig)
828+
req, err := http.NewRequest("POST", urls.UnloadAdapterURL, bytes.NewBuffer(payloadBytes))
829+
if err != nil {
830+
return err
831+
}
832+
req.Header.Set("Content-Type", "application/json")
833+
if token, ok := instance.Spec.AdditionalConfig["api-key"]; ok {
834+
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
835+
}
836+
837+
httpClient := &http.Client{}
838+
resp, err := httpClient.Do(req)
839+
if err != nil {
840+
return nil // Don't fail on HTTP errors during unload
841+
}
842+
defer func() {
843+
if err := resp.Body.Close(); err != nil {
844+
klog.InfoS("Error closing response body:", err)
845+
}
846+
}()
847+
848+
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated {
849+
body, _ := io.ReadAll(resp.Body)
850+
klog.Warningf("failed to unload LoRA adapter from pod %s: %s", podName, body)
851+
}
852+
853+
return nil
854+
}
855+
729856
func (r *ModelAdapterReconciler) reconcileService(ctx context.Context, instance *modelv1alpha1.ModelAdapter) (ctrl.Result, error) {
730857
// Retrieve the Service from the Kubernetes cluster with the name and namespace.
731858
found := &corev1.Service{}
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
apiVersion: model.aibrix.ai/v1alpha1
2+
kind: ModelAdapter
3+
metadata:
4+
name: sample-lora-multi-replica
5+
namespace: default
6+
spec:
7+
# Specify the number of replicas for the adapter
8+
# The adapter will be loaded on this many pods matching the selector
9+
replicas: 3
10+
# Pod selector to identify which pods can host this adapter
11+
podSelector:
12+
matchLabels:
13+
model.aibrix.ai/name: qwen-coder-1-5b-instruct
14+
adapter.model.aibrix.ai/enabled: "true"
15+
# URL for the LoRA adapter artifact
16+
artifactURL: "huggingface://SomethingNew/lora-adapter-demo"
17+
# Optional: Additional configuration
18+
additionalConfig:
19+
rank: "16"
20+
alpha: "32"

samples/adapter/adapter.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@ metadata:
77
model.aibrix.ai/name: "qwen-code-lora"
88
model.aibrix.ai/port: "8000"
99
spec:
10+
# Optional: Number of replicas for the adapter (default: 1)
11+
# Uncomment to load adapter on multiple pods
12+
# replicas: 3
1013
baseModel: qwen-coder-1-5b-instruct
1114
podSelector:
1215
matchLabels:

0 commit comments

Comments
 (0)