Skip to content

Commit 1d3f752

Browse files
authored
Merge pull request #327 from aws-samples/fsdp_sync_module_states_true
FSDP with meta device requires sync_module_states=True
2 parents 8fd4088 + da6f289 commit 1d3f752

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

3.test_cases/10.FSDP/train.py

+4
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,9 @@ def main(args):
161161
model = AutoModelForCausalLM.from_config(model_config)
162162
else:
163163
with torch.device("meta"):
164+
# Instantiating model on `meta` device doesn't consume CPU memory,
165+
# but requires specifing `param_init_fn=...`
166+
# and `sync_module_states=True` in FSDP c-tor.
164167
model = AutoModelForCausalLM.from_config(model_config)
165168

166169
num_params = compute_num_params(model)
@@ -197,6 +200,7 @@ def main(args):
197200
device_id=torch.cuda.current_device(),
198201
use_orig_params=False,
199202
sharding_strategy=sharding_strategy,
203+
sync_module_states=True,
200204
param_init_fn=(lambda module: module.to_empty(device=torch.device("cuda"), recurse=False))
201205
if global_rank != 0 else None,
202206
)

0 commit comments

Comments
 (0)