Skip to content

Commit 941796f

Browse files
authored
Added support to create multi model deployment. (#1072)
2 parents 02a3ce5 + 2c803e6 commit 941796f

File tree

11 files changed

+727
-42
lines changed

11 files changed

+727
-42
lines changed

ads/aqua/app.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,11 @@
66
import os
77
import traceback
88
from dataclasses import fields
9+
from datetime import datetime, timedelta
910
from typing import Dict, Optional, Union
1011

1112
import oci
13+
from cachetools import TTLCache, cached
1214
from oci.data_science.models import UpdateModelDetails, UpdateModelProvenanceDetails
1315

1416
from ads import set_auth
@@ -268,6 +270,7 @@ def if_artifact_exist(self, model_id: str, **kwargs) -> bool:
268270
logger.info(f"Artifact not found in model {model_id}.")
269271
return False
270272

273+
@cached(cache=TTLCache(maxsize=1, ttl=timedelta(minutes=1), timer=datetime.now))
271274
def get_config(
272275
self,
273276
model_id: str,

ads/aqua/common/entities.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
from typing import Optional
66

7+
from pydantic import Field
8+
79
from ads.aqua.config.utils.serializer import Serializable
810

911

@@ -42,16 +44,22 @@ class AquaMultiModelRef(Serializable):
4244
----------
4345
model_id : str
4446
The unique identifier of the model.
47+
model_name : Optional[str]
48+
The name of the model.
4549
gpu_count : Optional[int]
4650
Number of GPUs required for deployment.
4751
env_var : Optional[Dict[str, Any]]
4852
Optional environment variables to override during deployment.
4953
"""
5054

51-
model_id: str
52-
model_name: Optional[str] = None
53-
gpu_count: Optional[int] = None
54-
env_var: Optional[dict] = None
55+
model_id: str = Field(..., description="The model OCID to deploy.")
56+
model_name: Optional[str] = Field(None, description="The name of model.")
57+
gpu_count: Optional[int] = Field(
58+
None, description="The gpu count allocation for the model."
59+
)
60+
env_var: Optional[dict] = Field(
61+
default_factory=dict, description="The environment variables of the model."
62+
)
5563

5664
class Config:
5765
extra = "ignore"

ads/aqua/common/enums.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ class Tags(ExtendedEnum):
3131
AQUA_TAG = "OCI_AQUA"
3232
AQUA_SERVICE_MODEL_TAG = "aqua_service_model"
3333
AQUA_FINE_TUNED_MODEL_TAG = "aqua_fine_tuned_model"
34+
AQUA_MODEL_ID_TAG = "aqua_model_id"
3435
AQUA_MODEL_NAME_TAG = "aqua_model_name"
3536
AQUA_EVALUATION = "aqua_evaluation"
3637
AQUA_FINE_TUNING = "aqua_finetuning"

ads/aqua/constants.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#!/usr/bin/env python
2-
# Copyright (c) 2024 Oracle and/or its affiliates.
2+
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
33
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
44
"""This module defines constants used in ads.aqua module."""
55

@@ -30,14 +30,17 @@
3030
READY_TO_FINE_TUNE_STATUS = "TRUE"
3131
PRIVATE_ENDPOINT_TYPE = "MODEL_DEPLOYMENT"
3232
AQUA_GA_LIST = ["id19sfcrra6z"]
33+
AQUA_MULTI_MODEL_CONFIG = "MULTI_MODEL_CONFIG"
3334
AQUA_MODEL_TYPE_SERVICE = "service"
3435
AQUA_MODEL_TYPE_CUSTOM = "custom"
36+
AQUA_MODEL_TYPE_MULTI = "multi_model"
3537
AQUA_MODEL_ARTIFACT_CONFIG = "config.json"
3638
AQUA_MODEL_ARTIFACT_CONFIG_MODEL_NAME = "_name_or_path"
3739
AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE = "model_type"
3840
AQUA_MODEL_ARTIFACT_FILE = "model_file"
3941
HF_METADATA_FOLDER = ".cache/"
4042
HF_LOGIN_DEFAULT_TIMEOUT = 2
43+
MODEL_NAME_DELIMITER = ";"
4144

4245
TRAINING_METRICS_FINAL = "training_metrics_final"
4346
VALIDATION_METRICS_FINAL = "validation_metrics_final"

ads/aqua/model/model.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
FineTuningCustomMetadata,
6666
FineTuningMetricCategories,
6767
ModelCustomMetadataFields,
68+
ModelTask,
6869
ModelType,
6970
)
7071
from ads.aqua.model.entities import (
@@ -254,12 +255,31 @@ def create_multi(
254255
artifact_list = []
255256
display_name_list = []
256257
model_custom_metadata = ModelCustomMetadata()
257-
default_deployment_container = None
258+
# TODO: update it when more deployment containers are supported
259+
default_deployment_container = (
260+
InferenceContainerTypeFamily.AQUA_VLLM_CONTAINER_FAMILY
261+
)
258262

259263
# Process each model
260264
for idx, model in enumerate(models):
261265
source_model = DataScienceModel.from_id(model.model_id)
262266
display_name = source_model.display_name
267+
268+
if not source_model.freeform_tags.get(Tags.AQUA_SERVICE_MODEL_TAG, UNKNOWN):
269+
raise AquaValueError(
270+
f"Invalid selected model {display_name}. "
271+
"Currently only service models are supported for multi model deployment."
272+
)
273+
274+
if (
275+
source_model.freeform_tags.get(Tags.TASK, UNKNOWN)
276+
!= ModelTask.TEXT_GENERATION
277+
):
278+
raise AquaValueError(
279+
f"Invalid or missing {Tags.TASK} tag for selected model {display_name}. "
280+
f"Currently only {ModelTask.TEXT_GENERATION} models are support for multi model deployment."
281+
)
282+
263283
display_name_list.append(display_name)
264284

265285
# Retrieve model artifact
@@ -280,12 +300,10 @@ def create_multi(
280300
),
281301
).value
282302

283-
if idx == 0:
284-
default_deployment_container = deployment_container
285-
elif deployment_container != default_deployment_container:
303+
if default_deployment_container != deployment_container:
286304
raise AquaValueError(
287-
"Deployment container mismatch detected. "
288-
"All selected models must use the same deployment container."
305+
f"Unsupported deployment container '{deployment_container}' for model '{source_model.id}'. "
306+
f"Only '{InferenceContainerTypeFamily.AQUA_VLLM_CONTAINER_FAMILY}' is supported for multi-model deployments."
289307
)
290308

291309
# Add model-specific metadata

0 commit comments

Comments
 (0)