Skip to content

Commit 26e08a2

Browse files
committed
fixed docstrings and unused imports
1 parent 4461af7 commit 26e08a2

File tree

3 files changed

+9
-17
lines changed

3 files changed

+9
-17
lines changed
Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
1-
21
from tornado.web import HTTPError
32

43
from ads.aqua.common.decorator import handle_exceptions
54
from ads.aqua.extension.base_handler import AquaAPIhandler
65
from ads.aqua.extension.errors import Errors
76
from ads.aqua.shaperecommend.recommend import AquaRecommendApp
8-
from ads.config import COMPARTMENT_OCID
97

108

119
class AquaRecommendHandler(AquaAPIhandler):
@@ -14,8 +12,6 @@ class AquaRecommendHandler(AquaAPIhandler):
1412
1513
Methods
1614
-------
17-
get(self, id: Union[str, List[str]])
18-
Retrieves a list of AQUA deployments or model info or logs by ID.
1915
post(self, *args, **kwargs)
2016
Obtains the eligible compute shapes that would fit the specifed model, context length, model weights, and quantization level.
2117
@@ -27,16 +23,15 @@ class AquaRecommendHandler(AquaAPIhandler):
2723
@handle_exceptions
2824
def post(self, *args, **kwargs): # noqa: ARG002
2925
"""
30-
Lists the eligible GPU compute shapes for the specifed model.
26+
Obtains the eligible compute shapes that would fit the specifed model, context length, model weights, and quantization level.
3127
3228
Returns
3329
-------
34-
List[ComputeShapeSummary]:
35-
The list of the model deployment shapes.
30+
ShapeRecommendationReport
31+
Report containing shape recommendations and troubleshooting advice, if any.
3632
"""
3733
try:
3834
input_data = self.get_json_body()
39-
# input_data["compartment_id"] = self.get_argument("compartment_id", default=COMPARTMENT_OCID)
4035
except Exception as ex:
4136
raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT) from ex
4237

@@ -45,6 +40,7 @@ def post(self, *args, **kwargs): # noqa: ARG002
4540

4641
self.finish(AquaRecommendApp().which_gpu(**input_data))
4742

43+
4844
__handlers__ = [
4945
("recommendation/?([^/]*)", AquaRecommendHandler),
5046
]

ads/aqua/shaperecommend/constants.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,8 @@
3737
TEXT_GENERATION = "text_generation"
3838
SAFETENSORS = "safetensors"
3939

40-
#TODO:
41-
SHAPES_METADATA = "/Users/elizjo/tmp/accelerated-data-science/ads/aqua/resources/gpu_shapes_index.json"
42-
4340
TROUBLESHOOT_MSG = "The selected model is too large to fit on standard GPU shapes with the current configuration.\nAs troubleshooting, we have suggested the two largest available GPU shapes using the smallest quantization level ('4bit') to maximize chances of fitting the model. "
4441

45-
TEXT_MODEL = "text-generation"
4642

4743
QUANT_MAPPING = {
4844
"float32": 4,

ads/aqua/shaperecommend/recommend.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
)
1919
from ads.aqua.shaperecommend.constants import (
2020
SAFETENSORS,
21-
SHAPES_METADATA,
2221
TEXT_GENERATION,
2322
TROUBLESHOOT_MSG,
2423
)
@@ -126,6 +125,9 @@ def get_model_config(ocid: str):
126125
AquaValueError
127126
If the OCID is not for a Data Science model, or if the model type is not supported,
128127
or if required files/tags are not present.
128+
129+
AquaRecommendationError
130+
If the model OCID provided is not supported (only text-generation decoder models in safetensor format supported).
129131
"""
130132
resource_type = get_resource_type(ocid)
131133

@@ -176,9 +178,7 @@ def get_model_config(ocid: str):
176178
return data
177179

178180
@staticmethod
179-
def valid_compute_shapes(
180-
file: str = SHAPES_METADATA,
181-
) -> List["ComputeShapeSummary"]:
181+
def valid_compute_shapes() -> List["ComputeShapeSummary"]:
182182
"""
183183
Returns a filtered list of GPU-only ComputeShapeSummary objects by reading and parsing a JSON file.
184184
@@ -306,7 +306,7 @@ def summarize_shapes_for_seq_lens(
306306

307307
troubleshoot_msg = ""
308308

309-
if len(recommendations) > 5:
309+
if len(recommendations) > 2:
310310
recommendations = ShapeReport.pareto_front(recommendations)
311311

312312
if not recommendations:

0 commit comments

Comments
 (0)