Skip to content

Commit 917f9cc

Browse files
committed
Merge branch '2.3' into export-network-graph
2 parents 3a3d497 + 6da4295 commit 917f9cc

File tree

8 files changed

+434
-1
lines changed

8 files changed

+434
-1
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
- 2.3.0
2+
- Added BM25 Transformer
23
- Add `dropFeature()` method to the dataset object API
34
- Add neural network architecture visualization via GraphViz
45

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
<?php
2+
3+
namespace Rubix\ML\Benchmarks\Transformers;
4+
5+
use Tensor\Matrix;
6+
use Rubix\ML\Datasets\Unlabeled;
7+
use Rubix\ML\Transformers\BM25Transformer;
8+
9+
/**
10+
* @Groups({"Transformers"})
11+
* @BeforeMethods({"setUp"})
12+
*/
13+
class BM25TransformerBench
14+
{
15+
protected const NUM_SAMPLES = 10000;
16+
17+
/**
18+
* @var \Rubix\ML\Datasets\Unlabeled
19+
*/
20+
protected $dataset;
21+
22+
/**
23+
* @var \Rubix\ML\Transformers\BM25Transformer
24+
*/
25+
protected $transformer;
26+
27+
/**
28+
* @var array<array<mixed>>
29+
*/
30+
protected $aSamples;
31+
32+
/**
33+
* @var array<array<mixed>>
34+
*/
35+
protected $bSamples;
36+
37+
public function setUp() : void
38+
{
39+
$mask = Matrix::rand(self::NUM_SAMPLES, 100)
40+
->greater(0.8);
41+
42+
$samples = Matrix::gaussian(self::NUM_SAMPLES, 100)
43+
->multiply($mask)
44+
->asArray();
45+
46+
$this->dataset = Unlabeled::quick($samples);
47+
48+
$this->transformer = new BM25Transformer();
49+
}
50+
51+
/**
52+
* @Subject
53+
* @Iterations(3)
54+
* @OutputTimeUnit("milliseconds", precision=3)
55+
*/
56+
public function apply() : void
57+
{
58+
$this->dataset->apply($this->transformer);
59+
}
60+
}

docs/preprocessing.md

