Skip to content

Commit 05112d5

Browse files
committed
fix conflict
2 parents cad0b90 + 0af96fb commit 05112d5

File tree

6 files changed

+206
-348
lines changed

6 files changed

+206
-348
lines changed

exp/exp_main.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from data_provider.data_factory import data_provider
22
from exp.exp_basic import Exp_Basic
3-
from models import FEDformer, Informer, Autoformer, Transformer # Logformer, Reformer,Transformer_sin,Autoformer_sin
4-
# from models.reformer_pytorch.reformer_pytorch import Reformer
3+
from models import FEDformer, Autoformer, Informer, Transformer
54
from utils.tools import EarlyStopping, adjust_learning_rate, visual
65
from utils.metrics import metric
76

layers/AutoCorrelation.py

+1-75
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,8 @@
11
import time
22
import torch
33
import torch.nn as nn
4-
import torch.nn.functional as F
5-
import matplotlib.pyplot as plt
64
import numpy as np
75
import math
8-
from math import sqrt
9-
import os
10-
# from pytorch_wavelets import DWTForward, DWTInverse, DWT1DForward, DWT1DInverse
116
from torch.nn.functional import interpolate
127

138

@@ -38,36 +33,6 @@ def __init__(self, mask_flag=True, factor=1, scale=None, attention_dropout=0.1,
3833
self.dropout = nn.Dropout(attention_dropout)
3934
self.agg = None
4035
self.use_wavelet = configs.wavelet
41-
# if self.use_wavelet:
42-
# J = 3
43-
# self.dwt1d = DWT1DForward(J=J, wave='db4')
44-
# self.dwt1div = DWT1DInverse(wave='db4')
45-
# self.j_list = [1, 2, 4, 8, 8]
46-
# print('DWTCorrelation used, J={}, j_list={}'.format(J, self.j_list))
47-
48-
# @decor_time
49-
def time_delay_agg_mzq(self, values, corr):
50-
head = values.shape[1]
51-
channel = values.shape[2]
52-
length = values.shape[3]
53-
S = length
54-
# # else:
55-
values = values.transpose(2, 3)
56-
corr = corr.transpose(2, 3)
57-
top_k = int(round(self.factor * np.log(S)))
58-
# Rk = Rk.real
59-
# if version == 3:
60-
# V.size = [B, S, h]
61-
# S = V.shape[1]
62-
V_broad = torch.cat((values, values), dim=-2) # size=[B, H, 2*S, h]
63-
V_rolled = V_broad.unfold(-2, S, 1) # size=[B, H, S+1, h, S]
64-
# Rk.size = [B, S, h]
65-
Rk_kthsmallest = torch.kthvalue(corr, k=S - top_k, dim=-2, keepdim=True) # size=[B, H, 1, h]
66-
mask = corr > torch.repeat_interleave(Rk_kthsmallest[0], repeats=S, dim=-2)
67-
corr = torch.softmax(corr * mask, dim=-1) # size = [B, H, S, h]
68-
output = torch.einsum('beshi,besh->beih', V_rolled[:, :, 1:, :], corr) # .transpose(1, 2)
69-
# [B, H, S+1, h, S] * [B, H, S, h]
70-
return output.transpose(2, 3) # size=[batch, seq_len, h_dim]
7136

7237
# @decor_time
7338
def time_delay_agg_training(self, values, corr):
@@ -166,8 +131,6 @@ def forward(self, queries, keys, values, attn_mask):
166131
keys = keys.reshape([B, L, -1])
167132
Ql, Qh_list = self.dwt1d(queries.transpose(1, 2)) # [B, H*D, L]
168133
Kl, Kh_list = self.dwt1d(keys.transpose(1, 2))
169-
# n = queries.shape[1]
170-
# B = queries.shape[0]
171134
qs = [queries.transpose(1, 2)] + Qh_list + [Ql] # [B, H*D, L]
172135
ks = [keys.transpose(1, 2)] + Kh_list + [Kl]
173136
q_list = []
@@ -186,16 +149,11 @@ def forward(self, queries, keys, values, attn_mask):
186149

187150
# time delay agg
188151
if self.training:
189-
# if self.agg == 'thuml':
190152
V = self.time_delay_agg_training(values.permute(0, 2, 3, 1).contiguous(), corr).permute(0, 3, 1, 2) # [B, L, H, E], [B, H, E, L] -> [B, L, H, E]
191-
# elif self.agg == 'mzq':
192-
# V = self.time_delay_agg_mzq(values.permute(0, 2, 3, 1).contiguous(), corr).permute(0, 3, 1, 2)
193153
else:
194154
V = self.time_delay_agg_inference(values.permute(0, 2, 3, 1).contiguous(), corr).permute(0, 3, 1, 2)
195-
196155
else:
197156
V_list = []
198-
j_list = self.j_list
199157
queries = queries.reshape([B, L, -1])
200158
keys = keys.reshape([B, L, -1])
201159
values = values.reshape([B, L, -1])
@@ -261,36 +219,4 @@ def forward(self, queries, keys, values, attn_mask):
261219
)
262220

