15
15
from layers .utils import get_filter
16
16
17
17
18
- from layers .FourierCorrelation import FourierBlock , FourierCrossAttention
18
+ # from layers.FourierCorrelation import FourierBlock, # FourierCrossAttention
19
19
20
20
21
21
device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
@@ -60,10 +60,10 @@ def forward(self, queries, keys, values, attn_mask):
60
60
return (V .contiguous (), None )
61
61
62
62
63
- class FourierCrossAttention1 (nn .Module ):
63
+ class FourierCrossAttentionW (nn .Module ):
64
64
def __init__ (self , in_channels , out_channels , seq_len_q , seq_len_kv , modes = 16 , activation = 'tanh' ,
65
65
mode_select_method = 'random' ):
66
- super (FourierCrossAttention1 , self ).__init__ ()
66
+ super (FourierCrossAttentionW , self ).__init__ ()
67
67
print ('corss fourier correlation used!' )
68
68
69
69
"""
@@ -75,51 +75,6 @@ def __init__(self, in_channels, out_channels, seq_len_q, seq_len_kv, modes=16, a
75
75
self .modes1 = modes
76
76
self .activation = activation
77
77
78
- if modes > 10000 :
79
- modes2 = modes - 10000
80
- self .index_q0 = list (range (0 , min (seq_len_q // 4 , modes2 // 2 )))
81
- self .index_q1 = list (range (len (self .index_q0 ), seq_len_q // 2 ))
82
- np .random .shuffle (self .index_q1 )
83
- self .index_q1 = self .index_q1 [:min (seq_len_q // 4 , modes2 // 2 )]
84
- self .index_q = self .index_q0 + self .index_q1
85
- self .index_q .sort ()
86
-
87
- self .index_k_v0 = list (range (0 , min (seq_len_kv // 4 , modes2 // 2 )))
88
- self .index_k_v1 = list (range (len (self .index_k_v0 ), seq_len_kv // 2 ))
89
- np .random .shuffle (self .index_k_v1 )
90
- self .index_k_v1 = self .index_k_v1 [:min (seq_len_kv // 4 , modes2 // 2 )]
91
- self .index_k_v = self .index_k_v0 + self .index_k_v1
92
- self .index_k_v .sort ()
93
-
94
- # elif modes > 1000:
95
- # modes2 = modes - 1000
96
- # self.index_q = list(range(0, seq_len_q // 2))
97
- # np.random.shuffle(self.index_q)
98
- # self.index_q = self.index_q[:modes2]
99
- # self.index_q.sort()
100
- # self.index_k_v = list(range(0, seq_len_kv // 2))
101
- # np.random.shuffle(self.index_k_v)
102
- # self.index_k_v = self.index_k_v[:modes2]
103
- # self.index_k_v.sort()
104
- # elif modes < 0:
105
- # modes2 = abs(modes)
106
- # self.index_q = get_dynamic_modes(seq_len_q, modes2)
107
- # self.index_k_v = list(range(0, min(seq_len_kv // 2, modes2)))
108
- # else:
109
- # self.index_q = list(range(0, min(seq_len_q // 2, modes)))
110
- # self.index_k_v = list(range(0, min(seq_len_kv // 2, modes)))
111
-
112
- print ('index_q={}' .format (self .index_q ))
113
- print ('len mode q={}' , len (self .index_q ))
114
- print ('index_k_v={}' .format (self .index_k_v ))
115
- print ('len mode kv={}' , len (self .index_k_v ))
116
-
117
- self .register_buffer ('index_q2' , torch .tensor (self .index_q ))
118
-
119
- # self.scale = (1 / (in_channels * out_channels))
120
- # self.weights1 = nn.Parameter(
121
- # self.scale * torch.rand(8, in_channels // 8, out_channels // 8, len(self.index_q), dtype=torch.cfloat))
122
-
123
78
def forward (self , q , k , v , mask ):
124
79
# size = [B, L, H, E]
125
80
mask = mask
@@ -179,81 +134,6 @@ def compl_mul1d(x, weights):
179
134
return torch .einsum ("bix,iox->box" , x , weights )
180
135
181
136
182
- # class Sparsemax(nn.Module):
183
- # """Sparsemax function."""
184
- #
185
- # def __init__(self, dim=None):
186
- # """Initialize sparsemax activation
187
- #
188
- # Args:
189
- # dim (int, optional): The dimension over which to apply the sparsemax function.
190
- # """
191
- # super(Sparsemax, self).__init__()
192
- #
193
- # self.dim = -1 if dim is None else dim
194
- #
195
- # def forward(self, input):
196
- # """Forward function.
197
- # Args:
198
- # input (torch.Tensor): Input tensor. First dimension should be the batch size
199
- # Returns:
200
- # torch.Tensor: [batch_size x number_of_logits] Output tensor
201
- # """
202
- # # Sparsemax currently only handles 2-dim tensors,
203
- # # so we reshape to a convenient shape and reshape back after sparsemax
204
- # input = input.transpose(0, self.dim)
205
- # original_size = input.size()
206
- # input = input.reshape(input.size(0), -1)
207
- # input = input.transpose(0, 1)
208
- # dim = 1
209
- #
210
- # number_of_logits = input.size(dim)
211
- #
212
- # # Translate input by max for numerical stability
213
- # input = input - torch.max(input, dim=dim, keepdim=True)[0].expand_as(input)
214
- #
215
- # # Sort input in descending order.
216
- # # (NOTE: Can be replaced with linear time selection method described here:
217
- # # http://stanford.edu/~jduchi/projects/DuchiShSiCh08.html)
218
- # zs = torch.sort(input=input, dim=dim, descending=True)[0]
219
- # range = torch.arange(start=1, end=number_of_logits + 1, step=1, device=device, dtype=input.dtype).view(1, -1)
220
- # range = range.expand_as(zs)
221
- #
222
- # # Determine sparsity of projection
223
- # bound = 1 + range * zs
224
- # cumulative_sum_zs = torch.cumsum(zs, dim)
225
- # is_gt = torch.gt(bound, cumulative_sum_zs).type(input.type())
226
- # k = torch.max(is_gt * range, dim, keepdim=True)[0]
227
- #
228
- # # Compute threshold function
229
- # zs_sparse = is_gt * zs
230
- #
231
- # # Compute taus
232
- # taus = (torch.sum(zs_sparse, dim, keepdim=True) - 1) / k
233
- # taus = taus.expand_as(input)
234
- #
235
- # # Sparsemax
236
- # self.output = torch.max(torch.zeros_like(input), input - taus)
237
- #
238
- # # Reshape back to original shape
239
- # output = self.output
240
- # output = output.transpose(0, 1)
241
- # output = output.reshape(original_size)
242
- # output = output.transpose(0, self.dim)
243
- #
244
- # return output
245
- #
246
- # def backward(self, grad_output):
247
- # """Backward function."""
248
- # dim = 1
249
- #
250
- # nonzeros = torch.ne(self.output, 0)
251
- # sum = torch.sum(grad_output * nonzeros, dim=dim) / torch.sum(nonzeros, dim=dim)
252
- # self.grad_input = nonzeros * (grad_output - sum.expand_as(grad_output))
253
- #
254
- # return self.grad_input
255
-
256
-
257
137
class sparseKernelFT1d (nn .Module ):
258
138
def __init__ (self ,
259
139
k , alpha , c = 1 ,
@@ -411,16 +291,16 @@ def __init__(self, in_channels, out_channels, seq_len_q, seq_len_kv, modes, c=64
411
291
G1r [np .abs (G1r ) < 1e-8 ] = 0
412
292
self .max_item = 3
413
293
414
- self .attn1 = FourierCrossAttention (in_channels = in_channels , out_channels = out_channels , seq_len_q = seq_len_q ,
294
+ self .attn1 = FourierCrossAttentionW (in_channels = in_channels , out_channels = out_channels , seq_len_q = seq_len_q ,
415
295
seq_len_kv = seq_len_kv , modes = modes , activation = activation ,
416
296
mode_select_method = mode_select_method )
417
- self .attn2 = FourierCrossAttention (in_channels = in_channels , out_channels = out_channels , seq_len_q = seq_len_q ,
297
+ self .attn2 = FourierCrossAttentionW (in_channels = in_channels , out_channels = out_channels , seq_len_q = seq_len_q ,
418
298
seq_len_kv = seq_len_kv , modes = modes , activation = activation ,
419
299
mode_select_method = mode_select_method )
420
- self .attn3 = FourierCrossAttention (in_channels = in_channels , out_channels = out_channels , seq_len_q = seq_len_q ,
300
+ self .attn3 = FourierCrossAttentionW (in_channels = in_channels , out_channels = out_channels , seq_len_q = seq_len_q ,
421
301
seq_len_kv = seq_len_kv , modes = modes , activation = activation ,
422
302
mode_select_method = mode_select_method )
423
- self .attn4 = FourierCrossAttention (in_channels = in_channels , out_channels = out_channels , seq_len_q = seq_len_q ,
303
+ self .attn4 = FourierCrossAttentionW (in_channels = in_channels , out_channels = out_channels , seq_len_q = seq_len_q ,
424
304
seq_len_kv = seq_len_kv , modes = modes , activation = activation ,
425
305
mode_select_method = mode_select_method )
426
306
@@ -443,8 +323,8 @@ def __init__(self, in_channels, out_channels, seq_len_q, seq_len_kv, modes, c=64
443
323
self .modes1 = modes
444
324
445
325
def forward (self , q , k , v , mask = None ):
446
- B , N , H , E = q .shape # (B, N, k )
447
- _ , S , _ , _ = k .shape
326
+ B , N , H , E = q .shape # (B, N, H, E) torch.Size([3, 768, 8, 2] )
327
+ _ , S , _ , _ = k .shape # (B, S, H, E) torch.Size([3, 96, 8, 2])
448
328
449
329
q = q .view (q .shape [0 ], q .shape [1 ], - 1 )
450
330
k = k .view (k .shape [0 ], k .shape [1 ], - 1 )
0 commit comments