We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
2 parents 8fd4088 + da6f289 commit 1d3f752Copy full SHA for 1d3f752
3.test_cases/10.FSDP/train.py
@@ -161,6 +161,9 @@ def main(args):
161
model = AutoModelForCausalLM.from_config(model_config)
162
else:
163
with torch.device("meta"):
164
+ # Instantiating model on `meta` device doesn't consume CPU memory,
165
+ # but requires specifing `param_init_fn=...`
166
+ # and `sync_module_states=True` in FSDP c-tor.
167
168
169
num_params = compute_num_params(model)
@@ -197,6 +200,7 @@ def main(args):
197
200
device_id=torch.cuda.current_device(),
198
201
use_orig_params=False,
199
202
sharding_strategy=sharding_strategy,
203
+ sync_module_states=True,
204
param_init_fn=(lambda module: module.to_empty(device=torch.device("cuda"), recurse=False))
205
if global_rank != 0 else None,
206
)
0 commit comments