Skip to content

Commit 7914d39

Browse files
committed
fix dimension bug
1 parent 6f3fe05 commit 7914d39

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

layers/FourierCorrelation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def forward(self, q, k, v, mask):
9797
xv = v.permute(0, 2, 3, 1)
9898

9999
# Compute Fourier coefficients
100-
xq_ft_ = torch.zeros(B, H, E, len(self.index_q)+1, device=xq.device, dtype=torch.cfloat)
100+
xq_ft_ = torch.zeros(B, H, E, len(self.index_q), device=xq.device, dtype=torch.cfloat)
101101
xq_ft = torch.fft.rfft(xq, dim=-1)
102102
for i, j in enumerate(self.index_q):
103103
xq_ft_[:, :, :, i] = xq_ft[:, :, :, j]

0 commit comments

Comments
 (0)