263221
out = out.view(B, L, -1)
264-
return self.out_projection(out), attn
265-
266-
267-
if __name__ == '__main__':
268-
class Configs(object):
269-
wavelet = 2
270-
271-
configs = Configs()
272-
B = 3
273-
H = 2
274-
S = 240
275-
d = 16
276-
x = torch.randn([B, S, H, d])
277-
model1 = AutoCorrelation(configs=configs)
278-
model1.training = 1
279-
model1.factor = 3
280-
# model1.agg = 'thuml'
281-
#
282-
# model2 = AutoCorrelation()
283-
# model2.training = 1
284-
# model2.factor = 3
285-
# model2.agg = 'mzq'
286-
out1 = model1.forward(x, x, x, 1)
287-
# out2 = model2.forward(x, x, x, 1)
288-
# diff = out1[0] - out2[0]
289-
290-
# for S in 96, 480, 2400:
291-
# print('========{}========='.format(S))
292-
# x = torch.randn([B, S, H, d])
293-
# for i in range(0, 3):
294-
# out1 = model1.forward(x, x, x, 1)
295-
# out2 = model2.forward(x, x, x, 1)
296-
a = 1
222+
return self.out_projection(out), attn

layers/FourierCorrelation.py

+18-48
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,14 @@
55
import numpy as np
66
import torch
77
import torch.nn as nn
8-
import torch.nn.functional as F
9-
from torch.nn.parameter import Parameter
10-
from utils.masking import LocalMask
118

129

