diff --git a/python/ray/data/_internal/stats.py b/python/ray/data/_internal/stats.py index 1b129f81c5c0..869aa49b84ca 100644 --- a/python/ray/data/_internal/stats.py +++ b/python/ray/data/_internal/stats.py @@ -633,7 +633,9 @@ def __init__(self): self._update_thread: Optional[threading.Thread] = None self._update_thread_lock: threading.Lock = threading.Lock() - def _get_stats_actor(self, skip_cache: bool = False) -> Optional[ActorHandle]: + def _get_or_create_stats_actor( + self, skip_cache: bool = False + ) -> Optional[ActorHandle]: if ray._private.worker._global_node is None: raise RuntimeError( "Global node is not initialized. Driver might be not connected to Ray." @@ -650,27 +652,13 @@ def _get_stats_actor(self, skip_cache: bool = False) -> Optional[ActorHandle]: self._stats_actor_handle = ray.get_actor( name=STATS_ACTOR_NAME, namespace=STATS_ACTOR_NAMESPACE ) + self._stats_actor_cluster_id = current_cluster_id except ValueError: - return None - self._stats_actor_cluster_id = current_cluster_id - - return self._stats_actor_handle - - def _get_or_create_stats_actor(self) -> Optional[ActorHandle]: - if ray._private.worker._global_node is None: - raise RuntimeError( - "Global node is not initialized. Driver might be not connected to Ray." - ) - - # NOTE: In some cases (for ex, when registering dataset) actor might be gone - # (for ex, when prior driver disconnects) and therefore to avoid using - # stale handle we force looking up the actor with Ray to determine if - # we should create a new one. - actor = self._get_stats_actor(skip_cache=True) - - if actor is None: - self._stats_actor_handle = _get_or_create_stats_actor() - self._stats_actor_cluster_id = ray._private.worker._global_node.cluster_id + # Create an actor if it doesn't exist + self._stats_actor_handle = _get_or_create_stats_actor() + self._stats_actor_cluster_id = ( + ray._private.worker._global_node.cluster_id + ) return self._stats_actor_handle @@ -684,11 +672,7 @@ def _run_update_loop(): while True: if self._last_iteration_stats or self._last_execution_stats: try: - # Do not create _StatsActor if it doesn't exist because - # this thread can be running even after the cluster is - # shutdown. Creating an actor will automatically start - # a new cluster. - stats_actor = self._get_stats_actor() + stats_actor = self._get_or_create_stats_actor() if stats_actor is None: continue stats_actor.update_metrics.remote( @@ -806,7 +790,14 @@ def register_dataset_to_stats_actor( topology: Optional Topology representing the DAG structure to export data_context: The DataContext attached to the dataset """ - self._get_or_create_stats_actor().register_dataset.remote( + + # NOTE: In some cases (for ex, when registering dataset) actor might be gone + # (for ex, when prior driver disconnects) and therefore to avoid using + # stale handle we force looking up the actor with Ray to determine if + # we should create a new one. + stats_actor = self._get_or_create_stats_actor(skip_cache=True) + + stats_actor.register_dataset.remote( ray.get_runtime_context().get_job_id(), dataset_tag, operator_tags, @@ -816,7 +807,13 @@ def register_dataset_to_stats_actor( def get_dataset_id_from_stats_actor(self) -> str: try: - return ray.get(self._get_or_create_stats_actor().get_dataset_id.remote()) + # NOTE: In some cases (for ex, when registering dataset) actor might be gone + # (for ex, when prior driver disconnects) and therefore to avoid using + # stale handle we force looking up the actor with Ray to determine if + # we should create a new one. + stats_actor = self._get_or_create_stats_actor(skip_cache=True) + + return ray.get(stats_actor.get_dataset_id.remote()) except Exception: # Getting dataset id from _StatsActor may fail, in this case # fall back to uuid4