Skip to content

Commit b350914

Browse files
authored
Update Linear_ori.py
1 parent 18e0e0e commit b350914

File tree

1 file changed

+11
-92
lines changed

1 file changed

+11
-92
lines changed

models/Linear_ori.py

+11-92
Original file line numberDiff line numberDiff line change
@@ -33,22 +33,14 @@
3333

3434
# self.drop = nn.Dropout(p=0.1)
3535
# # Use this line if you want to visualize the weights
36-
# # self.Linear.weight = nn.Parameter((1/self.seq_len)*torch.ones([self.pred_len,self.seq_len]))
36+
#
3737
# def forward(self, x):
3838
# # x: [Batch, Input length, Channel]
39-
# # x = self.Linear(x.permute(0,2,1))
39+
#
4040
# x = x.permute(0,2,1) # (B,L,C)=》(B,C,L)
4141
# b, c, l = x.size() # (B,C,L)
42-
# # y = self.avg_pool(x) # (B,C,L) 通过avg=》 (B,C,1)
43-
# # print("y",y.shape)
4442
# y = self.avg_pool(x).view(b, c) # (B,C,L) 通过avg=》 (B,C,1)
45-
# # print("y",y.shape)
46-
# #为了丢给Linear学习,需要view把数据展平开
47-
# # y = self.fc(y).view(b, c, 96)
4843

49-
# y = self.fc(y).view(b,c,1)
50-
# # f_weight_np = y.cpu().detach().numpy()
51-
# z = self.Linear_More_1(x*y+x)
5244

5345
# # np.save('f_weight.npy', f_weight_np)
5446
# # # np.save('%d f_weight.npy' %epoch, f_weight_np)
@@ -68,9 +60,9 @@ def forward(self, x):
6860
x_hat = self.layernorm(x)
6961
bias = torch.mean(x_hat, dim=1).unsqueeze(1).repeat(1, x.shape[1], 1)
7062
return x_hat - bias
71-
class Model(nn.Module):#2022.11.7修改前,这个Model能跑通#forMultivariate
63+
class Model(nn.Module):
7264

73-
def __init__(self,configs,channel=96,ratio=1):#channel针对ili数据集应该改成36 channel=input_length
65+
def __init__(self,configs,channel=96,ratio=1):
7466
super(Model, self).__init__()
7567
# self.avg_pool = nn.AdaptiveAvgPool1d(1) #innovation
7668
self.seq_len = configs.seq_len
@@ -95,7 +87,7 @@ def __init__(self,configs,channel=96,ratio=1):#channel针对ili数据集应该
9587

9688
self.Linear = nn.Linear(self.seq_len, self.pred_len)
9789
self.Linear_1 = nn.Linear(self.seq_len, self.pred_len)
98-
# self.dct_norm = nn.LayerNorm([self.channel_num], eps=1e-6)#作为模块一般normal channel效果好点 for traffic
90+
# self.dct_norm = nn.LayerNorm([self.channel_num], eps=1e-6)
9991
self.dct_norm = nn.LayerNorm(self.seq_len, eps=1e-6)#
10092
# self.my_layer_norm = nn.LayerNorm([96], eps=1e-6)
10193
def forward(self, x):
@@ -106,13 +98,11 @@ def forward(self, x):
10698
b, c, l = x.size() # (B,C,L)
10799
list = []
108100

109-
for i in range(c):#i represent channel ,分别对channel的数据做dct
101+
for i in range(c):#i represent channel
110102
freq=dct.dct(x[:,i,:]) #dct
111-
# freq=torch.fft.fft(x[:,i,:])#fft,rfft的输出长度不与输入对齐 #fft结果不能通过linear
112-
113103
# print("freq-shape:",freq.shape)
114104
list.append(freq)
115-
##把dct结果进行拼接,再进行频率特征学习
105+
116106

117107
stack_dct=torch.stack(list,dim=1)
118108
stack_dct = torch.tensor(stack_dct)#(B,L,C)
@@ -121,27 +111,9 @@ def forward(self, x):
121111
f_weight = self.fc(stack_dct)
122112
f_weight = self.dct_norm(f_weight)#matters for traffic
123113

