Skip to content

Commit 310a4b2

Browse files
committed
Add Orbax checkpoint logger for AXLearn
1 parent c96e72b commit 310a4b2

8 files changed

+121
-4
lines changed

axlearn/cloud/gcp/measurement.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from typing import Optional, Sequence
2626

2727
import jax
28+
import orbax.checkpoint as ocp
2829
from absl import flags, logging
2930
from ml_goodput_measurement import goodput
3031
from ml_goodput_measurement import monitoring as goodput_monitoring
@@ -134,6 +135,19 @@ def record_event(self, event: measurement.EventType, *args, **kwargs):
134135
)
135136
# pylint: enable=try-except-raise
136137

138+
def create_checkpoint_logger(self) -> Optional[ocp.logging.CloudLogger]:
139+
try:
140+
logging.info("Creating a Goodput checkpoint logger.")
141+
return ocp.logging.CloudLogger(
142+
options=ocp.logging.CloudLoggerOptions(
143+
job_name=self._job_name,
144+
logger_name=self._logger_name,
145+
)
146+
)
147+
except Exception as e: # pylint: disable=broad-exception-caught
148+
logging.warning("Failed to create Goodput checkpoint logger: %s", e, exc_info=True)
149+
return None
150+
137151
@contextlib.contextmanager
138152
def _maybe_monitor_goodput(self, *args, **kwargs):
139153
"""Monitor cumulative goodput if enabled.

axlearn/cloud/gcp/measurement_test.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -373,3 +373,46 @@ def test_maybe_monitor_all(
373373
else:
374374
mock_monitor_instance.start_rolling_window_goodput_uploader.assert_not_called()
375375
mock_monitor_instance.stop_rolling_window_goodput_uploader.assert_not_called()
376+
377+
@mock.patch("jax.process_index", return_value=0)
378+
def test_create_checkpoint_logger_success(self, _):
379+
"""Tests that create_checkpoint_logger creates a CloudLogger with correct config."""
380+
cfg = GoodputRecorder.default_config().set(
381+
name="test-job",
382+
upload_dir="/test",
383+
upload_interval=30,
384+
)
385+
recorder = GoodputRecorder(cfg)
386+
387+
with mock.patch("orbax.checkpoint.logging.CloudLogger") as mock_logger_cls:
388+
mock_logger_instance = mock_logger_cls.return_value
389+
logger = recorder.create_checkpoint_logger()
390+
391+
mock_logger_cls.assert_called_once()
392+
self.assertIs(logger, mock_logger_instance)
393+
394+
_, kwargs = mock_logger_cls.call_args
395+
options = kwargs["options"]
396+
self.assertEqual(options.job_name, "test-job")
397+
self.assertEqual(options.logger_name, "goodput_logger_test-job")
398+
399+
@mock.patch("jax.process_index", return_value=0)
400+
def test_create_checkpoint_logger_failure(self, _):
401+
"""Tests that create_checkpoint_logger logs a warning on failure and returns None."""
402+
cfg = GoodputRecorder.default_config().set(
403+
name="fail-job",
404+
upload_dir="/test",
405+
upload_interval=30,
406+
)
407+
recorder = GoodputRecorder(cfg)
408+
409+
with mock.patch(
410+
"orbax.checkpoint.logging.CloudLogger", side_effect=RuntimeError("TestError")
411+
) as mock_logger_cls, mock.patch.object(logging, "warning") as mock_warning:
412+
logger = recorder.create_checkpoint_logger()
413+
self.assertIsNone(logger)
414+
mock_logger_cls.assert_called_once()
415+
mock_warning.assert_called_once()
416+
self.assertIn(
417+
"Failed to create Goodput checkpoint logger", mock_warning.call_args[0][0]
418+
)

axlearn/common/checkpointer_orbax.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import tensorflow as tf
1818
from absl import logging
1919

20-
from axlearn.common import utils
20+
from axlearn.common import measurement, utils
2121
from axlearn.common.checkpointer import (
2222
STEP_NUM_DIGITS,
2323
STEP_PREFIX,
@@ -232,6 +232,9 @@ def save_fn_with_summaries(step: int, last_saved_step: Optional[int]) -> bool:
232232
step_prefix=STEP_PREFIX,
233233
step_format_fixed_length=STEP_NUM_DIGITS,
234234
)
235+
self._checkpoint_logger = None
236+
if measurement.global_recorder:
237+
self._checkpoint_logger = measurement.global_recorder.create_checkpoint_logger()
235238
self._manager = ocp.CheckpointManager(
236239
directory=cfg.dir,
237240
options=ocp.CheckpointManagerOptions(
@@ -255,6 +258,7 @@ def save_fn_with_summaries(step: int, last_saved_step: Optional[int]) -> bool:
255258
restore_concurrent_gb=cfg.max_concurrent_restore_gb,
256259
),
257260
},
261+
logger=self._checkpoint_logger,
258262
)
259263

260264
def _get_spec(self, *, step: int, state: Nested[Any]) -> Nested[Any]:

axlearn/common/checkpointer_orbax_emergency.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from jax.experimental.array_serialization import serialization
2828

2929
from axlearn.common import file_system as fs
30-
from axlearn.common import utils, utils_spmd
30+
from axlearn.common import measurement, utils, utils_spmd
3131
from axlearn.common.checkpointer import (
3232
STEP_NUM_DIGITS,
3333
STEP_PREFIX,
@@ -667,6 +667,9 @@ def _composite_save_policy(*, step: int, evaler_summaries: dict[str, Any]):
667667
# See comments of _eval_summaries in `OrbaxCheckpointer`.
668668
self._eval_summaries = None
669669
self._reached_preemption = False
670+
self._checkpoint_logger = None
671+
if measurement.global_recorder:
672+
self._checkpoint_logger = measurement.global_recorder.create_checkpoint_logger()
670673

671674
# pylint: disable-next=redefined-builtin
672675
def ckpt_dir(self, step: int, dir: Optional[str] = None) -> str:
@@ -731,6 +734,7 @@ def _orbax_save_fn(
731734
cleanup_tmp_directories=True,
732735
enable_async_checkpointing=True,
733736
),
737+
logger=self._checkpoint_logger,
734738
)
735739
return self._tensor_manager
736740

axlearn/common/checkpointer_orbax_emergency_test.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import tempfile
1111
from contextlib import ExitStack, closing
1212
from typing import Optional
13+
from unittest import mock
1314

1415
import jax
1516
import numpy as np
@@ -18,7 +19,7 @@
1819
from absl.testing import parameterized
1920
from jax import numpy as jnp
2021

21-
from axlearn.common import utils_spmd
22+
from axlearn.common import measurement, utils_spmd
2223
from axlearn.common.checkpointer_orbax_emergency import (
2324
OrbaxEmergencyCheckpointer,
2425
_dump_process_info,
@@ -299,3 +300,28 @@ def start_processes(reverse_process_id: bool = False):
299300
finally:
300301
for p in processes:
301302
p.kill()
303+
304+
@mock.patch("orbax.checkpoint._src.multihost.multihost.initialize_runtime_to_distributed_ids")
305+
@mock.patch("orbax.checkpoint._src.multihost.multihost.initialize_distributed_to_device_ids")
306+
def test_emergency_checkpointer_initializes_logger_from_global_recorder(
307+
self, mock_init_runtime, mock_init_device_ids
308+
): # pylint: disable=unused-argument
309+
"""Tests OrbaxEmergencyCheckpointer initializes _checkpoint_logger."""
310+
with tempfile.TemporaryDirectory() as temp_dir, mock.patch.object(
311+
measurement, "global_recorder", mock.MagicMock()
312+
) as mock_recorder:
313+
mock_logger = mock.MagicMock()
314+
mock_recorder.create_checkpoint_logger.return_value = mock_logger
315+
316+
cfg = OrbaxEmergencyCheckpointer.default_config().set(
317+
name="test_logger",
318+
trainer_dir=temp_dir,
319+
dir=temp_dir,
320+
local_dir=temp_dir,
321+
replica_axis_index=0,
322+
)
323+
324+
ckpt: OrbaxEmergencyCheckpointer = cfg.instantiate(parent=None)
325+
326+
mock_recorder.create_checkpoint_logger.assert_called_once()
327+
self.assertEqual(ckpt._checkpoint_logger, mock_logger)

axlearn/common/checkpointer_orbax_test.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,14 @@
1010
import os
1111
import tempfile
1212
from typing import Sequence
13+
from unittest import mock
1314

1415
import jax
1516
import orbax.checkpoint as ocp
1617
from jax import numpy as jnp
1718
from jax.experimental import mesh_utils
1819

19-
from axlearn.common import test_utils
20+
from axlearn.common import measurement, test_utils
2021
from axlearn.common.checkpointer import read_index_file
2122
from axlearn.common.checkpointer_orbax import OrbaxCheckpointer
2223

@@ -52,3 +53,21 @@ def test_index(self):
5253
),
5354
)
5455
self.assertEqual(ref_index, test_index["index"])
56+
57+
def test_initializes_checkpoint_logger_from_global_recorder(self):
58+
"""Tests that OrbaxCheckpointer initializes _checkpoint_logger if global_recorder is set."""
59+
with tempfile.TemporaryDirectory() as temp_dir, mock.patch.object(
60+
measurement, "global_recorder", mock.MagicMock()
61+
) as mock_recorder:
62+
mock_logger = mock.MagicMock(spec=ocp.logging.CloudLogger)
63+
mock_recorder.create_checkpoint_logger.return_value = mock_logger
64+
65+
ckpt = (
66+
OrbaxCheckpointer.default_config()
67+
.set(name="test", dir=temp_dir)
68+
.instantiate(parent=None)
69+
)
70+
71+
# Ensure create_checkpoint_logger was called and the logger was set.
72+
mock_recorder.create_checkpoint_logger.assert_called_once()
73+
self.assertEqual(ckpt._checkpoint_logger, mock_logger)

axlearn/common/measurement.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,10 @@ def record_event(self, event: Event, *args, **kwargs):
102102
def maybe_monitor_all(self):
103103
yield
104104

105+
def create_checkpoint_logger(self) -> Optional[object]:
106+
"""Optionally returns a fully functional and independent checkpoint logger."""
107+
return None
108+
105109

106110
_recorders: dict[str, type] = {}
107111
_T = TypeVar("_T")

axlearn/common/measurement_test.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,3 +96,6 @@ def test_initialize(self, recorder_type, expected):
9696
# Ensure that maybe_monitor_all does not fail (just enter and exit context).
9797
with measurement.global_recorder.maybe_monitor_all():
9898
pass
99+
100+
# Ensure that create_checkpoint_logger does not crash.
101+
measurement.global_recorder.create_checkpoint_logger()

0 commit comments

Comments
 (0)