Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion paimon-vector/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ under the License.
<name>Paimon : Vector Index</name>

<properties>
<paimon-vector-index-java.version>0.1.0</paimon-vector-index-java.version>
<paimon-vector-index-java.version>0.1.0-SNAPSHOT</paimon-vector-index-java.version>
</properties>

<dependencies>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,15 @@ public class NativeVectorGlobalIndexWriter implements GlobalIndexSingleColumnWri

private static final int IO_BUFFER_SIZE = 8 * 1024 * 1024;
private static final int ADD_BATCH_SIZE = 10000;
private static final int TRAIN_BATCH_SIZE = 4096;
static final int MAX_FLOAT_ARRAY_LENGTH = Integer.MAX_VALUE - 8;
private static final long TRAIN_MEMORY_WARNING_BYTES = 4L * 1024 * 1024 * 1024;

private final GlobalIndexFileWriter fileWriter;
private final String identifier;
private final Map<String, String> nativeOptions;
private final int dim;
private final int trainMaxSamples;

private File tempVectorFile;
private FileChannel writeChannel;
Expand All @@ -84,11 +88,30 @@ public NativeVectorGlobalIndexWriter(
DataType fieldType,
Map<String, String> options,
String identifier) {
this(
fileWriter,
fieldType,
options,
identifier,
NativeVectorGlobalIndexerFactory.DEFAULT_TRAIN_MAX_SAMPLES);
}

