Skip to content

Commit 58bfc00

Browse files
committed
inital code for GPU Shape Recommendator
1 parent dc1f21b commit 58bfc00

File tree

9 files changed

+938
-5
lines changed

9 files changed

+938
-5
lines changed

ads/aqua/cli.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from ads.aqua.finetuning import AquaFineTuningApp
1515
from ads.aqua.model import AquaModelApp
1616
from ads.aqua.modeldeployment import AquaDeploymentApp
17+
from ads.aqua.shaperecommend.recommend import AquaRecommendApp
1718
from ads.common.utils import LOG_LEVELS
1819

1920

@@ -29,6 +30,7 @@ class AquaCommand:
2930
fine_tuning = AquaFineTuningApp
3031
deployment = AquaDeploymentApp
3132
evaluation = AquaEvaluationApp
33+
recommend = AquaRecommendApp
3234

3335
def __init__(
3436
self,
@@ -94,18 +96,20 @@ def _validate_value(flag, value):
9496
"If you intend to chain a function call to the result, please separate the "
9597
"flag and the subsequent function call with separator `-`."
9698
)
97-
99+
98100
@staticmethod
99101
def install():
100102
"""Install ADS Aqua Extension from wheel file. Set enviroment variable `AQUA_EXTENSTION_PATH` to change the wheel file path.
101103
102-
Return
104+
Return
103105
------
104106
int:
105107
Installatation status.
106108
"""
107109
import subprocess
108110

109-
wheel_file_path = os.environ.get("AQUA_EXTENSTION_PATH", "/ads/extension/adsjupyterlab_aqua_extension*.whl")
110-
status = subprocess.run(f"pip install {wheel_file_path}",shell=True)
111-
return status.check_returncode
111+
wheel_file_path = os.environ.get(
112+
"AQUA_EXTENSTION_PATH", "/ads/extension/adsjupyterlab_aqua_extension*.whl"
113+
)
114+
status = subprocess.run(f"pip install {wheel_file_path}", shell=True)
115+
return status.check_returncode

ads/aqua/extension/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
)
1313
from ads.aqua.extension.evaluation_handler import __handlers__ as __eval_handlers__
1414
from ads.aqua.extension.finetune_handler import __handlers__ as __finetune_handlers__
15+
from ads.aqua.extension.gpu_recommend_handler import __handlers__ as __gpu_handlers__
1516
from ads.aqua.extension.model_handler import __handlers__ as __model_handlers__
1617
from ads.aqua.extension.ui_handler import __handlers__ as __ui_handlers__
1718
from ads.aqua.extension.ui_websocket_handler import __handlers__ as __ws_handlers__
@@ -24,6 +25,7 @@
2425
+ __ui_handlers__
2526
+ __eval_handlers__
2627
+ __ws_handlers__
28+
+ __gpu_handlers__
2729
)
2830

2931

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
2+
from tornado.web import HTTPError
3+
4+
from ads.aqua.common.decorator import handle_exceptions
5+
from ads.aqua.extension.base_handler import AquaAPIhandler
6+
from ads.aqua.extension.errors import Errors
7+
from ads.aqua.shaperecommend.recommend import AquaRecommendApp
8+
from ads.config import COMPARTMENT_OCID
9+
10+
11+
class AquaRecommendHandler(AquaAPIhandler):
12+
"""
13+
Handler for Aqua GPU Recommendation REST APIs.
14+
15+
Methods
16+
-------
17+
get(self, id: Union[str, List[str]])
18+
Retrieves a list of AQUA deployments or model info or logs by ID.
19+
post(self, *args, **kwargs)
20+
Obtains the eligible compute shapes that would fit the specifed model, context length, model weights, and quantization level.
21+
22+
Raises
23+
------
24+
HTTPError: For various failure scenarios such as invalid input format, missing data, etc.
25+
"""
26+
27+
@handle_exceptions
28+
def post(self, *args, **kwargs): # noqa: ARG002
29+
"""
30+
Lists the eligible GPU compute shapes for the specifed model.
31+
32+
Returns
33+
-------
34+
List[ComputeShapeSummary]:
35+
The list of the model deployment shapes.
36+
"""
37+
try:
38+
input_data = self.get_json_body()
39+
# input_data["compartment_id"] = self.get_argument("compartment_id", default=COMPARTMENT_OCID)
40+
except Exception as ex:
41+
raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT) from ex
42+
43+
if not input_data:
44+
raise HTTPError(400, Errors.NO_INPUT_DATA)
45+
46+
self.finish(AquaRecommendApp().which_gpu(**input_data))
47+
48+
__handlers__ = [
49+
("gpu-shape-recommendation/?([^/]*)", AquaRecommendHandler),
50+
]

ads/aqua/shaperecommend/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
#!/usr/bin/env python
2+
# Copyright (c) 2025 Oracle and/or its affiliates.
3+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
4+
from ads.aqua.shaperecommend.recommend import AquaGPURecommendApp
5+
6+
__all__ = ["AquaGPURecommendApp"]

