Skip to content

Commit e6bab05

Browse files
authored
add GPTJ linear shapes to the list to fallback TPP to oneDNN (#3368)
1 parent 68d6d51 commit e6bab05

File tree

2 files changed

+3
-2
lines changed

2 files changed

+3
-2
lines changed

intel_extension_for_pytorch/nn/utils/_weight_prepack.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,8 @@ def TPPLinear_weight_prepack(m, bk=None, bc=None, layer_dtype=torch.float32):
5959
#
6060
# For long term, mark as TODO, we will tune TPP block layout/loop order to make it on par with oneDNN.
6161

62-
fallback_ic_shape_list = [13824, 11008]
63-
fallback_oc_shape_list = [4096, 5120]
62+
fallback_ic_shape_list = [13824, 11008, 16384, 4096]
63+
fallback_oc_shape_list = [4096, 5120, 16384, 12288]
6464

6565

6666
def Apply_TPPLinear_weight_prepack(m, dtype, device="cpu"):

tests/cpu/test_tpp_linear.py

+1
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ def test_tpp_linear_fallback_env_set(self):
146146
assert model.mlp.use_tpp is True
147147
self.assertEqual(out, ref_out)
148148
_disable_tpp()
149+
os.environ["BF16_OPTIMIZED_THROUGHPUT"] = "0"
149150

150151
def test_tpp_linear_fallback_flag(self):
151152
x1 = torch.rand(1, 1, 4097)

0 commit comments

Comments
 (0)