Skip to content

Commit e33ab50

Browse files
committed
numpy,taco: adjust tensor shifting to shift less for higher order tensors
1 parent 0e06f5b commit e33ab50

File tree

2 files changed

+11
-2
lines changed

2 files changed

+11
-2
lines changed

numpy/util.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,11 @@ def shiftLastMode(self, tensor):
171171
# resultValues[i] = data[i]
172172
# TODO (rohany): Temporarily use a constant as the value.
173173
resultValues[i] = 2
174-
resultCoords[-1][i] = (resultCoords[-1][i] + 1) % tensor.shape[-1]
174+
# For order 2 tensors, always shift the last coordinate. Otherwise, shift only coordinates
175+
# that have even last coordinates. This ensures that there is at least some overlap
176+
# between the original tensor and its shifted counter part.
177+
if tensor.shape[-1] <= 0 or resultCoords[-1][i] % 2 == 0:
178+
resultCoords[-1][i] = (resultCoords[-1][i] + 1) % tensor.shape[-1]
175179
return sparse.COO(resultCoords, resultValues, tensor.shape)
176180

177181
# ScipyTensorShifter shifts all elements in the last mode

taco/bench.h

+6-1
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,12 @@ taco::Tensor<T> shiftLastMode(std::string name, taco::Tensor<T2> original) {
5050
coords[i] = value.first[i];
5151
}
5252
int lastMode = original.getOrder() - 1;
53-
coords[lastMode] = (coords[lastMode] + 1) % original.getDimension(lastMode);
53+
// For order 2 tensors, always shift the last coordinate. Otherwise, shift only coordinates
54+
// that have even last coordinates. This ensures that there is at least some overlap
55+
// between the original tensor and its shifted counter part.
56+
if (original.getOrder() <= 2 || (coords[lastMode] % 2 == 0)) {
57+
coords[lastMode] = (coords[lastMode] + 1) % original.getDimension(lastMode);
58+
}
5459
// TODO (rohany): Temporarily use a constant value here.
5560
result.insert(coords, T(2));
5661
}

0 commit comments

Comments
 (0)