Skip to content

Refactor Faiss-based vector format for easier backport #14934

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
import org.apache.lucene.index.SegmentWriteState;

/**
* A Faiss-based format to create and search vector indexes, using {@link LibFaissC} to interact
* A Faiss-based format to create and search vector indexes, using {@link FaissLibrary} to interact
* with the native library.
*
* <p>The Faiss index is configured using its flexible <a
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,8 @@
import static org.apache.lucene.sandbox.codecs.faiss.FaissKnnVectorsFormat.META_EXTENSION;
import static org.apache.lucene.sandbox.codecs.faiss.FaissKnnVectorsFormat.VERSION_CURRENT;
import static org.apache.lucene.sandbox.codecs.faiss.FaissKnnVectorsFormat.VERSION_START;
import static org.apache.lucene.sandbox.codecs.faiss.LibFaissC.FAISS_IO_FLAG_MMAP;
import static org.apache.lucene.sandbox.codecs.faiss.LibFaissC.FAISS_IO_FLAG_READ_ONLY;
import static org.apache.lucene.sandbox.codecs.faiss.LibFaissC.indexRead;
import static org.apache.lucene.sandbox.codecs.faiss.LibFaissC.indexSearch;

import java.io.IOException;
import java.lang.foreign.Arena;
import java.lang.foreign.MemorySegment;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
Expand All @@ -44,7 +38,6 @@
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.store.ChecksumIndexInput;
import org.apache.lucene.store.DataAccessHint;
Expand All @@ -61,16 +54,13 @@
final class FaissKnnVectorsReader extends KnnVectorsReader {
private final FlatVectorsReader rawVectorsReader;
private final IndexInput data;
private final Map<String, IndexEntry> indexMap;
private final Arena arena;
private final Map<String, FaissLibrary.Index> indexMap;
private boolean closed;

public FaissKnnVectorsReader(SegmentReadState state, FlatVectorsReader rawVectorsReader)
throws IOException {
this.rawVectorsReader = rawVectorsReader;
this.indexMap = new HashMap<>();
this.arena = Arena.ofShared();
this.closed = false;

List<FieldMeta> fieldMetaList = new ArrayList<>();
String metaFileName =
Expand Down Expand Up @@ -125,9 +115,11 @@ public FaissKnnVectorsReader(SegmentReadState state, FlatVectorsReader rawVector
CodecUtil.retrieveChecksum(data);

for (FieldMeta fieldMeta : fieldMetaList) {
if (indexMap.put(fieldMeta.fieldInfo.name, loadField(data, arena, fieldMeta)) != null) {
throw new CorruptIndexException("Duplicate field: " + fieldMeta.fieldInfo.name, meta);
if (indexMap.containsKey(fieldMeta.name)) {
throw new CorruptIndexException("Duplicate field: " + fieldMeta.name, meta);
}
IndexInput indexInput = data.slice(fieldMeta.name, fieldMeta.offset, fieldMeta.length);
indexMap.put(fieldMeta.name, FaissLibrary.INSTANCE.readIndex(indexInput));
}
} catch (Throwable t) {
IOUtils.closeWhileSuppressingExceptions(t, this);
Expand All @@ -150,21 +142,7 @@ private static FieldMeta parseNextField(IndexInput meta, SegmentReadState state)
long dataOffset = meta.readLong();
long dataLength = meta.readLong();

return new FieldMeta(fieldInfo, dataOffset, dataLength);
}

@SuppressWarnings("restricted") // TODO: encapsulate the unsafeness into the LibFaissC
private static IndexEntry loadField(IndexInput data, Arena arena, FieldMeta fieldMeta)
throws IOException {
int ioFlags = FAISS_IO_FLAG_MMAP | FAISS_IO_FLAG_READ_ONLY;

// Read index into memory
MemorySegment indexPointer =
indexRead(data.slice(fieldMeta.fieldInfo.name, fieldMeta.offset, fieldMeta.length), ioFlags)
// Ensure timely cleanup
.reinterpret(arena, LibFaissC::freeIndex);

return new IndexEntry(indexPointer, fieldMeta.fieldInfo.getVectorSimilarityFunction());
return new FieldMeta(fieldInfo.name, dataOffset, dataLength);
}

@Override
Expand All @@ -188,9 +166,9 @@ public ByteVectorValues getByteVectorValues(String field) {

@Override
public void search(String field, float[] vector, KnnCollector knnCollector, Bits acceptDocs) {
IndexEntry entry = indexMap.get(field);
if (entry != null) {
indexSearch(entry.indexPointer, entry.function, vector, knnCollector, acceptDocs);
FaissLibrary.Index index = indexMap.get(field);
if (index != null) {
index.search(vector, knnCollector, acceptDocs);
}
}

Expand All @@ -210,12 +188,16 @@ public Map<String, Long> getOffHeapByteSize(FieldInfo fieldInfo) {
@Override
public void close() throws IOException {
if (closed == false) {
// Close all indexes
for (FaissLibrary.Index index : indexMap.values()) {
index.close();
}
indexMap.clear();

IOUtils.close(rawVectorsReader, data);
closed = true;
IOUtils.close(rawVectorsReader, arena::close, data, indexMap::clear);
}
}

private record FieldMeta(FieldInfo fieldInfo, long offset, long length) {}

private record IndexEntry(MemorySegment indexPointer, VectorSimilarityFunction function) {}
private record FieldMeta(String name, long offset, long length) {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,8 @@
import static org.apache.lucene.sandbox.codecs.faiss.FaissKnnVectorsFormat.META_CODEC_NAME;
import static org.apache.lucene.sandbox.codecs.faiss.FaissKnnVectorsFormat.META_EXTENSION;
import static org.apache.lucene.sandbox.codecs.faiss.FaissKnnVectorsFormat.VERSION_CURRENT;
import static org.apache.lucene.sandbox.codecs.faiss.LibFaissC.FAISS_IO_FLAG_MMAP;
import static org.apache.lucene.sandbox.codecs.faiss.LibFaissC.FAISS_IO_FLAG_READ_ONLY;
import static org.apache.lucene.sandbox.codecs.faiss.LibFaissC.createIndex;
import static org.apache.lucene.sandbox.codecs.faiss.LibFaissC.indexWrite;

import java.io.IOException;
import java.lang.foreign.Arena;
import java.lang.foreign.MemorySegment;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand All @@ -43,7 +37,6 @@
import org.apache.lucene.index.MergeState;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.Sorter;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.DocIdSet;
import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.util.IOUtils;
Expand Down Expand Up @@ -154,26 +147,23 @@ public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException {
}
}

@SuppressWarnings("restricted") // TODO: encapsulate the unsafeness into the LibFaissC
private void writeFloatField(
FieldInfo fieldInfo, FloatVectorValues floatVectorValues, IntToIntFunction oldToNewDocId)
throws IOException {
int number = fieldInfo.number;
meta.writeInt(number);

// Write index to temp file and deallocate from memory
try (Arena temp = Arena.ofConfined()) {
VectorSimilarityFunction function = fieldInfo.getVectorSimilarityFunction();
MemorySegment indexPointer =
createIndex(description, indexParams, function, floatVectorValues, oldToNewDocId)
// Ensure timely cleanup
.reinterpret(temp, LibFaissC::freeIndex);

int ioFlags = FAISS_IO_FLAG_MMAP | FAISS_IO_FLAG_READ_ONLY;
try (FaissLibrary.Index index =
FaissLibrary.INSTANCE.createIndex(
description,
indexParams,
fieldInfo.getVectorSimilarityFunction(),
floatVectorValues,
oldToNewDocId)) {

// Write index
long dataOffset = data.getFilePointer();
indexWrite(indexPointer, data, ioFlags);
index.write(data);
long dataLength = data.getFilePointer() - dataOffset;

meta.writeLong(dataOffset);
Expand Down Expand Up @@ -233,7 +223,7 @@ public int size() {

@Override
public FloatVectorValues copy() {
return new BufferedFloatVectorValues(floats, dimension, docIdSet);
throw new AssertionError("Should not be called");
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.sandbox.codecs.faiss;

import java.io.Closeable;
import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.hnsw.IntToIntFunction;

/**
* Minimal interface to create and query Faiss indexes.
*
* @lucene.experimental
*/
interface FaissLibrary {
FaissLibrary INSTANCE = lookup();

// TODO: Use vectorized version where available
String NAME = "faiss_c";
String VERSION = "1.11.0";

private static FaissLibrary lookup() {
final MethodHandles.Lookup lookup = MethodHandles.lookup();

final Class<?> cls;
try {
cls = lookup.findClass("org.apache.lucene.sandbox.codecs.faiss.FaissLibraryNativeImpl");
} catch (ClassNotFoundException | IllegalAccessException e) {
throw new LinkageError("FaissLibraryNativeImpl class is missing or inaccessible", e);
}

final MethodHandle constr;
try {
constr = lookup.findConstructor(cls, MethodType.methodType(void.class));
} catch (NoSuchMethodException | IllegalAccessException e) {
throw new LinkageError("FaissLibraryNativeImpl constructor is missing or inaccessible", e);
}

try {
return (FaissLibrary) constr.invoke();
} catch (RuntimeException | Error e) {
throw e;
} catch (Throwable t) {
throw new AssertionError("Should not throw checked exceptions", t);
}
}

interface Index extends Closeable {
void search(float[] query, KnnCollector knnCollector, Bits acceptDocs);

void write(IndexOutput output);
}

Index createIndex(
String description,
String indexParams,
VectorSimilarityFunction function,
FloatVectorValues floatVectorValues,
IntToIntFunction oldToNewDocId);

Index readIndex(IndexInput input);
}
Loading
Loading