5
5
import numpy as np
6
6
import torch
7
7
import torch .nn as nn
8
- import torch .nn .functional as F
9
- from torch .nn .parameter import Parameter
10
- from utils .masking import LocalMask
11
8
12
9
13
10
def get_frequency_modes (seq_len , modes = 64 , mode_select_method = 'random' ):
11
+ """
12
+ get modes on frequency domain:
13
+ 'random' means sampling randomly;
14
+ 'else' means sampling the lowest modes;
15
+ """
14
16
modes = min (modes , seq_len // 2 )
15
17
if mode_select_method == 'random' :
16
18
index = list (range (0 , seq_len // 2 ))
@@ -28,8 +30,10 @@ def __init__(self, in_channels, out_channels, seq_len, modes=0, mode_select_meth
28
30
super (FourierBlock , self ).__init__ ()
29
31
print ('fourier enhanced block used!' )
30
32
"""
31
- 1D Fourier layer. It does FFT, linear transform, and Inverse FFT.
33
+ 1D Fourier block. It performs representation learning on frequency domain,
34
+ it does FFT, linear transform, and Inverse FFT.
32
35
"""
36
+ # get modes on frequency domain
33
37
self .index = get_frequency_modes (seq_len , modes = modes , mode_select_method = mode_select_method )
34
38
print ('modes={}, index={}' .format (modes , self .index ))
35
39
@@ -44,63 +48,32 @@ def compl_mul1d(self, input, weights):
44
48
45
49
def forward (self , q , k , v , mask ):
46
50
# size = [B, L, H, E]
47
- k = k
48
- v = v
49
- mask = mask
50
51
B , L , H , E = q .shape
51
52
x = q .permute (0 , 2 , 3 , 1 )
52
- # batchsize = B
53
- # Compute Fourier coeffcients up to factor of e^(- something constant)
53
+ # Compute Fourier coefficients
54
54
x_ft = torch .fft .rfft (x , dim = - 1 )
55
- #out_ft = torch.zeros(B, H, E, L // 2 + 1, device=x.device, dtype=torch.cfloat)
56
- # if len(self.index)==0:
57
- # out_ft = torch.zeros(B, H, E, L // 2 + 1, device=x.device, dtype=torch.cfloat)
58
- # else:
59
- # out_ft = torch.zeros(B, H, E, len(self.index), device=x.device, dtype=torch.cfloat)
55
+ # Perform Fourier neural operations
60
56
out_ft = torch .zeros (B , H , E , L // 2 + 1 , device = x .device , dtype = torch .cfloat )
61
-
62
- # Multiply relevant Fourier modes
63
- # 取guided modes的版本
64
- # print('x shape',x.shape)
65
- # print('out_ft shape',out_ft.shape)
66
- # print('x_ft shape',x_ft.shape)
67
- # print('weight shape',self.weights1.shape)
68
- # print('self index',self.index)
69
57
for wi , i in enumerate (self .index ):
70
58
out_ft [:, :, :, wi ] = self .compl_mul1d (x_ft [:, :, :, i ], self .weights1 [:, :, :, wi ])
71
-
72
- # 取topk的modes版本
73
- # topk = torch.topk(torch.sum(x_ft, dim=[0, 1, 2]).abs(), dim=-1, k=self.modes1)
74
- # energy = (topk[0]**2).sum()
75
- # energy90 = 0
76
- # for index, j in enumerate(topk[0]):
77
- # energy90 += j**2
78
- # if energy90 >= energy * 0.9:
79
- # break
80
- # for i in topk[1][:index]:
81
- # out_ft[:, :, :, i] = self.compl_mul1d(x_ft[:, :, :, i], self.weights1[:, :, :, i])
82
-
83
- # Return to physical space
59
+ # Return to time domain
84
60
x = torch .fft .irfft (out_ft , n = x .size (- 1 ))
85
- #max_len = min(720,x.size(-1))
86
- #x = torch.fft.irfft(out_ft, n=max_len)
87
- # size = [B, L, H, E]
88
61
return (x , None )
89
62
90
63
91
- # ########## Cross Fourier Former ####################
64
+ # ########## Fourier Cross Former ####################
92
65
class FourierCrossAttention (nn .Module ):
93
66
def __init__ (self , in_channels , out_channels , seq_len_q , seq_len_kv , modes = 64 , mode_select_method = 'random' ,
94
67
activation = 'tanh' , policy = 0 ):
95
68
super (FourierCrossAttention , self ).__init__ ()
96
69
print (' fourier enhanced cross attention used!' )
97
-
98
70
"""
99
- 1D Fourier layer. It does FFT, linear transform, and Inverse FFT.
71
+ 1D Fourier Cross Attention layer. It does FFT, linear transform, attention mechanism and Inverse FFT.
100
72
"""
101
73
self .activation = activation
102
74
self .in_channels = in_channels
103
75
self .out_channels = out_channels
76
+ # get modes for queries and keys (& values) on frequency domain
104
77
self .index_q = get_frequency_modes (seq_len_q , modes = modes , mode_select_method = mode_select_method )
105
78
self .index_kv = get_frequency_modes (seq_len_kv , modes = modes , mode_select_method = mode_select_method )
106
79
@@ -118,23 +91,22 @@ def compl_mul1d(self, input, weights):
118
91
119
92
def forward (self , q , k , v , mask ):
120
93
# size = [B, L, H, E]
121
- mask = mask
122
94
B , L , H , E = q .shape
123
- xq = q .permute (0 , 2 , 3 , 1 ) # size = [B, H, E, L]
95
+ xq = q .permute (0 , 2 , 3 , 1 ) # size = [B, H, E, L]
124
96
xk = k .permute (0 , 2 , 3 , 1 )
125
97
xv = v .permute (0 , 2 , 3 , 1 )
126
98
127
- # Compute Fourier coeffcients up to factor of e^(- something constant)
99
+ # Compute Fourier coefficients
128
100
xq_ft_ = torch .zeros (B , H , E , len (self .index_q )+ 1 , device = xq .device , dtype = torch .cfloat )
129
101
xq_ft = torch .fft .rfft (xq , dim = - 1 )
130
102
for i , j in enumerate (self .index_q ):
131
103
xq_ft_ [:, :, :, i ] = xq_ft [:, :, :, j ]
132
-
133
104
xk_ft_ = torch .zeros (B , H , E , len (self .index_kv ), device = xq .device , dtype = torch .cfloat )
134
105
xk_ft = torch .fft .rfft (xk , dim = - 1 )
135
106
for i , j in enumerate (self .index_kv ):
136
107
xk_ft_ [:, :, :, i ] = xk_ft [:, :, :, j ]
137
108
109
+ # perform attention mechanism on frequency domain
138
110
xqk_ft = (torch .einsum ("bhex,bhey->bhxy" , xq_ft_ , xk_ft_ ))
139
111
if self .activation == 'tanh' :
140
112
xqk_ft = xqk_ft .tanh ()
@@ -143,15 +115,13 @@ def forward(self, q, k, v, mask):
143
115
xqk_ft = torch .complex (xqk_ft , torch .zeros_like (xqk_ft ))
144
116
else :
145
117
raise Exception ('{} actiation function is not implemented' .format (self .activation ))
146
-
147
118
xqkv_ft = torch .einsum ("bhxy,bhey->bhex" , xqk_ft , xk_ft_ )
148
119
xqkvw = torch .einsum ("bhex,heox->bhox" , xqkv_ft , self .weights1 )
149
120
out_ft = torch .zeros (B , H , E , L // 2 + 1 , device = xq .device , dtype = torch .cfloat )
150
121
for i , j in enumerate (self .index_q ):
151
122
out_ft [:, :, :, j ] = xqkvw [:, :, :, i ]
123
+ # Return to time domain
152
124
out = torch .fft .irfft (out_ft / self .in_channels / self .out_channels , n = xq .size (- 1 ))
153
- # size = [B, L, H, E]
154
-
155
125
return (out , None )
156
126
157
127
0 commit comments