Skip to content
This repository was archived by the owner on Oct 25, 2024. It is now read-only.

add gaudi modeling support in itrex #1438

Merged
merged 28 commits into from
May 24, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
56a2893
add gaudi modeling support in itrex
ClarkChin08 Mar 29, 2024
e0613ad
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 29, 2024
69e81b3
Add test example to itrex
ClarkChin08 Apr 9, 2024
6454315
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 9, 2024
72a4910
add fp8 support and fix bugs
ClarkChin08 Apr 15, 2024
4ad3b04
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 15, 2024
63f0fc0
Create requirements.txt
airMeng Apr 24, 2024
1da9dfb
add ppl measurement in gaudi
ClarkChin08 Apr 25, 2024
fb2f7cc
fix the ppl acc issue
ClarkChin08 Apr 26, 2024
360e32f
[Gaudi] Add LLAMA Streaming LLM in Gaudi (#1558)
zhentaoyu May 22, 2024
3a934c5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 22, 2024
fb2966e
fix the pylint issue
ClarkChin08 May 22, 2024
de64700
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 22, 2024
455e5c3
fix the pylint issue
ClarkChin08 May 22, 2024
83a42b2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 22, 2024
80476f1
add optimum-habana when pylint
ClarkChin08 May 22, 2024
ad793d4
Merge branch 'main' into gaudi-support
VincyZhang May 22, 2024
5aedccb
add pylint comment
ClarkChin08 May 22, 2024
092ffb6
Merge branch 'main' into gaudi-support
VincyZhang May 22, 2024
72f313f
add comment to avoid pylint check
ClarkChin08 May 23, 2024
4642ee0
ignore modeling_gaudi pylint
ClarkChin08 May 23, 2024
0f714b5
manual fix the pylint
ClarkChin08 May 23, 2024
22557de
Update pylint.sh
VincyZhang May 23, 2024
b8b7de5
Merge branch 'main' into gaudi-support
VincyZhang May 23, 2024
17614be
fix line by line pylint
ClarkChin08 May 23, 2024
e02739f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 23, 2024
2661a1e
disable before the line
ClarkChin08 May 23, 2024
cd82c90
pylint check
ClarkChin08 May 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from optimum.habana.transformers.generation.utils import MODELS_OPTIMIZED_WITH_STATIC_SHAPES
if "llava" not in MODELS_OPTIMIZED_WITH_STATIC_SHAPES:
MODELS_OPTIMIZED_WITH_STATIC_SHAPES.append("llava")
from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
from intel_extension_for_transformers.transformers.modeling.modeling_gaudi import adapt_transformers_to_gaudi
adapt_transformers_to_gaudi()

import torch
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,7 @@ def load_model(
# Tweak generation so that it runs faster on Gaudi
# pylint: disable=E0401
# pylint: disable=E0611
from optimum.habana.transformers.modeling_utils import (
from intel_extension_for_transformers.transformers.modeling.modeling_gaudi import (
adapt_transformers_to_gaudi,
)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright (c) 2024 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .modeling_utils import adapt_transformers_to_gaudi
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright (c) 2024 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .configuration_utils import GaudiGenerationConfig
from .stopping_criteria import (
gaudi_MaxLengthCriteria_call,
gaudi_MaxNewTokensCriteria_call,
)
from .utils import MODELS_OPTIMIZED_WITH_STATIC_SHAPES, GaudiGenerationMixin
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Copyright (c) 2024 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from transformers.generation import GenerationConfig


class GaudiGenerationConfig(GenerationConfig):
"""
This class extends [`transformers.generation.GenerationConfig`](https://github.com/huggingface/transformers/blob/main/src/transformers/generation/configuration_utils.py)
to add HPU-specific arguments for generation.

Arg:
trim_logit (`bool`, *optional):
Calculate logits only for the last token to save memory in the first step.
static_shapes (`bool`, *optional*):
Whether to use static shapes for generation or not. It will run faster on HPUs with static shapes
but not all models support it. If not specified, it will automatically be set to `True` if the given
model supports it.
ignore_eos (`bool`, *optional*):
Whether to ignore finished sequences (faster in lazy mode and with HPU graphs) or not (eager mode).
If not specified, it will automatically be set to `True` if lazy mode is on.
attn_softmax_bf16 (`bool`, *optional*):
Whether to run attention softmax layer in lower precision provided that the model supports it and
is also running in lower precision.
limit_hpu_graphs (`bool`, *optional*):
Skip HPU Graph usage for first token to save memory
reuse_cache (`bool`, *optional*):
Whether to reuse key/value cache for decoding. It should save memory.
bucket_size (`int`, *optional*):
If negative (default=-1) pad to max if `static_shapes` is set. Else start with
`shape = bucket_size * ceil(prompt_len/bucket_size)` and then grow space by `bucket_size` when needed.
Only active if `static_shapes` is used. Can't be used with `reuse_cache`.
bucket_internal (`bool`, *optional*):
Split kv sequence into buckets in decode phase. It improves throughput when max_new_tokens is large.
kv_cache_fp8 (`bool`, *optional*):
Store kv-cache in float8 when kv-cache is used
use_flash_attention (`bool`, *optional*):
Whether to use flash attention optimization.
flash_attention_recompute (`bool`, *optional*):
Whether to enable recompute if use Habana flash attention.
"""

def __init__(self, **kwargs):
super().__init__(**kwargs)
self.trim_logits = kwargs.get("trim_logits", None)
self.static_shapes = kwargs.get("static_shapes", None)
self.ignore_eos = kwargs.get("ignore_eos", None)
self.attn_softmax_bf16 = kwargs.get("attn_softmax_bf16", None)
self.limit_hpu_graphs = kwargs.get("limit_hpu_graphs", None)
self.reuse_cache = kwargs.get("reuse_cache", None)
self.bucket_size = kwargs.get("bucket_size", -1)
self.bucket_internal = kwargs.get("bucket_internal", None)
self.reduce_recompile = kwargs.get("reduce_recompile", None)
self.kv_cache_fp8 = kwargs.get("kv_cache_fp8", None)
self.use_flash_attention = kwargs.get("use_flash_attention", None)
self.flash_attention_recompute = kwargs.get("flash_attention_recompute", None)
self.use_fused_rope = kwargs.get("use_fused_rope", None)
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# coding=utf-8
# Copyright 2022 The Google AI Language Team Authors, Facebook AI Research authors and The HuggingFace Inc. team.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch

from optimum.utils import logging


logger = logging.get_logger(__name__)


def gaudi_MaxLengthCriteria_call(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
token_idx = kwargs.get("token_idx", None)
if token_idx is not None:
return token_idx >= self.max_length
else:
cur_len = input_ids.shape[-1]
is_done = cur_len >= self.max_length
if self.max_position_embeddings is not None and not is_done and cur_len >= self.max_position_embeddings:
logger.warning_once(
"This is a friendly reminder - the current text generation call will exceed the model's predefined "
f"maximum length ({self.max_position_embeddings}). Depending on the model, you may observe "
"exceptions, performance degradation, or nothing at all."
)
return is_done


def gaudi_MaxNewTokensCriteria_call(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
token_idx = kwargs.get("token_idx", None)
if token_idx is not None:
return token_idx >= self.max_length
else:
return input_ids.shape[-1] >= self.max_length
Loading