1310
def get_frequency_modes(seq_len, modes=64, mode_select_method='random'):
11+
"""
12+
get modes on frequency domain:
13+
'random' means sampling randomly;
14+
'else' means sampling the lowest modes;
15+
"""
1416
modes = min(modes, seq_len//2)
1517
if mode_select_method == 'random':
1618
index = list(range(0, seq_len // 2))
@@ -28,8 +30,10 @@ def __init__(self, in_channels, out_channels, seq_len, modes=0, mode_select_meth
2830
super(FourierBlock, self).__init__()
2931
print('fourier enhanced block used!')
3032
"""
31-
1D Fourier layer. It does FFT, linear transform, and Inverse FFT.
33+
1D Fourier block. It performs representation learning on frequency domain,
34+
it does FFT, linear transform, and Inverse FFT.
3235
"""
36+
# get modes on frequency domain
3337
self.index = get_frequency_modes(seq_len, modes=modes, mode_select_method=mode_select_method)
3438
print('modes={}, index={}'.format(modes, self.index))
3539

@@ -44,63 +48,32 @@ def compl_mul1d(self, input, weights):
4448

4549
def forward(self, q, k, v, mask):
4650
# size = [B, L, H, E]
47-
k = k
48-
v = v
49-
mask = mask
5051
B, L, H, E = q.shape
5152
x = q.permute(0, 2, 3, 1)
52-
# batchsize = B
53-
# Compute Fourier coeffcients up to factor of e^(- something constant)
53+
# Compute Fourier coefficients
5454
x_ft = torch.fft.rfft(x, dim=-1)
55-
#out_ft = torch.zeros(B, H, E, L // 2 + 1, device=x.device, dtype=torch.cfloat)
56-
# if len(self.index)==0:
57-
# out_ft = torch.zeros(B, H, E, L // 2 + 1, device=x.device, dtype=torch.cfloat)
58-
# else:
59-
# out_ft = torch.zeros(B, H, E, len(self.index), device=x.device, dtype=torch.cfloat)
55+
# Perform Fourier neural operations
6056
out_ft = torch.zeros(B, H, E, L // 2 + 1, device=x.device, dtype=torch.cfloat)
61-
62-
# Multiply relevant Fourier modes
63-
# 取guided modes的版本
64-
# print('x shape',x.shape)
65-
# print('out_ft shape',out_ft.shape)
66-
# print('x_ft shape',x_ft.shape)
67-
# print('weight shape',self.weights1.shape)
68-
# print('self index',self.index)
6957
for wi, i in enumerate(self.index):
7058
out_ft[:, :, :, wi] = self.compl_mul1d(x_ft[:, :, :, i], self.weights1[:, :, :, wi])
71-
72-
# 取topk的modes版本
73-
# topk = torch.topk(torch.sum(x_ft, dim=[0, 1, 2]).abs(), dim=-1, k=self.modes1)
74-
# energy = (topk[0]**2).sum()
75-
# energy90 = 0
76-
# for index, j in enumerate(topk[0]):
77-
# energy90 += j**2
78-
# if energy90 >= energy * 0.9:
79-
# break
80-
# for i in topk[1][:index]:
81-
# out_ft[:, :, :, i] = self.compl_mul1d(x_ft[:, :, :, i], self.weights1[:, :, :, i])
82-
83-
# Return to physical space
59+
# Return to time domain
8460
x = torch.fft.irfft(out_ft, n=x.size(-1))
85-
#max_len = min(720,x.size(-1))
86-
#x = torch.fft.irfft(out_ft, n=max_len)
87-
# size = [B, L, H, E]
8861
return (x, None)
8962

9063

91-
# ########## Cross Fourier Former ####################
64+
# ########## Fourier Cross Former ####################
9265
class FourierCrossAttention(nn.Module):
9366
def __init__(self, in_channels, out_channels, seq_len_q, seq_len_kv, modes=64, mode_select_method='random',
9467
activation='tanh', policy=0):
9568
super(FourierCrossAttention, self).__init__()
9669
print(' fourier enhanced cross attention used!')
97-
9870
"""
99-
1D Fourier layer. It does FFT, linear transform, and Inverse FFT.
71+
1D Fourier Cross Attention layer. It does FFT, linear transform, attention mechanism and Inverse FFT.
10072
"""
10173
self.activation = activation
10274
self.in_channels = in_channels
10375
self.out_channels = out_channels
76+
# get modes for queries and keys (& values) on frequency domain
10477
self.index_q = get_frequency_modes(seq_len_q, modes=modes, mode_select_method=mode_select_method)
10578
self.index_kv = get_frequency_modes(seq_len_kv, modes=modes, mode_select_method=mode_select_method)
10679

@@ -118,23 +91,22 @@ def compl_mul1d(self, input, weights):
11891

11992
def forward(self, q, k, v, mask):
12093
# size = [B, L, H, E]
121-
mask = mask
12294
B, L, H, E = q.shape
123-
xq = q.permute(0, 2, 3, 1) # size = [B, H, E, L]
95+
xq = q.permute(0, 2, 3, 1) # size = [B, H, E, L]
12496
xk = k.permute(0, 2, 3, 1)
12597
xv = v.permute(0, 2, 3, 1)
12698

127-
# Compute Fourier coeffcients up to factor of e^(- something constant)
99+
# Compute Fourier coefficients
128100
xq_ft_ = torch.zeros(B, H, E, len(self.index_q)+1, device=xq.device, dtype=torch.cfloat)
129101
xq_ft = torch.fft.rfft(xq, dim=-1)
130102
for i, j in enumerate(self.index_q):
131103
xq_ft_[:, :, :, i] = xq_ft[:, :, :, j]
132-
133104
xk_ft_ = torch.zeros(B, H, E, len(self.index_kv), device=xq.device, dtype=torch.cfloat)
134105
xk_ft = torch.fft.rfft(xk, dim=-1)
135106
for i, j in enumerate(self.index_kv):
136107
xk_ft_[:, :, :, i] = xk_ft[:, :, :, j]
137108

109+
# perform attention mechanism on frequency domain
138110
xqk_ft = (torch.einsum("bhex,bhey->bhxy", xq_ft_, xk_ft_))
139111
if self.activation == 'tanh':
140112
xqk_ft = xqk_ft.tanh()
@@ -143,15 +115,13 @@ def forward(self, q, k, v, mask):
143115
xqk_ft = torch.complex(xqk_ft, torch.zeros_like(xqk_ft))
144116
else:
145117
raise Exception('{} actiation function is not implemented'.format(self.activation))
146-
147118
xqkv_ft = torch.einsum("bhxy,bhey->bhex", xqk_ft, xk_ft_)
148119
xqkvw = torch.einsum("bhex,heox->bhox", xqkv_ft, self.weights1)
149120
out_ft = torch.zeros(B, H, E, L // 2 + 1, device=xq.device, dtype=torch.cfloat)
150121
for i, j in enumerate(self.index_q):
151122
out_ft[:, :, :, j] = xqkvw[:, :, :, i]
123+
# Return to time domain
152124
out = torch.fft.irfft(out_ft / self.in_channels / self.out_channels, n=xq.size(-1))
153-
# size = [B, L, H, E]
154-
155125
return (out, None)
156126

157127

0 commit comments

Comments
 (0)