Skip to content

Commit c8f5867

Browse files
committed
fix: serve not using defined model
1 parent 4ac360c commit c8f5867

File tree

6 files changed

+281
-53
lines changed

6 files changed

+281
-53
lines changed

pdm.lock

Lines changed: 262 additions & 47 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ dependencies = [
3838
"pydantic",
3939
"Pillow",
4040
"open-clip-torch",
41-
"torch",
41+
"torch==2.0.0",
4242
"typer",
4343
]
4444

src/clip_api_service/service.py renamed to src/clip_api_service/_service.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ class RankOutput(BaseItem):
4343
probabilities: List[List[float]]
4444
cosine_similarities: List[List[float]]
4545

46-
bento_model = init_model()
46+
bento_model = init_model("__model_name__")
4747
logit_scale = np.exp(bento_model.info.metadata.get("logit_scale", 4.60517))
4848

4949
clip_runner = get_clip_runner(bento_model)

src/clip_api_service/build.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,19 @@ def build_bento(model_name: str | None = None, use_gpu: bool = False):
2020
if model_name:
2121
os.environ[MODEL_ENV_VAR_KEY] = model_name
2222

23+
# Generate a service file
24+
# TODO: Refactor the code gen to a mode generalize function
25+
# TODO: Use temporary file instead of writing to the source file
26+
src_service_file_path = os.path.join(os.path.dirname(__file__), "_service.py")
27+
target_service_file_path = os.path.join(os.path.dirname(__file__), "service.py")
28+
29+
with open(src_service_file_path, 'r') as src_file, open(target_service_file_path, 'w') as target_file:
30+
data = src_file.read().replace('__model_name__', model_name)
31+
target_file.write(data)
32+
2333
# Get parent directory path
2434
build_ctx = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
2535
bentoml.bentos.build_bentofile(bento_file, build_ctx=build_ctx)
36+
37+
# Remove the generated service file
38+
os.remove(target_service_file_path)

src/clip_api_service/models/openclip.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,14 +165,14 @@ def __init__(self, bento_model: bentoml.Model):
165165
def encode_text(self, texts: list[str]) -> npt.NDArray:
166166
texts_encodings = self.tokenizer(texts)
167167
with torch.inference_mode():
168-
text_embeddings = self.model.encode_text(texts_encodings)
168+
text_embeddings = self.model.encode_text(texts_encodings.to(self.device))
169169
return text_embeddings.cpu().detach().numpy()
170170

171171
@bentoml.Runnable.method(batchable=True)
172172
def encode_image(self, images: list[Image.Image]) -> npt.NDArray:
173173
image_encodings = torch.stack([self.processor(image) for image in images])
174174
with torch.inference_mode():
175-
image_embeddings = self.model.encode_image(image_encodings)
175+
image_embeddings = self.model.encode_image(image_encodings.to(self.device))
176176
return image_embeddings.cpu().detach().numpy()
177177

178178

src/main.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import pprint
22

3-
from clip_api_service.service import svc
3+
from clip_api_service._service import svc
44

55
input_sample = {
66
"queries": [
@@ -28,7 +28,7 @@
2828
def test():
2929
import asyncio
3030

31-
from clip_api_service.service import RankInput
31+
from clip_api_service._service import RankInput
3232

3333
rank_input = RankInput.parse_obj(input_sample)
3434

0 commit comments

Comments
 (0)