Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions docs/reference/building.md
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,35 @@ encoderfile:
transform: "return lp_normalize(output)"
```

By default, libraries `table`, `string` and `math` are enabled if property `lua_libs` is not present. This property allows you to specify a different set of libraries as strings, to choose from:

* `coroutine`
* `table`
* `io`
* `os`
* `string`
* `utf8`
* `math`
* `package`

Note that, if this property is present, no libraries are loaded by default, so all used libraries must be present.

**Inline transform:**
```yaml
encoderfile:
name: my-model
path: ./models/my-model
model_type: embedding
lua_libs:
- table
- string
- math
- os
transform: |
t = os.time()
return lp_normalize(output)
```

### Custom Cache Directory

Specify a custom cache location:
Expand Down
25 changes: 15 additions & 10 deletions encoderfile/benches/benchmark_transforms.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use encoderfile::transforms::Postprocessor;
use encoderfile::transforms::{DEFAULT_LIBS, Postprocessor};
use ndarray::{Array2, Array3};
use rand::Rng;

Expand All @@ -22,9 +22,10 @@ fn get_random_3d(x: usize, y: usize, z: usize) -> Array3<f32> {

#[divan::bench(args = [(16, 16, 16), (32, 128, 384), (32, 256, 768)])]
fn bench_embedding_l2_normalization(bencher: divan::Bencher, (x, y, z): (usize, usize, usize)) {
let engine = encoderfile::transforms::EmbeddingTransform::new(Some(
include_str!("../../transforms/embedding/l2_normalize_embeddings.lua").to_string(),
))
let engine = encoderfile::transforms::EmbeddingTransform::new(
DEFAULT_LIBS.to_vec(),
Some(include_str!("../../transforms/embedding/l2_normalize_embeddings.lua").to_string()),
)
.unwrap();

let test_tensor = get_random_3d(x, y, z);
Expand All @@ -36,9 +37,12 @@ fn bench_embedding_l2_normalization(bencher: divan::Bencher, (x, y, z): (usize,

#[divan::bench(args = [(16, 2), (32, 8), (128, 32)])]
fn bench_seq_cls_softmax(bencher: divan::Bencher, (x, y): (usize, usize)) {
let engine = encoderfile::transforms::SequenceClassificationTransform::new(Some(
include_str!("../../transforms/sequence_classification/softmax_logits.lua").to_string(),
))
let engine = encoderfile::transforms::SequenceClassificationTransform::new(
DEFAULT_LIBS.to_vec(),
Some(
include_str!("../../transforms/sequence_classification/softmax_logits.lua").to_string(),
),
)
.unwrap();

let test_tensor = get_random_2d(x, y);
Expand All @@ -50,9 +54,10 @@ fn bench_seq_cls_softmax(bencher: divan::Bencher, (x, y): (usize, usize)) {

#[divan::bench(args = [(16, 16, 2), (32, 128, 8), (128, 256, 32)])]
fn bench_tok_cls_softmax(bencher: divan::Bencher, (x, y, z): (usize, usize, usize)) {
let engine = encoderfile::transforms::TokenClassificationTransform::new(Some(
include_str!("../../transforms/token_classification/softmax_logits.lua").to_string(),
))
let engine = encoderfile::transforms::TokenClassificationTransform::new(
DEFAULT_LIBS.to_vec(),
Some(include_str!("../../transforms/token_classification/softmax_logits.lua").to_string()),
)
.unwrap();

let test_tensor = get_random_3d(x, y, z);
Expand Down
6 changes: 6 additions & 0 deletions encoderfile/proto/manifest.proto
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ message EncoderfileManifest {
optional Artifact tokenizer = 130;
}

message LuaLibs {
repeated string libs = 1;
}

// Transform describes embedded preprocessing/postprocessing steps applied to inputs
// prior to inference.
//
Expand All @@ -72,6 +76,8 @@ message Transform {
// Transform source code.
// Interpretation is defined by the TransformType.
string transform = 2;

optional LuaLibs lua_libs = 3;
}

// Artifact describes a contiguous byte range within the embedded payload.
Expand Down
10 changes: 9 additions & 1 deletion encoderfile/src/build_cli/config.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::common::{Config as EmbeddedConfig, ModelConfig, ModelType};
use crate::common::{Config as EmbeddedConfig, LuaLibs, ModelConfig, ModelType};
use anyhow::{Context, Result, bail};
use schemars::JsonSchema;
use std::{
Expand Down Expand Up @@ -41,6 +41,7 @@ pub struct EncoderfileConfig {
pub cache_dir: Option<PathBuf>,
pub base_binary_path: Option<PathBuf>,
pub transform: Option<Transform>,
pub lua_libs: Option<Vec<String>>,
pub tokenizer: Option<TokenizerBuildConfig>,
#[serde(default = "default_validate_transform")]
pub validate_transform: bool,
Expand All @@ -61,6 +62,7 @@ impl EncoderfileConfig {
version: self.version.clone(),
model_type: self.model_type.clone(),
transform: self.transform()?,
lua_libs: None,
};

Ok(config)
Expand Down Expand Up @@ -104,6 +106,11 @@ impl EncoderfileConfig {
Ok(transform)
}

pub fn lua_libs(&self) -> Result<Option<LuaLibs>> {
let configlibs = &self.lua_libs.clone().map(LuaLibs::try_from).transpose()?;
Ok(*configlibs)
}

pub fn get_generated_dir(&self) -> PathBuf {
let filename_hash = Sha256::digest(self.name.as_bytes());

Expand Down Expand Up @@ -356,6 +363,7 @@ mod tests {
cache_dir: Some(base.clone()),
validate_transform: false,
transform: None,
lua_libs: None,
tokenizer: None,
base_binary_path: None,
target: None,
Expand Down
3 changes: 3 additions & 0 deletions encoderfile/src/build_cli/tokenizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,7 @@ mod tests {
output_path: None,
cache_dir: None,
transform: None,
lua_libs: None,
tokenizer: None,
validate_transform: false,
base_binary_path: None,
Expand Down Expand Up @@ -301,6 +302,7 @@ mod tests {
output_path: None,
cache_dir: None,
transform: None,
lua_libs: None,
tokenizer: Some(TokenizerBuildConfig {
pad_strategy: Some(TokenizerPadStrategy::Fixed { fixed: 512 }),
}),
Expand Down Expand Up @@ -351,6 +353,7 @@ mod tests {
output_path: None,
cache_dir: None,
transform: None,
lua_libs: None,
tokenizer: None,
validate_transform: false,
base_binary_path: None,
Expand Down
30 changes: 19 additions & 11 deletions encoderfile/src/build_cli/transforms/validation/embedding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ impl TransformValidatorExt for EmbeddingTransform {
mod tests {
use crate::build_cli::config::{EncoderfileConfig, ModelPath};
use crate::common::ModelType;
use crate::transforms::DEFAULT_LIBS;

use super::*;

Expand All @@ -69,6 +70,7 @@ mod tests {
cache_dir: None,
output_path: None,
transform: None,
lua_libs: None,
validate_transform: true,
tokenizer: None,
base_binary_path: None,
Expand All @@ -87,21 +89,26 @@ mod tests {
let encoderfile_config = test_encoderfile_config();
let model_config = test_model_config();

EmbeddingTransform::new(Some("function Postprocess(arr) return arr end".to_string()))
.expect("Failed to create transform")
.validate(&encoderfile_config, &model_config)
.expect("Failed to validate");
EmbeddingTransform::new(
DEFAULT_LIBS.to_vec(),
Some("function Postprocess(arr) return arr end".to_string()),
)
.expect("Failed to create transform")
.validate(&encoderfile_config, &model_config)
.expect("Failed to validate");
}

#[test]
fn test_bad_return_type() {
let encoderfile_config = test_encoderfile_config();
let model_config = test_model_config();

let result =
EmbeddingTransform::new(Some("function Postprocess(arr) return 1 end".to_string()))
.expect("Failed to create transform")
.validate(&encoderfile_config, &model_config);
let result = EmbeddingTransform::new(
DEFAULT_LIBS.to_vec(),
Some("function Postprocess(arr) return 1 end".to_string()),
)
.expect("Failed to create transform")
.validate(&encoderfile_config, &model_config);

assert!(result.is_err());
}
Expand All @@ -111,9 +118,10 @@ mod tests {
let encoderfile_config = test_encoderfile_config();
let model_config = test_model_config();

let result = EmbeddingTransform::new(Some(
"function Postprocess(arr) return arr:sum_axis(1) end".to_string(),
))
let result = EmbeddingTransform::new(
DEFAULT_LIBS.to_vec(),
Some("function Postprocess(arr) return arr:sum_axis(1) end".to_string()),
)
.expect("Failed to create transform")
.validate(&encoderfile_config, &model_config);

Expand Down
26 changes: 19 additions & 7 deletions encoderfile/src/build_cli/transforms/validation/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use crate::{
common::{ModelConfig, ModelType},
format::assets::{AssetKind, AssetSource, PlannedAsset},
transforms::TransformSpec,
generated::manifest::LuaLibs as ManifestLuaLibs,
transforms::{TransformSpec, convert_libs},
};
use anyhow::{Context, Result};

Expand Down Expand Up @@ -41,9 +42,12 @@ pub trait TransformValidatorExt: TransformSpec {

macro_rules! validate_transform {
($transform_type:ident, $transform_str:expr, $encoderfile_config:expr, $model_config:expr) => {
crate::transforms::$transform_type::new(Some($transform_str.clone()))
.with_context(|| utils::validation_err_ctx("Failed to create transform"))?
.validate($encoderfile_config, $model_config)
crate::transforms::$transform_type::new(
convert_libs($encoderfile_config.lua_libs()?.as_ref()),
Some($transform_str.clone()),
)
.with_context(|| utils::validation_err_ctx("Failed to create transform"))?
.validate($encoderfile_config, $model_config)
};
}

Expand Down Expand Up @@ -87,9 +91,14 @@ pub fn validate_transform<'a>(
),
}?;

let lua_libs: Option<ManifestLuaLibs> = encoderfile_config
.lua_libs
.clone()
.map(|libs| ManifestLuaLibs { libs });
let proto = crate::generated::manifest::Transform {
transform_type: crate::generated::manifest::TransformType::Lua.into(),
transform: transform_str,
lua_libs,
};

PlannedAsset::from_asset_source(
Expand All @@ -101,7 +110,7 @@ pub fn validate_transform<'a>(

#[cfg(test)]
mod tests {
use crate::transforms::EmbeddingTransform;
use crate::transforms::{DEFAULT_LIBS, EmbeddingTransform};

use crate::build_cli::config::{ModelPath, Transform};

Expand All @@ -116,6 +125,7 @@ mod tests {
cache_dir: None,
output_path: None,
transform: None,
lua_libs: None,
validate_transform: true,
tokenizer: None,
base_binary_path: None,
Expand All @@ -131,7 +141,7 @@ mod tests {

#[test]
fn test_empty_transform() {
let result = EmbeddingTransform::new(None)
let result = EmbeddingTransform::new(DEFAULT_LIBS.to_vec(), None)
.expect("Failed to make embedding transform")
.validate(&test_encoderfile_config(), &test_model_config());

Expand All @@ -143,7 +153,7 @@ mod tests {
let mut config = test_encoderfile_config();
config.validate_transform = false;

EmbeddingTransform::new(None)
EmbeddingTransform::new(DEFAULT_LIBS.to_vec(), None)
.expect("Failed to make embedding transform")
.validate(&config, &test_model_config())
.expect("Should be ok")
Expand All @@ -161,6 +171,7 @@ mod tests {
cache_dir: None,
output_path: None,
transform: Some(Transform::Inline(transform_str.to_string())),
lua_libs: None,
validate_transform: true,
tokenizer: None,
base_binary_path: None,
Expand Down Expand Up @@ -189,6 +200,7 @@ mod tests {
cache_dir: None,
output_path: None,
transform: None,
lua_libs: None,
validate_transform: true,
tokenizer: None,
base_binary_path: None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ impl TransformValidatorExt for SentenceEmbeddingTransform {
mod tests {
use crate::build_cli::config::{EncoderfileConfig, ModelPath};
use crate::common::ModelType;
use crate::transforms::DEFAULT_LIBS;

use super::*;

Expand All @@ -73,6 +74,7 @@ mod tests {
output_path: None,
transform: None,
validate_transform: true,
lua_libs: None,
tokenizer: None,
base_binary_path: None,
target: None,
Expand All @@ -90,9 +92,10 @@ mod tests {
let encoderfile_config = test_encoderfile_config();
let model_config = test_model_config();

SentenceEmbeddingTransform::new(Some(
"function Postprocess(arr, mask) return arr:mean_pool(mask) end".to_string(),
))
SentenceEmbeddingTransform::new(
DEFAULT_LIBS.to_vec(),
Some("function Postprocess(arr, mask) return arr:mean_pool(mask) end".to_string()),
)
.expect("Failed to create transform")
.validate(&encoderfile_config, &model_config)
.expect("Failed to validate");
Expand All @@ -103,9 +106,10 @@ mod tests {
let encoderfile_config = test_encoderfile_config();
let model_config = test_model_config();

let result = SentenceEmbeddingTransform::new(Some(
"function Postprocess(arr, mask) return 1 end".to_string(),
))
let result = SentenceEmbeddingTransform::new(
DEFAULT_LIBS.to_vec(),
Some("function Postprocess(arr, mask) return 1 end".to_string()),
)
.expect("Failed to create transform")
.validate(&encoderfile_config, &model_config);

Expand All @@ -117,9 +121,10 @@ mod tests {
let encoderfile_config = test_encoderfile_config();
let model_config = test_model_config();

let result = SentenceEmbeddingTransform::new(Some(
"function Postprocess(arr, mask) return arr end".to_string(),
))
let result = SentenceEmbeddingTransform::new(
DEFAULT_LIBS.to_vec(),
Some("function Postprocess(arr, mask) return arr end".to_string()),
)
.expect("Failed to create transform")
.validate(&encoderfile_config, &model_config);

Expand Down
Loading