diff --git a/paimon-vector/pom.xml b/paimon-vector/pom.xml
index 9fd20e915dd0..7f54c7dabc11 100644
--- a/paimon-vector/pom.xml
+++ b/paimon-vector/pom.xml
@@ -32,7 +32,7 @@ under the License.
Paimon : Vector Index
- 0.1.0
+ 0.1.0-SNAPSHOT
diff --git a/paimon-vector/src/main/java/org/apache/paimon/vector/index/NativeVectorGlobalIndexWriter.java b/paimon-vector/src/main/java/org/apache/paimon/vector/index/NativeVectorGlobalIndexWriter.java
index 89aa004bee21..2bbffdb12e75 100644
--- a/paimon-vector/src/main/java/org/apache/paimon/vector/index/NativeVectorGlobalIndexWriter.java
+++ b/paimon-vector/src/main/java/org/apache/paimon/vector/index/NativeVectorGlobalIndexWriter.java
@@ -62,11 +62,15 @@ public class NativeVectorGlobalIndexWriter implements GlobalIndexSingleColumnWri
private static final int IO_BUFFER_SIZE = 8 * 1024 * 1024;
private static final int ADD_BATCH_SIZE = 10000;
+ private static final int TRAIN_BATCH_SIZE = 4096;
+ static final int MAX_FLOAT_ARRAY_LENGTH = Integer.MAX_VALUE - 8;
+ private static final long TRAIN_MEMORY_WARNING_BYTES = 4L * 1024 * 1024 * 1024;
private final GlobalIndexFileWriter fileWriter;
private final String identifier;
private final Map nativeOptions;
private final int dim;
+ private final int trainMaxSamples;
private File tempVectorFile;
private FileChannel writeChannel;
@@ -84,11 +88,30 @@ public NativeVectorGlobalIndexWriter(
DataType fieldType,
Map options,
String identifier) {
+ this(
+ fileWriter,
+ fieldType,
+ options,
+ identifier,
+ NativeVectorGlobalIndexerFactory.DEFAULT_TRAIN_MAX_SAMPLES);
+ }
+
+ public NativeVectorGlobalIndexWriter(
+ GlobalIndexFileWriter fileWriter,
+ DataType fieldType,
+ Map options,
+ String identifier,
+ int trainMaxSamples) {
this.fileWriter = fileWriter;
this.identifier = identifier;
validateFieldType(fieldType);
this.nativeOptions = options;
this.dim = Integer.parseInt(options.get("dimension"));
+ if (trainMaxSamples <= 0) {
+ throw new IllegalArgumentException(
+ "trainMaxSamples must be a positive integer: " + trainMaxSamples);
+ }
+ this.trainMaxSamples = trainMaxSamples;
this.count = 0;
this.closed = false;
this.recordSizeInBytes = checkedRecordSize(dim, IO_BUFFER_SIZE);
@@ -272,8 +295,10 @@ private String fileNamePrefix() {
}
private void trainFromTempFile(VectorIndexWriter writer) throws IOException {
- int trainCount = (int) count;
- float[] trainData = new float[trainCount * dim];
+ int trainCount = trainingVectorCount(count, trainMaxSamples);
+ int trainBatchSize = vectorBatchSize(TRAIN_BATCH_SIZE, dim);
+ float[] batchVectors = new float[trainBatchSize * dim];
+ logTrainingMemoryEstimate(trainCount);
try (RandomAccessFile raf = new RandomAccessFile(tempVectorFile, "r");
FileChannel channel = raf.getChannel()) {
@@ -281,16 +306,47 @@ private void trainFromTempFile(VectorIndexWriter writer) throws IOException {
readBuf.order(ByteOrder.nativeOrder());
readBuf.limit(0);
- for (int i = 0; i < trainCount; i++) {
+ long selected = 0;
+ long nextSampleIndex = 0;
+ int batchCount = 0;
+
+ for (long recordIndex = 0;
+ recordIndex < count && selected < trainCount;
+ recordIndex++) {
ensureAvailable(readBuf, channel, recordSizeInBytes);
readBuf.getLong(); // skip rowId
- for (int d = 0; d < dim; d++) {
- trainData[i * dim + d] = readBuf.getFloat();
+ if (recordIndex == nextSampleIndex) {
+ for (int d = 0; d < dim; d++) {
+ batchVectors[batchCount * dim + d] = readBuf.getFloat();
+ }
+ selected++;
+ batchCount++;
+ if (batchCount == trainBatchSize) {
+ writer.addTrainingVectors(batchVectors, batchCount);
+ batchCount = 0;
+ }
+ if (selected < trainCount) {
+ nextSampleIndex = sampleIndex(selected, count, trainCount);
+ }
+ } else {
+ readBuf.position(readBuf.position() + dim * Float.BYTES);
}
}
+
+ if (batchCount > 0) {
+ writer.addTrainingVectors(
+ Arrays.copyOf(batchVectors, batchCount * dim), batchCount);
+ }
+ if (selected != trainCount) {
+ throw new IOException(
+ "Expected to select "
+ + trainCount
+ + " training vectors, but selected "
+ + selected);
+ }
}
- writer.train(trainData, trainCount);
+ writer.finishTraining();
}
private void addVectorsFromTempFile(VectorIndexWriter writer) throws IOException {
@@ -386,6 +442,73 @@ private static int checkedRecordSize(int dim, int bufferCapacity) {
return (int) recordSize;
}
+ static int trainingVectorCount(long vectorCount, int trainMaxSamples) {
+ if (vectorCount <= 0) {
+ return 0;
+ }
+ if (trainMaxSamples <= 0) {
+ throw new IllegalArgumentException(
+ "trainMaxSamples must be a positive integer: " + trainMaxSamples);
+ }
+ return (int) Math.min(vectorCount, (long) trainMaxSamples);
+ }
+
+ static int vectorBatchSize(int requestedBatchSize, int dim) {
+ if (requestedBatchSize <= 0) {
+ throw new IllegalArgumentException(
+ "requestedBatchSize must be a positive integer: " + requestedBatchSize);
+ }
+ if (dim <= 0) {
+ throw new IllegalArgumentException("dim must be a positive integer: " + dim);
+ }
+ int maxBatchSize = MAX_FLOAT_ARRAY_LENGTH / dim;
+ if (maxBatchSize <= 0) {
+ throw new IllegalStateException(
+ "Vector dimension " + dim + " exceeds Java float array capacity");
+ }
+ return Math.min(requestedBatchSize, maxBatchSize);
+ }
+
+ private void logTrainingMemoryEstimate(int trainCount) {
+ long rawBytes = saturatedMultiply(saturatedMultiply(trainCount, dim), Float.BYTES);
+ long estimatedPeakBytes = saturatedMultiply(rawBytes, 2);
+ if (estimatedPeakBytes >= TRAIN_MEMORY_WARNING_BYTES) {
+ LOG.warn(
+ "{} training uses {} samples out of {} vectors (dim={}). Estimated native "
+ + "training peak is at least {} bytes (~{} GiB) before OPQ and "
+ + "temporary buffers.",
+ identifier,
+ trainCount,
+ count,
+ dim,
+ estimatedPeakBytes,
+ String.format("%.2f", estimatedPeakBytes / 1024.0 / 1024.0 / 1024.0));
+ } else {
+ LOG.info(
+ "{} training uses {} samples out of {} vectors (dim={})",
+ identifier,
+ trainCount,
+ count,
+ dim);
+ }
+ }
+
+ private static long saturatedMultiply(long left, long right) {
+ if (left == 0 || right == 0) {
+ return 0;
+ }
+ if (left > Long.MAX_VALUE / right) {
+ return Long.MAX_VALUE;
+ }
+ return left * right;
+ }
+
+ private static long sampleIndex(long sampleOrdinal, long vectorCount, int trainCount) {
+ long quotient = vectorCount / trainCount;
+ long remainder = vectorCount % trainCount;
+ return sampleOrdinal * quotient + sampleOrdinal * remainder / trainCount;
+ }
+
@Override
public void close() {
if (!closed) {
diff --git a/paimon-vector/src/main/java/org/apache/paimon/vector/index/NativeVectorGlobalIndexer.java b/paimon-vector/src/main/java/org/apache/paimon/vector/index/NativeVectorGlobalIndexer.java
index f45a97d34ae5..43829545f367 100644
--- a/paimon-vector/src/main/java/org/apache/paimon/vector/index/NativeVectorGlobalIndexer.java
+++ b/paimon-vector/src/main/java/org/apache/paimon/vector/index/NativeVectorGlobalIndexer.java
@@ -39,17 +39,36 @@ public class NativeVectorGlobalIndexer implements VectorGlobalIndexer {
private final DataType fieldType;
private final Map options;
private final String identifier;
+ private final int trainMaxSamples;
public NativeVectorGlobalIndexer(
DataType fieldType, Map options, String identifier) {
+ this(
+ fieldType,
+ options,
+ identifier,
+ NativeVectorGlobalIndexerFactory.DEFAULT_TRAIN_MAX_SAMPLES);
+ }
+
+ public NativeVectorGlobalIndexer(
+ DataType fieldType,
+ Map options,
+ String identifier,
+ int trainMaxSamples) {
this.fieldType = fieldType;
this.options = Objects.requireNonNull(options, "options must not be null");
this.identifier = Objects.requireNonNull(identifier, "identifier must not be null");
+ if (trainMaxSamples <= 0) {
+ throw new IllegalArgumentException(
+ "trainMaxSamples must be a positive integer: " + trainMaxSamples);
+ }
+ this.trainMaxSamples = trainMaxSamples;
}
@Override
public GlobalIndexWriter createWriter(GlobalIndexFileWriter fileWriter) {
- return new NativeVectorGlobalIndexWriter(fileWriter, fieldType, options, identifier);
+ return new NativeVectorGlobalIndexWriter(
+ fileWriter, fieldType, options, identifier, trainMaxSamples);
}
@Override
diff --git a/paimon-vector/src/main/java/org/apache/paimon/vector/index/NativeVectorGlobalIndexerFactory.java b/paimon-vector/src/main/java/org/apache/paimon/vector/index/NativeVectorGlobalIndexerFactory.java
index 8e4daa030fad..759e9339349b 100644
--- a/paimon-vector/src/main/java/org/apache/paimon/vector/index/NativeVectorGlobalIndexerFactory.java
+++ b/paimon-vector/src/main/java/org/apache/paimon/vector/index/NativeVectorGlobalIndexerFactory.java
@@ -32,6 +32,8 @@
public abstract class NativeVectorGlobalIndexerFactory implements GlobalIndexerFactory {
private static final int DEFAULT_DIMENSION = 128;
+ static final String TRAIN_MAX_SAMPLES_OPTION = "train.max-samples";
+ static final int DEFAULT_TRAIN_MAX_SAMPLES = 65536;
@Override
public GlobalIndexer create(DataField field, Options options) {
@@ -39,7 +41,8 @@ public GlobalIndexer create(DataField field, Options options) {
return new NativeVectorGlobalIndexer(
field.type(),
nativeOptions(field.type(), options, identifier, field.name()),
- identifier);
+ identifier,
+ trainMaxSamples(options, identifier, field.name()));
}
static Map nativeOptions(
@@ -78,6 +81,43 @@ static Map nativeOptions(
return nativeOptions;
}
+ static int trainMaxSamples(Options tableOptions, String identifier, String fieldName) {
+ String optionPrefix = identifier + ".";
+ String fieldPrefix = "fields." + fieldName + ".";
+ String indexKey = optionPrefix + TRAIN_MAX_SAMPLES_OPTION;
+ String fieldKey = fieldPrefix + TRAIN_MAX_SAMPLES_OPTION;
+ Map tableOptionsMap = tableOptions.toMap();
+
+ String key = null;
+ String value = null;
+ if (tableOptionsMap.containsKey(indexKey)) {
+ key = indexKey;
+ value = tableOptionsMap.get(indexKey);
+ }
+ if (tableOptionsMap.containsKey(fieldKey)) {
+ key = fieldKey;
+ value = tableOptionsMap.get(fieldKey);
+ }
+ if (value == null) {
+ return DEFAULT_TRAIN_MAX_SAMPLES;
+ }
+
+ try {
+ int parsed = Integer.parseInt(value.trim());
+ if (parsed > 0) {
+ return parsed;
+ }
+ throw invalidTrainMaxSamples(key, value);
+ } catch (NumberFormatException e) {
+ throw invalidTrainMaxSamples(key, value);
+ }
+ }
+
+ private static IllegalArgumentException invalidTrainMaxSamples(String key, String value) {
+ return new IllegalArgumentException(
+ "Invalid value for '" + key + "': " + value + ". Must be a positive integer.");
+ }
+
private static String nativeOptionKey(String optionKey) {
switch (optionKey) {
case "index.dimension":
diff --git a/paimon-vector/src/test/java/org/apache/paimon/vector/index/NativeVectorGlobalIndexTest.java b/paimon-vector/src/test/java/org/apache/paimon/vector/index/NativeVectorGlobalIndexTest.java
index 72303f5bf036..5588cea5d6d9 100644
--- a/paimon-vector/src/test/java/org/apache/paimon/vector/index/NativeVectorGlobalIndexTest.java
+++ b/paimon-vector/src/test/java/org/apache/paimon/vector/index/NativeVectorGlobalIndexTest.java
@@ -180,6 +180,31 @@ public void testAllNullReturnsEmpty() {
assertThat(results).isEmpty();
}
+ @Test
+ public void testTrainingVectorCountUsesOnlyConfiguredSampleLimit() {
+ int oldJavaArrayLimitFor1024Dim =
+ NativeVectorGlobalIndexWriter.MAX_FLOAT_ARRAY_LENGTH / 1024;
+ int requestedSamples = oldJavaArrayLimitFor1024Dim + 1;
+
+ assertThat(
+ NativeVectorGlobalIndexWriter.trainingVectorCount(
+ requestedSamples, requestedSamples))
+ .isEqualTo(requestedSamples);
+ assertThat(NativeVectorGlobalIndexWriter.trainingVectorCount(10_000L, 64)).isEqualTo(64);
+ }
+
+ @Test
+ public void testTrainingBatchSizeProtectsSingleJavaArrayAllocation() {
+ assertThat(NativeVectorGlobalIndexWriter.vectorBatchSize(4096, 128)).isEqualTo(4096);
+ assertThat(
+ NativeVectorGlobalIndexWriter.vectorBatchSize(
+ 4096, NativeVectorGlobalIndexWriter.MAX_FLOAT_ARRAY_LENGTH))
+ .isEqualTo(1);
+ assertThatThrownBy(() -> NativeVectorGlobalIndexWriter.vectorBatchSize(4096, 0))
+ .isInstanceOf(IllegalArgumentException.class)
+ .hasMessageContaining("positive integer");
+ }
+
@Test
public void testMetaSerializationIsEmptyMap() throws IOException {
VectorIndexMeta meta = new VectorIndexMeta();
diff --git a/paimon-vector/src/test/java/org/apache/paimon/vector/index/NativeVectorGlobalIndexerFactoryTest.java b/paimon-vector/src/test/java/org/apache/paimon/vector/index/NativeVectorGlobalIndexerFactoryTest.java
index 92c56b648522..b8e013b3464b 100644
--- a/paimon-vector/src/test/java/org/apache/paimon/vector/index/NativeVectorGlobalIndexerFactoryTest.java
+++ b/paimon-vector/src/test/java/org/apache/paimon/vector/index/NativeVectorGlobalIndexerFactoryTest.java
@@ -217,4 +217,79 @@ public void testFieldLevelVectorOptionsCoexistWithCoreFieldOptions() {
.containsEntry("nlist", "256")
.doesNotContainKey("aggregate-function");
}
+
+ @Test
+ public void testTrainMaxSamplesDefaultAndOverrides() {
+ Options options = new Options();
+
+ assertThat(
+ NativeVectorGlobalIndexerFactory.trainMaxSamples(
+ options, IvfFlatVectorGlobalIndexerFactory.IDENTIFIER, "vec"))
+ .isEqualTo(NativeVectorGlobalIndexerFactory.DEFAULT_TRAIN_MAX_SAMPLES);
+
+ options.setString("ivf-flat.train.max-samples", "1024");
+ assertThat(
+ NativeVectorGlobalIndexerFactory.trainMaxSamples(
+ options, IvfFlatVectorGlobalIndexerFactory.IDENTIFIER, "vec"))
+ .isEqualTo(1024);
+
+ options.setString("fields.vec.train.max-samples", "2048");
+ assertThat(
+ NativeVectorGlobalIndexerFactory.trainMaxSamples(
+ options, IvfFlatVectorGlobalIndexerFactory.IDENTIFIER, "vec"))
+ .isEqualTo(2048);
+
+ assertThat(
+ NativeVectorGlobalIndexerFactory.trainMaxSamples(
+ options, IvfFlatVectorGlobalIndexerFactory.IDENTIFIER, "other"))
+ .isEqualTo(1024);
+ }
+
+ @Test
+ public void testInvalidTrainMaxSamples() {
+ Options options = new Options();
+ options.setString("ivf-flat.train.max-samples", "0");
+
+ assertThatThrownBy(
+ () ->
+ NativeVectorGlobalIndexerFactory.trainMaxSamples(
+ options,
+ IvfFlatVectorGlobalIndexerFactory.IDENTIFIER,
+ "vec"))
+ .isInstanceOf(IllegalArgumentException.class)
+ .hasMessageContaining("ivf-flat.train.max-samples")
+ .hasMessageContaining("positive integer");
+
+ options.setString("fields.vec.train.max-samples", "bad");
+ assertThatThrownBy(
+ () ->
+ NativeVectorGlobalIndexerFactory.trainMaxSamples(
+ options,
+ IvfFlatVectorGlobalIndexerFactory.IDENTIFIER,
+ "vec"))
+ .isInstanceOf(IllegalArgumentException.class)
+ .hasMessageContaining("fields.vec.train.max-samples")
+ .hasMessageContaining("positive integer");
+ }
+
+ @Test
+ public void testTrainMaxSamplesIsNotNativeOption() {
+ Options options = new Options();
+ options.setString("ivf-flat.dimension", "32");
+ options.setString("ivf-flat.nlist", "128");
+ options.setString("ivf-flat.train.max-samples", "1024");
+ options.setString("fields.vec.train.max-samples", "2048");
+
+ Map nativeOptions =
+ NativeVectorGlobalIndexerFactory.nativeOptions(
+ new ArrayType(new FloatType()),
+ options,
+ IvfFlatVectorGlobalIndexerFactory.IDENTIFIER,
+ "vec");
+
+ assertThat(nativeOptions)
+ .containsEntry("dimension", "32")
+ .containsEntry("nlist", "128")
+ .doesNotContainKey("train.max-samples");
+ }
}