Skip to content

Commit 953d654

Browse files
authored
Split woq kernel file to speed up compilation (#3382)
* tmp commit 20240926 * Put qlinear_woq_affine_impl and more in seperate header files * Unify calling qlinear_woq_affine_impl with or without weight zero points * Rename WoqTppKrnl.cpp -> WoqLinearKrnl.cpp * Improve ISA control macros in woq.h * Move quantize_per_block to woq.h * Split WOQ kernel file * clang-format * Further split int8 gemm kernel files * Remove unused woq.cpp * refine function notes * Refine file names
1 parent 1e2711c commit 953d654

17 files changed

+6366
-5399
lines changed

csrc/cpu/aten/Linear.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include "WeightPack.h"
77
#include "autocast/autocast_mode.h"
88
#include "ideep/IDeepConversions.h"
9+
#include "utils/woq_defines.h"
910

1011
namespace torch_ipex {
1112
namespace cpu {

csrc/cpu/aten/Linear.h

+41-20
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,32 @@ using woq_tpp_gemm_kernel_fn = at::Tensor (*)(
234234
int64_t,
235235
const c10::optional<at::Tensor>&);
236236

237+
using woq_gemm_kernel_fn = at::Tensor (*)(
238+
const at::Tensor&,
239+
const at::Tensor&,
240+
const std::vector<at::Tensor>&,
241+
const std::vector<at::Tensor>&,
242+
const std::vector<at::Tensor>&,
243+
const int,
244+
int64_t,
245+
const std::vector<at::Tensor>&,
246+
int64_t,
247+
int64_t);
248+
249+
using woq_int8_gemm_kernel_fn = at::Tensor (*)(
250+
const at::Tensor&,
251+
const at::Tensor&,
252+
const std::vector<at::Tensor>&,
253+
const std::vector<at::Tensor>&,
254+
const std::vector<at::Tensor>&,
255+
const int,
256+
int64_t,
257+
const std::vector<at::Tensor>&,
258+
int64_t,
259+
int64_t,
260+
int64_t,
261+
const c10::optional<at::Tensor>&);
262+
237263
using woq_tpp_gemm_packB_fn =
238264
at::Tensor (*)(const at::Tensor&, int, size_t, size_t, int64_t);
239265

@@ -254,33 +280,28 @@ using dequant_nf4_fn = at::Tensor (*)(
254280
c10::ScalarType);
255281

256282
IPEX_DECLARE_DISPATCH(woq_tpp_gemm_kernel_fn, woq_tpp_gemm_kernel_stub);
283+
IPEX_DECLARE_DISPATCH(woq_gemm_kernel_fn, woq_fp32_gemm_kernel_stub);
284+
IPEX_DECLARE_DISPATCH(woq_gemm_kernel_fn, woq_fp16_gemm_kernel_stub);
285+
IPEX_DECLARE_DISPATCH(woq_gemm_kernel_fn, woq_bf16_gemm_kernel_stub);
286+
IPEX_DECLARE_DISPATCH(
287+
woq_int8_gemm_kernel_fn,
288+
woq_int8_gemm_pre_tensor_kernel_stub);
289+
IPEX_DECLARE_DISPATCH(
290+
woq_int8_gemm_kernel_fn,
291+
woq_int8_gemm_pre_k_block_kernel_stub);
292+
IPEX_DECLARE_DISPATCH(
293+
woq_int8_gemm_kernel_fn,
294+
woq_int8_gemm_pre_m_block_kernel_stub);
295+
IPEX_DECLARE_DISPATCH(
296+
woq_int8_gemm_kernel_fn,
297+
woq_int8_gemm_pre_m_k_block_kernel_stub);
257298
IPEX_DECLARE_DISPATCH(woq_tpp_gemm_packB_fn, woq_tpp_gemm_packB_stub);
258299
IPEX_DECLARE_DISPATCH(woq_tpp_gemm_unpackB_fn, woq_tpp_gemm_unpackB_stub);
259300
IPEX_DECLARE_DISPATCH(
260301
woq_dequant_int4_to_int8_packed_fn,
261302
woq_dequant_int4_to_int8_packed_stub);
262303
IPEX_DECLARE_DISPATCH(dequant_nf4_fn, dequant_nf4_stub);
263304

264-
// Fusion types
265-
#define WOQ_FUSE_NONE 0x0
266-
// Unary post ops
267-
#define WOQ_FUSE_GELU_ERF 0x1
268-
#define WOQ_FUSE_GELU_TANH 0x2
269-
#define WOQ_FUSE_RELU 0x3
270-
#define WOQ_FUSE_SILU 0x4
271-
// Binary post ops
272-
#define WOQ_FUSE_ADD 0x10
273-
#define WOQ_FUSE_ADD_ADD 0x20
274-
#define WOQ_FUSE_MUL 0x30
275-
276-
// weight quant mode
277-
#define QUANT_W_PER_CHANNEL 0
278-
#define QUANT_W_PER_K_BLOCK 1
279-
#define QUANT_W_PER_CHANNEL_SYM 2
280-
#define QUANT_W_PER_K_BLOCK_SYM 3
281-
282-
#define WOQ_N_BLOCK_SIZE 32
283-
284305
#endif
285306

286307
} // namespace cpu
+170
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
// weight-only quantization gemm kernel (int8, int4 etc.)
2+
#ifdef USE_LIBXSMM
3+
#include <ATen/ATen.h>
4+
#include <ATen/Tensor.h>
5+
#include <ATen/cpu/vec/functional.h>
6+
#include <ATen/cpu/vec/vec.h>
7+
#include "aten/utils/woq.h"
8+
9+
#ifdef __GNUC__
10+
#include <features.h>
11+
#if __GNUC_PREREQ(12, 3)
12+
#define COMPILER_PREREQ_MET
13+
#endif
14+
#endif
15+
16+
namespace torch_ipex {
17+
namespace cpu {
18+
namespace {
19+
20+
using TensorList = std::vector<at::Tensor>;
21+
22+
// We only build optimized kernels if AVX512_FP16 is supported and gcc>=12.3
23+
#if defined(CPU_CAPABILITY_AVX512_FP16) && defined(COMPILER_PREREQ_MET)
24+
25+
// We separate GEMM kernel in different files to avoid long compile time
26+
/**
27+
* @brief quantized linear with quantized weight but activation in floating
28+
* point. Compute in bfloat16.
29+
*
30+
* @param x input activation in floating point format, 2D plain format [M,K]
31+
* @param qw weight in affine quantized format, could be 4-bit or 8-bit
32+
* quantized in 4D blocked format [Nc,Kc,Kb,Nb] or 2D plain format [N,K].
33+
* @param scales_list a list of fp32/fp16/bf16 scales tensors
34+
* @param zp_list a list of fp32/fp16/bf16/int8 zero points tensors
35+
* @param bias_list a list of fp32/fp16/bf16 bias tensors
36+
* @param qw_type weight dtype, such as int8, int4, etc.
37+
* @param fusion_type fusion type, such as gelu, add, etc.
38+
* @param others_list a list of other inputs for post ops, such as binary add,
39+
* etc.
40+
* @param quant_w_mode quantization mode for weight
41+
* @param quant_block_k block size for quantization
42+
* @return at::Tensor output in same dtype as `x`, 2D plain format [M,N]
43+
*/
44+
at::Tensor woq_gemm_bf16(
45+
const at::Tensor& x,
46+
const at::Tensor& qw,
47+
const TensorList& scales_list,
48+
const TensorList& zp_list,
49+
const TensorList& bias_list,
50+
const int qw_type,
51+
int64_t fusion_type,
52+
const TensorList& others_list,
53+
int64_t quant_w_mode = 0,
54+
int64_t quant_block_k = 0) {
55+
const int64_t k_splits = 0;
56+
quant_block_k = std::max(0L, quant_block_k);
57+
// int8_idx is only valid with zp_list when lowp_mode == LOWP_MODE_INT8
58+
constexpr size_t fp32_idx = 0, fp16_idx = 1, bf16_idx = 2, int8_idx = 3;
59+
auto biases = bias_list.empty()
60+
? TensorList({at::Tensor(), at::Tensor(), at::Tensor()})
61+
: bias_list;
62+
const bool is_4bit_flag = is_4bit(qw_type);
63+
const bool asym_quant_w = is_asymmetric_quant_w(quant_w_mode);
64+
if (qw_type == WOQ_DTYPE_NF4) {
65+
TORCH_CHECK(
66+
!asym_quant_w, "WOQ: symmetric quantization is required for NF4");
67+
}
68+
if (qw.dim() == 4) {
69+
auto w_sizes = qw.sizes();
70+
auto K = x.size(-1);
71+
auto M = x.numel() / K;
72+
auto N = w_sizes[0] * w_sizes[3];
73+
if (is_4bit_flag) {
74+
N *= 2;
75+
}
76+
auto out_sizes = x.sizes().vec();
77+
out_sizes.back() = N;
78+
auto y = at::empty(out_sizes, x.options());
79+
product_dispatcher<
80+
std::tuple<at::ScalarType, long>,
81+
std::tuple<
82+
enumerate_dispatcher<
83+
at::ScalarType,
84+
at::kFloat,
85+
at::kBFloat16,
86+
at::kHalf>,
87+
range_dispatcher<long, 0, 3>>>::
88+
call(
89+
std::make_tuple(x.scalar_type(), quant_w_mode),
90+
[&](auto tuple) {
91+
auto act_dtype = std::get<0>(tuple);
92+
auto quant_w_mode_ = std::get<1>(tuple);
93+
using act_type =
94+
typename c10::impl::ScalarTypeToCPPType<act_dtype>::type;
95+
qlinear_woq_affine_impl<
96+
act_type,
97+
bfloat16,
98+
/*TGemmOut*/ float,
99+
act_type,
100+
bfloat16,
101+
bfloat16,
102+
UNQUANT_A,
103+
quant_w_mode_>(
104+
x,
105+
qw,
106+
scales_list[bf16_idx],
107+
biases[fp32_idx],
108+
y,
109+
qw_type,
110+
k_splits,
111+
fusion_type,
112+
others_list,
113+
quant_block_k,
114+
zp_list[bf16_idx]);
115+
},
116+
[](auto tuple) { failing_fallback(); });
117+
return y;
118+
} else {
119+
return woq_gemm_ref_impl(
120+
x,
121+
qw,
122+
scales_list,
123+
zp_list,
124+
bias_list,
125+
qw_type,
126+
at::kBFloat16,
127+
fusion_type,
128+
others_list,
129+
quant_w_mode,
130+
quant_block_k);
131+
}
132+
}
133+
134+
#else // defined(CPU_CAPABILITY_AVX512_FP16) && defined(COMPILER_PREREQ_MET)
135+
136+
at::Tensor woq_gemm_bf16(
137+
const at::Tensor& x,
138+
const at::Tensor& qw,
139+
const TensorList& scales_list,
140+
const TensorList& zp_list,
141+
const TensorList& bias_list,
142+
const int qw_type,
143+
int64_t fusion_type,
144+
const TensorList& others_list,
145+
int64_t quant_w_mode = 0,
146+
int64_t quant_block_k = 0) {
147+
return woq_gemm_ref_impl(
148+
x,
149+
qw,
150+
scales_list,
151+
zp_list,
152+
bias_list,
153+
qw_type,
154+
at::kBFloat16,
155+
fusion_type,
156+
others_list,
157+
quant_w_mode,
158+
quant_block_k);
159+
}
160+
161+
#endif // defined(CPU_CAPABILITY_AVX512_FP16) && defined(COMPILER_PREREQ_MET)
162+
163+
} // namespace
164+
165+
IPEX_REGISTER_DISPATCH(woq_bf16_gemm_kernel_stub, &woq_gemm_bf16);
166+
167+
} // namespace cpu
168+
} // namespace torch_ipex
169+
170+
#endif

0 commit comments

Comments
 (0)