Skip to content

Commit 639eafc

Browse files
committed
taco: add tensor mode shifting logic to c++
1 parent b5f4d66 commit 639eafc

File tree

1 file changed

+16
-0
lines changed

1 file changed

+16
-0
lines changed

taco/bench.h

+16
Original file line numberDiff line numberDiff line change
@@ -40,4 +40,20 @@
4040

4141
taco::TensorBase loadRandomTensor(std::string name, std::vector<int> dims, float sparsity, taco::Format format);
4242

43+
template<typename T>
44+
taco::Tensor<T> shiftLastMode(std::string name, taco::Tensor<T> original) {
45+
taco::Tensor<T> result(name, original.getDimensions(), original.getFormat());
46+
std::vector<int> coords(original.getOrder());
47+
for (auto& value : taco::iterate<T>(original)) {
48+
for (int i = 0; i < original.getOrder(); i++) {
49+
coords[i] = value.first[i];
50+
}
51+
int lastMode = original.getOrder() - 1;
52+
coords[lastMode] = (coords[lastMode] + 1) % original.getDimension(lastMode);
53+
result.insert(coords, value.second);
54+
}
55+
result.pack();
56+
return result;
57+
}
58+
4359
#endif //TACO_BENCH_BENCH_H

0 commit comments

Comments
 (0)