diff --git a/docs/reference/building.md b/docs/reference/building.md index 8d350ca7..048f86b8 100644 --- a/docs/reference/building.md +++ b/docs/reference/building.md @@ -352,6 +352,35 @@ encoderfile: transform: "return lp_normalize(output)" ``` +By default, libraries `table`, `string` and `math` are enabled if property `lua_libs` is not present. This property allows you to specify a different set of libraries as strings, to choose from: + +* `coroutine` +* `table` +* `io` +* `os` +* `string` +* `utf8` +* `math` +* `package` + +Note that, if this property is present, no libraries are loaded by default, so all used libraries must be present. + +**Inline transform:** +```yaml +encoderfile: + name: my-model + path: ./models/my-model + model_type: embedding + lua_libs: + - table + - string + - math + - os + transform: | + t = os.time() + return lp_normalize(output) +``` + ### Custom Cache Directory Specify a custom cache location: diff --git a/encoderfile/benches/benchmark_transforms.rs b/encoderfile/benches/benchmark_transforms.rs index d3f966dd..378949a8 100644 --- a/encoderfile/benches/benchmark_transforms.rs +++ b/encoderfile/benches/benchmark_transforms.rs @@ -1,4 +1,4 @@ -use encoderfile::transforms::Postprocessor; +use encoderfile::transforms::{DEFAULT_LIBS, Postprocessor}; use ndarray::{Array2, Array3}; use rand::Rng; @@ -22,9 +22,10 @@ fn get_random_3d(x: usize, y: usize, z: usize) -> Array3 { #[divan::bench(args = [(16, 16, 16), (32, 128, 384), (32, 256, 768)])] fn bench_embedding_l2_normalization(bencher: divan::Bencher, (x, y, z): (usize, usize, usize)) { - let engine = encoderfile::transforms::EmbeddingTransform::new(Some( - include_str!("../../transforms/embedding/l2_normalize_embeddings.lua").to_string(), - )) + let engine = encoderfile::transforms::EmbeddingTransform::new( + DEFAULT_LIBS.to_vec(), + Some(include_str!("../../transforms/embedding/l2_normalize_embeddings.lua").to_string()), + ) .unwrap(); let test_tensor = get_random_3d(x, y, z); @@ -36,9 +37,12 @@ fn bench_embedding_l2_normalization(bencher: divan::Bencher, (x, y, z): (usize, #[divan::bench(args = [(16, 2), (32, 8), (128, 32)])] fn bench_seq_cls_softmax(bencher: divan::Bencher, (x, y): (usize, usize)) { - let engine = encoderfile::transforms::SequenceClassificationTransform::new(Some( - include_str!("../../transforms/sequence_classification/softmax_logits.lua").to_string(), - )) + let engine = encoderfile::transforms::SequenceClassificationTransform::new( + DEFAULT_LIBS.to_vec(), + Some( + include_str!("../../transforms/sequence_classification/softmax_logits.lua").to_string(), + ), + ) .unwrap(); let test_tensor = get_random_2d(x, y); @@ -50,9 +54,10 @@ fn bench_seq_cls_softmax(bencher: divan::Bencher, (x, y): (usize, usize)) { #[divan::bench(args = [(16, 16, 2), (32, 128, 8), (128, 256, 32)])] fn bench_tok_cls_softmax(bencher: divan::Bencher, (x, y, z): (usize, usize, usize)) { - let engine = encoderfile::transforms::TokenClassificationTransform::new(Some( - include_str!("../../transforms/token_classification/softmax_logits.lua").to_string(), - )) + let engine = encoderfile::transforms::TokenClassificationTransform::new( + DEFAULT_LIBS.to_vec(), + Some(include_str!("../../transforms/token_classification/softmax_logits.lua").to_string()), + ) .unwrap(); let test_tensor = get_random_3d(x, y, z); diff --git a/encoderfile/proto/manifest.proto b/encoderfile/proto/manifest.proto index 8c31af3e..77f15c8b 100644 --- a/encoderfile/proto/manifest.proto +++ b/encoderfile/proto/manifest.proto @@ -60,6 +60,10 @@ message EncoderfileManifest { optional Artifact tokenizer = 130; } +message LuaLibs { + repeated string libs = 1; +} + // Transform describes embedded preprocessing/postprocessing steps applied to inputs // prior to inference. // @@ -72,6 +76,8 @@ message Transform { // Transform source code. // Interpretation is defined by the TransformType. string transform = 2; + + optional LuaLibs lua_libs = 3; } // Artifact describes a contiguous byte range within the embedded payload. diff --git a/encoderfile/src/build_cli/config.rs b/encoderfile/src/build_cli/config.rs index a004b68c..987dc512 100644 --- a/encoderfile/src/build_cli/config.rs +++ b/encoderfile/src/build_cli/config.rs @@ -1,4 +1,4 @@ -use crate::common::{Config as EmbeddedConfig, ModelConfig, ModelType}; +use crate::common::{Config as EmbeddedConfig, LuaLibs, ModelConfig, ModelType}; use anyhow::{Context, Result, bail}; use schemars::JsonSchema; use std::{ @@ -41,6 +41,7 @@ pub struct EncoderfileConfig { pub cache_dir: Option, pub base_binary_path: Option, pub transform: Option, + pub lua_libs: Option>, pub tokenizer: Option, #[serde(default = "default_validate_transform")] pub validate_transform: bool, @@ -61,6 +62,7 @@ impl EncoderfileConfig { version: self.version.clone(), model_type: self.model_type.clone(), transform: self.transform()?, + lua_libs: None, }; Ok(config) @@ -104,6 +106,11 @@ impl EncoderfileConfig { Ok(transform) } + pub fn lua_libs(&self) -> Result> { + let configlibs = &self.lua_libs.clone().map(LuaLibs::try_from).transpose()?; + Ok(*configlibs) + } + pub fn get_generated_dir(&self) -> PathBuf { let filename_hash = Sha256::digest(self.name.as_bytes()); @@ -356,6 +363,7 @@ mod tests { cache_dir: Some(base.clone()), validate_transform: false, transform: None, + lua_libs: None, tokenizer: None, base_binary_path: None, target: None, diff --git a/encoderfile/src/build_cli/tokenizer.rs b/encoderfile/src/build_cli/tokenizer.rs index 6b755676..9e79993e 100644 --- a/encoderfile/src/build_cli/tokenizer.rs +++ b/encoderfile/src/build_cli/tokenizer.rs @@ -262,6 +262,7 @@ mod tests { output_path: None, cache_dir: None, transform: None, + lua_libs: None, tokenizer: None, validate_transform: false, base_binary_path: None, @@ -301,6 +302,7 @@ mod tests { output_path: None, cache_dir: None, transform: None, + lua_libs: None, tokenizer: Some(TokenizerBuildConfig { pad_strategy: Some(TokenizerPadStrategy::Fixed { fixed: 512 }), }), @@ -351,6 +353,7 @@ mod tests { output_path: None, cache_dir: None, transform: None, + lua_libs: None, tokenizer: None, validate_transform: false, base_binary_path: None, diff --git a/encoderfile/src/build_cli/transforms/validation/embedding.rs b/encoderfile/src/build_cli/transforms/validation/embedding.rs index f63d400f..070f9880 100644 --- a/encoderfile/src/build_cli/transforms/validation/embedding.rs +++ b/encoderfile/src/build_cli/transforms/validation/embedding.rs @@ -57,6 +57,7 @@ impl TransformValidatorExt for EmbeddingTransform { mod tests { use crate::build_cli::config::{EncoderfileConfig, ModelPath}; use crate::common::ModelType; + use crate::transforms::DEFAULT_LIBS; use super::*; @@ -69,6 +70,7 @@ mod tests { cache_dir: None, output_path: None, transform: None, + lua_libs: None, validate_transform: true, tokenizer: None, base_binary_path: None, @@ -87,10 +89,13 @@ mod tests { let encoderfile_config = test_encoderfile_config(); let model_config = test_model_config(); - EmbeddingTransform::new(Some("function Postprocess(arr) return arr end".to_string())) - .expect("Failed to create transform") - .validate(&encoderfile_config, &model_config) - .expect("Failed to validate"); + EmbeddingTransform::new( + DEFAULT_LIBS.to_vec(), + Some("function Postprocess(arr) return arr end".to_string()), + ) + .expect("Failed to create transform") + .validate(&encoderfile_config, &model_config) + .expect("Failed to validate"); } #[test] @@ -98,10 +103,12 @@ mod tests { let encoderfile_config = test_encoderfile_config(); let model_config = test_model_config(); - let result = - EmbeddingTransform::new(Some("function Postprocess(arr) return 1 end".to_string())) - .expect("Failed to create transform") - .validate(&encoderfile_config, &model_config); + let result = EmbeddingTransform::new( + DEFAULT_LIBS.to_vec(), + Some("function Postprocess(arr) return 1 end".to_string()), + ) + .expect("Failed to create transform") + .validate(&encoderfile_config, &model_config); assert!(result.is_err()); } @@ -111,9 +118,10 @@ mod tests { let encoderfile_config = test_encoderfile_config(); let model_config = test_model_config(); - let result = EmbeddingTransform::new(Some( - "function Postprocess(arr) return arr:sum_axis(1) end".to_string(), - )) + let result = EmbeddingTransform::new( + DEFAULT_LIBS.to_vec(), + Some("function Postprocess(arr) return arr:sum_axis(1) end".to_string()), + ) .expect("Failed to create transform") .validate(&encoderfile_config, &model_config); diff --git a/encoderfile/src/build_cli/transforms/validation/mod.rs b/encoderfile/src/build_cli/transforms/validation/mod.rs index eb494f17..97c16575 100644 --- a/encoderfile/src/build_cli/transforms/validation/mod.rs +++ b/encoderfile/src/build_cli/transforms/validation/mod.rs @@ -1,7 +1,8 @@ use crate::{ common::{ModelConfig, ModelType}, format::assets::{AssetKind, AssetSource, PlannedAsset}, - transforms::TransformSpec, + generated::manifest::LuaLibs as ManifestLuaLibs, + transforms::{TransformSpec, convert_libs}, }; use anyhow::{Context, Result}; @@ -41,9 +42,12 @@ pub trait TransformValidatorExt: TransformSpec { macro_rules! validate_transform { ($transform_type:ident, $transform_str:expr, $encoderfile_config:expr, $model_config:expr) => { - crate::transforms::$transform_type::new(Some($transform_str.clone())) - .with_context(|| utils::validation_err_ctx("Failed to create transform"))? - .validate($encoderfile_config, $model_config) + crate::transforms::$transform_type::new( + convert_libs($encoderfile_config.lua_libs()?.as_ref()), + Some($transform_str.clone()), + ) + .with_context(|| utils::validation_err_ctx("Failed to create transform"))? + .validate($encoderfile_config, $model_config) }; } @@ -87,9 +91,14 @@ pub fn validate_transform<'a>( ), }?; + let lua_libs: Option = encoderfile_config + .lua_libs + .clone() + .map(|libs| ManifestLuaLibs { libs }); let proto = crate::generated::manifest::Transform { transform_type: crate::generated::manifest::TransformType::Lua.into(), transform: transform_str, + lua_libs, }; PlannedAsset::from_asset_source( @@ -101,7 +110,7 @@ pub fn validate_transform<'a>( #[cfg(test)] mod tests { - use crate::transforms::EmbeddingTransform; + use crate::transforms::{DEFAULT_LIBS, EmbeddingTransform}; use crate::build_cli::config::{ModelPath, Transform}; @@ -116,6 +125,7 @@ mod tests { cache_dir: None, output_path: None, transform: None, + lua_libs: None, validate_transform: true, tokenizer: None, base_binary_path: None, @@ -131,7 +141,7 @@ mod tests { #[test] fn test_empty_transform() { - let result = EmbeddingTransform::new(None) + let result = EmbeddingTransform::new(DEFAULT_LIBS.to_vec(), None) .expect("Failed to make embedding transform") .validate(&test_encoderfile_config(), &test_model_config()); @@ -143,7 +153,7 @@ mod tests { let mut config = test_encoderfile_config(); config.validate_transform = false; - EmbeddingTransform::new(None) + EmbeddingTransform::new(DEFAULT_LIBS.to_vec(), None) .expect("Failed to make embedding transform") .validate(&config, &test_model_config()) .expect("Should be ok") @@ -161,6 +171,7 @@ mod tests { cache_dir: None, output_path: None, transform: Some(Transform::Inline(transform_str.to_string())), + lua_libs: None, validate_transform: true, tokenizer: None, base_binary_path: None, @@ -189,6 +200,7 @@ mod tests { cache_dir: None, output_path: None, transform: None, + lua_libs: None, validate_transform: true, tokenizer: None, base_binary_path: None, diff --git a/encoderfile/src/build_cli/transforms/validation/sentence_embedding.rs b/encoderfile/src/build_cli/transforms/validation/sentence_embedding.rs index d6fd1c34..263d3c84 100644 --- a/encoderfile/src/build_cli/transforms/validation/sentence_embedding.rs +++ b/encoderfile/src/build_cli/transforms/validation/sentence_embedding.rs @@ -60,6 +60,7 @@ impl TransformValidatorExt for SentenceEmbeddingTransform { mod tests { use crate::build_cli::config::{EncoderfileConfig, ModelPath}; use crate::common::ModelType; + use crate::transforms::DEFAULT_LIBS; use super::*; @@ -73,6 +74,7 @@ mod tests { output_path: None, transform: None, validate_transform: true, + lua_libs: None, tokenizer: None, base_binary_path: None, target: None, @@ -90,9 +92,10 @@ mod tests { let encoderfile_config = test_encoderfile_config(); let model_config = test_model_config(); - SentenceEmbeddingTransform::new(Some( - "function Postprocess(arr, mask) return arr:mean_pool(mask) end".to_string(), - )) + SentenceEmbeddingTransform::new( + DEFAULT_LIBS.to_vec(), + Some("function Postprocess(arr, mask) return arr:mean_pool(mask) end".to_string()), + ) .expect("Failed to create transform") .validate(&encoderfile_config, &model_config) .expect("Failed to validate"); @@ -103,9 +106,10 @@ mod tests { let encoderfile_config = test_encoderfile_config(); let model_config = test_model_config(); - let result = SentenceEmbeddingTransform::new(Some( - "function Postprocess(arr, mask) return 1 end".to_string(), - )) + let result = SentenceEmbeddingTransform::new( + DEFAULT_LIBS.to_vec(), + Some("function Postprocess(arr, mask) return 1 end".to_string()), + ) .expect("Failed to create transform") .validate(&encoderfile_config, &model_config); @@ -117,9 +121,10 @@ mod tests { let encoderfile_config = test_encoderfile_config(); let model_config = test_model_config(); - let result = SentenceEmbeddingTransform::new(Some( - "function Postprocess(arr, mask) return arr end".to_string(), - )) + let result = SentenceEmbeddingTransform::new( + DEFAULT_LIBS.to_vec(), + Some("function Postprocess(arr, mask) return arr end".to_string()), + ) .expect("Failed to create transform") .validate(&encoderfile_config, &model_config); diff --git a/encoderfile/src/build_cli/transforms/validation/sequence_classification.rs b/encoderfile/src/build_cli/transforms/validation/sequence_classification.rs index 798e3623..b806c47b 100644 --- a/encoderfile/src/build_cli/transforms/validation/sequence_classification.rs +++ b/encoderfile/src/build_cli/transforms/validation/sequence_classification.rs @@ -56,6 +56,7 @@ impl TransformValidatorExt for SequenceClassificationTransform { mod tests { use crate::build_cli::config::{EncoderfileConfig, ModelPath}; use crate::common::ModelType; + use crate::transforms::DEFAULT_LIBS; use super::*; @@ -68,6 +69,7 @@ mod tests { cache_dir: None, output_path: None, transform: None, + lua_libs: None, validate_transform: true, tokenizer: None, base_binary_path: None, @@ -86,9 +88,10 @@ mod tests { let encoderfile_config = test_encoderfile_config(); let model_config = test_model_config(); - SequenceClassificationTransform::new(Some( - "function Postprocess(arr) return arr end".to_string(), - )) + SequenceClassificationTransform::new( + DEFAULT_LIBS.to_vec(), + Some("function Postprocess(arr) return arr end".to_string()), + ) .expect("Failed to create transform") .validate(&encoderfile_config, &model_config) .expect("Failed to validate"); @@ -99,9 +102,10 @@ mod tests { let encoderfile_config = test_encoderfile_config(); let model_config = test_model_config(); - let result = SequenceClassificationTransform::new(Some( - "function Postprocess(arr) return 1 end".to_string(), - )) + let result = SequenceClassificationTransform::new( + DEFAULT_LIBS.to_vec(), + Some("function Postprocess(arr) return 1 end".to_string()), + ) .expect("Failed to create transform") .validate(&encoderfile_config, &model_config); @@ -113,9 +117,10 @@ mod tests { let encoderfile_config = test_encoderfile_config(); let model_config = test_model_config(); - let result = SequenceClassificationTransform::new(Some( - "function Postprocess(arr) return arr:sum_axis(1) end".to_string(), - )) + let result = SequenceClassificationTransform::new( + DEFAULT_LIBS.to_vec(), + Some("function Postprocess(arr) return arr:sum_axis(1) end".to_string()), + ) .expect("Failed to create transform") .validate(&encoderfile_config, &model_config); diff --git a/encoderfile/src/build_cli/transforms/validation/token_classification.rs b/encoderfile/src/build_cli/transforms/validation/token_classification.rs index 3c95eae1..a98dd887 100644 --- a/encoderfile/src/build_cli/transforms/validation/token_classification.rs +++ b/encoderfile/src/build_cli/transforms/validation/token_classification.rs @@ -57,6 +57,7 @@ impl TransformValidatorExt for TokenClassificationTransform { mod tests { use crate::build_cli::config::{EncoderfileConfig, ModelPath}; use crate::common::ModelType; + use crate::transforms::DEFAULT_LIBS; use super::*; @@ -69,6 +70,7 @@ mod tests { cache_dir: None, output_path: None, transform: None, + lua_libs: None, validate_transform: true, tokenizer: None, base_binary_path: None, @@ -87,9 +89,10 @@ mod tests { let encoderfile_config = test_encoderfile_config(); let model_config = test_model_config(); - TokenClassificationTransform::new(Some( - "function Postprocess(arr) return arr end".to_string(), - )) + TokenClassificationTransform::new( + DEFAULT_LIBS.to_vec(), + Some("function Postprocess(arr) return arr end".to_string()), + ) .expect("Failed to create transform") .validate(&encoderfile_config, &model_config) .expect("Failed to validate"); @@ -100,9 +103,10 @@ mod tests { let encoderfile_config = test_encoderfile_config(); let model_config = test_model_config(); - let result = TokenClassificationTransform::new(Some( - "function Postprocess(arr) return 1 end".to_string(), - )) + let result = TokenClassificationTransform::new( + DEFAULT_LIBS.to_vec(), + Some("function Postprocess(arr) return 1 end".to_string()), + ) .expect("Failed to create transform") .validate(&encoderfile_config, &model_config); @@ -114,9 +118,10 @@ mod tests { let encoderfile_config = test_encoderfile_config(); let model_config = test_model_config(); - let result = TokenClassificationTransform::new(Some( - "function Postprocess(arr) return arr:sum_axis(1) end".to_string(), - )) + let result = TokenClassificationTransform::new( + DEFAULT_LIBS.to_vec(), + Some("function Postprocess(arr) return arr:sum_axis(1) end".to_string()), + ) .expect("Failed to create transform") .validate(&encoderfile_config, &model_config); diff --git a/encoderfile/src/common/config.rs b/encoderfile/src/common/config.rs index f1c87727..4bc5fba2 100644 --- a/encoderfile/src/common/config.rs +++ b/encoderfile/src/common/config.rs @@ -1,4 +1,5 @@ use super::model_type::ModelType; +use anyhow::{Result, bail}; use serde::{Deserialize, Serialize}; use tokenizers::PaddingParams; @@ -8,6 +9,52 @@ pub struct Config { pub version: String, pub model_type: ModelType, pub transform: Option, + pub lua_libs: Option, +} + +#[derive(Debug, Serialize, Deserialize, Default, Copy, Clone)] +pub struct LuaLibs { + pub coroutine: bool, + pub table: bool, + pub io: bool, + pub os: bool, + pub string: bool, + pub utf8: bool, + // Check if / how this is supported in lua54 + // pub bit: bool, + pub math: bool, + pub package: bool, + // luau + // pub buffer: bool, + // pub vector: bool, + // luajit + // pub jit: bool, + // pub ffi: bool, + pub debug: bool, +} + +impl TryFrom> for LuaLibs { + type Error = anyhow::Error; + fn try_from(value: Vec) -> Result { + let mut resolved = LuaLibs::default(); + + for lib in value { + match lib.as_str() { + "coroutine" => resolved.coroutine = true, + "table" => resolved.table = true, + "io" => resolved.io = true, + "os" => resolved.os = true, + "string" => resolved.string = true, + "utf8" => resolved.utf8 = true, + "math" => resolved.math = true, + "package" => resolved.package = true, + "debug" => resolved.debug = true, + other => bail!("Unknown Lua stdlib: {}", other), + }; + } + + Ok(resolved) + } } #[derive(Debug, Clone, Default, Serialize, Deserialize)] diff --git a/encoderfile/src/dev_utils/mod.rs b/encoderfile/src/dev_utils/mod.rs index 5c88fa18..6f9a371a 100644 --- a/encoderfile/src/dev_utils/mod.rs +++ b/encoderfile/src/dev_utils/mod.rs @@ -20,6 +20,7 @@ pub fn get_state(dir: &str) -> AppState { version: "0.0.1".to_string(), model_type: T::enum_val(), transform: None, + lua_libs: None, }; let model_config = get_model_config(dir); diff --git a/encoderfile/src/runtime/loader.rs b/encoderfile/src/runtime/loader.rs index c88598c9..b00959c4 100644 --- a/encoderfile/src/runtime/loader.rs +++ b/encoderfile/src/runtime/loader.rs @@ -5,9 +5,9 @@ use std::io::{Read, Seek}; use ort::session::Session; use crate::{ - common::{Config, ModelConfig, ModelType}, + common::{Config, LuaLibs, ModelConfig, ModelType}, format::{assets::AssetKind, codec::EncoderfileCodec, container::Encoderfile}, - generated::manifest::TransformType, + generated::manifest::{self, TransformType}, runtime::TokenizerService, }; @@ -60,8 +60,8 @@ impl<'a, R: Read + Seek> EncoderfileLoader<'a, R> { } } - pub fn transform(&mut self) -> Result> { - let transform_str = match self + pub fn transform(&mut self) -> Result> { + let transform_proto = match self .encoderfile .open_optional(self.reader, AssetKind::Transform) { @@ -79,22 +79,28 @@ impl<'a, R: Read + Seek> EncoderfileLoader<'a, R> { } }; - Some(transform_proto.transform) + Some(transform_proto) } None => None, }; - Ok(transform_str) + Ok(transform_proto) } pub fn encoderfile_config(&mut self) -> Result { + let transform = self.transform()?; + let protolibs = transform + .as_ref() + .and_then(|t| t.lua_libs.clone()) + .map(|l| l.libs); + let configlibs = protolibs.map(LuaLibs::try_from).transpose()?; let config = Config { name: self.encoderfile.name().to_string(), version: self.encoderfile.version().to_string(), model_type: self.encoderfile.model_type(), - transform: self.transform()?, + transform: transform.map(|t| t.transform), + lua_libs: configlibs, }; - Ok(config) } diff --git a/encoderfile/src/runtime/state.rs b/encoderfile/src/runtime/state.rs index 6b084c95..5690d99e 100644 --- a/encoderfile/src/runtime/state.rs +++ b/encoderfile/src/runtime/state.rs @@ -6,6 +6,7 @@ use parking_lot::Mutex; use crate::{ common::{Config, ModelConfig, ModelType, model_type::ModelTypeSpec}, runtime::TokenizerService, + transforms::DEFAULT_LIBS, }; pub type AppState = Arc>; @@ -16,6 +17,7 @@ pub struct EncoderfileState { pub session: Mutex, pub tokenizer: TokenizerService, pub model_config: ModelConfig, + pub lua_libs: Vec, _marker: PhantomData, } @@ -26,11 +28,16 @@ impl EncoderfileState { tokenizer: TokenizerService, model_config: ModelConfig, ) -> EncoderfileState { + let lua_libs = match config.lua_libs { + Some(ref libs) => Vec::::from(libs), + None => DEFAULT_LIBS.to_vec(), + }; EncoderfileState { config, session, tokenizer, model_config, + lua_libs, _marker: PhantomData, } } @@ -39,6 +46,10 @@ impl EncoderfileState { self.config.transform.clone() } + pub fn lua_libs(&self) -> &Vec { + &self.lua_libs + } + pub fn model_type() -> ModelType { T::enum_val() } diff --git a/encoderfile/src/services/embedding.rs b/encoderfile/src/services/embedding.rs index d03de781..53e89f7b 100644 --- a/encoderfile/src/services/embedding.rs +++ b/encoderfile/src/services/embedding.rs @@ -17,7 +17,7 @@ impl Inference for AppState { let encodings = self.tokenizer.encode_text(request.inputs)?; - let transform = EmbeddingTransform::new(self.transform_str())?; + let transform = EmbeddingTransform::new(self.lua_libs.clone(), self.transform_str())?; let results = inference::embedding::embedding(self.session.lock(), &transform, encodings)?; diff --git a/encoderfile/src/services/sentence_embedding.rs b/encoderfile/src/services/sentence_embedding.rs index f04223ad..115c6322 100644 --- a/encoderfile/src/services/sentence_embedding.rs +++ b/encoderfile/src/services/sentence_embedding.rs @@ -17,7 +17,8 @@ impl Inference for AppState { let encodings = self.tokenizer.encode_text(request.inputs)?; - let transform = SentenceEmbeddingTransform::new(self.transform_str())?; + let transform = + SentenceEmbeddingTransform::new(self.lua_libs.clone(), self.transform_str())?; let results = inference::sentence_embedding::sentence_embedding( self.session.lock(), diff --git a/encoderfile/src/services/sequence_classification.rs b/encoderfile/src/services/sequence_classification.rs index 0550a3de..52af7313 100644 --- a/encoderfile/src/services/sequence_classification.rs +++ b/encoderfile/src/services/sequence_classification.rs @@ -17,7 +17,8 @@ impl Inference for AppState { let encodings = self.tokenizer.encode_text(request.inputs)?; - let transform = SequenceClassificationTransform::new(self.transform_str())?; + let transform = + SequenceClassificationTransform::new(self.lua_libs.clone(), self.transform_str())?; let results = inference::sequence_classification::sequence_classification( self.session.lock(), diff --git a/encoderfile/src/services/token_classification.rs b/encoderfile/src/services/token_classification.rs index f880f047..2fd12329 100644 --- a/encoderfile/src/services/token_classification.rs +++ b/encoderfile/src/services/token_classification.rs @@ -19,7 +19,8 @@ impl Inference for AppState { let encodings = self.tokenizer.encode_text(request.inputs)?; - let transform = TokenClassificationTransform::new(self.transform_str())?; + let transform = + TokenClassificationTransform::new(self.lua_libs.clone(), self.transform_str())?; let results = inference::token_classification::token_classification( session, diff --git a/encoderfile/src/transforms/engine/embedding.rs b/encoderfile/src/transforms/engine/embedding.rs index 8d102725..c632a716 100644 --- a/encoderfile/src/transforms/engine/embedding.rs +++ b/encoderfile/src/transforms/engine/embedding.rs @@ -49,11 +49,13 @@ impl Postprocessor for Transform { #[cfg(test)] mod tests { use super::*; + use crate::transforms::DEFAULT_LIBS; #[test] fn test_embedding_no_transform() { - let engine = Transform::::new(Some("".to_string())) - .expect("Failed to create Transform"); + let engine = + Transform::::new(DEFAULT_LIBS.to_vec(), Some("".to_string())) + .expect("Failed to create Transform"); let arr = ndarray::Array3::::from_elem((16, 32, 128), 2.0); @@ -64,14 +66,17 @@ mod tests { #[test] fn test_embedding_identity_transform() { - let engine = Transform::::new(Some( - r##" + let engine = Transform::::new( + DEFAULT_LIBS.to_vec(), + Some( + r##" function Postprocess(arr) return arr end "## - .to_string(), - )) + .to_string(), + ), + ) .expect("Failed to create engine"); let arr = ndarray::Array3::::from_elem((16, 32, 128), 2.0); @@ -83,14 +88,17 @@ mod tests { #[test] fn test_embedding_transform_bad_fn() { - let engine = Transform::::new(Some( - r##" + let engine = Transform::::new( + DEFAULT_LIBS.to_vec(), + Some( + r##" function Postprocess(arr) return 1 end "## - .to_string(), - )) + .to_string(), + ), + ) .expect("Failed to create engine"); let arr = ndarray::Array3::::from_elem((16, 32, 128), 2.0); @@ -102,14 +110,17 @@ mod tests { #[test] fn test_bad_dimensionality_transform_postprocessing() { - let engine = Transform::::new(Some( - r##" + let engine = Transform::::new( + DEFAULT_LIBS.to_vec(), + Some( + r##" function Postprocess(x) return x:sum_axis(1) end "## - .to_string(), - )) + .to_string(), + ), + ) .unwrap(); let arr = ndarray::Array3::::from_elem((3, 3, 3), 2.0); diff --git a/encoderfile/src/transforms/engine/mod.rs b/encoderfile/src/transforms/engine/mod.rs index 886b5062..73cee6e2 100644 --- a/encoderfile/src/transforms/engine/mod.rs +++ b/encoderfile/src/transforms/engine/mod.rs @@ -1,8 +1,12 @@ use std::marker::PhantomData; use crate::{ - common::model_type::{self, ModelTypeSpec}, + common::{ + LuaLibs, + model_type::{self, ModelTypeSpec}, + }, error::ApiError, + transforms::DEFAULT_LIBS, }; use super::tensor::Tensor; @@ -13,6 +17,65 @@ mod sentence_embedding; mod sequence_classification; mod token_classification; +impl From<&LuaLibs> for Vec { + fn from(value: &LuaLibs) -> Self { + let mut libs = Vec::new(); + if value.coroutine { + libs.push(mlua::StdLib::COROUTINE); + } + if value.table { + libs.push(mlua::StdLib::TABLE); + } + if value.io { + libs.push(mlua::StdLib::IO); + } + if value.os { + libs.push(mlua::StdLib::OS); + } + if value.string { + libs.push(mlua::StdLib::STRING); + } + if value.utf8 { + libs.push(mlua::StdLib::UTF8); + } + if value.math { + libs.push(mlua::StdLib::MATH); + } + if value.package { + libs.push(mlua::StdLib::PACKAGE); + } + // luau settings (https://luau.org/), not included right now + /* + if value.buffer { + libs.push(mlua::StdLib::BUFFER); + } + if value.vector { + libs.push(mlua::StdLib::VECTOR); + } + */ + // luajit settings (https://luajit.org/), not included right now + /* + if value.jit { + libs.push(mlua::StdLib::JIT); + } + if value.ffi { + libs.push(mlua::StdLib::FFI); + } + */ + if value.debug { + libs.push(mlua::StdLib::DEBUG); + } + libs + } +} + +pub fn convert_libs(value: Option<&LuaLibs>) -> Vec { + match value { + Some(libs) => Vec::from(libs), + None => DEFAULT_LIBS.to_vec(), + } +} + macro_rules! transform { ($type_name:ident, $mt:ident) => { pub type $type_name = Transform; @@ -49,8 +112,8 @@ impl Transform { } #[tracing::instrument(name = "new_transform", skip_all)] - pub fn new(transform: Option) -> Result { - let lua = new_lua()?; + pub fn new(libs: Vec, transform: Option) -> Result { + let lua = new_lua(libs)?; lua.load(transform.unwrap_or("".to_string())) .exec() @@ -75,9 +138,9 @@ impl TransformSpec for Transform { } } -fn new_lua() -> Result { +fn new_lua(libs: Vec) -> Result { let lua = Lua::new_with( - mlua::StdLib::TABLE | mlua::StdLib::STRING | mlua::StdLib::MATH, + libs.iter().fold(mlua::StdLib::NONE, |acc, lib| acc | *lib), mlua::LuaOptions::default(), ) .map_err(|e| { @@ -109,9 +172,10 @@ fn new_lua() -> Result { #[cfg(test)] mod tests { use super::*; + use crate::transforms::DEFAULT_LIBS; fn new_test_lua() -> Lua { - new_lua().expect("Failed to create new lua") + new_lua(DEFAULT_LIBS.to_vec()).expect("Failed to create new Lua") } #[test] @@ -152,7 +216,7 @@ mod tests { #[test] fn test_cannot_access_environment_or_execute_commands() { - let lua = new_lua().expect("Failed to create new Lua"); + let lua = new_lua(DEFAULT_LIBS.to_vec()).expect("Failed to create new Lua"); // `os.execute` shouldn't exist or be callable let res = lua @@ -279,4 +343,230 @@ mod tests { // shape should be preserved assert_eq!(out.0.shape(), &[3]); } + + enum TestLibItem { + Coroutine, + Io, + Utf8, + Os, + Package, + #[allow(dead_code)] + Debug, + } + + impl TestLibItem { + pub fn test_data(self) -> (String, mlua::StdLib) { + match self { + TestLibItem::Coroutine => ( + r#" + function MyCoroutine() + return Tensor({1, 2, 3}) + end + function MyTest() + local mycor = coroutine.create(MyCoroutine) + local _, tensor = coroutine.resume(mycor) + return tensor + end + "# + .to_string(), + mlua::StdLib::COROUTINE, + ), + TestLibItem::Io => ( + r#" + function MyTest() + local res = Tensor({1, 2, 3}) + io.stderr:write("This is a test of the IO library\n") + return res + end + "# + .to_string(), + mlua::StdLib::IO, + ), + TestLibItem::Utf8 => ( + r#" + function MyTest() + local fp_values = {} + for point in utf8.codes("hello") do + table.insert(fp_values, point) + end + return Tensor(fp_values) + end + "# + .to_string(), + mlua::StdLib::UTF8, + ), + TestLibItem::Os => ( + r#" + function MyTest() + local t = os.time() + return Tensor({1, 2, 3}) + end + "# + .to_string(), + mlua::StdLib::OS, + ), + TestLibItem::Package => ( + r#" + function MyTest() + p = package.path + return Tensor({1, 2, 3}) + end + "# + .to_string(), + mlua::StdLib::PACKAGE, + ), + TestLibItem::Debug => ( + r#" + function MyTest() + local info = debug.getinfo(1, "n") + return Tensor({info.currentline}) + end + "# + .to_string(), + mlua::StdLib::DEBUG, + ), + } + } + } + + #[test] + fn test_convert_default_lua_libs() { + let libs = LuaLibs::default(); + let stdlibs: Vec = Vec::from(&libs); + assert!(stdlibs.is_empty()); + } + + #[test] + fn test_convert_no_lua_libs() { + let maybe_libs = None; + let stdlibs: Vec = convert_libs(maybe_libs); + assert!(stdlibs.contains(&mlua::StdLib::TABLE)); + assert!(stdlibs.contains(&mlua::StdLib::STRING)); + assert!(stdlibs.contains(&mlua::StdLib::MATH)); + } + + #[test] + fn test_convert_some_lua_libs() { + let maybe_libs = Some(&LuaLibs { + coroutine: true, + table: false, + io: true, + os: false, + string: true, + utf8: false, + math: true, + package: false, + debug: true, + }); + let stdlibs: Vec = convert_libs(maybe_libs); + assert!(stdlibs.contains(&mlua::StdLib::COROUTINE)); + assert!(stdlibs.contains(&mlua::StdLib::IO)); + assert!(stdlibs.contains(&mlua::StdLib::STRING)); + assert!(stdlibs.contains(&mlua::StdLib::MATH)); + assert!(stdlibs.contains(&mlua::StdLib::DEBUG)); + assert!(!stdlibs.contains(&mlua::StdLib::TABLE)); + assert!(!stdlibs.contains(&mlua::StdLib::OS)); + assert!(!stdlibs.contains(&mlua::StdLib::UTF8)); + assert!(!stdlibs.contains(&mlua::StdLib::PACKAGE)); + } + + fn test_lualib_any_ok((chunk, lib): (String, mlua::StdLib)) { + let mut lualibs = DEFAULT_LIBS.to_vec(); + lualibs.push(lib); + let lua = new_lua(lualibs).expect("Failed to create new Lua"); + lua.load(chunk).exec().unwrap(); + + let function = lua + .globals() + .get::("MyTest") + .expect("Failed to get MyTest"); + let res = function.call::(()); + assert!( + res.is_ok(), + "Failed to execute function using library {:?}: {:?}", + lib, + res.err() + ); + } + + fn test_lualib_any_fails((chunk, lib): (String, mlua::StdLib)) { + let lua = new_test_lua(); + lua.load(chunk).exec().unwrap(); + + let function = lua + .globals() + .get::("MyTest") + .expect("Failed to get MyTest"); + let res = function.call::(()); + assert!( + res.is_err(), + "Function should have failed when using library {:?}, but got result: {:?}", + lib, + res.ok() + ); + } + + #[test] + fn test_lualib_coroutine_ok() { + test_lualib_any_ok(TestLibItem::Coroutine.test_data()); + } + + #[test] + fn test_lualib_coroutine_fails() { + test_lualib_any_fails(TestLibItem::Coroutine.test_data()); + } + + #[test] + fn test_lualib_io_ok() { + test_lualib_any_ok(TestLibItem::Io.test_data()); + } + + #[test] + fn test_lualib_io_fails() { + test_lualib_any_fails(TestLibItem::Io.test_data()); + } + + #[test] + fn test_lualib_utf8_ok() { + test_lualib_any_ok(TestLibItem::Utf8.test_data()); + } + + #[test] + fn test_lualib_utf8_fails() { + test_lualib_any_fails(TestLibItem::Utf8.test_data()); + } + + #[test] + fn test_lualib_os_ok() { + test_lualib_any_ok(TestLibItem::Os.test_data()); + } + + #[test] + fn test_lualib_os_fails() { + test_lualib_any_fails(TestLibItem::Os.test_data()); + } + + #[test] + fn test_lualib_package_ok() { + test_lualib_any_ok(TestLibItem::Package.test_data()); + } + + #[test] + fn test_lualib_package_fails() { + test_lualib_any_fails(TestLibItem::Package.test_data()); + } + + // TODO: check lua engine init with the debug lib enabled; + // tests currently fail here + /* + #[test] + fn test_lualib_debug_ok() { + test_lualib_any_ok(TestLibItem::Debug.test_data()); + } + + #[test] + fn test_lualib_debug_fails() { + test_lualib_any_fails(TestLibItem::Debug.test_data()); + } + */ } diff --git a/encoderfile/src/transforms/engine/sentence_embedding.rs b/encoderfile/src/transforms/engine/sentence_embedding.rs index a9137ced..bc71c920 100644 --- a/encoderfile/src/transforms/engine/sentence_embedding.rs +++ b/encoderfile/src/transforms/engine/sentence_embedding.rs @@ -63,12 +63,16 @@ impl Postprocessor for Transform { #[cfg(test)] mod tests { use super::*; + use crate::transforms::DEFAULT_LIBS; use ndarray::Axis; #[test] fn test_no_pooling() { - let engine = Transform::::new(Some("".to_string())) - .expect("Failed to create engine"); + let engine = Transform::::new( + DEFAULT_LIBS.to_vec(), + Some("".to_string()), + ) + .expect("Failed to create engine"); let arr = ndarray::Array3::::from_elem((16, 32, 128), 2.0); let mask = ndarray::Array2::::from_elem((16, 32), 1.0); @@ -85,15 +89,18 @@ mod tests { #[test] fn test_successful_pool() { - let engine = Transform::::new(Some( - r##" + let engine = Transform::::new( + DEFAULT_LIBS.to_vec(), + Some( + r##" function Postprocess(arr, mask) -- sum along second axis (lol) return arr:sum_axis(2) end "## - .to_string(), - )) + .to_string(), + ), + ) .expect("Failed to create engine"); let arr = ndarray::Array3::::from_elem((16, 32, 128), 2.0); @@ -108,14 +115,17 @@ mod tests { #[test] fn test_bad_dim_pool() { - let engine = Transform::::new(Some( - r##" + let engine = Transform::::new( + DEFAULT_LIBS.to_vec(), + Some( + r##" function Postprocess(arr, mask) return arr end "## - .to_string(), - )) + .to_string(), + ), + ) .expect("Failed to create engine"); let arr = ndarray::Array3::::from_elem((16, 32, 128), 2.0); @@ -128,14 +138,17 @@ mod tests { #[test] fn test_sentence_embedding_transform_bad_fn() { - let engine = Transform::::new(Some( - r##" + let engine = Transform::::new( + DEFAULT_LIBS.to_vec(), + Some( + r##" function Postprocess(arr, mask) return 1 end "## - .to_string(), - )) + .to_string(), + ), + ) .expect("Failed to create engine"); let arr = ndarray::Array3::::from_elem((16, 32, 128), 2.0); @@ -148,14 +161,17 @@ mod tests { #[test] fn test_bad_dimensionality_transform_postprocessing() { - let engine = Transform::::new(Some( - r##" + let engine = Transform::::new( + DEFAULT_LIBS.to_vec(), + Some( + r##" function Postprocess(arr, mask) return arr end "## - .to_string(), - )) + .to_string(), + ), + ) .unwrap(); let arr = ndarray::Array3::::from_elem((3, 3, 3), 2.0); diff --git a/encoderfile/src/transforms/engine/sequence_classification.rs b/encoderfile/src/transforms/engine/sequence_classification.rs index c9d53471..d369d873 100644 --- a/encoderfile/src/transforms/engine/sequence_classification.rs +++ b/encoderfile/src/transforms/engine/sequence_classification.rs @@ -47,11 +47,15 @@ impl Postprocessor for Transform { #[cfg(test)] mod tests { use super::*; + use crate::transforms::DEFAULT_LIBS; #[test] fn test_sequence_cls_no_transform() { - let engine = Transform::::new(Some("".to_string())) - .expect("Failed to create Transform"); + let engine = Transform::::new( + DEFAULT_LIBS.to_vec(), + Some("".to_string()), + ) + .expect("Failed to create Transform"); let arr = ndarray::Array2::::from_elem((16, 2), 2.0); @@ -62,14 +66,17 @@ mod tests { #[test] fn test_seq_cls_transform() { - let engine = Transform::::new(Some( - r##" + let engine = Transform::::new( + DEFAULT_LIBS.to_vec(), + Some( + r##" function Postprocess(arr) return arr end "## - .to_string(), - )) + .to_string(), + ), + ) .expect("Failed to create engine"); let arr = ndarray::Array2::::from_elem((16, 2), 2.0); @@ -81,14 +88,17 @@ mod tests { #[test] fn test_seq_cls_transform_bad_fn() { - let engine = Transform::::new(Some( - r##" + let engine = Transform::::new( + DEFAULT_LIBS.to_vec(), + Some( + r##" function Postprocess(arr) return 1 end "## - .to_string(), - )) + .to_string(), + ), + ) .expect("Failed to create engine"); let arr = ndarray::Array2::::from_elem((16, 2), 2.0); @@ -100,14 +110,17 @@ mod tests { #[test] fn test_bad_dimensionality_transform_postprocessing() { - let engine = Transform::::new(Some( - r##" + let engine = Transform::::new( + DEFAULT_LIBS.to_vec(), + Some( + r##" function Postprocess(x) return x:sum_axis(1) end "## - .to_string(), - )) + .to_string(), + ), + ) .unwrap(); let arr = ndarray::Array2::::from_elem((2, 2), 2.0); diff --git a/encoderfile/src/transforms/engine/token_classification.rs b/encoderfile/src/transforms/engine/token_classification.rs index 08046f11..6dc4ab44 100644 --- a/encoderfile/src/transforms/engine/token_classification.rs +++ b/encoderfile/src/transforms/engine/token_classification.rs @@ -47,11 +47,15 @@ impl Postprocessor for Transform { #[cfg(test)] mod tests { use super::*; + use crate::transforms::DEFAULT_LIBS; #[test] fn test_token_cls_no_transform() { - let engine = Transform::::new(Some("".to_string())) - .expect("Failed to create Transform"); + let engine = Transform::::new( + DEFAULT_LIBS.to_vec(), + Some("".to_string()), + ) + .expect("Failed to create Transform"); let arr = ndarray::Array3::::from_elem((32, 16, 2), 2.0); @@ -62,14 +66,17 @@ mod tests { #[test] fn test_token_cls_identity_transform() { - let engine = Transform::::new(Some( - r##" + let engine = Transform::::new( + DEFAULT_LIBS.to_vec(), + Some( + r##" function Postprocess(arr) return arr end "## - .to_string(), - )) + .to_string(), + ), + ) .expect("Failed to create engine"); let arr = ndarray::Array3::::from_elem((16, 32, 2), 2.0); @@ -81,14 +88,17 @@ mod tests { #[test] fn test_token_cls_transform_bad_fn() { - let engine = Transform::::new(Some( - r##" + let engine = Transform::::new( + DEFAULT_LIBS.to_vec(), + Some( + r##" function Postprocess(arr) return 1 end "## - .to_string(), - )) + .to_string(), + ), + ) .expect("Failed to create engine"); let arr = ndarray::Array3::::from_elem((16, 32, 2), 2.0); @@ -100,14 +110,17 @@ mod tests { #[test] fn test_bad_dimensionality_transform_postprocessing() { - let engine = Transform::::new(Some( - r##" + let engine = Transform::::new( + DEFAULT_LIBS.to_vec(), + Some( + r##" function Postprocess(x) return x:sum_axis(1) end "## - .to_string(), - )) + .to_string(), + ), + ) .unwrap(); let arr = ndarray::Array3::::from_elem((3, 3, 3), 2.0); diff --git a/encoderfile/src/transforms/mod.rs b/encoderfile/src/transforms/mod.rs index f782d819..6f00903a 100644 --- a/encoderfile/src/transforms/mod.rs +++ b/encoderfile/src/transforms/mod.rs @@ -4,3 +4,9 @@ mod utils; pub use engine::*; pub use tensor::Tensor; + +pub const DEFAULT_LIBS: [mlua::StdLib; 3] = [ + mlua::StdLib::TABLE, + mlua::StdLib::STRING, + mlua::StdLib::MATH, +]; diff --git a/encoderfile/tests/integration/test_build.rs b/encoderfile/tests/integration/test_build.rs index f42580b0..1fca5b93 100644 --- a/encoderfile/tests/integration/test_build.rs +++ b/encoderfile/tests/integration/test_build.rs @@ -26,6 +26,15 @@ encoderfile: path: {:?} model_type: token_classification output_path: {:?} + lua_libs: + - coroutine + - table + - io + - os + - package + - string + - utf8 + - math transform: | --- Applies a softmax across token classification logits. --- Each token classification is normalized independently. @@ -38,7 +47,19 @@ encoderfile: --- Tensor: The input tensor with softmax-normalized embeddings. ---@param arr Tensor ---@return Tensor + p = package.path + function MyCoroutine() + return Tensor({{1, 2, 3}}) + end function Postprocess(arr) + local mycor = coroutine.create(MyCoroutine) + local _, tensor = coroutine.resume(mycor) + io.stderr:write("This is a test of the IO library\n") + local fp_values = {{}} + for point in utf8.codes("hello") do + table.insert(fp_values, point) + end + local t = os.time() return arr:softmax(3) end "##, @@ -165,11 +186,16 @@ async fn send_http_inference(sample_text: &str, http_port: String) -> Result<()> inputs: vec![sample_text.to_owned()], metadata: None, }; - client + let res = client .post(format!("http://localhost:{http_port}/predict")) .json(&req) .send() .await?; + assert!( + res.status().is_success(), + "HTTP inference request failed with status: {}", + res.status() + ); Ok(()) } diff --git a/encoderfile/tests/test_models.rs b/encoderfile/tests/test_models.rs index 941008dc..44718dc0 100644 --- a/encoderfile/tests/test_models.rs +++ b/encoderfile/tests/test_models.rs @@ -3,7 +3,7 @@ use encoderfile::inference::{ embedding::embedding, sequence_classification::sequence_classification, token_classification::token_classification, }; -use encoderfile::transforms::Transform; +use encoderfile::transforms::{DEFAULT_LIBS, Transform}; #[test] fn test_embedding_model() { @@ -19,7 +19,8 @@ fn test_embedding_model() { let session_lock = state.session.lock(); - let transform = Transform::new(None).expect("Failed to create_transform"); + let transform = + Transform::new(DEFAULT_LIBS.to_vec(), None).expect("Failed to create_transform"); let results = embedding(session_lock, &transform, encodings.clone()).expect("Failed to compute results"); @@ -42,7 +43,8 @@ fn test_embedding_inference_with_bad_model() { let session_lock = state.session.lock(); - let transform = Transform::new(None).expect("Failed to create_transform"); + let transform = + Transform::new(DEFAULT_LIBS.to_vec(), None).expect("Failed to create_transform"); embedding(session_lock, &transform, encodings.clone()).expect("Failed to compute results"); } @@ -61,7 +63,8 @@ fn test_sequence_classification_model() { let session_lock = state.session.lock(); - let transform = Transform::new(None).expect("Failed to create_transform"); + let transform = + Transform::new(DEFAULT_LIBS.to_vec(), None).expect("Failed to create_transform"); let results = sequence_classification( session_lock, @@ -89,7 +92,8 @@ fn test_sequence_classification_inference_with_bad_model() { let session_lock = state.session.lock(); - let transform = Transform::new(None).expect("Failed to create_transform"); + let transform = + Transform::new(DEFAULT_LIBS.to_vec(), None).expect("Failed to create_transform"); sequence_classification( session_lock, @@ -114,7 +118,8 @@ fn test_token_classification_model() { let session_lock = state.session.lock(); - let transform = Transform::new(None).expect("Failed to create_transform"); + let transform = + Transform::new(DEFAULT_LIBS.to_vec(), None).expect("Failed to create_transform"); let results = token_classification( session_lock, @@ -142,7 +147,8 @@ fn test_token_classification_inference_with_bad_model() { let session_lock = state.session.lock(); - let transform = Transform::new(None).expect("Failed to create_transform"); + let transform = + Transform::new(DEFAULT_LIBS.to_vec(), None).expect("Failed to create_transform"); token_classification( session_lock, diff --git a/encoderfile/tests/transforms/main.rs b/encoderfile/tests/transforms/main.rs index 791b5f71..d9f58c6c 100644 --- a/encoderfile/tests/transforms/main.rs +++ b/encoderfile/tests/transforms/main.rs @@ -1,5 +1,5 @@ use encoderfile::transforms::{ - EmbeddingTransform, Postprocessor, SequenceClassificationTransform, + DEFAULT_LIBS, EmbeddingTransform, Postprocessor, SequenceClassificationTransform, TokenClassificationTransform, }; use ndarray::{Array2, Array3, Axis}; @@ -7,9 +7,10 @@ use ort::tensor::ArrayExtensions; #[test] fn test_l2_normalization() { - let engine = EmbeddingTransform::new(Some( - include_str!("../../../transforms/embedding/l2_normalize_embeddings.lua").to_string(), - )) + let engine = EmbeddingTransform::new( + DEFAULT_LIBS.to_vec(), + Some(include_str!("../../../transforms/embedding/l2_normalize_embeddings.lua").to_string()), + ) .expect("Failed to create engine"); let test_arr = Array3::::from_elem((8, 16, 36), 1.0); @@ -29,9 +30,13 @@ fn test_l2_normalization() { #[test] fn test_softmax_sequence_cls() { - let engine = SequenceClassificationTransform::new(Some( - include_str!("../../../transforms/sequence_classification/softmax_logits.lua").to_string(), - )) + let engine = SequenceClassificationTransform::new( + DEFAULT_LIBS.to_vec(), + Some( + include_str!("../../../transforms/sequence_classification/softmax_logits.lua") + .to_string(), + ), + ) .expect("Failed to create engine"); // run on array of shape [batch_size, n_labels] @@ -48,9 +53,12 @@ fn test_softmax_sequence_cls() { #[test] fn test_softmax_token_cls() { - let engine = TokenClassificationTransform::new(Some( - include_str!("../../../transforms/token_classification/softmax_logits.lua").to_string(), - )) + let engine = TokenClassificationTransform::new( + DEFAULT_LIBS.to_vec(), + Some( + include_str!("../../../transforms/token_classification/softmax_logits.lua").to_string(), + ), + ) .expect("Failed to create engine"); // run on array of shape [batch_size, n_tokens, n_labels]