-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathSubtensor.cuh
45 lines (35 loc) · 1.28 KB
/
Subtensor.cuh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
#pragma once
#include "Utils.cuh"
#include "Ops.cuh"
template <typename TDyn, typename TBase, typename TIdx>
_dev TDyn dynamicSubtensor(TBase &base, const TIdx &idx) {
TDyn dyn;
dyn.mData = const_cast<typename TDyn::DataType *>(base.data());
dyn.mOffset = base.offset();
for (idx_t dim = 0; dim < dyn.nDim(); dim++) {
if (idx[dim])
dyn.mOffset += base.stride(dim) * (*(idx[dim]));
dyn.mStride[dim] = base.stride(dim);
}
return dyn;
}
template <typename TDyn, typename TBase, typename TIdx>
_dev const TDyn dynamicSubtensor(const TBase &base, const TIdx &idx) {
return dynamicSubtensor(const_cast<TBase>(base), idx);
}
template <typename TTrgt, typename TBaseSrc, typename TDynSrc, idx_t nTrgtIdxs>
_dev void copyFromDynamicSubtensor(TTrgt &trgt,
const TBaseSrc &baseSrc, const Array<idx_t *, nTrgtIdxs> &srcIdx)
{
IdEOp_t copyOp;
TDynSrc dynSrc = dynamicSubtensor<TDynSrc>(baseSrc, srcIdx);
TTrgt::elemwise1Ary(copyOp, trgt, dynSrc);
}
template <typename TBaseTrgt, typename TDynTrgt, idx_t nTrgtIdxs, typename TSrc>
_dev void copyToDynamicSubtensor(TBaseTrgt &baseTrgt, const Array<idx_t *, nTrgtIdxs> &trgtIdx,
const TSrc &src)
{
IdEOp_t copyOp;
TDynTrgt dynTrgt = dynamicSubtensor<TDynTrgt>(baseTrgt, trgtIdx);
TDynTrgt::elemwise1Ary(copyOp, dynTrgt, src);
}