Skip to content

Commit b6f586b

Browse files
committed
Added Sparse Cosine distance kernel
1 parent 9fac61a commit b6f586b

File tree

7 files changed

+197
-18
lines changed

7 files changed

+197
-18
lines changed

CHANGELOG.md

+2
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
- Rename boosting `estimators` param to `epochs`
1111
- Neural net-based learners can now train for 0 epochs
1212
- Rename Labeled `stratify()` to `stratifyByLabel()`
13+
- Added Sparse Cosine distance kernel
14+
- Cosine distance now optimized for dense vectors
1315

1416
- 1.2.1
1517
- Refactor stratified methods on Labeled dataset

docs/kernels/distance/cosine.md

-3
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,6 @@ $$
77
{\displaystyle {\text{Cosine}}=1 - {\mathbf {A} \cdot \mathbf {B} \over \|\mathbf {A} \|\|\mathbf {B} \|}=1 - {\frac {\sum \limits _{i=1}^{n}{A_{i}B_{i}}}{{\sqrt {\sum \limits _{i=1}^{n}{A_{i}^{2}}}}{\sqrt {\sum \limits _{i=1}^{n}{B_{i}^{2}}}}}}}
88
$$
99

10-
!!! note
11-
This distance kernel is optimized for sparse (mainly zeros) coordinate vectors.
12-
1310
**Data Type Compatibility:** Continuous
1411

1512
## Parameters
+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
<span style="float:right;"><a href="https://github.com/RubixML/ML/blob/master/src/Kernels/Distance/SparseCosine.php">[source]</a></span>
2+
3+
# Sparse Cosine
4+
A version of the Cosine distance kernel that is specifically optimized for computing sparse vectors.
5+
6+
**Data Type Compatibility:** Continuous
7+
8+
## Parameters
9+
This kernel does not have any parameters.
10+
11+
## Example
12+
```php
13+
use Rubix\ML\Kernels\Distance\SparseCosine;
14+
15+
$kernel = new SparseCosine();
16+
```

mkdocs.yml

+1
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,7 @@ nav:
204204
- Manhattan: kernels/distance/manhattan.md
205205
- Minkowski: kernels/distance/minkowski.md
206206
- Safe Euclidean: kernels/distance/safe-euclidean.md
207+
- Sparse Cosine: kernels/distance/sparse-cosine.md
207208
- SVM:
208209
- Linear: kernels/svm/linear.md
209210
- Polynomial: kernels/svm/polynomial.md

src/Kernels/Distance/Cosine.php

+3-15
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@
1616
* satisfy the positive semi-definite condition, therefore the Cosine distance
1717
* is a number between 0 and 2.
1818
*
19-
* > **Note:** This distance kernel is optimized for sparse (mainly zeros) coordinate vectors.
20-
*
2119
* @category Machine Learning
2220
* @package Rubix/ML
2321
* @author Andrew DalPino
@@ -54,20 +52,10 @@ public function compute(array $a, array $b) : float
5452
foreach ($a as $i => $valueA) {
5553
$valueB = $b[$i];
5654

57-
if ($valueA != 0 and $valueB != 0) {
58-
$sigma += $valueA * $valueB;
59-
60-
$ssA += $valueA ** 2;
61-
$ssB += $valueB ** 2;
62-
} else {
63-
if ($valueA != 0) {
64-
$ssA += $valueA ** 2;
65-
}
55+
$sigma += $valueA * $valueB;
6656

67-
if ($valueB != 0) {
68-
$ssB += $valueB ** 2;
69-
}
70-
}
57+
$ssA += $valueA ** 2;
58+
$ssB += $valueB ** 2;
7159
}
7260

