2
2
# Copyright (c) 2025 Oracle and/or its affiliates.
3
3
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
4
4
5
+ import json
5
6
from typing import List , Optional
6
7
7
8
from pydantic import BaseModel , Field
8
9
9
10
from ads .aqua .common .entities import ComputeShapeSummary
10
- from ads .aqua .shaperecommend .constants import QUANT_MAPPING
11
+ from ads .aqua .modeldeployment .config_loader import AquaDeploymentConfig
12
+ from ads .aqua .shaperecommend .constants import (
13
+ DEFAULT_WEIGHT_SIZE ,
14
+ MAX_MODEL_LEN_FLAG ,
15
+ QUANT_FLAG ,
16
+ QUANT_MAPPING ,
17
+ VLLM_ENV_KEY ,
18
+ VLLM_PARAMS_KEY ,
19
+ )
11
20
from ads .aqua .shaperecommend .estimator import MemoryEstimator
12
21
from ads .config import COMPARTMENT_OCID
13
22
@@ -30,6 +39,10 @@ class RequestRecommend(BaseModel):
30
39
COMPARTMENT_OCID , description = "The OCID of user's compartment"
31
40
)
32
41
42
+ deployment_config : Optional [AquaDeploymentConfig ] = Field (
43
+ {}, description = "The deployment configuration for model (only available for service models)."
44
+ )
45
+
33
46
class Config :
34
47
protected_namespaces = ()
35
48
@@ -42,7 +55,7 @@ class DeploymentParams(BaseModel): # noqa: N801
42
55
quantization : Optional [str ] = Field (
43
56
None , description = "Type of quantization (e.g. 4bit)."
44
57
)
45
- max_model_len : int = Field (... , description = "Maximum length of input sequence." )
58
+ max_model_len : Optional [ int ] = Field (None , description = "Maximum length of input sequence." )
46
59
params : str = Field (
47
60
..., description = "Runtime parameters for deployment with vLLM, etc."
48
61
)
@@ -68,11 +81,12 @@ class ModelConfig(BaseModel):
68
81
The configuration for a model based on specific set of deployment parameters and memory capacity of shape.
69
82
"""
70
83
71
- model_details : ModelDetail = Field (..., description = "Details about the model." )
72
84
deployment_params : DeploymentParams = Field (
73
85
..., description = "Parameters for deployment."
74
86
)
75
- recommendation : str = Field (..., description = "GPU recommendation for the model." )
87
+ model_details : Optional [ModelDetail ] = Field (None , description = "Details about the model." )
88
+
89
+ recommendation : Optional [str ] = Field ("" , description = "GPU recommendation for the model." )
76
90
77
91
class Config :
78
92
protected_namespaces = ()
@@ -231,3 +245,62 @@ class ShapeRecommendationReport(BaseModel):
231
245
None ,
232
246
description = "Details for troubleshooting if no shapes fit the current model." ,
233
247
)
248
+
249
+
250
+ @classmethod
251
+ def from_deployment_config (cls , deployment_config : AquaDeploymentConfig , model_name : str , valid_shapes : List [ComputeShapeSummary ]) -> "ShapeRecommendationReport" :
252
+ """
253
+ For service models, pre-set deployment configurations (AquaDeploymentConfig) are available.
254
+ Derives ShapeRecommendationReport from AquaDeploymentConfig (if service model & available)
255
+ """
256
+
257
+ recs = []
258
+ # may need to sort?
259
+ for shape in valid_shapes :
260
+ current_config = deployment_config .configuration .get (shape .name )
261
+ if current_config :
262
+ quantization = None
263
+ max_model_len = None
264
+ recommendation = ""
265
+ current_params = current_config .parameters .get (VLLM_PARAMS_KEY )
266
+ current_env = current_config .env .get (VLLM_ENV_KEY )
267
+
268
+ if current_params :
269
+ param_list = current_params .split ()
270
+
271
+ if QUANT_FLAG in param_list and (idx := param_list .index (QUANT_FLAG )) + 1 < len (param_list ):
272
+ quantization = param_list [idx + 1 ]
273
+
274
+ if MAX_MODEL_LEN_FLAG in param_list and (idx := param_list .index (MAX_MODEL_LEN_FLAG )) + 1 < len (param_list ):
275
+ max_model_len = param_list [idx + 1 ]
276
+ max_model_len = int (max_model_len )
277
+
278
+ if current_env :
279
+ recommendation += f"ENV: { json .dumps (current_env )} \n \n "
280
+
281
+ recommendation += "Model fits well within the allowed compute shape."
282
+
283
+ deployment_params = DeploymentParams (
284
+ quantization = quantization if quantization else DEFAULT_WEIGHT_SIZE ,
285
+ max_model_len = max_model_len ,
286
+ params = current_params if current_params else "" ,
287
+ )
288
+
289
+ # TODO: calculate memory footprint based on params??
290
+ # TODO: add --env vars not just params, current_config.env
291
+ # are there multiple configurations in the SMM configs per shape??
292
+ configuration = [ModelConfig (
293
+ deployment_params = deployment_params ,
294
+ recommendation = recommendation ,
295
+ )]
296
+
297
+ recs .append (ShapeReport (
298
+ shape_details = shape ,
299
+ configurations = configuration
300
+ )
301
+ )
302
+
303
+ return ShapeRecommendationReport (
304
+ display_name = model_name ,
305
+ recommendations = recs
306
+ )
0 commit comments