Skip to content

Commit c2edd87

Browse files
committed
Verified metric support for Hamerly
1 parent 4f444a9 commit c2edd87

File tree

6 files changed

+49
-17
lines changed

6 files changed

+49
-17
lines changed

README.md

+1
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ ________________________________________________________________________________
6666
- Support for multi-theading implementation of K-Means clustering algorithm.
6767
- Kmeans++ initialization for faster and better convergence.
6868
- Implementation of all the variants of the K-Means algorithm.
69+
- Supported interface as an [MLJ](https://github.com/alan-turing-institute/MLJ.jl#available-models) model.
6970

7071
_________________________________________________________________________________________________________
7172

docs/src/index.md

+7-4
Original file line numberDiff line numberDiff line change
@@ -51,13 +51,15 @@ pkg> add ParallelKMeans
5151
The few (and selected) brave ones can simply grab the current experimental features by simply adding the experimental branch to your development environment after invoking the package manager with `]`:
5252

5353
```julia
54-
dev git@github.com:PyDataBlog/ParallelKMeans.jl.git
54+
pkg> add ParallelKMeans#experimental
5555
```
5656

57-
Don't forget to checkout the experimental branch and you are good to go with bleeding edge features and breakages!
57+
You are good to go with bleeding edge features and breakages!
5858

59-
```bash
60-
git checkout experimental
59+
To revert to a stable version, you can simply run:
60+
61+
```julia
62+
pkg> free ParallelKMeans
6163
```
6264

6365
## Features
@@ -78,6 +80,7 @@ git checkout experimental
7880
- [X] Support of MLJ Random generation hyperparameter.
7981
- [ ] Support for other distance metrics supported by [Distances.jl](https://github.com/JuliaStats/Distances.jl#supported-distances).
8082
- [ ] Implementation of [Geometric methods to accelerate k-means algorithm](http://cs.baylor.edu/~hamerly/papers/sdm2016_rysavy_hamerly.pdf).
83+
- [ ] Support of MLJ Random generation hyperparameter.
8184
- [ ] Native support for tabular data inputs outside of MLJModels' interface.
8285
- [ ] Refactoring and finalization of API design.
8386
- [ ] GPU support.

src/elkan.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@ function chunk_update_bounds(alg, containers, centroids, r, idx)
266266
stale = containers.stale
267267
labels = containers.labels
268268
T = eltype(centroids)
269-
269+
# TODO: Add metric support with multiple dispatch
270270
@inbounds for i in r
271271
for j in axes(centroids, 2)
272272
lb[j, i] = lb[j, i] > p[j] ? lb[j, i] + p[j] - T(2)*sqrt(abs(lb[j, i]*p[j])) : zero(T)

src/hamerly.jl

+30-8
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,10 @@ function kmeans!(alg::Hamerly, containers, X, k, weights;
4343
collect_containers(alg, containers, n_threads)
4444

4545
J = sum(containers.ub)
46-
move_centers(alg, containers, centroids)
46+
move_centers(alg, containers, centroids, metric)
4747

4848
r1, r2, pr1, pr2 = double_argmax(p)
49-
@parallelize n_threads ncol chunk_update_bounds(alg, containers, r1, r2, pr1, pr2)
49+
@parallelize n_threads ncol chunk_update_bounds(alg, containers, r1, r2, pr1, pr2, metric)
5050

5151
if verbose
5252
# Show progress and terminate if J stops decreasing as specified by the tolerance level.
@@ -241,15 +241,14 @@ end
241241
Calculates new positions of centers and distance they have moved. Results are stored
242242
in `centroids` and `p` respectively.
243243
"""
244-
function move_centers(::Hamerly, containers, centroids)
244+
function move_centers(::Hamerly, containers, centroids, metric)
245245
centroids_new = containers.centroids_new[end]
246246
p = containers.p
247247
T = eltype(centroids)
248248

249249
@inbounds for i in axes(centroids, 2)
250-
d = zero(T)
250+
d = distance(metric, centroids, centroids_new, i, i)
251251
for j in axes(centroids, 1)
252-
d += (centroids[j, i] - centroids_new[j, i])^2
253252
centroids[j, i] = centroids_new[j, i]
254253
end
255254
p[i] = d
@@ -258,11 +257,12 @@ end
258257

259258

260259
"""
261-
chunk_update_bounds(alg::Hamerly, containers, r1, r2, pr1, pr2, r, idx)
260+
chunk_update_bounds(alg::Hamerly, containers, r1, r2, pr1, pr2, metric::Euclidean, r, idx)
262261
263-
Updates upper and lower bounds of point distance to the centers, with regard to the centers movement.
262+
Updates upper and lower bounds of point distance to the centers, with regard to the centers movement
263+
when metric is Euclidean.
264264
"""
265-
function chunk_update_bounds(alg::Hamerly, containers, r1, r2, pr1, pr2, r, idx)
265+
function chunk_update_bounds(alg::Hamerly, containers, r1, r2, pr1, pr2, metric::Euclidean, r, idx)
266266
p = containers.p
267267
ub = containers.ub
268268
lb = containers.lb
@@ -297,6 +297,28 @@ function chunk_update_bounds(alg::Hamerly, containers, r1, r2, pr1, pr2, r, idx)
297297
end
298298

299299

300+
"""
301+
chunk_update_bounds(alg::Hamerly, containers, r1, r2, pr1, pr2, metric::Metric, r, idx)
302+
303+
Updates upper and lower bounds of point distance to the centers, with regard to the centers movement
304+
when metric is Euclidean.
305+
"""
306+
function chunk_update_bounds(alg::Hamerly, containers, r1, r2, pr1, pr2, metric::Metric, r, idx)
307+
p = containers.p
308+
ub = containers.ub
309+
lb = containers.lb
310+
labels = containers.labels
311+
T = eltype(containers.ub)
312+
# Using notation from original paper, `u` is upper bound and `a` is `labels`, so
313+
# `u[i] -> u[i] + p[a[i]]`
314+
@inbounds for i in r
315+
label = labels[i]
316+
ub[i] += p[label]
317+
lb[i] -= r1 == label ? pr2 : pr1
318+
end
319+
end
320+
321+
300322
"""
301323
double_argmax(p)
302324

src/lloyd.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ function kmeans!(alg::Lloyd, containers, X, k, weights;
4646
end
4747

4848
J_previous = J
49-
niters += 1
49+
niters += 1 # TODO: Investigate the potential bug in number of iterations
5050
end
5151
@parallelize n_threads ncol sum_of_squares(containers, X, containers.labels, centroids, weights, metric)
5252
totalcost = sum(containers.sum_of_squares)

test/test05_hamerly.jl

+9-3
Original file line numberDiff line numberDiff line change
@@ -118,9 +118,15 @@ end
118118
Random.seed!(2020)
119119
X = rand(3, 100)
120120

121-
res = kmeans(Hamerly(), X, 2, tol = 1e-16, metric=Cityblock())
122-
@test res.totalcost 62.04045252895372
123-
@test res.converged
121+
baseline = kmeans(Lloyd(), X, 2, tol = 1e-16, metric=Cityblock())
122+
123+
Random.seed!(2020)
124+
X = rand(3, 100)
125+
126+
res = kmeans(Hamerly(), X, 2; tol = 1e-16, metric=Cityblock())
127+
@test res.totalcost baseline.totalcost
128+
@test res.converged == baseline.converged
129+
@test res.iterations == baseline.iterations
124130
end
125131

126132
end # module

0 commit comments

Comments
 (0)