124-
#wo-dct
125-
# f_weight = self.dct_norm(x)
126-
# f_weight = self.dct_norm(f_weight)#matters for traffic
127-
# f_weight = self.fc(f_weight)
128-
# f_weight = self.dct_norm(f_weight)#matters for traffic
129-
# list_inverse = []
130-
# for i in range(c):
131-
# freq_inverse=dct.idct(f_weight[:,i,:]) #dct
132-
# # freq=torch.fft.fft(x[:,i,:])#fft,rfft的输出长度不与输入对齐 #fft结果不能通过linear
133-
134-
# # print("freq-shape:",freq.shape)
135-
# list_inverse.append(freq_inverse)
136-
# ##把dct结果进行拼接,再进行频率特征学习
137-
# stack_dct_inverse=torch.stack(list_inverse,dim=1)
138-
# stack_dct_inverse = torch.tensor(stack_dct_inverse)#(B,L,C)
139-
# # f_weight = self.fc(x)
140-
# stack_dct_inverse = self.dct_norm(stack_dct_inverse)#matters for traffic
141-
# f_weight_inverse = self.fc_inverse(stack_dct_inverse)
142-
# f_weight_inverse = self.dct_norm(f_weight_inverse)#matters for traffic
143-
144-
#可视化这个频率张量 generalized tensor
114+
115+
116+
#visualization for fecam tensor
145117
f_weight_cpu = f_weight
146118

147119
f_weight_np = f_weight_cpu.cpu().detach().numpy()
@@ -165,65 +137,12 @@ def forward(self, x):
165137
# result = self.Linear((x *(f_weight_inverse)))#forL
166138
result = self.Linear((x *(f_weight)))#forL
167139

168-
# result = result + (1)*torch.mean(result)# for ill 增对数据集而设定的先验知识,可以后续有这方面的思考
140+
# result = result + (1)*torch.mean(result)# for ill
169141
# result_1 = self.Linear_1(x)
170142
# result = result + result_1
171143
# result = self.my_layer_norm(result)
172144

173145
return result.permute(0,2,1)
174-
# return result
175-
176-
# class Model(nn.Module):#2022.11.7修改前,这个Model能跑通#forMultivariate
177-
# def __init__(self,configs,channel=96,ratio=1):#channel针对ili数据集应该改成36 channel=input_length
178-
# super(Model, self).__init__()
179-
# # self.avg_pool = nn.AdaptiveAvgPool1d(1) #innovation
180-
# self.seq_len = configs.seq_len
181-
# self.pred_len = configs.pred_len
182-
# self.fc = nn.Sequential(
183-
# nn.Linear(channel, channel*4, bias=False),
184-
# nn.Dropout(p=0.1),
185-
# nn.ReLU(inplace=True),
186-
# nn.Linear( channel*4, channel, bias=False),
187-
# nn.Sigmoid()
188-
# )
189-
# # self.fc_plot = nn.Linear(channel, channel, bias=False)
190-
# self.mid_Linear = nn.Linear(self.seq_len, self.seq_len)
191-
192-
# self.Linear = nn.Linear(self.seq_len, self.pred_len)
193-
# self.dct_norm = nn.LayerNorm([7], eps=1e-6)#作为模块一般normal channel效果好点 for traffic
194146

195-
# # self.my_layer_norm = nn.LayerNorm([96], eps=1e-6)
196-
# def forward(self, x):
197-
# x = x.permute(0,2,1) # (B,L,C)=》(B,C,L)#forL
198-
199-
200-
# # x_t = self.mid_Linear(x)
201-
# b, c, l = x.size() # (B,C,L)
202-
# list = []
203-
204-
# for i in range(c):#i represent channel ,分别对channel的数据做dct
205-
# freq=dct.dct(x[:,i,:]) #dct
206-
# # freq=torch.fft.fft(x[:,i,:])#fft,rfft的输出长度不与输入对齐 #fft结果不能通过linear
207-
208-
# # print("freq-shape:",freq.shape)
209-
# list.append(freq)
210-
# ##把dct结果进行拼接,再进行频率特征学习
211-
212-
213-
214-
# stack_dct=torch.stack(list,dim=1)
215-
# stack_dct = torch.tensor(stack_dct)#(B,L,C)
216-
# f_weight = self.fc(stack_dct)
217-
218-
219-
# # f_weight = self.dct_norm(f_weight.permute(0,2,1))#matters for traffic
220-
# # result = self.Linear(x *(f_weight.permute(0,2,1)))#forL
221-
222-
# result = self.Linear(x *(f_weight))#forL
223-
224-
225-
# # result = self.my_layer_norm(result)
226-
227-
# return result.permute(0,2,1)
228147

229148

0 commit comments

Comments
 (0)