Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions lib/index.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,12 @@ export class Index {
* @return {number} number of IDs removed.
*/
removeIds(ids: number[]): number
/**
* Reconstructs a vector from the index.
* @param {number} key The key of the vector to reconstruct.
* @return {number[]} The reconstructed vector.
*/
reconstruct(key: number): number[];

}

Expand All @@ -133,6 +139,12 @@ export class IndexFlatL2 extends Index {
* @param {IndexFlatL2} otherIndex The other IndexFlatL2 instance to merge from.
*/
mergeFrom(otherIndex: IndexFlatL2): void;
/**
* Reconstructs a vector from the index.
* @param {number} key The key of the vector to reconstruct.
* @return {number[]} The reconstructed vector.
*/
reconstruct(key: number): number[];
}

/**
Expand All @@ -158,4 +170,10 @@ export class IndexFlatIP extends Index {
* @param {IndexFlatIP} otherIndex The other IndexFlatIP instance to merge from.
*/
mergeFrom(otherIndex: IndexFlatIP): void;
/**
* Reconstructs a vector from the index.
* @param {number} key The key of the vector to reconstruct.
* @return {number[]} The reconstructed vector.
*/
reconstruct(key: number): number[];
}
49 changes: 49 additions & 0 deletions src/faiss.cc
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,52 @@ class IndexBase : public Napi::ObjectWrap<T>
return Napi::Buffer<uint8_t>::Copy(env, writer->data.data(), writer->data.size());
}

Napi::Value reconstruct(const Napi::CallbackInfo &info)
{
Napi::Env env = info.Env();

if (info.Length() != 1)
{
Napi::Error::New(env, "Expected 1 argument, but got " + std::to_string(info.Length()) + ".")
.ThrowAsJavaScriptException();
return env.Undefined();
}
if (!info[0].IsNumber())
{
Napi::TypeError::New(env, "Invalid argument type, must be a number.").ThrowAsJavaScriptException();
return env.Undefined();
}

const idx_t key = info[0].As<Napi::Number>().Int64Value();

if (key < 0 || key >= index_->ntotal)
{
Napi::Error::New(env, "Invalid key: " + std::to_string(key) + " (must be between 0 and " + std::to_string(index_->ntotal - 1) + ").")
.ThrowAsJavaScriptException();
return env.Undefined();
}

try
{
float *reconstructed = new float[index_->d];
index_->reconstruct(key, reconstructed);

Napi::Array result = Napi::Array::New(env, index_->d);
for (size_t i = 0; i < index_->d; i++)
{
result[i] = Napi::Number::New(env, reconstructed[i]);
}

delete[] reconstructed;
return result;
}
catch (const faiss::FaissException &ex)
{
Napi::Error::New(env, ex.what()).ThrowAsJavaScriptException();
return env.Undefined();
}
}

protected:
std::unique_ptr<faiss::Index> index_;
inline static Napi::FunctionReference *constructor;
Expand Down Expand Up @@ -467,6 +513,7 @@ class Index : public IndexBase<Index, faiss::IndexFlatL2>
InstanceMethod("mergeFrom", &Index::mergeFrom),
InstanceMethod("removeIds", &Index::removeIds),
InstanceMethod("toBuffer", &Index::toBuffer),
InstanceMethod("reconstruct", &Index::reconstruct),
StaticMethod("read", &Index::read),
StaticMethod("fromBuffer", &Index::fromBuffer),
StaticMethod("fromFactory", &Index::fromFactory),
Expand Down Expand Up @@ -502,6 +549,7 @@ class IndexFlatL2 : public IndexBase<IndexFlatL2, faiss::IndexFlatL2>
InstanceMethod("mergeFrom", &IndexFlatL2::mergeFrom),
InstanceMethod("removeIds", &IndexFlatL2::removeIds),
InstanceMethod("toBuffer", &IndexFlatL2::toBuffer),
InstanceMethod("reconstruct", &IndexFlatL2::reconstruct),
StaticMethod("read", &IndexFlatL2::read),
StaticMethod("fromBuffer", &IndexFlatL2::fromBuffer),
});
Expand Down Expand Up @@ -536,6 +584,7 @@ class IndexFlatIP : public IndexBase<IndexFlatIP, faiss::IndexFlatIP>
InstanceMethod("mergeFrom", &IndexFlatIP::mergeFrom),
InstanceMethod("removeIds", &IndexFlatIP::removeIds),
InstanceMethod("toBuffer", &IndexFlatIP::toBuffer),
InstanceMethod("reconstruct", &IndexFlatIP::reconstruct),
StaticMethod("read", &IndexFlatIP::read),
StaticMethod("fromBuffer", &IndexFlatIP::fromBuffer),
});
Expand Down
62 changes: 62 additions & 0 deletions test/reconstruct.test.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
const { IndexFlatL2 } = require('../lib');

describe('IndexFlatL2 reconstruct', () => {
test('should reconstruct vectors correctly', () => {
const dimension = 4;
const index = new IndexFlatL2(dimension);

const vectors = [
1.0, 2.0, 3.0, 4.0,
5.0, 6.0, 7.0, 8.0,
9.0, 10.0, 11.0, 12.0
];

index.add(vectors);

expect(index.ntotal()).toBe(3);

const reconstructed0 = index.reconstruct(0);
expect(reconstructed0).toEqual([1.0, 2.0, 3.0, 4.0]);

const reconstructed1 = index.reconstruct(1);
expect(reconstructed1).toEqual([5.0, 6.0, 7.0, 8.0]);

const reconstructed2 = index.reconstruct(2);
expect(reconstructed2).toEqual([9.0, 10.0, 11.0, 12.0]);
});

test('should throw error for invalid key', () => {
const dimension = 4;
const index = new IndexFlatL2(dimension);

const vector = [1.0, 2.0, 3.0, 4.0];
index.add(vector);

expect(() => {
index.reconstruct(5);
}).toThrow();

expect(() => {
index.reconstruct(-1);
}).toThrow();
});

test('should work with Index class', () => {
const { Index } = require('../lib');
const dimension = 3;
const index = new Index(dimension);

const vectors = [
1.5, 2.5, 3.5,
4.5, 5.5, 6.5
];

index.add(vectors);

const reconstructed0 = index.reconstruct(0);
expect(reconstructed0).toEqual([1.5, 2.5, 3.5]);

const reconstructed1 = index.reconstruct(1);
expect(reconstructed1).toEqual([4.5, 5.5, 6.5]);
});
});