Skip to content

Commit a93b763

Browse files
committed
Implement Swish Beta
1 parent e3e9181 commit a93b763

File tree

7 files changed

+65
-41
lines changed

7 files changed

+65
-41
lines changed

CHANGELOG.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
- Switch back to original fork of Tensor
33
- Added `maxBins` hyper-parameter to CART-based learners
44
- Added stream Deduplicator extractor
5-
- Added SiLU activation function from Extras package
5+
- Added the Swish activation function
66

77
- 1.2.2
88
- Allow empty dataset objects in `stack()`

docs/neural-network/activation-functions/silu.md

-18
This file was deleted.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
<span style="float:right;"><a href="https://github.com/RubixML/ML/blob/master/src/NeuralNet/ActivationFunctions/Swish.php">[source]</a></span>
2+
3+
# Swish
4+
Swish is a smooth and non-monotonic rectified activation function. The inputs are weighted by the [Sigmoid](sigmoid.md) activation function acting as a self-gating mechanism. In addition, the `beta` parameter allows you to adjust the gate such that you can interpolate between the ReLU function and the linear function as `beta` goes from 0 to infinity.
5+
6+
## Parameters
7+
## Parameters
8+
| # | Name | Default | Type | Description |
9+
|---|---|---|---|---|
10+
| 1 | beta | 1.0 | float | The parameter that adjusts the slope of the sigmoid gating mechanism. |
11+
12+
## Example
13+
```php
14+
use Rubix\ML\NeuralNet\ActivationFunctions\Swish;
15+
16+
$activationFunction = new Swish(1.0);
17+
```
18+
19+
### References
20+
>- S. Elwing et al. (2017). Sigmoid-Weighted Linear Units for Neural Network Function Approximation in Reinforcement Learning.
21+
>- P. Ramachandran er al. (2017). Swish: A Self-gated Activation Function.
22+
>- P. Ramachandran et al. (2017). Searching for Activation Functions.

mkdocs.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -166,10 +166,10 @@ nav:
166166
- ReLU: neural-network/activation-functions/relu.md
167167
- SELU: neural-network/activation-functions/selu.md
168168
- Sigmoid: neural-network/activation-functions/sigmoid.md
169-
- SiLU: neural-network/activation-functions/silu.md
170169
- Softmax: neural-network/activation-functions/softmax.md
171170
- Soft Plus: neural-network/activation-functions/soft-plus.md
172171
- Soft Sign: neural-network/activation-functions/softsign.md
172+
- Swish: neural-network/activation-functions/swish.md
173173
- Thresholded ReLU: neural-network/activation-functions/thresholded-relu.md
174174
- Cost Functions:
175175
- Cross Entropy: neural-network/cost-functions/cross-entropy.md

src/NeuralNet/ActivationFunctions/SiLU.php renamed to src/NeuralNet/ActivationFunctions/Swish.php

+30-10
Original file line numberDiff line numberDiff line change
@@ -3,34 +3,53 @@
33
namespace Rubix\ML\NeuralNet\ActivationFunctions;
44

55
use Tensor\Matrix;
6+
use InvalidArgumentException;
67

78
/**
8-
* SiLU
9+
* Swish
910
*
10-
* *Sigmoid-weighted Linear Unit* (SiLU) is a smooth and non-monotonic rectified activation function. The inputs
11-
* are weighted by the [Sigmoid](sigmoid.md) activation function acting as a self-gating mechanism. In addition,
12-
* an inherent global minimum functions as an implicit regularizer.
11+
* Swish is a smooth and non-monotonic rectified activation function. The inputs are weighted by the [Sigmoid](sigmoid.md)
12+
* activation function acting as a self-gating mechanism. In addition, the `beta` parameter allows you to adjust the gate
13+
* such that you can interpolate between the ReLU function and the linear function as `beta` goes from 0 to infinity.
1314
*
1415
* References:
1516
* [1] S. Elwing et al. (2017). Sigmoid-Weighted Linear Units for Neural Network Function
1617
* Approximation in Reinforcement Learning.
17-
* [2] P. Ramachandran er al. (2017). Swish: A Self-gated Activation Function.
18+
* [2] P. Ramachandran et al. (2017). Swish: A Self-gated Activation Function.
19+
* [3] P. Ramachandran et al. (2017). Searching for Activation Functions.
1820
*
1921
* @category Machine Learning
2022
* @package Rubix/ML
2123
* @author Andrew DalPino
2224
*/
23-
class SiLU implements ActivationFunction
25+
class Swish implements ActivationFunction
2426
{
27+
/**
28+
* The parameter that adjusts the slope of the sigmoid gating mechanism.
29+
*
30+
* @var float
31+
*/
32+
protected float $beta;
33+
2534
/**
2635
* The sigmoid activation function.
2736
*
2837
* @var \Rubix\ML\NeuralNet\ActivationFunctions\Sigmoid
2938
*/
30-
protected $sigmoid;
39+
protected \Rubix\ML\NeuralNet\ActivationFunctions\Sigmoid $sigmoid;
3140

32-
public function __construct()
41+
/**
42+
* @param float $beta
43+
* @throws \InvalidArgumentException
44+
*/
45+
public function __construct(float $beta = 1.0)
3346
{
47+
if ($beta < 0.0) {
48+
throw new InvalidArgumentException('Beta must be greater than'
49+
. " 0, $beta given.");
50+
}
51+
52+
$this->beta = $beta;
3453
$this->sigmoid = new Sigmoid();
3554
}
3655

@@ -42,7 +61,8 @@ public function __construct()
4261
*/
4362
public function compute(Matrix $z) : Matrix
4463
{
45-
return $this->sigmoid->compute($z)->multiply($z);
64+
return $this->sigmoid->compute($z->multiply($this->beta))
65+
->multiply($z);
4666
}
4767

4868
/**
@@ -68,6 +88,6 @@ public function differentiate(Matrix $z, Matrix $computed) : Matrix
6888
*/
6989
public function __toString() : string
7090
{
71-
return 'SiLU';
91+
return "Swish (beta: {$this->beta})";
7292
}
7393
}