ads/aqua/shaperecommend/constants.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
#!/usr/bin/env python
2+
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
3+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
4+
5+
"""
6+
aqua.shaperecommend.constants
7+
~~~~~~~~~~~~~~
8+
9+
This module contains constants used in Aqua GPU Recommendation for Models.
10+
11+
LLAMA_REQUIRED_FIELDS refer to fields necessary for calculating model memory for GQA Architecture Models
12+
13+
MOE_REQUIRED_FIELDS refer to fields necessary for Mixture of Experts (MoE) Architecture Models
14+
15+
NEXT_QUANT suggests the next quantization level based on the current quantization (if applied) or the model weights (if no quantization yet)
16+
"""
17+
LLAMA_REQUIRED_FIELDS = [
18+
"num_hidden_layers", "hidden_size", "num_attention_heads",
19+
"num_key_value_heads", "head_dim", "intermediate_size", "vocab_size"
20+
]
21+
22+
MOE_REQUIRED_FIELDS = LLAMA_REQUIRED_FIELDS + [
23+
"num_local_experts", "intermediate_size"
24+
]
25+
26+
NEXT_QUANT = {
27+
"float32": ["bfloat16", "float16", "int8"],
28+
"bfloat16": ["float16", "int8"],
29+
"float16": ["int8"],
30+
"int8": ["8bit", "4bit (Not Recommended)"],
31+
"8bit": ["4bit (Not Recommended)"],
32+
"4bit": ["No smaller quantization available"]
33+
}