+1
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ The library provides a number of transformers for Natural Language Processing (N
110110

111111
| Transformer | Supervised | [Stateful](transformers/api.md#stateful) | [Elastic](transformers/api.md#elastic) |
112112
|---|---|---|---|
113+
| [BM25 Transformer](transformers/bm25-transformer.md) | |||
113114
| [Regex Filter](transformers/regex-filter.md) | | | |
114115
| [Text Normalizer](transformers/text-normalizer.md) | | | |
115116
| [Multibyte Text Normalizer](transformers/multibyte-text-normalizer.md) | | | |

docs/transformers/bm25-transformer.md

+38
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
<span style="float:right;"><a href="https://github.com/RubixML/ML/blob/master/src/Transformers/BM25Transformer.php">[source]</a></span>
2+
3+
# BM25 Transformer
4+
BM25 is a sublinear term weighting scheme that takes term frequency (TF), document frequency (DF), and document length into account. It is similar to [TF-IDF](tf-idf-transformer.md) but with variable sublinearity and the addition of document length normalization.
5+
6+
> **Note:** BM25 Transformer assumes that its inputs are token frequency vectors such as those created by [Word Count Vectorizer](word-count-vectorizer.md).
7+
8+
**Interfaces:** [Transformer](api.md#transformer), [Stateful](api.md#stateful), [Elastic](api.md#elastic)
9+
10+
**Data Type Compatibility:** Continuous only
11+
12+
## Parameters
13+
| # | Param | Default | Type | Description |
14+
|---|---|---|---|---|
15+
| 1 | dampening | 1.2 | float | The term frequency (TF) dampening factor i.e. the `K1` parameter in the formula. Lower values will cause the TF to saturate quicker. |
16+
| 2 | normalization | 0.75 | float | The importance of document length in normalizing the term frequency i.e. the `b` parameter in the formula. |
17+
18+
## Example
19+
```php
20+
use Rubix\ML\Transformers\BM25Transformer;
21+
22+
$transformer = new BM25Transformer(1.2, 0.75);
23+
```
24+
25+
## Additional Methods
26+
Return the document frequencies calculated during fitting:
27+
```php
28+
public dfs() : ?array
29+
```
30+
31+
Return the average number of tokens per document:
32+
```php
33+
public averageDocumentLength() : ?float
34+
```
35+
36+
### References
37+
>- S. Robertson et al. (2009). The Probabilistic Relevance Framework: BM25 and Beyond.
38+
>- K. Sparck Jones et al. (2000). A probabilistic model of information retrieval: development and comparative experiments.

mkdocs.yml

+1
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ nav:
139139
- KNN Imputer: transformers/knn-imputer.md
140140
- Missing Data Imputer: transformers/missing-data-imputer.md
141141
- Natural Language:
142+
- BM25 Transformer: transformers/bm25-transformer.md
142143
- Regex Filter: transformers/regex-filter.md
143144
- Text Normalizer: transformers/text-normalizer.md
144145
- Multibyte Text Normalizer: transformers/multibyte-text-normalizer.md

src/Transformers/BM25Transformer.php

+241
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,241 @@
1+
<?php
2+
3+
namespace Rubix\ML\Transformers;
4+
5+
use Rubix\ML\DataType;
6+
use Rubix\ML\Datasets\Dataset;
7+
use Rubix\ML\Specifications\SamplesAreCompatibleWithTransformer;
8+
use Rubix\ML\Exceptions\InvalidArgumentException;
9+
use Rubix\ML\Exceptions\RuntimeException;
10+
11+
use function array_fill;
12+
use function array_sum;
13+
use function log;
14+
15+
/**
16+
* BM25 Transformer
17+
*
18+
* BM25 is a sublinear term weighting scheme that takes term frequency (TF), document frequency (DF),
19+
* and document length into account. It is similar to [TF-IDF](tf-idf-transformer.md) but with variable
20+
* sublinearity and the addition of document length normalization.
21+
*
22+
* > **Note**: BM25 Transformer assumes that its inputs are made up of token frequency
23+
* vectors such as those created by the Word Count or Token Hashing Vectorizer.
24+
*
25+
* References:
26+
* [1] S. Robertson et al. (2009). The Probabilistic Relevance Framework: BM25 and Beyond.
27+
* [2] K. Sparck Jones et al. (2000). A probabilistic model of information retrieval:
28+
* development and comparative experiments.
29+
*
30+
* @category Machine Learning
31+
* @package Rubix/ML
32+
* @author Andrew DalPino
33+
*/
34+
class BM25Transformer implements Transformer, Stateful, Elastic
35+
{
36+
/**
37+
* The term frequency (TF) dampening factor i.e. the `K1` parameter in the formula.
38+
* Lower values will cause the TF to saturate quicker.
39+
*
40+
* @var float
41+
*/
42+
protected float $dampening;
43+
44+
/**
45+
* The importance of document length in normalizing the term frequency i.e. the `b`
46+
* parameter in the formula
47+
*
48+
* @var float
49+
*/
50+
protected float $normalization;
51+
52+
/**
53+
* The document frequencies of each word i.e. the number of times a word appeared in
54+
* a document given the entire corpus.
55+
*
56+
* @var int[]|null
57+
*/
58+
protected ?array $dfs = null;
59+
60+
/**
61+
* The inverse document frequency values for each feature column.
62+
*
63+
* @var float[]|null
64+
*/
65+
protected ?array $idfs = null;
66+
67+
/**
68+
* The number of tokens fitted so far.
69+
*
70+
* @var int|null
71+
*/
72+
protected ?int $totalTokens;
73+
74+
/**
75+
* The number of documents (samples) that have been fitted so far.
76+
*
77+
* @var int|null
78+
*/
79+
protected ?int $n;
80+
81+
/**
82+
* The average token count per document.
83+
*
84+
* @var float|null
85+
*/
86+
protected ?float $averageDocumentLength;
87+
88+
/**
89+
* @param float $dampening
90+
* @param float $normalization
91+
* @throws \Rubix\ML\Exceptions\InvalidArgumentException
92+
*/
93+
public function __construct(float $dampening = 1.2, float $normalization = 0.75)
94+
{
95+
if ($dampening < 0.0) {
96+
throw new InvalidArgumentException('Dampening must be greater'
97+
. " than 0, $dampening given.");
98+
}
99+
100+
if ($normalization < 0.0 or $normalization > 1.0) {
101+
throw new InvalidArgumentException('Normalization must be between'
102+
. " 0 and 1, $normalization given.");
103+
}
104+
105+
$this->dampening = $dampening;
106+
$this->normalization = $normalization;
107+
}
108+
109+
/**
110+
* Return the data types that this transformer is compatible with.
111+
*
112+
* @return \Rubix\ML\DataType[]
113+
*/
114+
public function compatibility() : array
115+
{
116+
return [
117+
DataType::continuous(),
118+
];
119+
}
120+
121+
/**
122+
* Is the transformer fitted?
123+
*
124+
* @return bool
125+
*/
126+
public function fitted() : bool
127+
{
128+
return $this->idfs and $this->averageDocumentLength;
129+
}
130+
131+
/**
132+
* Return the document frequencies calculated during fitting.
133+
*
134+
* @return int[]|null
135+
*/
136+
public function dfs() : ?array
137+
{
138+
return $this->dfs;
139+
}
140+
141+
/**
142+
* Return the average number of tokens per document.
143+
*
144+
* @return float|null
145+
*/
146+
public function averageDocumentLength() : ?float
147+
{
148+
return $this->averageDocumentLength;
149+
}
150+
151+
/**
152+
* Fit the transformer to a dataset.
153+
*
154+
* @param \Rubix\ML\Datasets\Dataset $dataset
155+
*/
156+
public function fit(Dataset $dataset) : void
157+
{
158+
$this->dfs = array_fill(0, $dataset->numFeatures(), 1);
159+
$this->totalTokens = 0;
160+
$this->n = 1;
161+
162+
$this->update($dataset);
163+
}
164+
165+
/**
166+
* Update the fitting of the transformer.
167+
*
168+
* @param \Rubix\ML\Datasets\Dataset $dataset
169+
* @throws \Rubix\ML\Exceptions\InvalidArgumentException
170+
*/
171+
public function update(Dataset $dataset) : void
172+
{
173+
SamplesAreCompatibleWithTransformer::with($dataset, $this)->check();
174+
175+
if ($this->dfs === null or $this->n === null) {
176+
$this->fit($dataset);
177+
178+
return;
179+
}
180+
181+
foreach ($dataset->samples() as $sample) {
182+
foreach ($sample as $column => $tf) {
183+
if ($tf > 0) {
184+
++$this->dfs[$column];
185+
186+
$this->totalTokens += $tf;
187+
}
188+
}
189+
}
190+
191+
$this->n += $dataset->numSamples();
192+
193+
$this->averageDocumentLength = $this->totalTokens / $this->n;
194+
195+
$idfs = [];
196+
197+
foreach ($this->dfs as $df) {
198+
$idfs[] = log(1.0 + ($this->n - $df + 0.5) / ($df + 0.5));
199+
}
200+
201+
$this->idfs = $idfs;
202+
}
203+
204+
/**
205+
* Transform the dataset in place.
206+
*
207+
* @param array<array<mixed>> $samples
208+
* @throws \Rubix\ML\Exceptions\RuntimeException
209+
*/
210+
public function transform(array &$samples) : void
211+
{
212+
if ($this->idfs === null or $this->averageDocumentLength === null) {
213+
throw new RuntimeException('Transformer has not been fitted.');
214+
}
215+
216+
foreach ($samples as &$sample) {
217+
$delta = array_sum($sample) / $this->averageDocumentLength;
218+
219+
$delta = 1.0 - $this->normalization + $this->normalization * $delta;
220+
221+
$delta *= $this->dampening;
222+
223+
foreach ($sample as $column => &$tf) {
224+
if ($tf > 0) {
225+
$tf /= $tf + $delta;
226+
$tf *= $this->idfs[$column];
227+
}
228+
}
229+
}
230+
}
231+
232+
/**
233+
* Return the string representation of the object.
234+
*
235+
* @return string
236+
*/
237+
public function __toString() : string
238+
{
239+
return "BM25 Transformer (dampening: {$this->dampening}, normalization: {$this->normalization})";
240+
}
241+
}

src/Transformers/TfIdfTransformer.php

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
* and is offset by the frequency of the word in the corpus.
2525
*
2626
* > **Note**: TF-IDF Transformer assumes that its input is made up of term frequency
27-
* vectors such as those created by Word Count Vectorizer.
27+
* vectors such as those created by Word Count or Token Hashing Vectorizer.
2828
*
2929
* References:
3030
* [1] S. Robertson. (2003). Understanding Inverse Document Frequency: On theoretical

0 commit comments

Comments
 (0)