Skip to content

Commit 44e025c

Browse files
authored
[AQUA] Added support for deploy stack model. (#1223)
2 parents 43d0a37 + 190e2d8 commit 44e025c

File tree

7 files changed

+367
-64
lines changed

7 files changed

+367
-64
lines changed

ads/aqua/model/model.py

Lines changed: 118 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -141,19 +141,19 @@ class AquaModelApp(AquaApp):
141141
@telemetry(entry_point="plugin=model&action=create", name="aqua")
142142
def create(
143143
self,
144-
model_id: Union[str, AquaMultiModelRef],
144+
model: Union[str, AquaMultiModelRef],
145145
project_id: Optional[str] = None,
146146
compartment_id: Optional[str] = None,
147147
freeform_tags: Optional[Dict] = None,
148148
defined_tags: Optional[Dict] = None,
149149
**kwargs,
150-
) -> DataScienceModel:
150+
) -> Union[DataScienceModel, DataScienceModelGroup]:
151151
"""
152-
Creates a custom Aqua model from a service model.
152+
Creates a custom Aqua model or model group from a service model.
153153

154154
Parameters
155155
----------
156-
model_id : Union[str, AquaMultiModelRef]
156+
model : Union[str, AquaMultiModelRef]
157157
The model ID as a string or a AquaMultiModelRef instance to be deployed.
158158
project_id : Optional[str]
159159
The project ID for the custom model.
@@ -167,28 +167,18 @@ def create(
167167

168168
Returns
169169
-------
170-
DataScienceModel
171-
The instance of DataScienceModel.
170+
Union[DataScienceModel, DataScienceModelGroup]
171+
The instance of DataScienceModel or DataScienceModelGroup.
172172
"""
173-
model_id = (
174-
model_id.model_id if isinstance(model_id, AquaMultiModelRef) else model_id
175-
)
176-
service_model = DataScienceModel.from_id(model_id)
173+
fine_tune_weights = []
174+
if isinstance(model, AquaMultiModelRef):
175+
fine_tune_weights = model.fine_tune_weights
176+
model = model.model_id
177+
178+
service_model = DataScienceModel.from_id(model)
177179
target_project = project_id or PROJECT_OCID
178180
target_compartment = compartment_id or COMPARTMENT_OCID
179181