tests/NeuralNet/ActivationFunctions/SiLUTest.php renamed to tests/NeuralNet/ActivationFunctions/SwishTest.php

+6-6
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,19 @@
33
namespace Rubix\ML\Tests\NeuralNet\ActivationFunctions;
44

55
use Tensor\Matrix;
6-
use Rubix\ML\NeuralNet\ActivationFunctions\SiLU;
6+
use Rubix\ML\NeuralNet\ActivationFunctions\Swish;
77
use Rubix\ML\NeuralNet\ActivationFunctions\ActivationFunction;
88
use PHPUnit\Framework\TestCase;
99
use Generator;
1010

1111
/**
1212
* @group ActivationFunctions
13-
* @covers \Rubix\ML\NeuralNet\ActivationFunctions\SiLU
13+
* @covers \Rubix\ML\NeuralNet\ActivationFunctions\Swish
1414
*/
15-
class SiLUTest extends TestCase
15+
class SwishTest extends TestCase
1616
{
1717
/**
18-
* @var \Rubix\ML\NeuralNet\ActivationFunctions\SiLU
18+
* @var \Rubix\ML\NeuralNet\ActivationFunctions\Swish
1919
*/
2020
protected $activationFn;
2121

@@ -24,15 +24,15 @@ class SiLUTest extends TestCase
2424
*/
2525
protected function setUp() : void
2626
{
27-
$this->activationFn = new SiLU();
27+
$this->activationFn = new Swish(1.0);
2828
}
2929

3030
/**
3131
* @test
3232
*/
3333
public function build() : void
3434
{
35-
$this->assertInstanceOf(SiLU::class, $this->activationFn);
35+
$this->assertInstanceOf(Swish::class, $this->activationFn);
3636
$this->assertInstanceOf(ActivationFunction::class, $this->activationFn);
3737
}
3838

tests/Regressors/MLPRegressorTest.php

+5-5
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
use Rubix\ML\Transformers\ZScaleStandardizer;
2222
use Rubix\ML\CrossValidation\Metrics\RSquared;
2323
use Rubix\ML\NeuralNet\CostFunctions\LeastSquares;
24-
use Rubix\ML\NeuralNet\ActivationFunctions\LeakyReLU;
24+
use Rubix\ML\NeuralNet\ActivationFunctions\Swish;
2525
use Rubix\ML\Exceptions\InvalidArgumentException;
2626
use Rubix\ML\Exceptions\RuntimeException;
2727
use PHPUnit\Framework\TestCase;
@@ -84,9 +84,9 @@ protected function setUp() : void
8484

8585
$this->estimator = new MLPRegressor([
8686
new Dense(10),
87-
new Activation(new LeakyReLU()),
87+
new Activation(new Swish()),
8888
new Dense(10),
89-
new Activation(new LeakyReLU()),
89+
new Activation(new Swish()),
9090
], 10, new Adam(0.01), 1e-4, 100, 1e-3, 3, 0.1, new LeastSquares(), new RMSE());
9191

9292
$this->metric = new RSquared();
@@ -147,9 +147,9 @@ public function params() : void
147147
$expected = [
148148
'hidden layers' => [
149149
new Dense(10),
150-
new Activation(new LeakyReLU()),
150+
new Activation(new Swish()),
151151
new Dense(10),
152-
new Activation(new LeakyReLU()),
152+
new Activation(new Swish()),
153153
],
154154
'batch size' => 10,
155155
'optimizer' => new Adam(0.01),

0 commit comments

Comments
 (0)