Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ async def health():
return {"status": "OK"}



@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request, exc):
return PlainTextResponse(str(exc), status_code=400)
Expand Down
130 changes: 68 additions & 62 deletions src/api/models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import time
from abc import ABC
from typing import AsyncIterable, Iterable, Literal
from api.models.model_manager import ModelManager

import boto3
import numpy as np
Expand Down Expand Up @@ -75,83 +76,88 @@ def get_inference_region_prefix():

ENCODER = tiktoken.get_encoding("cl100k_base")

# Initialize the model list.
#bedrock_model_list = list_bedrock_models()

def list_bedrock_models() -> dict:
"""Automatically getting a list of supported models.

Returns a model list combines:
- ON_DEMAND models.
- Cross-Region Inference Profiles (if enabled via Env)
"""
model_list = {}
try:
profile_list = []
if ENABLE_CROSS_REGION_INFERENCE:
# List system defined inference profile IDs
response = bedrock_client.list_inference_profiles(
maxResults=1000,
typeEquals='SYSTEM_DEFINED'
)
profile_list = [p['inferenceProfileId'] for p in response['inferenceProfileSummaries']]

# List foundation models, only cares about text outputs here.
response = bedrock_client.list_foundation_models(
byOutputModality='TEXT'
)

for model in response['modelSummaries']:
model_id = model.get('modelId', 'N/A')
stream_supported = model.get('responseStreamingSupported', True)
status = model['modelLifecycle'].get('status', 'ACTIVE')

# currently, use this to filter out rerank models and legacy models
if not stream_supported or status != "ACTIVE":
continue

inference_types = model.get('inferenceTypesSupported', [])
input_modalities = model['inputModalities']
# Add on-demand model list
if 'ON_DEMAND' in inference_types:
model_list[model_id] = {
'modalities': input_modalities
}

# Add cross-region inference model list.
profile_id = cr_inference_prefix + '.' + model_id
if profile_id in profile_list:
model_list[profile_id] = {
'modalities': input_modalities
}

except Exception as e:
logger.error(f"Unable to list models: {str(e)}")
class BedrockModel(BaseChatModel):

if not model_list:
# In case stack not updated.
model_list[DEFAULT_MODEL] = {
'modalities': ["TEXT", "IMAGE"]
}
#bedrock_model_list = None
model_manager = None
def __init__(self):
super().__init__()
self.model_manager = ModelManager()

return model_list
def list_bedrock_models(self) -> dict:
"""Automatically getting a list of supported models.

Returns a model list combines:
- ON_DEMAND models.
- Cross-Region Inference Profiles (if enabled via Env)
"""
#model_list = {}
try:
profile_list = []
if ENABLE_CROSS_REGION_INFERENCE:
# List system defined inference profile IDs
response = bedrock_client.list_inference_profiles(
maxResults=1000,
typeEquals='SYSTEM_DEFINED'
)
profile_list = [p['inferenceProfileId'] for p in response['inferenceProfileSummaries']]

# Initialize the model list.
bedrock_model_list = list_bedrock_models()
# List foundation models, only cares about text outputs here.
response = bedrock_client.list_foundation_models(
byOutputModality='TEXT'
)

for model in response['modelSummaries']:
model_id = model.get('modelId', 'N/A')
stream_supported = model.get('responseStreamingSupported', True)
status = model['modelLifecycle'].get('status', 'ACTIVE')

# currently, use this to filter out rerank models and legacy models
if not stream_supported or status != "ACTIVE":
continue

inference_types = model.get('inferenceTypesSupported', [])
input_modalities = model['inputModalities']
# Add on-demand model list
if 'ON_DEMAND' in inference_types:
model[model_id] = {
'modalities': input_modalities
}
self.model_manager.add_model(model)
# model_list[model_id] = {
# 'modalities': input_modalities
# }

# Add cross-region inference model list.
profile_id = cr_inference_prefix + '.' + model_id
if profile_id in profile_list:
model[profile_id] = {
'modalities': input_modalities
}
self.model_manager.add_model(model)

class BedrockModel(BaseChatModel):
except Exception as e:
logger.error(e)
raise HTTPException(status_code=500, detail=str(e))

def list_models(self) -> list[str]:
"""Always refresh the latest model list"""
global bedrock_model_list
bedrock_model_list = list_bedrock_models()
return list(bedrock_model_list.keys())
#global bedrock_model_list
self.list_bedrock_models()
return list(self.model_manager.get_all_models().keys())

def validate(self, chat_request: ChatRequest):
"""Perform basic validation on requests"""

error = ""

###### TODO - failing here as kb and agents are not in the bedrock_model_list
# check if model is supported
if chat_request.model not in bedrock_model_list.keys():
if chat_request.model not in self.model_manager.get_all_models().keys():
error = f"Unsupported model {chat_request.model}, please use models API to get a list of supported models"

if error:
Expand Down Expand Up @@ -659,7 +665,7 @@ def _parse_content_parts(

@staticmethod
def is_supported_modality(model_id: str, modality: str = "IMAGE") -> bool:
model = bedrock_model_list.get(model_id)
model = ModelManager().models.get(model_id)
modalities = model.get('modalities', [])
if modality in modalities:
return True
Expand Down
Loading
Loading