diff --git a/encoderfile/src/transforms/tensor/mod.rs b/encoderfile/src/transforms/tensor/mod.rs index 0520cb07..561b77f1 100644 --- a/encoderfile/src/transforms/tensor/mod.rs +++ b/encoderfile/src/transforms/tensor/mod.rs @@ -1,10 +1,9 @@ use super::utils::table_to_vec; use mlua::prelude::*; -use ndarray::{Array1, ArrayD, Axis}; -use ndarray_stats::QuantileExt; +use ndarray::ArrayD; +use ops::arithm::{add, div, mul, sub}; -#[cfg(test)] -mod tests; +mod ops; #[derive(Debug, Clone, PartialEq)] pub struct Tensor(pub ArrayD); @@ -15,6 +14,11 @@ impl Tensor { } } +#[cfg(test)] +fn load_env() -> Lua { + Lua::new() +} + impl FromLua for Tensor { fn from_lua(value: LuaValue, _lua: &Lua) -> Result { match value { @@ -88,414 +92,126 @@ impl LuaUserData for Tensor { } } -impl Tensor { - #[tracing::instrument(skip_all)] - pub fn layer_norm(&self, axis: isize, eps: f32) -> Result { - // normalize over axis - let axis = self.axis1(axis)?; - let mean = self - .0 - .mean_axis(axis) - .ok_or(LuaError::external( - "Failed to mean_axis Tensor: Axis length must be > 0.", - ))? - .insert_axis(axis); - - // no bias: ddof = 0.0 - let var = self.0.var_axis(axis, 0.0); - let std = (var + eps).mapv(f32::sqrt).insert_axis(axis); - - // y = (x − mean(x)) / sqrt(var(x) + eps) - Ok(Tensor(((&self.0 - &mean) / &std).into_dyn())) - } - - #[tracing::instrument(skip_all)] - pub fn truncate_axis(&self, axis: isize, len: usize) -> Result { - let axis = self.axis1(axis)?; - - let actual_len = self.0.len_of(axis).min(len); - - let mut slice_spec = Vec::with_capacity(self.0.ndim()); - - for i in 0..self.0.ndim() { - if Axis(i) == axis { - slice_spec.push(ndarray::SliceInfoElem::Slice { - start: 0, - end: Some(actual_len as isize), - step: 1, - }); - } else { - slice_spec.push(ndarray::SliceInfoElem::Slice { - start: 0, - end: None, - step: 1, - }); - } - } - - Ok(Tensor(self.0.slice(&slice_spec[..]).to_owned())) - } - - #[tracing::instrument(skip_all)] - pub fn clamp(&self, min: Option, max: Option) -> Result { - let input = self - .0 - .as_slice() - .ok_or_else(|| LuaError::external("Array must be contiguous"))?; - - let mut out = ArrayD::::zeros(self.0.raw_dim()); - let out_slice = out - .as_slice_mut() - .ok_or_else(|| LuaError::external("Failed to fetch output slice"))?; - - // NaN bound policy: if any bound is NaN, everything becomes NaN. For IEEE-754 compliance :d - if min.is_some_and(f32::is_nan) || max.is_some_and(f32::is_nan) { - for dst in out_slice.iter_mut() { - *dst = f32::NAN; - } - return Ok(Self(out)); - } - - match (min, max) { - (Some(lo), Some(hi)) => { - for (dst, &src) in out_slice.iter_mut().zip(input.iter()) { - *dst = src.max(lo).min(hi); - } - } - (Some(lo), None) => { - for (dst, &src) in out_slice.iter_mut().zip(input.iter()) { - *dst = src.max(lo); - } - } - (None, Some(hi)) => { - for (dst, &src) in out_slice.iter_mut().zip(input.iter()) { - *dst = src.min(hi); - } - } - (None, None) => { - out_slice.copy_from_slice(input); - } - } - - Ok(Self(out)) - } - - #[tracing::instrument(skip_all)] - pub fn mean_pool(&self, Tensor(mask): Tensor) -> Result { - assert_eq!(self.0.ndim(), mask.ndim() + 1); - - let ndim = self.0.ndim(); - - // Expand mask by adding the last axis back - let mut mask_expanded = mask.clone(); - mask_expanded = mask_expanded.insert_axis(Axis(ndim - 1)); - - // Broadcast mask to full data shape - let mask_broadcast = mask_expanded - .broadcast(self.0.shape()) - .ok_or(LuaError::external(format!( - "cannot broadcast shape {:?} to {:?}", - mask_expanded.shape(), - self.0.shape() - )))?; - - // Multiply and sum over sequence dims (axes 1..ndim-1) - let weighted = &self.0 * &mask_broadcast; - - // All axes except the last one and the batch axis - let mut axes_to_reduce = Vec::new(); - for ax in 1..(ndim - 1) { - axes_to_reduce.push(ax); - } - - // Sum weighted values - let mut sum = weighted.clone(); - for ax in axes_to_reduce.iter().rev() { - sum = sum.sum_axis(Axis(*ax)); - } - - // Sum mask the same way -> counts - let mut count = mask_expanded.clone(); - for ax in axes_to_reduce.iter().rev() { - count = count.sum_axis(Axis(*ax)); - } - - // Final: divide elementwise - Ok(Self(&sum / &count)) - } - - #[tracing::instrument(skip_all)] - fn fold_axis(&self, axis: isize, acc: f32, func: LuaFunction) -> Result { - let axis = self.axis1(axis)?; - - let mut out = Vec::new(); - - for subview in self.0.axis_iter(axis) { - let mut acc = acc; - - for &x in subview.iter() { - acc = func.call((acc, x)).map_err(LuaError::external)?; - } - - out.push(acc); - } - - let result = Array1::from_shape_vec(out.len(), out) - .expect("Failed to recast results") - .into_dyn(); - - Ok(Tensor(result)) - } - - #[tracing::instrument] - fn map_axis(&self, axis: isize, func: LuaFunction) -> Result { - let axis = self.axis1(axis)?; - - // Pre-size by number of subviews, NOT tensor length. - let out_len = self.0.shape()[axis.0]; - let mut out = Vec::with_capacity(out_len); - - for subview in self.0.axis_iter(axis) { - // Only ONE allocation: convert subview into Tensor for Lua - let tensor_arg = Tensor(subview.to_owned().into_dyn()); - let mapped: Tensor = func.call(tensor_arg).map_err(LuaError::external)?; - out.push(mapped.0); // store raw ArrayD, not Tensor - } +#[test] +fn test_from_lua_create_table() { + let lua = load_env(); - // Stack views without re-wrapping as Tensor - let views: Vec<_> = out.iter().map(|a| a.view()).collect(); + let tbl: LuaTable = lua + .load("return {{1, 1, 1}, {1, 1, 1}, {1, 1, 1}}") + .eval() + .unwrap(); - let stacked = ndarray::stack(axis, &views) - .map_err(|e| LuaError::external(format!("stack error: {e}")))?; + let tensor = Tensor::from_lua(LuaValue::Table(tbl), &lua).expect("Failed to create tensor"); - Ok(Tensor(stacked)) - } - - #[tracing::instrument(skip_all)] - fn sum(&self) -> Result { - Ok(self.0.sum()) - } - - #[tracing::instrument(skip_all)] - fn sum_axis(&self, axis: isize) -> Result { - Ok(Self(self.0.sum_axis(self.axis1(axis)?))) - } - - #[tracing::instrument(skip_all)] - fn min(&self) -> Result { - self.0 - .min() - .copied() - .map_err(|e| LuaError::external(format!("Min max error: {e}"))) - } - - #[tracing::instrument(skip_all)] - fn max(&self) -> Result { - self.0 - .max() - .copied() - .map_err(|e| LuaError::external(format!("Min max error: {e}"))) - } - - #[tracing::instrument(skip_all)] - fn exp(&self) -> Result { - Ok(Self(self.0.exp())) - } - - #[tracing::instrument(skip_all)] - fn lp_normalize(&self, p: f32, axis: isize) -> Result { - if self.0.is_empty() { - return Err(LuaError::external("Cannot normalize an empty tensor")); - } - if p == 0.0 { - return Err(LuaError::external("p cannot equal 0.0")); - } - - let axis = self.axis1(axis)?; - let arr = &self.0; + assert_eq!(tensor.0.ndim(), 2); + assert_eq!(tensor.0.shape(), [3, 3]); +} - // Compute Lp norm along axis - let norms = arr.map_axis(axis, |subview| { - subview - .iter() - .map(|&v| v.abs().powf(p)) - .sum::() - .powf(1.0 / p) - }); +#[test] +fn test_from_lua_empty_table() { + let lua = load_env(); - // Avoid division by zero using in-place broadcast clamp - let norms = norms.mapv(|x| if x < 1e-12 { 1e-12 } else { x }); + let tbl: LuaTable = lua.load("return {}").eval().unwrap(); - // Broadcast division using ndarray’s broadcasting API - let normalized = arr / &norms.insert_axis(axis); + let Tensor(tensor) = Tensor::from_lua(LuaValue::Table(tbl), &lua).unwrap(); - Ok(Self(normalized)) - } + assert!(tensor.is_empty()); + assert_eq!(tensor.ndim(), 1); +} - fn axis1(&self, axis: isize) -> Result { - if axis <= 0 { - return Err(LuaError::external("Axis must be >= 1.")); - } +#[test] +fn test_from_lua_ragged() { + let lua = load_env(); - let axis_index = (axis - 1) as usize; + let tbl: LuaTable = lua + .load("return {{1, 1, 1}, {1, 1, 1}, {1, 1}}") + .eval() + .unwrap(); - if axis_index >= self.0.ndim() { - return Err(LuaError::external("Axis out of range.")); - } + let tensor = Tensor::from_lua(LuaValue::Table(tbl), &lua); - Ok(Axis(axis_index)) - } + assert!(tensor.is_err()); +} - #[tracing::instrument(skip_all)] - fn transpose(&self) -> Result { - Ok(Self(self.0.t().to_owned())) - } +#[test] +fn test_from_lua_bad_type() { + let lua = load_env(); - #[tracing::instrument(skip_all)] - fn len(&self) -> usize { - self.0.len() - } + let tbl: LuaString = lua.load("return \"i am not a table\"").eval().unwrap(); - #[tracing::instrument(skip_all)] - fn std(&self, ddof: f32) -> Result { - Ok(self.0.std(ddof)) - } + let tensor = Tensor::from_lua(LuaValue::String(tbl), &lua); - #[tracing::instrument(skip_all)] - fn mean(&self) -> Result, LuaError> { - Ok(self.0.mean()) - } + assert!(tensor.is_err()); +} - #[tracing::instrument(skip_all)] - fn ndim(&self) -> Result { - Ok(self.0.ndim()) - } +#[test] +fn test_from_lua_bad_type_err() { + let lua = load_env(); - #[tracing::instrument(skip_all)] - fn softmax(&self, axis: isize) -> Result { - let axis = self.axis1(axis)?; + let val = LuaValue::Boolean(false); - let max_vals = self.0.map_axis(axis, |row| { - row.iter().fold(f32::NEG_INFINITY, |m, &v| m.max(v)) - }); + let tensor = Tensor::from_lua(val, &lua); - let z = &self.0 - &max_vals.insert_axis(axis); + assert!(tensor.is_err()); +} - let numerator = z.mapv(|x| x.exp()); +#[test] +fn test_eq_simple() { + use ndarray::Array2; - let denom = numerator.map_axis(axis, |row| row.sum()); + let lua = load_env(); - Ok(Tensor(numerator / &denom.insert_axis(axis))) - } -} + let arr1 = Tensor(Array2::::ones((3, 3)).into_dyn()); + let arr2 = arr1.clone(); -#[tracing::instrument(skip_all)] -fn add(Tensor(this): &Tensor, other: LuaValue) -> Result { - let new = match other { - LuaValue::UserData(user_data) => { - let Tensor(oth) = user_data.borrow::()?.to_owned(); - - if !is_broadcastable(this.shape(), oth.shape()) { - return Err(LuaError::external(format!( - "Shape {:?} not broadcastable to {:?}", - this.shape(), - oth.shape() - ))); - } + assert!(arr1 == arr2); - this + oth - } - LuaValue::Number(n) => this + (n as f32), - LuaValue::Integer(i) => this + (i as f32), - _ => return Err(LuaError::external("Expected either number or Tensor.")), - }; + let result: bool = lua + .load("return function(x, y) return x == y end") + .eval::() + .unwrap() + .call((arr1, arr2)) + .expect("Failed to evaluate"); - Ok(Tensor(new)) + assert!(result); } -#[tracing::instrument(skip_all)] -fn sub(Tensor(this): &Tensor, other: LuaValue) -> Result { - let new = match other { - LuaValue::UserData(user_data) => { - let Tensor(oth) = user_data.borrow::()?.to_owned(); - - if !is_broadcastable(oth.shape(), this.shape()) { - return Err(LuaError::external(format!( - "Shape {:?} not broadcastable to {:?}", - this.shape(), - oth.shape() - ))); - } +#[test] +fn test_neq_simple() { + use ndarray::Array2; - this - oth - } - LuaValue::Number(n) => this - (n as f32), - LuaValue::Integer(i) => this - (i as f32), - _ => return Err(LuaError::external("Expected either number or Tensor.")), - }; + let lua = load_env(); - Ok(Tensor(new)) -} + let arr1 = Tensor(Array2::::ones((3, 3)).into_dyn()); + let arr2 = Tensor(Array2::::zeros((3, 3)).into_dyn()); -#[tracing::instrument(skip_all)] -fn mul(Tensor(this): &Tensor, other: LuaValue) -> Result { - let new = match other { - LuaValue::UserData(user_data) => { - let Tensor(oth) = user_data.borrow::()?.to_owned(); - - if !is_broadcastable(this.shape(), oth.shape()) { - return Err(LuaError::external(format!( - "Shape {:?} not broadcastable to {:?}", - this.shape(), - oth.shape() - ))); - } + assert!(arr1 != arr2); - this * oth - } - LuaValue::Number(n) => this * (n as f32), - LuaValue::Integer(i) => this * (i as f32), - _ => return Err(LuaError::external("Expected either number or Tensor.")), - }; + let result: bool = lua + .load("return function(x, y) return x == y end") + .eval::() + .unwrap() + .call((arr1, arr2)) + .expect("Failed to evaluate"); - Ok(Tensor(new)) + assert!(!result); } -#[tracing::instrument(skip_all)] -fn div(Tensor(this): &Tensor, other: LuaValue) -> Result { - let new = match other { - LuaValue::UserData(user_data) => { - let Tensor(oth) = user_data.borrow::()?.to_owned(); - - if !is_broadcastable(oth.shape(), this.shape()) { - return Err(LuaError::external(format!( - "Shape {:?} not broadcastable to {:?}", - this.shape(), - oth.shape() - ))); - } - - this / oth - } - LuaValue::Number(n) => this / (n as f32), - LuaValue::Integer(i) => this / (i as f32), - _ => return Err(LuaError::external("Expected either number or Tensor.")), - }; +#[test] +fn test_to_string() { + use ndarray::Array2; - Ok(Tensor(new)) -} + let lua = load_env(); -#[tracing::instrument(skip_all)] -fn is_broadcastable(a: &[usize], b: &[usize]) -> bool { - let ndim = a.len().max(b.len()); + let vec = Tensor(Array2::::ones((3, 3)).into_dyn()); + let vec_str_gold = vec.0.to_string(); - for i in 0..ndim { - let ad = *a.get(a.len().wrapping_sub(i + 1)).unwrap_or(&1); - let bd = *b.get(b.len().wrapping_sub(i + 1)).unwrap_or(&1); + let vec_str: String = lua + .globals() + .get::("tostring") + .unwrap() + .call(vec) + .unwrap(); - if ad != bd && ad != 1 && bd != 1 { - return false; - } - } - true + assert_eq!(vec_str, vec_str_gold); } diff --git a/encoderfile/src/transforms/tensor/ops/arithm.rs b/encoderfile/src/transforms/tensor/ops/arithm.rs new file mode 100644 index 00000000..470961de --- /dev/null +++ b/encoderfile/src/transforms/tensor/ops/arithm.rs @@ -0,0 +1,389 @@ +use super::Tensor; +use super::properties::is_broadcastable; +use mlua::prelude::*; + +impl Tensor { + #[tracing::instrument(skip_all)] + pub fn exp(&self) -> Result { + Ok(Self(self.0.exp())) + } +} + +#[tracing::instrument(skip_all)] +pub fn add(Tensor(this): &Tensor, other: LuaValue) -> Result { + let new = match other { + LuaValue::UserData(user_data) => { + let Tensor(oth) = user_data.borrow::()?.to_owned(); + + if !is_broadcastable(this.shape(), oth.shape()) { + return Err(LuaError::external(format!( + "Shape {:?} not broadcastable to {:?}", + this.shape(), + oth.shape() + ))); + } + + this + oth + } + LuaValue::Number(n) => this + (n as f32), + LuaValue::Integer(i) => this + (i as f32), + _ => return Err(LuaError::external("Expected either number or Tensor.")), + }; + + Ok(Tensor(new)) +} + +#[tracing::instrument(skip_all)] +pub fn sub(Tensor(this): &Tensor, other: LuaValue) -> Result { + let new = match other { + LuaValue::UserData(user_data) => { + let Tensor(oth) = user_data.borrow::()?.to_owned(); + + if !is_broadcastable(oth.shape(), this.shape()) { + return Err(LuaError::external(format!( + "Shape {:?} not broadcastable to {:?}", + this.shape(), + oth.shape() + ))); + } + + this - oth + } + LuaValue::Number(n) => this - (n as f32), + LuaValue::Integer(i) => this - (i as f32), + _ => return Err(LuaError::external("Expected either number or Tensor.")), + }; + + Ok(Tensor(new)) +} + +#[tracing::instrument(skip_all)] +pub fn mul(Tensor(this): &Tensor, other: LuaValue) -> Result { + let new = match other { + LuaValue::UserData(user_data) => { + let Tensor(oth) = user_data.borrow::()?.to_owned(); + + if !is_broadcastable(this.shape(), oth.shape()) { + return Err(LuaError::external(format!( + "Shape {:?} not broadcastable to {:?}", + this.shape(), + oth.shape() + ))); + } + + this * oth + } + LuaValue::Number(n) => this * (n as f32), + LuaValue::Integer(i) => this * (i as f32), + _ => return Err(LuaError::external("Expected either number or Tensor.")), + }; + + Ok(Tensor(new)) +} + +#[tracing::instrument(skip_all)] +pub fn div(Tensor(this): &Tensor, other: LuaValue) -> Result { + let new = match other { + LuaValue::UserData(user_data) => { + let Tensor(oth) = user_data.borrow::()?.to_owned(); + + if !is_broadcastable(oth.shape(), this.shape()) { + return Err(LuaError::external(format!( + "Shape {:?} not broadcastable to {:?}", + this.shape(), + oth.shape() + ))); + } + + this / oth + } + LuaValue::Number(n) => this / (n as f32), + LuaValue::Integer(i) => this / (i as f32), + _ => return Err(LuaError::external("Expected either number or Tensor.")), + }; + + Ok(Tensor(new)) +} + +#[cfg(test)] +mod tests { + use super::*; + use mlua::prelude::{Lua, LuaValue}; + + fn tensor(data: Vec, shape: &[usize]) -> Tensor { + Tensor(ndarray::ArrayD::from_shape_vec(shape, data).unwrap()) + } + + fn lua_number(n: f64) -> LuaValue { + LuaValue::Number(n) + } + + fn lua_tensor(t: Tensor, lua: &Lua) -> LuaValue { + mlua::Value::UserData(lua.create_userdata(t).unwrap()) + } + + macro_rules! generate_ops_test { + ($mod_name:ident, $op:tt, $rust_fn:ident, $lua_op:expr) => { + mod $mod_name { + + #[test] + fn test_binding() { + use crate::transforms::tensor::load_env; + use super::Tensor; + use super::$rust_fn; + use ndarray::Array2; + use mlua::prelude::{LuaValue, LuaFunction}; + + let lua = load_env(); + let arr1 = Tensor(Array2::::ones((3, 3)).into_dyn()); + let arr2 = arr1.clone(); + + let gold_val = $rust_fn( + &arr1, + LuaValue::UserData(lua.create_userdata(arr2.clone()).unwrap()) + ).expect("Failed to compute"); + + let result: Tensor = lua.load(format!("return function(x, y) return x {} y end", $lua_op)) + .eval::() + .unwrap() + .call((arr1, arr2)) + .expect("Binding failed"); + + assert_eq!(result, gold_val); + } + + #[test] + fn test_tensor() { + use crate::transforms::tensor::load_env; + use super::Tensor; + use ndarray::Array2; + use mlua::prelude::LuaValue; + use super::$rust_fn; + + let lua = load_env(); + let arr1 = Tensor(Array2::::ones((3, 3)).into_dyn()); + let arr2 = arr1.clone(); + + let val = LuaValue::UserData(lua.create_userdata(arr1.clone()).unwrap()); + let result = $rust_fn(&arr1, val).unwrap(); + + let gold = &arr1.0 $op &arr2.0; + + assert_eq!(gold, result.0); + } + + #[test] + fn test_number() { + use super::Tensor; + use ndarray::Array2; + use mlua::prelude::LuaValue; + use super::$rust_fn; + + let arr1 = Tensor(Array2::::ones((3, 3)).into_dyn()); + + let gold_sum = &arr1.0 $op Array2::::from_elem((3, 3), 5.0); + + let result = $rust_fn(&arr1, LuaValue::Number(5.0)).unwrap(); + + assert_eq!(gold_sum, result.0); + } + + #[test] + fn test_integer() { + use super::Tensor; + use ndarray::Array2; + use mlua::prelude::LuaValue; + use super::$rust_fn; + + let arr1 = Tensor(Array2::::ones((3, 3)).into_dyn()); + + let gold_sum = &arr1.0 $op Array2::::from_elem((3, 3), 5.0); + + let result = $rust_fn(&arr1, LuaValue::Integer(5)).unwrap(); + + assert_eq!(gold_sum, result.0); + } + + #[test] + fn test_bad_dtype() { + use super::Tensor; + use ndarray::Array2; + use mlua::prelude::{LuaValue, LuaError}; + use super::$rust_fn; + + let arr1 = Tensor(Array2::::ones((3, 3)).into_dyn()); + + let result: Result = $rust_fn(&arr1, LuaValue::Boolean(false)); + + assert!(result.is_err()); + } + } + } + } + + generate_ops_test!( + test_addition, +, add, "+" + ); + + generate_ops_test!( + test_subtraction, -, sub, "-" + ); + + generate_ops_test!( + test_multiplication, *, mul, "*" + ); + + generate_ops_test!( + test_division, /, div, "/" + ); + + #[test] + fn test_add_broadcast_success() { + let lua = Lua::new(); + + // (2, 3) + (3,) → OK via broadcasting + let a = tensor(vec![1., 2., 3., 4., 5., 6.], &[2, 3]); + let b = tensor(vec![10., 20., 30.], &[3]); + + let res = add(&a, lua_tensor(b, &lua)).unwrap(); + assert_eq!( + res.0, + ndarray::arr2(&[[11., 22., 33.], [14., 25., 36.]]).into_dyn() + ); + } + + #[test] + fn test_add_broadcast_failure() { + let lua = Lua::new(); + + // (2, 3) + (2,) → NOT broadcastable because trailing dims mismatch + let a = tensor(vec![1., 2., 3., 4., 5., 6.], &[2, 3]); + let b = tensor(vec![1., 2.], &[2]); + + let err = add(&a, lua_tensor(b, &lua)).unwrap_err(); + let msg = format!("{err}"); + assert!(msg.contains("not broadcastable"), "Got: {msg}"); + } + + #[test] + fn test_sub_broadcast_success() { + let lua = Lua::new(); + + // (3, 1) - (3,) → OK (result is (3,3)) + let a = tensor(vec![1., 2., 3.], &[3, 1]); + let b = tensor(vec![1., 10., 100.], &[3]); + + let res = sub(&a, lua_tensor(b, &lua)).unwrap(); + assert_eq!( + res.0, + ndarray::arr2(&[[0., -9., -99.], [1., -8., -98.], [2., -7., -97.]]).into_dyn() + ); + } + + #[test] + fn test_sub_broadcast_failure() { + let lua = Lua::new(); + + // (3,2) - (3,) → failure: trailing dim (2 vs 3) + let a = tensor(vec![1., 2., 3., 4., 5., 6.], &[3, 2]); + let b = tensor(vec![1., 2., 3.], &[3]); + + let err = sub(&a, lua_tensor(b, &lua)).unwrap_err(); + assert!(format!("{err}").contains("not broadcastable")); + } + + #[test] + fn test_mul_broadcast_success() { + // (2,3) * scalar → always OK + let a = tensor(vec![1., 2., 3., 4., 5., 6.], &[2, 3]); + let res = mul(&a, lua_number(2.0)).unwrap(); + + assert_eq!( + res.0, + ndarray::arr2(&[[2., 4., 6.], [8., 10., 12.]]).into_dyn() + ); + } + + #[test] + fn test_mul_broadcast_shape_success() { + let lua = Lua::new(); + + // (4,1) * (1,3) → → (4,3) + let a = tensor(vec![1., 2., 3., 4.], &[4, 1]); + let b = tensor(vec![10., 20., 30.], &[1, 3]); + + let res = mul(&a, lua_tensor(b, &lua)).unwrap(); + + assert_eq!( + res.0, + ndarray::arr2(&[ + [10., 20., 30.], + [20., 40., 60.], + [30., 60., 90.], + [40., 80., 120.] + ]) + .into_dyn() + ); + } + + #[test] + fn test_mul_broadcast_fail() { + let lua = Lua::new(); + + // (2,2) * (3,) → cannot broadcast trailing dims + let a = tensor(vec![1., 2., 3., 4.], &[2, 2]); + let b = tensor(vec![1., 2., 3.], &[3]); + + let err = mul(&a, lua_tensor(b, &lua)).unwrap_err(); + assert!(format!("{err}").contains("not broadcastable")); + } + + #[test] + fn test_div_broadcast_success() { + let lua = Lua::new(); + + // (3,3) / (3,) → OK + let a = tensor((1..=9).map(|x| x as f32).collect(), &[3, 3]); + let b = tensor(vec![1., 2., 3.], &[3]); + + let res = div(&a, lua_tensor(b, &lua)).unwrap(); + + assert_eq!( + res.0, + ndarray::arr2(&[ + [1.0 / 1., 2.0 / 2., 3.0 / 3.], + [4.0 / 1., 5.0 / 2., 6.0 / 3.], + [7.0 / 1., 8.0 / 2., 9.0 / 3.], + ]) + .into_dyn() + ); + } + + #[test] + fn test_div_broadcast_fail() { + let lua = Lua::new(); + + // (2,3) vs (2,) again → nope + let a = tensor(vec![1., 2., 3., 4., 5., 6.], &[2, 3]); + let b = tensor(vec![1., 2.], &[2]); + + let err = div(&a, lua_tensor(b, &lua)).unwrap_err(); + assert!(format!("{err}").contains("not broadcastable")); + } + + #[test] + fn test_exp() { + use ndarray::Array2; + + let arr = Array2::ones((3, 3)).into_dyn(); + let tensor = Tensor(arr.clone()); + assert_eq!(tensor.exp().unwrap(), Tensor(arr.mapv(f32::exp))); + } + + #[test] + fn test_exp_empty() { + let tensor = Tensor(ndarray::array![[[]]].into_dyn()); + let Tensor(exp) = tensor.exp().unwrap(); + assert!(exp.is_empty()); + } +} diff --git a/encoderfile/src/transforms/tensor/ops/axes.rs b/encoderfile/src/transforms/tensor/ops/axes.rs new file mode 100644 index 00000000..6ce7e9bb --- /dev/null +++ b/encoderfile/src/transforms/tensor/ops/axes.rs @@ -0,0 +1,19 @@ +use super::Tensor; +use mlua::prelude::*; +use ndarray::Axis; + +impl Tensor { + pub fn axis1(&self, axis: isize) -> Result { + if axis <= 0 { + return Err(LuaError::external("Axis must be >= 1.")); + } + + let axis_index = (axis - 1) as usize; + + if axis_index >= self.0.ndim() { + return Err(LuaError::external("Axis out of range.")); + } + + Ok(Axis(axis_index)) + } +} diff --git a/encoderfile/src/transforms/tensor/ops/clamp.rs b/encoderfile/src/transforms/tensor/ops/clamp.rs new file mode 100644 index 00000000..a9827f4b --- /dev/null +++ b/encoderfile/src/transforms/tensor/ops/clamp.rs @@ -0,0 +1,151 @@ +use super::Tensor; +use mlua::prelude::*; +use ndarray::ArrayD; + +impl Tensor { + #[tracing::instrument(skip_all)] + pub fn clamp(&self, min: Option, max: Option) -> Result { + let input = self + .0 + .as_slice() + .ok_or_else(|| LuaError::external("Array must be contiguous"))?; + + let mut out = ArrayD::::zeros(self.0.raw_dim()); + let out_slice = out + .as_slice_mut() + .ok_or_else(|| LuaError::external("Failed to fetch output slice"))?; + + // NaN bound policy: if any bound is NaN, everything becomes NaN. For IEEE-754 compliance :d + if min.is_some_and(f32::is_nan) || max.is_some_and(f32::is_nan) { + for dst in out_slice.iter_mut() { + *dst = f32::NAN; + } + return Ok(Self(out)); + } + + match (min, max) { + (Some(lo), Some(hi)) => { + for (dst, &src) in out_slice.iter_mut().zip(input.iter()) { + *dst = src.max(lo).min(hi); + } + } + (Some(lo), None) => { + for (dst, &src) in out_slice.iter_mut().zip(input.iter()) { + *dst = src.max(lo); + } + } + (None, Some(hi)) => { + for (dst, &src) in out_slice.iter_mut().zip(input.iter()) { + *dst = src.min(hi); + } + } + (None, None) => { + out_slice.copy_from_slice(input); + } + } + + Ok(Self(out)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_clamp_correctness() { + let tensor = Tensor(ndarray::array!([-5.0, -1.0, 0.0, 1.0, 5.0]).into_dyn()); + let result = tensor + .clamp(Some(-1.0), Some(1.0)) + .expect("Failed to clamp"); + let expected = Tensor(ndarray::array!([-1.0, -1.0, 0.0, 1.0, 1.0]).into_dyn()); + assert_eq!(result.0, expected.0); + } + + #[test] + fn test_clamp_lower_bound_only() { + let tensor = Tensor(ndarray::array!([-3.0, 0.0, 2.0]).into_dyn()); + let result = tensor + .clamp(Some(0.0), None) + .expect("Failed to clamp tensor"); + let expected = Tensor(ndarray::array!([0.0, 0.0, 2.0]).into_dyn()); + assert_eq!(result.0, expected.0); + } + + #[test] + fn test_clamp_upper_bound_only() { + let tensor = Tensor(ndarray::array!([-3.0, 0.0, 2.0, 5.0]).into_dyn()); + let result = tensor + .clamp(None, Some(2.0)) + .expect("Failed to clamp tensor"); + let expected = Tensor(ndarray::array!([-3.0, 0.0, 2.0, 2.0]).into_dyn()); + assert_eq!(result.0, expected.0); + } + + #[test] + fn test_clamp_infinite_bounds() { + let tensor = Tensor(ndarray::array!([-3.0, 0.0, 2.0, 5.0]).into_dyn()); + let result = tensor + .clamp(Some(f32::NEG_INFINITY), Some(f32::INFINITY)) + .expect("Failed to clamp tensor"); + let expected = Tensor(ndarray::array!([-3.0, 0.0, 2.0, 5.0]).into_dyn()); + assert_eq!(result.0, expected.0); + } + + #[test] + fn test_clamp_multidimensional() { + let tensor = + Tensor(ndarray::array!([[-3.0, 3.0], [0.0, 0.0], [2.0, 2.0], [5.0, 5.0]]).into_dyn()); + let expected_shape = tensor.0.shape().to_owned(); + + let result = tensor + .clamp(Some(0.0), Some(1.0)) + .expect("Failed to clamp tensor"); + + let expected = + Tensor(ndarray::array!([[0.0, 1.0], [0.0, 0.0], [1.0, 1.0], [1.0, 1.0]]).into_dyn()); + + assert_eq!(result.0.shape(), expected_shape.as_slice()); + assert_eq!(result.0, expected.0); + } + + #[test] + fn test_clamp_identity() { + let tensor = Tensor(ndarray::array!([-3.0, 0.0, 2.0, 5.0]).into_dyn()); + let result = tensor.clamp(None, None).expect("Failed to clamp tensor"); + assert_eq!(result.0, tensor.0); + } + + #[test] + fn test_clamp_min_equals_max() { + let tensor = Tensor(ndarray::array!([0.0, 3.0, 10.0]).into_dyn()); + let result = tensor + .clamp(Some(3.0), Some(3.0)) + .expect("Failed to clamp tensor"); + let expected = Tensor(ndarray::array!([3.0, 3.0, 3.0]).into_dyn()); + assert_eq!(result.0, expected.0); + } + + #[test] + fn test_clamp_inverted_bounds() { + let tensor = Tensor(ndarray::array!([0.0, 3.0, 10.0]).into_dyn()); + let result = tensor + .clamp(Some(5.0), Some(2.0)) + .expect("Failed to clamp tensor"); + let expected = Tensor(ndarray::array!([2.0, 2.0, 2.0]).into_dyn()); + assert_eq!(result.0, expected.0); + } + + #[test] + fn test_clamp_nan() { + // clamping with NaN bounds nuke the entire tensor. Just so that we have no surprises later ;) + let tensor = Tensor(ndarray::array!([0.0, 3.0, 10.0]).into_dyn()); + let result = tensor + .clamp(Some(f32::NAN), Some(f32::NAN)) + .expect("Failed to clamp tensor"); + let expected = Tensor(ndarray::array!([f32::NAN, f32::NAN, f32::NAN]).into_dyn()); + for (a, b) in result.0.iter().zip(expected.0.iter()) { + assert!(a.is_nan() && b.is_nan()); + } + } +} diff --git a/encoderfile/src/transforms/tensor/ops/fold_axis.rs b/encoderfile/src/transforms/tensor/ops/fold_axis.rs new file mode 100644 index 00000000..f5d07e9b --- /dev/null +++ b/encoderfile/src/transforms/tensor/ops/fold_axis.rs @@ -0,0 +1,77 @@ +use super::Tensor; +use mlua::prelude::*; +use ndarray::Array1; + +impl Tensor { + #[tracing::instrument(skip_all)] + pub fn fold_axis(&self, axis: isize, acc: f32, func: LuaFunction) -> Result { + let axis = self.axis1(axis)?; + + let mut out = Vec::new(); + + for subview in self.0.axis_iter(axis) { + let mut acc = acc; + + for &x in subview.iter() { + acc = func.call((acc, x)).map_err(LuaError::external)?; + } + + out.push(acc); + } + + let result = Array1::from_shape_vec(out.len(), out) + .expect("Failed to recast results") + .into_dyn(); + + Ok(Tensor(result)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn fold_axis_sum_rows() -> LuaResult<()> { + use crate::transforms::tensor::load_env; + let lua = load_env(); + let arr = ndarray::array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]].into_dyn(); + let t = Tensor(arr); + + let func: LuaFunction = lua + .load( + r#" + return function(acc, x) return acc + x end + "#, + ) + .eval()?; + + let res = t.fold_axis(1, 0.0, func)?; // fold each row + let v = res.0.into_dimensionality::().unwrap(); + + assert_eq!(v.as_slice().unwrap(), &[6.0, 15.0]); + Ok(()) + } + + #[test] + fn fold_axis_product() -> LuaResult<()> { + use crate::transforms::tensor::load_env; + let lua = load_env(); + let arr = ndarray::array![[1.0, 2.0], [3.0, 4.0]].into_dyn(); + let t = Tensor(arr); + + let func: LuaFunction = lua + .load( + r#" + return function(acc, x) return acc * x end + "#, + ) + .eval()?; + + let res = t.fold_axis(1, 1.0, func)?; // multiply across each row + let v = res.0.into_dimensionality::().unwrap(); + + assert_eq!(v.as_slice().unwrap(), &[2.0, 12.0]); + Ok(()) + } +} diff --git a/encoderfile/src/transforms/tensor/ops/layer_norm.rs b/encoderfile/src/transforms/tensor/ops/layer_norm.rs new file mode 100644 index 00000000..ea1fbaa8 --- /dev/null +++ b/encoderfile/src/transforms/tensor/ops/layer_norm.rs @@ -0,0 +1,87 @@ +use super::Tensor; +use mlua::prelude::*; + +impl Tensor { + #[tracing::instrument(skip_all)] + pub fn layer_norm(&self, axis: isize, eps: f32) -> Result { + // normalize over axis + let axis = self.axis1(axis)?; + let mean = self + .0 + .mean_axis(axis) + .ok_or(LuaError::external( + "Failed to mean_axis Tensor: Axis length must be > 0.", + ))? + .insert_axis(axis); + + // no bias: ddof = 0.0 + let var = self.0.var_axis(axis, 0.0); + let std = (var + eps).mapv(f32::sqrt).insert_axis(axis); + + // y = (x − mean(x)) / sqrt(var(x) + eps) + Ok(Tensor(((&self.0 - &mean) / &std).into_dyn())) + } +} + +#[cfg(test)] +mod tests { + use super::super::arithm; + use super::*; + + #[test] + fn test_layer_norm_correctness() { + let input = Tensor(ndarray::array![[1.0, 2.0, 3.0], [10.0, 20.0, 30.0]].into_dyn()); + + let result = input + .layer_norm(2, 1e-5) + .expect("Failed to compute layer_norm"); + + for row in result.0.rows() { + let m = row.mean().unwrap(); + let v = row.var(0.0); + // mean should be roughly equal to 0 + assert!((m - 0.0).abs() < 1e-5); + // variance tolerances are always a bit looser, but should roughly equal 1.0 + assert!((v - 1.0).abs() < 1e-4); + } + } + + #[test] + fn test_layer_norm_epsilon_behavior() { + let input = Tensor(ndarray::array![[5.0, 5.0, 5.0]].into_dyn()); + let result = input + .layer_norm(2, 1e-5) + .expect("Failed to compute layer_norm"); + + // nothing should blow up or be NaNs + assert!(result.0.iter().all(|v| v.is_finite())); + } + + #[test] + fn test_layer_norm_dimensionality() { + use ndarray::Array3; + let input = Tensor(Array3::from_elem([10, 10, 10], 3.0).into_dyn()); + let result = input + .layer_norm(2, 1e-5) + .expect("Failed to compute layer_norm"); + assert_eq!(input.0.dim(), result.0.dim()) + } + + #[test] + fn test_layer_norm_translation() { + // layer_norm should be invariant to additive bias per row + let input_1 = Tensor(ndarray::array![[1.0, 2.0, 3.0, 4.0, 5.0]].into_dyn()); + let input_2 = + arithm::add(&input_1, LuaValue::Number(5.0)).expect("Scalar transformation failed"); + let layer_norm_1 = input_1 + .layer_norm(2, 1e-5) + .expect("Failed to compute layer_norm for input_1"); + let layer_norm_2 = input_2 + .layer_norm(2, 1e-5) + .expect("Failed to compute layer_norm for input_2"); + + for (a, b) in layer_norm_1.0.iter().zip(layer_norm_2.0.iter()) { + assert!((a - b).abs() < 1e-4, "mismatch: {a} vs {b}"); + } + } +} diff --git a/encoderfile/src/transforms/tensor/ops/lp_normalize.rs b/encoderfile/src/transforms/tensor/ops/lp_normalize.rs new file mode 100644 index 00000000..ff16f1b3 --- /dev/null +++ b/encoderfile/src/transforms/tensor/ops/lp_normalize.rs @@ -0,0 +1,65 @@ +use super::Tensor; +use mlua::prelude::*; + +impl Tensor { + #[tracing::instrument(skip_all)] + pub fn lp_normalize(&self, p: f32, axis: isize) -> Result { + if self.0.is_empty() { + return Err(LuaError::external("Cannot normalize an empty tensor")); + } + if p == 0.0 { + return Err(LuaError::external("p cannot equal 0.0")); + } + + let axis = self.axis1(axis)?; + let arr = &self.0; + + // Compute Lp norm along axis + let norms = arr.map_axis(axis, |subview| { + subview + .iter() + .map(|&v| v.abs().powf(p)) + .sum::() + .powf(1.0 / p) + }); + + // Avoid division by zero using in-place broadcast clamp + let norms = norms.mapv(|x| if x < 1e-12 { 1e-12 } else { x }); + + // Broadcast division using ndarray’s broadcasting API + let normalized = arr / &norms.insert_axis(axis); + + Ok(Self(normalized)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_lp_norm_empty() { + use ndarray::ArrayD; + let arr: ArrayD = ndarray::array![[[]]].into_dyn(); + + assert!(arr.is_empty()); + assert!(Tensor(arr).lp_normalize(1.0, 1).is_err()) + } + + #[test] + fn test_lp_norm_zero() { + use ndarray::{Array3, ArrayD}; + let arr: ArrayD = Array3::ones((3, 3, 3)).into_dyn(); + + assert!(Tensor(arr).lp_normalize(0.0, 1).is_err()) + } + + #[test] + fn test_lp_norm_nonexistent_dim() { + use ndarray::{Array3, ArrayD}; + let arr: ArrayD = Array3::ones((3, 3, 3)).into_dyn(); + + assert!(Tensor(arr.clone()).lp_normalize(1.0, 0).is_err()); // lua starts with 1 + assert!(Tensor(arr.clone()).lp_normalize(1.0, 4).is_err()); + } +} diff --git a/encoderfile/src/transforms/tensor/ops/map_axis.rs b/encoderfile/src/transforms/tensor/ops/map_axis.rs new file mode 100644 index 00000000..70057534 --- /dev/null +++ b/encoderfile/src/transforms/tensor/ops/map_axis.rs @@ -0,0 +1,69 @@ +use super::Tensor; +use mlua::prelude::*; + +impl Tensor { + #[tracing::instrument] + pub fn map_axis(&self, axis: isize, func: LuaFunction) -> Result { + let axis = self.axis1(axis)?; + + // Pre-size by number of subviews, NOT tensor length. + let out_len = self.0.shape()[axis.0]; + let mut out = Vec::with_capacity(out_len); + + for subview in self.0.axis_iter(axis) { + // Only ONE allocation: convert subview into Tensor for Lua + let tensor_arg = Tensor(subview.to_owned().into_dyn()); + let mapped: Tensor = func.call(tensor_arg).map_err(LuaError::external)?; + out.push(mapped.0); // store raw ArrayD, not Tensor + } + + // Stack views without re-wrapping as Tensor + let views: Vec<_> = out.iter().map(|a| a.view()).collect(); + + let stacked = ndarray::stack(axis, &views) + .map_err(|e| LuaError::external(format!("stack error: {e}")))?; + + Ok(Tensor(stacked)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_map_axis_zero_transform() { + use crate::transforms::tensor::load_env; + use ndarray::Array3; + let lua = load_env(); + let tensor = Tensor(Array3::::from_elem((3, 6, 9), 1.0).into_dyn()); + + let func = lua + .load("return function(x) return x end") + .eval::() + .unwrap(); + + let result = tensor.map_axis(3, func).expect("Failed to map axis"); + + assert_eq!(tensor, result); + } + + #[test] + fn test_map_axis_double_values() { + use crate::transforms::tensor::load_env; + use ndarray::Array3; + let lua = load_env(); + let tensor = Tensor( + Array3::::from_shape_fn((2, 2, 2), |(i, j, k)| (i + j + k) as f32).into_dyn(), + ); + + let func = lua + .load("return function(x) return x * 2 end") + .eval::() + .unwrap(); + + let result = tensor.map_axis(3, func).expect("Failed to map axis"); + + assert_eq!(result.0, tensor.0 * 2.0); + } +} diff --git a/encoderfile/src/transforms/tensor/ops/mean_pool.rs b/encoderfile/src/transforms/tensor/ops/mean_pool.rs new file mode 100644 index 00000000..37069750 --- /dev/null +++ b/encoderfile/src/transforms/tensor/ops/mean_pool.rs @@ -0,0 +1,166 @@ +use super::Tensor; +use mlua::prelude::*; +use ndarray::Axis; + +impl Tensor { + #[tracing::instrument(skip_all)] + pub fn mean_pool(&self, Tensor(mask): Tensor) -> Result { + assert_eq!(self.0.ndim(), mask.ndim() + 1); + + let ndim = self.0.ndim(); + + // Expand mask by adding the last axis back + let mut mask_expanded = mask.clone(); + mask_expanded = mask_expanded.insert_axis(Axis(ndim - 1)); + + // Broadcast mask to full data shape + let mask_broadcast = mask_expanded + .broadcast(self.0.shape()) + .ok_or(LuaError::external(format!( + "cannot broadcast shape {:?} to {:?}", + mask_expanded.shape(), + self.0.shape() + )))?; + + // Multiply and sum over sequence dims (axes 1..ndim-1) + let weighted = &self.0 * &mask_broadcast; + + // All axes except the last one and the batch axis + let mut axes_to_reduce = Vec::new(); + for ax in 1..(ndim - 1) { + axes_to_reduce.push(ax); + } + + // Sum weighted values + let mut sum = weighted.clone(); + for ax in axes_to_reduce.iter().rev() { + sum = sum.sum_axis(Axis(*ax)); + } + + // Sum mask the same way -> counts + let mut count = mask_expanded.clone(); + for ax in axes_to_reduce.iter().rev() { + count = count.sum_axis(Axis(*ax)); + } + + // Final: divide elementwise + Ok(Self(&sum / &count)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn mean_pool_single_vector_no_mask() { + // shape: (batch=1, seq=1, dim=3) + let x = Tensor(ndarray::array![[[1.0, 2.0, 3.0]]].into_dyn()); + let mask = Tensor(ndarray::array![[1.0]].into_dyn()); + + let pooled = x.mean_pool(mask).unwrap(); + assert_eq!(pooled.0, ndarray::array![[1.0, 2.0, 3.0]].into_dyn()); + } + + #[test] + fn mean_pool_two_tokens_equal_weight() { + // shape: (1, 2, 3) + let x = Tensor(ndarray::array![[[1.0, 2.0, 3.0], [3.0, 2.0, 1.0]]].into_dyn()); + + let mask = Tensor(ndarray::array![[1.0, 1.0]].into_dyn()); + + let pooled = x.mean_pool(mask).unwrap(); + let expected = ndarray::array![[2.0, 2.0, 2.0]].into_dyn(); + + assert_allclose(&pooled.0, &expected); + } + + #[test] + fn mean_pool_ignores_masked_tokens() { + // shape: (1, 3, 2) + // Only the first and last token should count. + let x = Tensor( + ndarray::array![[ + [10.0, 0.0], + [99.0, 99.0], // masked out + [20.0, 0.0] + ]] + .into_dyn(), + ); + + let mask = Tensor(ndarray::array![[1.0, 0.0, 1.0]].into_dyn()); + + let pooled = x.mean_pool(mask).unwrap(); + let expected = ndarray::array![[(10.0 + 20.0) / 2.0, 0.0]].into_dyn(); + + assert_allclose(&pooled.0, &expected); + } + + #[test] + fn mean_pool_batch_mode() { + // shape: (2, 2, 2) + let x = Tensor( + ndarray::array![ + [[1.0, 1.0], [3.0, 3.0]], // batch 0 + [[2.0, 4.0], [4.0, 2.0]], // batch 1 + ] + .into_dyn(), + ); + + let mask = Tensor(ndarray::array![[1.0, 1.0], [1.0, 0.0],].into_dyn()); + + let pooled = x.mean_pool(mask).unwrap(); + + let expected = + ndarray::array![[(1.0 + 3.0) / 2.0, (1.0 + 3.0) / 2.0], [2.0, 4.0]].into_dyn(); + + assert_allclose(&pooled.0, &expected); + } + + #[test] + fn mean_pool_mask_broadcasting() { + let x = Tensor( + ndarray::array![[ + [[1.0, 1.0], [2.0, 2.0], [3.0, 3.0]], + [[4.0, 4.0], [5.0, 5.0], [6.0, 6.0]] + ]] + .into_dyn(), + ); + + let mask = Tensor(ndarray::array![[[1.0, 1.0, 0.0], [1.0, 1.0, 0.0]]].into_dyn()); + + let pooled = x.mean_pool(mask).unwrap(); + + // Compute manually: + // First inner seq: avg of [1,2] and [4,5] + // Second inner seq isn't separate — everything is reduced together. + // + // Values included: + // 1.0, 2.0, 4.0, 5.0 (mask=1) + // and the same duplicated for the second feature. + let expected = ndarray::array![[3.0, 3.0]].into_dyn(); // (1,2) + + assert_allclose(&pooled.0, &expected); + } + + #[cfg(test)] + pub fn assert_allclose(a: &ndarray::ArrayD, b: &ndarray::ArrayD) { + let tol = 1e-6; + assert_eq!( + a.shape(), + b.shape(), + "shape mismatch: {:?} vs {:?}", + a.shape(), + b.shape() + ); + let a_slice = a.as_slice().unwrap(); + let b_slice = b.as_slice().unwrap(); + for (i, (x, y)) in a_slice.iter().zip(b_slice.iter()).enumerate() { + let diff = (x - y).abs(); + assert!( + diff <= tol, + "mismatch at index {i}: {x} vs {y} (diff {diff})" + ); + } + } +} diff --git a/encoderfile/src/transforms/tensor/ops/mod.rs b/encoderfile/src/transforms/tensor/ops/mod.rs new file mode 100644 index 00000000..3b2eb4fd --- /dev/null +++ b/encoderfile/src/transforms/tensor/ops/mod.rs @@ -0,0 +1,15 @@ +use super::Tensor; + +pub mod arithm; +pub mod axes; +pub mod clamp; +pub mod fold_axis; +pub mod layer_norm; +pub mod lp_normalize; +pub mod map_axis; +pub mod mean_pool; +pub mod properties; +pub mod softmax; +pub mod sum_axis; +pub mod transpose; +pub mod truncate_axis; diff --git a/encoderfile/src/transforms/tensor/ops/properties.rs b/encoderfile/src/transforms/tensor/ops/properties.rs new file mode 100644 index 00000000..113f6a0b --- /dev/null +++ b/encoderfile/src/transforms/tensor/ops/properties.rs @@ -0,0 +1,215 @@ +use super::Tensor; +use mlua::prelude::*; +use ndarray_stats::QuantileExt; + +impl Tensor { + #[tracing::instrument(skip_all)] + pub fn sum(&self) -> Result { + Ok(self.0.sum()) + } + + #[tracing::instrument(skip_all)] + pub fn min(&self) -> Result { + self.0 + .min() + .copied() + .map_err(|e| LuaError::external(format!("Min max error: {e}"))) + } + + #[tracing::instrument(skip_all)] + pub fn max(&self) -> Result { + self.0 + .max() + .copied() + .map_err(|e| LuaError::external(format!("Min max error: {e}"))) + } + + #[tracing::instrument(skip_all)] + pub fn len(&self) -> usize { + self.0.len() + } + + // The lint does not understand that this is + // a Lua method, so same rules will not apply. + // But it doesn't hurt to have one anyway. + // Maybe.... + // #[allow(clippy::len_without_is_empty)] + #[tracing::instrument(skip_all)] + pub fn is_empty(&self) -> bool { + self.0.is_empty() + } + + #[tracing::instrument(skip_all)] + pub fn std(&self, ddof: f32) -> Result { + Ok(self.0.std(ddof)) + } + + #[tracing::instrument(skip_all)] + pub fn mean(&self) -> Result, LuaError> { + Ok(self.0.mean()) + } + + #[tracing::instrument(skip_all)] + pub fn ndim(&self) -> Result { + Ok(self.0.ndim()) + } +} + +#[tracing::instrument(skip_all)] +pub fn is_broadcastable(a: &[usize], b: &[usize]) -> bool { + let ndim = a.len().max(b.len()); + + for i in 0..ndim { + let ad = *a.get(a.len().wrapping_sub(i + 1)).unwrap_or(&1); + let bd = *b.get(b.len().wrapping_sub(i + 1)).unwrap_or(&1); + + if ad != bd && ad != 1 && bd != 1 { + return false; + } + } + true +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_min() { + use ndarray::Array2; + + let tensor = Tensor(Array2::ones((3, 3)).into_dyn()); + assert_eq!(tensor.min().unwrap(), 1.0); + } + + #[test] + fn test_min_empty() { + let tensor = Tensor(ndarray::array![[[]]].into_dyn()); + assert!(tensor.min().is_err()) + } + + #[test] + fn test_max() { + use ndarray::Array2; + + let tensor = Tensor(Array2::ones((3, 3)).into_dyn()); + assert_eq!(tensor.max().unwrap(), 1.0); + } + + #[test] + fn test_max_empty() { + let tensor = Tensor(ndarray::array![[[]]].into_dyn()); + assert!(tensor.max().is_err()) + } + + #[test] + fn test_len() { + use crate::transforms::tensor::load_env; + use ndarray::Array2; + + let lua = load_env(); + let tensor = Tensor(Array2::zeros((3, 3)).into_dyn()); + let tensor_len = tensor.len(); + + let len = lua + .load("return function(x) return #x end") + .eval::() + .expect("Bad function") + .call::(tensor) + .expect("Function failed"); + + assert_eq!(tensor_len, len); + } + + #[test] + fn test_ndim() { + use crate::transforms::tensor::load_env; + use ndarray::Array2; + + let lua = load_env(); + let tensor = Tensor(Array2::zeros((3, 3)).into_dyn()); + + let ndim = lua + .load("return function(x) return x:ndim() end") + .eval::() + .unwrap() + .call::(tensor) + .unwrap(); + + assert_eq!(ndim, 2); + } + + #[test] + fn test_ndim_0() { + use crate::transforms::tensor::load_env; + use ndarray::Array0; + + let lua = load_env(); + let tensor = Tensor(Array0::::zeros(()).into_dyn()); + + let ndim = lua + .load("return function(x) return x:ndim() end") + .eval::() + .unwrap() + .call::(tensor) + .unwrap(); + + assert_eq!(ndim, 0); + } + + #[test] + fn test_mean() { + use ndarray::Array2; + + let tensor = Tensor(Array2::ones((3, 3)).into_dyn()); + + assert_eq!( + tensor.mean().expect("Failed to calculate mean"), + tensor.0.mean() + ); + } + + #[test] + fn test_std() { + use ndarray::Array2; + + let tensor = Tensor(Array2::ones((3, 3)).into_dyn()); + + assert_eq!( + tensor.std(1.0).expect("Failed to calculate mean"), + tensor.0.std(1.0) + ); + } + + #[test] + fn test_sum() { + use ndarray::Array2; + + let tensor = Tensor(Array2::::from_elem((3, 3), 2.0).into_dyn()); + let expected = 2.0 * 9.0; // 3x3 of 2.0 + assert_eq!(tensor.sum().unwrap(), expected); + } + + #[test] + fn test_sum_empty() { + let tensor = Tensor(ndarray::ArrayD::::zeros(vec![0])); + assert_eq!(tensor.sum().unwrap(), 0.0); + } + + #[test] + fn test_sum_with_lua_binding() { + use crate::transforms::tensor::load_env; + use ndarray::Array2; + + let lua = load_env(); + let tensor = Tensor(Array2::::from_elem((3, 3), 2.0).into_dyn()); + + let func = lua + .load("return function(x) return x:sum() end") + .eval::() + .unwrap(); + + let result: f32 = func.call(tensor.clone()).unwrap(); + assert_eq!(result, tensor.sum().unwrap()); + } +} diff --git a/encoderfile/src/transforms/tensor/ops/softmax.rs b/encoderfile/src/transforms/tensor/ops/softmax.rs new file mode 100644 index 00000000..20993589 --- /dev/null +++ b/encoderfile/src/transforms/tensor/ops/softmax.rs @@ -0,0 +1,90 @@ +use super::Tensor; +use mlua::prelude::*; + +impl Tensor { + #[tracing::instrument(skip_all)] + pub fn softmax(&self, axis: isize) -> Result { + let axis = self.axis1(axis)?; + + let max_vals = self.0.map_axis(axis, |row| { + row.iter().fold(f32::NEG_INFINITY, |m, &v| m.max(v)) + }); + + let z = &self.0 - &max_vals.insert_axis(axis); + + let numerator = z.mapv(|x| x.exp()); + + let denom = numerator.map_axis(axis, |row| row.sum()); + + Ok(Tensor(numerator / &denom.insert_axis(axis))) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_softmax() { + use ndarray::{ArrayD, s}; + let arr: ArrayD = ndarray::array![[1.0, 2.0, 3.0], [1.0, 2.0, 3.0],].into_dyn(); + + // softmax along second axis + // remember — function is 1-indexed + + let Tensor(softmaxed) = Tensor(arr).softmax(2).expect("Failed to softmax"); + + let arr1 = softmaxed.slice(s![0, ..]); + let arr2 = softmaxed.slice(s![1, ..]); + + assert_eq!(arr1, arr2); + } + + #[test] + fn test_softmax_rows_sum_to_one() { + use ndarray::Axis; + let arr = ndarray::array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]].into_dyn(); + + // iterate over second axis in lua-land, first axis in rust land + let Tensor(sm) = Tensor(arr).softmax(2).unwrap(); + + // should iterate over 0th axis + for row in sm.axis_iter(Axis(0)) { + let sum = row.sum(); + assert!((sum - 1.0).abs() < 1e-6); + } + } + + #[test] + fn test_softmax_large_negative_values() { + use ndarray::Axis; + let arr = ndarray::array![[-1000.0, -1000.0, -1000.0]].into_dyn(); + + let Tensor(sm) = Tensor(arr).softmax(1).unwrap(); + let sum: f32 = sm.sum_axis(Axis(0))[0]; + + assert!((sum - 1.0).abs() < 1e-6); + assert!(!sm.iter().any(|x| x.is_nan())); + } + + #[test] + fn test_softmax_peaked_distribution() { + let arr = ndarray::array![[0.0, 0.0, 100.0]].into_dyn(); + + let Tensor(sm) = Tensor(arr).softmax(1).unwrap(); + + assert!(sm[[0, 2]] > 0.999); + } + + #[test] + fn test_softmax_fail() { + use ndarray::ArrayD; + let arr: ArrayD = ndarray::array![[1.0, 2.0], [4.0, 5.0],].into_dyn(); + + let ts = Tensor(arr.clone()); + + assert!(ts.softmax(-1).is_err()); + assert!(ts.softmax(0).is_err()); + assert!(ts.softmax(3).is_err()); + } +} diff --git a/encoderfile/src/transforms/tensor/ops/sum_axis.rs b/encoderfile/src/transforms/tensor/ops/sum_axis.rs new file mode 100644 index 00000000..8286f23a --- /dev/null +++ b/encoderfile/src/transforms/tensor/ops/sum_axis.rs @@ -0,0 +1,79 @@ +use super::Tensor; +use mlua::prelude::*; + +impl Tensor { + #[tracing::instrument(skip_all)] + pub fn sum_axis(&self, axis: isize) -> Result { + Ok(Self(self.0.sum_axis(self.axis1(axis)?))) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_sum_axis_columns() { + use ndarray::{Array2, Axis}; + let tensor = Tensor( + Array2::::from_shape_vec((2, 3), vec![1., 2., 3., 4., 5., 6.]) + .unwrap() + .into_dyn(), + ); + let result = tensor.sum_axis(2).unwrap(); + let expected = Tensor(ndarray::array![6., 15.].into_dyn()); + assert_eq!(result, expected); + + let expected = tensor.0.sum_axis(Axis(1)); + assert_eq!(result, Tensor(expected)); + } + + #[test] + fn test_sum_axis_rows() { + use ndarray::{Array2, Axis}; + let tensor = Tensor( + Array2::::from_shape_vec((2, 3), vec![1., 2., 3., 4., 5., 6.]) + .unwrap() + .into_dyn(), + ); + let result = tensor.sum_axis(1).unwrap(); + let expected = Tensor(ndarray::array![5., 7., 9.].into_dyn()); + assert_eq!(result, expected); + + let expected = tensor.0.sum_axis(Axis(0)); + assert_eq!(result, Tensor(expected)); + } + + #[test] + fn test_sum_axis_invalid() { + use ndarray::Array2; + let tensor = Tensor( + Array2::::from_shape_vec((2, 3), vec![1., 2., 3., 4., 5., 6.]) + .unwrap() + .into_dyn(), + ); + let result = tensor.sum_axis(3); // invalid axis (too large) + assert!(result.is_err()); + } + + #[test] + fn test_sum_axis_with_lua_binding() { + use crate::transforms::tensor::load_env; + use ndarray::Array2; + let lua = load_env(); + let tensor = Tensor( + Array2::::from_shape_vec((2, 3), vec![1., 2., 3., 4., 5., 6.]) + .unwrap() + .into_dyn(), + ); + + let func = lua + .load("return function(x) return x:sum_axis(2) end") + .eval::() + .unwrap(); + + let result: Tensor = func.call(tensor.clone()).unwrap(); + let expected = Tensor(ndarray::array![6., 15.].into_dyn()); + assert_eq!(result, expected); + } +} diff --git a/encoderfile/src/transforms/tensor/ops/transpose.rs b/encoderfile/src/transforms/tensor/ops/transpose.rs new file mode 100644 index 00000000..1fb66d36 --- /dev/null +++ b/encoderfile/src/transforms/tensor/ops/transpose.rs @@ -0,0 +1,23 @@ +use super::Tensor; +use mlua::prelude::*; + +impl Tensor { + #[tracing::instrument(skip_all)] + pub fn transpose(&self) -> Result { + Ok(Self(self.0.t().to_owned())) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_transpose() { + use ndarray::ArrayD; + let arr: ArrayD = ndarray::array![[1.0, 2.0], [4.0, 5.0],].into_dyn(); + let transpose = arr.t().into_owned(); + + assert_eq!(Tensor(arr).transpose().unwrap().0, transpose) + } +} diff --git a/encoderfile/src/transforms/tensor/ops/truncate_axis.rs b/encoderfile/src/transforms/tensor/ops/truncate_axis.rs new file mode 100644 index 00000000..92e232be --- /dev/null +++ b/encoderfile/src/transforms/tensor/ops/truncate_axis.rs @@ -0,0 +1,65 @@ +use super::Tensor; +use mlua::prelude::*; +use ndarray::Axis; + +impl Tensor { + #[tracing::instrument(skip_all)] + pub fn truncate_axis(&self, axis: isize, len: usize) -> Result { + let axis = self.axis1(axis)?; + + let actual_len = self.0.len_of(axis).min(len); + + let mut slice_spec = Vec::with_capacity(self.0.ndim()); + + for i in 0..self.0.ndim() { + if Axis(i) == axis { + slice_spec.push(ndarray::SliceInfoElem::Slice { + start: 0, + end: Some(actual_len as isize), + step: 1, + }); + } else { + slice_spec.push(ndarray::SliceInfoElem::Slice { + start: 0, + end: None, + step: 1, + }); + } + } + + Ok(Tensor(self.0.slice(&slice_spec[..]).to_owned())) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_truncate_axis_correctness() { + use ndarray::Array3; + let tensor = Tensor(Array3::from_elem([3, 3, 3], 1.0).into_dyn()); + + // truncate along 2rd axis (3rd in lua land) to 2 + let result = tensor + .truncate_axis(3, 2) + .expect("Failed to truncate tensor"); + let expected = Tensor(Array3::from_elem([3, 3, 2], 1.0).into_dyn()); + + assert_eq!(result, expected); + } + + #[test] + fn test_truncate_axis_out_of_bounds() { + use ndarray::Array3; + let tensor = Tensor(Array3::from_elem([3, 3, 3], 1.0).into_dyn()); + + // should return the same thing + let result = tensor + .truncate_axis(3, 500) + .expect("Failed to truncate tensor"); + let expected = Tensor(Array3::from_elem([3, 3, 3], 1.0).into_dyn()); + + assert_eq!(result, expected); + } +} diff --git a/encoderfile/src/transforms/tensor/tests/linalg.rs b/encoderfile/src/transforms/tensor/tests/linalg.rs deleted file mode 100644 index bd7f3c78..00000000 --- a/encoderfile/src/transforms/tensor/tests/linalg.rs +++ /dev/null @@ -1,93 +0,0 @@ -use super::Tensor; -use ndarray::{Array3, ArrayD, Axis, array, s}; - -#[test] -fn test_transpose() { - let arr: ArrayD = ndarray::array![[1.0, 2.0], [4.0, 5.0],].into_dyn(); - let transpose = arr.t().into_owned(); - - assert_eq!(Tensor(arr).transpose().unwrap().0, transpose) -} - -#[test] -fn test_softmax() { - let arr: ArrayD = ndarray::array![[1.0, 2.0, 3.0], [1.0, 2.0, 3.0],].into_dyn(); - - // softmax along second axis - // remember — function is 1-indexed - - let Tensor(softmaxed) = Tensor(arr).softmax(2).expect("Failed to softmax"); - - let arr1 = softmaxed.slice(s![0, ..]); - let arr2 = softmaxed.slice(s![1, ..]); - - assert_eq!(arr1, arr2); -} - -#[test] -fn test_softmax_rows_sum_to_one() { - let arr = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]].into_dyn(); - - // iterate over second axis in lua-land, first axis in rust land - let Tensor(sm) = Tensor(arr).softmax(2).unwrap(); - - // should iterate over 0th axis - for row in sm.axis_iter(Axis(0)) { - let sum = row.sum(); - assert!((sum - 1.0).abs() < 1e-6); - } -} - -#[test] -fn test_softmax_large_negative_values() { - let arr = array![[-1000.0, -1000.0, -1000.0]].into_dyn(); - - let Tensor(sm) = Tensor(arr).softmax(1).unwrap(); - let sum: f32 = sm.sum_axis(Axis(0))[0]; - - assert!((sum - 1.0).abs() < 1e-6); - assert!(!sm.iter().any(|x| x.is_nan())); -} - -#[test] -fn test_softmax_peaked_distribution() { - let arr = array![[0.0, 0.0, 100.0]].into_dyn(); - - let Tensor(sm) = Tensor(arr).softmax(1).unwrap(); - - assert!(sm[[0, 2]] > 0.999); -} - -#[test] -fn test_softmax_fail() { - let arr: ArrayD = ndarray::array![[1.0, 2.0], [4.0, 5.0],].into_dyn(); - - let ts = Tensor(arr.clone()); - - assert!(ts.softmax(-1).is_err()); - assert!(ts.softmax(0).is_err()); - assert!(ts.softmax(3).is_err()); -} - -#[test] -fn test_lp_norm_empty() { - let arr: ArrayD = ndarray::array![[[]]].into_dyn(); - - assert!(arr.is_empty()); - assert!(Tensor(arr).lp_normalize(1.0, 1).is_err()) -} - -#[test] -fn test_lp_norm_zero() { - let arr: ArrayD = Array3::ones((3, 3, 3)).into_dyn(); - - assert!(Tensor(arr).lp_normalize(0.0, 1).is_err()) -} - -#[test] -fn test_lp_norm_nonexistent_dim() { - let arr: ArrayD = Array3::ones((3, 3, 3)).into_dyn(); - - assert!(Tensor(arr.clone()).lp_normalize(1.0, 0).is_err()); // lua starts with 1 - assert!(Tensor(arr.clone()).lp_normalize(1.0, 4).is_err()); -} diff --git a/encoderfile/src/transforms/tensor/tests/mod.rs b/encoderfile/src/transforms/tensor/tests/mod.rs deleted file mode 100644 index 99018ced..00000000 --- a/encoderfile/src/transforms/tensor/tests/mod.rs +++ /dev/null @@ -1,9 +0,0 @@ -use super::*; - -mod linalg; -mod ops; -mod tensor; - -fn load_env() -> Lua { - Lua::new() -} diff --git a/encoderfile/src/transforms/tensor/tests/ops.rs b/encoderfile/src/transforms/tensor/tests/ops.rs deleted file mode 100644 index e490b69d..00000000 --- a/encoderfile/src/transforms/tensor/tests/ops.rs +++ /dev/null @@ -1,845 +0,0 @@ -use super::{Tensor, add, div, load_env, mul, sub}; -use mlua::prelude::*; -use ndarray::{Array0, Array2, Array3, Axis, array}; - -#[test] -fn test_layer_norm_correctness() { - let input = Tensor(ndarray::array![[1.0, 2.0, 3.0], [10.0, 20.0, 30.0]].into_dyn()); - - let result = input - .layer_norm(2, 1e-5) - .expect("Failed to compute layer_norm"); - - for row in result.0.rows() { - let m = row.mean().unwrap(); - let v = row.var(0.0); - // mean should be roughly equal to 0 - assert!((m - 0.0).abs() < 1e-5); - // variance tolerances are always a bit looser, but should roughly equal 1.0 - assert!((v - 1.0).abs() < 1e-4); - } -} - -#[test] -fn test_layer_norm_epsilon_behavior() { - let input = Tensor(ndarray::array![[5.0, 5.0, 5.0]].into_dyn()); - let result = input - .layer_norm(2, 1e-5) - .expect("Failed to compute layer_norm"); - - // nothing should blow up or be NaNs - assert!(result.0.iter().all(|v| v.is_finite())); -} - -#[test] -fn test_layer_norm_dimensionality() { - let input = Tensor(Array3::from_elem([10, 10, 10], 3.0).into_dyn()); - let result = input - .layer_norm(2, 1e-5) - .expect("Failed to compute layer_norm"); - assert_eq!(input.0.dim(), result.0.dim()) -} - -#[test] -fn test_layer_norm_translation() { - // layer_norm should be invariant to additive bias per row - let input_1 = Tensor(ndarray::array![[1.0, 2.0, 3.0, 4.0, 5.0]].into_dyn()); - let input_2 = add(&input_1, LuaValue::Number(5.0)).expect("Scalar transformation failed"); - let layer_norm_1 = input_1 - .layer_norm(2, 1e-5) - .expect("Failed to compute layer_norm for input_1"); - let layer_norm_2 = input_2 - .layer_norm(2, 1e-5) - .expect("Failed to compute layer_norm for input_2"); - - for (a, b) in layer_norm_1.0.iter().zip(layer_norm_2.0.iter()) { - assert!((a - b).abs() < 1e-4, "mismatch: {a} vs {b}"); - } -} - -#[test] -fn test_truncate_axis_correctness() { - let tensor = Tensor(Array3::from_elem([3, 3, 3], 1.0).into_dyn()); - - // truncate along 2rd axis (3rd in lua land) to 2 - let result = tensor - .truncate_axis(3, 2) - .expect("Failed to truncate tensor"); - let expected = Tensor(Array3::from_elem([3, 3, 2], 1.0).into_dyn()); - - assert_eq!(result, expected); -} - -#[test] -fn test_truncate_axis_out_of_bounds() { - let tensor = Tensor(Array3::from_elem([3, 3, 3], 1.0).into_dyn()); - - // should return the same thing - let result = tensor - .truncate_axis(3, 500) - .expect("Failed to truncate tensor"); - let expected = Tensor(Array3::from_elem([3, 3, 3], 1.0).into_dyn()); - - assert_eq!(result, expected); -} - -#[test] -fn test_clamp_correctness() { - let tensor = Tensor(ndarray::array!([-5.0, -1.0, 0.0, 1.0, 5.0]).into_dyn()); - let result = tensor - .clamp(Some(-1.0), Some(1.0)) - .expect("Failed to clamp"); - let expected = Tensor(ndarray::array!([-1.0, -1.0, 0.0, 1.0, 1.0]).into_dyn()); - assert_eq!(result.0, expected.0); -} - -#[test] -fn test_clamp_lower_bound_only() { - let tensor = Tensor(ndarray::array!([-3.0, 0.0, 2.0]).into_dyn()); - let result = tensor - .clamp(Some(0.0), None) - .expect("Failed to clamp tensor"); - let expected = Tensor(ndarray::array!([0.0, 0.0, 2.0]).into_dyn()); - assert_eq!(result.0, expected.0); -} - -#[test] -fn test_clamp_upper_bound_only() { - let tensor = Tensor(ndarray::array!([-3.0, 0.0, 2.0, 5.0]).into_dyn()); - let result = tensor - .clamp(None, Some(2.0)) - .expect("Failed to clamp tensor"); - let expected = Tensor(ndarray::array!([-3.0, 0.0, 2.0, 2.0]).into_dyn()); - assert_eq!(result.0, expected.0); -} - -#[test] -fn test_clamp_infinite_bounds() { - let tensor = Tensor(ndarray::array!([-3.0, 0.0, 2.0, 5.0]).into_dyn()); - let result = tensor - .clamp(Some(f32::NEG_INFINITY), Some(f32::INFINITY)) - .expect("Failed to clamp tensor"); - let expected = Tensor(ndarray::array!([-3.0, 0.0, 2.0, 5.0]).into_dyn()); - assert_eq!(result.0, expected.0); -} - -#[test] -fn test_clamp_multidimensional() { - let tensor = - Tensor(ndarray::array!([[-3.0, 3.0], [0.0, 0.0], [2.0, 2.0], [5.0, 5.0]]).into_dyn()); - let expected_shape = tensor.0.shape().to_owned(); - - let result = tensor - .clamp(Some(0.0), Some(1.0)) - .expect("Failed to clamp tensor"); - - let expected = - Tensor(ndarray::array!([[0.0, 1.0], [0.0, 0.0], [1.0, 1.0], [1.0, 1.0]]).into_dyn()); - - assert_eq!(result.0.shape(), expected_shape.as_slice()); - assert_eq!(result.0, expected.0); -} - -#[test] -fn test_clamp_identity() { - let tensor = Tensor(ndarray::array!([-3.0, 0.0, 2.0, 5.0]).into_dyn()); - let result = tensor.clamp(None, None).expect("Failed to clamp tensor"); - assert_eq!(result.0, tensor.0); -} - -#[test] -fn test_clamp_min_equals_max() { - let tensor = Tensor(ndarray::array!([0.0, 3.0, 10.0]).into_dyn()); - let result = tensor - .clamp(Some(3.0), Some(3.0)) - .expect("Failed to clamp tensor"); - let expected = Tensor(ndarray::array!([3.0, 3.0, 3.0]).into_dyn()); - assert_eq!(result.0, expected.0); -} - -#[test] -fn test_clamp_inverted_bounds() { - let tensor = Tensor(ndarray::array!([0.0, 3.0, 10.0]).into_dyn()); - let result = tensor - .clamp(Some(5.0), Some(2.0)) - .expect("Failed to clamp tensor"); - let expected = Tensor(ndarray::array!([2.0, 2.0, 2.0]).into_dyn()); - assert_eq!(result.0, expected.0); -} - -#[test] -fn test_clamp_nan() { - // clamping with NaN bounds nuke the entire tensor. Just so that we have no surprises later ;) - let tensor = Tensor(ndarray::array!([0.0, 3.0, 10.0]).into_dyn()); - let result = tensor - .clamp(Some(f32::NAN), Some(f32::NAN)) - .expect("Failed to clamp tensor"); - let expected = Tensor(ndarray::array!([f32::NAN, f32::NAN, f32::NAN]).into_dyn()); - for (a, b) in result.0.iter().zip(expected.0.iter()) { - assert!(a.is_nan() && b.is_nan()); - } -} - -#[test] -fn test_min() { - let tensor = Tensor(Array2::ones((3, 3)).into_dyn()); - assert_eq!(tensor.min().unwrap(), 1.0); -} - -#[test] -fn test_min_empty() { - let tensor = Tensor(ndarray::array![[[]]].into_dyn()); - assert!(tensor.min().is_err()) -} - -#[test] -fn test_max() { - let tensor = Tensor(Array2::ones((3, 3)).into_dyn()); - assert_eq!(tensor.max().unwrap(), 1.0); -} - -#[test] -fn test_max_empty() { - let tensor = Tensor(ndarray::array![[[]]].into_dyn()); - assert!(tensor.max().is_err()) -} - -#[test] -fn test_exp() { - let arr = Array2::ones((3, 3)).into_dyn(); - let tensor = Tensor(arr.clone()); - assert_eq!(tensor.exp().unwrap(), Tensor(arr.mapv(f32::exp))); -} - -#[test] -fn test_exp_empty() { - let tensor = Tensor(ndarray::array![[[]]].into_dyn()); - let Tensor(exp) = tensor.exp().unwrap(); - assert!(exp.is_empty()); -} - -#[test] -fn test_len() { - let lua = load_env(); - let tensor = Tensor(Array2::zeros((3, 3)).into_dyn()); - let tensor_len = tensor.len(); - - let len = lua - .load("return function(x) return #x end") - .eval::() - .expect("Bad function") - .call::(tensor) - .expect("Function failed"); - - assert_eq!(tensor_len, len); -} - -#[test] -fn test_ndim() { - let lua = load_env(); - let tensor = Tensor(Array2::zeros((3, 3)).into_dyn()); - - let ndim = lua - .load("return function(x) return x:ndim() end") - .eval::() - .unwrap() - .call::(tensor) - .unwrap(); - - assert_eq!(ndim, 2); -} - -#[test] -fn test_ndim_0() { - let lua = load_env(); - let tensor = Tensor(Array0::::zeros(()).into_dyn()); - - let ndim = lua - .load("return function(x) return x:ndim() end") - .eval::() - .unwrap() - .call::(tensor) - .unwrap(); - - assert_eq!(ndim, 0); -} - -macro_rules! generate_ops_test { - ($mod_name:ident, $op:tt, $rust_fn:ident, $lua_op:expr) => { - mod $mod_name { - use super::*; - - #[test] - fn test_binding() { - let lua = load_env(); - let arr1 = Tensor(Array2::::ones((3, 3)).into_dyn()); - let arr2 = arr1.clone(); - - let gold_val = $rust_fn( - &arr1, - LuaValue::UserData(lua.create_userdata(arr2.clone()).unwrap()) - ).expect("Failed to compute"); - - let result: Tensor = lua.load(format!("return function(x, y) return x {} y end", $lua_op)) - .eval::() - .unwrap() - .call((arr1, arr2)) - .expect("Binding failed"); - - assert_eq!(result, gold_val); - } - - #[test] - fn test_tensor() { - let lua = load_env(); - let arr1 = Tensor(Array2::::ones((3, 3)).into_dyn()); - let arr2 = arr1.clone(); - - let val = LuaValue::UserData(lua.create_userdata(arr1.clone()).unwrap()); - let result = $rust_fn(&arr1, val).unwrap(); - - let gold = &arr1.0 $op &arr2.0; - - assert_eq!(gold, result.0); - } - - #[test] - fn test_number() { - let arr1 = Tensor(Array2::::ones((3, 3)).into_dyn()); - - let gold_sum = &arr1.0 $op Array2::::from_elem((3, 3), 5.0); - - let result = $rust_fn(&arr1, LuaValue::Number(5.0)).unwrap(); - - assert_eq!(gold_sum, result.0); - } - - #[test] - fn test_integer() { - let arr1 = Tensor(Array2::::ones((3, 3)).into_dyn()); - - let gold_sum = &arr1.0 $op Array2::::from_elem((3, 3), 5.0); - - let result = $rust_fn(&arr1, LuaValue::Integer(5)).unwrap(); - - assert_eq!(gold_sum, result.0); - } - - #[test] - fn test_bad_dtype() { - let arr1 = Tensor(Array2::::ones((3, 3)).into_dyn()); - - let result: Result = $rust_fn(&arr1, LuaValue::Boolean(false)); - - assert!(result.is_err()); - } - } - } -} - -generate_ops_test!( - test_addition, +, add, "+" -); - -generate_ops_test!( - test_subtraction, -, sub, "-" -); - -generate_ops_test!( - test_multiplication, *, mul, "*" -); - -generate_ops_test!( - test_division, /, div, "/" -); - -#[test] -fn test_eq_simple() { - let lua = load_env(); - - let arr1 = Tensor(Array2::::ones((3, 3)).into_dyn()); - let arr2 = arr1.clone(); - - assert!(arr1 == arr2); - - let result: bool = lua - .load("return function(x, y) return x == y end") - .eval::() - .unwrap() - .call((arr1, arr2)) - .expect("Failed to evaluate"); - - assert!(result); -} - -#[test] -fn test_neq_simple() { - let lua = load_env(); - - let arr1 = Tensor(Array2::::ones((3, 3)).into_dyn()); - let arr2 = Tensor(Array2::::zeros((3, 3)).into_dyn()); - - assert!(arr1 != arr2); - - let result: bool = lua - .load("return function(x, y) return x == y end") - .eval::() - .unwrap() - .call((arr1, arr2)) - .expect("Failed to evaluate"); - - assert!(!result); -} - -#[test] -fn test_to_string() { - let lua = load_env(); - - let vec = Tensor(Array2::::ones((3, 3)).into_dyn()); - let vec_str_gold = vec.0.to_string(); - - let vec_str: String = lua - .globals() - .get::("tostring") - .unwrap() - .call(vec) - .unwrap(); - - assert_eq!(vec_str, vec_str_gold); -} - -#[test] -fn test_mean() { - let tensor = Tensor(Array2::ones((3, 3)).into_dyn()); - - assert_eq!( - tensor.mean().expect("Failed to calculate mean"), - tensor.0.mean() - ); -} - -#[test] -fn test_std() { - let tensor = Tensor(Array2::ones((3, 3)).into_dyn()); - - assert_eq!( - tensor.std(1.0).expect("Failed to calculate mean"), - tensor.0.std(1.0) - ); -} - -#[test] -fn test_sum() { - let tensor = Tensor(Array2::::from_elem((3, 3), 2.0).into_dyn()); - let expected = 2.0 * 9.0; // 3x3 of 2.0 - assert_eq!(tensor.sum().unwrap(), expected); -} - -#[test] -fn test_sum_empty() { - let tensor = Tensor(ndarray::ArrayD::::zeros(vec![0])); - assert_eq!(tensor.sum().unwrap(), 0.0); -} - -#[test] -fn test_sum_axis_columns() { - let tensor = Tensor( - Array2::::from_shape_vec((2, 3), vec![1., 2., 3., 4., 5., 6.]) - .unwrap() - .into_dyn(), - ); - let result = tensor.sum_axis(2).unwrap(); - let expected = Tensor(ndarray::array![6., 15.].into_dyn()); - assert_eq!(result, expected); - - let expected = tensor.0.sum_axis(Axis(1)); - assert_eq!(result, Tensor(expected)); -} - -#[test] -fn test_sum_axis_rows() { - let tensor = Tensor( - Array2::::from_shape_vec((2, 3), vec![1., 2., 3., 4., 5., 6.]) - .unwrap() - .into_dyn(), - ); - let result = tensor.sum_axis(1).unwrap(); - let expected = Tensor(ndarray::array![5., 7., 9.].into_dyn()); - assert_eq!(result, expected); - - let expected = tensor.0.sum_axis(Axis(0)); - assert_eq!(result, Tensor(expected)); -} - -#[test] -fn test_sum_axis_invalid() { - let tensor = Tensor( - Array2::::from_shape_vec((2, 3), vec![1., 2., 3., 4., 5., 6.]) - .unwrap() - .into_dyn(), - ); - let result = tensor.sum_axis(3); // invalid axis (too large) - assert!(result.is_err()); -} - -#[test] -fn test_sum_axis_with_lua_binding() { - let lua = load_env(); - let tensor = Tensor( - Array2::::from_shape_vec((2, 3), vec![1., 2., 3., 4., 5., 6.]) - .unwrap() - .into_dyn(), - ); - - let func = lua - .load("return function(x) return x:sum_axis(2) end") - .eval::() - .unwrap(); - - let result: Tensor = func.call(tensor.clone()).unwrap(); - let expected = Tensor(ndarray::array![6., 15.].into_dyn()); - assert_eq!(result, expected); -} - -#[test] -fn test_sum_with_lua_binding() { - let lua = load_env(); - let tensor = Tensor(Array2::::from_elem((3, 3), 2.0).into_dyn()); - - let func = lua - .load("return function(x) return x:sum() end") - .eval::() - .unwrap(); - - let result: f32 = func.call(tensor.clone()).unwrap(); - assert_eq!(result, tensor.sum().unwrap()); -} - -#[test] -fn test_map_axis_zero_transform() { - let lua = load_env(); - let tensor = Tensor(Array3::::from_elem((3, 6, 9), 1.0).into_dyn()); - - let func = lua - .load("return function(x) return x end") - .eval::() - .unwrap(); - - let result = tensor.map_axis(3, func).expect("Failed to map axis"); - - assert_eq!(tensor, result); -} - -#[test] -fn test_map_axis_double_values() { - let lua = load_env(); - let tensor = - Tensor(Array3::::from_shape_fn((2, 2, 2), |(i, j, k)| (i + j + k) as f32).into_dyn()); - - let func = lua - .load("return function(x) return x * 2 end") - .eval::() - .unwrap(); - - let result = tensor.map_axis(3, func).expect("Failed to map axis"); - - assert_eq!(result.0, tensor.0 * 2.0); -} - -#[test] -fn fold_axis_sum_rows() -> LuaResult<()> { - let lua = load_env(); - let arr = ndarray::array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]].into_dyn(); - let t = Tensor(arr); - - let func: LuaFunction = lua - .load( - r#" - return function(acc, x) return acc + x end - "#, - ) - .eval()?; - - let res = t.fold_axis(1, 0.0, func)?; // fold each row - let v = res.0.into_dimensionality::().unwrap(); - - assert_eq!(v.as_slice().unwrap(), &[6.0, 15.0]); - Ok(()) -} - -#[test] -fn fold_axis_product() -> LuaResult<()> { - let lua = Lua::new(); - let arr = ndarray::array![[1.0, 2.0], [3.0, 4.0]].into_dyn(); - let t = Tensor(arr); - - let func: LuaFunction = lua - .load( - r#" - return function(acc, x) return acc * x end - "#, - ) - .eval()?; - - let res = t.fold_axis(1, 1.0, func)?; // multiply across each row - let v = res.0.into_dimensionality::().unwrap(); - - assert_eq!(v.as_slice().unwrap(), &[2.0, 12.0]); - Ok(()) -} - -fn tensor(data: Vec, shape: &[usize]) -> Tensor { - Tensor(ndarray::ArrayD::from_shape_vec(shape, data).unwrap()) -} - -fn lua_number(n: f64) -> LuaValue { - LuaValue::Number(n) -} - -fn lua_tensor(t: Tensor, lua: &Lua) -> LuaValue { - mlua::Value::UserData(lua.create_userdata(t).unwrap()) -} - -#[test] -fn test_add_broadcast_success() { - let lua = Lua::new(); - - // (2, 3) + (3,) → OK via broadcasting - let a = tensor(vec![1., 2., 3., 4., 5., 6.], &[2, 3]); - let b = tensor(vec![10., 20., 30.], &[3]); - - let res = add(&a, lua_tensor(b, &lua)).unwrap(); - assert_eq!( - res.0, - ndarray::arr2(&[[11., 22., 33.], [14., 25., 36.]]).into_dyn() - ); -} - -#[test] -fn test_add_broadcast_failure() { - let lua = Lua::new(); - - // (2, 3) + (2,) → NOT broadcastable because trailing dims mismatch - let a = tensor(vec![1., 2., 3., 4., 5., 6.], &[2, 3]); - let b = tensor(vec![1., 2.], &[2]); - - let err = add(&a, lua_tensor(b, &lua)).unwrap_err(); - let msg = format!("{err}"); - assert!(msg.contains("not broadcastable"), "Got: {msg}"); -} - -#[test] -fn test_sub_broadcast_success() { - let lua = Lua::new(); - - // (3, 1) - (3,) → OK (result is (3,3)) - let a = tensor(vec![1., 2., 3.], &[3, 1]); - let b = tensor(vec![1., 10., 100.], &[3]); - - let res = sub(&a, lua_tensor(b, &lua)).unwrap(); - assert_eq!( - res.0, - ndarray::arr2(&[[0., -9., -99.], [1., -8., -98.], [2., -7., -97.]]).into_dyn() - ); -} - -#[test] -fn test_sub_broadcast_failure() { - let lua = Lua::new(); - - // (3,2) - (3,) → failure: trailing dim (2 vs 3) - let a = tensor(vec![1., 2., 3., 4., 5., 6.], &[3, 2]); - let b = tensor(vec![1., 2., 3.], &[3]); - - let err = sub(&a, lua_tensor(b, &lua)).unwrap_err(); - assert!(format!("{err}").contains("not broadcastable")); -} - -#[test] -fn test_mul_broadcast_success() { - // (2,3) * scalar → always OK - let a = tensor(vec![1., 2., 3., 4., 5., 6.], &[2, 3]); - let res = mul(&a, lua_number(2.0)).unwrap(); - - assert_eq!( - res.0, - ndarray::arr2(&[[2., 4., 6.], [8., 10., 12.]]).into_dyn() - ); -} - -#[test] -fn test_mul_broadcast_shape_success() { - let lua = Lua::new(); - - // (4,1) * (1,3) → → (4,3) - let a = tensor(vec![1., 2., 3., 4.], &[4, 1]); - let b = tensor(vec![10., 20., 30.], &[1, 3]); - - let res = mul(&a, lua_tensor(b, &lua)).unwrap(); - - assert_eq!( - res.0, - ndarray::arr2(&[ - [10., 20., 30.], - [20., 40., 60.], - [30., 60., 90.], - [40., 80., 120.] - ]) - .into_dyn() - ); -} - -#[test] -fn test_mul_broadcast_fail() { - let lua = Lua::new(); - - // (2,2) * (3,) → cannot broadcast trailing dims - let a = tensor(vec![1., 2., 3., 4.], &[2, 2]); - let b = tensor(vec![1., 2., 3.], &[3]); - - let err = mul(&a, lua_tensor(b, &lua)).unwrap_err(); - assert!(format!("{err}").contains("not broadcastable")); -} - -#[test] -fn test_div_broadcast_success() { - let lua = Lua::new(); - - // (3,3) / (3,) → OK - let a = tensor((1..=9).map(|x| x as f32).collect(), &[3, 3]); - let b = tensor(vec![1., 2., 3.], &[3]); - - let res = div(&a, lua_tensor(b, &lua)).unwrap(); - - assert_eq!( - res.0, - ndarray::arr2(&[ - [1.0 / 1., 2.0 / 2., 3.0 / 3.], - [4.0 / 1., 5.0 / 2., 6.0 / 3.], - [7.0 / 1., 8.0 / 2., 9.0 / 3.], - ]) - .into_dyn() - ); -} - -#[test] -fn test_div_broadcast_fail() { - let lua = Lua::new(); - - // (2,3) vs (2,) again → nope - let a = tensor(vec![1., 2., 3., 4., 5., 6.], &[2, 3]); - let b = tensor(vec![1., 2.], &[2]); - - let err = div(&a, lua_tensor(b, &lua)).unwrap_err(); - assert!(format!("{err}").contains("not broadcastable")); -} - -#[test] -fn mean_pool_single_vector_no_mask() { - // shape: (batch=1, seq=1, dim=3) - let x = Tensor(array![[[1.0, 2.0, 3.0]]].into_dyn()); - let mask = Tensor(array![[1.0]].into_dyn()); - - let pooled = x.mean_pool(mask).unwrap(); - assert_eq!(pooled.0, array![[1.0, 2.0, 3.0]].into_dyn()); -} - -#[test] -fn mean_pool_two_tokens_equal_weight() { - // shape: (1, 2, 3) - let x = Tensor(array![[[1.0, 2.0, 3.0], [3.0, 2.0, 1.0]]].into_dyn()); - - let mask = Tensor(array![[1.0, 1.0]].into_dyn()); - - let pooled = x.mean_pool(mask).unwrap(); - let expected = array![[2.0, 2.0, 2.0]].into_dyn(); - - assert_allclose(&pooled.0, &expected); -} - -#[test] -fn mean_pool_ignores_masked_tokens() { - // shape: (1, 3, 2) - // Only the first and last token should count. - let x = Tensor( - array![[ - [10.0, 0.0], - [99.0, 99.0], // masked out - [20.0, 0.0] - ]] - .into_dyn(), - ); - - let mask = Tensor(array![[1.0, 0.0, 1.0]].into_dyn()); - - let pooled = x.mean_pool(mask).unwrap(); - let expected = array![[(10.0 + 20.0) / 2.0, 0.0]].into_dyn(); - - assert_allclose(&pooled.0, &expected); -} - -#[test] -fn mean_pool_batch_mode() { - // shape: (2, 2, 2) - let x = Tensor( - array![ - [[1.0, 1.0], [3.0, 3.0]], // batch 0 - [[2.0, 4.0], [4.0, 2.0]], // batch 1 - ] - .into_dyn(), - ); - - let mask = Tensor(array![[1.0, 1.0], [1.0, 0.0],].into_dyn()); - - let pooled = x.mean_pool(mask).unwrap(); - - let expected = array![[(1.0 + 3.0) / 2.0, (1.0 + 3.0) / 2.0], [2.0, 4.0]].into_dyn(); - - assert_allclose(&pooled.0, &expected); -} - -#[test] -fn mean_pool_mask_broadcasting() { - let x = Tensor( - array![[ - [[1.0, 1.0], [2.0, 2.0], [3.0, 3.0]], - [[4.0, 4.0], [5.0, 5.0], [6.0, 6.0]] - ]] - .into_dyn(), - ); - - let mask = Tensor(array![[[1.0, 1.0, 0.0], [1.0, 1.0, 0.0]]].into_dyn()); - - let pooled = x.mean_pool(mask).unwrap(); - - // Compute manually: - // First inner seq: avg of [1,2] and [4,5] - // Second inner seq isn't separate — everything is reduced together. - // - // Values included: - // 1.0, 2.0, 4.0, 5.0 (mask=1) - // and the same duplicated for the second feature. - let expected = array![[3.0, 3.0]].into_dyn(); // (1,2) - - assert_allclose(&pooled.0, &expected); -} - -fn assert_allclose(a: &ndarray::ArrayD, b: &ndarray::ArrayD) { - let tol = 1e-6; - assert_eq!( - a.shape(), - b.shape(), - "shape mismatch: {:?} vs {:?}", - a.shape(), - b.shape() - ); - let a_slice = a.as_slice().unwrap(); - let b_slice = b.as_slice().unwrap(); - for (i, (x, y)) in a_slice.iter().zip(b_slice.iter()).enumerate() { - let diff = (x - y).abs(); - assert!( - diff <= tol, - "mismatch at index {i}: {x} vs {y} (diff {diff})" - ); - } -} diff --git a/encoderfile/src/transforms/tensor/tests/tensor.rs b/encoderfile/src/transforms/tensor/tests/tensor.rs deleted file mode 100644 index a48adb95..00000000 --- a/encoderfile/src/transforms/tensor/tests/tensor.rs +++ /dev/null @@ -1,64 +0,0 @@ -use super::*; - -#[test] -fn test_from_lua_create_table() { - let lua = load_env(); - - let tbl: LuaTable = lua - .load("return {{1, 1, 1}, {1, 1, 1}, {1, 1, 1}}") - .eval() - .unwrap(); - - let tensor = Tensor::from_lua(LuaValue::Table(tbl), &lua).expect("Failed to create tensor"); - - assert_eq!(tensor.0.ndim(), 2); - assert_eq!(tensor.0.shape(), [3, 3]); -} - -#[test] -fn test_from_lua_empty_table() { - let lua = load_env(); - - let tbl: LuaTable = lua.load("return {}").eval().unwrap(); - - let Tensor(tensor) = Tensor::from_lua(LuaValue::Table(tbl), &lua).unwrap(); - - assert!(tensor.is_empty()); - assert_eq!(tensor.ndim(), 1); -} - -#[test] -fn test_from_lua_ragged() { - let lua = load_env(); - - let tbl: LuaTable = lua - .load("return {{1, 1, 1}, {1, 1, 1}, {1, 1}}") - .eval() - .unwrap(); - - let tensor = Tensor::from_lua(LuaValue::Table(tbl), &lua); - - assert!(tensor.is_err()); -} - -#[test] -fn test_from_lua_bad_type() { - let lua = load_env(); - - let tbl: LuaString = lua.load("return \"i am not a table\"").eval().unwrap(); - - let tensor = Tensor::from_lua(LuaValue::String(tbl), &lua); - - assert!(tensor.is_err()); -} - -#[test] -fn test_from_lua_bad_type_err() { - let lua = load_env(); - - let val = LuaValue::Boolean(false); - - let tensor = Tensor::from_lua(val, &lua); - - assert!(tensor.is_err()); -}