diff --git a/paimon-vector/pom.xml b/paimon-vector/pom.xml index 9fd20e915dd0..7f54c7dabc11 100644 --- a/paimon-vector/pom.xml +++ b/paimon-vector/pom.xml @@ -32,7 +32,7 @@ under the License. Paimon : Vector Index - 0.1.0 + 0.1.0-SNAPSHOT diff --git a/paimon-vector/src/main/java/org/apache/paimon/vector/index/NativeVectorGlobalIndexWriter.java b/paimon-vector/src/main/java/org/apache/paimon/vector/index/NativeVectorGlobalIndexWriter.java index 89aa004bee21..2bbffdb12e75 100644 --- a/paimon-vector/src/main/java/org/apache/paimon/vector/index/NativeVectorGlobalIndexWriter.java +++ b/paimon-vector/src/main/java/org/apache/paimon/vector/index/NativeVectorGlobalIndexWriter.java @@ -62,11 +62,15 @@ public class NativeVectorGlobalIndexWriter implements GlobalIndexSingleColumnWri private static final int IO_BUFFER_SIZE = 8 * 1024 * 1024; private static final int ADD_BATCH_SIZE = 10000; + private static final int TRAIN_BATCH_SIZE = 4096; + static final int MAX_FLOAT_ARRAY_LENGTH = Integer.MAX_VALUE - 8; + private static final long TRAIN_MEMORY_WARNING_BYTES = 4L * 1024 * 1024 * 1024; private final GlobalIndexFileWriter fileWriter; private final String identifier; private final Map nativeOptions; private final int dim; + private final int trainMaxSamples; private File tempVectorFile; private FileChannel writeChannel; @@ -84,11 +88,30 @@ public NativeVectorGlobalIndexWriter( DataType fieldType, Map options, String identifier) { + this( + fileWriter, + fieldType, + options, + identifier, + NativeVectorGlobalIndexerFactory.DEFAULT_TRAIN_MAX_SAMPLES); + } + + public NativeVectorGlobalIndexWriter( + GlobalIndexFileWriter fileWriter, + DataType fieldType, + Map options, + String identifier, + int trainMaxSamples) { this.fileWriter = fileWriter; this.identifier = identifier; validateFieldType(fieldType); this.nativeOptions = options; this.dim = Integer.parseInt(options.get("dimension")); + if (trainMaxSamples <= 0) { + throw new IllegalArgumentException( + "trainMaxSamples must be a positive integer: " + trainMaxSamples); + } + this.trainMaxSamples = trainMaxSamples; this.count = 0; this.closed = false; this.recordSizeInBytes = checkedRecordSize(dim, IO_BUFFER_SIZE); @@ -272,8 +295,10 @@ private String fileNamePrefix() { } private void trainFromTempFile(VectorIndexWriter writer) throws IOException { - int trainCount = (int) count; - float[] trainData = new float[trainCount * dim]; + int trainCount = trainingVectorCount(count, trainMaxSamples); + int trainBatchSize = vectorBatchSize(TRAIN_BATCH_SIZE, dim); + float[] batchVectors = new float[trainBatchSize * dim]; + logTrainingMemoryEstimate(trainCount); try (RandomAccessFile raf = new RandomAccessFile(tempVectorFile, "r"); FileChannel channel = raf.getChannel()) { @@ -281,16 +306,47 @@ private void trainFromTempFile(VectorIndexWriter writer) throws IOException { readBuf.order(ByteOrder.nativeOrder()); readBuf.limit(0); - for (int i = 0; i < trainCount; i++) { + long selected = 0; + long nextSampleIndex = 0; + int batchCount = 0; + + for (long recordIndex = 0; + recordIndex < count && selected < trainCount; + recordIndex++) { ensureAvailable(readBuf, channel, recordSizeInBytes); readBuf.getLong(); // skip rowId - for (int d = 0; d < dim; d++) { - trainData[i * dim + d] = readBuf.getFloat(); + if (recordIndex == nextSampleIndex) { + for (int d = 0; d < dim; d++) { + batchVectors[batchCount * dim + d] = readBuf.getFloat(); + } + selected++; + batchCount++; + if (batchCount == trainBatchSize) { + writer.addTrainingVectors(batchVectors, batchCount); + batchCount = 0; + } + if (selected < trainCount) { + nextSampleIndex = sampleIndex(selected, count, trainCount); + } + } else { + readBuf.position(readBuf.position() + dim * Float.BYTES); } } + + if (batchCount > 0) { + writer.addTrainingVectors( + Arrays.copyOf(batchVectors, batchCount * dim), batchCount); + } + if (selected != trainCount) { + throw new IOException( + "Expected to select " + + trainCount + + " training vectors, but selected " + + selected); + } } - writer.train(trainData, trainCount); + writer.finishTraining(); } private void addVectorsFromTempFile(VectorIndexWriter writer) throws IOException { @@ -386,6 +442,73 @@ private static int checkedRecordSize(int dim, int bufferCapacity) { return (int) recordSize; } + static int trainingVectorCount(long vectorCount, int trainMaxSamples) { + if (vectorCount <= 0) { + return 0; + } + if (trainMaxSamples <= 0) { + throw new IllegalArgumentException( + "trainMaxSamples must be a positive integer: " + trainMaxSamples); + } + return (int) Math.min(vectorCount, (long) trainMaxSamples); + } + + static int vectorBatchSize(int requestedBatchSize, int dim) { + if (requestedBatchSize <= 0) { + throw new IllegalArgumentException( + "requestedBatchSize must be a positive integer: " + requestedBatchSize); + } + if (dim <= 0) { + throw new IllegalArgumentException("dim must be a positive integer: " + dim); + } + int maxBatchSize = MAX_FLOAT_ARRAY_LENGTH / dim; + if (maxBatchSize <= 0) { + throw new IllegalStateException( + "Vector dimension " + dim + " exceeds Java float array capacity"); + } + return Math.min(requestedBatchSize, maxBatchSize); + } + + private void logTrainingMemoryEstimate(int trainCount) { + long rawBytes = saturatedMultiply(saturatedMultiply(trainCount, dim), Float.BYTES); + long estimatedPeakBytes = saturatedMultiply(rawBytes, 2); + if (estimatedPeakBytes >= TRAIN_MEMORY_WARNING_BYTES) { + LOG.warn( + "{} training uses {} samples out of {} vectors (dim={}). Estimated native " + + "training peak is at least {} bytes (~{} GiB) before OPQ and " + + "temporary buffers.", + identifier, + trainCount, + count, + dim, + estimatedPeakBytes, + String.format("%.2f", estimatedPeakBytes / 1024.0 / 1024.0 / 1024.0)); + } else { + LOG.info( + "{} training uses {} samples out of {} vectors (dim={})", + identifier, + trainCount, + count, + dim); + } + } + + private static long saturatedMultiply(long left, long right) { + if (left == 0 || right == 0) { + return 0; + } + if (left > Long.MAX_VALUE / right) { + return Long.MAX_VALUE; + } + return left * right; + } + + private static long sampleIndex(long sampleOrdinal, long vectorCount, int trainCount) { + long quotient = vectorCount / trainCount; + long remainder = vectorCount % trainCount; + return sampleOrdinal * quotient + sampleOrdinal * remainder / trainCount; + } + @Override public void close() { if (!closed) { diff --git a/paimon-vector/src/main/java/org/apache/paimon/vector/index/NativeVectorGlobalIndexer.java b/paimon-vector/src/main/java/org/apache/paimon/vector/index/NativeVectorGlobalIndexer.java index f45a97d34ae5..43829545f367 100644 --- a/paimon-vector/src/main/java/org/apache/paimon/vector/index/NativeVectorGlobalIndexer.java +++ b/paimon-vector/src/main/java/org/apache/paimon/vector/index/NativeVectorGlobalIndexer.java @@ -39,17 +39,36 @@ public class NativeVectorGlobalIndexer implements VectorGlobalIndexer { private final DataType fieldType; private final Map options; private final String identifier; + private final int trainMaxSamples; public NativeVectorGlobalIndexer( DataType fieldType, Map options, String identifier) { + this( + fieldType, + options, + identifier, + NativeVectorGlobalIndexerFactory.DEFAULT_TRAIN_MAX_SAMPLES); + } + + public NativeVectorGlobalIndexer( + DataType fieldType, + Map options, + String identifier, + int trainMaxSamples) { this.fieldType = fieldType; this.options = Objects.requireNonNull(options, "options must not be null"); this.identifier = Objects.requireNonNull(identifier, "identifier must not be null"); + if (trainMaxSamples <= 0) { + throw new IllegalArgumentException( + "trainMaxSamples must be a positive integer: " + trainMaxSamples); + } + this.trainMaxSamples = trainMaxSamples; } @Override public GlobalIndexWriter createWriter(GlobalIndexFileWriter fileWriter) { - return new NativeVectorGlobalIndexWriter(fileWriter, fieldType, options, identifier); + return new NativeVectorGlobalIndexWriter( + fileWriter, fieldType, options, identifier, trainMaxSamples); } @Override diff --git a/paimon-vector/src/main/java/org/apache/paimon/vector/index/NativeVectorGlobalIndexerFactory.java b/paimon-vector/src/main/java/org/apache/paimon/vector/index/NativeVectorGlobalIndexerFactory.java index 8e4daa030fad..759e9339349b 100644 --- a/paimon-vector/src/main/java/org/apache/paimon/vector/index/NativeVectorGlobalIndexerFactory.java +++ b/paimon-vector/src/main/java/org/apache/paimon/vector/index/NativeVectorGlobalIndexerFactory.java @@ -32,6 +32,8 @@ public abstract class NativeVectorGlobalIndexerFactory implements GlobalIndexerFactory { private static final int DEFAULT_DIMENSION = 128; + static final String TRAIN_MAX_SAMPLES_OPTION = "train.max-samples"; + static final int DEFAULT_TRAIN_MAX_SAMPLES = 65536; @Override public GlobalIndexer create(DataField field, Options options) { @@ -39,7 +41,8 @@ public GlobalIndexer create(DataField field, Options options) { return new NativeVectorGlobalIndexer( field.type(), nativeOptions(field.type(), options, identifier, field.name()), - identifier); + identifier, + trainMaxSamples(options, identifier, field.name())); } static Map nativeOptions( @@ -78,6 +81,43 @@ static Map nativeOptions( return nativeOptions; } + static int trainMaxSamples(Options tableOptions, String identifier, String fieldName) { + String optionPrefix = identifier + "."; + String fieldPrefix = "fields." + fieldName + "."; + String indexKey = optionPrefix + TRAIN_MAX_SAMPLES_OPTION; + String fieldKey = fieldPrefix + TRAIN_MAX_SAMPLES_OPTION; + Map tableOptionsMap = tableOptions.toMap(); + + String key = null; + String value = null; + if (tableOptionsMap.containsKey(indexKey)) { + key = indexKey; + value = tableOptionsMap.get(indexKey); + } + if (tableOptionsMap.containsKey(fieldKey)) { + key = fieldKey; + value = tableOptionsMap.get(fieldKey); + } + if (value == null) { + return DEFAULT_TRAIN_MAX_SAMPLES; + } + + try { + int parsed = Integer.parseInt(value.trim()); + if (parsed > 0) { + return parsed; + } + throw invalidTrainMaxSamples(key, value); + } catch (NumberFormatException e) { + throw invalidTrainMaxSamples(key, value); + } + } + + private static IllegalArgumentException invalidTrainMaxSamples(String key, String value) { + return new IllegalArgumentException( + "Invalid value for '" + key + "': " + value + ". Must be a positive integer."); + } + private static String nativeOptionKey(String optionKey) { switch (optionKey) { case "index.dimension": diff --git a/paimon-vector/src/test/java/org/apache/paimon/vector/index/NativeVectorGlobalIndexTest.java b/paimon-vector/src/test/java/org/apache/paimon/vector/index/NativeVectorGlobalIndexTest.java index 72303f5bf036..5588cea5d6d9 100644 --- a/paimon-vector/src/test/java/org/apache/paimon/vector/index/NativeVectorGlobalIndexTest.java +++ b/paimon-vector/src/test/java/org/apache/paimon/vector/index/NativeVectorGlobalIndexTest.java @@ -180,6 +180,31 @@ public void testAllNullReturnsEmpty() { assertThat(results).isEmpty(); } + @Test + public void testTrainingVectorCountUsesOnlyConfiguredSampleLimit() { + int oldJavaArrayLimitFor1024Dim = + NativeVectorGlobalIndexWriter.MAX_FLOAT_ARRAY_LENGTH / 1024; + int requestedSamples = oldJavaArrayLimitFor1024Dim + 1; + + assertThat( + NativeVectorGlobalIndexWriter.trainingVectorCount( + requestedSamples, requestedSamples)) + .isEqualTo(requestedSamples); + assertThat(NativeVectorGlobalIndexWriter.trainingVectorCount(10_000L, 64)).isEqualTo(64); + } + + @Test + public void testTrainingBatchSizeProtectsSingleJavaArrayAllocation() { + assertThat(NativeVectorGlobalIndexWriter.vectorBatchSize(4096, 128)).isEqualTo(4096); + assertThat( + NativeVectorGlobalIndexWriter.vectorBatchSize( + 4096, NativeVectorGlobalIndexWriter.MAX_FLOAT_ARRAY_LENGTH)) + .isEqualTo(1); + assertThatThrownBy(() -> NativeVectorGlobalIndexWriter.vectorBatchSize(4096, 0)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("positive integer"); + } + @Test public void testMetaSerializationIsEmptyMap() throws IOException { VectorIndexMeta meta = new VectorIndexMeta(); diff --git a/paimon-vector/src/test/java/org/apache/paimon/vector/index/NativeVectorGlobalIndexerFactoryTest.java b/paimon-vector/src/test/java/org/apache/paimon/vector/index/NativeVectorGlobalIndexerFactoryTest.java index 92c56b648522..b8e013b3464b 100644 --- a/paimon-vector/src/test/java/org/apache/paimon/vector/index/NativeVectorGlobalIndexerFactoryTest.java +++ b/paimon-vector/src/test/java/org/apache/paimon/vector/index/NativeVectorGlobalIndexerFactoryTest.java @@ -217,4 +217,79 @@ public void testFieldLevelVectorOptionsCoexistWithCoreFieldOptions() { .containsEntry("nlist", "256") .doesNotContainKey("aggregate-function"); } + + @Test + public void testTrainMaxSamplesDefaultAndOverrides() { + Options options = new Options(); + + assertThat( + NativeVectorGlobalIndexerFactory.trainMaxSamples( + options, IvfFlatVectorGlobalIndexerFactory.IDENTIFIER, "vec")) + .isEqualTo(NativeVectorGlobalIndexerFactory.DEFAULT_TRAIN_MAX_SAMPLES); + + options.setString("ivf-flat.train.max-samples", "1024"); + assertThat( + NativeVectorGlobalIndexerFactory.trainMaxSamples( + options, IvfFlatVectorGlobalIndexerFactory.IDENTIFIER, "vec")) + .isEqualTo(1024); + + options.setString("fields.vec.train.max-samples", "2048"); + assertThat( + NativeVectorGlobalIndexerFactory.trainMaxSamples( + options, IvfFlatVectorGlobalIndexerFactory.IDENTIFIER, "vec")) + .isEqualTo(2048); + + assertThat( + NativeVectorGlobalIndexerFactory.trainMaxSamples( + options, IvfFlatVectorGlobalIndexerFactory.IDENTIFIER, "other")) + .isEqualTo(1024); + } + + @Test + public void testInvalidTrainMaxSamples() { + Options options = new Options(); + options.setString("ivf-flat.train.max-samples", "0"); + + assertThatThrownBy( + () -> + NativeVectorGlobalIndexerFactory.trainMaxSamples( + options, + IvfFlatVectorGlobalIndexerFactory.IDENTIFIER, + "vec")) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("ivf-flat.train.max-samples") + .hasMessageContaining("positive integer"); + + options.setString("fields.vec.train.max-samples", "bad"); + assertThatThrownBy( + () -> + NativeVectorGlobalIndexerFactory.trainMaxSamples( + options, + IvfFlatVectorGlobalIndexerFactory.IDENTIFIER, + "vec")) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("fields.vec.train.max-samples") + .hasMessageContaining("positive integer"); + } + + @Test + public void testTrainMaxSamplesIsNotNativeOption() { + Options options = new Options(); + options.setString("ivf-flat.dimension", "32"); + options.setString("ivf-flat.nlist", "128"); + options.setString("ivf-flat.train.max-samples", "1024"); + options.setString("fields.vec.train.max-samples", "2048"); + + Map nativeOptions = + NativeVectorGlobalIndexerFactory.nativeOptions( + new ArrayType(new FloatType()), + options, + IvfFlatVectorGlobalIndexerFactory.IDENTIFIER, + "vec"); + + assertThat(nativeOptions) + .containsEntry("dimension", "32") + .containsEntry("nlist", "128") + .doesNotContainKey("train.max-samples"); + } }