diff --git a/lib/index.d.ts b/lib/index.d.ts index 0bca395..1f7d860 100644 --- a/lib/index.d.ts +++ b/lib/index.d.ts @@ -107,6 +107,12 @@ export class Index { * @return {number} number of IDs removed. */ removeIds(ids: number[]): number + /** + * Reconstructs a vector from the index. + * @param {number} key The key of the vector to reconstruct. + * @return {number[]} The reconstructed vector. + */ + reconstruct(key: number): number[]; } @@ -133,6 +139,12 @@ export class IndexFlatL2 extends Index { * @param {IndexFlatL2} otherIndex The other IndexFlatL2 instance to merge from. */ mergeFrom(otherIndex: IndexFlatL2): void; + /** + * Reconstructs a vector from the index. + * @param {number} key The key of the vector to reconstruct. + * @return {number[]} The reconstructed vector. + */ + reconstruct(key: number): number[]; } /** @@ -158,4 +170,10 @@ export class IndexFlatIP extends Index { * @param {IndexFlatIP} otherIndex The other IndexFlatIP instance to merge from. */ mergeFrom(otherIndex: IndexFlatIP): void; + /** + * Reconstructs a vector from the index. + * @param {number} key The key of the vector to reconstruct. + * @return {number[]} The reconstructed vector. + */ + reconstruct(key: number): number[]; } \ No newline at end of file diff --git a/src/faiss.cc b/src/faiss.cc index f32b6b7..26e6862 100644 --- a/src/faiss.cc +++ b/src/faiss.cc @@ -440,6 +440,52 @@ class IndexBase : public Napi::ObjectWrap return Napi::Buffer::Copy(env, writer->data.data(), writer->data.size()); } + Napi::Value reconstruct(const Napi::CallbackInfo &info) + { + Napi::Env env = info.Env(); + + if (info.Length() != 1) + { + Napi::Error::New(env, "Expected 1 argument, but got " + std::to_string(info.Length()) + ".") + .ThrowAsJavaScriptException(); + return env.Undefined(); + } + if (!info[0].IsNumber()) + { + Napi::TypeError::New(env, "Invalid argument type, must be a number.").ThrowAsJavaScriptException(); + return env.Undefined(); + } + + const idx_t key = info[0].As().Int64Value(); + + if (key < 0 || key >= index_->ntotal) + { + Napi::Error::New(env, "Invalid key: " + std::to_string(key) + " (must be between 0 and " + std::to_string(index_->ntotal - 1) + ").") + .ThrowAsJavaScriptException(); + return env.Undefined(); + } + + try + { + float *reconstructed = new float[index_->d]; + index_->reconstruct(key, reconstructed); + + Napi::Array result = Napi::Array::New(env, index_->d); + for (size_t i = 0; i < index_->d; i++) + { + result[i] = Napi::Number::New(env, reconstructed[i]); + } + + delete[] reconstructed; + return result; + } + catch (const faiss::FaissException &ex) + { + Napi::Error::New(env, ex.what()).ThrowAsJavaScriptException(); + return env.Undefined(); + } + } + protected: std::unique_ptr index_; inline static Napi::FunctionReference *constructor; @@ -467,6 +513,7 @@ class Index : public IndexBase InstanceMethod("mergeFrom", &Index::mergeFrom), InstanceMethod("removeIds", &Index::removeIds), InstanceMethod("toBuffer", &Index::toBuffer), + InstanceMethod("reconstruct", &Index::reconstruct), StaticMethod("read", &Index::read), StaticMethod("fromBuffer", &Index::fromBuffer), StaticMethod("fromFactory", &Index::fromFactory), @@ -502,6 +549,7 @@ class IndexFlatL2 : public IndexBase InstanceMethod("mergeFrom", &IndexFlatL2::mergeFrom), InstanceMethod("removeIds", &IndexFlatL2::removeIds), InstanceMethod("toBuffer", &IndexFlatL2::toBuffer), + InstanceMethod("reconstruct", &IndexFlatL2::reconstruct), StaticMethod("read", &IndexFlatL2::read), StaticMethod("fromBuffer", &IndexFlatL2::fromBuffer), }); @@ -536,6 +584,7 @@ class IndexFlatIP : public IndexBase InstanceMethod("mergeFrom", &IndexFlatIP::mergeFrom), InstanceMethod("removeIds", &IndexFlatIP::removeIds), InstanceMethod("toBuffer", &IndexFlatIP::toBuffer), + InstanceMethod("reconstruct", &IndexFlatIP::reconstruct), StaticMethod("read", &IndexFlatIP::read), StaticMethod("fromBuffer", &IndexFlatIP::fromBuffer), }); diff --git a/test/reconstruct.test.js b/test/reconstruct.test.js new file mode 100644 index 0000000..567c5be --- /dev/null +++ b/test/reconstruct.test.js @@ -0,0 +1,62 @@ +const { IndexFlatL2 } = require('../lib'); + +describe('IndexFlatL2 reconstruct', () => { + test('should reconstruct vectors correctly', () => { + const dimension = 4; + const index = new IndexFlatL2(dimension); + + const vectors = [ + 1.0, 2.0, 3.0, 4.0, + 5.0, 6.0, 7.0, 8.0, + 9.0, 10.0, 11.0, 12.0 + ]; + + index.add(vectors); + + expect(index.ntotal()).toBe(3); + + const reconstructed0 = index.reconstruct(0); + expect(reconstructed0).toEqual([1.0, 2.0, 3.0, 4.0]); + + const reconstructed1 = index.reconstruct(1); + expect(reconstructed1).toEqual([5.0, 6.0, 7.0, 8.0]); + + const reconstructed2 = index.reconstruct(2); + expect(reconstructed2).toEqual([9.0, 10.0, 11.0, 12.0]); + }); + + test('should throw error for invalid key', () => { + const dimension = 4; + const index = new IndexFlatL2(dimension); + + const vector = [1.0, 2.0, 3.0, 4.0]; + index.add(vector); + + expect(() => { + index.reconstruct(5); + }).toThrow(); + + expect(() => { + index.reconstruct(-1); + }).toThrow(); + }); + + test('should work with Index class', () => { + const { Index } = require('../lib'); + const dimension = 3; + const index = new Index(dimension); + + const vectors = [ + 1.5, 2.5, 3.5, + 4.5, 5.5, 6.5 + ]; + + index.add(vectors); + + const reconstructed0 = index.reconstruct(0); + expect(reconstructed0).toEqual([1.5, 2.5, 3.5]); + + const reconstructed1 = index.reconstruct(1); + expect(reconstructed1).toEqual([4.5, 5.5, 6.5]); + }); +});