public NativeVectorGlobalIndexWriter(
GlobalIndexFileWriter fileWriter,
DataType fieldType,
Map<String, String> options,
String identifier,
int trainMaxSamples) {
this.fileWriter = fileWriter;
this.identifier = identifier;
validateFieldType(fieldType);
this.nativeOptions = options;
this.dim = Integer.parseInt(options.get("dimension"));
if (trainMaxSamples <= 0) {
throw new IllegalArgumentException(
"trainMaxSamples must be a positive integer: " + trainMaxSamples);
}
this.trainMaxSamples = trainMaxSamples;
this.count = 0;
this.closed = false;
this.recordSizeInBytes = checkedRecordSize(dim, IO_BUFFER_SIZE);
Expand Down Expand Up @@ -272,25 +295,58 @@ private String fileNamePrefix() {
}

private void trainFromTempFile(VectorIndexWriter writer) throws IOException {
int trainCount = (int) count;
float[] trainData = new float[trainCount * dim];
int trainCount = trainingVectorCount(count, trainMaxSamples);
int trainBatchSize = vectorBatchSize(TRAIN_BATCH_SIZE, dim);
float[] batchVectors = new float[trainBatchSize * dim];
logTrainingMemoryEstimate(trainCount);

try (RandomAccessFile raf = new RandomAccessFile(tempVectorFile, "r");
FileChannel channel = raf.getChannel()) {
ByteBuffer readBuf = ByteBuffer.allocateDirect(IO_BUFFER_SIZE);
readBuf.order(ByteOrder.nativeOrder());
readBuf.limit(0);

for (int i = 0; i < trainCount; i++) {
long selected = 0;
long nextSampleIndex = 0;
int batchCount = 0;

for (long recordIndex = 0;
recordIndex < count && selected < trainCount;
recordIndex++) {
ensureAvailable(readBuf, channel, recordSizeInBytes);
readBuf.getLong(); // skip rowId
for (int d = 0; d < dim; d++) {
trainData[i * dim + d] = readBuf.getFloat();
if (recordIndex == nextSampleIndex) {
for (int d = 0; d < dim; d++) {
batchVectors[batchCount * dim + d] = readBuf.getFloat();
}
selected++;
batchCount++;
if (batchCount == trainBatchSize) {
writer.addTrainingVectors(batchVectors, batchCount);
batchCount = 0;
}
if (selected < trainCount) {
nextSampleIndex = sampleIndex(selected, count, trainCount);
}
} else {
readBuf.position(readBuf.position() + dim * Float.BYTES);
}
}

if (batchCount > 0) {
writer.addTrainingVectors(
Arrays.copyOf(batchVectors, batchCount * dim), batchCount);
}
if (selected != trainCount) {
throw new IOException(
"Expected to select "
+ trainCount
+ " training vectors, but selected "
+ selected);
}
}

writer.train(trainData, trainCount);
writer.finishTraining();
}

private void addVectorsFromTempFile(VectorIndexWriter writer) throws IOException {
Expand Down Expand Up @@ -386,6 +442,73 @@ private static int checkedRecordSize(int dim, int bufferCapacity) {
return (int) recordSize;
}

static int trainingVectorCount(long vectorCount, int trainMaxSamples) {
if (vectorCount <= 0) {
return 0;
}
if (trainMaxSamples <= 0) {
throw new IllegalArgumentException(
"trainMaxSamples must be a positive integer: " + trainMaxSamples);
}
return (int) Math.min(vectorCount, (long) trainMaxSamples);
}

static int vectorBatchSize(int requestedBatchSize, int dim) {
if (requestedBatchSize <= 0) {
throw new IllegalArgumentException(
"requestedBatchSize must be a positive integer: " + requestedBatchSize);
}
if (dim <= 0) {
throw new IllegalArgumentException("dim must be a positive integer: " + dim);
}
int maxBatchSize = MAX_FLOAT_ARRAY_LENGTH / dim;
if (maxBatchSize <= 0) {
throw new IllegalStateException(
"Vector dimension " + dim + " exceeds Java float array capacity");
}
return Math.min(requestedBatchSize, maxBatchSize);
}

private void logTrainingMemoryEstimate(int trainCount) {
long rawBytes = saturatedMultiply(saturatedMultiply(trainCount, dim), Float.BYTES);
long estimatedPeakBytes = saturatedMultiply(rawBytes, 2);
if (estimatedPeakBytes >= TRAIN_MEMORY_WARNING_BYTES) {
LOG.warn(
"{} training uses {} samples out of {} vectors (dim={}). Estimated native "
+ "training peak is at least {} bytes (~{} GiB) before OPQ and "
+ "temporary buffers.",
identifier,
trainCount,
count,
dim,
estimatedPeakBytes,
String.format("%.2f", estimatedPeakBytes / 1024.0 / 1024.0 / 1024.0));
} else {
LOG.info(
"{} training uses {} samples out of {} vectors (dim={})",
identifier,
trainCount,
count,
dim);
}
}

private static long saturatedMultiply(long left, long right) {
if (left == 0 || right == 0) {
return 0;
}
if (left > Long.MAX_VALUE / right) {
return Long.MAX_VALUE;
}
return left * right;
}

private static long sampleIndex(long sampleOrdinal, long vectorCount, int trainCount) {
long quotient = vectorCount / trainCount;
long remainder = vectorCount % trainCount;
return sampleOrdinal * quotient + sampleOrdinal * remainder / trainCount;
}

@Override
public void close() {
if (!closed) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,17 +39,36 @@ public class NativeVectorGlobalIndexer implements VectorGlobalIndexer {
private final DataType fieldType;
private final Map<String, String> options;
private final String identifier;
private final int trainMaxSamples;

public NativeVectorGlobalIndexer(
DataType fieldType, Map<String, String> options, String identifier) {
this(
fieldType,
options,
identifier,
NativeVectorGlobalIndexerFactory.DEFAULT_TRAIN_MAX_SAMPLES);
}

public NativeVectorGlobalIndexer(
DataType fieldType,
Map<String, String> options,
String identifier,
int trainMaxSamples) {
this.fieldType = fieldType;
this.options = Objects.requireNonNull(options, "options must not be null");
this.identifier = Objects.requireNonNull(identifier, "identifier must not be null");
if (trainMaxSamples <= 0) {
throw new IllegalArgumentException(
"trainMaxSamples must be a positive integer: " + trainMaxSamples);
}
this.trainMaxSamples = trainMaxSamples;
}

@Override
public GlobalIndexWriter createWriter(GlobalIndexFileWriter fileWriter) {
return new NativeVectorGlobalIndexWriter(fileWriter, fieldType, options, identifier);
return new NativeVectorGlobalIndexWriter(
fileWriter, fieldType, options, identifier, trainMaxSamples);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,17 @@
public abstract class NativeVectorGlobalIndexerFactory implements GlobalIndexerFactory {

private static final int DEFAULT_DIMENSION = 128;
static final String TRAIN_MAX_SAMPLES_OPTION = "train.max-samples";
static final int DEFAULT_TRAIN_MAX_SAMPLES = 65536;

@Override
public GlobalIndexer create(DataField field, Options options) {
String identifier = identifier();
return new NativeVectorGlobalIndexer(
field.type(),
nativeOptions(field.type(), options, identifier, field.name()),
identifier);
identifier,
trainMaxSamples(options, identifier, field.name()));
}

static Map<String, String> nativeOptions(
Expand Down Expand Up @@ -78,6 +81,43 @@ static Map<String, String> nativeOptions(
return nativeOptions;
}

static int trainMaxSamples(Options tableOptions, String identifier, String fieldName) {
String optionPrefix = identifier + ".";
String fieldPrefix = "fields." + fieldName + ".";
String indexKey = optionPrefix + TRAIN_MAX_SAMPLES_OPTION;
String fieldKey = fieldPrefix + TRAIN_MAX_SAMPLES_OPTION;
Map<String, String> tableOptionsMap = tableOptions.toMap();

String key = null;
String value = null;
if (tableOptionsMap.containsKey(indexKey)) {
key = indexKey;
value = tableOptionsMap.get(indexKey);
}
if (tableOptionsMap.containsKey(fieldKey)) {
key = fieldKey;
value = tableOptionsMap.get(fieldKey);
}
if (value == null) {
return DEFAULT_TRAIN_MAX_SAMPLES;
}

try {
int parsed = Integer.parseInt(value.trim());
if (parsed > 0) {
return parsed;
}
throw invalidTrainMaxSamples(key, value);
} catch (NumberFormatException e) {
throw invalidTrainMaxSamples(key, value);
}
}

private static IllegalArgumentException invalidTrainMaxSamples(String key, String value) {
return new IllegalArgumentException(
"Invalid value for '" + key + "': " + value + ". Must be a positive integer.");
}

private static String nativeOptionKey(String optionKey) {
switch (optionKey) {
case "index.dimension":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,31 @@ public void testAllNullReturnsEmpty() {
assertThat(results).isEmpty();
}

@Test
public void testTrainingVectorCountUsesOnlyConfiguredSampleLimit() {
int oldJavaArrayLimitFor1024Dim =
NativeVectorGlobalIndexWriter.MAX_FLOAT_ARRAY_LENGTH / 1024;
int requestedSamples = oldJavaArrayLimitFor1024Dim + 1;

assertThat(
NativeVectorGlobalIndexWriter.trainingVectorCount(
requestedSamples, requestedSamples))
.isEqualTo(requestedSamples);
assertThat(NativeVectorGlobalIndexWriter.trainingVectorCount(10_000L, 64)).isEqualTo(64);
}

@Test
public void testTrainingBatchSizeProtectsSingleJavaArrayAllocation() {
assertThat(NativeVectorGlobalIndexWriter.vectorBatchSize(4096, 128)).isEqualTo(4096);
assertThat(
NativeVectorGlobalIndexWriter.vectorBatchSize(
4096, NativeVectorGlobalIndexWriter.MAX_FLOAT_ARRAY_LENGTH))
.isEqualTo(1);
assertThatThrownBy(() -> NativeVectorGlobalIndexWriter.vectorBatchSize(4096, 0))
.isInstanceOf(IllegalArgumentException.class)
.hasMessageContaining("positive integer");
}

@Test
public void testMetaSerializationIsEmptyMap() throws IOException {
VectorIndexMeta meta = new VectorIndexMeta();
Expand Down
Loading
Loading