ads/aqua/shaperecommend/estimator.py

Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
#!/usr/bin/env python
2+
# Copyright (c) 2025 Oracle and/or its affiliates.
3+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
4+
from typing import Optional
5+
6+
from pydantic import BaseModel, Field
7+
8+
from ads.aqua.app import logger
9+
from ads.aqua.shaperecommend.constants import LLAMA_REQUIRED_FIELDS, MOE_REQUIRED_FIELDS
10+
from ads.aqua.shaperecommend.llm_config import LLMConfig
11+
12+
13+
class MemoryEstimator(BaseModel):
14+
"""
15+
The generic estimator for Transformer Architecture models (OPT/ Bloom)
16+
Used as a fallback estimator if model identified is not a MoE or GQA Architecture Model.
17+
Has properties to estimate the KV Cache size, Model size, and total footprint (KV Cache + Model size)
18+
"""
19+
20+
llm_config: LLMConfig = Field(
21+
...,
22+
description="The model's config.json file with the necessary parameters for model size and KV cache estimation/",
23+
)
24+
batch_size: int = (
25+
1 # we assume that estimation for batch sizes are not supported yet
26+
)
27+
seq_len: Optional[int] = Field(
28+
4096, description="The max-seq-len to estimate the size of the KV cache."
29+
)
30+
31+
@property
32+
def kv_cache_memory(self) -> float:
33+
"""
34+
Estimates the KV cache size (in GB) using the LLM config.json parameters.
35+
36+
Uses num_attention_heads (assumes no GQA, each attention head has its own query, key, value) for estimation
37+
"""
38+
seq_len = self.seq_len or self.llm_config.max_seq_len
39+
c = self.llm_config
40+
kv_cache_dtype_bytes = (
41+
c.bytes_per_parameter
42+
) # vLLM uses model's weight/quantization applied to KV cache
43+
44+
total_bytes = (
45+
self.batch_size
46+
* c.num_hidden_layers
47+
* 2
48+
* c.num_attention_heads
49+
* seq_len
50+
* c.head_dim
51+
* kv_cache_dtype_bytes
52+
)
53+
return total_bytes / 1e9
54+
55+
@property
56+
def model_memory(self) -> float:
57+
"""
58+
Estimates the model size (in GB) based on estimating the model parameter size and model weights
59+
60+
Model Parameter estimation: Standard decoder-only, untied/tied embeddings possible
61+
"""
62+
c = self.llm_config
63+
embedding_count = 1 if getattr(c, "tie_word_embeddings", True) else 2
64+
embedding_params = (
65+
embedding_count * c.vocab_size * c.hidden_size
66+
) # input and output untied
67+
layer_params = 12 * c.num_hidden_layers * (c.hidden_size**2) # GPT-style
68+
num_params = layer_params + embedding_params
69+
70+
return num_params * c.bytes_per_parameter / 1e9
71+
72+
# @property
73+
# def model_overhead(self) -> float:
74+
# overhead = max(1, math.ceil(0.0 * self.model_memory))
75+
# return overhead
76+
77+
@property
78+
def total_memory(self) -> float:
79+
"""
80+
Computes the total memory footprint of the model (KV cache & model size from estimated parameters)
81+
"""
82+
return self.model_memory + self.kv_cache_memory
83+
84+
85+
# Specialized estimators:
86+
class LlamaMemoryEstimator(MemoryEstimator):
87+
"""
88+
Estimator for GQA-type architectures. Handles tied (memory savings) and untied embeddings,
89+
and uses grouped attention (GQA) for more efficient KV cache memory estimation.
90+
91+
KV cache: Use num_attention_heads (assumes GQA)
92+
Model Parameter estimation: Standard decoder-only, untied/tied embeddings possible
93+
"""
94+
95+
@property
96+
def model_memory(self) -> float:
97+
"""
98+
Returns estimated model parameter memory (in GB), accurately accounting
99+
for Llama-style attention and MLP, and tied or untied embeddings.
100+
"""
101+
c = self.llm_config
102+
103+
embedding_params, attn_params = self._calc_attn_embed_params()
104+
105+
# MLP params
106+
gate_proj = c.hidden_size * c.intermediate_size
107+
up_proj = c.hidden_size * c.intermediate_size
108+
down_proj = c.intermediate_size * c.hidden_size
109+
mlp_params = gate_proj + up_proj + down_proj
110+
111+
# Total per-layer
112+
layer_params = attn_params + mlp_params
113+
# Total params
114+
num_params = c.num_hidden_layers * layer_params + embedding_params
115+
return num_params * c.bytes_per_parameter / 1e9
116+
117+
@property
118+
def kv_cache_memory(self) -> float:
119+
"""
120+
Returns estimated KV cache memory in GB for GQA models.
121+
122+
Grouped Query Attention uses num_key_value_heads, which groups of Q heads share a K and V projection.
123+
num_key_value_heads < num_attention_heads, which reduces the KV Cache size.
124+
"""
125+
c = self.llm_config
126+
seq_len = self.seq_len or getattr(c, "max_seq_len", 2048)
127+
kv_cache_dtype_bytes = c.bytes_per_parameter
128+
kv_heads = c.num_key_value_heads
129+
130+
total_bytes = (
131+
self.batch_size
132+
* c.num_hidden_layers
133+
* 2
134+
* kv_heads
135+
* seq_len
136+
* c.head_dim
137+
* kv_cache_dtype_bytes
138+
)
139+
return total_bytes / 1e9
140+
141+
def _calc_attn_embed_params(self) -> tuple:
142+
"""
143+
Returns the embedding parameter count and attention parameter count for Llama-family (GQA) models.
144+
"""
145+
c = self.llm_config
146+
147+
# Embedding parameters
148+
# assume tied embeddings unless tie_word_embeddings = False
149+
embedding_count = 1 if getattr(c, "tie_word_embeddings", True) else 2
150+
embedding_params = embedding_count * c.vocab_size * c.hidden_size
151+
152+
q_proj = c.hidden_size * c.hidden_size
153+
k_proj = c.hidden_size * (c.num_key_value_heads * c.head_dim)
154+
v_proj = c.hidden_size * (c.num_key_value_heads * c.head_dim)
155+
o_proj = c.hidden_size * c.hidden_size
156+
attn_params = q_proj + k_proj + v_proj + o_proj
157+
158+
return embedding_params, attn_params
159+
160+
161+
class MixtureMemoryEstimator(LlamaMemoryEstimator):
162+
"""
163+
Estimator for Mixture-of-Experts (MoE) architectures (e.g., Mixtral, MoE Llama).
164+
Adds extra expert parallelism block parameter count to LlamaMemoryEstimator logic.
165+
"""
166+
167+
@property
168+
def model_memory(self) -> float:
169+
"""
170+
Accounts for the increase in model parameters due to additional expert MLP blocks in MoE Models.
171+
172+
Returns the estimated memory size of the MoE Model (in GB).
173+
"""
174+
c = self.llm_config
175+
# Attention parameter count (Llama-style)
176+
embedding_params, attn_params = self._calc_attn_embed_params()
177+
178+
# MoE MLP params per layer
179+
moe_params_per_layer = (
180+
c.num_local_experts * 3 * c.hidden_size * c.intermediate_size
181+
)
182+
total_params = (
183+
c.num_hidden_layers * (attn_params + moe_params_per_layer)
184+
+ embedding_params
185+
)
186+
187+
# Convert to GB
188+
return total_params * c.bytes_per_parameter / 1e9
189+
190+
191+
def get_estimator(llm_config, **kwargs) -> MemoryEstimator:
192+
"""
193+
Extracts the correct estimator based on the defined parameters in the config.json
194+
See constants.py for LLMConfig parameters necessary for specific estimators.
195+
Uses MemoryEstimator as a fallback if parameters needed for GQA and MoE Architectures are missing.
196+
197+
Returns the appropriate MemoryEstimator based on the fields defined by the model's config.json (as represented by LLMConfig).
198+
"""
199+
if all(
200+
hasattr(llm_config, f) and getattr(llm_config, f) is not None
201+
for f in MOE_REQUIRED_FIELDS
202+
):
203+
return MixtureMemoryEstimator(llm_config=llm_config, **kwargs)
204+
elif all(
205+
hasattr(llm_config, f) and getattr(llm_config, f) is not None
206+
for f in LLAMA_REQUIRED_FIELDS
207+
):
208+
return LlamaMemoryEstimator(llm_config=llm_config, **kwargs)
209+
else:
210+
logger.warning(
211+
"Falling back to generic GPT estimator: required fields missing from config.json file in model."
212+
)
213+
return MemoryEstimator(llm_config=llm_config, **kwargs)

0 commit comments

Comments
 (0)