Skip to content

Commit 23cc2e2

Browse files
xin3hexinhe3
andauthored
add INC_PT_ONLY and INC_TF_ONLY (#2202)
* add INC_PT_ONLY and INC_TF_ONLY * compatible with previous install method --------- Signed-off-by: Xin He <[email protected]> Co-authored-by: Xin He <[email protected]>
1 parent e7f452f commit 23cc2e2

File tree

4 files changed

+32
-22
lines changed

4 files changed

+32
-22
lines changed

neural_compressor/__init__.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,18 @@
1717
"""Intel® Neural Compressor: An open-source Python library supporting popular model compression techniques."""
1818
from .version import __version__
1919

20-
# we need to set a global 'NA' backend, or Model can't be used
21-
from .config import (
22-
DistillationConfig,
23-
PostTrainingQuantConfig,
24-
WeightPruningConfig,
25-
QuantizationAwareTrainingConfig,
26-
MixedPrecisionConfig,
27-
)
28-
from .contrib import *
29-
from .model import *
30-
from .metric import *
31-
from .utils import options
32-
from .utils.utility import set_random_seed, set_tensorboard, set_workspace, set_resume_from
20+
import os
21+
22+
if not (os.environ.get("INC_PT_ONLY", False) or os.environ.get("INC_TF_ONLY", False)):
23+
from .config import (
24+
DistillationConfig,
25+
PostTrainingQuantConfig,
26+
WeightPruningConfig,
27+
QuantizationAwareTrainingConfig,
28+
MixedPrecisionConfig,
29+
)
30+
from .contrib import *
31+
from .model import *
32+
from .metric import *
33+
from .utils import options
34+
from .utils.utility import set_random_seed, set_tensorboard, set_workspace, set_resume_from

setup.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -95,22 +95,25 @@ def get_build_version():
9595

9696

9797
if __name__ == "__main__":
98-
cfg_key = "neural_compressor"
99-
100-
# Temporary implementation of fp8 tensor saving and loading
101-
# Will remove after Habana torch applies below patch:
102-
# https://github.com/pytorch/pytorch/pull/114662
103-
ext_modules = []
104-
cmdclass = {}
105-
98+
# for setuptools>=80.0.0, `INC_PT_ONLY=1 pip install -e .`
99+
if os.environ.get("INC_PT_ONLY", False) and os.environ.get("INC_TF_ONLY", False):
100+
raise ValueError("Both INC_PT_ONLY and INC_TF_ONLY are set. Please set only one.")
101+
if os.environ.get("INC_PT_ONLY", False):
102+
cfg_key = "neural_compressor_pt"
103+
elif os.environ.get("INC_TF_ONLY", False):
104+
cfg_key = "neural_compressor_tf"
105+
else:
106+
cfg_key = "neural_compressor"
107+
# for setuptools < 80.0.0, `python setup.py develop pt`
106108
if "pt" in sys.argv:
107109
sys.argv.remove("pt")
108110
cfg_key = "neural_compressor_pt"
109-
110111
if "tf" in sys.argv:
111112
sys.argv.remove("tf")
112113
cfg_key = "neural_compressor_tf"
113114

115+
ext_modules = []
116+
cmdclass = {}
114117
project_name = PKG_INSTALL_CFG[cfg_key].get("project_name")
115118
include_packages = PKG_INSTALL_CFG[cfg_key].get("include_packages") or {}
116119
package_data = PKG_INSTALL_CFG[cfg_key].get("package_data") or {}
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
# Called once at the beginning of the test session
22
def pytest_sessionstart():
3+
import os
34
import habana_frameworks.torch.core as htcore
45
import torch
56

67
htcore.hpu_set_env()
78

89
# Use reproducible results
910
torch.use_deterministic_algorithms(True)
11+
# Ensure that only 3x PyTorch part of INC is imported
12+
os.environ.setdefault("INC_PT_ONLY", "1")
1013

1114
# Fix the seed - just in case
1215
torch.manual_seed(0)

test/3x/torch/quantization/fp8_quant/conftest.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
# Ensure that the HPU is in lazy mode and weight sharing is disabled
33
os.environ.setdefault("PT_HPU_LAZY_MODE", "1")
44
os.environ.setdefault("PT_HPU_WEIGHT_SHARING", "0")
5+
# Ensure that only 3x PyTorch part of INC is imported
6+
os.environ.setdefault("INC_PT_ONLY", "1")
57

68

79
def pytest_sessionstart():

0 commit comments

Comments
 (0)