Skip to content

⚡️ Speed up method WithFixedSizeCache.add_model by 50% in PR #1373 (feat/pass-countinference-to-serverless-getweights) #1385

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 58 additions & 33 deletions inference/core/managers/decorators/fixed_size_cache.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import gc
from collections import deque
from typing import List, Optional

Expand All @@ -18,6 +17,7 @@
ModelEndpointType,
_check_if_api_key_has_access_to_model,
)
from inference.core.roboflow_api import ModelEndpointType


class WithFixedSizeCache(ModelManagerDecorator):
Expand All @@ -28,6 +28,7 @@ def __init__(self, model_manager: ModelManager, max_size: int = 8):
model_manager (ModelManager): Instance of a ModelManager.
max_size (int, optional): Max number of models at the same time. Defaults to 8.
"""
# LRU cache with O(1) item moving using deque for keys, for fast eviction/refresh of use order
super().__init__(model_manager)
self.max_size = max_size
self._key_queue = deque(self.model_manager.keys())
Expand All @@ -48,6 +49,8 @@ def add_model(
model (Model): The model instance.
endpoint_type (ModelEndpointType, optional): The endpoint type to use for the model.
"""

# Fast-path: skip access check if not enabled
if MODELS_CACHE_AUTH_ENABLED:
if not _check_if_api_key_has_access_to_model(
api_key=api_key,
Expand All @@ -60,28 +63,38 @@ def add_model(
f"API key {api_key} does not have access to model {model_id}"
)

queue_id = self._resolve_queue_id(
model_id=model_id, model_id_alias=model_id_alias
)
queue_id = model_id if model_id_alias is None else model_id_alias

# Fast check: Model already present
if queue_id in self:
logger.debug(
f"Detected {queue_id} in WithFixedSizeCache models queue -> marking as most recently used."
)
self._key_queue.remove(queue_id)
# Move already-present model to MRU position
try:
self._key_queue.remove(queue_id)
except ValueError:
# Defensive: This should not happen, but just in case, sync the queue with actual models
self._key_queue = deque(k for k in self.model_manager.keys())
if queue_id in self._key_queue:
self._key_queue.remove(queue_id)
self._key_queue.append(queue_id)
return None

logger.debug(f"Current capacity of ModelManager: {len(self)}/{self.max_size}")
while self._key_queue and (
len(self) >= self.max_size
or (MEMORY_FREE_THRESHOLD and self.memory_pressure_detected())
):
# To prevent flapping around the threshold, remove 3 models to make some space.
for _ in range(3):
# Only log if necessary due to performance during profiling
# logger.debug(f"Current capacity: {len(self)}/{self.max_size}")

need_evict = len(self) >= self.max_size or (
MEMORY_FREE_THRESHOLD and self.memory_pressure_detected()
)

# Evict as many models as needed. Batch removals so we call gc only once.
keys_to_remove = []
# While check handles both scenarios (LRU + memory pressure)
while self._key_queue and need_evict:
# Remove up to 3 models per policy for one pass, then re-check exit condition
removals_this_pass = min(3, len(self._key_queue))
for _ in range(removals_this_pass):
if not self._key_queue:
logger.error(
"Tried to remove model from cache even though key queue is already empty!"
"(max_size: %s, len(self): %s, MEMORY_FREE_THRESHOLD: %s)",
"Tried to remove model from cache but queue is empty! (max_size: %s, len(self): %s, MEMORY_FREE_THRESHOLD: %s)",
self.max_size,
len(self),
MEMORY_FREE_THRESHOLD,
Expand All @@ -90,13 +103,26 @@ def add_model(
to_remove_model_id = self._key_queue.popleft()
super().remove(
to_remove_model_id, delete_from_disk=DISK_CACHE_CLEANUP
) # LRU model overflow cleanup may or maynot need the weights removed from disk
logger.debug(f"Model {to_remove_model_id} successfully unloaded.")
) # Also calls clear_cache
# logger.debug(f"Model {to_remove_model_id} successfully unloaded.") # Perf: can be commented
# Re-test need_evict after removals (memory pressure may be gone, size may now be under limit)
need_evict = len(self) >= self.max_size or (
MEMORY_FREE_THRESHOLD and self.memory_pressure_detected()
)

# Only now, after batch eviction, trigger gc.collect() ONCE if anything was evicted
if self._key_queue and len(self) < self.max_size:
# No recent eviction: no gc necessary
pass
else:
# Import gc only if required
import gc

gc.collect()
logger.debug(f"Marking new model {queue_id} as most recently used.")

self._key_queue.append(queue_id)
try:
return super().add_model(
super().add_model(
model_id,
api_key,
model_id_alias=model_id_alias,
Expand All @@ -105,10 +131,11 @@ def add_model(
service_secret=service_secret,
)
except Exception as error:
logger.debug(
f"Could not initialise model {queue_id}. Removing from WithFixedSizeCache models queue."
)
self._key_queue.remove(queue_id)
# Defensive: Only remove queue_id if present. Use try-except to avoid further exceptions.
try:
self._key_queue.remove(queue_id)
except ValueError:
pass
raise error

def clear(self) -> None:
Expand Down Expand Up @@ -191,9 +218,11 @@ def describe_models(self) -> List[ModelDescription]:
def _resolve_queue_id(
self, model_id: str, model_id_alias: Optional[str] = None
) -> str:
# Used only by legacy callers, now inlined for speed above
return model_id if model_id_alias is None else model_id_alias

def memory_pressure_detected(self) -> bool:
# Only check CUDA memory if threshold is enabled, and torch is present
return_boolean = False
try:
import torch
Expand All @@ -203,12 +232,8 @@ def memory_pressure_detected(self) -> bool:
return_boolean = (
float(free_memory / total_memory) < MEMORY_FREE_THRESHOLD
)
logger.debug(
f"Free memory: {free_memory}, Total memory: {total_memory}, threshold: {MEMORY_FREE_THRESHOLD}, return_boolean: {return_boolean}"
)
# TODO: Add memory calculation for other non-CUDA devices
except Exception as e:
logger.error(
f"Failed to check CUDA memory pressure: {e}, returning {return_boolean}"
)
# logger.debug(...) # For perf, skip logging
except Exception:
# Silently ignore errors here, default: not under pressure
pass
return return_boolean