-
Notifications
You must be signed in to change notification settings - Fork 13
/
Copy pathtest05_hamerly.jl
132 lines (99 loc) · 3.95 KB
/
test05_hamerly.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
module TestHamerly
using ParallelKMeans
using ParallelKMeans: chunk_initialize, double_argmax
using Test
using StableRNGs
using Distances
@testset "initialize" begin
X = permutedims([1.0 2; 2 1; 4 5; 6 6])
centroids = permutedims([1.0 2; 4 5; 6 6])
nrow, ncol = size(X)
containers = ParallelKMeans.create_containers(Hamerly(), X, 3, nrow, ncol, 1)
ParallelKMeans.chunk_initialize(Hamerly(), containers, centroids, X, nothing, Euclidean(), 1:ncol, 1)
@test containers.lb == [18.0, 20.0, 5.0, 5.0]
@test containers.ub == [0.0, 2.0, 0.0, 0.0]
end
@testset "double argmax" begin
@test double_argmax([0.5, 0, 0]) == (1, 2, 0.5, 0.0)
end
@testset "singlethread linear separation" begin
# with the same amount of iterations answer should be the same as in Lloyd case
rng = StableRNG(2020)
X = rand(rng, 3, 100)
rng_orig = deepcopy(rng)
res = kmeans(Hamerly(), X, 3; n_threads = 1, tol = 1e-10, max_iters = 4, verbose = false, rng = rng)
@test res.totalcost ≈ 14.133433380466027
@test !res.converged
@test res.iterations == 4
rng = deepcopy(rng_orig)
res = kmeans(Hamerly(), X, 3; n_threads = 1, tol = 1e-10, max_iters = 1000, verbose = false, rng = rng)
@test res.totalcost ≈ 14.133433380466027
@test res.converged
@test res.iterations == 5
end
@testset "multithread linear separation quasi two threads" begin
rng = StableRNG(2020)
X = rand(rng, 3, 100)
rng_orig = deepcopy(rng)
res = kmeans(Hamerly(), X, 3; n_threads = 2, tol = 1e-10, max_iters = 4, verbose = false, rng = rng)
@test res.totalcost ≈ 14.133433380466027
@test !res.converged
@test res.iterations == 4
rng = deepcopy(rng_orig)
res = kmeans(Hamerly(), X, 3; n_threads = 2, tol = 1e-10, max_iters = 1000, verbose = false, rng = rng)
@test res.totalcost ≈ 14.133433380466027
@test res.converged
@test res.iterations == 5
end
@testset "Hamerly Float32 support" begin
rng = StableRNG(2020)
X = Float32.(rand(rng, 3, 100))
rng_orig = deepcopy(rng)
res = kmeans(Hamerly(), X, 3; n_threads = 1, tol = 1e-6, verbose = false, rng = rng)
@test typeof(res.totalcost) == Float32
@test res.totalcost ≈ 14.133433f0
@test res.converged
@test res.iterations == 5
rng = deepcopy(rng_orig)
res = kmeans(Hamerly(), X, 3; n_threads = 2, tol = 1e-6, verbose = false, rng = rng)
@test typeof(res.totalcost) == Float32
@test res.totalcost ≈ 14.133433f0
@test res.converged
@test res.iterations == 5
end
@testset "Hamerly weights support" begin
rng = StableRNG(2020)
X = rand(rng, 3, 100)
weights = rand(rng, 100)
rng_orig = deepcopy(rng)
baseline = kmeans(Lloyd(), X, 10; weights = weights, tol = 1e-10, verbose = false, rng = rng)
rng = deepcopy(rng_orig)
res = kmeans(Hamerly(), X, 10; weights = weights, tol = 1e-10, verbose = false, rng = rng)
@test res.totalcost ≈ baseline.totalcost
@test res.converged
@test res.iterations == baseline.iterations
rng = deepcopy(rng_orig)
res = kmeans(Hamerly(), X, 10; weights = weights, n_threads = 2, tol = 1e-10, verbose = false, rng = rng)
@test res.totalcost ≈ baseline.totalcost
@test res.converged
@test res.iterations == baseline.iterations
end
@testset "Hamerly metric support" begin
rng = StableRNG(2020)
X = [1. 2. 4.;]
res = kmeans(Hamerly(), X, 2; tol = 1e-16, metric=Cityblock(), rng = rng)
@test res.assignments == [2, 2, 1]
@test res.centers == [4.0 1.5]
@test res.totalcost == 1.0
@test res.converged
rng = StableRNG(2020)
X = rand(3, 100)
rng_orig = deepcopy(rng)
baseline = kmeans(Lloyd(), X, 2, tol = 1e-16, metric=Cityblock(), rng = rng)
rng = deepcopy(rng_orig)
res = kmeans(Hamerly(), X, 2; tol = 1e-16, metric=Cityblock(), rng = rng)
@test res.totalcost ≈ baseline.totalcost
@test res.converged == baseline.converged
@test res.iterations == baseline.iterations
end
end # module