diff --git a/axlearn/common/array_serialization.py b/axlearn/common/array_serialization.py index dfdc10ac6..d45a35520 100644 --- a/axlearn/common/array_serialization.py +++ b/axlearn/common/array_serialization.py @@ -252,7 +252,7 @@ async def _async_serialize( and arr_inp.is_fully_addressable ) # pylint: disable-next=protected-access - if not serialization._spec_has_metadata(tensorstore_spec): + if not serialization.ts_impl._spec_has_metadata(tensorstore_spec): # pylint: disable-next=protected-access tensorstore_spec["metadata"] = serialization._get_metadata(arr_inp) if "dtype" not in tensorstore_spec: @@ -274,14 +274,14 @@ async def _async_serialize( # does no I/O operation and returns the tensorstore object. For every process other than `0`, # we open with `assume_metadata=True`. if jax.process_index() == 0: - await serialization.ts.open( - serialization.ts.Spec(tensorstore_spec), + await serialization.ts_impl.ts.open( + serialization.ts_impl.ts.Spec(tensorstore_spec), create=True, open=True, context=serialization.TS_CONTEXT, ) - t = await serialization.ts.open( - serialization.ts.Spec(tensorstore_spec), + t = await serialization.ts_impl.ts.open( + serialization.ts_impl.ts.Spec(tensorstore_spec), open=True, assume_metadata=True, context=serialization.TS_CONTEXT,