Skip to content

Commit a4aa215

Browse files
authored
Calculating parameter for deep learning models.
Calculating parameter for deep learning models.
1 parent bd69629 commit a4aa215

File tree

1 file changed

+216
-0
lines changed

1 file changed

+216
-0
lines changed

count_parameter.ipynb

+216
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,216 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": 16,
6+
"metadata": {},
7+
"outputs": [
8+
{
9+
"name": "stdout",
10+
"output_type": "stream",
11+
"text": [
12+
"| name | #elements or shape |\n",
13+
"|:-----------------------|:---------------------|\n",
14+
"| model | 37.3K |\n",
15+
"| fc | 0.2K |\n",
16+
"| fc.0 | 98 |\n",
17+
"| fc.0.weight | (14, 7) |\n",
18+
"| fc.3 | 98 |\n",
19+
"| fc.3.weight | (7, 14) |\n",
20+
"| Linear_More_1 | 18.6K |\n",
21+
"| Linear_More_1.weight | (192, 96) |\n",
22+
"| Linear_More_1.bias | (192,) |\n",
23+
"| Linear_More_2 | 18.5K |\n",
24+
"| Linear_More_2.weight | (96, 192) |\n",
25+
"| Linear_More_2.bias | (96,) |\n"
26+
]
27+
}
28+
],
29+
"source": [
30+
"import torch\n",
31+
"import torch.nn as nn\n",
32+
"import torch.nn.functional as F\n",
33+
"import numpy as np\n",
34+
"from fvcore.nn import FlopCountAnalysis,parameter_count_table\n",
35+
"class Model(nn.Module):\n",
36+
" \"\"\"\n",
37+
" Just one Linear layer\n",
38+
" \"\"\"\n",
39+
" def __init__(self, channel=7,ratio=1):\n",
40+
" super(Model, self).__init__()\n",
41+
"\n",
42+
" self.avg_pool = nn.AdaptiveAvgPool1d(1) #innovation\n",
43+
" self.fc = nn.Sequential(\n",
44+
" nn.Linear(7,14, bias=False),\n",
45+
" nn.Dropout(p=0.1),\n",
46+
" nn.ReLU(inplace=True) ,\n",
47+
" nn.Linear(14, 7, bias=False),\n",
48+
" nn.Sigmoid()\n",
49+
" )\n",
50+
" self.seq_len = 96\n",
51+
" self.pred_len = 96\n",
52+
" self.Linear_More_1 = nn.Linear(self.seq_len,self.pred_len * 2)\n",
53+
" self.Linear_More_2 = nn.Linear(self.pred_len*2,self.pred_len)\n",
54+
" self.relu = nn.ReLU()\n",
55+
" self.gelu = nn.GELU() \n",
56+
"\n",
57+
" self.drop = nn.Dropout(p=0.1)\n",
58+
" # Use this line if you want to visualize the weights\n",
59+
" # self.Linear.weight = nn.Parameter((1/self.seq_len)*torch.ones([self.pred_len,self.seq_len]))\n",
60+
" def forward(self, x):\n",
61+
" # x: [Batch, Input length, Channel]\n",
62+
" # x = self.Linear(x.permute(0,2,1))\n",
63+
" x = x.permute(0,2,1) # (B,L,C)=》(B,C,L)\n",
64+
" b, c, l = x.size() # (B,C,L)\n",
65+
" # y = self.avg_pool(x) # (B,C,L) 通过avg=》 (B,C,1)\n",
66+
" # print(\"y\",y.shape)\n",
67+
" y = self.avg_pool(x).view(b, c) # (B,C,L) 通过avg=》 (B,C,1)\n",
68+
" # print(\"y\",y.shape)\n",
69+
" #为了丢给Linear学习,需要view把数据展平开\n",
70+
" # y = self.fc(y).view(b, c, 96)\n",
71+
" \n",
72+
" y = self.fc(y).view(b,c,1)\n",
73+
"\n",
74+
" # print(\"y\",y.shape)\n",
75+
" return (x * y).permute(0,2,1)\n",
76+
"model = Model()\n",
77+
"print(parameter_count_table(model))"
78+
]
79+
},
80+
{
81+
"cell_type": "code",
82+
"execution_count": 17,
83+
"metadata": {},
84+
"outputs": [
85+
{
86+
"name": "stderr",
87+
"output_type": "stream",
88+
"text": [
89+
"Unsupported operator aten::adaptive_avg_pool1d encountered 1 time(s)\n",
90+
"Unsupported operator aten::sigmoid encountered 1 time(s)\n",
91+
"Unsupported operator aten::mul encountered 1 time(s)\n",
92+
"The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.\n",
93+
"Linear_More_1, Linear_More_2, drop, gelu, relu\n"
94+
]
95+
},
96+
{
97+
"name": "stdout",
98+
"output_type": "stream",
99+
"text": [
100+
"FLOPs: 1568\n"
101+
]
102+
}
103+
],
104+
"source": [
105+
"'''\n",
106+
"计算FLOPs\n",
107+
"'''\n",
108+
"tensor=torch.randn(8,96,7)\n",
109+
"FLOPs = FlopCountAnalysis(model,tensor)\n",
110+
"print(\"FLOPs:\",FLOPs.total())"
111+
]
112+
},
113+
{
114+
"cell_type": "code",
115+
"execution_count": 20,
116+
"metadata": {},
117+
"outputs": [
118+
{
119+
"name": "stdout",
120+
"output_type": "stream",
121+
"text": [
122+
"| name | #elements or shape |\n",
123+
"|:----------------|:---------------------|\n",
124+
"| model | 69.8K |\n",
125+
"| Linear | 69.8K |\n",
126+
"| Linear.weight | (720, 96) |\n",
127+
"| Linear.bias | (720,) |\n"
128+
]
129+
}
130+
],
131+
"source": [
132+
"class Model(nn.Module):\n",
133+
" \"\"\"\n",
134+
" Normalization-Linear\n",
135+
" \"\"\"\n",
136+
" def __init__(self):\n",
137+
" super(Model, self).__init__()\n",
138+
" self.seq_len = 96\n",
139+
" self.pred_len = 720\n",
140+
" self.Linear = nn.Linear(self.seq_len, self.pred_len)\n",
141+
" # Use this line if you want to visualize the weights\n",
142+
" # self.Linear.weight = nn.Parameter((1/self.seq_len)*torch.ones([self.pred_len,self.seq_len]))\n",
143+
"\n",
144+
" def forward(self, x):\n",
145+
" # x: [Batch, Input length, Channel]\n",
146+
"\n",
147+
" x = self.Linear(x.permute(0,2,1)).permute(0,2,1)\n",
148+
" \n",
149+
" return x # [Batch, Output length, Channel]\n",
150+
"model = Model()\n",
151+
"\n",
152+
"\n",
153+
"\n",
154+
"print(parameter_count_table(model))"
155+
]
156+
},
157+
{
158+
"cell_type": "code",
159+
"execution_count": 42,
160+
"metadata": {},
161+
"outputs": [
162+
{
163+
"data": {
164+
"text/plain": [
165+
"'\\n计算FLOPs\\n'"
166+
]
167+
},
168+
"execution_count": 42,
169+
"metadata": {},
170+
"output_type": "execute_result"
171+
}
172+
],
173+
"source": [
174+
"'''\n",
175+
"计算FLOPs\n",
176+
"'''\n",
177+
"# tensor=torch.randn(,96,7)\n",
178+
"result =model(tensor)\n",
179+
"# FLOPs = FlopCountAnalysis(model,tensor)\n",
180+
"# print(\"FLOPs:\",FLOPs.total())"
181+
]
182+
},
183+
{
184+
"cell_type": "code",
185+
"execution_count": 49,
186+
"metadata": {},
187+
"outputs": [],
188+
"source": []
189+
}
190+
],
191+
"metadata": {
192+
"interpreter": {
193+
"hash": "f57785bf53e86c458d31dd8512073d1ac6cae98f342ec9a1a9a8506681d63dcb"
194+
},
195+
"kernelspec": {
196+
"display_name": "Python 3.7.13 ('openmmlab')",
197+
"language": "python",
198+
"name": "python3"
199+
},
200+
"language_info": {
201+
"codemirror_mode": {
202+
"name": "ipython",
203+
"version": 3
204+
},
205+
"file_extension": ".py",
206+
"mimetype": "text/x-python",
207+
"name": "python",
208+
"nbconvert_exporter": "python",
209+
"pygments_lexer": "ipython3",
210+
"version": "3.7.13"
211+
},
212+
"orig_nbformat": 4
213+
},
214+
"nbformat": 4,
215+
"nbformat_minor": 2
216+
}

0 commit comments

Comments
 (0)