Skip to content

Commit 78b7a1d

Browse files
committed
[OPT] Load fp16 weights on CPU before moving to GPU
1 parent 33e0860 commit 78b7a1d

File tree

6 files changed

+27
-12
lines changed

6 files changed

+27
-12
lines changed

flash_attn/models/gpt.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -166,9 +166,10 @@ def from_pretrained(cls, model_name, config, *args, strict=True, device=None, dt
166166
"""
167167
# Instantiate model.
168168
model = cls(config, *args, device=device, dtype=dtype, **kwargs)
169-
# If we're going to shard the model, then don't load fp32 weights to GPU.
169+
# Load state_dict in cpu because we already initialized the model in GPU, and we don't
170+
# want extra stuff taking up more GPU memory
170171
state_dict = state_dict_from_pretrained(
171-
model_name, device=device if world_size == 1 else None, dtype=dtype
172+
model_name, device='cpu', dtype=dtype
172173
)
173174
if model_name.startswith('gpt2'):
174175
state_dict = remap_state_dict_gpt2(state_dict, config)
@@ -178,7 +179,6 @@ def from_pretrained(cls, model_name, config, *args, strict=True, device=None, dt
178179
raise NotImplementedError(f'Model {model_name} not supported')
179180
if world_size > 1:
180181
state_dict = shard_state_dict_tp(state_dict, config, world_size, rank)
181-
state_dict = {k: v.to(device=device) for k, v in state_dict.items()}
182182
load_return = model.load_state_dict(state_dict, strict=strict)
183183
logger.info(load_return)
184184
return model

flash_attn/models/opt.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ def key_mapping_emb(key):
4343
# LayerNorm
4444
def key_mapping_ln(key):
4545
key = re.sub(r'^transformer.final_layer_norm.', r'transformer.ln_f.', key)
46+
# The OPT-175B checkpoint calls this 'decoder.layer_norm' instead of 'decoder.final_layer_norm'
47+
key = re.sub(r'^transformer.layer_norm.', r'transformer.ln_f.', key)
4648
key = re.sub(r'^transformer.layers.(\d+).self_attn_layer_norm.',
4749
r'transformer.layers.\1.norm1.', key)
4850
key = re.sub(r'^transformer.layers.(\d+).final_layer_norm.',

flash_attn/utils/generation.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ class DecodingCGCache:
196196

197197
@torch.inference_mode()
198198
def update_graph_cache(model, cache, batch_size, seqlen_og, max_seqlen, tensor_parallel=1,
199-
dtype=None):
199+
dtype=None, n_warmups=2):
200200
if cache is None:
201201
cache = DecodingCGCache()
202202
param_example = next(iter(model.parameters()))
@@ -228,7 +228,8 @@ def update_graph_cache(model, cache, batch_size, seqlen_og, max_seqlen, tensor_p
228228
if s_type not in cache.callables:
229229
seqlen = min(max(seqlen_og, seqlen_type_to_seqlen(s_type)), max_seqlen)
230230
cache.callables[s_type] = capture_graph(
231-
model, cache.inference_params, batch_size, seqlen_og, seqlen, mempool=cache.mempool
231+
model, cache.inference_params, batch_size, seqlen_og, seqlen, mempool=cache.mempool,
232+
n_warmups=n_warmups
232233
)
233234

234235
def dispatch(input_ids, position_ids, seqlen):
@@ -239,7 +240,8 @@ def dispatch(input_ids, position_ids, seqlen):
239240
return cache
240241

241242

242-
def capture_graph(model, inference_params, batch_size, seqlen_og, max_seqlen, mempool=None):
243+
def capture_graph(model, inference_params, batch_size, seqlen_og, max_seqlen, mempool=None,
244+
n_warmups=2):
243245
assert max_seqlen >= seqlen_og
244246
device = next(iter(model.parameters())).device
245247
input_ids = torch.full((batch_size, 1), 0, dtype=torch.long, device=device)
@@ -250,10 +252,15 @@ def capture_graph(model, inference_params, batch_size, seqlen_og, max_seqlen, me
250252
s = torch.cuda.Stream()
251253
s.wait_stream(torch.cuda.current_stream())
252254
with torch.cuda.stream(s):
253-
for _ in range(2):
255+
for _ in range(n_warmups):
254256
logits = model(input_ids, position_ids=position_ids,
255257
inference_params=inference_params).logits[:, -1]
256258
s.synchronize()
259+
# This might be needed for correctness if we run with NCCL_GRAPH_MIXING_SUPPORT=0,
260+
# which requires that graph launch and non-captured launch to not overlap (I think,
261+
# that's how I interpret the documentation). I'm not sure if this is required.
262+
if torch.distributed.is_initialized():
263+
torch.distributed.barrier()
257264
torch.cuda.current_stream().wait_stream(s)
258265
# Captures the graph
259266
# To allow capture, automatically sets a side stream as the current stream in the context

flash_attn/utils/pretrained.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77

88

99
def state_dict_from_pretrained(model_name, device=None, dtype=None):
10+
# If not fp32, then we don't want to load directly to the GPU
11+
mapped_device = 'cpu' if dtype not in [torch.float32, None] else device
1012
is_sharded = False
1113
resolved_archive_file = cached_file(model_name, WEIGHTS_NAME,
1214
_raise_exceptions_for_missing_entries=False)
@@ -25,9 +27,11 @@ def state_dict_from_pretrained(model_name, device=None, dtype=None):
2527
)
2628
state_dict = {}
2729
for sharded_file in resolved_archive_file:
28-
state_dict.update(torch.load(sharded_file, map_location=device))
30+
state_dict.update(torch.load(sharded_file, map_location=mapped_device))
2931
else:
3032
state_dict = torch.load(cached_file(model_name, WEIGHTS_NAME), map_location=device)
33+
# Convert dtype before moving to GPU to save memory
3134
if dtype is not None:
32-
state_dict = {k: v.to(dtype) for k, v in state_dict.items()}
35+
state_dict = {k: v.to(dtype=dtype) for k, v in state_dict.items()}
36+
state_dict = {k: v.to(device=device) for k, v in state_dict.items()}
3337
return state_dict

tests/models/test_gpt_generation.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def test_greedy_decode_gpt2(model_name, rotary, optimized, fused_ft_kernel):
114114

115115

116116
@pytest.mark.parametrize('model_name', ["facebook/opt-125m", "facebook/opt-350m", "facebook/opt-1.3b", "facebook/opt-2.7b", "facebook/opt-6.7b"])
117-
# @pytest.mark.parametrize('model_name', ["facebook/opt-6.7b"])
117+
# @pytest.mark.parametrize('model_name', ["facebook/opt-125m"])
118118
def test_greedy_decode_opt(model_name):
119119
"""Check that our implementation of OPT generation matches the HF implementation:
120120
the scores in fp16 should be around the same as the HF scores in fp16, when compared to
@@ -145,7 +145,7 @@ def test_greedy_decode_opt(model_name):
145145

146146
input_ids = tokenizer("Hello, my dog is cute and",
147147
return_tensors="pt").input_ids.to(device=device)
148-
max_length = 30
148+
max_length = 60
149149
# input_ids = torch.randint(0, 100, (2, 10), dtype=torch.long, device='cuda')
150150
# max_length = input_ids.shape[1] + 40
151151

@@ -192,7 +192,7 @@ def test_greedy_decode_opt(model_name):
192192
print(f'Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms')
193193
if verbose:
194194
print(out_cg.sequences)
195-
print(tokenizer.batch_decode(out.sequences.tolist()))
195+
print(tokenizer.batch_decode(out_cg.sequences.tolist()))
196196

197197
del model
198198

tests/models/test_gpt_generation_parallel.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,3 +129,5 @@ def test_tensor_parallel(model_name, rotary, fused_ft_kernel, world_size):
129129
assert torch.all(out.sequences == out_hf.sequences)
130130

131131
assert (torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item() < 3 * (torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()
132+
133+
parallel_state.destroy_model_parallel()

0 commit comments

Comments
 (0)