Skip to content
This repository was archived by the owner on Oct 25, 2024. It is now read-only.

Commit 3e78ae8

Browse files
authored
h2o for kv cache compression (#1468)
Signed-off-by: n1ck-guo <[email protected]>
1 parent b00652d commit 3e78ae8

File tree

14 files changed

+3104
-0
lines changed

14 files changed

+3104
-0
lines changed

.github/workflows/script/unitTest/coverage/.optimize-coveragerc

+1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ omit =
1818
*/intel_extension_for_transformers/langchain/**
1919
*/intel_extension_for_transformers/llama_index/**
2020
*/intel_extension_for_transformers/transformers/utils/get_throughput.py
21+
*/intel_extension_for_transformers/transformers/kv_cache_compression/**
2122
exclude_lines =
2223
pragma: no cover
2324
raise NotImplementedError

docs/h2o.md

+49
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# H2O: Heavy-Hitter Oracle for Efficient Generative Inference of Large Language Models
2+
1. [Introduction](#introduction)
3+
2. [Usage](#usage)
4+
5+
## Introduction
6+
**Heavy-Hitter Oracal (H2O)** is a novel approach for implementing the KV cache which significantly reduces memory footprint.
7+
8+
This methods base on the fact that the accumulated attention scores of all tokens in attention blocks adhere to a power-law distribution. It suggests that there exists a small set of influential tokens that are critical during generation, named heavy-hitters (H2). H2 provides an opportunity to step away from the combinatorial search problem and identify an eviction policy that maintains accuracy.
9+
10+
H2O can dynamically retains the balance of recent and H2 tokens. Significantly increase model throughput while ensuring accuracy.
11+
12+
13+
For more info, please refer to the paper [H2O: Heavy-Hitter Oracle for Efficient Generative Inference of Large Language Models](https://arxiv.org/pdf/2306.14048).
14+
15+
16+
![](./imgs/h2o.png)
17+
18+
19+
## Usage
20+
Using simulation mode
21+
```python
22+
from intel_extension_for_transformers.transformers.kv_cache_compression import H2OConfig, LlamaForCausalLM
23+
h2o_config = H2OConfig(
24+
heavy_ratio=heavy_ratio,
25+
recent_ratio=recent_ratio,
26+
h2o_min_seqlen=h2o_min_seqlen,
27+
real_drop=False,
28+
)
29+
user_model = LlamaForCausalLM.from_pretrained(
30+
args.model,
31+
prune_config=h2o_config,
32+
trust_remote_code=args.trust_remote_code)
33+
```
34+
To run the real_drop mode
35+
```python
36+
from intel_extension_for_transformers.transformers.kv_cache_compression import H2OConfig, LlamaForCausalLM
37+
h2o_config = H2OConfig(
38+
heavy_ratio=heavy_ratio,
39+
recent_ratio=recent_ratio,
40+
h2o_min_seqlen=h2o_min_seqlen,
41+
real_drop=True,
42+
)
43+
user_model = LlamaForCausalLM.from_pretrained(
44+
args.model,
45+
prune_config=h2o_config,
46+
trust_remote_code=args.trust_remote_code)
47+
```
48+
49+
Please refer to [h2o example](../examples/huggingface/pytorch/text-generation/h2o/run_generation.py) for the details.

docs/imgs/h2o.png

352 KB
Loading
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# H2O: Heavy-Hitter Oracle for Efficient Generative Inference of Large Language Models
2+
3+
**Heavy-Hitter Oracal (H2O)** is a novel approach for implementing the KV cache which significantly reduces memory footprint.
4+
5+
This methods base on the fact that the accumulated attention scores of all tokens in attention blocks adhere to a power-law distribution. It suggests that there exists a small set of influential tokens that are critical during generation, named heavy-hitters (H2). H2 provides an opportunity to step away from the combinatorial search problem and identify an eviction policy that maintains accuracy.
6+
7+
H2O can dynamically retains the balance of recent and H2 tokens. Significantly increase model throughput while ensuring accuracy.
8+
9+
10+
For more info, please refer to the paper [H2O: Heavy-Hitter Oracle for Efficient Generative Inference of Large Language Models](https://arxiv.org/pdf/2306.14048).
11+
12+
13+
![](./imgs/1.png)
14+
15+
16+
## Usage and Examples
17+
### Evaluation on tasks from [lm-eval-harness](https://github.com/EleutherAI/lm-evaluation-harness) framework
18+
Using simulation mode
19+
```bash
20+
python run_generation.py \
21+
--model meta-llama/Meta-Llama-3-8B \
22+
--accuracy \
23+
--batch_size 16 \
24+
--h2o \
25+
--heavy_ratio 0.1 \
26+
--recent_ratio 0.1 \
27+
--device 0
28+
```
29+
To run the real_drop mode
30+
```bash
31+
python run_generation.py \
32+
--model meta-llama/Meta-Llama-3-8B \
33+
--accuracy \
34+
--batch_size 16 \
35+
--h2o \
36+
--heavy_ratio 0.1 \
37+
--recent_ratio 0.1 \
38+
--device 0
39+
--real_drop
40+
```
41+
Get the accuracy of dense model
42+
```bash
43+
python run_generation.py \
44+
--model meta-llama/Meta-Llama-3-8B \
45+
--accuracy \
46+
--batch_size 16
47+
```
Loading
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,238 @@
1+
import argparse
2+
import sys
3+
import time
4+
import json
5+
import torch
6+
from transformers import AutoConfig, AutoTokenizer, AutoModel, AutoModelForCausalLM
7+
from transformers.utils import check_min_version
8+
9+
parser = argparse.ArgumentParser()
10+
parser.add_argument("--model", default=None)
11+
parser.add_argument(
12+
"--dataset", nargs="?", default="NeelNanda/pile-10k", const="NeelNanda/pile-10k"
13+
)
14+
parser.add_argument(
15+
"--max_new_tokens", default=32, type=int, help="output max new tokens"
16+
)
17+
parser.add_argument("--output_dir", nargs="?", default="./saved_results")
18+
parser.add_argument("--int8", action="store_true")
19+
parser.add_argument(
20+
"--int8_bf16_mixed",
21+
action="store_true",
22+
help="by default it is int8-fp32 mixed, to enable int8 mixed amp bf16 (work on platforms like SPR)",
23+
)
24+
parser.add_argument(
25+
"--restore",
26+
action="store_true",
27+
help="restore ipex quantized model from output_dir/best_configure.json",
28+
)
29+
parser.add_argument(
30+
"--peft_model_id", type=str, default=None, help="model_name_or_path of peft model"
31+
)
32+
parser.add_argument("--_commit_hash", default=None, type=str)
33+
parser.add_argument("--trust_remote_code", action="store_true")
34+
parser.add_argument("--use_neural_speed", action="store_true")
35+
# ============Benchmark configs==============
36+
parser.add_argument("--benchmark", action="store_true")
37+
parser.add_argument("--iters", default=100, type=int, help="num iter")
38+
parser.add_argument("--num_warmup", default=10, type=int, help="num warmup")
39+
# ============Accuracy configs==============
40+
parser.add_argument("--accuracy", action="store_true")
41+
parser.add_argument("--batch_size", default=16, type=int, help="batch size num.")
42+
parser.add_argument(
43+
"--save_accuracy_path", default=None, help="Save accuracy results path."
44+
)
45+
parser.add_argument("--output_excel", default=None, type=str)
46+
parser.add_argument("--eval_bs", default=4, type=int,
47+
help="eval batch size")
48+
parser.add_argument("--tasks", nargs='+', default=["winogrande", "copa", "piqa", "rte", "hellaswag", \
49+
"openbookqa", "lambada_openai", "lambada_standard", "wikitext"], type=str, \
50+
help="tasks list for accuracy validation")
51+
parser.add_argument("--num_fewshot", default=0, type=int, help="num few shot.")
52+
# ============MixedPrecision configs==============
53+
parser.add_argument("--mixed_precision", action="store_true")
54+
55+
# ============h2o configs==============
56+
parser.add_argument('--h2o', action='store_true')
57+
parser.add_argument('--is_gen', action='store_true')
58+
parser.add_argument('--real_drop', action='store_true')
59+
parser.add_argument("--heavy_ratio", type=float, default=0.1)
60+
parser.add_argument("--recent_ratio", type=float, default=0.1)
61+
parser.add_argument("--device", type=str, default='cpu')
62+
parser.add_argument("--h2o_min_seqlen", type=int, default=0)
63+
64+
args = parser.parse_args()
65+
# transformers version >= 4.32.0 contained the mpt modeling definition.
66+
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/mpt/modeling_mpt.py
67+
# 4.31.0 for ipex.optimize_transformers
68+
# get model config
69+
if args.peft_model_id:
70+
from peft import PeftConfig
71+
72+
peft_config = PeftConfig.from_pretrained(args.peft_model_id)
73+
if args.model is None:
74+
args.model = peft_config.base_model_name_or_path
75+
print("we will use peft base_model_name_or_path to get tokenizer.")
76+
77+
config = AutoConfig.from_pretrained(
78+
args.model,
79+
torchscript=False,
80+
use_cache=True, # to use kv cache.
81+
trust_remote_code=args.trust_remote_code,
82+
_commit_hash=args._commit_hash,
83+
)
84+
85+
# chatglm
86+
if config.model_type == "chatglm":
87+
AutoModelForCausalLM = AutoModel
88+
# tokenizer
89+
if config.model_type == "llama":
90+
from transformers import LlamaTokenizer
91+
92+
# tokenizer = LlamaTokenizer.from_pretrained(args.model)
93+
tokenizer = AutoTokenizer.from_pretrained(args.model)
94+
else:
95+
tokenizer = AutoTokenizer.from_pretrained(
96+
args.model, trust_remote_code=args.trust_remote_code
97+
)
98+
99+
# use peft
100+
args.model = args.peft_model_id if args.peft_model_id is not None else args.model
101+
102+
# Generation
103+
if args.use_neural_speed:
104+
generate_kwargs = dict(do_sample=False, temperature=0.9, num_beams=1)
105+
else:
106+
generate_kwargs = dict(do_sample=False, temperature=0.9, num_beams=4)
107+
108+
if 'cpu' in args.device:
109+
device = args.device
110+
else:
111+
device = f"cuda:{args.device}"
112+
113+
# get optimized model
114+
if args.h2o:
115+
print('Enable Small Cache Size')
116+
from intel_extension_for_transformers.transformers.kv_cache_compression import H2OConfig, LlamaForCausalLM
117+
h2o_config = H2OConfig(
118+
heavy_ratio=args.heavy_ratio,
119+
recent_ratio=args.recent_ratio,
120+
h2o_min_seqlen=args.h2o_min_seqlen,
121+
real_drop=args.real_drop,
122+
mean=False,
123+
)
124+
user_model = LlamaForCausalLM.from_pretrained(
125+
args.model,
126+
prune_config=h2o_config,
127+
trust_remote_code=args.trust_remote_code)
128+
print("converted model: ", user_model)
129+
else:
130+
user_model = AutoModelForCausalLM.from_pretrained(args.model, trust_remote_code=args.trust_remote_code)
131+
user_model.to(device)
132+
133+
# save model
134+
# if args.output_dir is not None:
135+
# tokenizer.save_pretrained(args.output_dir)
136+
# user_model.save_pretrained(args.output_dir)
137+
138+
if args.benchmark:
139+
user_model = (
140+
user_model.eval() if (not (args.int8 or args.int8_bf16_mixed) and hasattr(user_model, "eval")) else user_model
141+
)
142+
prompt = "Once upon a time, there existed a little girl, who liked to have adventures. She wanted to go to places and meet new people, and have fun."
143+
input_size = tokenizer(prompt, return_tensors="pt").input_ids.size(dim=1)
144+
print("---- Prompt size:", input_size)
145+
146+
# start
147+
total_time = 0.0
148+
num_iter = args.iters
149+
num_warmup = args.num_warmup
150+
total_token_num = 0
151+
eos_token_id = tokenizer.eos_token_id
152+
with torch.inference_mode(), torch.no_grad():
153+
for i in range(num_iter):
154+
tic = time.time()
155+
if hasattr(tokenizer, "build_chat_input"):
156+
input_ids = tokenizer.build_chat_input(prompt)["input_ids"]
157+
input_ids = input_ids.repeat(args.batch_size, 1)
158+
eos_token_id = [
159+
tokenizer.eos_token_id,
160+
tokenizer.get_command("<|user|>"),
161+
tokenizer.get_command("<|observation|>"),
162+
]
163+
elif hasattr(tokenizer, "build_prompt"):
164+
build_prompt = tokenizer.build_prompt(prompt)
165+
input_ids = tokenizer(
166+
[build_prompt] * args.batch_size, return_tensors="pt"
167+
).input_ids
168+
else:
169+
input_ids = tokenizer(
170+
[prompt] * args.batch_size, return_tensors="pt"
171+
).input_ids
172+
gen_ids = user_model.generate(
173+
input_ids,
174+
max_new_tokens=args.max_new_tokens,
175+
**generate_kwargs,
176+
eos_token_id=eos_token_id
177+
)
178+
gen_text = tokenizer.batch_decode(gen_ids, skip_special_tokens=True)
179+
toc = time.time()
180+
# please check the gen_ids if include input_ids.
181+
input_tokens_num = input_ids.numel()
182+
output_tokens_num = torch.tensor(gen_ids).numel() - input_tokens_num
183+
print(gen_text, flush=True)
184+
if i >= num_warmup:
185+
total_time += toc - tic
186+
total_token_num += output_tokens_num
187+
188+
print("\n", "-" * 10, "Summary:", "-" * 10)
189+
latency = total_time / total_token_num
190+
print("Inference latency: %.3f sec." % latency)
191+
throughput = total_token_num / total_time
192+
print("Throughput: {} samples/sec".format(throughput))
193+
194+
if args.accuracy:
195+
user_model = (user_model.eval() if (not (args.int8 or args.int8_bf16_mixed) and hasattr(user_model, "eval")) \
196+
else user_model)
197+
# from intel_extension_for_transformers.transformers.llm.evaluation.lm_eval import evaluate, LMEvalParser
198+
# model_args="pretrained="+args.model+",trust_remote_code="+str(args.trust_remote_code)
199+
# args.tasks = ",".join(args.tasks)
200+
# tokenizer.pad_token = tokenizer.eos_token
201+
# eval_args = LMEvalParser(model = "hf",
202+
# user_model=user_model,
203+
# tokenizer=tokenizer,
204+
# model_args=model_args,
205+
# tasks = args.tasks,
206+
# device = device,
207+
# num_fewshot=args.num_fewshot,
208+
# output_path=args.save_accuracy_path,
209+
# batch_size = args.batch_size)
210+
# print("using device:", device)
211+
# results = evaluate(eval_args)
212+
213+
214+
# original lm_eval
215+
from lm_eval.evaluator import simple_evaluate
216+
from lm_eval.tasks import TaskManager
217+
import lm_eval
218+
219+
verbosity = 'INFO'
220+
task_manager = TaskManager(verbosity)
221+
limit = None
222+
cache_requests = False
223+
lm = lm_eval.api.registry.get_model("hf")(
224+
pretrained=user_model,
225+
batch_size=args.batch_size,
226+
max_batch_size=None,
227+
)
228+
model_args="pretrained="+ args.model+ ",tokenizer="+ args.model + ",dtype=float32"
229+
use_cache = None
230+
results = simple_evaluate(
231+
model=lm,
232+
model_args=model_args,
233+
tasks=args.tasks,
234+
num_fewshot=args.num_fewshot,
235+
device=device
236+
)
237+
import pprint
238+
pprint.pprint(results["results"])

0 commit comments

Comments
 (0)