@@ -141,19 +141,19 @@ class AquaModelApp(AquaApp):
141
141
@telemetry(entry_point="plugin=model&action=create", name="aqua")
142
142
def create(
143
143
self,
144
- model_id : Union[str, AquaMultiModelRef],
144
+ model : Union[str, AquaMultiModelRef],
145
145
project_id: Optional[str] = None,
146
146
compartment_id: Optional[str] = None,
147
147
freeform_tags: Optional[Dict] = None,
148
148
defined_tags: Optional[Dict] = None,
149
149
**kwargs,
150
- ) -> DataScienceModel:
150
+ ) -> Union[ DataScienceModel, DataScienceModelGroup] :
151
151
"""
152
- Creates a custom Aqua model from a service model.
152
+ Creates a custom Aqua model or model group from a service model.
153
153
154
154
Parameters
155
155
----------
156
- model_id : Union[str, AquaMultiModelRef]
156
+ model : Union[str, AquaMultiModelRef]
157
157
The model ID as a string or a AquaMultiModelRef instance to be deployed.
158
158
project_id : Optional[str]
159
159
The project ID for the custom model.
@@ -167,28 +167,18 @@ def create(
167
167
168
168
Returns
169
169
-------
170
- DataScienceModel
171
- The instance of DataScienceModel.
170
+ Union[ DataScienceModel, DataScienceModelGroup]
171
+ The instance of DataScienceModel or DataScienceModelGroup .
172
172
"""
173
- model_id = (
174
- model_id.model_id if isinstance(model_id, AquaMultiModelRef) else model_id
175
- )
176
- service_model = DataScienceModel.from_id(model_id)
173
+ fine_tune_weights = []
174
+ if isinstance(model, AquaMultiModelRef):
175
+ fine_tune_weights = model.fine_tune_weights
176
+ model = model.model_id
177
+
178
+ service_model = DataScienceModel.from_id(model)
177
179
target_project = project_id or PROJECT_OCID
178
180
target_compartment = compartment_id or COMPARTMENT_OCID
179
181
180
- # Skip model copying if it is registered model or fine-tuned model
181
- if (
182
- service_model.freeform_tags.get(Tags.BASE_MODEL_CUSTOM, None) is not None
183
- or service_model.freeform_tags.get(Tags.AQUA_FINE_TUNED_MODEL_TAG)
184
- is not None
185
- ):
186
- logger.info(
187
- f"Aqua Model {model_id} already exists in the user's compartment."
188
- "Skipped copying."
189
- )
190
- return service_model
191
-
192
182
# combine tags
193
183
combined_freeform_tags = {
194
184
**(service_model.freeform_tags or {}),
@@ -199,29 +189,112 @@ def create(
199
189
**(defined_tags or {}),
200
190
}
201
191
192
+ custom_model = None
193
+ if fine_tune_weights:
194
+ custom_model = self._create_model_group(
195
+ model_id=model,
196
+ compartment_id=target_compartment,
197
+ project_id=target_project,
198
+ freeform_tags=combined_freeform_tags,
199
+ defined_tags=combined_defined_tags,
200
+ fine_tune_weights=fine_tune_weights,
201
+ service_model=service_model,
202
+ )
203
+
204
+ logger.info(
205
+ f"Aqua Model Group {custom_model.id} created with the service model {model}."
206
+ )
207
+ else:
208
+ # Skip model copying if it is registered model or fine-tuned model
209
+ if (
210
+ Tags.BASE_MODEL_CUSTOM in service_model.freeform_tags
211
+ or Tags.AQUA_FINE_TUNED_MODEL_TAG in service_model.freeform_tags
212
+ ):
213
+ logger.info(
214
+ f"Aqua Model {model} already exists in the user's compartment."
215
+ "Skipped copying."
216
+ )
217
+ return service_model
218
+
219
+ custom_model = self._create_model(
220
+ compartment_id=target_compartment,
221
+ project_id=target_project,
222
+ freeform_tags=combined_freeform_tags,
223
+ defined_tags=combined_defined_tags,
224
+ service_model=service_model,
225
+ **kwargs,
226
+ )
227
+ logger.info(
228
+ f"Aqua Model {custom_model.id} created with the service model {model}."
229
+ )
230
+
231
+ # Track unique models that were created in the user's compartment
232
+ self.telemetry.record_event_async(
233
+ category="aqua/service/model",
234
+ action="create",
235
+ detail=service_model.display_name,
236
+ )
237
+
238
+ return custom_model
239
+
240
+ def _create_model(
241
+ self,
242
+ compartment_id: str,
243
+ project_id: str,
244
+ freeform_tags: Dict,
245
+ defined_tags: Dict,
246
+ service_model: DataScienceModel,
247
+ **kwargs,
248
+ ):
249
+ """Creates a data science model by reference."""
202
250
custom_model = (
203
251
DataScienceModel()
204
- .with_compartment_id(target_compartment )
205
- .with_project_id(target_project )
252
+ .with_compartment_id(compartment_id )
253
+ .with_project_id(project_id )
206
254
.with_model_file_description(json_dict=service_model.model_file_description)
207
255
.with_display_name(service_model.display_name)
208
256
.with_description(service_model.description)
209
- .with_freeform_tags(**combined_freeform_tags )
210
- .with_defined_tags(**combined_defined_tags )
257
+ .with_freeform_tags(**freeform_tags )
258
+ .with_defined_tags(**defined_tags )
211
259
.with_custom_metadata_list(service_model.custom_metadata_list)
212
260
.with_defined_metadata_list(service_model.defined_metadata_list)
213
261
.with_provenance_metadata(service_model.provenance_metadata)
214
262
.create(model_by_reference=True, **kwargs)
215
263
)
216
- logger.info(
217
- f"Aqua Model {custom_model.id} created with the service model {model_id}."
218
- )
219
264
220
- # Track unique models that were created in the user's compartment
221
- self.telemetry.record_event_async(
222
- category="aqua/service/model",
223
- action="create",
224
- detail=service_model.display_name,
265
+ return custom_model
266
+
267
+ def _create_model_group(
268
+ self,
269
+ model_id: str,
270
+ compartment_id: str,
271
+ project_id: str,
272
+ freeform_tags: Dict,
273
+ defined_tags: Dict,
274
+ fine_tune_weights: List,
275
+ service_model: DataScienceModel,
276
+ ):
277
+ """Creates a data science model group."""
278
+ custom_model = (
279
+ DataScienceModelGroup()
280
+ .with_compartment_id(compartment_id)
281
+ .with_project_id(project_id)
282
+ .with_display_name(service_model.display_name)
283
+ .with_description(service_model.description)
284
+ .with_freeform_tags(**freeform_tags)
285
+ .with_defined_tags(**defined_tags)
286
+ .with_custom_metadata_list(service_model.custom_metadata_list)
287
+ .with_base_model_id(model_id)
288
+ .with_member_models(
289
+ [
290
+ {
291
+ "inference_key": fine_tune_weight.model_name,
292
+ "model_id": fine_tune_weight.model_id,
293
+ }
294
+ for fine_tune_weight in fine_tune_weights
295
+ ]
296
+ )
297
+ .create()
225
298
)
226
299
227
300
return custom_model
@@ -271,6 +344,16 @@ def create_multi(
271
344
DataScienceModelGroup
272
345
Instance of DataScienceModelGroup object.
273
346
"""
347
+ member_model_ids = [{"model_id": model.model_id} for model in models]
348
+ for model in models:
349
+ if model.fine_tune_weights:
350
+ member_model_ids.extend(
351
+ [
352
+ {"model_id": fine_tune_model.model_id}
353
+ for fine_tune_model in model.fine_tune_weights
354
+ ]
355
+ )
356
+
274
357
custom_model_group = (
275
358
DataScienceModelGroup()
276
359
.with_compartment_id(compartment_id)
@@ -281,7 +364,7 @@ def create_multi(
281
364
.with_defined_tags(**(defined_tags or {}))
282
365
.with_custom_metadata_list(model_custom_metadata)
283
366
# TODO: add member model inference key
284
- .with_member_models([{"model_id": model.model_id for model in models}] )
367
+ .with_member_models(member_model_ids )
285
368
)
286
369
custom_model_group.create()
287
370
0 commit comments