Skip to content

Commit 009043f

Browse files
committed
refactor: wavelts
1 parent 16a3501 commit 009043f

File tree

1 file changed

+9
-129
lines changed

1 file changed

+9
-129
lines changed

layers/MultiWaveletCorrelation.py

+9-129
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from layers.utils import get_filter
1616

1717

18-
from layers.FourierCorrelation import FourierBlock, FourierCrossAttention
18+
# from layers.FourierCorrelation import FourierBlock, # FourierCrossAttention
1919

2020

2121
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -60,10 +60,10 @@ def forward(self, queries, keys, values, attn_mask):
6060
return (V.contiguous(), None)
6161

6262

63-
class FourierCrossAttention1(nn.Module):
63+
class FourierCrossAttentionW(nn.Module):
6464
def __init__(self, in_channels, out_channels, seq_len_q, seq_len_kv, modes=16, activation='tanh',
6565
mode_select_method='random'):
66-
super(FourierCrossAttention1, self).__init__()
66+
super(FourierCrossAttentionW, self).__init__()
6767
print('corss fourier correlation used!')
6868

6969
"""
@@ -75,51 +75,6 @@ def __init__(self, in_channels, out_channels, seq_len_q, seq_len_kv, modes=16, a
7575
self.modes1 = modes
7676
self.activation = activation
7777

78-
if modes > 10000:
79-
modes2 = modes - 10000
80-
self.index_q0 = list(range(0, min(seq_len_q // 4, modes2 // 2)))
81-
self.index_q1 = list(range(len(self.index_q0), seq_len_q // 2))
82-
np.random.shuffle(self.index_q1)
83-
self.index_q1 = self.index_q1[:min(seq_len_q // 4, modes2 // 2)]
84-
self.index_q = self.index_q0 + self.index_q1
85-
self.index_q.sort()
86-
87-
self.index_k_v0 = list(range(0, min(seq_len_kv // 4, modes2 // 2)))
88-
self.index_k_v1 = list(range(len(self.index_k_v0), seq_len_kv // 2))
89-
np.random.shuffle(self.index_k_v1)
90-
self.index_k_v1 = self.index_k_v1[:min(seq_len_kv // 4, modes2 // 2)]
91-
self.index_k_v = self.index_k_v0 + self.index_k_v1
92-
self.index_k_v.sort()
93-
94-
# elif modes > 1000:
95-
# modes2 = modes - 1000
96-
# self.index_q = list(range(0, seq_len_q // 2))
97-
# np.random.shuffle(self.index_q)
98-
# self.index_q = self.index_q[:modes2]
99-
# self.index_q.sort()
100-
# self.index_k_v = list(range(0, seq_len_kv // 2))
101-
# np.random.shuffle(self.index_k_v)
102-
# self.index_k_v = self.index_k_v[:modes2]
103-
# self.index_k_v.sort()
104-
# elif modes < 0:
105-
# modes2 = abs(modes)
106-
# self.index_q = get_dynamic_modes(seq_len_q, modes2)
107-
# self.index_k_v = list(range(0, min(seq_len_kv // 2, modes2)))
108-
# else:
109-
# self.index_q = list(range(0, min(seq_len_q // 2, modes)))
110-
# self.index_k_v = list(range(0, min(seq_len_kv // 2, modes)))
111-
112-
print('index_q={}'.format(self.index_q))
113-
print('len mode q={}', len(self.index_q))
114-
print('index_k_v={}'.format(self.index_k_v))
115-
print('len mode kv={}', len(self.index_k_v))
116-
117-
self.register_buffer('index_q2', torch.tensor(self.index_q))
118-
119-
# self.scale = (1 / (in_channels * out_channels))
120-
# self.weights1 = nn.Parameter(
121-
# self.scale * torch.rand(8, in_channels // 8, out_channels // 8, len(self.index_q), dtype=torch.cfloat))
122-
12378
def forward(self, q, k, v, mask):
12479
# size = [B, L, H, E]
12580
mask = mask
@@ -179,81 +134,6 @@ def compl_mul1d(x, weights):
179134
return torch.einsum("bix,iox->box", x, weights)
180135

181136

182-
# class Sparsemax(nn.Module):
183-
# """Sparsemax function."""
184-
#
185-
# def __init__(self, dim=None):
186-
# """Initialize sparsemax activation
187-
#
188-
# Args:
189-
# dim (int, optional): The dimension over which to apply the sparsemax function.
190-
# """
191-
# super(Sparsemax, self).__init__()
192-
#
193-
# self.dim = -1 if dim is None else dim
194-
#
195-
# def forward(self, input):
196-
# """Forward function.
197-
# Args:
198-
# input (torch.Tensor): Input tensor. First dimension should be the batch size
199-
# Returns:
200-
# torch.Tensor: [batch_size x number_of_logits] Output tensor
201-
# """
202-
# # Sparsemax currently only handles 2-dim tensors,
203-
# # so we reshape to a convenient shape and reshape back after sparsemax
204-
# input = input.transpose(0, self.dim)
205-
# original_size = input.size()
206-
# input = input.reshape(input.size(0), -1)
207-
# input = input.transpose(0, 1)
208-
# dim = 1
209-
#
210-
# number_of_logits = input.size(dim)
211-
#
212-
# # Translate input by max for numerical stability
213-
# input = input - torch.max(input, dim=dim, keepdim=True)[0].expand_as(input)
214-
#
215-
# # Sort input in descending order.
216-
# # (NOTE: Can be replaced with linear time selection method described here:
217-
# # http://stanford.edu/~jduchi/projects/DuchiShSiCh08.html)
218-
# zs = torch.sort(input=input, dim=dim, descending=True)[0]
219-
# range = torch.arange(start=1, end=number_of_logits + 1, step=1, device=device, dtype=input.dtype).view(1, -1)
220-
# range = range.expand_as(zs)
221-
#
222-
# # Determine sparsity of projection
223-
# bound = 1 + range * zs
224-
# cumulative_sum_zs = torch.cumsum(zs, dim)
225-
# is_gt = torch.gt(bound, cumulative_sum_zs).type(input.type())
226-
# k = torch.max(is_gt * range, dim, keepdim=True)[0]
227-
#
228-
# # Compute threshold function
229-
# zs_sparse = is_gt * zs
230-
#
231-
# # Compute taus
232-
# taus = (torch.sum(zs_sparse, dim, keepdim=True) - 1) / k
233-
# taus = taus.expand_as(input)
234-
#
235-
# # Sparsemax
236-
# self.output = torch.max(torch.zeros_like(input), input - taus)
237-
#
238-
# # Reshape back to original shape
239-
# output = self.output
240-
# output = output.transpose(0, 1)
241-
# output = output.reshape(original_size)
242-
# output = output.transpose(0, self.dim)
243-
#
244-
# return output
245-
#
246-
# def backward(self, grad_output):
247-
# """Backward function."""
248-
# dim = 1
249-
#
250-
# nonzeros = torch.ne(self.output, 0)
251-
# sum = torch.sum(grad_output * nonzeros, dim=dim) / torch.sum(nonzeros, dim=dim)
252-
# self.grad_input = nonzeros * (grad_output - sum.expand_as(grad_output))
253-
#
254-
# return self.grad_input
255-
256-
257137
class sparseKernelFT1d(nn.Module):
258138
def __init__(self,
259139
k, alpha, c=1,
@@ -411,16 +291,16 @@ def __init__(self, in_channels, out_channels, seq_len_q, seq_len_kv, modes, c=64
411291
G1r[np.abs(G1r) < 1e-8] = 0
412292
self.max_item = 3
413293

414-
self.attn1 = FourierCrossAttention(in_channels=in_channels, out_channels=out_channels, seq_len_q=seq_len_q,
294+
self.attn1 = FourierCrossAttentionW(in_channels=in_channels, out_channels=out_channels, seq_len_q=seq_len_q,
415295
seq_len_kv=seq_len_kv, modes=modes, activation=activation,
416296
mode_select_method=mode_select_method)
417-
self.attn2 = FourierCrossAttention(in_channels=in_channels, out_channels=out_channels, seq_len_q=seq_len_q,
297+
self.attn2 = FourierCrossAttentionW(in_channels=in_channels, out_channels=out_channels, seq_len_q=seq_len_q,
418298
seq_len_kv=seq_len_kv, modes=modes, activation=activation,
419299
mode_select_method=mode_select_method)
420-
self.attn3 = FourierCrossAttention(in_channels=in_channels, out_channels=out_channels, seq_len_q=seq_len_q,
300+
self.attn3 = FourierCrossAttentionW(in_channels=in_channels, out_channels=out_channels, seq_len_q=seq_len_q,
421301
seq_len_kv=seq_len_kv, modes=modes, activation=activation,
422302
mode_select_method=mode_select_method)
423-
self.attn4 = FourierCrossAttention(in_channels=in_channels, out_channels=out_channels, seq_len_q=seq_len_q,
303+
self.attn4 = FourierCrossAttentionW(in_channels=in_channels, out_channels=out_channels, seq_len_q=seq_len_q,
424304
seq_len_kv=seq_len_kv, modes=modes, activation=activation,
425305
mode_select_method=mode_select_method)
426306

@@ -443,8 +323,8 @@ def __init__(self, in_channels, out_channels, seq_len_q, seq_len_kv, modes, c=64
443323
self.modes1 = modes
444324

445325
def forward(self, q, k, v, mask=None):
446-
B, N, H, E = q.shape # (B, N, k)
447-
_, S, _, _ = k.shape
326+
B, N, H, E = q.shape # (B, N, H, E) torch.Size([3, 768, 8, 2])
327+
_, S, _, _ = k.shape # (B, S, H, E) torch.Size([3, 96, 8, 2])
448328

449329
q = q.view(q.shape[0], q.shape[1], -1)
450330
k = k.view(k.shape[0], k.shape[1], -1)

0 commit comments

Comments
 (0)