Skip to content

Commit 24cd901

Browse files
Sylwester DawidaSylwester Dawida
Sylwester Dawida
authored and
Sylwester Dawida
committed
mode 26 97 % ready
1 parent 0b9600b commit 24cd901

19 files changed

+495
-54
lines changed

python/.idea/workspace.xml

+26-28
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
90 Bytes
Binary file not shown.
685 Bytes
Binary file not shown.
Binary file not shown.

python/datasetLoader.py

+17-12
Original file line numberDiff line numberDiff line change
@@ -6,22 +6,27 @@
66
import cv2
77
import numpy as np
88
import os
9+
import random
910

1011

1112
class myDataset(Dataset):
1213
"""hand symbols dataset."""
1314

14-
def __init__(self, dir, split, test=False, transform=None):
15+
def __init__(self, dir, split, test=False, train=False, validation=False, transform=None):
1516
"""
1617
dir - directory to dataset
1718
"""
1819
self.dir = dir
1920
self.transform = transform
2021
splitter = ""
21-
if (not test):
22-
splitter = '**/*[' + str(split + 1) + '-9].jpg'
22+
if test:
23+
splitter = '**/*0.jpg'
24+
elif train:
25+
splitter = '**/*[2-9].jpg'
26+
elif validation:
27+
splitter = '**/*1.jpg'
2328
else:
24-
splitter = '**/*[0-' + str(split) + '].jpg'
29+
raise Exception("chose one of train/test/validation")
2530
self.pathsList = glob.glob(self.dir + splitter, recursive=True)
2631

2732
def __len__(self):
@@ -125,13 +130,13 @@ def imshow(img):
125130
cv2.imshow("preview", img)
126131

127132

128-
transform = transforms.Compose(
129-
[tf.RandomCrop([400, 400]),
130-
tf.Resize((heigh, width)),
131-
tf.ColorJitter(brightness=0.5, contrast=0.5),
132-
tf.ToTensor(),
133-
tf.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
134-
])
133+
transform = transforms.Compose([
134+
tf.myRandomCrop(380, 480),
135+
tf.Resize((heigh, width)),
136+
tf.ColorJitter(brightness=0.5, contrast=0.5),
137+
tf.ToTensor(),
138+
tf.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
139+
])
135140
dir = "D:/DataSetNew/"
136141
testDataset = myDataset(dir, split, test=True, transform=transform)
137142
trainDataset = myDataset(dir, split, test=False, transform=transform)
@@ -147,7 +152,7 @@ def imshow(img):
147152
num_workers=4)
148153

149154
trainLoader = torch.utils.data.DataLoader(trainDataset,
150-
batch_size=1, shuffle=False,
155+
batch_size=1, shuffle=True,
151156
num_workers=4)
152157

153158
for data in trainLoader:

python/mainCuda.py

+36-13
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,33 @@ def imshow(img):
1919
cv2.imshow(np.transpose(npimg, (1, 2, 0)))
2020

2121

22-
def test(length=32):
22+
def validate(length=32):
2323
val_acc = val_loss = 0
24-
for i, data in enumerate(testDataset):
24+
for i, data in enumerate(validateDataset):
2525
image, label = data
2626
out = fwd_pass(image.view(-1, 3, heigh, width).to(device), label.to(device))
2727
val_acc += out[0]
2828
val_loss += out[1]
29-
if i == length:
29+
if i == length - 1:
3030
break
3131
return val_acc / length, val_loss / length
3232

3333

34+
def test():
35+
val_acc = val_loss = 0
36+
print("running test")
37+
t = tqdm(total=len(trainDataset)) # Initialise
38+
for i, data in enumerate(testDataset):
39+
t.update(1)
40+
image, label = data
41+
out = fwd_pass(image.view(-1, 3, heigh, width).to(device), label.to(device))
42+
val_acc += out[0]
43+
val_loss += out[1]
44+
t.close()
45+
length = len(testDataset)
46+
return val_acc / length, val_loss / length
47+
48+
3449
def fwd_pass(images, labels, train=False):
3550
if train:
3651
net.zero_grad()
@@ -76,8 +91,8 @@ def train(net, epochs, startingEpoch):
7691

7792
acc, loss = fwd_pass(images, labels, train=True)
7893

79-
if i % 100 == 99:
80-
val_acc, val_loss = test()
94+
if i % 200 == 199:
95+
val_acc, val_loss = validate(64)
8196
# print(val_acc, float(val_loss))
8297
f.write(
8398
f"{MODEL_NAME},{round(time.time(), 3)},{round(float(acc), 3)}, {round(float(loss), 4)}, {round(float(val_acc), 3)}, {round(float(val_loss), 4)}, {epoch}\n")
@@ -90,7 +105,7 @@ def train(net, epochs, startingEpoch):
90105

91106

92107
class globNr(object):
93-
nr = 21
108+
nr = 26
94109

95110

96111
if __name__ == "__main__":
@@ -111,8 +126,8 @@ class globNr(object):
111126
split = 1
112127
width = 60
113128
heigh = 60
114-
epochs = 10
115-
batchSize = 30
129+
epochs = 14
130+
batchSize = 16
116131
startEpoch = 0
117132
drawing = False
118133

@@ -123,21 +138,25 @@ class globNr(object):
123138
])
124139

125140
trainTransform = transforms.Compose([
126-
tf.RandomCrop([400, 400]),
141+
tf.myRandomCrop(380, 480),
127142
tf.Resize((heigh, width)),
128143
tf.ColorJitter(brightness=0.5, contrast=0.5),
129144
tf.ToTensor(),
130145
tf.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
131146
])
132147
testLoader = myDataset(dataSetPath, split, test=True, transform=transform)
133-
trainLoader = myDataset(dataSetPath, split, test=False, transform=trainTransform)
134-
print("test {}, train {}".format(len(testLoader), len(trainLoader)))
148+
trainLoader = myDataset(dataSetPath, split, train=True, transform=trainTransform)
149+
validateLoader = myDataset(dataSetPath, split, validation=True, transform=trainTransform)
150+
print("test {}, train {}, val {}".format(len(testLoader), len(trainLoader), len(validateLoader)))
135151
testDataset = torch.utils.data.DataLoader(testLoader,
136-
batch_size=batchSize, shuffle=True)
152+
batch_size=batchSize)
137153

138154
trainDataset = torch.utils.data.DataLoader(trainLoader,
139155
batch_size=batchSize, shuffle=True)
140156

157+
validateDataset = torch.utils.data.DataLoader(trainLoader,
158+
batch_size=batchSize, shuffle=True)
159+
141160
net = Net(width, heigh).to(device)
142161
optimizer = optim.Adam(net.parameters(), lr=0.001)
143162

@@ -156,8 +175,12 @@ class globNr(object):
156175
loss_function = nn.MSELoss()
157176
MODEL_NAME = f"model-{int(time.time())}" # gives a dynamic model name, to just help with things getting messy over time.
158177
print("Model name: " + MODEL_NAME)
159-
print("test dataset size: " + str(len(testDataset)) + " train dataset size: " + str(len(trainDataset)))
178+
print("test dataset size: {}, train dataset size: {}, validation dataset size: {}".format(len(testDataset),
179+
len(trainDataset),
180+
len(validateDataset)))
160181
startEpoch = train(net, epochs, startEpoch)
182+
acc, loss = test()
183+
print("Test (average): acc: {}%, loss: {}".format(acc * 100, loss))
161184
state = {
162185
'logFile': logFile,
163186
'epoch': epochs,

0 commit comments

Comments
 (0)