|
| 1 | +{ |
| 2 | + "cells": [ |
| 3 | + { |
| 4 | + "cell_type": "code", |
| 5 | + "execution_count": 16, |
| 6 | + "metadata": {}, |
| 7 | + "outputs": [ |
| 8 | + { |
| 9 | + "name": "stdout", |
| 10 | + "output_type": "stream", |
| 11 | + "text": [ |
| 12 | + "| name | #elements or shape |\n", |
| 13 | + "|:-----------------------|:---------------------|\n", |
| 14 | + "| model | 37.3K |\n", |
| 15 | + "| fc | 0.2K |\n", |
| 16 | + "| fc.0 | 98 |\n", |
| 17 | + "| fc.0.weight | (14, 7) |\n", |
| 18 | + "| fc.3 | 98 |\n", |
| 19 | + "| fc.3.weight | (7, 14) |\n", |
| 20 | + "| Linear_More_1 | 18.6K |\n", |
| 21 | + "| Linear_More_1.weight | (192, 96) |\n", |
| 22 | + "| Linear_More_1.bias | (192,) |\n", |
| 23 | + "| Linear_More_2 | 18.5K |\n", |
| 24 | + "| Linear_More_2.weight | (96, 192) |\n", |
| 25 | + "| Linear_More_2.bias | (96,) |\n" |
| 26 | + ] |
| 27 | + } |
| 28 | + ], |
| 29 | + "source": [ |
| 30 | + "import torch\n", |
| 31 | + "import torch.nn as nn\n", |
| 32 | + "import torch.nn.functional as F\n", |
| 33 | + "import numpy as np\n", |
| 34 | + "from fvcore.nn import FlopCountAnalysis,parameter_count_table\n", |
| 35 | + "class Model(nn.Module):\n", |
| 36 | + " \"\"\"\n", |
| 37 | + " Just one Linear layer\n", |
| 38 | + " \"\"\"\n", |
| 39 | + " def __init__(self, channel=7,ratio=1):\n", |
| 40 | + " super(Model, self).__init__()\n", |
| 41 | + "\n", |
| 42 | + " self.avg_pool = nn.AdaptiveAvgPool1d(1) #innovation\n", |
| 43 | + " self.fc = nn.Sequential(\n", |
| 44 | + " nn.Linear(7,14, bias=False),\n", |
| 45 | + " nn.Dropout(p=0.1),\n", |
| 46 | + " nn.ReLU(inplace=True) ,\n", |
| 47 | + " nn.Linear(14, 7, bias=False),\n", |
| 48 | + " nn.Sigmoid()\n", |
| 49 | + " )\n", |
| 50 | + " self.seq_len = 96\n", |
| 51 | + " self.pred_len = 96\n", |
| 52 | + " self.Linear_More_1 = nn.Linear(self.seq_len,self.pred_len * 2)\n", |
| 53 | + " self.Linear_More_2 = nn.Linear(self.pred_len*2,self.pred_len)\n", |
| 54 | + " self.relu = nn.ReLU()\n", |
| 55 | + " self.gelu = nn.GELU() \n", |
| 56 | + "\n", |
| 57 | + " self.drop = nn.Dropout(p=0.1)\n", |
| 58 | + " # Use this line if you want to visualize the weights\n", |
| 59 | + " # self.Linear.weight = nn.Parameter((1/self.seq_len)*torch.ones([self.pred_len,self.seq_len]))\n", |
| 60 | + " def forward(self, x):\n", |
| 61 | + " # x: [Batch, Input length, Channel]\n", |
| 62 | + " # x = self.Linear(x.permute(0,2,1))\n", |
| 63 | + " x = x.permute(0,2,1) # (B,L,C)=》(B,C,L)\n", |
| 64 | + " b, c, l = x.size() # (B,C,L)\n", |
| 65 | + " # y = self.avg_pool(x) # (B,C,L) 通过avg=》 (B,C,1)\n", |
| 66 | + " # print(\"y\",y.shape)\n", |
| 67 | + " y = self.avg_pool(x).view(b, c) # (B,C,L) 通过avg=》 (B,C,1)\n", |
| 68 | + " # print(\"y\",y.shape)\n", |
| 69 | + " #为了丢给Linear学习,需要view把数据展平开\n", |
| 70 | + " # y = self.fc(y).view(b, c, 96)\n", |
| 71 | + " \n", |
| 72 | + " y = self.fc(y).view(b,c,1)\n", |
| 73 | + "\n", |
| 74 | + " # print(\"y\",y.shape)\n", |
| 75 | + " return (x * y).permute(0,2,1)\n", |
| 76 | + "model = Model()\n", |
| 77 | + "print(parameter_count_table(model))" |
| 78 | + ] |
| 79 | + }, |
| 80 | + { |
| 81 | + "cell_type": "code", |
| 82 | + "execution_count": 17, |
| 83 | + "metadata": {}, |
| 84 | + "outputs": [ |
| 85 | + { |
| 86 | + "name": "stderr", |
| 87 | + "output_type": "stream", |
| 88 | + "text": [ |
| 89 | + "Unsupported operator aten::adaptive_avg_pool1d encountered 1 time(s)\n", |
| 90 | + "Unsupported operator aten::sigmoid encountered 1 time(s)\n", |
| 91 | + "Unsupported operator aten::mul encountered 1 time(s)\n", |
| 92 | + "The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.\n", |
| 93 | + "Linear_More_1, Linear_More_2, drop, gelu, relu\n" |
| 94 | + ] |
| 95 | + }, |
| 96 | + { |
| 97 | + "name": "stdout", |
| 98 | + "output_type": "stream", |
| 99 | + "text": [ |
| 100 | + "FLOPs: 1568\n" |
| 101 | + ] |
| 102 | + } |
| 103 | + ], |
| 104 | + "source": [ |
| 105 | + "'''\n", |
| 106 | + "计算FLOPs\n", |
| 107 | + "'''\n", |
| 108 | + "tensor=torch.randn(8,96,7)\n", |
| 109 | + "FLOPs = FlopCountAnalysis(model,tensor)\n", |
| 110 | + "print(\"FLOPs:\",FLOPs.total())" |
| 111 | + ] |
| 112 | + }, |
| 113 | + { |
| 114 | + "cell_type": "code", |
| 115 | + "execution_count": 20, |
| 116 | + "metadata": {}, |
| 117 | + "outputs": [ |
| 118 | + { |
| 119 | + "name": "stdout", |
| 120 | + "output_type": "stream", |
| 121 | + "text": [ |
| 122 | + "| name | #elements or shape |\n", |
| 123 | + "|:----------------|:---------------------|\n", |
| 124 | + "| model | 69.8K |\n", |
| 125 | + "| Linear | 69.8K |\n", |
| 126 | + "| Linear.weight | (720, 96) |\n", |
| 127 | + "| Linear.bias | (720,) |\n" |
| 128 | + ] |
| 129 | + } |
| 130 | + ], |
| 131 | + "source": [ |
| 132 | + "class Model(nn.Module):\n", |
| 133 | + " \"\"\"\n", |
| 134 | + " Normalization-Linear\n", |
| 135 | + " \"\"\"\n", |
| 136 | + " def __init__(self):\n", |
| 137 | + " super(Model, self).__init__()\n", |
| 138 | + " self.seq_len = 96\n", |
| 139 | + " self.pred_len = 720\n", |
| 140 | + " self.Linear = nn.Linear(self.seq_len, self.pred_len)\n", |
| 141 | + " # Use this line if you want to visualize the weights\n", |
| 142 | + " # self.Linear.weight = nn.Parameter((1/self.seq_len)*torch.ones([self.pred_len,self.seq_len]))\n", |
| 143 | + "\n", |
| 144 | + " def forward(self, x):\n", |
| 145 | + " # x: [Batch, Input length, Channel]\n", |
| 146 | + "\n", |
| 147 | + " x = self.Linear(x.permute(0,2,1)).permute(0,2,1)\n", |
| 148 | + " \n", |
| 149 | + " return x # [Batch, Output length, Channel]\n", |
| 150 | + "model = Model()\n", |
| 151 | + "\n", |
| 152 | + "\n", |
| 153 | + "\n", |
| 154 | + "print(parameter_count_table(model))" |
| 155 | + ] |
| 156 | + }, |
| 157 | + { |
| 158 | + "cell_type": "code", |
| 159 | + "execution_count": 42, |
| 160 | + "metadata": {}, |
| 161 | + "outputs": [ |
| 162 | + { |
| 163 | + "data": { |
| 164 | + "text/plain": [ |
| 165 | + "'\\n计算FLOPs\\n'" |
| 166 | + ] |
| 167 | + }, |
| 168 | + "execution_count": 42, |
| 169 | + "metadata": {}, |
| 170 | + "output_type": "execute_result" |
| 171 | + } |
| 172 | + ], |
| 173 | + "source": [ |
| 174 | + "'''\n", |
| 175 | + "计算FLOPs\n", |
| 176 | + "'''\n", |
| 177 | + "# tensor=torch.randn(,96,7)\n", |
| 178 | + "result =model(tensor)\n", |
| 179 | + "# FLOPs = FlopCountAnalysis(model,tensor)\n", |
| 180 | + "# print(\"FLOPs:\",FLOPs.total())" |
| 181 | + ] |
| 182 | + }, |
| 183 | + { |
| 184 | + "cell_type": "code", |
| 185 | + "execution_count": 49, |
| 186 | + "metadata": {}, |
| 187 | + "outputs": [], |
| 188 | + "source": [] |
| 189 | + } |
| 190 | + ], |
| 191 | + "metadata": { |
| 192 | + "interpreter": { |
| 193 | + "hash": "f57785bf53e86c458d31dd8512073d1ac6cae98f342ec9a1a9a8506681d63dcb" |
| 194 | + }, |
| 195 | + "kernelspec": { |
| 196 | + "display_name": "Python 3.7.13 ('openmmlab')", |
| 197 | + "language": "python", |
| 198 | + "name": "python3" |
| 199 | + }, |
| 200 | + "language_info": { |
| 201 | + "codemirror_mode": { |
| 202 | + "name": "ipython", |
| 203 | + "version": 3 |
| 204 | + }, |
| 205 | + "file_extension": ".py", |
| 206 | + "mimetype": "text/x-python", |
| 207 | + "name": "python", |
| 208 | + "nbconvert_exporter": "python", |
| 209 | + "pygments_lexer": "ipython3", |
| 210 | + "version": "3.7.13" |
| 211 | + }, |
| 212 | + "orig_nbformat": 4 |
| 213 | + }, |
| 214 | + "nbformat": 4, |
| 215 | + "nbformat_minor": 2 |
| 216 | +} |
0 commit comments