@@ -19,18 +19,33 @@ def imshow(img):
19
19
cv2 .imshow (np .transpose (npimg , (1 , 2 , 0 )))
20
20
21
21
22
- def test (length = 32 ):
22
+ def validate (length = 32 ):
23
23
val_acc = val_loss = 0
24
- for i , data in enumerate (testDataset ):
24
+ for i , data in enumerate (validateDataset ):
25
25
image , label = data
26
26
out = fwd_pass (image .view (- 1 , 3 , heigh , width ).to (device ), label .to (device ))
27
27
val_acc += out [0 ]
28
28
val_loss += out [1 ]
29
- if i == length :
29
+ if i == length - 1 :
30
30
break
31
31
return val_acc / length , val_loss / length
32
32
33
33
34
+ def test ():
35
+ val_acc = val_loss = 0
36
+ print ("running test" )
37
+ t = tqdm (total = len (trainDataset )) # Initialise
38
+ for i , data in enumerate (testDataset ):
39
+ t .update (1 )
40
+ image , label = data
41
+ out = fwd_pass (image .view (- 1 , 3 , heigh , width ).to (device ), label .to (device ))
42
+ val_acc += out [0 ]
43
+ val_loss += out [1 ]
44
+ t .close ()
45
+ length = len (testDataset )
46
+ return val_acc / length , val_loss / length
47
+
48
+
34
49
def fwd_pass (images , labels , train = False ):
35
50
if train :
36
51
net .zero_grad ()
@@ -76,8 +91,8 @@ def train(net, epochs, startingEpoch):
76
91
77
92
acc , loss = fwd_pass (images , labels , train = True )
78
93
79
- if i % 100 == 99 :
80
- val_acc , val_loss = test ( )
94
+ if i % 200 == 199 :
95
+ val_acc , val_loss = validate ( 64 )
81
96
# print(val_acc, float(val_loss))
82
97
f .write (
83
98
f"{ MODEL_NAME } ,{ round (time .time (), 3 )} ,{ round (float (acc ), 3 )} , { round (float (loss ), 4 )} , { round (float (val_acc ), 3 )} , { round (float (val_loss ), 4 )} , { epoch } \n " )
@@ -90,7 +105,7 @@ def train(net, epochs, startingEpoch):
90
105
91
106
92
107
class globNr (object ):
93
- nr = 21
108
+ nr = 26
94
109
95
110
96
111
if __name__ == "__main__" :
@@ -111,8 +126,8 @@ class globNr(object):
111
126
split = 1
112
127
width = 60
113
128
heigh = 60
114
- epochs = 10
115
- batchSize = 30
129
+ epochs = 14
130
+ batchSize = 16
116
131
startEpoch = 0
117
132
drawing = False
118
133
@@ -123,21 +138,25 @@ class globNr(object):
123
138
])
124
139
125
140
trainTransform = transforms .Compose ([
126
- tf .RandomCrop ([ 400 , 400 ] ),
141
+ tf .myRandomCrop ( 380 , 480 ),
127
142
tf .Resize ((heigh , width )),
128
143
tf .ColorJitter (brightness = 0.5 , contrast = 0.5 ),
129
144
tf .ToTensor (),
130
145
tf .Normalize ((0.5 , 0.5 , 0.5 ), (0.5 , 0.5 , 0.5 ))
131
146
])
132
147
testLoader = myDataset (dataSetPath , split , test = True , transform = transform )
133
- trainLoader = myDataset (dataSetPath , split , test = False , transform = trainTransform )
134
- print ("test {}, train {}" .format (len (testLoader ), len (trainLoader )))
148
+ trainLoader = myDataset (dataSetPath , split , train = True , transform = trainTransform )
149
+ validateLoader = myDataset (dataSetPath , split , validation = True , transform = trainTransform )
150
+ print ("test {}, train {}, val {}" .format (len (testLoader ), len (trainLoader ), len (validateLoader )))
135
151
testDataset = torch .utils .data .DataLoader (testLoader ,
136
- batch_size = batchSize , shuffle = True )
152
+ batch_size = batchSize )
137
153
138
154
trainDataset = torch .utils .data .DataLoader (trainLoader ,
139
155
batch_size = batchSize , shuffle = True )
140
156
157
+ validateDataset = torch .utils .data .DataLoader (trainLoader ,
158
+ batch_size = batchSize , shuffle = True )
159
+
141
160
net = Net (width , heigh ).to (device )
142
161
optimizer = optim .Adam (net .parameters (), lr = 0.001 )
143
162
@@ -156,8 +175,12 @@ class globNr(object):
156
175
loss_function = nn .MSELoss ()
157
176
MODEL_NAME = f"model-{ int (time .time ())} " # gives a dynamic model name, to just help with things getting messy over time.
158
177
print ("Model name: " + MODEL_NAME )
159
- print ("test dataset size: " + str (len (testDataset )) + " train dataset size: " + str (len (trainDataset )))
178
+ print ("test dataset size: {}, train dataset size: {}, validation dataset size: {}" .format (len (testDataset ),
179
+ len (trainDataset ),
180
+ len (validateDataset )))
160
181
startEpoch = train (net , epochs , startEpoch )
182
+ acc , loss = test ()
183
+ print ("Test (average): acc: {}%, loss: {}" .format (acc * 100 , loss ))
161
184
state = {
162
185
'logFile' : logFile ,
163
186
'epoch' : epochs ,
0 commit comments