@@ -106,19 +106,21 @@ struct UnionLeftCompAlgebra {
106
106
}
107
107
};
108
108
109
- // Logical Not Op and Algebra
110
- struct Power {
111
- ir::Expr operator ()(const std::vector<ir::Expr> &v) {
112
- return ir::Literal (1 , v.get )
113
- }
114
- };
115
109
116
110
struct CompAlgebra {
117
111
IterationAlgebra operator ()(const std::vector<IndexExpr>& regions) {
118
112
return Complement (regions[0 ]);
119
113
}
120
114
};
121
115
116
+ struct NestedXorAlgebra {
117
+ IterationAlgebra operator ()(const std::vector<IndexExpr> & regions) {
118
+ IterationAlgebra intersect2 = Union (Intersect (regions[2 ], Union (regions[0 ], regions[1 ])), Intersect (regions[0 ], Union (regions[2 ], regions[1 ])));
119
+ IterationAlgebra intersect3 = Intersect (Intersect (regions[0 ], regions[1 ]), regions[2 ]);
120
+ IterationAlgebra unionComplement = Complement (Union (Union (regions[0 ], regions[1 ]), regions[2 ]));
121
+ return Union (Complement (Union (intersect2, unionComplement)), intersect3);
122
+ }
123
+ };
122
124
123
125
template <int I, class ...Ts>
124
126
decltype (auto ) get(Ts&&... ts) {
@@ -212,28 +214,34 @@ Func rightShift("right_shift", RightShift(), leftIncAlgebra());
212
214
Func xorOp (" logical_xor" , GeneralAdd(), xorAlgebra());
213
215
Func andOp (" logical_and" , GeneralAdd(), andAlgebra());
214
216
Func orOp (" logical_or" , GeneralAdd(), orAlgebra());
215
-
217
+ Func nestedXorOp ( " fused_xor " , GeneralAdd(), NestedXorAlgebra());
216
218
static void bench_ufunc_fused (benchmark::State& state, const Format& f) {
217
219
int dim = state.range (0 );
218
220
auto sparsity = 0.01 ;
219
- Tensor<double > matrix = loadRandomTensor (" A" , {dim, dim}, sparsity, f);
220
- Tensor<double > matrix1 = loadRandomTensor (" B" , {dim, dim}, sparsity, f, 1 /* variant */ );
221
- Tensor<double > matrix2 = loadRandomTensor (" C" , {dim, dim}, sparsity, f, 2 /* variant */ );
221
+ Tensor<int64_t > matrix = castToType< int64_t >( " A " , loadRandomTensor (" A" , {dim, dim}, sparsity, f) );
222
+ Tensor<int64_t > matrix1 = castToType< int64_t >( " B " , loadRandomTensor (" B" , {dim, dim}, sparsity, f, 1 /* variant */ ) );
223
+ Tensor<int64_t > matrix2 = castToType< int64_t >( " C " , loadRandomTensor (" C" , {dim, dim}, sparsity, f, 2 /* variant */ ) );
222
224
223
225
for (auto _ : state) {
224
226
state.PauseTiming ();
225
- Tensor<double > result (" result" , {dim, dim}, f);
227
+ Tensor<int64_t > result (" result" , {dim, dim}, f);
226
228
IndexVar i (" i" ), j (" j" );
227
- result (i, j) = andOp ( xorOp ( matrix (i, j), matrix1 (i, j) ), matrix2 (i, j));
229
+ result (i, j) = nestedXorOp ( matrix (i, j), matrix1 (i, j), matrix2 (i, j));
228
230
result.setAssembleWhileCompute (true );
229
231
result.compile ();
230
232
state.ResumeTiming ();
231
233
232
234
result.compute ();
235
+ result = result.removeExplicitZeros (result.getFormat ());
236
+ int nnz = 0 ;
237
+ for (auto & it : iterate<int64_t >(result)) {
238
+ nnz++;
239
+ }
240
+ std::cout << " Result NNZ = " << nnz << std::endl;
233
241
}
234
242
}
235
- // TACO_BENCH_ARGS(bench_ufunc_fused, csr, CSR)
236
- // ->ArgsProduct({{5000, 10000, 20000}});
243
+ TACO_BENCH_ARGS (bench_ufunc_fused, csr, CSR)
244
+ ->ArgsProduct({{5000 , 10000 , 20000 }});
237
245
238
246
// UfuncInputCache is a cache for the input to ufunc benchmarks. These benchmarks
239
247
// operate on a tensor loaded from disk and the same tensor shifted slightly. Since
@@ -362,6 +370,7 @@ FOREACH_FROSTT_TENSOR(DECLARE_FROSTT_UFUNC_BENCH)
362
370
enum FusedUfuncOp {
363
371
XOR_AND = 1 ,
364
372
XOR_OR = 2 ,
373
+ XOR_XOR = 3 ,
365
374
};
366
375
367
376
static void bench_frostt_ufunc_fused (benchmark::State& state, std::string tnsPath, FusedUfuncOp op) {
@@ -397,6 +406,10 @@ static void bench_frostt_ufunc_fused(benchmark::State& state, std::string tnsPat
397
406
result (i, j, k) = orOp (xorOp (frosttTensor (i, j, k), other (i, j, k)), third (i, j, k));
398
407
break ;
399
408
}
409
+ case XOR_XOR: {
410
+ result (i, j, k) = nestedXorOp (frosttTensor (i, j, k), other (i, j, k), third (i, j, k));
411
+ break ;
412
+ }
400
413
default :
401
414
state.SkipWithError (" invalid fused op" );
402
415
return ;
@@ -414,6 +427,10 @@ static void bench_frostt_ufunc_fused(benchmark::State& state, std::string tnsPat
414
427
result (i, j, k, l) = orOp (xorOp (frosttTensor (i, j, k, l), other (i, j, k, l)), third (i, j, k, l));
415
428
break ;
416
429
}
430
+ case XOR_XOR: {
431
+ result (i, j, k, l) = nestedXorOp (frosttTensor (i, j, k, l), other (i, j, k, l), third (i, j, k, l));
432
+ break ;
433
+ }
417
434
default :
418
435
state.SkipWithError (" invalid fused op" );
419
436
return ;
@@ -431,6 +448,10 @@ static void bench_frostt_ufunc_fused(benchmark::State& state, std::string tnsPat
431
448
result (i, j, k, l, m) = orOp (xorOp (frosttTensor (i, j, k, l, m), other (i, j, k, l, m)), third (i, j, k, l, m));
432
449
break ;
433
450
}
451
+ case XOR_XOR: {
452
+ result (i, j, k, l, m) = nestedXorOp (frosttTensor (i, j, k, l, m), other (i, j, k, l, m), third (i, j, k, l, m));
453
+ break ;
454
+ }
434
455
default :
435
456
state.SkipWithError (" invalid fused op" );
436
457
return ;
@@ -445,12 +466,14 @@ static void bench_frostt_ufunc_fused(benchmark::State& state, std::string tnsPat
445
466
state.ResumeTiming ();
446
467
447
468
result.compute ();
469
+
448
470
}
449
471
}
450
472
451
473
#define DECLARE_FROSTT_FUSED_UFUNC_BENCH (name, path ) \
452
474
TACO_BENCH_ARGS (bench_frostt_ufunc_fused, name/xorAndFused, path, XOR_AND); \
453
475
TACO_BENCH_ARGS (bench_frostt_ufunc_fused, name/xorOrFused, path, XOR_OR); \
476
+ // TACO_BENCH_ARGS(bench_frostt_ufunc_fused, name/xorXorFused, path, XOR_XOR); \
454
477
455
478
FOREACH_FROSTT_TENSOR (DECLARE_FROSTT_FUSED_UFUNC_BENCH)
456
479
0 commit comments