-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy patheval.py
67 lines (50 loc) · 2.16 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import sys
import importlib
from data import Test_Dataset
from thop import profile
import torch
import time
from progress.bar import Bar
import os
from collections import OrderedDict
import cv2
from PIL import Image
from util import *
import numpy as np
import argparse
from base.metric import *
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--data_path', default='../dataset/', help='The name of network')
parser.add_argument('--vals', default='all', help='Set the testing sets')
parser.add_argument('--size', default=320, help='Set the testing sets')
parser.add_argument('--pre_path', default='./maps', help='Weight path of network')
params = parser.parse_args()
config = vars(params)
config['orig_size'] = True
if config['vals'] == 'all':
vals = ['PASCAL-S', 'ECSSD', 'HKU-IS', 'DUTS-TE', 'DUT-OMRON']
else:
vals = config['vals'].split(',')
for val in vals:
img_path = '{}/scrn/{}/'.format(config['pre_path'], val)
if not os.path.exists(img_path):
continue
test_set = Test_Dataset(name=val, config=config)
titer = test_set.size
MR = MetricRecorder(titer)
test_bar = Bar('Dataset {:10}:'.format(val), max=titer)
for j in range(titer):
_, gt, name = test_set.load_data(j)
pred = Image.open(img_path + name + '.png').convert('L')
out_shape = gt.shape
pred = np.array(pred.resize((out_shape[::-1])))
pred, gt = normalize_pil(pred, gt)
MR.update(pre=pred, gt=gt)
Bar.suffix = '{}/{}'.format(j, titer)
test_bar.next()
mae, (maxf, meanf, *_), sm, em, wfm = MR.show(bit_num=3)
#print(' MAE: {}, Max-F: {}, Maen-F: {}, SM: {}, EM: {}, Fbw: {}.'.format(mae, maxf, meanf, sm, em, wfm))
print(' Max-F: {:.3f}, Maen-F: {:.3f}, Fbw: {:.3f}, MAE: {:.3f}, SM: {:.3f}, EM: {:.3f}.'.format(maxf, meanf, wfm, mae, sm, em))
if __name__ == "__main__":
main()