From d95b5e1fc012ec0bd0deef7e547720eb7f4e3811 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 13 Feb 2025 16:56:10 +1100 Subject: [PATCH] Added GELU activation functions --- basalt/nn/__init__.mojo | 1 + basalt/nn/activations.mojo | 14 ++++++++++++++ 2 files changed, 15 insertions(+) diff --git a/basalt/nn/__init__.mojo b/basalt/nn/__init__.mojo index c2a0660..11dc4fa 100644 --- a/basalt/nn/__init__.mojo +++ b/basalt/nn/__init__.mojo @@ -11,6 +11,7 @@ from .activations import ( LogSoftmax, ReLU, LeakyReLU, + GELU, Sigmoid, Tanh, ) diff --git a/basalt/nn/activations.mojo b/basalt/nn/activations.mojo index 9a83a0f..d679399 100644 --- a/basalt/nn/activations.mojo +++ b/basalt/nn/activations.mojo @@ -17,6 +17,20 @@ fn LeakyReLU( attributes=AttributeVector(Attribute("negative_slope", negative_slope)), ) +fn GELU(inout g: Graph, input: Symbol) -> Symbol: + var SQRT_2_OVER_PI = 0.7978845608028654 + var GELU_COEFF = 0.044715 + + var x_cubed = g.op(OP.POW, input, 3.0) + var term = g.op(OP.ADD, input, g.op(OP.MUL, GELU_COEFF, x_cubed)) + var scaled_term = g.op(OP.MUL, SQRT_2_OVER_PI, term) + var tanh_result = g.op(OP.TANH, scaled_term) + var one_plus_tanh = g.op(OP.ADD, 1.0, tanh_result) + var gelu_output = g.op(OP.MUL, g.op(OP.MUL, 0.5, input), one_plus_tanh) + + return gelu_output + + fn Sigmoid(inout g: Graph, input: Symbol) -> Symbol: return g.op(OP.SIGMOID, input)