7361
if ($ssA === 0.0 and $ssB === 0.0) {

src/Kernels/Distance/SparseCosine.php

+86
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
<?php
2+
3+
namespace Rubix\ML\Kernels\Distance;
4+
5+
use Rubix\ML\DataType;
6+
7+
/**
8+
* Sparse Cosine
9+
*
10+
* A version of the Cosine distance kernel that is specifically optimized for computing sparse vectors.
11+
*
12+
* @category Machine Learning
13+
* @package Rubix/ML
14+
* @author Andrew DalPino
15+
*/
16+
class SparseCosine implements Distance
17+
{
18+
/**
19+
* Return the data types that this kernel is compatible with.
20+
*
21+
* @internal
22+
*
23+
* @return list<\Rubix\ML\DataType>
24+
*/
25+
public function compatibility() : array
26+
{
27+
return [
28+
DataType::continuous(),
29+
];
30+
}
31+
32+
/**
33+
* Compute the distance between two vectors.
34+
*
35+
* @internal
36+
*
37+
* @param list<int|float> $a
38+
* @param list<int|float> $b
39+
* @return float
40+
*/
41+
public function compute(array $a, array $b) : float
42+
{
43+
$sigma = $ssA = $ssB = 0.0;
44+
45+
foreach ($a as $i => $valueA) {
46+
$valueB = $b[$i];
47+
48+
if ($valueA != 0 and $valueB != 0) {
49+
$sigma += $valueA * $valueB;
50+
51+
$ssA += $valueA ** 2;
52+
$ssB += $valueB ** 2;
53+
} else {
54+
if ($valueA != 0) {
55+
$ssA += $valueA ** 2;
56+
}
57+
58+
if ($valueB != 0) {
59+
$ssB += $valueB ** 2;
60+
}
61+
}
62+
}
63+
64+
if ($ssA === 0.0 and $ssB === 0.0) {
65+
return 0.0;
66+
}
67+
68+
if ($ssA === 0.0 or $ssB === 0.0) {
69+
return 2.0;
70+
}
71+
72+
return 1.0 - ($sigma / sqrt($ssA * $ssB));
73+
}
74+
75+
/**
76+
* Return the string representation of the object.
77+
*
78+
* @internal
79+
*
80+
* @return string
81+
*/
82+
public function __toString() : string
83+
{
84+
return 'Sparse Cosine';
85+
}
86+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
<?php
2+
3+
namespace Rubix\ML\Tests\Kernels\Distance;
4+
5+
use Rubix\ML\Kernels\Distance\SparseCosine;
6+
use Rubix\ML\Kernels\Distance\Distance;
7+
use PHPUnit\Framework\TestCase;
8+
use Generator;
9+
10+
/**
11+
* @group Distances
12+
* @covers \Rubix\ML\Kernels\Distance\Cosine
13+
*/
14+
class SparseCosineTest extends TestCase
15+
{
16+
/**
17+
* @var \Rubix\ML\Kernels\Distance\SparseCosine
18+
*/
19+
protected $kernel;
20+
21+
/**
22+
* @before
23+
*/
24+
protected function setUp() : void
25+
{
26+
$this->kernel = new SparseCosine();
27+
}
28+
29+
/**
30+
* @test
31+
*/
32+
public function build() : void
33+
{
34+
$this->assertInstanceOf(SparseCosine::class, $this->kernel);
35+
$this->assertInstanceOf(Distance::class, $this->kernel);
36+
}
37+
38+
/**
39+
* @test
40+
* @dataProvider computeProvider
41+
*
42+
* @param (int|float)[] $a
43+
* @param (int|float)[] $b
44+
* @param float $expected
45+
*/
46+
public function compute(array $a, array $b, float $expected) : void
47+
{
48+
$distance = $this->kernel->compute($a, $b);
49+
50+
$this->assertGreaterThanOrEqual(0.0, $distance);
51+
$this->assertEquals($expected, $distance);
52+
}
53+
54+
/**
55+
* @return \Generator<array>
56+
*/
57+
public function computeProvider() : Generator
58+
{
59+
yield [
60+
[2, 1, 4, 0], [-2, 1, 8, -2],
61+
0.2593263058537443,
62+
];
63+
64+
yield [
65+
[7.4, -2.5], [0.01, -1],
66+
0.6704765571747832,
67+
];
68+
69+
yield [
70+
[1000, -2000, 3000], [1000, -2000, 3000],
71+
0.0,
72+
];
73+
74+
yield [
75+
[1000, -2000, 3000], [-1000, 2000, -3000],
76+
2.0,
77+
];
78+
79+
yield [
80+
[1.0, 2.0, 3.0], [0.0, 0.0, 0.0],
81+
2.0,
82+
];
83+
84+
yield [
85+
[0.0, 0.0, 0.0], [0.0, 0.0, 0.0],
86+
0.0,
87+
];
88+
}
89+
}

0 commit comments

Comments
 (0)