-
Notifications
You must be signed in to change notification settings - Fork 346
/
Copy pathplot_kmeans.py
123 lines (107 loc) · 4.13 KB
/
plot_kmeans.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
# -*- coding: utf-8 -*-
"""
k-means
=======
This example uses :math:`k`-means clustering for time series. Three variants of
the algorithm are available: standard
Euclidean :math:`k`-means, DBA-:math:`k`-means (for DTW Barycenter
Averaging [1])
and Soft-DTW :math:`k`-means [2].
In the figure below, each row corresponds to the result of a different
clustering. In a row, each sub-figure corresponds to a cluster.
It represents the set
of time series from the training set that were assigned to the considered
cluster (in black) as well as the barycenter of the cluster (in red).
A note on pre-processing
~~~~~~~~~~~~~~~~~~~~~~~~
In this example, time series are preprocessed using
`TimeSeriesScalerMeanVariance`. This scaler is such that each output time
series has zero mean and unit variance.
The assumption here is that the range of a given time series is uninformative
and one only wants to compare shapes in an amplitude-invariant manner (when
time series are multivariate, this also rescales all modalities such that there
will not be a single modality responsible for a large part of the variance).
This means that one cannot scale barycenters back to data range because each
time series is scaled independently and there is hence no such thing as an
overall data range.
[1] F. Petitjean, A. Ketterlin & P. Gancarski. A global averaging method \
for dynamic time warping, with applications to clustering. Pattern \
Recognition, Elsevier, 2011, Vol. 44, Num. 3, pp. 678-693
[2] M. Cuturi, M. Blondel "Soft-DTW: a Differentiable Loss Function for \
Time-Series," ICML 2017.
"""
# Author: Romain Tavenard
# License: BSD 3 clause
import numpy
import matplotlib.pyplot as plt
from tslearn.clustering import TimeSeriesKMeans
from tslearn.datasets import CachedDatasets
from tslearn.preprocessing import TimeSeriesScalerMeanVariance, \
TimeSeriesResampler
seed = 0
numpy.random.seed(seed)
X_train, y_train, X_test, y_test = CachedDatasets().load_dataset("Trace")
X_train = X_train[y_train < 4] # Keep first 3 classes
numpy.random.shuffle(X_train)
# Keep only 50 time series
X_train = TimeSeriesScalerMeanVariance().fit_transform(X_train[:50])
# Make time series shorter
X_train = TimeSeriesResampler(sz=40).fit_transform(X_train)
sz = X_train.shape[1]
# Euclidean k-means
print("Euclidean k-means")
km = TimeSeriesKMeans(n_clusters=3, verbose=True, random_state=seed)
y_pred = km.fit_predict(X_train)
plt.figure()
for yi in range(3):
plt.subplot(3, 3, yi + 1)
for xx in X_train[y_pred == yi]:
plt.plot(xx.ravel(), "k-", alpha=.2)
plt.plot(km.cluster_centers_[yi].ravel(), "r-")
plt.xlim(0, sz)
plt.ylim(-4, 4)
plt.text(0.55, 0.85,'Cluster %d' % (yi + 1),
transform=plt.gca().transAxes)
if yi == 1:
plt.title("Euclidean $k$-means")
# DBA-k-means
print("DBA k-means")
dba_km = TimeSeriesKMeans(n_clusters=3,
n_init=2,
metric="dtw",
verbose=True,
max_iter_barycenter=10,
random_state=seed)
y_pred = dba_km.fit_predict(X_train)
for yi in range(3):
plt.subplot(3, 3, 4 + yi)
for xx in X_train[y_pred == yi]:
plt.plot(xx.ravel(), "k-", alpha=.2)
plt.plot(dba_km.cluster_centers_[yi].ravel(), "r-")
plt.xlim(0, sz)
plt.ylim(-4, 4)
plt.text(0.55, 0.85,'Cluster %d' % (yi + 1),
transform=plt.gca().transAxes)
if yi == 1:
plt.title("DBA $k$-means")
# Soft-DTW-k-means
print("Soft-DTW k-means")
sdtw_km = TimeSeriesKMeans(n_clusters=3,
metric="softdtw",
metric_params={"gamma": .01},
verbose=True,
random_state=seed)
y_pred = sdtw_km.fit_predict(X_train)
for yi in range(3):
plt.subplot(3, 3, 7 + yi)
for xx in X_train[y_pred == yi]:
plt.plot(xx.ravel(), "k-", alpha=.2)
plt.plot(sdtw_km.cluster_centers_[yi].ravel(), "r-")
plt.xlim(0, sz)
plt.ylim(-4, 4)
plt.text(0.55, 0.85,'Cluster %d' % (yi + 1),
transform=plt.gca().transAxes)
if yi == 1:
plt.title("Soft-DTW $k$-means")
plt.tight_layout()
plt.show()