180-
# Skip model copying if it is registered model or fine-tuned model
181-
if (
182-
service_model.freeform_tags.get(Tags.BASE_MODEL_CUSTOM, None) is not None
183-
or service_model.freeform_tags.get(Tags.AQUA_FINE_TUNED_MODEL_TAG)
184-
is not None
185-
):
186-
logger.info(
187-
f"Aqua Model {model_id} already exists in the user's compartment."
188-
"Skipped copying."
189-
)
190-
return service_model
191-
192182
# combine tags
193183
combined_freeform_tags = {
194184
**(service_model.freeform_tags or {}),
@@ -199,29 +189,112 @@ def create(
199189
**(defined_tags or {}),
200190
}
201191

192+
custom_model = None
193+
if fine_tune_weights:
194+
custom_model = self._create_model_group(
195+
model_id=model,
196+
compartment_id=target_compartment,
197+
project_id=target_project,
198+
freeform_tags=combined_freeform_tags,
199+
defined_tags=combined_defined_tags,
200+
fine_tune_weights=fine_tune_weights,
201+
service_model=service_model,
202+
)
203+
204+
logger.info(
205+
f"Aqua Model Group {custom_model.id} created with the service model {model}."
206+
)
207+
else:
208+
# Skip model copying if it is registered model or fine-tuned model
209+
if (
210+
Tags.BASE_MODEL_CUSTOM in service_model.freeform_tags
211+
or Tags.AQUA_FINE_TUNED_MODEL_TAG in service_model.freeform_tags
212+
):
213+
logger.info(
214+
f"Aqua Model {model} already exists in the user's compartment."
215+
"Skipped copying."
216+
)
217+
return service_model
218+
219+
custom_model = self._create_model(
220+
compartment_id=target_compartment,
221+
project_id=target_project,
222+
freeform_tags=combined_freeform_tags,
223+
defined_tags=combined_defined_tags,
224+
service_model=service_model,
225+
**kwargs,
226+
)
227+
logger.info(
228+
f"Aqua Model {custom_model.id} created with the service model {model}."
229+
)
230+
231+
# Track unique models that were created in the user's compartment
232+
self.telemetry.record_event_async(
233+
category="aqua/service/model",
234+
action="create",
235+
detail=service_model.display_name,
236+
)
237+
238+
return custom_model
239+
240+
def _create_model(
241+
self,
242+
compartment_id: str,
243+
project_id: str,
244+
freeform_tags: Dict,
245+
defined_tags: Dict,
246+
service_model: DataScienceModel,
247+
**kwargs,
248+
):
249+
"""Creates a data science model by reference."""
202250
custom_model = (
203251
DataScienceModel()
204-
.with_compartment_id(target_compartment)
205-
.with_project_id(target_project)
252+
.with_compartment_id(compartment_id)
253+
.with_project_id(project_id)
206254
.with_model_file_description(json_dict=service_model.model_file_description)
207255
.with_display_name(service_model.display_name)
208256
.with_description(service_model.description)
209-
.with_freeform_tags(**combined_freeform_tags)
210-
.with_defined_tags(**combined_defined_tags)
257+
.with_freeform_tags(**freeform_tags)
258+
.with_defined_tags(**defined_tags)
211259
.with_custom_metadata_list(service_model.custom_metadata_list)
212260
.with_defined_metadata_list(service_model.defined_metadata_list)
213261
.with_provenance_metadata(service_model.provenance_metadata)
214262
.create(model_by_reference=True, **kwargs)
215263
)
216-
logger.info(
217-
f"Aqua Model {custom_model.id} created with the service model {model_id}."
218-
)
219264

220-
# Track unique models that were created in the user's compartment
221-
self.telemetry.record_event_async(
222-
category="aqua/service/model",
223-
action="create",
224-
detail=service_model.display_name,
265+
return custom_model
266+
267+
def _create_model_group(
268+
self,
269+
model_id: str,
270+
compartment_id: str,
271+
project_id: str,
272+
freeform_tags: Dict,
273+
defined_tags: Dict,
274+
fine_tune_weights: List,
275+
service_model: DataScienceModel,
276+
):
277+
"""Creates a data science model group."""
278+
custom_model = (
279+
DataScienceModelGroup()
280+
.with_compartment_id(compartment_id)
281+
.with_project_id(project_id)
282+
.with_display_name(service_model.display_name)
283+
.with_description(service_model.description)
284+
.with_freeform_tags(**freeform_tags)
285+
.with_defined_tags(**defined_tags)
286+
.with_custom_metadata_list(service_model.custom_metadata_list)
287+
.with_base_model_id(model_id)
288+
.with_member_models(
289+
[
290+
{
291+
"inference_key": fine_tune_weight.model_name,
292+
"model_id": fine_tune_weight.model_id,
293+
}
294+
for fine_tune_weight in fine_tune_weights
295+
]
296+
)
297+
.create()
225298
)
226299

227300
return custom_model
@@ -271,6 +344,16 @@ def create_multi(
271344
DataScienceModelGroup
272345
Instance of DataScienceModelGroup object.
273346
"""
347+
member_model_ids = [{"model_id": model.model_id} for model in models]
348+
for model in models:
349+
if model.fine_tune_weights:
350+
member_model_ids.extend(
351+
[
352+
{"model_id": fine_tune_model.model_id}
353+
for fine_tune_model in model.fine_tune_weights
354+
]
355+
)
356+
274357
custom_model_group = (
275358
DataScienceModelGroup()
276359
.with_compartment_id(compartment_id)
@@ -281,7 +364,7 @@ def create_multi(
281364
.with_defined_tags(**(defined_tags or {}))
282365
.with_custom_metadata_list(model_custom_metadata)
283366
# TODO: add member model inference key
284-
.with_member_models([{"model_id": model.model_id for model in models}])
367+
.with_member_models(member_model_ids)
285368
)
286369
custom_model_group.create()
287370

ads/aqua/modeldeployment/constants.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,12 @@
99
This module contains constants used in Aqua Model Deployment.
1010
"""
1111

12+
from ads.common.extended_enum import ExtendedEnum
13+
1214
DEFAULT_WAIT_TIME = 12000
1315
DEFAULT_POLL_INTERVAL = 10
16+
17+
18+
class DeploymentType(ExtendedEnum):
19+
STACKED = "STACKED"
20+
MULTI = "MULTI"

ads/aqua/modeldeployment/deployment.py

Lines changed: 45 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,13 @@
2424
from ads.aqua.common.errors import AquaRuntimeError, AquaValueError
2525
from ads.aqua.common.utils import (
2626
DEFINED_METADATA_TO_FILE_MAP,
27+
build_params_string,
2728
build_pydantic_error_message,
2829
find_restricted_params,
2930
get_combined_params,
3031
get_container_params_type,
3132
get_ocid_substring,
33+
get_params_dict,
3234
get_params_list,
3335
get_preferred_compatible_family,
3436
get_resource_name,
@@ -61,7 +63,11 @@
6163
ModelDeploymentConfigSummary,
6264
MultiModelDeploymentConfigLoader,
6365
)
64-
from ads.aqua.modeldeployment.constants import DEFAULT_POLL_INTERVAL, DEFAULT_WAIT_TIME
66+
from ads.aqua.modeldeployment.constants import (
67+
DEFAULT_POLL_INTERVAL,
68+
DEFAULT_WAIT_TIME,
69+
DeploymentType,
70+
)
6571
from ads.aqua.modeldeployment.entities import (
6672
AquaDeployment,
6773
AquaDeploymentDetail,
@@ -76,6 +82,7 @@
7682
AQUA_DEPLOYMENT_CONTAINER_CMD_VAR_METADATA_NAME,
7783
AQUA_DEPLOYMENT_CONTAINER_METADATA_NAME,
7884
AQUA_DEPLOYMENT_CONTAINER_URI_METADATA_NAME,
85+
AQUA_MODEL_DEPLOYMENT_FOLDER,
7986
AQUA_TELEMETRY_BUCKET,
8087
AQUA_TELEMETRY_BUCKET_NS,
8188
COMPARTMENT_OCID,
@@ -162,6 +169,7 @@ def create(
162169
cmd_var (Optional[List[str]]): Command variables for the container runtime.
163170
freeform_tags (Optional[Dict]): Freeform tags for model deployment.
164171
defined_tags (Optional[Dict]): Defined tags for model deployment.
172+
deployment_type (Optional[str]): The type of model deployment.
165173

166174
Returns
167175
-------
@@ -206,13 +214,26 @@ def create(
206214

207215
# Create an AquaModelApp instance once to perform the deployment creation.
208216
model_app = AquaModelApp()
209-
if create_deployment_details.model_id:
217+
if (
218+
create_deployment_details.model_id
219+
or create_deployment_details.deployment_type == DeploymentType.STACKED
220+
):
221+
model = create_deployment_details.model_id
222+
if not model:
223+
if len(create_deployment_details.models) != 1:
224+
raise AquaValueError(
225+
"Invalid 'models' provided. Only one base model is required for model stack deployment."
226+
)
227+
model = create_deployment_details.models[0]
228+
229+
service_model_id = model if isinstance(model, str) else model.model_id
210230
logger.debug(
211-
f"Single model ({create_deployment_details.model_id}) provided. "
231+
f"Single model ({service_model_id}) provided. "
212232
"Delegating to single model creation method."
213233
)
234+
214235
aqua_model = model_app.create(
215-
model_id=create_deployment_details.model_id,
236+
model=model,
216237
compartment_id=compartment_id,
217238
project_id=project_id,
218239
freeform_tags=freeform_tags,
@@ -231,6 +252,7 @@ def create(
231252
create_deployment_details=create_deployment_details,
232253
container_config=container_config,
233254
)
255+
# TODO: add multi model validation from deployment_type
234256
else:
235257
# Collect all unique model IDs (including fine-tuned models)
236258
source_model_ids = list(
@@ -685,7 +707,7 @@ def _build_model_group_config(
685707

686708
def _create(
687709
self,
688-
aqua_model: DataScienceModel,
710+
aqua_model: Union[DataScienceModel, DataScienceModelGroup],
689711
create_deployment_details: CreateModelDeploymentDetails,
690712
container_config: Dict,
691713
) -> AquaDeployment:
@@ -719,7 +741,10 @@ def _create(
719741
tags.update({Tags.TASK: aqua_model.freeform_tags.get(Tags.TASK, UNKNOWN)})
720742

721743
# Set up info to get deployment config
722-
config_source_id = create_deployment_details.model_id
744+
config_source_id = (
745+
create_deployment_details.model_id
746+
or create_deployment_details.models[0].model_id
747+
)
723748
model_name = aqua_model.display_name
724749

725750
# set up env and cmd var
@@ -870,6 +895,20 @@ def _create(
870895
deployment_params = get_combined_params(config_params, user_params)
871896

872897
params = f"{params} {deployment_params}".strip()
898+
899+
if isinstance(aqua_model, DataScienceModelGroup):
900+
env_var.update({"VLLM_ALLOW_RUNTIME_LORA_UPDATING": "true"})
901+
env_var.update(
902+
{"MODEL": f"{AQUA_MODEL_DEPLOYMENT_FOLDER}{aqua_model.base_model_id}/"}
903+
)
904+
905+
params_dict = get_params_dict(params)
906+
# updates `--served-model-name` with service model id
907+
params_dict.update({"--served-model-name": aqua_model.base_model_id})
908+
# adds `--enable_lora` to parameters
909+
params_dict.update({"--enable_lora": UNKNOWN})
910+
params = build_params_string(params_dict)
911+
873912
if params:
874913
env_var.update({"PARAMS": params})
875914
env_vars = container_spec.env_vars if container_spec else []

ads/aqua/modeldeployment/entities.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,9 @@ class CreateModelDeploymentDetails(BaseModel):
325325
defined_tags: Optional[Dict] = Field(
326326
None, description="Defined tags for model deployment."
327327
)
328+
deployment_type: Optional[str] = Field(
329+
None, description="The type of model deployment."
330+
)
328331

329332
@model_validator(mode="before")
330333
@classmethod

0 commit comments

Comments
 (0)