-
Notifications
You must be signed in to change notification settings - Fork 551
/
Copy pathlib.rs
42 lines (35 loc) · 1.25 KB
/
lib.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
mod backward;
mod forward;
use burn::tensor::{Tensor, TensorPrimitive, activation, ops::FloatTensor};
/// We create our own Backend trait that extends the Burn backend trait.
pub trait Backend: burn::tensor::backend::Backend {
fn fused_matmul_add_relu(
lhs: FloatTensor<Self>,
rhs: FloatTensor<Self>,
bias: FloatTensor<Self>,
) -> FloatTensor<Self>;
}
/// We create our own AutodiffBackend trait that extends the Burn autodiff backend trait.
pub trait AutodiffBackend: Backend + burn::tensor::backend::AutodiffBackend {}
/// We define our custom implementation using the added function on our custom backend.
pub fn matmul_add_relu_custom<B: Backend>(
lhs: Tensor<B, 3>,
rhs: Tensor<B, 3>,
bias: Tensor<B, 3>,
) -> Tensor<B, 3> {
let output = B::fused_matmul_add_relu(
lhs.into_primitive().tensor(),
rhs.into_primitive().tensor(),
bias.into_primitive().tensor(),
);
Tensor::from_primitive(TensorPrimitive::Float(output))
}
/// We define a reference implementation using basic tensor operations.
pub fn matmul_add_relu_reference<B: Backend>(
lhs: Tensor<B, 3>,
rhs: Tensor<B, 3>,
bias: Tensor<B, 3>,
) -> Tensor<B, 3> {
let x = lhs.matmul(rhs) + bias;
activation::relu(x)
}