Skip to content

Commit 7ddbc83

Browse files
author
Rohan Yadav
committed
taco: adjust suitesparse benchmark for when taco can't read input tensor
taco only reads matrix market tensors in the coordinate format, so handle errors when some of the matrix market tensors have different formats.
1 parent c5f76be commit 7ddbc83

File tree

1 file changed

+14
-2
lines changed

1 file changed

+14
-2
lines changed

taco/ufuncs.cpp

+14-2
Original file line numberDiff line numberDiff line change
@@ -169,8 +169,10 @@ struct UfuncInputCache {
169169
}
170170

171171
// Otherwise, we missed the cache. Load in the target tensor and process it.
172-
this->lastPath = path;
173172
this->lastLoaded = taco::read(path, format);
173+
// We assign lastPath after lastLoaded so that if taco::read throws an exception
174+
// then lastPath isn't updated to the new path.
175+
this->lastPath = path;
174176
this->inputTensor = castToType<int64_t>("A", this->lastLoaded);
175177
this->otherTensor = shiftLastMode<int64_t, int64_t>("B", this->inputTensor);
176178
return std::make_pair(this->inputTensor, this->otherTensor);
@@ -288,6 +290,9 @@ struct SuiteSparseTensors {
288290
SuiteSparseTensors ssTensors;
289291

290292
static void bench_suitesparse_ufunc(benchmark::State& state, Func op) {
293+
// Counters must be present in every run to get reported to the CSV.
294+
state.counters["dimx"] = 0;
295+
state.counters["dimy"] = 0;
291296
if (ssTensors.tensors.size() == 0) {
292297
state.error_occurred();
293298
return;
@@ -300,7 +305,14 @@ static void bench_suitesparse_ufunc(benchmark::State& state, Func op) {
300305
state.SetLabel(tensorName);
301306

302307
taco::Tensor<int64_t> ssTensor, other;
303-
std::tie(ssTensor, other) = inputCache.getUfuncInput(tensorPath, CSR);
308+
try {
309+
std::tie(ssTensor, other) = inputCache.getUfuncInput(tensorPath, CSR);
310+
} catch (TacoException& e) {
311+
// Counters don't show up in the generated CSV if we used SkipWithError, so
312+
// just add in the label that this run is skipped.
313+
state.SetLabel(tensorName+"-SKIPPED-FAILED-READ");
314+
return;
315+
}
304316

305317
state.counters["dimx"] = ssTensor.getDimension(0);
306318
state.counters["dimy"] = ssTensor.getDimension(1);

0 commit comments

Comments
 (0)