diff --git a/python/llm/src/ipex_llm/transformers/low_bit_linear.py b/python/llm/src/ipex_llm/transformers/low_bit_linear.py index b479eb5bccf..f9b0757842d 100644 --- a/python/llm/src/ipex_llm/transformers/low_bit_linear.py +++ b/python/llm/src/ipex_llm/transformers/low_bit_linear.py @@ -644,7 +644,6 @@ def forward(self, x: torch.Tensor): if x0.device.type == "xpu": # GPU logic try: - import intel_extension_for_pytorch import xe_linear from ipex_llm.transformers.models.utils import use_xmx except ModuleNotFoundError: diff --git a/python/llm/src/ipex_llm/transformers/models/utils.py b/python/llm/src/ipex_llm/transformers/models/utils.py index 734c7a3e723..0c6f6208d3d 100644 --- a/python/llm/src/ipex_llm/transformers/models/utils.py +++ b/python/llm/src/ipex_llm/transformers/models/utils.py @@ -346,8 +346,7 @@ def use_decoding_fast_path(proj, def use_xmx(x: torch.Tensor, qtype: int): device = get_xpu_device_type(x) return ( - os.environ.get("BIGDL_LLM_XMX_DISABLED", "0") != "1" - and device in ["arc", "flex", "pvc"] + device in ["arc", "flex", "pvc"] and qtype in [SYM_INT4, SYM_INT8, FP8E4, FP8E5] and ( (device == "pvc" and 1 < x.size(0) <= 16)