From c5373fa9a7b14c58c7e14fdea53c36b75669dbea Mon Sep 17 00:00:00 2001 From: Davide Angelocola Date: Thu, 11 Jun 2026 20:49:08 +0200 Subject: [PATCH 1/2] docs(readme): drop alpha banner, add perf line, trim lifecycle box MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Drop the "Alpha — APIs will change without notice" banner; project is past 0.5.0 and the wire format is stable. - Add a one-line perf claim in the intro paragraph pointing at the benchmarks section in docs/explanation.md. - Shrink the Lifecycle blockquote from 6 lines to 3, push the long-form rules to docs/explanation.md#memory-model. - Replace em-dash separator with parens on the trailing example line. Co-Authored-By: Claude Opus 4.7 --- README.md | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index 85887cf1..a11a5faf 100644 --- a/README.md +++ b/README.md @@ -4,11 +4,10 @@ [![Maven Central](https://img.shields.io/maven-central/v/io.github.dfa1.vortex/vortex-reader.svg)](https://central.sonatype.com/artifact/io.github.dfa1.vortex/vortex-reader) [![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/license/Apache-2.0) -> **Alpha** — not production-ready. APIs will change without notice. - Pure-Java reader/writer for the [Vortex](https://github.com/vortex-data/vortex) columnar file format. 100% Java, no JNI, no `sun.misc.Unsafe`. Uses the FFM API (`MemorySegment`/`Arena`, Java 25+) -for zero-copy memory-mapped reads. +for zero-copy memory-mapped reads. Read benchmarks match or beat the Rust JNI on the workloads +tested (Apple M5, JDK 25); see [docs/explanation.md#benchmarks](docs/explanation.md#benchmarks). | Project | Language | Notes | |---------------------------------------------------------------------|----------|-----------------------------------------| @@ -50,15 +49,12 @@ try (VortexReader vf = VortexReader.open(Path.of("data/example.vortex")); } ``` -> **Lifecycle.** `ScanIterator` implements `Iterator` and `Chunk` implements -> `AutoCloseable`. Each chunk owns a confined `Arena`; closing it releases the -> decoded buffers. Calling `iter.next()` while a prior chunk is still open throws -> `IllegalStateException`. Use try-with-resources, or -> `iter.forEachRemaining(c -> ...)` which closes each chunk for you. See -> [docs/explanation.md#memory-model](docs/explanation.md#memory-model). +> **Lifecycle.** `Chunk` owns a confined `Arena` — close it (try-with-resources +> or `iter.forEachRemaining`) to release the decoded buffers. Full lifecycle +> rules: [docs/explanation.md#memory-model](docs/explanation.md#memory-model). -For more examples — writing, projection, filtering, custom encodings, and the CLI — -see the documentation below. +For more examples (writing, projection, filtering, custom encodings, CLI) see +the documentation below. ## Documentation From e8fefd939087d8140f0b0f412d52576fb38fa1e2 Mon Sep 17 00:00:00 2001 From: Davide Angelocola Date: Thu, 11 Jun 2026 21:43:14 +0200 Subject: [PATCH 2/2] =?UTF-8?q?refactor:=20ADR=200001=20=E2=80=94=20split?= =?UTF-8?q?=20read=20and=20write=20runtimes?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Architectural split of the encoding pipeline into separate reader and writer modules per ADR 0001. After this PR a read-only deployment only carries decoders, and a write-only deployment only carries encoders. ## Core changes - Split bifunctional `Encoding` interface into read-side `EncodingDecoder` and write-side `EncodingEncoder` (Phase 1). Existing `Encoding` extends both, preserving source compatibility for in-tree consumers. - `Registry` learns two new SPIs (`EncodingDecoder`, `EncodingEncoder`) alongside the legacy `Encoding` SPI; lookup routes to standalone impls with bifunctional fallback during the lift. - Added `ReadRegistry` and `WriteRegistry` delegating facades that scope the read/write API surface respectively. - Promoted internals required by lifted decoders: BitpackedEncoding.FL_ORDER, AlpEncoding F10/IF10 tables + I64/I32 DType constants, AlpRdEncoding U16/U32/U64 DType constants, Pco format/offset-bits constants, PcoBin, PcoTansDecoder + decode methods, LeBitReader + read methods, ArrayNode/KnownArrayNode/UnknownArrayNode, PTypeIO.set/copyArray, DeltaEncoding helpers, RleEncoding.FL_CHUNK_SIZE. ## reader module - 33 standalone `*EncodingDecoder` implementations covering every registered encoding ID, registered via `META-INF/services/io.github.dfa1.vortex.encoding.EncodingDecoder`. - VortexReader/VortexHttpReader gain a typed `decodeFlatSegment` accessor. ## writer module - 33 standalone `*EncodingEncoder` implementations (Pco encode is a stub, matching prior behavior). - Service manifest for `EncodingEncoder`. - Test-jar dependency on core for shared fixtures. ## Test migration - core/ publishes a test-jar so reader/ and writer/ test trees can reuse shared fixtures (DTypes, EncodeTestHelper, TestRegistry, TestSegments, TestDecodeContexts) without duplication. - All 33 *EncodingTest classes moved out of core/encoding/ into writer/encode/*EncodingEncoderTest or reader/decode/*EncodingDecoderTest per their primary role. - Cross-cutting regression tests dispatched: BitpackedEncodingPatchesTest, BitpackedConstantPatchesBroadcastTest, RandomAccessTest, CascadingCompressorTest → writer/encode/ SegmentBroadcastTest (extracted from PatchesBroadcastRegressionTest) stays in core/encoding as it tests the SegmentBroadcast API. - core/encoding/ test tree now contains only shared fixtures and tests for core machinery (EncodingId, LeBitReader, PcoTansDecoder, Registry, SegmentBroadcast). ## CSV regression fix Lifted BitpackedEncodingDecoder ports the null/empty-metadata fallback from BitpackedEncoding (proto3 elides default-valued fields, so BitPackedMetadata(0, 0, null) serializes to 0 bytes — treat absent metadata as all-defaults rather than rejecting). CsvImporterTest's schema-override path triggers this via the FSST cascade. Co-Authored-By: Claude Opus 4.7 --- README.md | 18 +- .../vortex/cli/tui/VortexInspectorTui.java | 5 +- core/pom.xml | 19 + .../dfa1/vortex/encoding/AlpEncoding.java | 12 +- .../dfa1/vortex/encoding/AlpRdEncoding.java | 6 +- .../dfa1/vortex/encoding/ArrayNode.java | 2 +- .../vortex/encoding/BitpackedEncoding.java | 2 +- .../dfa1/vortex/encoding/DeltaEncoding.java | 16 +- .../github/dfa1/vortex/encoding/Encoding.java | 49 +- .../dfa1/vortex/encoding/EncodingDecoder.java | 28 + .../dfa1/vortex/encoding/EncodingEncoder.java | 41 + .../dfa1/vortex/encoding/KnownArrayNode.java | 2 +- .../dfa1/vortex/encoding/LeBitReader.java | 21 +- .../github/dfa1/vortex/encoding/PTypeIO.java | 4 +- .../github/dfa1/vortex/encoding/PcoBin.java | 2 +- .../dfa1/vortex/encoding/PcoEncoding.java | 10 +- .../dfa1/vortex/encoding/PcoTansDecoder.java | 12 +- .../dfa1/vortex/encoding/ReadRegistry.java | 48 ++ .../github/dfa1/vortex/encoding/Registry.java | 81 +- .../dfa1/vortex/encoding/RleEncoding.java | 2 +- .../vortex/encoding/UnknownArrayNode.java | 2 +- .../dfa1/vortex/encoding/WriteRegistry.java | 45 + .../dfa1/vortex/extension/Extension.java | 41 +- .../vortex/extension/ExtensionEncoder.java | 40 + .../vortex/encoding/AlpRdEncodingTest.java | 52 -- .../encoding/BitpackedEncodingTest.java | 98 --- .../vortex/encoding/BoolEncodingTest.java | 69 -- .../vortex/encoding/ByteBoolEncodingTest.java | 92 -- .../github/dfa1/vortex/encoding/DTypes.java | 48 +- .../DecimalBytePartsEncodingTest.java | 77 -- .../vortex/encoding/DecimalEncodingTest.java | 107 --- .../vortex/encoding/DeltaEncodingTest.java | 129 --- .../vortex/encoding/DictEncodingTest.java | 195 ----- .../vortex/encoding/EncodeTestHelper.java | 15 +- .../vortex/encoding/MaskedEncodingTest.java | 199 ----- .../vortex/encoding/PatchedEncodingTest.java | 214 ----- .../PatchesBroadcastRegressionTest.java | 144 ---- .../vortex/encoding/SegmentBroadcastTest.java | 54 ++ .../vortex/encoding/TestDecodeContexts.java | 14 +- .../dfa1/vortex/encoding/TestRegistry.java | 41 +- .../dfa1/vortex/encoding/TestSegments.java | 14 +- .../vortex/encoding/VariantEncodingTest.java | 158 ---- docs/compatibility.md | 27 + pom.xml | 6 + reader/pom.xml | 8 + .../dfa1/vortex/reader/ScanIterator.java | 4 +- .../dfa1/vortex/reader/VortexHandle.java | 14 + .../dfa1/vortex/reader/VortexHttpReader.java | 11 + .../dfa1/vortex/reader/VortexReader.java | 11 + .../reader/decode/AlpEncodingDecoder.java | 167 ++++ .../reader/decode/AlpRdEncodingDecoder.java | 175 ++++ .../decode/BitpackedEncodingDecoder.java | 526 ++++++++++++ .../reader/decode/BoolEncodingDecoder.java | 37 + .../decode/ByteBoolEncodingDecoder.java | 46 + .../reader/decode/ChunkedEncodingDecoder.java | 128 +++ .../decode/ConstantEncodingDecoder.java | 175 ++++ .../decode/DateTimePartsEncodingDecoder.java | 60 ++ .../DecimalBytePartsEncodingDecoder.java | 66 ++ .../reader/decode/DecimalEncodingDecoder.java | 69 ++ .../reader/decode/DeltaEncodingDecoder.java | 154 ++++ .../reader/decode/DictEncodingDecoder.java | 404 +++++++++ .../reader/decode/ExtEncodingDecoder.java | 40 + .../decode/FixedSizeListEncodingDecoder.java | 54 ++ .../FrameOfReferenceEncodingDecoder.java | 132 +++ .../reader/decode/FsstEncodingDecoder.java | 117 +++ .../reader/decode/ListEncodingDecoder.java | 66 ++ .../decode/ListViewEncodingDecoder.java | 69 ++ .../reader/decode/MaskedEncodingDecoder.java | 55 ++ .../reader/decode/NullEncodingDecoder.java | 31 + .../reader/decode/PatchedEncodingDecoder.java | 126 +++ .../reader/decode/PcoEncodingDecoder.java | 784 ++++++++++++++++++ .../decode/PrimitiveEncodingDecoder.java | 64 ++ .../reader/decode/RleEncodingDecoder.java | 229 +++++ .../reader/decode/RunEndEncodingDecoder.java | 273 ++++++ .../decode/SequenceEncodingDecoder.java | 142 ++++ .../reader/decode/SparseEncodingDecoder.java | 261 ++++++ .../reader/decode/StructEncodingDecoder.java | 108 +++ .../reader/decode/VarBinEncodingDecoder.java | 67 ++ .../decode/VarBinViewEncodingDecoder.java | 83 ++ .../reader/decode/VariantEncodingDecoder.java | 129 +++ .../reader/decode/ZigZagEncodingDecoder.java | 97 +++ .../reader/decode/ZstdEncodingDecoder.java | 258 ++++++ ...ithub.dfa1.vortex.encoding.EncodingDecoder | 33 + .../decode/ByteBoolEncodingDecoderTest.java | 57 ++ .../decode/NullEncodingDecoderTest.java | 35 + .../decode/PatchedEncodingDecoderTest.java | 167 ++++ .../reader/decode/PcoEncodingDecoderTest.java | 434 ++-------- .../decode/VariantEncodingDecoderTest.java | 143 ++++ writer/pom.xml | 7 + .../writer/encode/AlpEncodingEncoder.java | 327 ++++++++ .../writer/encode/AlpRdEncodingEncoder.java | 327 ++++++++ .../encode/BitpackedEncodingEncoder.java | 227 +++++ .../writer/encode/BoolEncodingEncoder.java | 75 ++ .../encode/ByteBoolEncodingEncoder.java | 38 + .../writer/encode/ChunkedEncodingEncoder.java | 93 +++ .../encode/ConstantEncodingEncoder.java | 103 +++ .../encode/DateTimePartsEncodingEncoder.java | 158 ++++ .../DecimalBytePartsEncodingEncoder.java | 49 ++ .../writer/encode/DecimalEncodingEncoder.java | 79 ++ .../writer/encode/DeltaEncodingEncoder.java | 192 +++++ .../writer/encode/DictEncodingEncoder.java | 317 +++++++ .../writer/encode/ExtEncodingEncoder.java | 79 ++ .../encode/FixedSizeListEncodingEncoder.java | 71 ++ .../FrameOfReferenceEncodingEncoder.java | 167 ++++ .../writer/encode/FsstEncodingEncoder.java | 158 ++++ .../writer/encode/ListEncodingEncoder.java | 88 ++ .../encode/ListViewEncodingEncoder.java | 97 +++ .../writer/encode/MaskedEncodingEncoder.java | 79 ++ .../writer/encode/NullEncodingEncoder.java | 34 + .../writer/encode/PatchedEncodingEncoder.java | 31 + .../writer/encode/PcoEncodingEncoder.java | 32 + .../encode/PrimitiveEncodingEncoder.java | 292 +++++++ .../writer/encode/RleEncodingEncoder.java | 255 ++++++ .../writer/encode/RunEndEncodingEncoder.java | 108 +++ .../encode/SequenceEncodingEncoder.java | 138 +++ .../writer/encode/SparseEncodingEncoder.java | 149 ++++ .../writer/encode/StructEncodingEncoder.java | 74 ++ .../writer/encode/VarBinEncodingEncoder.java | 84 ++ .../encode/VarBinViewEncodingEncoder.java | 84 ++ .../writer/encode/VariantEncodingEncoder.java | 31 + .../writer/encode/ZigZagEncodingEncoder.java | 84 ++ .../writer/encode/ZstdEncodingEncoder.java | 159 ++++ ...ithub.dfa1.vortex.encoding.EncodingEncoder | 33 + .../writer/encode/AlpEncodingEncoderTest.java | 189 ++--- .../encode/AlpRdEncodingEncoderTest.java | 57 ++ ...BitpackedConstantPatchesBroadcastTest.java | 83 ++ .../encode/BitpackedEncodingEncoderTest.java | 87 ++ .../encode}/BitpackedEncodingPatchesTest.java | 38 +- .../encode/BoolEncodingEncoderTest.java | 72 ++ .../encode/ByteBoolEncodingEncoderTest.java | 52 ++ .../encode}/CascadingCompressorTest.java | 36 +- .../encode/ChunkedEncodingEncoderTest.java | 104 +-- .../encode/ConstantEncodingEncoderTest.java | 142 ++-- .../DateTimePartsEncodingEncoderTest.java | 125 +-- .../DecimalBytePartsEncodingEncoderTest.java | 79 ++ .../encode/DecimalEncodingEncoderTest.java | 98 +++ .../encode/DeltaEncodingEncoderTest.java | 109 +++ .../encode/DictEncodingEncoderTest.java | 161 ++++ .../writer/encode/ExtEncodingEncoderTest.java | 74 +- .../FixedSizeListEncodingEncoderTest.java | 71 +- .../FrameOfReferenceEncodingEncoderTest.java | 133 +-- .../encode/FsstEncodingEncoderTest.java | 172 ++-- .../encode/ListEncodingEncoderTest.java | 103 +-- .../encode/ListViewEncodingEncoderTest.java | 97 +-- .../encode/MaskedEncodingEncoderTest.java | 161 ++++ .../encode/NullEncodingEncoderTest.java | 47 +- .../writer/encode/PcoEncodingEncoderTest.java | 22 + .../encode/PrimitiveEncodingEncoderTest.java | 99 +-- .../writer/encode}/RandomAccessTest.java | 44 +- .../writer/encode/RleEncodingEncoderTest.java | 223 ++--- .../encode/RunEndEncodingEncoderTest.java | 89 +- .../encode/SequenceEncodingEncoderTest.java | 133 +-- .../encode/SparseEncodingEncoderTest.java | 210 ++--- .../encode/StructEncodingEncoderTest.java | 90 +- .../encode/VarBinEncodingEncoderTest.java | 114 +-- .../encode/VarBinViewEncodingEncoderTest.java | 78 +- .../encode/ZigZagEncodingEncoderTest.java | 67 +- .../encode/ZstdEncodingEncoderTest.java | 151 +--- 158 files changed, 12498 insertions(+), 3801 deletions(-) create mode 100644 core/src/main/java/io/github/dfa1/vortex/encoding/EncodingDecoder.java create mode 100644 core/src/main/java/io/github/dfa1/vortex/encoding/EncodingEncoder.java create mode 100644 core/src/main/java/io/github/dfa1/vortex/encoding/ReadRegistry.java create mode 100644 core/src/main/java/io/github/dfa1/vortex/encoding/WriteRegistry.java create mode 100644 core/src/main/java/io/github/dfa1/vortex/extension/ExtensionEncoder.java delete mode 100644 core/src/test/java/io/github/dfa1/vortex/encoding/AlpRdEncodingTest.java delete mode 100644 core/src/test/java/io/github/dfa1/vortex/encoding/BitpackedEncodingTest.java delete mode 100644 core/src/test/java/io/github/dfa1/vortex/encoding/BoolEncodingTest.java delete mode 100644 core/src/test/java/io/github/dfa1/vortex/encoding/ByteBoolEncodingTest.java delete mode 100644 core/src/test/java/io/github/dfa1/vortex/encoding/DecimalBytePartsEncodingTest.java delete mode 100644 core/src/test/java/io/github/dfa1/vortex/encoding/DecimalEncodingTest.java delete mode 100644 core/src/test/java/io/github/dfa1/vortex/encoding/DeltaEncodingTest.java delete mode 100644 core/src/test/java/io/github/dfa1/vortex/encoding/DictEncodingTest.java delete mode 100644 core/src/test/java/io/github/dfa1/vortex/encoding/MaskedEncodingTest.java delete mode 100644 core/src/test/java/io/github/dfa1/vortex/encoding/PatchedEncodingTest.java delete mode 100644 core/src/test/java/io/github/dfa1/vortex/encoding/PatchesBroadcastRegressionTest.java create mode 100644 core/src/test/java/io/github/dfa1/vortex/encoding/SegmentBroadcastTest.java delete mode 100644 core/src/test/java/io/github/dfa1/vortex/encoding/VariantEncodingTest.java create mode 100644 reader/src/main/java/io/github/dfa1/vortex/reader/decode/AlpEncodingDecoder.java create mode 100644 reader/src/main/java/io/github/dfa1/vortex/reader/decode/AlpRdEncodingDecoder.java create mode 100644 reader/src/main/java/io/github/dfa1/vortex/reader/decode/BitpackedEncodingDecoder.java create mode 100644 reader/src/main/java/io/github/dfa1/vortex/reader/decode/BoolEncodingDecoder.java create mode 100644 reader/src/main/java/io/github/dfa1/vortex/reader/decode/ByteBoolEncodingDecoder.java create mode 100644 reader/src/main/java/io/github/dfa1/vortex/reader/decode/ChunkedEncodingDecoder.java create mode 100644 reader/src/main/java/io/github/dfa1/vortex/reader/decode/ConstantEncodingDecoder.java create mode 100644 reader/src/main/java/io/github/dfa1/vortex/reader/decode/DateTimePartsEncodingDecoder.java create mode 100644 reader/src/main/java/io/github/dfa1/vortex/reader/decode/DecimalBytePartsEncodingDecoder.java create mode 100644 reader/src/main/java/io/github/dfa1/vortex/reader/decode/DecimalEncodingDecoder.java create mode 100644 reader/src/main/java/io/github/dfa1/vortex/reader/decode/DeltaEncodingDecoder.java create mode 100644 reader/src/main/java/io/github/dfa1/vortex/reader/decode/DictEncodingDecoder.java create mode 100644 reader/src/main/java/io/github/dfa1/vortex/reader/decode/ExtEncodingDecoder.java create mode 100644 reader/src/main/java/io/github/dfa1/vortex/reader/decode/FixedSizeListEncodingDecoder.java create mode 100644 reader/src/main/java/io/github/dfa1/vortex/reader/decode/FrameOfReferenceEncodingDecoder.java create mode 100644 reader/src/main/java/io/github/dfa1/vortex/reader/decode/FsstEncodingDecoder.java create mode 100644 reader/src/main/java/io/github/dfa1/vortex/reader/decode/ListEncodingDecoder.java create mode 100644 reader/src/main/java/io/github/dfa1/vortex/reader/decode/ListViewEncodingDecoder.java create mode 100644 reader/src/main/java/io/github/dfa1/vortex/reader/decode/MaskedEncodingDecoder.java create mode 100644 reader/src/main/java/io/github/dfa1/vortex/reader/decode/NullEncodingDecoder.java create mode 100644 reader/src/main/java/io/github/dfa1/vortex/reader/decode/PatchedEncodingDecoder.java create mode 100644 reader/src/main/java/io/github/dfa1/vortex/reader/decode/PcoEncodingDecoder.java create mode 100644 reader/src/main/java/io/github/dfa1/vortex/reader/decode/PrimitiveEncodingDecoder.java create mode 100644 reader/src/main/java/io/github/dfa1/vortex/reader/decode/RleEncodingDecoder.java create mode 100644 reader/src/main/java/io/github/dfa1/vortex/reader/decode/RunEndEncodingDecoder.java create mode 100644 reader/src/main/java/io/github/dfa1/vortex/reader/decode/SequenceEncodingDecoder.java create mode 100644 reader/src/main/java/io/github/dfa1/vortex/reader/decode/SparseEncodingDecoder.java create mode 100644 reader/src/main/java/io/github/dfa1/vortex/reader/decode/StructEncodingDecoder.java create mode 100644 reader/src/main/java/io/github/dfa1/vortex/reader/decode/VarBinEncodingDecoder.java create mode 100644 reader/src/main/java/io/github/dfa1/vortex/reader/decode/VarBinViewEncodingDecoder.java create mode 100644 reader/src/main/java/io/github/dfa1/vortex/reader/decode/VariantEncodingDecoder.java create mode 100644 reader/src/main/java/io/github/dfa1/vortex/reader/decode/ZigZagEncodingDecoder.java create mode 100644 reader/src/main/java/io/github/dfa1/vortex/reader/decode/ZstdEncodingDecoder.java create mode 100644 reader/src/main/resources/META-INF/services/io.github.dfa1.vortex.encoding.EncodingDecoder create mode 100644 reader/src/test/java/io/github/dfa1/vortex/reader/decode/ByteBoolEncodingDecoderTest.java create mode 100644 reader/src/test/java/io/github/dfa1/vortex/reader/decode/NullEncodingDecoderTest.java create mode 100644 reader/src/test/java/io/github/dfa1/vortex/reader/decode/PatchedEncodingDecoderTest.java rename core/src/test/java/io/github/dfa1/vortex/encoding/PcoEncodingTest.java => reader/src/test/java/io/github/dfa1/vortex/reader/decode/PcoEncodingDecoderTest.java (51%) create mode 100644 reader/src/test/java/io/github/dfa1/vortex/reader/decode/VariantEncodingDecoderTest.java create mode 100644 writer/src/main/java/io/github/dfa1/vortex/writer/encode/AlpEncodingEncoder.java create mode 100644 writer/src/main/java/io/github/dfa1/vortex/writer/encode/AlpRdEncodingEncoder.java create mode 100644 writer/src/main/java/io/github/dfa1/vortex/writer/encode/BitpackedEncodingEncoder.java create mode 100644 writer/src/main/java/io/github/dfa1/vortex/writer/encode/BoolEncodingEncoder.java create mode 100644 writer/src/main/java/io/github/dfa1/vortex/writer/encode/ByteBoolEncodingEncoder.java create mode 100644 writer/src/main/java/io/github/dfa1/vortex/writer/encode/ChunkedEncodingEncoder.java create mode 100644 writer/src/main/java/io/github/dfa1/vortex/writer/encode/ConstantEncodingEncoder.java create mode 100644 writer/src/main/java/io/github/dfa1/vortex/writer/encode/DateTimePartsEncodingEncoder.java create mode 100644 writer/src/main/java/io/github/dfa1/vortex/writer/encode/DecimalBytePartsEncodingEncoder.java create mode 100644 writer/src/main/java/io/github/dfa1/vortex/writer/encode/DecimalEncodingEncoder.java create mode 100644 writer/src/main/java/io/github/dfa1/vortex/writer/encode/DeltaEncodingEncoder.java create mode 100644 writer/src/main/java/io/github/dfa1/vortex/writer/encode/DictEncodingEncoder.java create mode 100644 writer/src/main/java/io/github/dfa1/vortex/writer/encode/ExtEncodingEncoder.java create mode 100644 writer/src/main/java/io/github/dfa1/vortex/writer/encode/FixedSizeListEncodingEncoder.java create mode 100644 writer/src/main/java/io/github/dfa1/vortex/writer/encode/FrameOfReferenceEncodingEncoder.java create mode 100644 writer/src/main/java/io/github/dfa1/vortex/writer/encode/FsstEncodingEncoder.java create mode 100644 writer/src/main/java/io/github/dfa1/vortex/writer/encode/ListEncodingEncoder.java create mode 100644 writer/src/main/java/io/github/dfa1/vortex/writer/encode/ListViewEncodingEncoder.java create mode 100644 writer/src/main/java/io/github/dfa1/vortex/writer/encode/MaskedEncodingEncoder.java create mode 100644 writer/src/main/java/io/github/dfa1/vortex/writer/encode/NullEncodingEncoder.java create mode 100644 writer/src/main/java/io/github/dfa1/vortex/writer/encode/PatchedEncodingEncoder.java create mode 100644 writer/src/main/java/io/github/dfa1/vortex/writer/encode/PcoEncodingEncoder.java create mode 100644 writer/src/main/java/io/github/dfa1/vortex/writer/encode/PrimitiveEncodingEncoder.java create mode 100644 writer/src/main/java/io/github/dfa1/vortex/writer/encode/RleEncodingEncoder.java create mode 100644 writer/src/main/java/io/github/dfa1/vortex/writer/encode/RunEndEncodingEncoder.java create mode 100644 writer/src/main/java/io/github/dfa1/vortex/writer/encode/SequenceEncodingEncoder.java create mode 100644 writer/src/main/java/io/github/dfa1/vortex/writer/encode/SparseEncodingEncoder.java create mode 100644 writer/src/main/java/io/github/dfa1/vortex/writer/encode/StructEncodingEncoder.java create mode 100644 writer/src/main/java/io/github/dfa1/vortex/writer/encode/VarBinEncodingEncoder.java create mode 100644 writer/src/main/java/io/github/dfa1/vortex/writer/encode/VarBinViewEncodingEncoder.java create mode 100644 writer/src/main/java/io/github/dfa1/vortex/writer/encode/VariantEncodingEncoder.java create mode 100644 writer/src/main/java/io/github/dfa1/vortex/writer/encode/ZigZagEncodingEncoder.java create mode 100644 writer/src/main/java/io/github/dfa1/vortex/writer/encode/ZstdEncodingEncoder.java create mode 100644 writer/src/main/resources/META-INF/services/io.github.dfa1.vortex.encoding.EncodingEncoder rename core/src/test/java/io/github/dfa1/vortex/encoding/AlpEncodingTest.java => writer/src/test/java/io/github/dfa1/vortex/writer/encode/AlpEncodingEncoderTest.java (52%) create mode 100644 writer/src/test/java/io/github/dfa1/vortex/writer/encode/AlpRdEncodingEncoderTest.java create mode 100644 writer/src/test/java/io/github/dfa1/vortex/writer/encode/BitpackedConstantPatchesBroadcastTest.java create mode 100644 writer/src/test/java/io/github/dfa1/vortex/writer/encode/BitpackedEncodingEncoderTest.java rename {core/src/test/java/io/github/dfa1/vortex/encoding => writer/src/test/java/io/github/dfa1/vortex/writer/encode}/BitpackedEncodingPatchesTest.java (68%) create mode 100644 writer/src/test/java/io/github/dfa1/vortex/writer/encode/BoolEncodingEncoderTest.java create mode 100644 writer/src/test/java/io/github/dfa1/vortex/writer/encode/ByteBoolEncodingEncoderTest.java rename {core/src/test/java/io/github/dfa1/vortex/encoding => writer/src/test/java/io/github/dfa1/vortex/writer/encode}/CascadingCompressorTest.java (86%) rename core/src/test/java/io/github/dfa1/vortex/encoding/ChunkedEncodingTest.java => writer/src/test/java/io/github/dfa1/vortex/writer/encode/ChunkedEncodingEncoderTest.java (64%) rename core/src/test/java/io/github/dfa1/vortex/encoding/ConstantEncodingTest.java => writer/src/test/java/io/github/dfa1/vortex/writer/encode/ConstantEncodingEncoderTest.java (57%) rename core/src/test/java/io/github/dfa1/vortex/encoding/DateTimePartsEncodingTest.java => writer/src/test/java/io/github/dfa1/vortex/writer/encode/DateTimePartsEncodingEncoderTest.java (66%) create mode 100644 writer/src/test/java/io/github/dfa1/vortex/writer/encode/DecimalBytePartsEncodingEncoderTest.java create mode 100644 writer/src/test/java/io/github/dfa1/vortex/writer/encode/DecimalEncodingEncoderTest.java create mode 100644 writer/src/test/java/io/github/dfa1/vortex/writer/encode/DeltaEncodingEncoderTest.java create mode 100644 writer/src/test/java/io/github/dfa1/vortex/writer/encode/DictEncodingEncoderTest.java rename core/src/test/java/io/github/dfa1/vortex/encoding/ExtEncodingTest.java => writer/src/test/java/io/github/dfa1/vortex/writer/encode/ExtEncodingEncoderTest.java (71%) rename core/src/test/java/io/github/dfa1/vortex/encoding/FixedSizeListEncodingTest.java => writer/src/test/java/io/github/dfa1/vortex/writer/encode/FixedSizeListEncodingEncoderTest.java (67%) rename core/src/test/java/io/github/dfa1/vortex/encoding/FrameOfReferenceEncodingTest.java => writer/src/test/java/io/github/dfa1/vortex/writer/encode/FrameOfReferenceEncodingEncoderTest.java (61%) rename core/src/test/java/io/github/dfa1/vortex/encoding/FsstEncodingTest.java => writer/src/test/java/io/github/dfa1/vortex/writer/encode/FsstEncodingEncoderTest.java (60%) rename core/src/test/java/io/github/dfa1/vortex/encoding/ListEncodingTest.java => writer/src/test/java/io/github/dfa1/vortex/writer/encode/ListEncodingEncoderTest.java (63%) rename core/src/test/java/io/github/dfa1/vortex/encoding/ListViewEncodingTest.java => writer/src/test/java/io/github/dfa1/vortex/writer/encode/ListViewEncodingEncoderTest.java (68%) create mode 100644 writer/src/test/java/io/github/dfa1/vortex/writer/encode/MaskedEncodingEncoderTest.java rename core/src/test/java/io/github/dfa1/vortex/encoding/NullEncodingTest.java => writer/src/test/java/io/github/dfa1/vortex/writer/encode/NullEncodingEncoderTest.java (51%) create mode 100644 writer/src/test/java/io/github/dfa1/vortex/writer/encode/PcoEncodingEncoderTest.java rename core/src/test/java/io/github/dfa1/vortex/encoding/PrimitiveEncodingTest.java => writer/src/test/java/io/github/dfa1/vortex/writer/encode/PrimitiveEncodingEncoderTest.java (65%) rename {core/src/test/java/io/github/dfa1/vortex/encoding => writer/src/test/java/io/github/dfa1/vortex/writer/encode}/RandomAccessTest.java (54%) rename core/src/test/java/io/github/dfa1/vortex/encoding/RleEncodingTest.java => writer/src/test/java/io/github/dfa1/vortex/writer/encode/RleEncodingEncoderTest.java (60%) rename core/src/test/java/io/github/dfa1/vortex/encoding/RunEndEncodingTest.java => writer/src/test/java/io/github/dfa1/vortex/writer/encode/RunEndEncodingEncoderTest.java (63%) rename core/src/test/java/io/github/dfa1/vortex/encoding/SequenceEncodingTest.java => writer/src/test/java/io/github/dfa1/vortex/writer/encode/SequenceEncodingEncoderTest.java (70%) rename core/src/test/java/io/github/dfa1/vortex/encoding/SparseEncodingTest.java => writer/src/test/java/io/github/dfa1/vortex/writer/encode/SparseEncodingEncoderTest.java (67%) rename core/src/test/java/io/github/dfa1/vortex/encoding/StructEncodingTest.java => writer/src/test/java/io/github/dfa1/vortex/writer/encode/StructEncodingEncoderTest.java (65%) rename core/src/test/java/io/github/dfa1/vortex/encoding/VarBinEncodingTest.java => writer/src/test/java/io/github/dfa1/vortex/writer/encode/VarBinEncodingEncoderTest.java (52%) rename core/src/test/java/io/github/dfa1/vortex/encoding/VarBinViewEncodingTest.java => writer/src/test/java/io/github/dfa1/vortex/writer/encode/VarBinViewEncodingEncoderTest.java (71%) rename core/src/test/java/io/github/dfa1/vortex/encoding/ZigZagEncodingTest.java => writer/src/test/java/io/github/dfa1/vortex/writer/encode/ZigZagEncodingEncoderTest.java (69%) rename core/src/test/java/io/github/dfa1/vortex/encoding/ZstdEncodingTest.java => writer/src/test/java/io/github/dfa1/vortex/writer/encode/ZstdEncodingEncoderTest.java (73%) diff --git a/README.md b/README.md index a11a5faf..85887cf1 100644 --- a/README.md +++ b/README.md @@ -4,10 +4,11 @@ [![Maven Central](https://img.shields.io/maven-central/v/io.github.dfa1.vortex/vortex-reader.svg)](https://central.sonatype.com/artifact/io.github.dfa1.vortex/vortex-reader) [![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/license/Apache-2.0) +> **Alpha** — not production-ready. APIs will change without notice. + Pure-Java reader/writer for the [Vortex](https://github.com/vortex-data/vortex) columnar file format. 100% Java, no JNI, no `sun.misc.Unsafe`. Uses the FFM API (`MemorySegment`/`Arena`, Java 25+) -for zero-copy memory-mapped reads. Read benchmarks match or beat the Rust JNI on the workloads -tested (Apple M5, JDK 25); see [docs/explanation.md#benchmarks](docs/explanation.md#benchmarks). +for zero-copy memory-mapped reads. | Project | Language | Notes | |---------------------------------------------------------------------|----------|-----------------------------------------| @@ -49,12 +50,15 @@ try (VortexReader vf = VortexReader.open(Path.of("data/example.vortex")); } ``` -> **Lifecycle.** `Chunk` owns a confined `Arena` — close it (try-with-resources -> or `iter.forEachRemaining`) to release the decoded buffers. Full lifecycle -> rules: [docs/explanation.md#memory-model](docs/explanation.md#memory-model). +> **Lifecycle.** `ScanIterator` implements `Iterator` and `Chunk` implements +> `AutoCloseable`. Each chunk owns a confined `Arena`; closing it releases the +> decoded buffers. Calling `iter.next()` while a prior chunk is still open throws +> `IllegalStateException`. Use try-with-resources, or +> `iter.forEachRemaining(c -> ...)` which closes each chunk for you. See +> [docs/explanation.md#memory-model](docs/explanation.md#memory-model). -For more examples (writing, projection, filtering, custom encodings, CLI) see -the documentation below. +For more examples — writing, projection, filtering, custom encodings, and the CLI — +see the documentation below. ## Documentation diff --git a/cli/src/main/java/io/github/dfa1/vortex/cli/tui/VortexInspectorTui.java b/cli/src/main/java/io/github/dfa1/vortex/cli/tui/VortexInspectorTui.java index a6c58272..9e282b03 100644 --- a/cli/src/main/java/io/github/dfa1/vortex/cli/tui/VortexInspectorTui.java +++ b/cli/src/main/java/io/github/dfa1/vortex/cli/tui/VortexInspectorTui.java @@ -571,11 +571,8 @@ private void runDictLoad(InspectorTree.Node dictNode) { try (java.lang.foreign.Arena arena = java.lang.foreign.Arena.ofConfined()) { int segIdx = values.segments().getFirst(); SegmentSpec spec = tree.segmentSpecs().get(segIdx); - java.lang.foreign.MemorySegment seg = handle.slice(spec.offset(), spec.length()); io.github.dfa1.vortex.core.array.Array arr = - new io.github.dfa1.vortex.encoding.FlatSegmentDecoder(handle.registry()) - .decode(seg, handle.footer().arraySpecs(), - dtype, values.rowCount(), arena); + handle.decodeFlatSegment(spec, dtype, values.rowCount(), arena); int n = (int) Math.min(arr.length(), DATA_PREVIEW_ROWS); List out = new ArrayList<>(n); for (int i = 0; i < n; i++) { diff --git a/core/pom.xml b/core/pom.xml index b4b90110..03d6e9dd 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -47,6 +47,25 @@ + + + + + org.apache.maven.plugins + maven-jar-plugin + + + publish-test-jar + + test-jar + + + + + + + + + io.github.dfa1.vortex + vortex-inspector + 0.6.0 + +``` + +`./mvnw -pl core,reader,inspector verify` builds the read-only artifact set +without the writer module on the classpath. None of the writer-side encoder +implementations are loaded; `ServiceLoader` resolves only the +standalone decoders in `reader`, falling back to the bifunctional `Encoding` +implementations in `core` for encoding families not yet lifted (ADR 0001 +Phases 2–3). + ## Known wire-format gaps | Item | Introduced | Java status | diff --git a/pom.xml b/pom.xml index ae1c9507..5ed9a397 100644 --- a/pom.xml +++ b/pom.xml @@ -85,6 +85,12 @@ vortex-core ${project.version} + + io.github.dfa1.vortex + vortex-core + ${project.version} + test-jar + io.github.dfa1.vortex vortex-reader diff --git a/reader/pom.xml b/reader/pom.xml index f2faece3..ed08aa8e 100644 --- a/reader/pom.xml +++ b/reader/pom.xml @@ -25,6 +25,14 @@ flatbuffers-java + + + io.github.dfa1.vortex + vortex-core + test-jar + test + + org.junit.jupiter junit-jupiter diff --git a/reader/src/main/java/io/github/dfa1/vortex/reader/ScanIterator.java b/reader/src/main/java/io/github/dfa1/vortex/reader/ScanIterator.java index 6561bbc0..905114af 100644 --- a/reader/src/main/java/io/github/dfa1/vortex/reader/ScanIterator.java +++ b/reader/src/main/java/io/github/dfa1/vortex/reader/ScanIterator.java @@ -23,7 +23,6 @@ import io.github.dfa1.vortex.core.array.VarBinArray; import io.github.dfa1.vortex.encoding.EncodingId; import io.github.dfa1.vortex.encoding.Registry; -import io.github.dfa1.vortex.encoding.FlatSegmentDecoder; import java.lang.foreign.Arena; import java.lang.foreign.MemorySegment; @@ -478,8 +477,7 @@ private Array decodeFlat(Layout flat, DType dtype, SegmentAllocator arena) { } int segIdx = flat.segments().getFirst(); SegmentSpec spec = file.footer().segmentSpecs().get(segIdx); - MemorySegment seg = file.slice(spec.offset(), spec.length()); - return new FlatSegmentDecoder(registry).decode(seg, file.footer().arraySpecs(), dtype, flat.rowCount(), arena); + return file.decodeFlatSegment(spec, dtype, flat.rowCount(), arena); } private Array decodeDictLayout(Layout dictLayout, DType dtype, SegmentAllocator arena) { diff --git a/reader/src/main/java/io/github/dfa1/vortex/reader/VortexHandle.java b/reader/src/main/java/io/github/dfa1/vortex/reader/VortexHandle.java index 163ff16d..f300a1c0 100644 --- a/reader/src/main/java/io/github/dfa1/vortex/reader/VortexHandle.java +++ b/reader/src/main/java/io/github/dfa1/vortex/reader/VortexHandle.java @@ -3,10 +3,13 @@ import io.github.dfa1.vortex.core.DType; import io.github.dfa1.vortex.core.Footer; import io.github.dfa1.vortex.core.Layout; +import io.github.dfa1.vortex.core.SegmentSpec; +import io.github.dfa1.vortex.core.array.Array; import io.github.dfa1.vortex.encoding.Registry; import java.io.Closeable; import java.lang.foreign.MemorySegment; +import java.lang.foreign.SegmentAllocator; /// Common interface for handles to a Vortex file, regardless of storage backend. /// @@ -23,6 +26,17 @@ public interface VortexHandle extends Closeable { long fileSize(); + /// Typed accessor for the common pattern "slice a flat segment by its {@link SegmentSpec} + /// and decode the encoded array contained therein." Replaces the raw {@link #slice} + /// escape hatch for read-side consumers (scan, inspector, TUI). + /// + /// @param spec the segment spec to read from + /// @param dtype logical type of the decoded array + /// @param rowCount number of logical rows in the segment + /// @param arena allocator for decode output; lifetime matches the caller's chunk epoch + /// @return the decoded array + Array decodeFlatSegment(SegmentSpec spec, DType dtype, long rowCount, SegmentAllocator arena); + /// Returns a read-only view of bytes `[offset, offset+length)` within the file. /// Writes through the returned segment throw `UnsupportedOperationException`. /// diff --git a/reader/src/main/java/io/github/dfa1/vortex/reader/VortexHttpReader.java b/reader/src/main/java/io/github/dfa1/vortex/reader/VortexHttpReader.java index f7040e2b..1a6ab7af 100644 --- a/reader/src/main/java/io/github/dfa1/vortex/reader/VortexHttpReader.java +++ b/reader/src/main/java/io/github/dfa1/vortex/reader/VortexHttpReader.java @@ -217,6 +217,17 @@ public long fileSize() { // ── HTTP helpers ────────────────────────────────────────────────────────── + @Override + public io.github.dfa1.vortex.core.array.Array decodeFlatSegment( + io.github.dfa1.vortex.core.SegmentSpec spec, + DType dtype, long rowCount, + java.lang.foreign.SegmentAllocator arenaOut + ) { + MemorySegment seg = slice(spec.offset(), spec.length()); + return new io.github.dfa1.vortex.encoding.FlatSegmentDecoder(registry) + .decode(seg, footer.arraySpecs(), dtype, rowCount, arenaOut); + } + /// Fetches bytes `[offset, offset+length)` via HTTP Range and returns them /// as an off-heap [MemorySegment] tied to this reader's [Arena]. @Override diff --git a/reader/src/main/java/io/github/dfa1/vortex/reader/VortexReader.java b/reader/src/main/java/io/github/dfa1/vortex/reader/VortexReader.java index f386eeb0..782de836 100644 --- a/reader/src/main/java/io/github/dfa1/vortex/reader/VortexReader.java +++ b/reader/src/main/java/io/github/dfa1/vortex/reader/VortexReader.java @@ -233,6 +233,17 @@ private ArrayStats readFlatStats(Layout flat) { return ArrayStats.fromFbs(root.stats()); } + @Override + public io.github.dfa1.vortex.core.array.Array decodeFlatSegment( + io.github.dfa1.vortex.core.SegmentSpec spec, + DType dtype, long rowCount, + java.lang.foreign.SegmentAllocator arena + ) { + MemorySegment seg = fileSegment.asSlice(spec.offset(), spec.length()).asReadOnly(); + return new io.github.dfa1.vortex.encoding.FlatSegmentDecoder(registry) + .decode(seg, footer.arraySpecs(), dtype, rowCount, arena); + } + /// Zero-copy read-only slice of the memory-mapped file. @Override public MemorySegment slice(long offset, long length) { diff --git a/reader/src/main/java/io/github/dfa1/vortex/reader/decode/AlpEncodingDecoder.java b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/AlpEncodingDecoder.java new file mode 100644 index 00000000..931e6e3c --- /dev/null +++ b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/AlpEncodingDecoder.java @@ -0,0 +1,167 @@ +package io.github.dfa1.vortex.reader.decode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.PType; +import io.github.dfa1.vortex.core.VortexException; +import io.github.dfa1.vortex.core.array.Array; +import io.github.dfa1.vortex.core.array.DoubleArray; +import io.github.dfa1.vortex.core.array.FloatArray; +import io.github.dfa1.vortex.encoding.AlpEncoding; +import io.github.dfa1.vortex.encoding.DecodeContext; +import io.github.dfa1.vortex.encoding.EncodingDecoder; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.PTypeIO; +import io.github.dfa1.vortex.encoding.SegmentBroadcast; +import io.github.dfa1.vortex.proto.ALPMetadata; +import io.github.dfa1.vortex.proto.PatchesMetadata; + +import java.io.IOException; +import java.lang.foreign.MemorySegment; +import java.lang.foreign.ValueLayout; +import java.nio.ByteBuffer; + +/// Read-only decoder for {@code vortex.alp}. +public final class AlpEncodingDecoder implements EncodingDecoder { + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public AlpEncodingDecoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_ALP; + } + + @Override + public boolean accepts(DType dtype) { + if (!(dtype instanceof DType.Primitive p)) { + return false; + } + return p.ptype() == PType.F64 || p.ptype() == PType.F32; + } + + @Override + public Array decode(DecodeContext ctx) { + ByteBuffer rawMeta = ctx.metadata(); + ALPMetadata meta; + if (rawMeta == null || !rawMeta.hasRemaining()) { + meta = new ALPMetadata(0, 0, null); + } else { + try { + MemorySegment metaSeg = MemorySegment.ofBuffer(rawMeta.duplicate()); + meta = ALPMetadata.decode(metaSeg, 0, metaSeg.byteSize()); + } catch (IOException e) { + throw new VortexException(EncodingId.VORTEX_ALP, "invalid metadata", e); + } + } + + if (!(ctx.dtype() instanceof DType.Primitive p)) { + throw new VortexException(EncodingId.VORTEX_ALP, "expected primitive dtype, got " + ctx.dtype()); + } + + int expE = meta.exp_e(); + int expF = meta.exp_f(); + PType ptype = p.ptype(); + long n = ctx.rowCount(); + + return switch (ptype) { + case F64 -> decodeF64(ctx, meta, expE, expF, n); + case F32 -> decodeF32(ctx, meta, expE, expF, n); + default -> throw new VortexException(EncodingId.VORTEX_ALP, "unsupported dtype " + ptype); + }; + } + + private static Array decodeF64(DecodeContext ctx, ALPMetadata meta, int expE, int expF, long n) { + double df = AlpEncoding.F10_F64[expF]; + double de = AlpEncoding.IF10_F64[expE]; + + MemorySegment src = ctx.decodeChildSegment(0, AlpEncoding.I64_DTYPE, n); + MemorySegment buf = src.isReadOnly() ? ctx.arena().allocate(n * 8, 8) : src; + if (src.isReadOnly()) { + long srcCap = SegmentBroadcast.capacity(src, 8); + if (srcCap == n) { + for (long i = 0; i < n; i++) { + buf.setAtIndex(PTypeIO.LE_DOUBLE, i, (double) src.getAtIndex(PTypeIO.LE_LONG, i) * df * de); + } + } else { + for (long i = 0; i < n; i++) { + buf.setAtIndex(PTypeIO.LE_DOUBLE, i, (double) src.getAtIndex(PTypeIO.LE_LONG, i % srcCap) * df * de); + } + } + } else { + for (long i = 0; i < n; i++) { + buf.setAtIndex(PTypeIO.LE_DOUBLE, i, (double) buf.getAtIndex(PTypeIO.LE_LONG, i) * df * de); + } + } + + if (meta.patches() != null) { + applyPatches(ctx, meta.patches(), buf, 8); + } + + return new DoubleArray(ctx.dtype(), n, buf.asReadOnly()); + } + + private static Array decodeF32(DecodeContext ctx, ALPMetadata meta, int expE, int expF, long n) { + float df = AlpEncoding.F10_F32[expF]; + float de = AlpEncoding.IF10_F32[expE]; + + MemorySegment src32 = ctx.decodeChildSegment(0, AlpEncoding.I32_DTYPE, n); + MemorySegment buf32 = src32.isReadOnly() ? ctx.arena().allocate(n * 4, 4) : src32; + if (src32.isReadOnly()) { + long srcCap = SegmentBroadcast.capacity(src32, 4); + if (srcCap == n) { + for (long i = 0; i < n; i++) { + buf32.setAtIndex(PTypeIO.LE_FLOAT, i, (float) src32.getAtIndex(PTypeIO.LE_INT, i) * df * de); + } + } else { + for (long i = 0; i < n; i++) { + buf32.setAtIndex(PTypeIO.LE_FLOAT, i, (float) src32.getAtIndex(PTypeIO.LE_INT, i % srcCap) * df * de); + } + } + } else { + for (long i = 0; i < n; i++) { + buf32.setAtIndex(PTypeIO.LE_FLOAT, i, (float) buf32.getAtIndex(PTypeIO.LE_INT, i) * df * de); + } + } + + if (meta.patches() != null) { + applyPatches(ctx, meta.patches(), buf32, 4); + } + + return new FloatArray(ctx.dtype(), n, buf32.asReadOnly()); + } + + private static void applyPatches(DecodeContext ctx, PatchesMetadata pm, MemorySegment out, int elemBytes) { + long numPatches = pm.len(); + long offset = pm.offset(); + PType idxPtype = PType.fromOrdinal(pm.indices_ptype().value()); + int idxBytes = idxPtype.byteSize(); + + MemorySegment idxSeg = ctx.decodeChildSegment(1, new DType.Primitive(idxPtype, false), numPatches); + MemorySegment valSeg = ctx.decodeChildSegment(2, ctx.dtype(), numPatches); + + long idxCap = SegmentBroadcast.capacity(idxSeg, idxBytes); + long valCap = SegmentBroadcast.capacity(valSeg, elemBytes); + if (idxCap >= numPatches && valCap >= numPatches) { + for (long i = 0; i < numPatches; i++) { + long absIdx = readUnsigned(idxSeg, i * idxBytes, idxPtype) - offset; + MemorySegment.copy(valSeg, i * elemBytes, out, absIdx * elemBytes, elemBytes); + } + } else { + for (long i = 0; i < numPatches; i++) { + long absIdx = readUnsigned(idxSeg, (i % idxCap) * idxBytes, idxPtype) - offset; + MemorySegment.copy(valSeg, (i % valCap) * elemBytes, out, absIdx * elemBytes, elemBytes); + } + } + } + + private static long readUnsigned(MemorySegment seg, long off, PType ptype) { + return switch (ptype) { + case U8 -> Byte.toUnsignedLong(seg.get(ValueLayout.JAVA_BYTE, off)); + case U16 -> Short.toUnsignedLong(seg.get(PTypeIO.LE_SHORT, off)); + case U32 -> Integer.toUnsignedLong(seg.get(PTypeIO.LE_INT, off)); + case U64 -> seg.get(PTypeIO.LE_LONG, off); + default -> throw new VortexException(EncodingId.VORTEX_ALP, "non-unsigned patch index ptype " + ptype); + }; + } +} diff --git a/reader/src/main/java/io/github/dfa1/vortex/reader/decode/AlpRdEncodingDecoder.java b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/AlpRdEncodingDecoder.java new file mode 100644 index 00000000..28fb670b --- /dev/null +++ b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/AlpRdEncodingDecoder.java @@ -0,0 +1,175 @@ +package io.github.dfa1.vortex.reader.decode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.PType; +import io.github.dfa1.vortex.core.VortexException; +import io.github.dfa1.vortex.core.array.Array; +import io.github.dfa1.vortex.core.array.DoubleArray; +import io.github.dfa1.vortex.core.array.FloatArray; +import io.github.dfa1.vortex.encoding.AlpRdEncoding; +import io.github.dfa1.vortex.encoding.DecodeContext; +import io.github.dfa1.vortex.encoding.EncodingDecoder; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.PTypeIO; +import io.github.dfa1.vortex.encoding.SegmentBroadcast; +import io.github.dfa1.vortex.proto.ALPRDMetadata; +import io.github.dfa1.vortex.proto.PatchesMetadata; + +import java.io.IOException; +import java.lang.foreign.MemorySegment; +import java.lang.foreign.ValueLayout; +import java.nio.ByteBuffer; + +/// Read-only decoder for {@code vortex.alprd}. +public final class AlpRdEncodingDecoder implements EncodingDecoder { + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public AlpRdEncodingDecoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_ALPRD; + } + + @Override + public boolean accepts(DType dtype) { + if (!(dtype instanceof DType.Primitive p)) { + return false; + } + return p.ptype() == PType.F32 || p.ptype() == PType.F64; + } + + @Override + public Array decode(DecodeContext ctx) { + ALPRDMetadata meta = parseMeta(ctx); + + if (!(ctx.dtype() instanceof DType.Primitive p)) { + throw new VortexException(EncodingId.VORTEX_ALPRD, + "expected primitive dtype, got " + ctx.dtype()); + } + + int rightBitWidth = meta.right_bit_width(); + int dictLen = meta.dict_len(); + short[] dict = new short[dictLen]; + for (int i = 0; i < dictLen; i++) { + dict[i] = (short) (meta.dict().get(i) & 0xFFFF); + } + + long n = ctx.rowCount(); + PType ptype = p.ptype(); + + return switch (ptype) { + case F64 -> decodeF64(ctx, meta, dict, rightBitWidth, n); + case F32 -> decodeF32(ctx, meta, dict, rightBitWidth, n); + default -> throw new VortexException(EncodingId.VORTEX_ALPRD, "unsupported dtype " + ptype); + }; + } + + private static Array decodeF64(DecodeContext ctx, ALPRDMetadata meta, short[] dict, int rightBitWidth, long n) { + MemorySegment leftSeg = ctx.decodeChildSegment(0, AlpRdEncoding.U16_DTYPE, n); + MemorySegment rightSeg = ctx.decodeChildSegment(1, AlpRdEncoding.U64_DTYPE, n); + long leftCap = SegmentBroadcast.capacity(leftSeg, 2); + long rightCap = SegmentBroadcast.capacity(rightSeg, 8); + MemorySegment out = ctx.arena().allocate(n * Long.BYTES, Long.BYTES); + + for (long i = 0; i < n; i++) { + int code = Short.toUnsignedInt(leftSeg.getAtIndex(PTypeIO.LE_SHORT, i % leftCap)); + long leftBits = (long) (dict[code] & 0xFFFF) << rightBitWidth; + long rightBits = rightSeg.getAtIndex(PTypeIO.LE_LONG, i % rightCap); + out.setAtIndex(PTypeIO.LE_LONG, i, leftBits | rightBits); + } + + if (meta.patches() != null) { + applyPatchesF64(ctx, meta.patches(), out, rightSeg, rightCap, rightBitWidth); + } + + return new DoubleArray(ctx.dtype(), n, out.asReadOnly()); + } + + private static Array decodeF32(DecodeContext ctx, ALPRDMetadata meta, short[] dict, int rightBitWidth, long n) { + MemorySegment leftSeg = ctx.decodeChildSegment(0, AlpRdEncoding.U16_DTYPE, n); + MemorySegment rightSeg = ctx.decodeChildSegment(1, AlpRdEncoding.U32_DTYPE, n); + long leftCap = SegmentBroadcast.capacity(leftSeg, 2); + long rightCap = SegmentBroadcast.capacity(rightSeg, 4); + MemorySegment out = ctx.arena().allocate(n * Integer.BYTES, Integer.BYTES); + + for (long i = 0; i < n; i++) { + int code = Short.toUnsignedInt(leftSeg.getAtIndex(PTypeIO.LE_SHORT, i % leftCap)); + int leftBits = (dict[code] & 0xFFFF) << rightBitWidth; + int rightBits = rightSeg.getAtIndex(PTypeIO.LE_INT, i % rightCap); + out.setAtIndex(PTypeIO.LE_INT, i, leftBits | rightBits); + } + + if (meta.patches() != null) { + applyPatchesF32(ctx, meta.patches(), out, rightSeg, rightCap, rightBitWidth); + } + + return new FloatArray(ctx.dtype(), n, out.asReadOnly()); + } + + private static void applyPatchesF64(DecodeContext ctx, PatchesMetadata pm, + MemorySegment out, MemorySegment rightSeg, long rightCap, int rightBitWidth) { + long numPatches = pm.len(); + long offset = pm.offset(); + PType idxPtype = PType.fromOrdinal(pm.indices_ptype().value()); + + MemorySegment idxSeg = ctx.decodeChildSegment(2, new DType.Primitive(idxPtype, false), numPatches); + MemorySegment valSeg = ctx.decodeChildSegment(3, AlpRdEncoding.U16_DTYPE, numPatches); + int idxBytes = idxPtype.byteSize(); + long valCap = SegmentBroadcast.capacity(valSeg, 2); + + for (long j = 0; j < numPatches; j++) { + long absIdx = readUnsigned(idxSeg, SegmentBroadcast.elementOffset(idxSeg, j, idxBytes), idxPtype) - offset; + short actualLeftU16 = valSeg.getAtIndex(PTypeIO.LE_SHORT, j % valCap); + long leftBits = (long) (actualLeftU16 & 0xFFFF) << rightBitWidth; + long rightBits = rightSeg.getAtIndex(PTypeIO.LE_LONG, absIdx % rightCap); + out.setAtIndex(PTypeIO.LE_LONG, absIdx, leftBits | rightBits); + } + } + + private static void applyPatchesF32(DecodeContext ctx, PatchesMetadata pm, + MemorySegment out, MemorySegment rightSeg, long rightCap, int rightBitWidth) { + long numPatches = pm.len(); + long offset = pm.offset(); + PType idxPtype = PType.fromOrdinal(pm.indices_ptype().value()); + + MemorySegment idxSeg = ctx.decodeChildSegment(2, new DType.Primitive(idxPtype, false), numPatches); + MemorySegment valSeg = ctx.decodeChildSegment(3, AlpRdEncoding.U16_DTYPE, numPatches); + int idxBytes = idxPtype.byteSize(); + long valCap = SegmentBroadcast.capacity(valSeg, 2); + + for (long j = 0; j < numPatches; j++) { + long absIdx = readUnsigned(idxSeg, SegmentBroadcast.elementOffset(idxSeg, j, idxBytes), idxPtype) - offset; + short actualLeftU16 = valSeg.getAtIndex(PTypeIO.LE_SHORT, j % valCap); + int leftBits = (actualLeftU16 & 0xFFFF) << rightBitWidth; + int rightBits = rightSeg.getAtIndex(PTypeIO.LE_INT, absIdx % rightCap); + out.setAtIndex(PTypeIO.LE_INT, (int) absIdx, leftBits | rightBits); + } + } + + private static long readUnsigned(MemorySegment seg, long off, PType ptype) { + return switch (ptype) { + case U8 -> Byte.toUnsignedLong(seg.get(ValueLayout.JAVA_BYTE, off)); + case U16 -> Short.toUnsignedLong(seg.get(PTypeIO.LE_SHORT, off)); + case U32 -> Integer.toUnsignedLong(seg.get(PTypeIO.LE_INT, off)); + case U64 -> seg.get(PTypeIO.LE_LONG, off); + default -> throw new VortexException(EncodingId.VORTEX_ALPRD, + "non-unsigned patch index ptype " + ptype); + }; + } + + private static ALPRDMetadata parseMeta(DecodeContext ctx) { + ByteBuffer rawMeta = ctx.metadata(); + if (rawMeta == null || !rawMeta.hasRemaining()) { + return new ALPRDMetadata(0, 0, java.util.List.of(), + io.github.dfa1.vortex.proto.PType.fromValue(PType.U16.ordinal()), null); + } + try { + MemorySegment metaSeg = MemorySegment.ofBuffer(rawMeta.duplicate()); + return ALPRDMetadata.decode(metaSeg, 0, metaSeg.byteSize()); + } catch (IOException e) { + throw new VortexException(EncodingId.VORTEX_ALPRD, "invalid metadata", e); + } + } +} diff --git a/reader/src/main/java/io/github/dfa1/vortex/reader/decode/BitpackedEncodingDecoder.java b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/BitpackedEncodingDecoder.java new file mode 100644 index 00000000..a0a584d5 --- /dev/null +++ b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/BitpackedEncodingDecoder.java @@ -0,0 +1,526 @@ +package io.github.dfa1.vortex.reader.decode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.PType; +import io.github.dfa1.vortex.core.VortexException; +import io.github.dfa1.vortex.core.array.Array; +import io.github.dfa1.vortex.core.array.ByteArray; +import io.github.dfa1.vortex.core.array.IntArray; +import io.github.dfa1.vortex.core.array.LongArray; +import io.github.dfa1.vortex.core.array.ShortArray; +import io.github.dfa1.vortex.encoding.BitpackedEncoding; +import io.github.dfa1.vortex.encoding.DecodeContext; +import io.github.dfa1.vortex.encoding.EncodingDecoder; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.PTypeIO; +import io.github.dfa1.vortex.encoding.SegmentBroadcast; +import io.github.dfa1.vortex.proto.BitPackedMetadata; +import io.github.dfa1.vortex.proto.PatchesMetadata; + +import java.io.IOException; +import java.lang.foreign.MemorySegment; +import java.lang.foreign.ValueLayout; +import java.nio.ByteBuffer; + +/// Read-only decoder for {@code fastlanes.bitpacked}. +public final class BitpackedEncodingDecoder implements EncodingDecoder { + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public BitpackedEncodingDecoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.FASTLANES_BITPACKED; + } + + @Override + public boolean accepts(DType dtype) { + if (!(dtype instanceof DType.Primitive p)) { + return false; + } + return switch (p.ptype()) { + case I8, I16, I32, I64, U8, U16, U32, U64 -> true; + default -> false; + }; + } + + @Override + public Array decode(DecodeContext ctx) { + ByteBuffer rawMeta = ctx.metadata(); + // proto3 elides default-valued fields, so BitPackedMetadata(0, 0, null) serialises + // to a 0-byte payload and the writer skips the empty vector. Treat absent metadata + // as all-defaults rather than rejecting — happens when bit_width=0 (constant + // residuals nested under FoR / RLE). + BitPackedMetadata meta; + if (rawMeta == null || !rawMeta.hasRemaining()) { + meta = new BitPackedMetadata(0, 0, null); + } else { + try { + MemorySegment metaSeg = MemorySegment.ofBuffer(rawMeta.duplicate()); + meta = BitPackedMetadata.decode(metaSeg, 0, metaSeg.byteSize()); + } catch (IOException e) { + throw new VortexException(EncodingId.FASTLANES_BITPACKED, "invalid metadata", e); + } + } + + int bitWidth = meta.bit_width(); + int offset = meta.offset(); + PType ptype = ((DType.Primitive) ctx.dtype()).ptype(); + int typeBits = ptype.byteSize() * 8; + long rowCount = ctx.rowCount(); + + MemorySegment packed = ctx.buffer(0); + MemorySegment output = ctx.arena().allocate(rowCount * ptype.byteSize()); + fastlanesUnpackToSeg(packed, bitWidth, offset, typeBits, rowCount, output); + + if (meta.patches() != null) { + applyPatches(ctx, meta.patches(), output, ptype.byteSize()); + } + + return switch (ptype) { + case I64, U64 -> new LongArray(ctx.dtype(), rowCount, output); + case I32, U32 -> new IntArray(ctx.dtype(), rowCount, output); + case I16, U16 -> new ShortArray(ctx.dtype(), rowCount, output); + case I8, U8 -> new ByteArray(ctx.dtype(), rowCount, output); + default -> throw new VortexException(EncodingId.FASTLANES_BITPACKED, "unsupported ptype " + ptype); + }; + } + + private static void fastlanesUnpackToSeg( + MemorySegment buf, int bitWidth, int offset, int typeBits, long rowCount, + MemorySegment output) { + if (bitWidth == 0) { + return; + } + switch (typeBits) { + case 8 -> unpackLoop8(buf, bitWidth, offset, rowCount, output); + case 16 -> unpackLoop16(buf, bitWidth, offset, rowCount, output); + case 32 -> unpackLoop32(buf, bitWidth, offset, rowCount, output); + case 64 -> unpackLoop64(buf, bitWidth, offset, rowCount, output); + default -> + throw new VortexException(EncodingId.FASTLANES_BITPACKED, "unsupported typeBits: " + typeBits); + } + } + + private static void unpackLoop8(MemorySegment buf, int bitWidth, int offset, long rowCount, MemorySegment out) { + final int lanes = 128; + long totalElems = rowCount + offset; + int blockCount = (int) ((totalElems + 1023) / 1024); + long bitMask = (1L << bitWidth) - 1L; + + int[] shifts = new int[8]; + int[] remainingBits = new int[8]; + int[] currentBits = new int[8]; + long[] loMasks = new long[8]; + long[] hiMasks = new long[8]; + long[] currWordByteBase = new long[8]; + long[] nextWordByteBase = new long[8]; + long[] outRowByteOff = new long[8]; + for (int row = 0; row < 8; row++) { + int currWord = (row * bitWidth) / 8; + int nextWord = ((row + 1) * bitWidth) / 8; + shifts[row] = (row * bitWidth) % 8; + int rem = (nextWord > currWord) ? ((row + 1) * bitWidth) % 8 : 0; + remainingBits[row] = rem; + int curr = bitWidth - rem; + currentBits[row] = curr; + loMasks[row] = rem > 0 ? (1L << curr) - 1L : 0L; + hiMasks[row] = rem > 0 ? (1L << rem) - 1L : 0L; + currWordByteBase[row] = (long) lanes * currWord; + nextWordByteBase[row] = rem > 0 ? (long) lanes * nextWord : 0L; + int o = row / 8; + int s = row % 8; + outRowByteOff[row] = BitpackedEncoding.FL_ORDER[o] * 16 + s * 128; + } + + long blockByteOff = 0L; + long blockByteStride = 128L * bitWidth; + for (int block = 0; block < blockCount; block++, blockByteOff += blockByteStride) { + int blockLogicStart = block * 1024 - offset; + boolean fullBlock = blockLogicStart >= 0 && (long) blockLogicStart + 1023L < rowCount; + + if (fullBlock) { + for (int row = 0; row < 8; row++) { + int shift = shifts[row]; + int rem = remainingBits[row]; + int curr = currentBits[row]; + long outBase = blockLogicStart + outRowByteOff[row]; + long wordBase = blockByteOff + currWordByteBase[row]; + if (rem > 0) { + long hiBase = blockByteOff + nextWordByteBase[row]; + long loMask = loMasks[row]; + long hiMask = hiMasks[row]; + for (int lane = 0; lane < lanes; lane++) { + long lo = (Byte.toUnsignedLong(buf.get(ValueLayout.JAVA_BYTE, wordBase + lane)) >>> shift) & loMask; + long hi = Byte.toUnsignedLong(buf.get(ValueLayout.JAVA_BYTE, hiBase + lane)) & hiMask; + out.set(ValueLayout.JAVA_BYTE, outBase + lane, (byte) (lo | (hi << curr))); + } + } else { + for (int lane = 0; lane < lanes; lane++) { + out.set(ValueLayout.JAVA_BYTE, outBase + lane, + (byte) ((Byte.toUnsignedLong(buf.get(ValueLayout.JAVA_BYTE, wordBase + lane)) >>> shift) & bitMask)); + } + } + } + } else { + for (int row = 0; row < 8; row++) { + int shift = shifts[row]; + int rem = remainingBits[row]; + int curr = currentBits[row]; + int o = row / 8; + int s = row % 8; + int baseIdx = blockLogicStart + BitpackedEncoding.FL_ORDER[o] * 16 + s * 128; + long wordBase = blockByteOff + currWordByteBase[row]; + long hiBase = rem > 0 ? blockByteOff + nextWordByteBase[row] : 0L; + long loMask = loMasks[row]; + long hiMask = hiMasks[row]; + for (int lane = 0; lane < lanes; lane++) { + int logicalIdx = baseIdx + lane; + if (logicalIdx < 0 || logicalIdx >= rowCount) { + continue; + } + long src = Byte.toUnsignedLong(buf.get(ValueLayout.JAVA_BYTE, wordBase + lane)); + long value; + if (rem > 0) { + long lo = (src >>> shift) & loMask; + long hi = Byte.toUnsignedLong(buf.get(ValueLayout.JAVA_BYTE, hiBase + lane)) & hiMask; + value = lo | (hi << curr); + } else { + value = (src >>> shift) & bitMask; + } + out.set(ValueLayout.JAVA_BYTE, logicalIdx, (byte) value); + } + } + } + } + } + + private static void unpackLoop16(MemorySegment buf, int bitWidth, int offset, long rowCount, MemorySegment out) { + final int lanes = 64; + long totalElems = rowCount + offset; + int blockCount = (int) ((totalElems + 1023) / 1024); + long bitMask = (1L << bitWidth) - 1L; + + int[] shifts = new int[16]; + int[] remainingBits = new int[16]; + int[] currentBits = new int[16]; + long[] loMasks = new long[16]; + long[] hiMasks = new long[16]; + long[] currWordByteBase = new long[16]; + long[] nextWordByteBase = new long[16]; + long[] outRowByteOff = new long[16]; + for (int row = 0; row < 16; row++) { + int currWord = (row * bitWidth) / 16; + int nextWord = ((row + 1) * bitWidth) / 16; + shifts[row] = (row * bitWidth) % 16; + int rem = (nextWord > currWord) ? ((row + 1) * bitWidth) % 16 : 0; + remainingBits[row] = rem; + int curr = bitWidth - rem; + currentBits[row] = curr; + loMasks[row] = rem > 0 ? (1L << curr) - 1L : 0L; + hiMasks[row] = rem > 0 ? (1L << rem) - 1L : 0L; + currWordByteBase[row] = (long) lanes * currWord * 2L; + nextWordByteBase[row] = rem > 0 ? (long) lanes * nextWord * 2L : 0L; + int o = row / 8; + int s = row % 8; + outRowByteOff[row] = (long) (BitpackedEncoding.FL_ORDER[o] * 16 + s * 128) * 2L; + } + + long blockByteOff = 0L; + long blockByteStride = 128L * bitWidth; + for (int block = 0; block < blockCount; block++, blockByteOff += blockByteStride) { + int blockLogicStart = block * 1024 - offset; + boolean fullBlock = blockLogicStart >= 0 && (long) blockLogicStart + 1023L < rowCount; + long blockOutByteBase = (long) blockLogicStart * 2L; + + if (fullBlock) { + for (int row = 0; row < 16; row++) { + int shift = shifts[row]; + int rem = remainingBits[row]; + int curr = currentBits[row]; + long outBase = blockOutByteBase + outRowByteOff[row]; + long wordBase = blockByteOff + currWordByteBase[row]; + if (rem > 0) { + long hiBase = blockByteOff + nextWordByteBase[row]; + long loMask = loMasks[row]; + long hiMask = hiMasks[row]; + long laneOff = 0L; + for (int lane = 0; lane < lanes; lane++, laneOff += 2L) { + long lo = (Short.toUnsignedLong(buf.get(PTypeIO.LE_SHORT, wordBase + laneOff)) >>> shift) & loMask; + long hi = Short.toUnsignedLong(buf.get(PTypeIO.LE_SHORT, hiBase + laneOff)) & hiMask; + out.set(PTypeIO.LE_SHORT, outBase + laneOff, (short) (lo | (hi << curr))); + } + } else { + long laneOff = 0L; + for (int lane = 0; lane < lanes; lane++, laneOff += 2L) { + out.set(PTypeIO.LE_SHORT, outBase + laneOff, + (short) ((Short.toUnsignedLong(buf.get(PTypeIO.LE_SHORT, wordBase + laneOff)) >>> shift) & bitMask)); + } + } + } + } else { + for (int row = 0; row < 16; row++) { + int shift = shifts[row]; + int rem = remainingBits[row]; + int curr = currentBits[row]; + int o = row / 8; + int s = row % 8; + int baseIdx = blockLogicStart + BitpackedEncoding.FL_ORDER[o] * 16 + s * 128; + long wordBase = blockByteOff + currWordByteBase[row]; + long hiBase = rem > 0 ? blockByteOff + nextWordByteBase[row] : 0L; + long loMask = loMasks[row]; + long hiMask = hiMasks[row]; + for (int lane = 0; lane < lanes; lane++) { + int logicalIdx = baseIdx + lane; + if (logicalIdx < 0 || logicalIdx >= rowCount) { + continue; + } + long src = Short.toUnsignedLong(buf.get(PTypeIO.LE_SHORT, wordBase + (long) lane * 2)); + long value; + if (rem > 0) { + long lo = (src >>> shift) & loMask; + long hi = Short.toUnsignedLong(buf.get(PTypeIO.LE_SHORT, hiBase + (long) lane * 2)) & hiMask; + value = lo | (hi << curr); + } else { + value = (src >>> shift) & bitMask; + } + out.set(PTypeIO.LE_SHORT, (long) logicalIdx * 2, (short) value); + } + } + } + } + } + + private static void unpackLoop32(MemorySegment buf, int bitWidth, int offset, long rowCount, MemorySegment out) { + final int lanes = 32; + long totalElems = rowCount + offset; + int blockCount = (int) ((totalElems + 1023) / 1024); + long bitMask = (1L << bitWidth) - 1L; + + int[] shifts = new int[32]; + int[] remainingBits = new int[32]; + int[] currentBits = new int[32]; + long[] loMasks = new long[32]; + long[] hiMasks = new long[32]; + long[] currWordByteBase = new long[32]; + long[] nextWordByteBase = new long[32]; + long[] outRowByteOff = new long[32]; + for (int row = 0; row < 32; row++) { + int currWord = (row * bitWidth) / 32; + int nextWord = ((row + 1) * bitWidth) / 32; + shifts[row] = (row * bitWidth) % 32; + int rem = (nextWord > currWord) ? ((row + 1) * bitWidth) % 32 : 0; + remainingBits[row] = rem; + int curr = bitWidth - rem; + currentBits[row] = curr; + loMasks[row] = rem > 0 ? (1L << curr) - 1L : 0L; + hiMasks[row] = rem > 0 ? (1L << rem) - 1L : 0L; + currWordByteBase[row] = (long) lanes * currWord * 4L; + nextWordByteBase[row] = rem > 0 ? (long) lanes * nextWord * 4L : 0L; + int o = row / 8; + int s = row % 8; + outRowByteOff[row] = (long) (BitpackedEncoding.FL_ORDER[o] * 16 + s * 128) * 4L; + } + + long blockByteOff = 0L; + long blockByteStride = 128L * bitWidth; + for (int block = 0; block < blockCount; block++, blockByteOff += blockByteStride) { + int blockLogicStart = block * 1024 - offset; + boolean fullBlock = blockLogicStart >= 0 && (long) blockLogicStart + 1023L < rowCount; + long blockOutByteBase = (long) blockLogicStart * 4L; + + if (fullBlock) { + for (int row = 0; row < 32; row++) { + int shift = shifts[row]; + int rem = remainingBits[row]; + int curr = currentBits[row]; + long outBase = blockOutByteBase + outRowByteOff[row]; + long wordBase = blockByteOff + currWordByteBase[row]; + if (rem > 0) { + long hiBase = blockByteOff + nextWordByteBase[row]; + long loMask = loMasks[row]; + long hiMask = hiMasks[row]; + long laneOff = 0L; + for (int lane = 0; lane < lanes; lane++, laneOff += 4L) { + long lo = (Integer.toUnsignedLong(buf.get(PTypeIO.LE_INT, wordBase + laneOff)) >>> shift) & loMask; + long hi = Integer.toUnsignedLong(buf.get(PTypeIO.LE_INT, hiBase + laneOff)) & hiMask; + out.set(PTypeIO.LE_INT, outBase + laneOff, (int) (lo | (hi << curr))); + } + } else { + long laneOff = 0L; + for (int lane = 0; lane < lanes; lane++, laneOff += 4L) { + out.set(PTypeIO.LE_INT, outBase + laneOff, + (int) ((Integer.toUnsignedLong(buf.get(PTypeIO.LE_INT, wordBase + laneOff)) >>> shift) & bitMask)); + } + } + } + } else { + for (int row = 0; row < 32; row++) { + int shift = shifts[row]; + int rem = remainingBits[row]; + int curr = currentBits[row]; + int o = row / 8; + int s = row % 8; + int baseIdx = blockLogicStart + BitpackedEncoding.FL_ORDER[o] * 16 + s * 128; + long wordBase = blockByteOff + currWordByteBase[row]; + long hiBase = rem > 0 ? blockByteOff + nextWordByteBase[row] : 0L; + long loMask = loMasks[row]; + long hiMask = hiMasks[row]; + for (int lane = 0; lane < lanes; lane++) { + int logicalIdx = baseIdx + lane; + if (logicalIdx < 0 || logicalIdx >= rowCount) { + continue; + } + long src = Integer.toUnsignedLong(buf.get(PTypeIO.LE_INT, wordBase + (long) lane * 4)); + long value; + if (rem > 0) { + long lo = (src >>> shift) & loMask; + long hi = Integer.toUnsignedLong(buf.get(PTypeIO.LE_INT, hiBase + (long) lane * 4)) & hiMask; + value = lo | (hi << curr); + } else { + value = (src >>> shift) & bitMask; + } + out.set(PTypeIO.LE_INT, (long) logicalIdx * 4, (int) value); + } + } + } + } + } + + private static void unpackLoop64(MemorySegment buf, int bitWidth, int offset, long rowCount, MemorySegment out) { + final int lanes = 16; + long totalElems = rowCount + offset; + int blockCount = (int) ((totalElems + 1023) / 1024); + long bitMask = bitWidth == 64 ? -1L : (1L << bitWidth) - 1L; + + int[] shifts = new int[64]; + int[] remainingBits = new int[64]; + int[] currentBits = new int[64]; + long[] loMasks = new long[64]; + long[] hiMasks = new long[64]; + long[] currWordByteBase = new long[64]; + long[] nextWordByteBase = new long[64]; + long[] outRowByteOff = new long[64]; + for (int row = 0; row < 64; row++) { + int currWord = (row * bitWidth) / 64; + int nextWord = ((row + 1) * bitWidth) / 64; + shifts[row] = (row * bitWidth) % 64; + int rem = (nextWord > currWord) ? ((row + 1) * bitWidth) % 64 : 0; + remainingBits[row] = rem; + int curr = bitWidth - rem; + currentBits[row] = curr; + loMasks[row] = rem > 0 ? (1L << curr) - 1L : 0L; + hiMasks[row] = rem > 0 ? (1L << rem) - 1L : 0L; + currWordByteBase[row] = (long) lanes * currWord * 8L; + nextWordByteBase[row] = rem > 0 ? (long) lanes * nextWord * 8L : 0L; + int o = row / 8; + int s = row % 8; + outRowByteOff[row] = (long) (BitpackedEncoding.FL_ORDER[o] * 16 + s * 128) * 8L; + } + + long blockByteOff = 0L; + long blockByteStride = 128L * bitWidth; + for (int block = 0; block < blockCount; block++, blockByteOff += blockByteStride) { + int blockLogicStart = block * 1024 - offset; + boolean fullBlock = blockLogicStart >= 0 && (long) blockLogicStart + 1023L < rowCount; + long blockOutByteBase = (long) blockLogicStart * 8L; + + if (fullBlock) { + for (int row = 0; row < 64; row++) { + int shift = shifts[row]; + int rem = remainingBits[row]; + int curr = currentBits[row]; + long outBase = blockOutByteBase + outRowByteOff[row]; + long wordBase = blockByteOff + currWordByteBase[row]; + if (rem > 0) { + long hiBase = blockByteOff + nextWordByteBase[row]; + long loMask = loMasks[row]; + long hiMask = hiMasks[row]; + long laneOff = 0L; + for (int lane = 0; lane < lanes; lane++, laneOff += 8L) { + long lo = (buf.get(PTypeIO.LE_LONG, wordBase + laneOff) >>> shift) & loMask; + long hi = buf.get(PTypeIO.LE_LONG, hiBase + laneOff) & hiMask; + out.set(PTypeIO.LE_LONG, outBase + laneOff, lo | (hi << curr)); + } + } else { + long laneOff = 0L; + for (int lane = 0; lane < lanes; lane++, laneOff += 8L) { + out.set(PTypeIO.LE_LONG, outBase + laneOff, + (buf.get(PTypeIO.LE_LONG, wordBase + laneOff) >>> shift) & bitMask); + } + } + } + } else { + for (int row = 0; row < 64; row++) { + int shift = shifts[row]; + int rem = remainingBits[row]; + int curr = currentBits[row]; + int o = row / 8; + int s = row % 8; + int baseIdx = blockLogicStart + BitpackedEncoding.FL_ORDER[o] * 16 + s * 128; + long wordBase = blockByteOff + currWordByteBase[row]; + long hiBase = rem > 0 ? blockByteOff + nextWordByteBase[row] : 0L; + long loMask = loMasks[row]; + long hiMask = hiMasks[row]; + for (int lane = 0; lane < lanes; lane++) { + int logicalIdx = baseIdx + lane; + if (logicalIdx < 0 || logicalIdx >= rowCount) { + continue; + } + long src = buf.get(PTypeIO.LE_LONG, wordBase + (long) lane * 8); + long value; + if (rem > 0) { + long lo = (src >>> shift) & loMask; + long hi = buf.get(PTypeIO.LE_LONG, hiBase + (long) lane * 8) & hiMask; + value = lo | (hi << curr); + } else { + value = (src >>> shift) & bitMask; + } + out.set(PTypeIO.LE_LONG, (long) logicalIdx * 8, value); + } + } + } + } + } + + private static void applyPatches(DecodeContext ctx, PatchesMetadata pm, + MemorySegment out, int elemBytes) { + long numPatches = pm.len(); + if (numPatches == 0) { + return; + } + long offset = pm.offset(); + PType idxPtype = ptypeFromProto(pm.indices_ptype()); + + MemorySegment idxSeg = ctx.decodeChildSegment(0, new DType.Primitive(idxPtype, false), numPatches); + MemorySegment valSeg = ctx.decodeChildSegment(1, ctx.dtype(), numPatches); + + int idxBytes = idxPtype.byteSize(); + long n = ctx.rowCount(); + for (long i = 0; i < numPatches; i++) { + long absIdx = readUnsignedIdx(idxSeg, SegmentBroadcast.elementOffset(idxSeg, i, idxBytes), idxPtype) - offset; + if (absIdx < 0 || absIdx >= n) { + throw new VortexException(EncodingId.FASTLANES_BITPACKED, + "patch index " + absIdx + " out of range [0," + n + ")"); + } + MemorySegment.copy(valSeg, SegmentBroadcast.elementOffset(valSeg, i, elemBytes), + out, absIdx * elemBytes, elemBytes); + } + } + + private static long readUnsignedIdx(MemorySegment seg, long off, PType ptype) { + return switch (ptype) { + case U8 -> Byte.toUnsignedLong(seg.get(ValueLayout.JAVA_BYTE, off)); + case U16 -> Short.toUnsignedLong(seg.get(PTypeIO.LE_SHORT, off)); + case U32 -> Integer.toUnsignedLong(seg.get(PTypeIO.LE_INT, off)); + case U64 -> seg.get(PTypeIO.LE_LONG, off); + default -> throw new VortexException(EncodingId.FASTLANES_BITPACKED, + "non-unsigned patch index ptype " + ptype); + }; + } + + private static PType ptypeFromProto(io.github.dfa1.vortex.proto.PType proto) { + return PType.fromOrdinal(proto.value()); + } +} diff --git a/reader/src/main/java/io/github/dfa1/vortex/reader/decode/BoolEncodingDecoder.java b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/BoolEncodingDecoder.java new file mode 100644 index 00000000..7b972dc8 --- /dev/null +++ b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/BoolEncodingDecoder.java @@ -0,0 +1,37 @@ +package io.github.dfa1.vortex.reader.decode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.array.Array; +import io.github.dfa1.vortex.core.array.BoolArray; +import io.github.dfa1.vortex.encoding.DecodeContext; +import io.github.dfa1.vortex.encoding.EncodingDecoder; +import io.github.dfa1.vortex.encoding.EncodingId; + +/// Read-only decoder for {@code vortex.bool} (bit-packed boolean arrays, LSB first). +/// +///

ADR 0001 Phase 2: first encoding lifted into a standalone {@link EncodingDecoder} +/// implementation in the {@code reader} module. The corresponding write-side encode +/// path continues to live on {@link io.github.dfa1.vortex.encoding.BoolEncoding} in +/// {@code core}; that file is peeled into a {@code BoolEncodingEncoder} in +/// {@code writer} during Phase 3. +public final class BoolEncodingDecoder implements EncodingDecoder { + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public BoolEncodingDecoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_BOOL; + } + + @Override + public boolean accepts(DType dtype) { + return dtype instanceof DType.Bool; + } + + @Override + public Array decode(DecodeContext ctx) { + return new BoolArray(ctx.dtype(), ctx.rowCount(), ctx.buffer(0)); + } +} diff --git a/reader/src/main/java/io/github/dfa1/vortex/reader/decode/ByteBoolEncodingDecoder.java b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/ByteBoolEncodingDecoder.java new file mode 100644 index 00000000..9c1447b5 --- /dev/null +++ b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/ByteBoolEncodingDecoder.java @@ -0,0 +1,46 @@ +package io.github.dfa1.vortex.reader.decode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.array.Array; +import io.github.dfa1.vortex.core.array.BoolArray; +import io.github.dfa1.vortex.encoding.DecodeContext; +import io.github.dfa1.vortex.encoding.EncodingDecoder; +import io.github.dfa1.vortex.encoding.EncodingId; + +import java.lang.foreign.MemorySegment; +import java.lang.foreign.ValueLayout; + +/// Read-only decoder for {@code vortex.bytebool} — packs the input byte buffer into the +/// bit-packed {@link BoolArray} layout used by {@code vortex.bool}. +public final class ByteBoolEncodingDecoder implements EncodingDecoder { + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public ByteBoolEncodingDecoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_BYTEBOOL; + } + + @Override + public boolean accepts(DType dtype) { + return dtype instanceof DType.Bool; + } + + @Override + public Array decode(DecodeContext ctx) { + long n = ctx.rowCount(); + MemorySegment bytes = ctx.buffer(0); + long packedBytes = (n + 7) >>> 3; + MemorySegment packed = ctx.arena().allocate(packedBytes > 0 ? packedBytes : 1); + for (long i = 0; i < n; i++) { + if (bytes.get(ValueLayout.JAVA_BYTE, i) != 0) { + long byteIdx = i >>> 3; + byte cur = packed.get(ValueLayout.JAVA_BYTE, byteIdx); + packed.set(ValueLayout.JAVA_BYTE, byteIdx, (byte) (cur | (1 << (i & 7)))); + } + } + return new BoolArray(ctx.dtype(), n, packed); + } +} diff --git a/reader/src/main/java/io/github/dfa1/vortex/reader/decode/ChunkedEncodingDecoder.java b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/ChunkedEncodingDecoder.java new file mode 100644 index 00000000..aad37f82 --- /dev/null +++ b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/ChunkedEncodingDecoder.java @@ -0,0 +1,128 @@ +package io.github.dfa1.vortex.reader.decode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.PType; +import io.github.dfa1.vortex.core.VortexException; +import io.github.dfa1.vortex.core.array.Array; +import io.github.dfa1.vortex.core.array.ArraySegments; +import io.github.dfa1.vortex.core.array.ByteArray; +import io.github.dfa1.vortex.core.array.DoubleArray; +import io.github.dfa1.vortex.core.array.FloatArray; +import io.github.dfa1.vortex.core.array.IntArray; +import io.github.dfa1.vortex.core.array.LongArray; +import io.github.dfa1.vortex.core.array.ShortArray; +import io.github.dfa1.vortex.core.array.StructArray; +import io.github.dfa1.vortex.encoding.DecodeContext; +import io.github.dfa1.vortex.encoding.EncodingDecoder; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.SegmentBroadcast; + +import java.lang.foreign.MemorySegment; +import java.lang.foreign.SegmentAllocator; +import java.lang.foreign.ValueLayout; +import java.nio.ByteOrder; +import java.util.ArrayList; +import java.util.List; + +/// Read-only decoder for {@code vortex.chunked}. +public final class ChunkedEncodingDecoder implements EncodingDecoder { + + private static final ValueLayout.OfLong LE_LONG = + ValueLayout.JAVA_LONG_UNALIGNED.withOrder(ByteOrder.LITTLE_ENDIAN); + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public ChunkedEncodingDecoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_CHUNKED; + } + + @Override + public boolean accepts(DType dtype) { + return dtype instanceof DType.Primitive || dtype instanceof DType.Struct; + } + + @Override + public Array decode(DecodeContext ctx) { + int nchildren = ctx.node().children().length; + if (nchildren < 1) { + throw new VortexException(EncodingId.VORTEX_CHUNKED, + "needs at least one child (chunk offsets)"); + } + int nchunks = nchildren - 1; + long[] offsets = readOffsets(ctx, nchunks); + + DType dtype = ctx.dtype(); + List chunks = new ArrayList<>(nchunks); + for (int i = 0; i < nchunks; i++) { + long chunkLen = offsets[i + 1] - offsets[i]; + chunks.add(ctx.decodeChild(i + 1, dtype, chunkLen)); + } + + return concat(chunks, dtype, ctx.rowCount(), ctx.arena()); + } + + private static long[] readOffsets(DecodeContext ctx, int nchunks) { + DType u64 = new DType.Primitive(PType.U64, false); + MemorySegment offsetsBuf = ctx.decodeChildSegment(0, u64, nchunks + 1L); + long cap = SegmentBroadcast.capacity(offsetsBuf, 8); + long[] offsets = new long[nchunks + 1]; + for (int i = 0; i <= nchunks; i++) { + offsets[i] = offsetsBuf.get(LE_LONG, (i % cap) * 8); + } + return offsets; + } + + private static Array concat(List chunks, DType dtype, long totalRows, SegmentAllocator arena) { + if (dtype instanceof DType.Primitive pt) { + return concatPrimitive(chunks, pt, dtype, totalRows, arena); + } + if (dtype instanceof DType.Struct struct) { + return concatStruct(chunks, struct, totalRows, arena); + } + throw new VortexException(EncodingId.VORTEX_CHUNKED, + "concat not supported for dtype: " + dtype); + } + + private static Array concatPrimitive( + List chunks, DType.Primitive pt, DType dtype, long totalRows, SegmentAllocator arena + ) { + PType ptype = pt.ptype(); + MemorySegment combined = arena.allocate(totalRows * ptype.byteSize()); + long byteOffset = 0; + for (Array chunk : chunks) { + MemorySegment src = ArraySegments.of(chunk); + MemorySegment.copy(src, 0, combined, byteOffset, src.byteSize()); + byteOffset += src.byteSize(); + } + MemorySegment ro = combined.asReadOnly(); + return switch (ptype) { + case I64, U64 -> new LongArray(dtype, totalRows, ro); + case I32, U32 -> new IntArray(dtype, totalRows, ro); + case F64 -> new DoubleArray(dtype, totalRows, ro); + case F32 -> new FloatArray(dtype, totalRows, ro); + case I16, U16 -> new ShortArray(dtype, totalRows, ro); + case I8, U8 -> new ByteArray(dtype, totalRows, ro); + default -> throw new VortexException(EncodingId.VORTEX_CHUNKED, + "unsupported ptype for concat: " + ptype); + }; + } + + private static StructArray concatStruct( + List chunks, DType.Struct struct, long totalRows, SegmentAllocator arena + ) { + int nfields = struct.fieldTypes().size(); + List concatFields = new ArrayList<>(nfields); + for (int f = 0; f < nfields; f++) { + DType fieldDtype = struct.fieldTypes().get(f); + List fieldChunks = new ArrayList<>(chunks.size()); + for (Array chunk : chunks) { + fieldChunks.add(((StructArray) chunk).field(f)); + } + concatFields.add(concat(fieldChunks, fieldDtype, totalRows, arena)); + } + return new StructArray(struct, totalRows, concatFields); + } +} diff --git a/reader/src/main/java/io/github/dfa1/vortex/reader/decode/ConstantEncodingDecoder.java b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/ConstantEncodingDecoder.java new file mode 100644 index 00000000..becb6d90 --- /dev/null +++ b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/ConstantEncodingDecoder.java @@ -0,0 +1,175 @@ +package io.github.dfa1.vortex.reader.decode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.PType; +import io.github.dfa1.vortex.core.VortexException; +import io.github.dfa1.vortex.core.array.Array; +import io.github.dfa1.vortex.core.array.ArraySegments; +import io.github.dfa1.vortex.core.array.BoolArray; +import io.github.dfa1.vortex.core.array.ByteArray; +import io.github.dfa1.vortex.core.array.DoubleArray; +import io.github.dfa1.vortex.core.array.FloatArray; +import io.github.dfa1.vortex.core.array.GenericArray; +import io.github.dfa1.vortex.core.array.IntArray; +import io.github.dfa1.vortex.core.array.LongArray; +import io.github.dfa1.vortex.core.array.NullArray; +import io.github.dfa1.vortex.core.array.ShortArray; +import io.github.dfa1.vortex.core.array.VarBinArray; +import io.github.dfa1.vortex.encoding.DecodeContext; +import io.github.dfa1.vortex.encoding.EncodingDecoder; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.PTypeIO; +import io.github.dfa1.vortex.proto.ScalarValue; + +import java.io.IOException; +import java.lang.foreign.MemorySegment; +import java.lang.foreign.ValueLayout; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.charset.StandardCharsets; + +/// Read-only decoder for {@code vortex.constant}. +public final class ConstantEncodingDecoder implements EncodingDecoder { + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public ConstantEncodingDecoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_CONSTANT; + } + + @Override + public boolean accepts(DType dtype) { + return dtype instanceof DType.Primitive; + } + + @Override + public Array decode(DecodeContext ctx) { + MemorySegment scalarBuf = ctx.buffer(0); + ScalarValue scalar; + try { + scalar = ScalarValue.decode(scalarBuf, 0, scalarBuf.byteSize()); + } catch (IOException e) { + throw new VortexException(EncodingId.VORTEX_CONSTANT, "invalid scalar value", e); + } + + long n = ctx.rowCount(); + + if (ctx.dtype() instanceof DType.Null) { + return new NullArray(ctx.dtype(), n); + } + + if (ctx.dtype() instanceof DType.Utf8 || ctx.dtype() instanceof DType.Binary) { + return decodeString(ctx, scalar, n); + } + + if (ctx.dtype() instanceof DType.Bool) { + return decodeBool(ctx, scalar, n); + } + + if (ctx.dtype() instanceof DType.Decimal) { + return decodeDecimal(ctx, scalar, n); + } + + if (ctx.dtype() instanceof DType.Extension ext) { + var storageCtx = new DecodeContext(ctx.node(), ext.storageDType(), ctx.rowCount(), + ctx.segmentBuffers(), ctx.registry(), ctx.arena()); + Array storage = decode(storageCtx); + return new GenericArray(ctx.dtype(), n, ArraySegments.of(storage)); + } + + if (!(ctx.dtype() instanceof DType.Primitive p)) { + throw new VortexException(EncodingId.VORTEX_CONSTANT, "unsupported dtype " + ctx.dtype()); + } + + PType ptype = p.ptype(); + int elemBytes = ptype.byteSize(); + long rawBits = scalarToRawBits(scalar, ptype); + + MemorySegment outSeg = ctx.arena().allocate(elemBytes); + ByteBuffer out = outSeg.asByteBuffer().order(ByteOrder.LITTLE_ENDIAN); + writeRaw(out, ptype, rawBits); + + MemorySegment ro = outSeg.asReadOnly(); + return switch (ptype) { + case I64, U64 -> new LongArray(ctx.dtype(), n, ro); + case I32, U32 -> new IntArray(ctx.dtype(), n, ro); + case F64 -> new DoubleArray(ctx.dtype(), n, ro); + case F32 -> new FloatArray(ctx.dtype(), n, ro); + case I16, U16 -> new ShortArray(ctx.dtype(), n, ro); + case I8, U8 -> new ByteArray(ctx.dtype(), n, ro); + default -> throw new VortexException(EncodingId.VORTEX_CONSTANT, "unsupported ptype " + ptype); + }; + } + + private static Array decodeDecimal(DecodeContext ctx, ScalarValue scalar, long n) { + byte[] elemBytes = scalar.bytes_value(); + int elemLen = elemBytes.length; + MemorySegment outSeg = ctx.arena().allocate(n * elemLen); + MemorySegment elemSeg = MemorySegment.ofArray(elemBytes); + for (long i = 0; i < n; i++) { + MemorySegment.copy(elemSeg, 0L, outSeg, i * elemLen, elemLen); + } + return new GenericArray(ctx.dtype(), n, outSeg.asReadOnly()); + } + + private static Array decodeBool(DecodeContext ctx, ScalarValue scalar, long n) { + boolean value = scalar.bool_value() != null && scalar.bool_value(); + long numBytes = (n + 7) >>> 3; + MemorySegment seg = ctx.arena().allocate(numBytes); + if (value) { + for (long i = 0; i < numBytes; i++) { + seg.set(ValueLayout.JAVA_BYTE, i, (byte) 0xFF); + } + } + return new BoolArray(ctx.dtype(), n, seg.asReadOnly()); + } + + private static Array decodeString(DecodeContext ctx, ScalarValue scalar, long n) { + byte[] strBytes = scalar.string_value() != null + ? scalar.string_value().getBytes(StandardCharsets.UTF_8) + : (scalar.bytes_value() != null ? scalar.bytes_value() : new byte[0]); + + int strLen = strBytes.length; + + MemorySegment bytesSeg = ctx.arena().allocate((long) n * strLen); + for (long i = 0; i < n; i++) { + MemorySegment.copy(MemorySegment.ofArray(strBytes), 0L, bytesSeg, i * strLen, strLen); + } + + MemorySegment offsetsSeg = ctx.arena().allocate((n + 1) * 4L, 4); + for (long i = 0; i <= n; i++) { + offsetsSeg.setAtIndex(PTypeIO.LE_INT, i, (int) (i * strLen)); + } + + return new VarBinArray(ctx.dtype(), n, bytesSeg.asReadOnly(), offsetsSeg.asReadOnly(), PType.I32); + } + + private static long scalarToRawBits(ScalarValue scalar, PType ptype) { + if (scalar.int64_value() != null) { + return scalar.int64_value(); + } + if (scalar.uint64_value() != null) { + return scalar.uint64_value(); + } + if (scalar.f32_value() != null) { + return Float.floatToRawIntBits(scalar.f32_value()); + } + if (scalar.f64_value() != null) { + return Double.doubleToRawLongBits(scalar.f64_value()); + } + return 0L; + } + + private static void writeRaw(ByteBuffer buf, PType ptype, long rawBits) { + switch (ptype.byteSize()) { + case 1 -> buf.put((byte) rawBits); + case 2 -> buf.putShort((short) rawBits); + case 4 -> buf.putInt((int) rawBits); + case 8 -> buf.putLong(rawBits); + default -> throw new VortexException(EncodingId.VORTEX_CONSTANT, "unsupported ptype " + ptype); + } + } +} diff --git a/reader/src/main/java/io/github/dfa1/vortex/reader/decode/DateTimePartsEncodingDecoder.java b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/DateTimePartsEncodingDecoder.java new file mode 100644 index 00000000..637fb01e --- /dev/null +++ b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/DateTimePartsEncodingDecoder.java @@ -0,0 +1,60 @@ +package io.github.dfa1.vortex.reader.decode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.PType; +import io.github.dfa1.vortex.core.VortexException; +import io.github.dfa1.vortex.core.array.Array; +import io.github.dfa1.vortex.core.array.GenericArray; +import io.github.dfa1.vortex.encoding.DecodeContext; +import io.github.dfa1.vortex.encoding.EncodingDecoder; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.proto.DateTimePartsMetadata; + +import java.io.IOException; +import java.lang.foreign.MemorySegment; +import java.nio.ByteBuffer; + +/// Read-only decoder for {@code vortex.datetimeparts}. +public final class DateTimePartsEncodingDecoder implements EncodingDecoder { + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public DateTimePartsEncodingDecoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_DATETIMEPARTS; + } + + @Override + public boolean accepts(DType dtype) { + return dtype instanceof DType.Extension; + } + + @Override + public Array decode(DecodeContext ctx) { + ByteBuffer meta = ctx.metadata(); + if (meta == null || meta.remaining() == 0) { + throw new VortexException(EncodingId.VORTEX_DATETIMEPARTS, "missing metadata"); + } + DateTimePartsMetadata decoded; + try { + MemorySegment metaSeg = MemorySegment.ofBuffer(meta.duplicate()); + decoded = DateTimePartsMetadata.decode(metaSeg, 0, metaSeg.byteSize()); + } catch (IOException e) { + throw new VortexException(EncodingId.VORTEX_DATETIMEPARTS, "invalid metadata: " + e.getMessage()); + } + + PType daysPtype = PType.fromOrdinal(decoded.days_ptype().value()); + PType secondsPtype = PType.fromOrdinal(decoded.seconds_ptype().value()); + PType subsecondsPtype = PType.fromOrdinal(decoded.subseconds_ptype().value()); + boolean nullable = ctx.dtype().nullable(); + + Array days = ctx.decodeChild(0, new DType.Primitive(daysPtype, nullable), ctx.rowCount()); + Array seconds = ctx.decodeChild(1, new DType.Primitive(secondsPtype, false), ctx.rowCount()); + Array subseconds = ctx.decodeChild(2, new DType.Primitive(subsecondsPtype, false), ctx.rowCount()); + + return new GenericArray(ctx.dtype(), ctx.rowCount(), new MemorySegment[0], + new Array[]{days, seconds, subseconds}); + } +} diff --git a/reader/src/main/java/io/github/dfa1/vortex/reader/decode/DecimalBytePartsEncodingDecoder.java b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/DecimalBytePartsEncodingDecoder.java new file mode 100644 index 00000000..51551b53 --- /dev/null +++ b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/DecimalBytePartsEncodingDecoder.java @@ -0,0 +1,66 @@ +package io.github.dfa1.vortex.reader.decode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.PType; +import io.github.dfa1.vortex.core.VortexException; +import io.github.dfa1.vortex.core.array.Array; +import io.github.dfa1.vortex.core.array.GenericArray; +import io.github.dfa1.vortex.encoding.ArrayNode; +import io.github.dfa1.vortex.encoding.DecodeContext; +import io.github.dfa1.vortex.encoding.EncodingDecoder; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.proto.DecimalBytePartsMetadata; + +import java.io.IOException; +import java.lang.foreign.MemorySegment; +import java.nio.ByteBuffer; + +/// Read-only decoder for {@code vortex.decimal_byte_parts}. +public final class DecimalBytePartsEncodingDecoder implements EncodingDecoder { + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public DecimalBytePartsEncodingDecoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_DECIMAL_BYTE_PARTS; + } + + @Override + public boolean accepts(DType dtype) { + return dtype instanceof DType.Decimal; + } + + @Override + public Array decode(DecodeContext ctx) { + ByteBuffer meta = ctx.metadata(); + if (meta == null || meta.remaining() == 0) { + throw new VortexException(EncodingId.VORTEX_DECIMAL_BYTE_PARTS, "missing metadata"); + } + DecimalBytePartsMetadata decoded; + try { + MemorySegment metaSeg = MemorySegment.ofBuffer(meta.duplicate()); + decoded = DecimalBytePartsMetadata.decode(metaSeg, 0, metaSeg.byteSize()); + } catch (IOException e) { + throw new VortexException(EncodingId.VORTEX_DECIMAL_BYTE_PARTS, "invalid metadata: " + e.getMessage()); + } + + int lowerPartCount = decoded.lower_part_count(); + if (lowerPartCount != 0) { + throw new VortexException(EncodingId.VORTEX_DECIMAL_BYTE_PARTS, + "lower_part_count > 0 not supported, got " + lowerPartCount); + } + + PType mspPtype = PType.fromOrdinal(decoded.zeroth_child_ptype().value()); + boolean nullable = ctx.dtype().nullable(); + DType mspDtype = new DType.Primitive(mspPtype, nullable); + ArrayNode mspNode = ctx.node().children()[0]; + DecodeContext mspCtx = new DecodeContext( + mspNode, mspDtype, ctx.rowCount(), + ctx.segmentBuffers(), ctx.registry(), ctx.arena()); + Array mspArray = ctx.registry().decode(mspCtx); + return new GenericArray(ctx.dtype(), ctx.rowCount(), new MemorySegment[0], + new Array[]{mspArray}); + } +} diff --git a/reader/src/main/java/io/github/dfa1/vortex/reader/decode/DecimalEncodingDecoder.java b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/DecimalEncodingDecoder.java new file mode 100644 index 00000000..d9fc1d41 --- /dev/null +++ b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/DecimalEncodingDecoder.java @@ -0,0 +1,69 @@ +package io.github.dfa1.vortex.reader.decode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.VortexException; +import io.github.dfa1.vortex.core.array.Array; +import io.github.dfa1.vortex.core.array.GenericArray; +import io.github.dfa1.vortex.encoding.DecodeContext; +import io.github.dfa1.vortex.encoding.EncodingDecoder; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.proto.DecimalMetadata; + +import java.io.IOException; +import java.lang.foreign.MemorySegment; +import java.nio.ByteBuffer; + +/// Read-only decoder for {@code vortex.decimal}. +public final class DecimalEncodingDecoder implements EncodingDecoder { + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public DecimalEncodingDecoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_DECIMAL; + } + + @Override + public boolean accepts(DType dtype) { + return dtype instanceof DType.Decimal; + } + + @Override + public Array decode(DecodeContext ctx) { + ByteBuffer meta = ctx.metadata(); + if (meta == null || meta.remaining() == 0) { + throw new VortexException(EncodingId.VORTEX_DECIMAL, "missing metadata"); + } + DecimalMetadata decoded; + try { + MemorySegment metaSeg = MemorySegment.ofBuffer(meta.duplicate()); + decoded = DecimalMetadata.decode(metaSeg, 0, metaSeg.byteSize()); + } catch (IOException e) { + throw new VortexException(EncodingId.VORTEX_DECIMAL, "invalid metadata: " + e.getMessage()); + } + int valuesType = decoded.values_type(); + int byteWidth = decimalTypeByteWidth(valuesType); + MemorySegment buffer = ctx.buffer(0); + long expected = ctx.rowCount() * byteWidth; + if (buffer.byteSize() < expected) { + throw new VortexException(EncodingId.VORTEX_DECIMAL, + "buffer too small: expected %d bytes, got %d".formatted(expected, buffer.byteSize())); + } + return new GenericArray(ctx.dtype(), ctx.rowCount(), buffer); + } + + private static int decimalTypeByteWidth(int valuesType) { + return switch (valuesType) { + case 0 -> 1; + case 1 -> 2; + case 2 -> 4; + case 3 -> 8; + case 4 -> 16; + case 5 -> 32; + default -> throw new VortexException(EncodingId.VORTEX_DECIMAL, + "unknown DecimalType value: " + valuesType); + }; + } +} diff --git a/reader/src/main/java/io/github/dfa1/vortex/reader/decode/DeltaEncodingDecoder.java b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/DeltaEncodingDecoder.java new file mode 100644 index 00000000..a5c1abd0 --- /dev/null +++ b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/DeltaEncodingDecoder.java @@ -0,0 +1,154 @@ +package io.github.dfa1.vortex.reader.decode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.PType; +import io.github.dfa1.vortex.core.VortexException; +import io.github.dfa1.vortex.core.array.Array; +import io.github.dfa1.vortex.core.array.ByteArray; +import io.github.dfa1.vortex.core.array.IntArray; +import io.github.dfa1.vortex.core.array.LongArray; +import io.github.dfa1.vortex.core.array.ShortArray; +import io.github.dfa1.vortex.encoding.DecodeContext; +import io.github.dfa1.vortex.encoding.DeltaEncoding; +import io.github.dfa1.vortex.encoding.EncodingDecoder; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.PTypeIO; +import io.github.dfa1.vortex.encoding.SegmentBroadcast; +import io.github.dfa1.vortex.proto.DeltaMetadata; + +import java.io.IOException; +import java.lang.foreign.MemorySegment; +import java.lang.foreign.ValueLayout; +import java.nio.ByteBuffer; + +/// Read-only decoder for {@code fastlanes.delta}. +public final class DeltaEncodingDecoder implements EncodingDecoder { + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public DeltaEncodingDecoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.FASTLANES_DELTA; + } + + @Override + public boolean accepts(DType dtype) { + if (!(dtype instanceof DType.Primitive p)) { + return false; + } + return switch (p.ptype()) { + case I8, I16, I32, I64, U8, U16, U32, U64 -> true; + default -> false; + }; + } + + @Override + public Array decode(DecodeContext ctx) { + ByteBuffer rawMeta = ctx.metadata(); + DeltaMetadata meta; + if (rawMeta == null || !rawMeta.hasRemaining()) { + meta = new DeltaMetadata(0L, 0); + } else { + try { + MemorySegment metaSeg = MemorySegment.ofBuffer(rawMeta.duplicate()); + meta = DeltaMetadata.decode(metaSeg, 0, metaSeg.byteSize()); + } catch (IOException e) { + throw new VortexException(EncodingId.FASTLANES_DELTA, "invalid metadata", e); + } + } + + PType ptype = ((DType.Primitive) ctx.dtype()).ptype(); + long rowCount = ctx.rowCount(); + int typeBits = DeltaEncoding.typeBits(ptype); + int lanes = DeltaEncoding.lanes(ptype); + long mask = DeltaEncoding.typeMask(ptype); + + long deltasLen = meta.deltas_len(); + int offset = meta.offset(); + + if (deltasLen == 0L) { + MemorySegment empty = ctx.arena().allocate(0); + return switch (ptype) { + case I64, U64 -> new LongArray(ctx.dtype(), 0L, empty); + case I32, U32 -> new IntArray(ctx.dtype(), 0L, empty); + case I16, U16 -> new ShortArray(ctx.dtype(), 0L, empty); + case I8, U8 -> new ByteArray(ctx.dtype(), 0L, empty); + default -> throw new VortexException(EncodingId.FASTLANES_DELTA, "unsupported ptype: " + ptype); + }; + } + + long basesLen = (deltasLen / DeltaEncoding.FL_CHUNK_SIZE) * lanes; + DType dtype = ctx.dtype(); + + long[] basesAll = readLongs(ctx.decodeChildSegment(0, dtype, basesLen), (int) basesLen, ptype); + long[] deltasAll = readLongs(ctx.decodeChildSegment(1, dtype, deltasLen), (int) deltasLen, ptype); + + int numChunks = (int) (deltasLen / DeltaEncoding.FL_CHUNK_SIZE); + long[] decoded = new long[(int) deltasLen]; + long[] untransposedChunk = new long[DeltaEncoding.FL_CHUNK_SIZE]; + long[] chunkBases = new long[lanes]; + long[] chunkDeltas = new long[DeltaEncoding.FL_CHUNK_SIZE]; + long[] chunkUndelta = new long[DeltaEncoding.FL_CHUNK_SIZE]; + + for (int chunk = 0; chunk < numChunks; chunk++) { + int basesOff = chunk * lanes; + int deltaOff = chunk * DeltaEncoding.FL_CHUNK_SIZE; + + System.arraycopy(basesAll, basesOff, chunkBases, 0, lanes); + System.arraycopy(deltasAll, deltaOff, chunkDeltas, 0, DeltaEncoding.FL_CHUNK_SIZE); + + undeltaChunk(chunkDeltas, chunkBases, lanes, typeBits, mask, chunkUndelta); + + for (int i = 0; i < DeltaEncoding.FL_CHUNK_SIZE; i++) { + untransposedChunk[DeltaEncoding.transposeIndex(i)] = chunkUndelta[i]; + } + System.arraycopy(untransposedChunk, 0, decoded, deltaOff, DeltaEncoding.FL_CHUNK_SIZE); + } + + long[] result = new long[(int) rowCount]; + System.arraycopy(decoded, offset, result, 0, (int) rowCount); + + MemorySegment seg = DeltaEncoding.fromLongs(result, ptype, ctx.arena()); + return switch (ptype) { + case I64, U64 -> new LongArray(ctx.dtype(), rowCount, seg); + case I32, U32 -> new IntArray(ctx.dtype(), rowCount, seg); + case I16, U16 -> new ShortArray(ctx.dtype(), rowCount, seg); + case I8, U8 -> new ByteArray(ctx.dtype(), rowCount, seg); + default -> throw new VortexException(EncodingId.FASTLANES_DELTA, "unsupported ptype: " + ptype); + }; + } + + private static void undeltaChunk(long[] deltas, long[] bases, int lanes, int typeBits, long mask, long[] out) { + for (int lane = 0; lane < lanes; lane++) { + long prev = bases[lane] & mask; + for (int row = 0; row < typeBits; row++) { + int idx = DeltaEncoding.iterateIndex(row, lane); + long next = ((deltas[idx] & mask) + prev) & mask; + out[idx] = next; + prev = next; + } + } + } + + private static long[] readLongs(MemorySegment buf, int count, PType ptype) { + long[] out = new long[count]; + int elemSize = ptype.byteSize(); + long cap = SegmentBroadcast.capacity(buf, elemSize); + for (int i = 0; i < count; i++) { + long off = (i % cap) * elemSize; + out[i] = switch (ptype) { + case I8 -> buf.get(ValueLayout.JAVA_BYTE, off); + case U8 -> Byte.toUnsignedLong(buf.get(ValueLayout.JAVA_BYTE, off)); + case I16 -> buf.get(PTypeIO.LE_SHORT, off); + case U16 -> Short.toUnsignedLong(buf.get(PTypeIO.LE_SHORT, off)); + case I32 -> buf.get(PTypeIO.LE_INT, off); + case U32 -> Integer.toUnsignedLong(buf.get(PTypeIO.LE_INT, off)); + case I64, U64 -> buf.get(PTypeIO.LE_LONG, off); + default -> throw new VortexException(EncodingId.FASTLANES_DELTA, "unsupported ptype: " + ptype); + }; + } + return out; + } +} diff --git a/reader/src/main/java/io/github/dfa1/vortex/reader/decode/DictEncodingDecoder.java b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/DictEncodingDecoder.java new file mode 100644 index 00000000..58090b84 --- /dev/null +++ b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/DictEncodingDecoder.java @@ -0,0 +1,404 @@ +package io.github.dfa1.vortex.reader.decode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.PType; +import io.github.dfa1.vortex.core.VortexException; +import io.github.dfa1.vortex.core.array.Array; +import io.github.dfa1.vortex.core.array.ByteArray; +import io.github.dfa1.vortex.core.array.DoubleArray; +import io.github.dfa1.vortex.core.array.FloatArray; +import io.github.dfa1.vortex.core.array.IntArray; +import io.github.dfa1.vortex.core.array.LongArray; +import io.github.dfa1.vortex.core.array.ShortArray; +import io.github.dfa1.vortex.core.array.VarBinArray; +import io.github.dfa1.vortex.encoding.DecodeContext; +import io.github.dfa1.vortex.encoding.EncodingDecoder; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.PTypeIO; +import io.github.dfa1.vortex.encoding.SegmentBroadcast; +import io.github.dfa1.vortex.proto.DictMetadata; + +import java.io.IOException; +import java.lang.foreign.MemorySegment; +import java.lang.foreign.ValueLayout; +import java.nio.ByteBuffer; + +/// Read-only decoder for {@code vortex.dict}. +public final class DictEncodingDecoder implements EncodingDecoder { + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public DictEncodingDecoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_DICT; + } + + @Override + public boolean accepts(DType dtype) { + return dtype instanceof DType.Primitive || dtype instanceof DType.Utf8; + } + + @Override + public Array decode(DecodeContext ctx) { + ByteBuffer meta = ctx.metadata(); + + if (ctx.dtype() instanceof DType.Utf8) { + if (ctx.node().children().length == 0) { + if (meta == null || !meta.hasRemaining()) { + throw new VortexException(EncodingId.VORTEX_DICT, "missing metadata for legacy utf8 dict"); + } + return decodeUtf8DictLegacy(ctx, meta); + } + if (meta == null || !meta.hasRemaining()) { + throw new VortexException(EncodingId.VORTEX_DICT, "missing metadata for utf8 dict"); + } + return decodeUtf8DictProto(ctx, meta.duplicate()); + } + + if (meta == null || !meta.hasRemaining()) { + throw new VortexException(EncodingId.VORTEX_DICT, "missing metadata"); + } + + if (meta.remaining() == 1) { + return decodeLegacyJava(ctx, meta.get(0)); + } + return decodeRustProto(ctx, meta.duplicate()); + } + + private static Array decodeLegacyJava(DecodeContext ctx, byte codeTypeByte) { + PType codePType = PType.fromOrdinal(Byte.toUnsignedInt(codeTypeByte)); + PType valPType = ((DType.Primitive) ctx.dtype()).ptype(); + int elemSize = valPType.byteSize(); + long rowCount = ctx.rowCount(); + + MemorySegment valuesBuf = ctx.segmentBuffers()[ctx.node().children()[0].bufferIndices()[0]]; + + DType codesDtype = new DType.Primitive(codePType, false); + MemorySegment codesBuf = ctx.decodeChildSegment(1, codesDtype, rowCount); + + MemorySegment out = ctx.arena().allocate(rowCount * (long) elemSize); + switch (codePType) { + case U8 -> expandU8(codesBuf, valuesBuf, out, rowCount, elemSize); + case U16 -> expandU16(codesBuf, valuesBuf, out, rowCount, elemSize); + case U32 -> expandU32(codesBuf, valuesBuf, out, rowCount, elemSize); + default -> { + for (long i = 0; i < rowCount; i++) { + long code = readCode(codesBuf, codePType, i); + MemorySegment.copy(valuesBuf, code * elemSize, out, i * elemSize, elemSize); + } + } + } + return typedArray(ctx.dtype(), valPType, rowCount, out.asReadOnly()); + } + + private static Array decodeRustProto(DecodeContext ctx, ByteBuffer metaBuf) { + DictMetadata meta; + try { + MemorySegment metaSeg = MemorySegment.ofBuffer(metaBuf); + meta = DictMetadata.decode(metaSeg, 0, metaSeg.byteSize()); + } catch (IOException e) { + throw new VortexException(EncodingId.VORTEX_DICT, "invalid proto metadata", e); + } + + PType codePType = PType.fromOrdinal(meta.codes_ptype().value()); + long valuesLen = meta.values_len(); + long rowCount = ctx.rowCount(); + PType valPType = ((DType.Primitive) ctx.dtype()).ptype(); + int elemSize = valPType.byteSize(); + + DType codesDtype = new DType.Primitive(codePType, false); + MemorySegment codesBuf = ctx.decodeChildSegment(0, codesDtype, rowCount); + MemorySegment valuesBuf = ctx.decodeChildSegment(1, ctx.dtype(), valuesLen); + + MemorySegment out = ctx.arena().allocate(rowCount * (long) elemSize); + switch (codePType) { + case U8 -> expandU8(codesBuf, valuesBuf, out, rowCount, elemSize); + case U16 -> expandU16(codesBuf, valuesBuf, out, rowCount, elemSize); + case U32 -> expandU32(codesBuf, valuesBuf, out, rowCount, elemSize); + default -> throw new VortexException(EncodingId.VORTEX_DICT, "unexpected code type: " + codePType); + } + return typedArray(ctx.dtype(), valPType, rowCount, out.asReadOnly()); + } + + private static Array decodeUtf8DictLegacy(DecodeContext ctx, ByteBuffer meta) { + PType codePType = PType.fromOrdinal(Byte.toUnsignedInt(meta.get(0))); + long n = ctx.rowCount(); + + MemorySegment dictBytes = ctx.buffer(0); + MemorySegment dictOffsets = ctx.buffer(1); + MemorySegment codes = ctx.buffer(2); + + return VarBinArray.ofDict(ctx.dtype(), n, + dictBytes, dictOffsets, PType.I64, + codes, codePType); + } + + private static Array decodeUtf8DictProto(DecodeContext ctx, ByteBuffer metaBuf) { + DictMetadata meta; + try { + MemorySegment metaSeg = MemorySegment.ofBuffer(metaBuf); + meta = DictMetadata.decode(metaSeg, 0, metaSeg.byteSize()); + } catch (IOException e) { + throw new VortexException(EncodingId.VORTEX_DICT, "invalid utf8 dict proto metadata", e); + } + PType codePType = PType.fromOrdinal(meta.codes_ptype().value()); + long dictSize = meta.values_len(); + long n = ctx.rowCount(); + + DType codesDtype = new DType.Primitive(codePType, false); + MemorySegment codesBuf = ctx.decodeChildSegment(0, codesDtype, n); + + Array valuesArr = ctx.decodeChild(1, ctx.dtype(), dictSize); + VarBinArray varBinValues = (VarBinArray) valuesArr; + MemorySegment dictBytes = varBinValues.bytesSegment(); + MemorySegment dictOffsets = varBinValues.offsetsSegment(); + + return VarBinArray.ofDict(ctx.dtype(), n, + dictBytes, dictOffsets, PType.I64, + codesBuf, codePType); + } + + private static long readCode(MemorySegment buf, PType codePType, long i) { + long cap = SegmentBroadcast.capacity(buf, codePType.byteSize()); + long idx = i % cap; + return switch (codePType) { + case U8 -> Byte.toUnsignedLong(buf.get(ValueLayout.JAVA_BYTE, idx)); + case U16 -> Short.toUnsignedLong(buf.get(PTypeIO.LE_SHORT, idx * 2)); + case U32 -> Integer.toUnsignedLong(buf.get(PTypeIO.LE_INT, idx * 4)); + default -> throw new VortexException(EncodingId.VORTEX_DICT, "unexpected code type: " + codePType); + }; + } + + private static void expandU8(MemorySegment codes, MemorySegment values, MemorySegment out, long rowCount, int elemSize) { + long codesCap = SegmentBroadcast.capacity(codes, 1); + long valuesCap = SegmentBroadcast.capacity(values, elemSize); + boolean fast = codesCap >= rowCount && valuesCap > 1; + switch (elemSize) { + case 8 -> { + if (fast) { + for (long i = 0; i < rowCount; i++) { + long code = Byte.toUnsignedLong(codes.get(ValueLayout.JAVA_BYTE, i)); + out.setAtIndex(PTypeIO.LE_LONG, i, values.getAtIndex(PTypeIO.LE_LONG, code)); + } + } else { + for (long i = 0; i < rowCount; i++) { + long code = Byte.toUnsignedLong(codes.get(ValueLayout.JAVA_BYTE, i % codesCap)); + out.setAtIndex(PTypeIO.LE_LONG, i, values.getAtIndex(PTypeIO.LE_LONG, code % valuesCap)); + } + } + } + case 4 -> { + if (fast) { + for (long i = 0; i < rowCount; i++) { + long code = Byte.toUnsignedLong(codes.get(ValueLayout.JAVA_BYTE, i)); + out.setAtIndex(PTypeIO.LE_INT, i, values.getAtIndex(PTypeIO.LE_INT, code)); + } + } else { + for (long i = 0; i < rowCount; i++) { + long code = Byte.toUnsignedLong(codes.get(ValueLayout.JAVA_BYTE, i % codesCap)); + out.setAtIndex(PTypeIO.LE_INT, i, values.getAtIndex(PTypeIO.LE_INT, code % valuesCap)); + } + } + } + case 2 -> { + if (fast) { + for (long i = 0; i < rowCount; i++) { + long code = Byte.toUnsignedLong(codes.get(ValueLayout.JAVA_BYTE, i)); + out.setAtIndex(PTypeIO.LE_SHORT, i, values.getAtIndex(PTypeIO.LE_SHORT, code)); + } + } else { + for (long i = 0; i < rowCount; i++) { + long code = Byte.toUnsignedLong(codes.get(ValueLayout.JAVA_BYTE, i % codesCap)); + out.setAtIndex(PTypeIO.LE_SHORT, i, values.getAtIndex(PTypeIO.LE_SHORT, code % valuesCap)); + } + } + } + case 1 -> { + if (fast) { + for (long i = 0; i < rowCount; i++) { + long code = Byte.toUnsignedLong(codes.get(ValueLayout.JAVA_BYTE, i)); + out.set(ValueLayout.JAVA_BYTE, i, values.get(ValueLayout.JAVA_BYTE, code)); + } + } else { + for (long i = 0; i < rowCount; i++) { + long code = Byte.toUnsignedLong(codes.get(ValueLayout.JAVA_BYTE, i % codesCap)); + out.set(ValueLayout.JAVA_BYTE, i, values.get(ValueLayout.JAVA_BYTE, code % valuesCap)); + } + } + } + default -> { + if (fast) { + for (long i = 0, outOff = 0; i < rowCount; i++, outOff += elemSize) { + long code = Byte.toUnsignedLong(codes.get(ValueLayout.JAVA_BYTE, i)); + MemorySegment.copy(values, code * elemSize, out, outOff, elemSize); + } + } else { + for (long i = 0, outOff = 0; i < rowCount; i++, outOff += elemSize) { + long code = Byte.toUnsignedLong(codes.get(ValueLayout.JAVA_BYTE, i % codesCap)); + MemorySegment.copy(values, (code % valuesCap) * elemSize, out, outOff, elemSize); + } + } + } + } + } + + private static void expandU16(MemorySegment codes, MemorySegment values, MemorySegment out, long rowCount, int elemSize) { + long codesCap = SegmentBroadcast.capacity(codes, 2); + long valuesCap = SegmentBroadcast.capacity(values, elemSize); + boolean fast = codesCap >= rowCount && valuesCap > 1; + switch (elemSize) { + case 8 -> { + if (fast) { + for (long i = 0; i < rowCount; i++) { + long code = Short.toUnsignedLong(codes.get(PTypeIO.LE_SHORT, i * 2)); + out.setAtIndex(PTypeIO.LE_LONG, i, values.getAtIndex(PTypeIO.LE_LONG, code)); + } + } else { + for (long i = 0; i < rowCount; i++) { + long code = Short.toUnsignedLong(codes.get(PTypeIO.LE_SHORT, (i % codesCap) * 2)); + out.setAtIndex(PTypeIO.LE_LONG, i, values.getAtIndex(PTypeIO.LE_LONG, code % valuesCap)); + } + } + } + case 4 -> { + if (fast) { + for (long i = 0; i < rowCount; i++) { + long code = Short.toUnsignedLong(codes.get(PTypeIO.LE_SHORT, i * 2)); + out.setAtIndex(PTypeIO.LE_INT, i, values.getAtIndex(PTypeIO.LE_INT, code)); + } + } else { + for (long i = 0; i < rowCount; i++) { + long code = Short.toUnsignedLong(codes.get(PTypeIO.LE_SHORT, (i % codesCap) * 2)); + out.setAtIndex(PTypeIO.LE_INT, i, values.getAtIndex(PTypeIO.LE_INT, code % valuesCap)); + } + } + } + case 2 -> { + if (fast) { + for (long i = 0; i < rowCount; i++) { + long code = Short.toUnsignedLong(codes.get(PTypeIO.LE_SHORT, i * 2)); + out.setAtIndex(PTypeIO.LE_SHORT, i, values.getAtIndex(PTypeIO.LE_SHORT, code)); + } + } else { + for (long i = 0; i < rowCount; i++) { + long code = Short.toUnsignedLong(codes.get(PTypeIO.LE_SHORT, (i % codesCap) * 2)); + out.setAtIndex(PTypeIO.LE_SHORT, i, values.getAtIndex(PTypeIO.LE_SHORT, code % valuesCap)); + } + } + } + case 1 -> { + if (fast) { + for (long i = 0; i < rowCount; i++) { + long code = Short.toUnsignedLong(codes.get(PTypeIO.LE_SHORT, i * 2)); + out.set(ValueLayout.JAVA_BYTE, i, values.get(ValueLayout.JAVA_BYTE, code)); + } + } else { + for (long i = 0; i < rowCount; i++) { + long code = Short.toUnsignedLong(codes.get(PTypeIO.LE_SHORT, (i % codesCap) * 2)); + out.set(ValueLayout.JAVA_BYTE, i, values.get(ValueLayout.JAVA_BYTE, code % valuesCap)); + } + } + } + default -> { + if (fast) { + for (long i = 0, outOff = 0; i < rowCount; i++, outOff += elemSize) { + long code = Short.toUnsignedLong(codes.get(PTypeIO.LE_SHORT, i * 2)); + MemorySegment.copy(values, code * elemSize, out, outOff, elemSize); + } + } else { + for (long i = 0, outOff = 0; i < rowCount; i++, outOff += elemSize) { + long code = Short.toUnsignedLong(codes.get(PTypeIO.LE_SHORT, (i % codesCap) * 2)); + MemorySegment.copy(values, (code % valuesCap) * elemSize, out, outOff, elemSize); + } + } + } + } + } + + private static void expandU32(MemorySegment codes, MemorySegment values, MemorySegment out, long rowCount, int elemSize) { + long codesCap = SegmentBroadcast.capacity(codes, 4); + long valuesCap = SegmentBroadcast.capacity(values, elemSize); + boolean fast = codesCap >= rowCount && valuesCap > 1; + switch (elemSize) { + case 8 -> { + if (fast) { + for (long i = 0; i < rowCount; i++) { + long code = Integer.toUnsignedLong(codes.get(PTypeIO.LE_INT, i * 4)); + out.setAtIndex(PTypeIO.LE_LONG, i, values.getAtIndex(PTypeIO.LE_LONG, code)); + } + } else { + for (long i = 0; i < rowCount; i++) { + long code = Integer.toUnsignedLong(codes.get(PTypeIO.LE_INT, (i % codesCap) * 4)); + out.setAtIndex(PTypeIO.LE_LONG, i, values.getAtIndex(PTypeIO.LE_LONG, code % valuesCap)); + } + } + } + case 4 -> { + if (fast) { + for (long i = 0; i < rowCount; i++) { + long code = Integer.toUnsignedLong(codes.get(PTypeIO.LE_INT, i * 4)); + out.setAtIndex(PTypeIO.LE_INT, i, values.getAtIndex(PTypeIO.LE_INT, code)); + } + } else { + for (long i = 0; i < rowCount; i++) { + long code = Integer.toUnsignedLong(codes.get(PTypeIO.LE_INT, (i % codesCap) * 4)); + out.setAtIndex(PTypeIO.LE_INT, i, values.getAtIndex(PTypeIO.LE_INT, code % valuesCap)); + } + } + } + case 2 -> { + if (fast) { + for (long i = 0; i < rowCount; i++) { + long code = Integer.toUnsignedLong(codes.get(PTypeIO.LE_INT, i * 4)); + out.setAtIndex(PTypeIO.LE_SHORT, i, values.getAtIndex(PTypeIO.LE_SHORT, code)); + } + } else { + for (long i = 0; i < rowCount; i++) { + long code = Integer.toUnsignedLong(codes.get(PTypeIO.LE_INT, (i % codesCap) * 4)); + out.setAtIndex(PTypeIO.LE_SHORT, i, values.getAtIndex(PTypeIO.LE_SHORT, code % valuesCap)); + } + } + } + case 1 -> { + if (fast) { + for (long i = 0; i < rowCount; i++) { + long code = Integer.toUnsignedLong(codes.get(PTypeIO.LE_INT, i * 4)); + out.set(ValueLayout.JAVA_BYTE, i, values.get(ValueLayout.JAVA_BYTE, code)); + } + } else { + for (long i = 0; i < rowCount; i++) { + long code = Integer.toUnsignedLong(codes.get(PTypeIO.LE_INT, (i % codesCap) * 4)); + out.set(ValueLayout.JAVA_BYTE, i, values.get(ValueLayout.JAVA_BYTE, code % valuesCap)); + } + } + } + default -> { + if (fast) { + for (long i = 0, outOff = 0; i < rowCount; i++, outOff += elemSize) { + long code = Integer.toUnsignedLong(codes.get(PTypeIO.LE_INT, i * 4)); + MemorySegment.copy(values, code * elemSize, out, outOff, elemSize); + } + } else { + for (long i = 0, outOff = 0; i < rowCount; i++, outOff += elemSize) { + long code = Integer.toUnsignedLong(codes.get(PTypeIO.LE_INT, (i % codesCap) * 4)); + MemorySegment.copy(values, (code % valuesCap) * elemSize, out, outOff, elemSize); + } + } + } + } + } + + private static Array typedArray(DType dtype, PType ptype, long n, MemorySegment seg) { + return switch (ptype) { + case I64, U64 -> new LongArray(dtype, n, seg); + case I32, U32 -> new IntArray(dtype, n, seg); + case F64 -> new DoubleArray(dtype, n, seg); + case F32 -> new FloatArray(dtype, n, seg); + case I16, U16 -> new ShortArray(dtype, n, seg); + case I8, U8 -> new ByteArray(dtype, n, seg); + default -> throw new VortexException(EncodingId.VORTEX_DICT, "unsupported ptype " + ptype); + }; + } +} diff --git a/reader/src/main/java/io/github/dfa1/vortex/reader/decode/ExtEncodingDecoder.java b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/ExtEncodingDecoder.java new file mode 100644 index 00000000..efb7f308 --- /dev/null +++ b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/ExtEncodingDecoder.java @@ -0,0 +1,40 @@ +package io.github.dfa1.vortex.reader.decode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.VortexException; +import io.github.dfa1.vortex.core.array.Array; +import io.github.dfa1.vortex.encoding.ArrayNode; +import io.github.dfa1.vortex.encoding.DecodeContext; +import io.github.dfa1.vortex.encoding.EncodingDecoder; +import io.github.dfa1.vortex.encoding.EncodingId; + +/// Read-only decoder for {@code vortex.ext} — unwraps the storage-array child. +public final class ExtEncodingDecoder implements EncodingDecoder { + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public ExtEncodingDecoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_EXT; + } + + @Override + public boolean accepts(DType dtype) { + return dtype instanceof DType.Extension; + } + + @Override + public Array decode(DecodeContext ctx) { + if (!(ctx.dtype() instanceof DType.Extension ext)) { + throw new VortexException(EncodingId.VORTEX_EXT, "expected extension dtype, got " + ctx.dtype()); + } + long n = ctx.rowCount(); + ArrayNode childNode = ctx.node().children()[0]; + DecodeContext childCtx = new DecodeContext( + childNode, ext.storageDType(), n, + ctx.segmentBuffers(), ctx.registry(), ctx.arena()); + return ctx.registry().decode(childCtx); + } +} diff --git a/reader/src/main/java/io/github/dfa1/vortex/reader/decode/FixedSizeListEncodingDecoder.java b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/FixedSizeListEncodingDecoder.java new file mode 100644 index 00000000..693346f6 --- /dev/null +++ b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/FixedSizeListEncodingDecoder.java @@ -0,0 +1,54 @@ +package io.github.dfa1.vortex.reader.decode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.VortexException; +import io.github.dfa1.vortex.core.array.Array; +import io.github.dfa1.vortex.core.array.FixedSizeListArray; +import io.github.dfa1.vortex.encoding.ArrayNode; +import io.github.dfa1.vortex.encoding.DecodeContext; +import io.github.dfa1.vortex.encoding.EncodingDecoder; +import io.github.dfa1.vortex.encoding.EncodingId; + +/// Read-only decoder for {@code vortex.fixed_size_list}. +public final class FixedSizeListEncodingDecoder implements EncodingDecoder { + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public FixedSizeListEncodingDecoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_FIXED_SIZE_LIST; + } + + @Override + public boolean accepts(DType dtype) { + return dtype instanceof DType.FixedSizeList; + } + + @Override + public Array decode(DecodeContext ctx) { + if (!(ctx.dtype() instanceof DType.FixedSizeList fsl)) { + throw new VortexException(EncodingId.VORTEX_FIXED_SIZE_LIST, + "expected DType.FixedSizeList, got " + ctx.dtype()); + } + + int nchildren = ctx.node().children().length; + if (nchildren < 1 || nchildren > 2) { + throw new VortexException(EncodingId.VORTEX_FIXED_SIZE_LIST, + "expected 1 or 2 children, got " + nchildren); + } + + long outerLen = ctx.rowCount(); + long elemLen = outerLen * fsl.fixedSize(); + DType elementType = fsl.elementType(); + + ArrayNode elemNode = ctx.node().children()[0]; + var elemCtx = new DecodeContext( + elemNode, elementType, elemLen, + ctx.segmentBuffers(), ctx.registry(), ctx.arena()); + Array elements = ctx.registry().decode(elemCtx); + + return new FixedSizeListArray(fsl, outerLen, elements); + } +} diff --git a/reader/src/main/java/io/github/dfa1/vortex/reader/decode/FrameOfReferenceEncodingDecoder.java b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/FrameOfReferenceEncodingDecoder.java new file mode 100644 index 00000000..03d4feac --- /dev/null +++ b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/FrameOfReferenceEncodingDecoder.java @@ -0,0 +1,132 @@ +package io.github.dfa1.vortex.reader.decode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.PType; +import io.github.dfa1.vortex.core.VortexException; +import io.github.dfa1.vortex.core.array.Array; +import io.github.dfa1.vortex.core.array.ArraySegments; +import io.github.dfa1.vortex.core.array.BoolArray; +import io.github.dfa1.vortex.core.array.ByteArray; +import io.github.dfa1.vortex.core.array.DoubleArray; +import io.github.dfa1.vortex.core.array.IntArray; +import io.github.dfa1.vortex.core.array.LongArray; +import io.github.dfa1.vortex.core.array.MaskedArray; +import io.github.dfa1.vortex.core.array.ShortArray; +import io.github.dfa1.vortex.encoding.DecodeContext; +import io.github.dfa1.vortex.encoding.EncodingDecoder; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.PTypeIO; +import io.github.dfa1.vortex.proto.ScalarValue; + +import java.io.IOException; +import java.lang.foreign.MemorySegment; +import java.lang.foreign.SegmentAllocator; +import java.lang.foreign.ValueLayout; +import java.nio.ByteBuffer; + +/// Read-only decoder for {@code fastlanes.for} (Frame of Reference). +public final class FrameOfReferenceEncodingDecoder implements EncodingDecoder { + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public FrameOfReferenceEncodingDecoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.FASTLANES_FOR; + } + + @Override + public boolean accepts(DType dtype) { + return dtype instanceof DType.Primitive p && !p.ptype().isFloating(); + } + + @Override + public Array decode(DecodeContext ctx) { + ByteBuffer rawMeta = ctx.metadata(); + if (rawMeta == null || !rawMeta.hasRemaining()) { + throw new VortexException(EncodingId.FASTLANES_FOR, "missing metadata"); + } + ScalarValue scalar; + try { + MemorySegment metaSeg = MemorySegment.ofBuffer(rawMeta.duplicate()); + scalar = ScalarValue.decode(metaSeg, 0, metaSeg.byteSize()); + } catch (IOException e) { + throw new VortexException(EncodingId.FASTLANES_FOR, "invalid metadata", e); + } + + Array encoded = ctx.decodeChild(0); + + BoolArray validity = null; + Array rawEncoded = encoded; + if (encoded instanceof MaskedArray masked) { + rawEncoded = masked.inner(); + validity = masked.validity(); + } + + if (!(ctx.dtype() instanceof DType.Primitive p)) { + throw new VortexException(EncodingId.FASTLANES_FOR, "expected primitive dtype, got " + ctx.dtype()); + } + + long ref = referenceValue(scalar); + if (ref == 0L) { + return validity != null ? new MaskedArray(rawEncoded, validity) : rawEncoded; + } + + MemorySegment src = ArraySegments.of(rawEncoded); + long n = ctx.rowCount(); + MemorySegment dst = applyReference(src, n, p.ptype(), ref, ctx.arena()); + Array result = switch (p.ptype()) { + case I64, U64 -> new LongArray(ctx.dtype(), n, dst); + case I32, U32 -> new IntArray(ctx.dtype(), n, dst); + case F64 -> new DoubleArray(ctx.dtype(), n, dst); + case I16, U16 -> new ShortArray(ctx.dtype(), n, dst); + case I8, U8 -> new ByteArray(ctx.dtype(), n, dst); + default -> throw new VortexException(EncodingId.FASTLANES_FOR, "unsupported ptype " + p.ptype()); + }; + return validity != null ? new MaskedArray(result, validity) : result; + } + + private static long referenceValue(ScalarValue scalar) { + if (scalar.int64_value() != null) { + return scalar.int64_value(); + } + if (scalar.uint64_value() != null) { + return scalar.uint64_value(); + } + return 0L; + } + + private static MemorySegment applyReference(MemorySegment src, long n, PType ptype, long ref, SegmentAllocator arena) { + int wordBytes = ptype.byteSize(); + MemorySegment dst = arena.allocate(n * wordBytes); + switch (ptype) { + case I8, U8 -> { + for (long off = 0, end = n; off < end; off++) { + byte v = src.get(ValueLayout.JAVA_BYTE, off); + dst.set(ValueLayout.JAVA_BYTE, off, (byte) (v + (byte) ref)); + } + } + case I16, U16 -> { + for (long off = 0, end = n * 2; off < end; off += 2) { + short v = src.get(PTypeIO.LE_SHORT, off); + dst.set(PTypeIO.LE_SHORT, off, (short) (v + (short) ref)); + } + } + case I32, U32 -> { + for (long off = 0, end = n * 4; off < end; off += 4) { + int v = src.get(PTypeIO.LE_INT, off); + dst.set(PTypeIO.LE_INT, off, v + (int) ref); + } + } + case I64, U64 -> { + for (long off = 0, end = n * 8; off < end; off += 8) { + long v = src.get(PTypeIO.LE_LONG, off); + dst.set(PTypeIO.LE_LONG, off, v + ref); + } + } + default -> throw new VortexException(EncodingId.FASTLANES_FOR, "unsupported ptype " + ptype); + } + return dst; + } +} diff --git a/reader/src/main/java/io/github/dfa1/vortex/reader/decode/FsstEncodingDecoder.java b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/FsstEncodingDecoder.java new file mode 100644 index 00000000..fef40200 --- /dev/null +++ b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/FsstEncodingDecoder.java @@ -0,0 +1,117 @@ +package io.github.dfa1.vortex.reader.decode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.PType; +import io.github.dfa1.vortex.core.VortexException; +import io.github.dfa1.vortex.core.array.Array; +import io.github.dfa1.vortex.core.array.VarBinArray; +import io.github.dfa1.vortex.encoding.DecodeContext; +import io.github.dfa1.vortex.encoding.EncodingDecoder; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.PTypeIO; +import io.github.dfa1.vortex.encoding.SegmentBroadcast; +import io.github.dfa1.vortex.proto.FSSTMetadata; + +import java.io.IOException; +import java.lang.foreign.MemorySegment; +import java.lang.foreign.ValueLayout; +import java.nio.ByteBuffer; + +/// Read-only decoder for {@code vortex.fsst}. +public final class FsstEncodingDecoder implements EncodingDecoder { + + private static final int ESCAPE = 0xFF; + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public FsstEncodingDecoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_FSST; + } + + @Override + public boolean accepts(DType dtype) { + return dtype instanceof DType.Utf8 || dtype instanceof DType.Binary; + } + + @Override + public Array decode(DecodeContext ctx) { + ByteBuffer rawMeta = ctx.metadata(); + if (rawMeta == null) { + throw new VortexException(EncodingId.VORTEX_FSST, "missing metadata"); + } + FSSTMetadata meta; + try { + MemorySegment metaSeg = MemorySegment.ofBuffer(rawMeta.duplicate()); + meta = FSSTMetadata.decode(metaSeg, 0, metaSeg.byteSize()); + } catch (IOException e) { + throw new VortexException(EncodingId.VORTEX_FSST, "invalid metadata", e); + } + + PType uncompLenPType = PType.fromOrdinal(meta.uncompressed_lengths_ptype().value()); + PType codesOffPType = PType.fromOrdinal(meta.codes_offsets_ptype().value()); + + long n = ctx.rowCount(); + + MemorySegment symbolsBuf = ctx.buffer(0); + MemorySegment symbolLensBuf = ctx.buffer(1); + MemorySegment compressedBytes = ctx.buffer(2); + + MemorySegment uncompLensSeg = ctx.decodeChildSegment(0, new DType.Primitive(uncompLenPType, false), n); + MemorySegment codesOffsetsSeg = ctx.decodeChildSegment(1, new DType.Primitive(codesOffPType, false), n + 1); + long uncompLensCap = SegmentBroadcast.capacity(uncompLensSeg, uncompLenPType.byteSize()); + long codesOffCap = SegmentBroadcast.capacity(codesOffsetsSeg, codesOffPType.byteSize()); + + long totalUncompressed = 0L; + for (long i = 0; i < n; i++) { + totalUncompressed += readUnsigned(uncompLensSeg, i % uncompLensCap, uncompLenPType); + } + + MemorySegment outBytes = ctx.arena().allocate(totalUncompressed); + MemorySegment outOffsets = ctx.arena().allocate((n + 1) * 4L, 4); + outOffsets.setAtIndex(PTypeIO.LE_INT, 0, 0); + + long outPos = 0L; + for (long i = 0; i < n; i++) { + long cStart = readUnsigned(codesOffsetsSeg, i % codesOffCap, codesOffPType); + long cEnd = readUnsigned(codesOffsetsSeg, (i + 1) % codesOffCap, codesOffPType); + outPos = decompressString(compressedBytes, symbolsBuf, symbolLensBuf, + cStart, cEnd, outBytes, outPos); + outOffsets.setAtIndex(PTypeIO.LE_INT, i + 1, (int) outPos); + } + + return new VarBinArray(ctx.dtype(), n, outBytes.asReadOnly(), outOffsets.asReadOnly(), PType.I32); + } + + private static long decompressString( + MemorySegment compressed, MemorySegment symbols, MemorySegment symLens, + long start, long end, MemorySegment out, long outPos + ) { + for (long j = start; j < end; j++) { + int b = Byte.toUnsignedInt(compressed.get(ValueLayout.JAVA_BYTE, j)); + if (b == ESCAPE) { + out.set(ValueLayout.JAVA_BYTE, outPos++, compressed.get(ValueLayout.JAVA_BYTE, ++j)); + } else { + int symLen = Byte.toUnsignedInt(symLens.get(ValueLayout.JAVA_BYTE, b)); + long sym = symbols.getAtIndex(PTypeIO.LE_LONG, b); + for (int k = 0; k < symLen; k++) { + out.set(ValueLayout.JAVA_BYTE, outPos++, (byte) (sym >>> (k * 8))); + } + } + } + return outPos; + } + + private static long readUnsigned(MemorySegment seg, long idx, PType ptype) { + return switch (ptype) { + case U8 -> Byte.toUnsignedLong(seg.get(ValueLayout.JAVA_BYTE, idx)); + case U16 -> Short.toUnsignedLong(seg.get(PTypeIO.LE_SHORT, idx * 2)); + case U32 -> Integer.toUnsignedLong(seg.getAtIndex(PTypeIO.LE_INT, idx)); + case I32 -> seg.getAtIndex(PTypeIO.LE_INT, idx); + case I64, U64 -> seg.getAtIndex(PTypeIO.LE_LONG, idx); + default -> throw new VortexException(EncodingId.VORTEX_FSST, "unsupported ptype " + ptype); + }; + } +} diff --git a/reader/src/main/java/io/github/dfa1/vortex/reader/decode/ListEncodingDecoder.java b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/ListEncodingDecoder.java new file mode 100644 index 00000000..5523030a --- /dev/null +++ b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/ListEncodingDecoder.java @@ -0,0 +1,66 @@ +package io.github.dfa1.vortex.reader.decode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.PType; +import io.github.dfa1.vortex.core.VortexException; +import io.github.dfa1.vortex.core.array.Array; +import io.github.dfa1.vortex.core.array.ListArray; +import io.github.dfa1.vortex.encoding.DecodeContext; +import io.github.dfa1.vortex.encoding.EncodingDecoder; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.proto.ListMetadata; + +import java.io.IOException; +import java.lang.foreign.MemorySegment; + +/// Read-only decoder for {@code vortex.list}. +public final class ListEncodingDecoder implements EncodingDecoder { + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public ListEncodingDecoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_LIST; + } + + @Override + public boolean accepts(DType dtype) { + return dtype instanceof DType.List; + } + + @Override + public Array decode(DecodeContext ctx) { + if (!(ctx.dtype() instanceof DType.List listDtype)) { + throw new VortexException(EncodingId.VORTEX_LIST, + "expected DType.List, got " + ctx.dtype()); + } + + int nchildren = ctx.node().children().length; + if (nchildren < 2 || nchildren > 3) { + throw new VortexException(EncodingId.VORTEX_LIST, + "expected 2 or 3 children, got " + nchildren); + } + + ListMetadata meta; + try { + MemorySegment metaSeg = MemorySegment.ofBuffer(ctx.metadata().duplicate()); + meta = ListMetadata.decode(metaSeg, 0, metaSeg.byteSize()); + } catch (IOException e) { + throw new VortexException(EncodingId.VORTEX_LIST, "invalid metadata", e); + } + + long elementsLen = meta.elements_len(); + PType offsetPtype = PType.fromOrdinal(meta.offset_ptype().value()); + long outerLen = ctx.rowCount(); + + DType elementDtype = listDtype.elementType(); + DType offsetsDtype = new DType.Primitive(offsetPtype, false); + + Array elements = ctx.decodeChild(0, elementDtype, elementsLen); + Array offsets = ctx.decodeChild(1, offsetsDtype, outerLen + 1); + + return new ListArray(listDtype, outerLen, elements, offsets); + } +} diff --git a/reader/src/main/java/io/github/dfa1/vortex/reader/decode/ListViewEncodingDecoder.java b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/ListViewEncodingDecoder.java new file mode 100644 index 00000000..5293e19a --- /dev/null +++ b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/ListViewEncodingDecoder.java @@ -0,0 +1,69 @@ +package io.github.dfa1.vortex.reader.decode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.PType; +import io.github.dfa1.vortex.core.VortexException; +import io.github.dfa1.vortex.core.array.Array; +import io.github.dfa1.vortex.core.array.ListViewArray; +import io.github.dfa1.vortex.encoding.DecodeContext; +import io.github.dfa1.vortex.encoding.EncodingDecoder; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.proto.ListViewMetadata; + +import java.io.IOException; +import java.lang.foreign.MemorySegment; + +/// Read-only decoder for {@code vortex.listview}. +public final class ListViewEncodingDecoder implements EncodingDecoder { + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public ListViewEncodingDecoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_LISTVIEW; + } + + @Override + public boolean accepts(DType dtype) { + return dtype instanceof DType.List; + } + + @Override + public Array decode(DecodeContext ctx) { + if (!(ctx.dtype() instanceof DType.List listDtype)) { + throw new VortexException(EncodingId.VORTEX_LISTVIEW, + "expected DType.List, got " + ctx.dtype()); + } + + int nchildren = ctx.node().children().length; + if (nchildren < 3 || nchildren > 4) { + throw new VortexException(EncodingId.VORTEX_LISTVIEW, + "expected 3 or 4 children, got " + nchildren); + } + + ListViewMetadata meta; + try { + MemorySegment metaSeg = MemorySegment.ofBuffer(ctx.metadata().duplicate()); + meta = ListViewMetadata.decode(metaSeg, 0, metaSeg.byteSize()); + } catch (IOException e) { + throw new VortexException(EncodingId.VORTEX_LISTVIEW, "invalid metadata", e); + } + + long elementsLen = meta.elements_len(); + PType offsetPtype = PType.fromOrdinal(meta.offset_ptype().value()); + PType sizePtype = PType.fromOrdinal(meta.size_ptype().value()); + long outerLen = ctx.rowCount(); + + DType elementDtype = listDtype.elementType(); + DType offsetsDtype = new DType.Primitive(offsetPtype, false); + DType sizesDtype = new DType.Primitive(sizePtype, false); + + Array elements = ctx.decodeChild(0, elementDtype, elementsLen); + Array offsets = ctx.decodeChild(1, offsetsDtype, outerLen); + Array sizes = ctx.decodeChild(2, sizesDtype, outerLen); + + return new ListViewArray(listDtype, outerLen, elements, offsets, sizes); + } +} diff --git a/reader/src/main/java/io/github/dfa1/vortex/reader/decode/MaskedEncodingDecoder.java b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/MaskedEncodingDecoder.java new file mode 100644 index 00000000..633a09b1 --- /dev/null +++ b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/MaskedEncodingDecoder.java @@ -0,0 +1,55 @@ +package io.github.dfa1.vortex.reader.decode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.VortexException; +import io.github.dfa1.vortex.core.array.Array; +import io.github.dfa1.vortex.core.array.BoolArray; +import io.github.dfa1.vortex.core.array.MaskedArray; +import io.github.dfa1.vortex.encoding.DecodeContext; +import io.github.dfa1.vortex.encoding.EncodingDecoder; +import io.github.dfa1.vortex.encoding.EncodingId; + +/// Read-only decoder for {@code vortex.masked} — payload child + optional validity bitmap child. +public final class MaskedEncodingDecoder implements EncodingDecoder { + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public MaskedEncodingDecoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_MASKED; + } + + @Override + public boolean accepts(DType dtype) { + return false; + } + + @Override + public Array decode(DecodeContext ctx) { + if (ctx.node().bufferIndices().length != 0) { + throw new VortexException(EncodingId.VORTEX_MASKED, + "expected 0 buffers, got " + ctx.node().bufferIndices().length); + } + int numChildren = ctx.node().children().length; + if (numChildren < 1 || numChildren > 2) { + throw new VortexException(EncodingId.VORTEX_MASKED, + "expected 1 or 2 children, got " + numChildren); + } + + Array child = ctx.decodeChild(0, ctx.dtype().withNullable(false), ctx.rowCount()); + + BoolArray validity = null; + if (numChildren == 2) { + Array validityArray = ctx.decodeChild(1, new DType.Bool(false), ctx.rowCount()); + if (!(validityArray instanceof BoolArray ba)) { + throw new VortexException(EncodingId.VORTEX_MASKED, + "validity child decoded to unexpected type: " + validityArray.getClass().getSimpleName()); + } + validity = ba; + } + + return new MaskedArray(child, validity); + } +} diff --git a/reader/src/main/java/io/github/dfa1/vortex/reader/decode/NullEncodingDecoder.java b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/NullEncodingDecoder.java new file mode 100644 index 00000000..dc84b926 --- /dev/null +++ b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/NullEncodingDecoder.java @@ -0,0 +1,31 @@ +package io.github.dfa1.vortex.reader.decode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.array.Array; +import io.github.dfa1.vortex.core.array.NullArray; +import io.github.dfa1.vortex.encoding.DecodeContext; +import io.github.dfa1.vortex.encoding.EncodingDecoder; +import io.github.dfa1.vortex.encoding.EncodingId; + +/// Read-only decoder for {@code vortex.null} (all-null arrays). +public final class NullEncodingDecoder implements EncodingDecoder { + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public NullEncodingDecoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_NULL; + } + + @Override + public boolean accepts(DType dtype) { + return dtype instanceof DType.Null; + } + + @Override + public Array decode(DecodeContext ctx) { + return new NullArray(ctx.dtype(), ctx.rowCount()); + } +} diff --git a/reader/src/main/java/io/github/dfa1/vortex/reader/decode/PatchedEncodingDecoder.java b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/PatchedEncodingDecoder.java new file mode 100644 index 00000000..b42b6a68 --- /dev/null +++ b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/PatchedEncodingDecoder.java @@ -0,0 +1,126 @@ +package io.github.dfa1.vortex.reader.decode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.PType; +import io.github.dfa1.vortex.core.VortexException; +import io.github.dfa1.vortex.core.array.Array; +import io.github.dfa1.vortex.core.array.ByteArray; +import io.github.dfa1.vortex.core.array.DoubleArray; +import io.github.dfa1.vortex.core.array.FloatArray; +import io.github.dfa1.vortex.core.array.IntArray; +import io.github.dfa1.vortex.core.array.LongArray; +import io.github.dfa1.vortex.core.array.ShortArray; +import io.github.dfa1.vortex.encoding.DecodeContext; +import io.github.dfa1.vortex.encoding.EncodingDecoder; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.PTypeIO; +import io.github.dfa1.vortex.encoding.SegmentBroadcast; +import io.github.dfa1.vortex.proto.PatchedMetadata; + +import java.io.IOException; +import java.lang.foreign.MemorySegment; +import java.nio.ByteBuffer; + +/// Read-only decoder for {@code vortex.patched}. +public final class PatchedEncodingDecoder implements EncodingDecoder { + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public PatchedEncodingDecoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_PATCHED; + } + + @Override + public boolean accepts(DType dtype) { + return false; + } + + @Override + public Array decode(DecodeContext ctx) { + ByteBuffer rawMeta = ctx.metadata(); + if (rawMeta == null || !rawMeta.hasRemaining()) { + throw new VortexException(EncodingId.VORTEX_PATCHED, "missing metadata"); + } + + long nPatches; + long nLanes; + long offset; + try { + MemorySegment metaSeg = MemorySegment.ofBuffer(rawMeta.duplicate()); + PatchedMetadata meta = PatchedMetadata.decode(metaSeg, 0, metaSeg.byteSize()); + nPatches = Integer.toUnsignedLong(meta.n_patches()); + nLanes = Integer.toUnsignedLong(meta.n_lanes()); + offset = Integer.toUnsignedLong(meta.offset()); + } catch (IOException e) { + throw new VortexException(EncodingId.VORTEX_PATCHED, "invalid metadata", e); + } + + if (nLanes == 0) { + throw new VortexException(EncodingId.VORTEX_PATCHED, "n_lanes must be > 0"); + } + + if (!(ctx.dtype() instanceof DType.Primitive p)) { + throw new VortexException(EncodingId.VORTEX_PATCHED, + "expected primitive dtype, got " + ctx.dtype()); + } + + PType ptype = p.ptype(); + long n = ctx.rowCount(); + long nChunks = (n + offset + 1023) / 1024; + int elemBytes = ptype.byteSize(); + + MemorySegment innerSeg = ctx.decodeChildSegment(0, ctx.dtype(), n); + MemorySegment laneOffsetsSeg = ctx.decodeChildSegment(1, + new DType.Primitive(PType.U32, false), nChunks * nLanes + 1); + MemorySegment patchIndicesSeg = ctx.decodeChildSegment(2, + new DType.Primitive(PType.U16, false), nPatches); + MemorySegment patchValuesSeg = ctx.decodeChildSegment(3, ctx.dtype(), nPatches); + + MemorySegment out = ctx.arena().allocate(n * elemBytes); + SegmentBroadcast.broadcastCopy(innerSeg, out, n, elemBytes); + + if (nPatches > 0) { + applyPatches(out, n, nChunks, nLanes, offset, elemBytes, + laneOffsetsSeg, patchIndicesSeg, patchValuesSeg); + } + + return switch (ptype) { + case I8, U8 -> new ByteArray(ctx.dtype(), n, out); + case I16, U16 -> new ShortArray(ctx.dtype(), n, out); + case I32, U32 -> new IntArray(ctx.dtype(), n, out); + case I64, U64 -> new LongArray(ctx.dtype(), n, out); + case F32 -> new FloatArray(ctx.dtype(), n, out); + case F64 -> new DoubleArray(ctx.dtype(), n, out); + default -> throw new VortexException(EncodingId.VORTEX_PATCHED, + "unsupported ptype: " + ptype); + }; + } + + private static void applyPatches( + MemorySegment out, long n, long nChunks, long nLanes, long offset, int elemBytes, + MemorySegment laneOffsets, MemorySegment patchIndices, MemorySegment patchValues + ) { + long laneCap = SegmentBroadcast.capacity(laneOffsets, 4); + long idxCap = SegmentBroadcast.capacity(patchIndices, 2); + long valCap = SegmentBroadcast.capacity(patchValues, elemBytes); + for (long chunk = 0; chunk < nChunks; chunk++) { + long start = Integer.toUnsignedLong( + laneOffsets.getAtIndex(PTypeIO.LE_INT, (chunk * nLanes) % laneCap)); + long stop = Integer.toUnsignedLong( + laneOffsets.getAtIndex(PTypeIO.LE_INT, (chunk * nLanes + nLanes) % laneCap)); + + for (long i = start; i < stop; i++) { + long physicalIdx = chunk * 1024 + + Short.toUnsignedLong(patchIndices.getAtIndex(PTypeIO.LE_SHORT, i % idxCap)); + if (physicalIdx < offset || physicalIdx >= offset + n) { + continue; + } + long outputIdx = physicalIdx - offset; + MemorySegment.copy(patchValues, (i % valCap) * elemBytes, out, outputIdx * elemBytes, elemBytes); + } + } + } +} diff --git a/reader/src/main/java/io/github/dfa1/vortex/reader/decode/PcoEncodingDecoder.java b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/PcoEncodingDecoder.java new file mode 100644 index 00000000..ed459efb --- /dev/null +++ b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/PcoEncodingDecoder.java @@ -0,0 +1,784 @@ +package io.github.dfa1.vortex.reader.decode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.PType; +import io.github.dfa1.vortex.core.VortexException; +import io.github.dfa1.vortex.core.array.Array; +import io.github.dfa1.vortex.core.array.BoolArray; +import io.github.dfa1.vortex.core.array.DoubleArray; +import io.github.dfa1.vortex.core.array.FloatArray; +import io.github.dfa1.vortex.core.array.IntArray; +import io.github.dfa1.vortex.core.array.LongArray; +import io.github.dfa1.vortex.core.array.MaskedArray; +import io.github.dfa1.vortex.core.array.ShortArray; +import io.github.dfa1.vortex.encoding.DecodeContext; +import io.github.dfa1.vortex.encoding.EncodingDecoder; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.LeBitReader; +import io.github.dfa1.vortex.encoding.PTypeIO; +import io.github.dfa1.vortex.encoding.PcoBin; +import io.github.dfa1.vortex.encoding.PcoEncoding; +import io.github.dfa1.vortex.encoding.PcoTansDecoder; +import io.github.dfa1.vortex.proto.PcoChunkInfo; +import io.github.dfa1.vortex.proto.PcoMetadata; + +import java.io.IOException; +import java.lang.foreign.MemorySegment; +import java.lang.foreign.SegmentAllocator; +import java.lang.foreign.ValueLayout; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; + +/// Read-only decoder for {@code vortex.pco} — port of pcodec. +public final class PcoEncodingDecoder implements EncodingDecoder { + + private static final ValueLayout.OfLong LE_LONG = + ValueLayout.JAVA_LONG_UNALIGNED.withOrder(ByteOrder.LITTLE_ENDIAN); + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public PcoEncodingDecoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_PCO; + } + + @Override + public boolean accepts(DType dtype) { + return false; + } + + @Override + public Array decode(DecodeContext ctx) { + PcoMetadata meta = parseMeta(ctx); + validateHeader(meta); + + DType dtype = ctx.dtype(); + if (!(dtype instanceof DType.Primitive dt)) { + throw new VortexException(EncodingId.VORTEX_PCO, + "pco decode requires Primitive dtype, got: " + dtype); + } + PType ptype = dt.ptype(); + int dtypeSize = dtypeSize(ptype); + + long n = ctx.rowCount(); + + BoolArray validity = null; + long validCount = n; + if (ctx.node().children().length > 0) { + Array validityArr = ctx.decodeChild(0, new DType.Bool(false), n); + if (!(validityArr instanceof BoolArray ba)) { + throw new VortexException(EncodingId.VORTEX_PCO, + "pco validity child must be Bool, got: " + validityArr.getClass().getSimpleName()); + } + validity = ba; + validCount = 0; + for (long i = 0; i < n; i++) { + if (validity.getBoolean(i)) { + validCount++; + } + } + } + + MemorySegment rawLatents = ctx.arena().allocate(validCount * Long.BYTES); + + int nChunks = meta.chunks().size(); + int bufIdx = 0; + long rawByteOffset = 0L; + + long[] batchLowers1 = new long[PcoTansDecoder.BATCH_N]; + int[] batchOffsetBits1 = new int[PcoTansDecoder.BATCH_N]; + long[] batchLowers2 = new long[PcoTansDecoder.BATCH_N]; + int[] batchOffsetBits2 = new int[PcoTansDecoder.BATCH_N]; + + for (int c = 0; c < nChunks; c++) { + PcoChunkInfo chunkInfo = meta.chunks().get(c); + MemorySegment chunkMetaBuf = ctx.buffer(bufIdx++); + PcoChunkMeta chunkMeta = readChunkMeta(chunkMetaBuf, dtypeSize); + + int mode = chunkMeta.mode(); + int deltaVariant = chunkMeta.deltaVariant(); + long chunkStartOffset = rawByteOffset; + + int chunkN = 0; + for (int p = 0; p < chunkInfo.pages().size(); p++) { + chunkN += chunkInfo.pages().get(p).n_values(); + } + + if (deltaVariant == 3) { + PcoTansDecoder primaryTans = PcoTansDecoder.build( + chunkMeta.ansSizeLog(), chunkMeta.bins()); + for (int p = 0; p < chunkInfo.pages().size(); p++) { + int pageN = chunkInfo.pages().get(p).n_values(); + MemorySegment pageBuf = ctx.buffer(bufIdx++); + rawByteOffset = decodeConv1Page( + primaryTans, chunkMeta.ansSizeLog(), + chunkMeta.conv1Weights().length, + chunkMeta.conv1Quantization(), chunkMeta.conv1Bias(), + chunkMeta.conv1Weights(), + dtypeSize, pageBuf, pageN, + rawLatents, rawByteOffset, + batchLowers1, batchOffsetBits1); + } + } else if (deltaVariant == 2) { + if (mode != 0) { + throw new VortexException(EncodingId.VORTEX_PCO, + "pco Lookback delta with non-Classic mode " + mode + " not yet implemented"); + } + PcoTansDecoder deltaTans = PcoTansDecoder.build( + chunkMeta.deltaAnsSizeLog(), chunkMeta.deltaBins()); + PcoTansDecoder primaryTans = PcoTansDecoder.build( + chunkMeta.ansSizeLog(), chunkMeta.bins()); + int stateN = 1 << chunkMeta.stateNLog(); + int windowN = 1 << chunkMeta.windowNLog(); + long mid = typeMid(dtypeSize); + long mask = typeMask(dtypeSize); + for (int p = 0; p < chunkInfo.pages().size(); p++) { + int pageN = chunkInfo.pages().get(p).n_values(); + MemorySegment pageBuf = ctx.buffer(bufIdx++); + rawByteOffset = decodeLookbackPage( + deltaTans, chunkMeta.deltaAnsSizeLog(), + primaryTans, chunkMeta.ansSizeLog(), + stateN, windowN, mid, mask, + dtypeSize, pageBuf, pageN, + rawLatents, rawByteOffset, ctx.arena(), + batchLowers1, batchOffsetBits1, + batchLowers2, batchOffsetBits2); + } + } else if (mode == 0 || mode == 4) { + int primaryDtypeSize = (mode == 4) ? 32 : dtypeSize; + PcoTansDecoder tans = PcoTansDecoder.build(chunkMeta.ansSizeLog(), chunkMeta.bins()); + for (int p = 0; p < chunkInfo.pages().size(); p++) { + int pageN = chunkInfo.pages().get(p).n_values(); + MemorySegment pageBuf = ctx.buffer(bufIdx++); + rawByteOffset = decodeClassicPage(tans, chunkMeta.ansSizeLog(), + chunkMeta.deltaOrder(), primaryDtypeSize, + pageBuf, pageN, rawLatents, rawByteOffset, + batchLowers1, batchOffsetBits1); + } + if (mode == 4) { + combineDict(chunkMeta.dict(), chunkN, rawLatents, chunkStartOffset); + } + } else { + long base = chunkMeta.base(); + int primaryAnsSizeLog = chunkMeta.ansSizeLog(); + int secondaryAnsSizeLog = chunkMeta.secondaryAnsSizeLog(); + PcoTansDecoder primaryTans = PcoTansDecoder.build(primaryAnsSizeLog, chunkMeta.bins()); + PcoTansDecoder secondaryTans = PcoTansDecoder.build(secondaryAnsSizeLog, chunkMeta.secondaryBins()); + int deltaOrder = chunkMeta.deltaOrder(); + int secondaryDeltaOrder = chunkMeta.secondaryUsesDelta() ? deltaOrder : 0; + + MemorySegment rawAdjs = ctx.arena().allocate((long) chunkN * Long.BYTES); + long adjByteOffset = 0L; + for (int p = 0; p < chunkInfo.pages().size(); p++) { + int pageN = chunkInfo.pages().get(p).n_values(); + MemorySegment pageBuf = ctx.buffer(bufIdx++); + decodeIntMultPage(primaryTans, primaryAnsSizeLog, deltaOrder, + secondaryTans, secondaryAnsSizeLog, secondaryDeltaOrder, + dtypeSize, pageBuf, pageN, + rawLatents, rawByteOffset, + rawAdjs, adjByteOffset, + batchLowers1, batchOffsetBits1, + batchLowers2, batchOffsetBits2); + rawByteOffset += (long) pageN * Long.BYTES; + adjByteOffset += (long) pageN * Long.BYTES; + } + + if (mode == 1) { + long mask = typeMask(dtypeSize); + for (int i = 0; i < chunkN; i++) { + long off = chunkStartOffset + (long) i * Long.BYTES; + long mult = rawLatents.get(LE_LONG, off); + long adj = rawAdjs.get(LE_LONG, (long) i * Long.BYTES); + rawLatents.set(LE_LONG, off, (mult * base + adj) & mask); + } + } else if (mode == 2) { + combineFloatMult(ptype, base, chunkN, rawLatents, chunkStartOffset, rawAdjs); + } else { + combineFloatQuant(ptype, chunkMeta.quantizeK(), chunkN, rawLatents, chunkStartOffset, rawAdjs); + } + } + } + + int elemBytes = ptype.byteSize(); + MemorySegment compactOut = ctx.arena().allocate(validCount * elemBytes); + for (long i = 0; i < validCount; i++) { + long latent = rawLatents.get(LE_LONG, i * Long.BYTES); + PTypeIO.set(compactOut, i * elemBytes, ptype, fromLatentOrdered(latent, ptype)); + } + + if (validity == null) { + return toArray(dtype, n, compactOut); + } + + MemorySegment fullOut = ctx.arena().allocate(n * elemBytes); + long srcOff = 0; + for (long i = 0; i < n; i++) { + if (validity.getBoolean(i)) { + MemorySegment.copy(compactOut, srcOff, fullOut, i * elemBytes, elemBytes); + srcOff += elemBytes; + } + } + DType nonNullDtype = new DType.Primitive(ptype, false); + return new MaskedArray(toArray(nonNullDtype, n, fullOut), validity); + } + + private static long decodeClassicPage(PcoTansDecoder tans, int ansSizeLog, int deltaOrder, + int primaryDtypeSize, MemorySegment pageBuf, int pageN, + MemorySegment rawLatents, long rawByteOffset, + long[] batchLowers, int[] batchOffsetBits) { + LeBitReader pageReader = new LeBitReader(pageBuf); + + long[] moments = new long[deltaOrder]; + for (int m = 0; m < deltaOrder; m++) { + moments[m] = pageReader.readBits(primaryDtypeSize); + } + + int[] stateIdxs = new int[PcoTansDecoder.ANS_INTERLEAVING]; + for (int i = 0; i < PcoTansDecoder.ANS_INTERLEAVING; i++) { + stateIdxs[i] = (int) pageReader.readBits(ansSizeLog); + } + pageReader.alignToByte(); + + int decodedN = pageN - deltaOrder; + tans.decodePage(pageReader, stateIdxs, decodedN, rawLatents, rawByteOffset, + batchLowers, batchOffsetBits); + + if (deltaOrder > 0) { + applyConsecutiveDelta(rawLatents, rawByteOffset, pageN, moments, primaryDtypeSize); + } + + return rawByteOffset + (long) pageN * Long.BYTES; + } + + private static void decodeIntMultPage( + PcoTansDecoder primaryTans, int primaryAnsSizeLog, int deltaOrder, + PcoTansDecoder secondaryTans, int secondaryAnsSizeLog, int secondaryDeltaOrder, + int dtypeSize, MemorySegment pageBuf, int pageN, + MemorySegment rawMults, long multsOffset, + MemorySegment rawAdjs, long adjsOffset, + long[] batchLowersP, int[] batchOffsetBitsP, + long[] batchLowersS, int[] batchOffsetBitsS) { + LeBitReader pageReader = new LeBitReader(pageBuf); + + long[] primaryMoments = new long[deltaOrder]; + for (int m = 0; m < deltaOrder; m++) { + primaryMoments[m] = pageReader.readBits(dtypeSize); + } + int[] primaryStateIdxs = new int[PcoTansDecoder.ANS_INTERLEAVING]; + for (int i = 0; i < PcoTansDecoder.ANS_INTERLEAVING; i++) { + primaryStateIdxs[i] = (int) pageReader.readBits(primaryAnsSizeLog); + } + + long[] secondaryMoments = new long[secondaryDeltaOrder]; + for (int m = 0; m < secondaryDeltaOrder; m++) { + secondaryMoments[m] = pageReader.readBits(dtypeSize); + } + int[] secondaryStateIdxs = new int[PcoTansDecoder.ANS_INTERLEAVING]; + for (int i = 0; i < PcoTansDecoder.ANS_INTERLEAVING; i++) { + secondaryStateIdxs[i] = (int) pageReader.readBits(secondaryAnsSizeLog); + } + + pageReader.alignToByte(); + + int nRemaining = pageN; + long primaryPos = multsOffset; + long secondaryPos = adjsOffset; + + while (nRemaining > 0) { + int batchN = Math.min(nRemaining, PcoTansDecoder.BATCH_N); + int primaryPreDeltaN = Math.clamp(nRemaining - deltaOrder, 0, batchN); + int secondaryPreDeltaN = Math.clamp(nRemaining - secondaryDeltaOrder, 0, batchN); + + primaryTans.decodeBatch(pageReader, primaryStateIdxs, primaryPreDeltaN, + batchLowersP, batchOffsetBitsP, rawMults, primaryPos); + secondaryTans.decodeBatch(pageReader, secondaryStateIdxs, secondaryPreDeltaN, + batchLowersS, batchOffsetBitsS, rawAdjs, secondaryPos); + + primaryPos += (long) batchN * Long.BYTES; + secondaryPos += (long) batchN * Long.BYTES; + nRemaining -= batchN; + } + + if (deltaOrder > 0) { + applyConsecutiveDelta(rawMults, multsOffset, pageN, primaryMoments, dtypeSize); + } + if (secondaryDeltaOrder > 0) { + applyConsecutiveDelta(rawAdjs, adjsOffset, pageN, secondaryMoments, dtypeSize); + } + } + + private static long decodeLookbackPage( + PcoTansDecoder deltaTans, int deltaAnsSizeLog, + PcoTansDecoder primaryTans, int primaryAnsSizeLog, + int stateN, int windowN, long mid, long mask, + int dtypeSize, MemorySegment pageBuf, int pageN, + MemorySegment rawLatents, long latentsOffset, + SegmentAllocator arena, + long[] batchLowersD, int[] batchOffsetBitsD, + long[] batchLowersP, int[] batchOffsetBitsP) { + if (pageN < stateN) { + throw new VortexException(EncodingId.VORTEX_PCO, + "pco corrupt lookback page: stateN " + stateN + " exceeds pageN " + pageN); + } + LeBitReader pageReader = new LeBitReader(pageBuf); + + int[] deltaStateIdxs = new int[PcoTansDecoder.ANS_INTERLEAVING]; + for (int i = 0; i < PcoTansDecoder.ANS_INTERLEAVING; i++) { + deltaStateIdxs[i] = (int) pageReader.readBits(deltaAnsSizeLog); + } + + long[] initialState = new long[stateN]; + for (int m = 0; m < stateN; m++) { + initialState[m] = pageReader.readBits(dtypeSize); + } + int[] primaryStateIdxs = new int[PcoTansDecoder.ANS_INTERLEAVING]; + for (int i = 0; i < PcoTansDecoder.ANS_INTERLEAVING; i++) { + primaryStateIdxs[i] = (int) pageReader.readBits(primaryAnsSizeLog); + } + pageReader.alignToByte(); + + int decodeN = pageN - stateN; + if (decodeN > 1 << 23) { + throw new VortexException(EncodingId.VORTEX_PCO, + "pco corrupt lookback page: decodeN " + decodeN + " exceeds max 8388608"); + } + MemorySegment rawLookbacks = arena.allocate((long) decodeN * Long.BYTES); + MemorySegment rawResiduals = arena.allocate((long) decodeN * Long.BYTES); + + int remaining = decodeN; + long dPos = 0L; + long pPos = 0L; + while (remaining > 0) { + int batchN = Math.min(remaining, PcoTansDecoder.BATCH_N); + deltaTans.decodeBatch(pageReader, deltaStateIdxs, batchN, + batchLowersD, batchOffsetBitsD, rawLookbacks, dPos); + primaryTans.decodeBatch(pageReader, primaryStateIdxs, batchN, + batchLowersP, batchOffsetBitsP, rawResiduals, pPos); + dPos += (long) batchN * Long.BYTES; + pPos += (long) batchN * Long.BYTES; + remaining -= batchN; + } + + for (int i = 0; i < decodeN; i++) { + long off = (long) i * Long.BYTES; + rawResiduals.set(LE_LONG, off, (rawResiduals.get(LE_LONG, off) ^ mid) & mask); + } + + for (int i = 0; i < stateN; i++) { + rawLatents.set(LE_LONG, latentsOffset + (long) i * Long.BYTES, initialState[i] & mask); + } + + if (stateN > windowN) { + throw new VortexException(EncodingId.VORTEX_PCO, + "pco corrupt lookback: stateN " + stateN + " exceeds windowN " + windowN); + } + long[] window = new long[windowN + decodeN]; + for (int i = 0; i < stateN; i++) { + window[windowN - stateN + i] = initialState[i] & mask; + } + for (int i = 0; i < decodeN; i++) { + int lb = (int) rawLookbacks.get(LE_LONG, (long) i * Long.BYTES); + if (lb < 1 || lb > windowN) { + throw new VortexException(EncodingId.VORTEX_PCO, + "pco corrupt lookback index " + lb + " not in [1, " + windowN + "]"); + } + long decoded = (rawResiduals.get(LE_LONG, (long) i * Long.BYTES) + window[windowN + i - lb]) & mask; + window[windowN + i] = decoded; + rawLatents.set(LE_LONG, latentsOffset + (long) (stateN + i) * Long.BYTES, decoded); + } + + return latentsOffset + (long) pageN * Long.BYTES; + } + + private static long decodeConv1Page( + PcoTansDecoder tans, int ansSizeLog, + int order, int quantization, long bias, long[] weights, + int dtypeSize, MemorySegment pageBuf, int pageN, + MemorySegment rawLatents, long latentsOffset, + long[] batchLowers, int[] batchOffsetBits) { + LeBitReader pageReader = new LeBitReader(pageBuf); + + long[] state = new long[order]; + for (int i = 0; i < order; i++) { + state[i] = pageReader.readBits(dtypeSize); + } + int[] stateIdxs = new int[PcoTansDecoder.ANS_INTERLEAVING]; + for (int i = 0; i < PcoTansDecoder.ANS_INTERLEAVING; i++) { + stateIdxs[i] = (int) pageReader.readBits(ansSizeLog); + } + pageReader.alignToByte(); + + int decodeN = pageN - order; + long mid = typeMid(dtypeSize); + long mask = typeMask(dtypeSize); + + for (int i = 0; i < order; i++) { + rawLatents.set(LE_LONG, latentsOffset + (long) i * Long.BYTES, state[i]); + } + + tans.decodePage(pageReader, stateIdxs, decodeN, rawLatents, + latentsOffset + (long) order * Long.BYTES, + batchLowers, batchOffsetBits); + + for (int i = order; i < pageN; i++) { + long off = latentsOffset + (long) i * Long.BYTES; + rawLatents.set(LE_LONG, off, (rawLatents.get(LE_LONG, off) ^ mid) & mask); + } + + for (int i = order; i < pageN; i++) { + long pred = predictConv1(rawLatents, latentsOffset, i, order, + weights, bias, quantization, mask, dtypeSize); + long off = latentsOffset + (long) i * Long.BYTES; + rawLatents.set(LE_LONG, off, (rawLatents.get(LE_LONG, off) + pred) & mask); + } + + return latentsOffset + (long) pageN * Long.BYTES; + } + + private static long predictConv1(MemorySegment seg, long baseOff, int pos, int order, + long[] weights, long bias, int quantization, long mask, int dtypeSize) { + long s = (dtypeSize == 16) ? (int) bias : bias; + for (int k = 0; k < order; k++) { + long w = (dtypeSize == 16) ? (int) weights[k] : weights[k]; + long l = seg.get(LE_LONG, baseOff + (long) (pos - order + k) * Long.BYTES); + s += w * l; + } + if (s < 0) { + s = 0; + } + return (s >> quantization) & mask; + } + + private static long fromLatentOrdered(long latent, PType ptype) { + return switch (ptype) { + case I16 -> latent ^ 0x8000L; + case I32 -> latent ^ 0x80000000L; + case I64 -> latent ^ Long.MIN_VALUE; + case F32 -> { + long l32 = latent & 0xFFFFFFFFL; + yield (l32 & 0x80000000L) != 0 ? l32 ^ 0x80000000L : l32 ^ 0xFFFFFFFFL; + } + case F64 -> (latent & Long.MIN_VALUE) != 0 ? latent ^ Long.MIN_VALUE : ~latent; + default -> latent; + }; + } + + private static void applyConsecutiveDelta(MemorySegment rawLatents, long offset, + int pageN, long[] moments, int dtypeSize) { + long mid = typeMid(dtypeSize); + long mask = typeMask(dtypeSize); + + for (int i = 0; i < pageN; i++) { + long byteOff = offset + (long) i * Long.BYTES; + rawLatents.set(LE_LONG, byteOff, (rawLatents.get(LE_LONG, byteOff) ^ mid) & mask); + } + + for (int m = moments.length - 1; m >= 0; m--) { + long moment = moments[m] & mask; + for (int i = 0; i < pageN; i++) { + long byteOff = offset + (long) i * Long.BYTES; + long tmp = rawLatents.get(LE_LONG, byteOff); + rawLatents.set(LE_LONG, byteOff, moment); + moment = (moment + tmp) & mask; + } + } + } + + private static long typeMid(int dtypeSize) { + return switch (dtypeSize) { + case 64 -> Long.MIN_VALUE; + case 32 -> 0x80000000L; + case 16 -> 0x8000L; + default -> throw new VortexException(EncodingId.VORTEX_PCO, + "pco: invalid dtypeSize " + dtypeSize); + }; + } + + private static long typeMask(int dtypeSize) { + return switch (dtypeSize) { + case 64 -> -1L; + case 32 -> 0xFFFFFFFFL; + case 16 -> 0xFFFFL; + default -> throw new VortexException(EncodingId.VORTEX_PCO, + "pco: invalid dtypeSize " + dtypeSize); + }; + } + + private static int dtypeSize(PType ptype) { + return switch (ptype) { + case I16, U16 -> 16; + case I32, U32, F32 -> 32; + case I64, U64, F64 -> 64; + default -> throw new VortexException(EncodingId.VORTEX_PCO, + "pco: unsupported ptype " + ptype); + }; + } + + private static int bitsToEncodeOffsetBits(int dtypeSize) { + return switch (dtypeSize) { + case 64 -> PcoEncoding.BITS_TO_ENCODE_OFFSET_BITS_64; + case 32 -> PcoEncoding.BITS_TO_ENCODE_OFFSET_BITS_32; + case 16 -> PcoEncoding.BITS_TO_ENCODE_OFFSET_BITS_16; + default -> throw new VortexException(EncodingId.VORTEX_PCO, + "pco: invalid dtypeSize " + dtypeSize); + }; + } + + private static Array toArray(DType dtype, long n, MemorySegment out) { + PType ptype = ((DType.Primitive) dtype).ptype(); + return switch (ptype) { + case I16, U16 -> new ShortArray(dtype, n, out); + case I32, U32 -> new IntArray(dtype, n, out); + case F32 -> new FloatArray(dtype, n, out); + case I64, U64 -> new LongArray(dtype, n, out); + case F64 -> new DoubleArray(dtype, n, out); + default -> throw new VortexException(EncodingId.VORTEX_PCO, + "pco: unsupported ptype " + ptype); + }; + } + + private static float intFloatFromLatentF32(long l) { + long mid = 0x80000000L; + boolean negative = (l < mid); + long absInt = negative ? (0x7FFFFFFFL - l) : (l ^ 0x80000000L); + long gpi = 1L << 24; + float absFloat = (absInt < gpi) ? (float) absInt + : Float.intBitsToFloat(0x4B800000 + (int) (absInt - gpi)); + return negative ? -absFloat : absFloat; + } + + private static double intFloatFromLatentF64(long l) { + boolean negative = (l >= 0); + long absInt = negative ? (Long.MAX_VALUE - l) : (l ^ Long.MIN_VALUE); + long gpi = 1L << 53; + double absFloat = (absInt < gpi) ? (double) absInt + : Double.longBitsToDouble(0x4340000000000000L + (absInt - gpi)); + return negative ? -absFloat : absFloat; + } + + private static long toLatentOrderedF32(float f) { + int bits = Float.floatToRawIntBits(f); + if ((bits & 0x80000000) != 0) { + return (~bits) & 0xFFFFFFFFL; + } else { + return (bits ^ 0x80000000) & 0xFFFFFFFFL; + } + } + + private static long toLatentOrderedF64(double d) { + long bits = Double.doubleToRawLongBits(d); + if ((bits & Long.MIN_VALUE) != 0) { + return ~bits; + } else { + return bits ^ Long.MIN_VALUE; + } + } + + private static void combineFloatMult(PType ptype, long baseLatent, int chunkN, + MemorySegment rawLatents, long multsOffset, MemorySegment rawAdjs) { + if (ptype == PType.F32) { + float baseFloat = Float.intBitsToFloat((int) fromLatentOrdered(baseLatent, PType.F32)); + for (int i = 0; i < chunkN; i++) { + long off = multsOffset + (long) i * Long.BYTES; + long mult = rawLatents.get(LE_LONG, off); + long adj = rawAdjs.get(LE_LONG, (long) i * Long.BYTES); + long unadjusted = toLatentOrderedF32(intFloatFromLatentF32(mult) * baseFloat); + rawLatents.set(LE_LONG, off, (unadjusted + adj) & 0xFFFFFFFFL); + } + } else { + double baseDouble = Double.longBitsToDouble(fromLatentOrdered(baseLatent, PType.F64)); + for (int i = 0; i < chunkN; i++) { + long off = multsOffset + (long) i * Long.BYTES; + long mult = rawLatents.get(LE_LONG, off); + long adj = rawAdjs.get(LE_LONG, (long) i * Long.BYTES); + long unadjusted = toLatentOrderedF64(intFloatFromLatentF64(mult) * baseDouble); + rawLatents.set(LE_LONG, off, unadjusted + adj); + } + } + } + + private static void combineFloatQuant(PType ptype, int k, int chunkN, + MemorySegment rawLatents, long multsOffset, MemorySegment rawAdjs) { + if (ptype == PType.F32) { + long signCutoff = 0x80000000L >>> k; + long lowestKBitsMax = (1L << k) - 1L; + for (int i = 0; i < chunkN; i++) { + long off = multsOffset + (long) i * Long.BYTES; + long quantum = rawLatents.get(LE_LONG, off); + long adj = rawAdjs.get(LE_LONG, (long) i * Long.BYTES); + long lowestKBits = (quantum >= signCutoff) ? adj : (lowestKBitsMax - adj); + rawLatents.set(LE_LONG, off, (quantum << k) + lowestKBits); + } + } else { + long signCutoff = Long.MIN_VALUE >>> k; + long lowestKBitsMax = (1L << k) - 1L; + for (int i = 0; i < chunkN; i++) { + long off = multsOffset + (long) i * Long.BYTES; + long quantum = rawLatents.get(LE_LONG, off); + long adj = rawAdjs.get(LE_LONG, (long) i * Long.BYTES); + boolean isPos = Long.compareUnsigned(quantum, signCutoff) >= 0; + long lowestKBits = isPos ? adj : (lowestKBitsMax - adj); + rawLatents.set(LE_LONG, off, (quantum << k) + lowestKBits); + } + } + } + + private static void combineDict(long[] dict, int chunkN, + MemorySegment rawLatents, long offset) { + for (int i = 0; i < chunkN; i++) { + long off = offset + (long) i * Long.BYTES; + int idx = (int) rawLatents.get(LE_LONG, off); + if (idx < 0 || idx >= dict.length) { + throw new VortexException(EncodingId.VORTEX_PCO, + "pco dict index " + idx + " out of range [0, " + dict.length + ")"); + } + rawLatents.set(LE_LONG, off, dict[idx]); + } + } + + private static PcoChunkMeta readChunkMeta(MemorySegment buf, int dtypeSize) { + LeBitReader r = new LeBitReader(buf); + + int modeNibble = (int) r.readBits(4); + long base = 0L; + int quantizeK = 0; + long[] dict = null; + if (modeNibble == 1 || modeNibble == 2) { + base = r.readBits(dtypeSize); + } else if (modeNibble == 3) { + quantizeK = (int) r.readBits(8); + } else if (modeNibble == 4) { + int nUnique = (int) r.readBits(25); + if (nUnique > 1 << 16) { + throw new VortexException(EncodingId.VORTEX_PCO, + "pco dict nUnique " + nUnique + " exceeds max 65536"); + } + r.alignToByte(); + dict = new long[nUnique]; + for (int i = 0; i < nUnique; i++) { + dict[i] = r.readBits(dtypeSize); + } + } else if (modeNibble != 0) { + throw new VortexException(EncodingId.VORTEX_PCO, + "pco mode " + modeNibble + " not yet implemented " + + "(Classic=0, IntMult=1, FloatMult=2, FloatQuant=3, Dict=4 supported)"); + } + + int deltaVariant = (int) r.readBits(4); + int deltaOrder = 0; + boolean secondaryUsesDelta = false; + int windowNLog = 0; + int stateNLog = 0; + int conv1Quantization = 0; + long conv1Bias = 0L; + long[] conv1Weights = new long[0]; + if (deltaVariant == 0) { + // NoOp + } else if (deltaVariant == 1) { + deltaOrder = (int) r.readBits(3); + secondaryUsesDelta = r.readBits(1) != 0; + } else if (deltaVariant == 2) { + windowNLog = 1 + (int) r.readBits(5); + if (windowNLog > 24) { + throw new VortexException(EncodingId.VORTEX_PCO, + "pco lookback windowNLog " + windowNLog + " exceeds max 24"); + } + stateNLog = (int) r.readBits(4); + secondaryUsesDelta = r.readBits(1) != 0; + } else if (deltaVariant == 3) { + if (dtypeSize == 64) { + throw new VortexException(EncodingId.VORTEX_PCO, + "pco Conv1 delta not supported for 64-bit dtypes (I64/U64/F64)"); + } + conv1Quantization = (int) r.readBits(5); + conv1Bias = r.readBits(64) ^ Long.MIN_VALUE; + int conv1Order = 1 + (int) r.readBits(5); + conv1Weights = new long[conv1Order]; + for (int i = 0; i < conv1Order; i++) { + conv1Weights[i] = (int) (r.readBits(32) ^ 0x80000000L); + } + } else { + throw new VortexException(EncodingId.VORTEX_PCO, + "pco delta variant " + deltaVariant + " not yet implemented " + + "(NoOp=0, Consecutive=1, Lookback=2, Conv1=3 supported)"); + } + + int deltaAnsSizeLog = 0; + PcoBin[] deltaBins = new PcoBin[0]; + if (deltaVariant == 2) { + deltaAnsSizeLog = (int) r.readBits(4); + int nDeltaBins = (int) r.readBits(15); + deltaBins = readBins(r, nDeltaBins, deltaAnsSizeLog, 32); + } + + int primaryDtypeSize = (modeNibble == 4) ? 32 : dtypeSize; + int ansSizeLog = (int) r.readBits(4); + int nBins = (int) r.readBits(15); + PcoBin[] bins = readBins(r, nBins, ansSizeLog, primaryDtypeSize); + + int secondaryAnsSizeLog = 0; + PcoBin[] secondaryBins = new PcoBin[0]; + if (modeNibble == 1 || modeNibble == 2 || modeNibble == 3) { + secondaryAnsSizeLog = (int) r.readBits(4); + int nSecondaryBins = (int) r.readBits(15); + secondaryBins = readBins(r, nSecondaryBins, secondaryAnsSizeLog, dtypeSize); + } + r.alignToByte(); + + return new PcoChunkMeta(modeNibble, base, quantizeK, dict, + deltaVariant, deltaOrder, secondaryUsesDelta, + windowNLog, stateNLog, deltaAnsSizeLog, deltaBins, + conv1Quantization, conv1Bias, conv1Weights, + ansSizeLog, bins, secondaryAnsSizeLog, secondaryBins); + } + + private static PcoBin[] readBins(LeBitReader r, int nBins, int ansSizeLog, int dtypeSize) { + PcoBin[] bins = new PcoBin[nBins]; + int offsetBitsWidth = bitsToEncodeOffsetBits(dtypeSize); + for (int b = 0; b < nBins; b++) { + int weight = (int) r.readBits(ansSizeLog) + 1; + long lower = r.readBits(dtypeSize); + int offsetBits = (int) r.readBits(offsetBitsWidth); + bins[b] = new PcoBin(weight, lower, offsetBits); + } + return bins; + } + + private static PcoMetadata parseMeta(DecodeContext ctx) { + ByteBuffer raw = ctx.metadata(); + if (raw == null) { + throw new VortexException(EncodingId.VORTEX_PCO, "missing PcoMetadata"); + } + try { + MemorySegment metaSeg = MemorySegment.ofBuffer(raw.duplicate()); + return PcoMetadata.decode(metaSeg, 0, metaSeg.byteSize()); + } catch (IOException e) { + throw new VortexException(EncodingId.VORTEX_PCO, + "invalid PcoMetadata: " + e.getMessage()); + } + } + + private static void validateHeader(PcoMetadata meta) { + byte[] header = meta.header(); + if (header.length < 2) { + throw new VortexException(EncodingId.VORTEX_PCO, + "pco header too short: " + header.length + " bytes"); + } + if (header[0] != PcoEncoding.PCO_FORMAT_MAJOR || header[1] != PcoEncoding.PCO_FORMAT_MINOR) { + throw new VortexException(EncodingId.VORTEX_PCO, + String.format("unsupported pco format version %02x.%02x (expected %02x.%02x)", + header[0] & 0xFF, header[1] & 0xFF, + PcoEncoding.PCO_FORMAT_MAJOR & 0xFF, PcoEncoding.PCO_FORMAT_MINOR & 0xFF)); + } + } + + private record PcoChunkMeta(int mode, long base, int quantizeK, long[] dict, + int deltaVariant, int deltaOrder, boolean secondaryUsesDelta, + int windowNLog, int stateNLog, int deltaAnsSizeLog, PcoBin[] deltaBins, + int conv1Quantization, long conv1Bias, long[] conv1Weights, + int ansSizeLog, PcoBin[] bins, + int secondaryAnsSizeLog, PcoBin[] secondaryBins) { + } +} diff --git a/reader/src/main/java/io/github/dfa1/vortex/reader/decode/PrimitiveEncodingDecoder.java b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/PrimitiveEncodingDecoder.java new file mode 100644 index 00000000..7bcc4731 --- /dev/null +++ b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/PrimitiveEncodingDecoder.java @@ -0,0 +1,64 @@ +package io.github.dfa1.vortex.reader.decode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.PType; +import io.github.dfa1.vortex.core.VortexException; +import io.github.dfa1.vortex.core.array.Array; +import io.github.dfa1.vortex.core.array.BoolArray; +import io.github.dfa1.vortex.core.array.ByteArray; +import io.github.dfa1.vortex.core.array.DoubleArray; +import io.github.dfa1.vortex.core.array.Float16Array; +import io.github.dfa1.vortex.core.array.FloatArray; +import io.github.dfa1.vortex.core.array.IntArray; +import io.github.dfa1.vortex.core.array.LongArray; +import io.github.dfa1.vortex.core.array.MaskedArray; +import io.github.dfa1.vortex.core.array.ShortArray; +import io.github.dfa1.vortex.encoding.DecodeContext; +import io.github.dfa1.vortex.encoding.EncodingDecoder; +import io.github.dfa1.vortex.encoding.EncodingId; + +import java.lang.foreign.MemorySegment; + +/// Read-only decoder for {@code vortex.primitive} — raw little-endian primitive arrays. +public final class PrimitiveEncodingDecoder implements EncodingDecoder { + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public PrimitiveEncodingDecoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_PRIMITIVE; + } + + @Override + public boolean accepts(DType dtype) { + return dtype instanceof DType.Primitive; + } + + @Override + public Array decode(DecodeContext ctx) { + MemorySegment buf = ctx.buffer(0); + long n = ctx.rowCount(); + DType dt = ctx.dtype(); + PType ptype = ((DType.Primitive) dt).ptype(); + Array values = switch (ptype) { + case I64, U64 -> new LongArray(dt, n, buf); + case I32, U32 -> new IntArray(dt, n, buf); + case F64 -> new DoubleArray(dt, n, buf); + case F32 -> new FloatArray(dt, n, buf); + case I16, U16 -> new ShortArray(dt, n, buf); + case I8, U8 -> new ByteArray(dt, n, buf); + case F16 -> new Float16Array(dt, n, buf); + }; + if (ctx.node().children().length == 1) { + Array va = ctx.decodeChild(0, new DType.Bool(false), n); + if (!(va instanceof BoolArray validity)) { + throw new VortexException(EncodingId.VORTEX_PRIMITIVE, + "validity child decoded to unexpected type: " + va.getClass().getSimpleName()); + } + return new MaskedArray(values, validity); + } + return values; + } +} diff --git a/reader/src/main/java/io/github/dfa1/vortex/reader/decode/RleEncodingDecoder.java b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/RleEncodingDecoder.java new file mode 100644 index 00000000..012cec1f --- /dev/null +++ b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/RleEncodingDecoder.java @@ -0,0 +1,229 @@ +package io.github.dfa1.vortex.reader.decode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.PType; +import io.github.dfa1.vortex.core.VortexException; +import io.github.dfa1.vortex.core.array.Array; +import io.github.dfa1.vortex.core.array.ArraySegments; +import io.github.dfa1.vortex.core.array.BoolArray; +import io.github.dfa1.vortex.core.array.ByteArray; +import io.github.dfa1.vortex.core.array.DoubleArray; +import io.github.dfa1.vortex.core.array.Float16Array; +import io.github.dfa1.vortex.core.array.FloatArray; +import io.github.dfa1.vortex.core.array.IntArray; +import io.github.dfa1.vortex.core.array.LongArray; +import io.github.dfa1.vortex.core.array.MaskedArray; +import io.github.dfa1.vortex.core.array.ShortArray; +import io.github.dfa1.vortex.encoding.DecodeContext; +import io.github.dfa1.vortex.encoding.EncodingDecoder; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.PTypeIO; +import io.github.dfa1.vortex.encoding.RleEncoding; +import io.github.dfa1.vortex.encoding.SegmentBroadcast; +import io.github.dfa1.vortex.proto.RLEMetadata; + +import java.io.IOException; +import java.lang.foreign.MemorySegment; +import java.lang.foreign.SegmentAllocator; +import java.lang.foreign.ValueLayout; +import java.nio.ByteBuffer; + +/// Read-only decoder for {@code fastlanes.rle}. +public final class RleEncodingDecoder implements EncodingDecoder { + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public RleEncodingDecoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.FASTLANES_RLE; + } + + @Override + public boolean accepts(DType dtype) { + return dtype instanceof DType.Primitive p && !p.ptype().isFloating(); + } + + @Override + public Array decode(DecodeContext ctx) { + ByteBuffer rawMeta = ctx.metadata(); + RLEMetadata meta; + try { + MemorySegment metaSeg = MemorySegment.ofBuffer(rawMeta.duplicate()); + meta = RLEMetadata.decode(metaSeg, 0, metaSeg.byteSize()); + } catch (IOException e) { + throw new VortexException(EncodingId.FASTLANES_RLE, "invalid metadata", e); + } + + long valuesLen = meta.values_len(); + long indicesLen = meta.indices_len(); + PType indicesPtype = PType.fromOrdinal(meta.indices_ptype().value()); + long offsetsLen = meta.values_idx_offsets_len(); + PType offsetsPtype = PType.fromOrdinal(meta.values_idx_offsets_ptype().value()); + int offset = (int) meta.offset(); + + long rowCount = ctx.rowCount(); + if (rowCount == 0 || indicesLen == 0) { + return emptyArray(ctx); + } + + if (!(ctx.dtype() instanceof DType.Primitive p)) { + throw new VortexException(EncodingId.FASTLANES_RLE, "expected Primitive dtype, got " + ctx.dtype()); + } + PType ptype = p.ptype(); + + DType valuesDtype = new DType.Primitive(ptype, false); + DType indicesDtype = new DType.Primitive(indicesPtype, false); + DType offsetsDtype = new DType.Primitive(offsetsPtype, false); + + Array indicesRaw = ctx.decodeChild(1, indicesDtype, indicesLen); + + BoolArray indicesValidity = null; + Array indicesArr = indicesRaw; + if (indicesRaw instanceof MaskedArray masked) { + indicesArr = masked.inner(); + indicesValidity = masked.validity(); + } + + long[] values = readLongs(ctx.decodeChildSegment(0, valuesDtype, valuesLen), (int) valuesLen, ptype); + int[] indices = readIndices(ArraySegments.of(indicesArr), (int) indicesLen, indicesPtype); + long[] valuesIdxOffsets = readUnsignedLongs(ctx.decodeChildSegment(2, offsetsDtype, offsetsLen), (int) offsetsLen, offsetsPtype); + + int numChunks = (int) (indicesLen / RleEncoding.FL_CHUNK_SIZE); + int chunkEnd = (int) ((offset + rowCount + RleEncoding.FL_CHUNK_SIZE - 1) / RleEncoding.FL_CHUNK_SIZE); + chunkEnd = Math.min(chunkEnd, numChunks); + + long[] decoded = new long[chunkEnd * RleEncoding.FL_CHUNK_SIZE]; + long firstOffset = valuesLen > 0 ? valuesIdxOffsets[0] : 0L; + + for (int chunkIdx = 0; chunkIdx < chunkEnd; chunkIdx++) { + long valueIdxOffset = valuesIdxOffsets[chunkIdx] - firstOffset; + long nextValueIdxOffset = (chunkIdx + 1 < numChunks) + ? (valuesIdxOffsets[chunkIdx + 1] - firstOffset) + : valuesLen; + int numChunkValues = (int) (nextValueIdxOffset - valueIdxOffset); + + int chunkBase = chunkIdx * RleEncoding.FL_CHUNK_SIZE; + if (numChunkValues <= 1) { + long fillVal = numChunkValues == 1 ? values[(int) valueIdxOffset] : 0L; + for (int i = 0; i < RleEncoding.FL_CHUNK_SIZE; i++) { + decoded[chunkBase + i] = fillVal; + } + } else { + for (int i = 0; i < RleEncoding.FL_CHUNK_SIZE; i++) { + int idx = indices[chunkBase + i]; + if (idx >= numChunkValues) { + idx = numChunkValues - 1; + } + decoded[chunkBase + i] = values[(int) valueIdxOffset + idx]; + } + } + } + + MemorySegment seg = fromLongs(decoded, offset, (int) rowCount, ptype, ctx.arena()); + Array result = toArray(ctx.dtype(), rowCount, seg, ptype); + if (indicesValidity == null) { + return result; + } + int validityBytes = (int) ((rowCount + 7) / 8); + MemorySegment validityBuf = ctx.arena().allocate(validityBytes); + for (long j = 0; j < rowCount; j++) { + if (indicesValidity.getBoolean(offset + j)) { + int byteIdx = (int) (j >>> 3); + byte current = validityBuf.get(ValueLayout.JAVA_BYTE, byteIdx); + validityBuf.set(ValueLayout.JAVA_BYTE, byteIdx, (byte) (current | (1 << (j & 7)))); + } + } + BoolArray outputValidity = new BoolArray(new DType.Bool(false), rowCount, validityBuf); + return new MaskedArray(result, outputValidity); + } + + private static Array emptyArray(DecodeContext ctx) { + MemorySegment empty = ctx.arena().allocate(0); + DType dt = ctx.dtype(); + PType ptype = ((DType.Primitive) dt).ptype(); + return toArray(dt, 0L, empty, ptype); + } + + private static Array toArray(DType dtype, long n, MemorySegment seg, PType ptype) { + return switch (ptype) { + case I64, U64 -> new LongArray(dtype, n, seg); + case I32, U32 -> new IntArray(dtype, n, seg); + case I16, U16 -> new ShortArray(dtype, n, seg); + case I8, U8 -> new ByteArray(dtype, n, seg); + case F64 -> new DoubleArray(dtype, n, seg); + case F32 -> new FloatArray(dtype, n, seg); + case F16 -> new Float16Array(dtype, n, seg); + }; + } + + private static long[] readLongs(MemorySegment buf, int count, PType ptype) { + long[] out = new long[count]; + int elemSize = ptype.byteSize(); + long cap = SegmentBroadcast.capacity(buf, elemSize); + for (int i = 0; i < count; i++) { + long off = (i % cap) * elemSize; + out[i] = switch (ptype) { + case I8 -> buf.get(ValueLayout.JAVA_BYTE, off); + case U8 -> Byte.toUnsignedLong(buf.get(ValueLayout.JAVA_BYTE, off)); + case I16 -> buf.get(PTypeIO.LE_SHORT, off); + case U16, F16 -> Short.toUnsignedLong(buf.get(PTypeIO.LE_SHORT, off)); + case I32 -> buf.get(PTypeIO.LE_INT, off); + case U32 -> Integer.toUnsignedLong(buf.get(PTypeIO.LE_INT, off)); + case I64, U64 -> buf.get(PTypeIO.LE_LONG, off); + case F32 -> Integer.toUnsignedLong(buf.get(PTypeIO.LE_INT, off)); + case F64 -> buf.get(PTypeIO.LE_LONG, off); + }; + } + return out; + } + + private static int[] readIndices(MemorySegment buf, int count, PType indicesPtype) { + int[] out = new int[count]; + int elemSize = indicesPtype.byteSize(); + long cap = SegmentBroadcast.capacity(buf, elemSize); + switch (indicesPtype) { + case U8 -> { + for (int i = 0; i < count; i++) { + out[i] = Byte.toUnsignedInt(buf.get(ValueLayout.JAVA_BYTE, i % cap)); + } + } + case U16 -> { + for (int i = 0; i < count; i++) { + out[i] = Short.toUnsignedInt(buf.get(PTypeIO.LE_SHORT, (i % cap) * 2)); + } + } + default -> + throw new VortexException(EncodingId.FASTLANES_RLE, "unsupported indices ptype: " + indicesPtype); + } + return out; + } + + private static long[] readUnsignedLongs(MemorySegment buf, int count, PType ptype) { + long[] out = new long[count]; + int elemSize = ptype.byteSize(); + long cap = SegmentBroadcast.capacity(buf, elemSize); + for (int i = 0; i < count; i++) { + long off = (i % cap) * elemSize; + out[i] = switch (ptype) { + case U8 -> Byte.toUnsignedLong(buf.get(ValueLayout.JAVA_BYTE, off)); + case U16 -> Short.toUnsignedLong(buf.get(PTypeIO.LE_SHORT, off)); + case U32 -> Integer.toUnsignedLong(buf.get(PTypeIO.LE_INT, off)); + case U64 -> buf.get(PTypeIO.LE_LONG, off); + default -> + throw new VortexException(EncodingId.FASTLANES_RLE, "unsupported offsets ptype: " + ptype); + }; + } + return out; + } + + private static MemorySegment fromLongs(long[] decoded, int offset, int count, PType ptype, SegmentAllocator arena) { + int elemSize = ptype.byteSize(); + MemorySegment seg = arena.allocate((long) count * elemSize); + for (int i = 0; i < count; i++) { + PTypeIO.set(seg, (long) i * elemSize, ptype, decoded[offset + i]); + } + return seg; + } +} diff --git a/reader/src/main/java/io/github/dfa1/vortex/reader/decode/RunEndEncodingDecoder.java b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/RunEndEncodingDecoder.java new file mode 100644 index 00000000..fa9d43c3 --- /dev/null +++ b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/RunEndEncodingDecoder.java @@ -0,0 +1,273 @@ +package io.github.dfa1.vortex.reader.decode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.PType; +import io.github.dfa1.vortex.core.VortexException; +import io.github.dfa1.vortex.core.array.Array; +import io.github.dfa1.vortex.core.array.ArraySegments; +import io.github.dfa1.vortex.core.array.BoolArray; +import io.github.dfa1.vortex.core.array.ByteArray; +import io.github.dfa1.vortex.core.array.IntArray; +import io.github.dfa1.vortex.core.array.LongArray; +import io.github.dfa1.vortex.core.array.ShortArray; +import io.github.dfa1.vortex.core.array.VarBinArray; +import io.github.dfa1.vortex.encoding.DecodeContext; +import io.github.dfa1.vortex.encoding.EncodingDecoder; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.PTypeIO; +import io.github.dfa1.vortex.encoding.SegmentBroadcast; +import io.github.dfa1.vortex.proto.RunEndMetadata; + +import java.io.IOException; +import java.lang.foreign.MemorySegment; +import java.lang.foreign.SegmentAllocator; +import java.lang.foreign.ValueLayout; +import java.nio.ByteBuffer; + +/// Read-only decoder for {@code vortex.runend}. +public final class RunEndEncodingDecoder implements EncodingDecoder { + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public RunEndEncodingDecoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_RUNEND; + } + + @Override + public boolean accepts(DType dtype) { + return dtype instanceof DType.Primitive p && !p.ptype().isFloating(); + } + + @Override + public Array decode(DecodeContext ctx) { + ByteBuffer rawMeta = ctx.metadata(); + if (rawMeta == null) { + throw new VortexException(EncodingId.VORTEX_RUNEND, "missing metadata"); + } + + RunEndMetadata meta; + try { + MemorySegment metaSeg = MemorySegment.ofBuffer(rawMeta.duplicate()); + meta = RunEndMetadata.decode(metaSeg, 0, metaSeg.byteSize()); + } catch (IOException e) { + throw new VortexException(EncodingId.VORTEX_RUNEND, "invalid metadata", e); + } + + PType endsPtype = PType.fromOrdinal(meta.ends_ptype().value()); + long numRuns = meta.num_runs(); + long offset = meta.offset(); + + long n = ctx.rowCount(); + DType endsDtype = new DType.Primitive(endsPtype, false); + Array endsArr = ctx.decodeChild(0, endsDtype, numRuns); + + if (ctx.dtype() instanceof DType.Utf8 || ctx.dtype() instanceof DType.Binary) { + Array valuesArr = ctx.decodeChild(1, ctx.dtype(), numRuns); + return expandStrings(endsArr, (VarBinArray) valuesArr, endsPtype, numRuns, offset, n, ctx.dtype(), ctx.arena()); + } + + if (ctx.dtype() instanceof DType.Bool) { + Array valuesArr = ctx.decodeChild(1, ctx.dtype(), numRuns); + return expandBool(endsArr, (BoolArray) valuesArr, endsPtype, numRuns, offset, n, ctx.dtype(), ctx.arena()); + } + + if (!(ctx.dtype() instanceof DType.Primitive p)) { + throw new VortexException(EncodingId.VORTEX_RUNEND, "expected primitive dtype, got " + ctx.dtype()); + } + PType valuePtype = p.ptype(); + + return expand(ArraySegments.of(endsArr), ctx.decodeChildSegment(1, ctx.dtype(), numRuns), + endsPtype, valuePtype, numRuns, offset, n, ctx.dtype(), ctx.arena()); + } + + private static Array expand( + MemorySegment endsSeg, MemorySegment valuesSeg, + PType endsPtype, PType valuePtype, + long numRuns, long offset, long n, + DType dtype, SegmentAllocator arena + ) { + MemorySegment out = arena.allocate(n * valuePtype.byteSize()); + switch (valuePtype) { + case I8, U8 -> expandByte(endsSeg, valuesSeg, endsPtype, numRuns, offset, n, out); + case I16, U16 -> expandShort(endsSeg, valuesSeg, endsPtype, numRuns, offset, n, out); + case I32, U32 -> expandInt(endsSeg, valuesSeg, endsPtype, numRuns, offset, n, out); + case I64, U64 -> expandLong(endsSeg, valuesSeg, endsPtype, numRuns, offset, n, out); + default -> throw new VortexException(EncodingId.VORTEX_RUNEND, "unsupported ptype " + valuePtype); + } + MemorySegment ro = out.asReadOnly(); + return switch (valuePtype) { + case I64, U64 -> new LongArray(dtype, n, ro); + case I32, U32 -> new IntArray(dtype, n, ro); + case I16, U16 -> new ShortArray(dtype, n, ro); + case I8, U8 -> new ByteArray(dtype, n, ro); + default -> throw new VortexException(EncodingId.VORTEX_RUNEND, "unsupported ptype " + valuePtype); + }; + } + + private static void expandByte(MemorySegment endsSeg, MemorySegment valuesSeg, + PType endsPtype, long numRuns, long offset, long n, MemorySegment out) { + long endsCap = SegmentBroadcast.capacity(endsSeg, endsPtype.byteSize()); + long valCap = SegmentBroadcast.capacity(valuesSeg, 1); + long logicalPos = 0L, outPos = 0L; + for (long run = 0; run < numRuns && outPos < n; run++) { + long runEnd = readUnsigned(endsSeg, run % endsCap, endsPtype); + byte rawValue = valuesSeg.get(ValueLayout.JAVA_BYTE, run % valCap); + long writeEnd = Math.min(runEnd, offset + n); + for (long lp = Math.max(logicalPos, offset); lp < writeEnd; lp++, outPos++) { + out.set(ValueLayout.JAVA_BYTE, outPos, rawValue); + } + logicalPos = runEnd; + } + } + + private static void expandShort(MemorySegment endsSeg, MemorySegment valuesSeg, + PType endsPtype, long numRuns, long offset, long n, MemorySegment out) { + long endsCap = SegmentBroadcast.capacity(endsSeg, endsPtype.byteSize()); + long valCap = SegmentBroadcast.capacity(valuesSeg, 2); + long logicalPos = 0L, outPos = 0L; + for (long run = 0; run < numRuns && outPos < n; run++) { + long runEnd = readUnsigned(endsSeg, run % endsCap, endsPtype); + short rawValue = valuesSeg.get(PTypeIO.LE_SHORT, (run % valCap) * 2); + long writeEnd = Math.min(runEnd, offset + n); + for (long lp = Math.max(logicalPos, offset); lp < writeEnd; lp++, outPos++) { + out.set(PTypeIO.LE_SHORT, outPos * 2, rawValue); + } + logicalPos = runEnd; + } + } + + private static void expandInt(MemorySegment endsSeg, MemorySegment valuesSeg, + PType endsPtype, long numRuns, long offset, long n, MemorySegment out) { + long endsCap = SegmentBroadcast.capacity(endsSeg, endsPtype.byteSize()); + long valCap = SegmentBroadcast.capacity(valuesSeg, 4); + long logicalPos = 0L, outPos = 0L; + for (long run = 0; run < numRuns && outPos < n; run++) { + long runEnd = readUnsigned(endsSeg, run % endsCap, endsPtype); + int rawValue = valuesSeg.get(PTypeIO.LE_INT, (run % valCap) * 4); + long writeEnd = Math.min(runEnd, offset + n); + for (long lp = Math.max(logicalPos, offset); lp < writeEnd; lp++, outPos++) { + out.set(PTypeIO.LE_INT, outPos * 4, rawValue); + } + logicalPos = runEnd; + } + } + + private static void expandLong(MemorySegment endsSeg, MemorySegment valuesSeg, + PType endsPtype, long numRuns, long offset, long n, MemorySegment out) { + long endsCap = SegmentBroadcast.capacity(endsSeg, endsPtype.byteSize()); + long valCap = SegmentBroadcast.capacity(valuesSeg, 8); + long logicalPos = 0L, outPos = 0L; + for (long run = 0; run < numRuns && outPos < n; run++) { + long runEnd = readUnsigned(endsSeg, run % endsCap, endsPtype); + long rawValue = valuesSeg.get(PTypeIO.LE_LONG, (run % valCap) * 8); + long writeEnd = Math.min(runEnd, offset + n); + for (long lp = Math.max(logicalPos, offset); lp < writeEnd; lp++, outPos++) { + out.set(PTypeIO.LE_LONG, outPos * 8, rawValue); + } + logicalPos = runEnd; + } + } + + private static Array expandBool( + Array endsArr, BoolArray valuesArr, + PType endsPtype, long numRuns, long offset, long n, + DType dtype, SegmentAllocator arena + ) { + MemorySegment endsSeg = ArraySegments.of(endsArr); + long endsCap = SegmentBroadcast.capacity(endsSeg, endsPtype.byteSize()); + long numBytes = (n + 7) >>> 3; + MemorySegment out = arena.allocate(numBytes); + + long outIdx = 0; + long logicalPos = 0; + for (long run = 0; run < numRuns && outIdx < n; run++) { + long runEnd = readUnsigned(endsSeg, run % endsCap, endsPtype); + boolean val = valuesArr.getBoolean(run); + long lo = Math.max(logicalPos, offset); + long hi = Math.min(runEnd, offset + n); + for (long lp = lo; lp < hi; lp++, outIdx++) { + if (val) { + long byteIdx = outIdx >>> 3; + byte cur = out.get(ValueLayout.JAVA_BYTE, byteIdx); + out.set(ValueLayout.JAVA_BYTE, byteIdx, (byte) (cur | (1 << (outIdx & 7)))); + } + } + logicalPos = runEnd; + } + return new BoolArray(dtype, n, out.asReadOnly()); + } + + private static Array expandStrings( + Array endsArr, VarBinArray valuesArr, + PType endsPtype, long numRuns, long offset, long n, + DType dtype, SegmentAllocator arena + ) { + MemorySegment endsSeg = ArraySegments.of(endsArr); + long endsCap = SegmentBroadcast.capacity(endsSeg, endsPtype.byteSize()); + MemorySegment valBytes = valuesArr.bytesSegment(); + MemorySegment valOffsets = valuesArr.offsetsSegment(); + PType valOffPtype = valuesArr.offsetsPtype(); + + long totalBytes = 0; + long logicalPos = 0; + for (long run = 0; run < numRuns; run++) { + long runEnd = readUnsigned(endsSeg, run % endsCap, endsPtype); + long lo = Math.max(logicalPos, offset); + long hi = Math.min(runEnd, offset + n); + long count = Math.max(0, hi - lo); + long strLen = readVarBinOffset(valOffsets, run + 1, valOffPtype) + - readVarBinOffset(valOffsets, run, valOffPtype); + totalBytes += count * strLen; + logicalPos = runEnd; + } + + MemorySegment outBytes = arena.allocate(totalBytes > 0 ? totalBytes : 1); + MemorySegment outOffsets = arena.allocate((n + 1) * 4L, 4); + outOffsets.setAtIndex(PTypeIO.LE_INT, 0, 0); + + long bytePos = 0; + long outIdx = 0; + logicalPos = 0; + for (long run = 0; run < numRuns && outIdx < n; run++) { + long runEnd = readUnsigned(endsSeg, run % endsCap, endsPtype); + long lo = Math.max(logicalPos, offset); + long hi = Math.min(runEnd, offset + n); + if (hi > lo) { + long strStart = readVarBinOffset(valOffsets, run, valOffPtype); + long strEnd = readVarBinOffset(valOffsets, run + 1, valOffPtype); + long strLen = strEnd - strStart; + for (long lp = lo; lp < hi; lp++, outIdx++) { + if (strLen > 0) { + MemorySegment.copy(valBytes, strStart, outBytes, bytePos, strLen); + bytePos += strLen; + } + outOffsets.setAtIndex(PTypeIO.LE_INT, outIdx + 1, (int) bytePos); + } + } + logicalPos = runEnd; + } + + return new VarBinArray(dtype, n, outBytes.asReadOnly(), outOffsets.asReadOnly(), PType.I32); + } + + private static long readUnsigned(MemorySegment seg, long i, PType ptype) { + return switch (ptype) { + case U8 -> Byte.toUnsignedLong(seg.get(ValueLayout.JAVA_BYTE, i)); + case U16 -> Short.toUnsignedLong(seg.get(PTypeIO.LE_SHORT, i * 2)); + case U32 -> Integer.toUnsignedLong(seg.get(PTypeIO.LE_INT, i * 4)); + case U64 -> seg.get(PTypeIO.LE_LONG, i * 8); + default -> throw new VortexException(EncodingId.VORTEX_RUNEND, "non-unsigned ends ptype " + ptype); + }; + } + + private static long readVarBinOffset(MemorySegment seg, long i, PType ptype) { + return switch (ptype) { + case I32, U32 -> Integer.toUnsignedLong(seg.getAtIndex(PTypeIO.LE_INT, i)); + case I64, U64 -> seg.getAtIndex(PTypeIO.LE_LONG, i); + default -> throw new VortexException(EncodingId.VORTEX_RUNEND, "unsupported offset ptype " + ptype); + }; + } +} diff --git a/reader/src/main/java/io/github/dfa1/vortex/reader/decode/SequenceEncodingDecoder.java b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/SequenceEncodingDecoder.java new file mode 100644 index 00000000..b06f80a8 --- /dev/null +++ b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/SequenceEncodingDecoder.java @@ -0,0 +1,142 @@ +package io.github.dfa1.vortex.reader.decode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.PType; +import io.github.dfa1.vortex.core.VortexException; +import io.github.dfa1.vortex.core.array.Array; +import io.github.dfa1.vortex.core.array.ByteArray; +import io.github.dfa1.vortex.core.array.DoubleArray; +import io.github.dfa1.vortex.core.array.Float16Array; +import io.github.dfa1.vortex.core.array.FloatArray; +import io.github.dfa1.vortex.core.array.IntArray; +import io.github.dfa1.vortex.core.array.LongArray; +import io.github.dfa1.vortex.core.array.ShortArray; +import io.github.dfa1.vortex.encoding.DecodeContext; +import io.github.dfa1.vortex.encoding.EncodingDecoder; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.PTypeIO; +import io.github.dfa1.vortex.proto.ScalarValue; +import io.github.dfa1.vortex.proto.SequenceMetadata; + +import java.io.IOException; +import java.lang.foreign.MemorySegment; +import java.lang.foreign.SegmentAllocator; +import java.lang.foreign.ValueLayout; +import java.nio.ByteBuffer; + +/// Read-only decoder for {@code vortex.sequence} — {@code A[i] = base + i * multiplier}. +public final class SequenceEncodingDecoder implements EncodingDecoder { + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public SequenceEncodingDecoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_SEQUENCE; + } + + @Override + public boolean accepts(DType dtype) { + return dtype instanceof DType.Primitive; + } + + @Override + public Array decode(DecodeContext ctx) { + ByteBuffer metaBuf = ctx.metadata(); + if (metaBuf == null || !metaBuf.hasRemaining()) { + throw new VortexException(EncodingId.VORTEX_SEQUENCE, "missing metadata"); + } + SequenceMetadata meta; + try { + MemorySegment seg = MemorySegment.ofBuffer(metaBuf.duplicate()); + meta = SequenceMetadata.decode(seg, 0, seg.byteSize()); + } catch (IOException e) { + throw new VortexException(EncodingId.VORTEX_SEQUENCE, "invalid metadata", e); + } + + if (!(ctx.dtype() instanceof DType.Primitive p)) { + throw new VortexException(EncodingId.VORTEX_SEQUENCE, "expected primitive dtype, got " + ctx.dtype()); + } + + long n = ctx.rowCount(); + PType pt = p.ptype(); + return switch (pt) { + case I8, I16, I32, I64, U8, U16, U32, U64 -> decodeInteger(meta, pt, n, ctx.dtype(), ctx.arena()); + case F32 -> decodeF32(meta, n, ctx.dtype(), ctx.arena()); + case F64 -> decodeF64(meta, n, ctx.dtype(), ctx.arena()); + case F16 -> decodeF16(meta, n, ctx.dtype(), ctx.arena()); + }; + } + + private static Array decodeInteger( + SequenceMetadata meta, PType pt, long n, DType dtype, SegmentAllocator arena + ) { + long base = signedValue(meta.base()); + long mul = signedValue(meta.multiplier()); + int elemBytes = pt.byteSize(); + MemorySegment seg = arena.allocate(n * elemBytes); + for (long i = 0; i < n; i++) { + long v = base + i * mul; + switch (pt) { + case I8, U8 -> seg.set(ValueLayout.JAVA_BYTE, i, (byte) v); + case I16, U16 -> seg.setAtIndex(PTypeIO.LE_SHORT, i, (short) v); + case I32, U32 -> seg.setAtIndex(PTypeIO.LE_INT, i, (int) v); + case I64, U64 -> seg.setAtIndex(PTypeIO.LE_LONG, i, v); + default -> throw new IllegalStateException("unreachable"); + } + } + return switch (pt) { + case I64, U64 -> new LongArray(dtype, n, seg); + case I32, U32 -> new IntArray(dtype, n, seg); + case I16, U16 -> new ShortArray(dtype, n, seg); + case I8, U8 -> new ByteArray(dtype, n, seg); + default -> throw new VortexException(EncodingId.VORTEX_SEQUENCE, "unsupported ptype " + pt); + }; + } + + private static Array decodeF32(SequenceMetadata meta, long n, DType dtype, SegmentAllocator arena) { + float base = meta.base().f32_value(); + float mul = meta.multiplier().f32_value(); + MemorySegment seg = arena.allocate(n * 4L); + for (long i = 0; i < n; i++) { + seg.setAtIndex(PTypeIO.LE_FLOAT, i, base + i * mul); + } + return new FloatArray(dtype, n, seg); + } + + private static Array decodeF64(SequenceMetadata meta, long n, DType dtype, SegmentAllocator arena) { + double base = meta.base().f64_value(); + double mul = meta.multiplier().f64_value(); + MemorySegment seg = arena.allocate(n * 8L); + for (long i = 0; i < n; i++) { + seg.setAtIndex(PTypeIO.LE_DOUBLE, i, base + i * mul); + } + return new DoubleArray(dtype, n, seg); + } + + private static Array decodeF16(SequenceMetadata meta, long n, DType dtype, SegmentAllocator arena) { + short baseShort = (short) meta.base().f16_value().longValue(); + short mulShort = (short) meta.multiplier().f16_value().longValue(); + float base = Float.float16ToFloat(baseShort); + float mul = Float.float16ToFloat(mulShort); + MemorySegment seg = arena.allocate(n * 2L); + for (long i = 0; i < n; i++) { + seg.setAtIndex(PTypeIO.LE_SHORT, i, Float.floatToFloat16(base + i * mul)); + } + return new Float16Array(dtype, n, seg); + } + + private static long signedValue(ScalarValue sv) { + if (sv == null) { + return 0L; + } + if (sv.int64_value() != null) { + return sv.int64_value(); + } + if (sv.uint64_value() != null) { + return sv.uint64_value(); + } + return 0L; + } +} diff --git a/reader/src/main/java/io/github/dfa1/vortex/reader/decode/SparseEncodingDecoder.java b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/SparseEncodingDecoder.java new file mode 100644 index 00000000..15a80c74 --- /dev/null +++ b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/SparseEncodingDecoder.java @@ -0,0 +1,261 @@ +package io.github.dfa1.vortex.reader.decode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.PType; +import io.github.dfa1.vortex.core.VortexException; +import io.github.dfa1.vortex.core.array.Array; +import io.github.dfa1.vortex.core.array.BoolArray; +import io.github.dfa1.vortex.core.array.ByteArray; +import io.github.dfa1.vortex.core.array.DoubleArray; +import io.github.dfa1.vortex.core.array.FloatArray; +import io.github.dfa1.vortex.core.array.IntArray; +import io.github.dfa1.vortex.core.array.LongArray; +import io.github.dfa1.vortex.core.array.ShortArray; +import io.github.dfa1.vortex.core.array.VarBinArray; +import io.github.dfa1.vortex.encoding.DecodeContext; +import io.github.dfa1.vortex.encoding.EncodingDecoder; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.PTypeIO; +import io.github.dfa1.vortex.encoding.SegmentBroadcast; +import io.github.dfa1.vortex.proto.PatchesMetadata; +import io.github.dfa1.vortex.proto.ScalarValue; +import io.github.dfa1.vortex.proto.SparseMetadata; + +import java.io.IOException; +import java.lang.foreign.MemorySegment; +import java.lang.foreign.ValueLayout; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; + +/// Read-only decoder for {@code vortex.sparse}. +public final class SparseEncodingDecoder implements EncodingDecoder { + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public SparseEncodingDecoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_SPARSE; + } + + @Override + public boolean accepts(DType dtype) { + return dtype instanceof DType.Primitive; + } + + @Override + public Array decode(DecodeContext ctx) { + ByteBuffer rawMeta = ctx.metadata(); + if (rawMeta == null || !rawMeta.hasRemaining()) { + throw new VortexException(EncodingId.VORTEX_SPARSE, "missing metadata"); + } + SparseMetadata sparseMeta; + try { + MemorySegment metaSeg = MemorySegment.ofBuffer(rawMeta.duplicate()); + sparseMeta = SparseMetadata.decode(metaSeg, 0, metaSeg.byteSize()); + } catch (IOException e) { + throw new VortexException(EncodingId.VORTEX_SPARSE, "invalid metadata", e); + } + + PatchesMetadata patches = sparseMeta.patches(); + long numPatches = patches.len(); + long offset = patches.offset(); + PType indicesPtype = PType.fromOrdinal(patches.indices_ptype().value()); + + long n = ctx.rowCount(); + + if (ctx.dtype() instanceof DType.Utf8 || ctx.dtype() instanceof DType.Binary) { + return decodeVarBin(ctx, n, numPatches, offset, indicesPtype); + } + + if (ctx.dtype() instanceof DType.Bool) { + return decodeBool(ctx, n, numPatches, offset, indicesPtype); + } + + if (!(ctx.dtype() instanceof DType.Primitive)) { + throw new VortexException(EncodingId.VORTEX_SPARSE, "expected primitive dtype, got " + ctx.dtype()); + } + PType valuePtype = ((DType.Primitive) ctx.dtype()).ptype(); + + MemorySegment fillBuf = ctx.buffer(0); + ScalarValue fillScalar; + try { + fillScalar = ScalarValue.decode(fillBuf, 0, fillBuf.byteSize()); + } catch (IOException e) { + throw new VortexException(EncodingId.VORTEX_SPARSE, "invalid fill value", e); + } + + int elemBytes = valuePtype.byteSize(); + MemorySegment out = ctx.arena().allocate(n * elemBytes); + fillSegment(out, n, valuePtype, fillScalar); + + if (numPatches > 0) { + DType indicesDtype = new DType.Primitive(indicesPtype, false); + applyPatches(out, n, valuePtype, + ctx.decodeChildSegment(0, indicesDtype, numPatches), + ctx.decodeChildSegment(1, ctx.dtype(), numPatches), + indicesPtype, numPatches, offset); + } + + return switch (valuePtype) { + case I64, U64 -> new LongArray(ctx.dtype(), n, out); + case I32, U32 -> new IntArray(ctx.dtype(), n, out); + case F64 -> new DoubleArray(ctx.dtype(), n, out); + case F32 -> new FloatArray(ctx.dtype(), n, out); + case I16, U16 -> new ShortArray(ctx.dtype(), n, out); + case I8, U8 -> new ByteArray(ctx.dtype(), n, out); + default -> throw new VortexException(EncodingId.VORTEX_SPARSE, "unsupported ptype " + valuePtype); + }; + } + + private static Array decodeBool( + DecodeContext ctx, long n, long numPatches, long offset, PType indicesPtype + ) { + long numBytes = (n + 7) >>> 3; + MemorySegment out = ctx.arena().allocate(numBytes); + if (numPatches > 0) { + DType indicesDtype = new DType.Primitive(indicesPtype, false); + MemorySegment idxSeg = ctx.decodeChildSegment(0, indicesDtype, numPatches); + BoolArray bools = (BoolArray) ctx.decodeChild(1, ctx.dtype(), numPatches); + int idxBytes = indicesPtype.byteSize(); + for (long i = 0; i < numPatches; i++) { + if (bools.getBoolean(i)) { + long pos = readUnsignedIdx(idxSeg, SegmentBroadcast.elementOffset(idxSeg, i, idxBytes), indicesPtype) - offset; + long byteIdx = pos >>> 3; + byte cur = out.get(ValueLayout.JAVA_BYTE, byteIdx); + out.set(ValueLayout.JAVA_BYTE, byteIdx, (byte) (cur | (1 << (pos & 7)))); + } + } + } + return new BoolArray(ctx.dtype(), n, out); + } + + private static Array decodeVarBin( + DecodeContext ctx, long n, long numPatches, long offset, PType indicesPtype + ) { + MemorySegment outOffsets = ctx.arena().allocate((n + 1) * 4L, 4); + if (numPatches == 0) { + MemorySegment outBytes = ctx.arena().allocate(1); + return new VarBinArray(ctx.dtype(), n, outBytes, outOffsets, PType.I32); + } + + DType indicesDtype = new DType.Primitive(indicesPtype, false); + MemorySegment idxSeg = ctx.decodeChildSegment(0, indicesDtype, numPatches); + VarBinArray varBin = (VarBinArray) ctx.decodeChild(1, ctx.dtype(), numPatches); + MemorySegment valBytes = varBin.bytesSegment(); + MemorySegment valOffsets = varBin.offsetsSegment(); + PType valOffPtype = varBin.offsetsPtype(); + + int idxBytes = indicesPtype.byteSize(); + long totalBytes = 0; + for (long i = 0; i < numPatches; i++) { + totalBytes += readVarBinOffset(valOffsets, i + 1, valOffPtype) + - readVarBinOffset(valOffsets, i, valOffPtype); + } + + MemorySegment outBytes = ctx.arena().allocate(Math.max(1, totalBytes)); + long patchCursor = 0; + long bytePos = 0; + for (long pos = 0; pos < n; pos++) { + if (patchCursor < numPatches) { + long patchPos = readUnsignedIdx(idxSeg, SegmentBroadcast.elementOffset(idxSeg, patchCursor, idxBytes), indicesPtype) - offset; + if (patchPos == pos) { + long strStart = readVarBinOffset(valOffsets, patchCursor, valOffPtype); + long strEnd = readVarBinOffset(valOffsets, patchCursor + 1, valOffPtype); + long strLen = strEnd - strStart; + if (strLen > 0) { + MemorySegment.copy(valBytes, strStart, outBytes, bytePos, strLen); + bytePos += strLen; + } + patchCursor++; + } + } + outOffsets.setAtIndex(PTypeIO.LE_INT, pos + 1, (int) bytePos); + } + + return new VarBinArray(ctx.dtype(), n, outBytes, outOffsets, PType.I32); + } + + private static long readVarBinOffset(MemorySegment seg, long i, PType ptype) { + return switch (ptype) { + case I32, U32 -> Integer.toUnsignedLong(seg.getAtIndex(PTypeIO.LE_INT, i)); + case I64, U64 -> seg.getAtIndex(PTypeIO.LE_LONG, i); + default -> throw new VortexException(EncodingId.VORTEX_SPARSE, "unsupported offset ptype " + ptype); + }; + } + + private static void fillSegment(MemorySegment out, long n, PType ptype, ScalarValue scalar) { + long fillLong = scalarToLong(scalar); + ByteBuffer bb = out.asByteBuffer().order(ByteOrder.LITTLE_ENDIAN); + for (long i = 0; i < n; i++) { + writeElem(bb, ptype, fillLong); + } + } + + private static void applyPatches( + MemorySegment out, long n, PType valuePtype, + MemorySegment idxSeg, MemorySegment valSeg, + PType idxPtype, long numPatches, long offset + ) { + int elemBytes = valuePtype.byteSize(); + int idxBytes = idxPtype.byteSize(); + ByteBuffer outBuf = out.asByteBuffer().order(ByteOrder.LITTLE_ENDIAN); + for (long i = 0; i < numPatches; i++) { + long idx = readUnsignedIdx(idxSeg, SegmentBroadcast.elementOffset(idxSeg, i, idxBytes), idxPtype) - offset; + if (idx < 0 || idx >= n) { + throw new VortexException(EncodingId.VORTEX_SPARSE, + "patch index " + idx + " out of range [0," + n + ")"); + } + long val = readElem(valSeg, SegmentBroadcast.elementOffset(valSeg, i, elemBytes), valuePtype); + outBuf.position((int) (idx * elemBytes)); + writeElem(outBuf, valuePtype, val); + } + } + + private static long readUnsignedIdx(MemorySegment seg, long off, PType ptype) { + return switch (ptype) { + case U8 -> Byte.toUnsignedLong(seg.get(ValueLayout.JAVA_BYTE, off)); + case U16 -> Short.toUnsignedLong(seg.get(PTypeIO.LE_SHORT, off)); + case U32 -> Integer.toUnsignedLong(seg.get(PTypeIO.LE_INT, off)); + case U64 -> seg.get(PTypeIO.LE_LONG, off); + default -> throw new VortexException(EncodingId.VORTEX_SPARSE, "non-unsigned index ptype " + ptype); + }; + } + + private static long readElem(MemorySegment seg, long off, PType ptype) { + return switch (ptype) { + case I8, U8 -> Byte.toUnsignedLong(seg.get(ValueLayout.JAVA_BYTE, off)); + case I16, U16 -> Short.toUnsignedLong(seg.get(PTypeIO.LE_SHORT, off)); + case I32, U32 -> Integer.toUnsignedLong(seg.get(PTypeIO.LE_INT, off)); + case I64, U64, F32, F64 -> seg.get(PTypeIO.LE_LONG, off); + default -> throw new UnsupportedOperationException("vortex.sparse: unsupported ptype " + ptype); + }; + } + + private static void writeElem(ByteBuffer bb, PType ptype, long bits) { + switch (ptype) { + case I8, U8 -> bb.put((byte) bits); + case I16, U16 -> bb.putShort((short) bits); + case I32, U32 -> bb.putInt((int) bits); + case I64, U64, F32, F64 -> bb.putLong(bits); + default -> throw new UnsupportedOperationException("vortex.sparse: unsupported ptype " + ptype); + } + } + + private static long scalarToLong(ScalarValue scalar) { + if (scalar.int64_value() != null) { + return scalar.int64_value(); + } + if (scalar.uint64_value() != null) { + return scalar.uint64_value(); + } + if (scalar.f32_value() != null) { + return Float.floatToRawIntBits(scalar.f32_value()); + } + if (scalar.f64_value() != null) { + return Double.doubleToRawLongBits(scalar.f64_value()); + } + return 0L; + } +} diff --git a/reader/src/main/java/io/github/dfa1/vortex/reader/decode/StructEncodingDecoder.java b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/StructEncodingDecoder.java new file mode 100644 index 00000000..f0d2ba14 --- /dev/null +++ b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/StructEncodingDecoder.java @@ -0,0 +1,108 @@ +package io.github.dfa1.vortex.reader.decode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.VortexException; +import io.github.dfa1.vortex.core.array.Array; +import io.github.dfa1.vortex.core.array.BoolArray; +import io.github.dfa1.vortex.core.array.MaskedArray; +import io.github.dfa1.vortex.core.array.StructArray; +import io.github.dfa1.vortex.encoding.ArrayNode; +import io.github.dfa1.vortex.encoding.DecodeContext; +import io.github.dfa1.vortex.encoding.EncodingDecoder; +import io.github.dfa1.vortex.encoding.EncodingId; + +import java.util.ArrayList; +import java.util.List; + +/// Read-only decoder for {@code vortex.struct}. +public final class StructEncodingDecoder implements EncodingDecoder { + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public StructEncodingDecoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_STRUCT; + } + + @Override + public boolean accepts(DType dtype) { + return dtype instanceof DType.Struct; + } + + @Override + public Array decode(DecodeContext ctx) { + int numChildren = ctx.node().children().length; + + if (ctx.dtype() instanceof DType.Struct structDtype) { + int nfields = structDtype.fieldTypes().size(); + if (numChildren != nfields && numChildren != nfields + 1) { + throw new VortexException(EncodingId.VORTEX_STRUCT, + "expected %d or %d children for struct dtype, got %d" + .formatted(nfields, nfields + 1, numChildren)); + } + boolean hasValidity = (numChildren == nfields + 1); + int fieldOffset = hasValidity ? 1 : 0; + + BoolArray structValidity = null; + if (hasValidity) { + ArrayNode validityNode = ctx.node().children()[0]; + var validityCtx = new DecodeContext(validityNode, new DType.Bool(false), + ctx.rowCount(), ctx.segmentBuffers(), ctx.registry(), ctx.arena()); + Array va = ctx.registry().decode(validityCtx); + if (!(va instanceof BoolArray ba)) { + throw new VortexException(EncodingId.VORTEX_STRUCT, + "struct validity decoded to unexpected type: " + va.getClass().getSimpleName()); + } + structValidity = ba; + } + + if (nfields == 1) { + DType fieldDtype = structDtype.fieldTypes().getFirst(); + ArrayNode fieldNode = ctx.node().children()[fieldOffset]; + var fieldCtx = new DecodeContext(fieldNode, fieldDtype.withNullable(false), + ctx.rowCount(), ctx.segmentBuffers(), ctx.registry(), ctx.arena()); + Array field = ctx.registry().decode(fieldCtx); + return structValidity != null ? new MaskedArray(field, structValidity) : field; + } + + List fieldArrays = new ArrayList<>(nfields); + for (int i = 0; i < nfields; i++) { + ArrayNode fieldNode = ctx.node().children()[fieldOffset + i]; + DType fieldDtype = structDtype.fieldTypes().get(i); + var fieldCtx = new DecodeContext(fieldNode, fieldDtype.withNullable(false), + ctx.rowCount(), ctx.segmentBuffers(), ctx.registry(), ctx.arena()); + Array field = ctx.registry().decode(fieldCtx); + fieldArrays.add(structValidity != null ? new MaskedArray(field, structValidity) : field); + } + return new StructArray(structDtype, ctx.rowCount(), fieldArrays); + } + + if (numChildren == 1) { + ArrayNode valuesNode = ctx.node().children()[0]; + var valuesCtx = new DecodeContext( + valuesNode, ctx.dtype(), ctx.rowCount(), + ctx.segmentBuffers(), ctx.registry(), ctx.arena()); + return ctx.registry().decode(valuesCtx); + } else if (numChildren == 2) { + ArrayNode validityNode = ctx.node().children()[0]; + var validityCtx = new DecodeContext(validityNode, new DType.Bool(false), + ctx.rowCount(), ctx.segmentBuffers(), ctx.registry(), ctx.arena()); + Array va = ctx.registry().decode(validityCtx); + if (!(va instanceof BoolArray validity)) { + throw new VortexException(EncodingId.VORTEX_STRUCT, + "scalar wrapper validity decoded to unexpected type: " + va.getClass().getSimpleName()); + } + ArrayNode valuesNode = ctx.node().children()[1]; + var valuesCtx = new DecodeContext( + valuesNode, ctx.dtype().withNullable(false), ctx.rowCount(), + ctx.segmentBuffers(), ctx.registry(), ctx.arena()); + Array values = ctx.registry().decode(valuesCtx); + return new MaskedArray(values, validity); + } else { + throw new VortexException(EncodingId.VORTEX_STRUCT, + "unexpected child count " + numChildren + " for scalar wrapper"); + } + } +} diff --git a/reader/src/main/java/io/github/dfa1/vortex/reader/decode/VarBinEncodingDecoder.java b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/VarBinEncodingDecoder.java new file mode 100644 index 00000000..46b424bb --- /dev/null +++ b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/VarBinEncodingDecoder.java @@ -0,0 +1,67 @@ +package io.github.dfa1.vortex.reader.decode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.PType; +import io.github.dfa1.vortex.core.VortexException; +import io.github.dfa1.vortex.core.array.Array; +import io.github.dfa1.vortex.core.array.VarBinArray; +import io.github.dfa1.vortex.encoding.DecodeContext; +import io.github.dfa1.vortex.encoding.EncodingDecoder; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.SegmentBroadcast; +import io.github.dfa1.vortex.proto.VarBinMetadata; + +import java.io.IOException; +import java.lang.foreign.MemorySegment; +import java.nio.ByteBuffer; + +/// Read-only decoder for {@code vortex.varbin}. +public final class VarBinEncodingDecoder implements EncodingDecoder { + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public VarBinEncodingDecoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_VARBIN; + } + + @Override + public boolean accepts(DType dtype) { + return dtype instanceof DType.Utf8 || dtype instanceof DType.Binary; + } + + @Override + public Array decode(DecodeContext ctx) { + ByteBuffer rawMeta = ctx.metadata(); + if (rawMeta == null) { + throw new VortexException(EncodingId.VORTEX_VARBIN, "missing metadata"); + } + VarBinMetadata meta; + try { + MemorySegment metaSeg = MemorySegment.ofBuffer(rawMeta.duplicate()); + meta = VarBinMetadata.decode(metaSeg, 0, metaSeg.byteSize()); + } catch (IOException e) { + throw new VortexException(EncodingId.VORTEX_VARBIN, "invalid metadata", e); + } + + PType offsetsPtype = PType.fromOrdinal(meta.offsets_ptype().value()); + DType offsetsDtype = new DType.Primitive(offsetsPtype, false); + long n = ctx.rowCount(); + + MemorySegment offsets = ctx.decodeChildSegment(0, offsetsDtype, n + 1); + + int offBytes = offsetsPtype.byteSize(); + long offCap = SegmentBroadcast.capacity(offsets, offBytes); + if (offCap < n + 1) { + MemorySegment materialized = ctx.arena().allocate((n + 1) * (long) offBytes, offBytes); + SegmentBroadcast.broadcastCopy(offsets, materialized, n + 1, offBytes); + offsets = materialized; + } + + MemorySegment bytes = ctx.buffer(0); + + return new VarBinArray(ctx.dtype(), n, bytes, offsets, offsetsPtype); + } +} diff --git a/reader/src/main/java/io/github/dfa1/vortex/reader/decode/VarBinViewEncodingDecoder.java b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/VarBinViewEncodingDecoder.java new file mode 100644 index 00000000..c7840d4c --- /dev/null +++ b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/VarBinViewEncodingDecoder.java @@ -0,0 +1,83 @@ +package io.github.dfa1.vortex.reader.decode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.PType; +import io.github.dfa1.vortex.core.VortexException; +import io.github.dfa1.vortex.core.array.Array; +import io.github.dfa1.vortex.core.array.VarBinArray; +import io.github.dfa1.vortex.encoding.DecodeContext; +import io.github.dfa1.vortex.encoding.EncodingDecoder; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.PTypeIO; + +import java.lang.foreign.MemorySegment; + +/// Read-only decoder for {@code vortex.varbinview} (Apache Arrow StringView/BinaryView). +public final class VarBinViewEncodingDecoder implements EncodingDecoder { + + private static final int MAX_INLINED_SIZE = 12; + private static final int VIEW_SIZE = 16; + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public VarBinViewEncodingDecoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_VARBINVIEW; + } + + @Override + public boolean accepts(DType dtype) { + return dtype instanceof DType.Utf8 || dtype instanceof DType.Binary; + } + + @Override + public Array decode(DecodeContext ctx) { + if (!(ctx.dtype() instanceof DType.Utf8 || ctx.dtype() instanceof DType.Binary)) { + throw new VortexException(EncodingId.VORTEX_VARBINVIEW, + "expected Utf8/Binary dtype, got " + ctx.dtype()); + } + + int numBufs = ctx.node().bufferIndices().length; + if (numBufs < 1) { + throw new VortexException(EncodingId.VORTEX_VARBINVIEW, + "expected at least 1 buffer (views), got 0"); + } + + MemorySegment viewsBuf = ctx.buffer(numBufs - 1); + MemorySegment[] dataBufs = new MemorySegment[numBufs - 1]; + for (int i = 0; i < dataBufs.length; i++) { + dataBufs[i] = ctx.buffer(i); + } + + long n = ctx.rowCount(); + + long totalBytes = 0; + for (long i = 0; i < n; i++) { + long size = Integer.toUnsignedLong(viewsBuf.get(PTypeIO.LE_INT, i * VIEW_SIZE)); + totalBytes += size; + } + + MemorySegment outBytes = ctx.arena().allocate(totalBytes > 0 ? totalBytes : 1); + MemorySegment outOffsets = ctx.arena().allocate((n + 1) * Long.BYTES, Long.BYTES); + + long bytePos = 0; + outOffsets.setAtIndex(PTypeIO.LE_LONG, 0, 0L); + for (long i = 0; i < n; i++) { + long viewOff = i * VIEW_SIZE; + long size = Integer.toUnsignedLong(viewsBuf.get(PTypeIO.LE_INT, viewOff)); + if (size <= MAX_INLINED_SIZE) { + MemorySegment.copy(viewsBuf, viewOff + 4, outBytes, bytePos, size); + } else { + int bufferIndex = viewsBuf.get(PTypeIO.LE_INT, viewOff + 8); + long srcOffset = Integer.toUnsignedLong(viewsBuf.get(PTypeIO.LE_INT, viewOff + 12)); + MemorySegment.copy(dataBufs[bufferIndex], srcOffset, outBytes, bytePos, size); + } + bytePos += size; + outOffsets.setAtIndex(PTypeIO.LE_LONG, i + 1, bytePos); + } + + return new VarBinArray(ctx.dtype(), n, outBytes.asReadOnly(), outOffsets, PType.I64); + } +} diff --git a/reader/src/main/java/io/github/dfa1/vortex/reader/decode/VariantEncodingDecoder.java b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/VariantEncodingDecoder.java new file mode 100644 index 00000000..4c7ec102 --- /dev/null +++ b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/VariantEncodingDecoder.java @@ -0,0 +1,129 @@ +package io.github.dfa1.vortex.reader.decode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.PType; +import io.github.dfa1.vortex.core.VortexException; +import io.github.dfa1.vortex.core.array.Array; +import io.github.dfa1.vortex.core.array.VariantArray; +import io.github.dfa1.vortex.encoding.DecodeContext; +import io.github.dfa1.vortex.encoding.EncodingDecoder; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.proto.VariantMetadata; + +import java.io.IOException; +import java.lang.foreign.MemorySegment; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; + +/// Read-only decoder for {@code vortex.variant}. +public final class VariantEncodingDecoder implements EncodingDecoder { + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public VariantEncodingDecoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_VARIANT; + } + + @Override + public boolean accepts(DType dtype) { + return false; + } + + @Override + public Array decode(DecodeContext ctx) { + DType shreddedDtype = parseShreddedDtype(ctx.metadata()); + + int numChildren = ctx.node().children().length; + if (numChildren < 1 || numChildren > 2) { + throw new VortexException(EncodingId.VORTEX_VARIANT, + "expected 1 or 2 children, got " + numChildren); + } + + Array coreStorage = ctx.decodeChild(0, ctx.dtype(), ctx.rowCount()); + + Array shredded = null; + if (shreddedDtype != null && numChildren >= 2) { + shredded = ctx.decodeChild(1, shreddedDtype, ctx.rowCount()); + } + + return new VariantArray(ctx.dtype(), ctx.rowCount(), coreStorage, shredded); + } + + private static DType parseShreddedDtype(ByteBuffer rawMeta) { + if (rawMeta == null || !rawMeta.hasRemaining()) { + return null; + } + try { + MemorySegment metaSeg = MemorySegment.ofBuffer(rawMeta.duplicate()); + VariantMetadata meta = VariantMetadata.decode(metaSeg, 0, metaSeg.byteSize()); + if (meta.shredded_dtype() == null) { + return null; + } + return dtypeFromProto(meta.shredded_dtype()); + } catch (IOException e) { + throw new VortexException(EncodingId.VORTEX_VARIANT, "invalid metadata", e); + } + } + + static DType dtypeFromProto(io.github.dfa1.vortex.proto.DType proto) { + if (proto.null_() != null) { + return new DType.Null(true); + } + if (proto.bool() != null) { + return new DType.Bool(proto.bool().nullable()); + } + if (proto.primitive() != null) { + return new DType.Primitive( + PType.values()[proto.primitive().type().value()], + proto.primitive().nullable()); + } + if (proto.decimal() != null) { + return new DType.Decimal( + (byte) proto.decimal().precision(), + (byte) proto.decimal().scale(), + proto.decimal().nullable()); + } + if (proto.utf8() != null) { + return new DType.Utf8(proto.utf8().nullable()); + } + if (proto.binary() != null) { + return new DType.Binary(proto.binary().nullable()); + } + if (proto.struct() != null) { + var s = proto.struct(); + var names = new ArrayList(s.names().size()); + var types = new ArrayList(s.dtypes().size()); + names.addAll(s.names()); + for (io.github.dfa1.vortex.proto.DType child : s.dtypes()) { + types.add(dtypeFromProto(child)); + } + return new DType.Struct(List.copyOf(names), List.copyOf(types), s.nullable()); + } + if (proto.list() != null) { + return new DType.List( + dtypeFromProto(proto.list().element_type()), + proto.list().nullable()); + } + if (proto.fixed_size_list() != null) { + return new DType.FixedSizeList( + dtypeFromProto(proto.fixed_size_list().element_type()), + proto.fixed_size_list().size(), + proto.fixed_size_list().nullable()); + } + if (proto.extension() != null) { + return new DType.Extension( + proto.extension().id(), + dtypeFromProto(proto.extension().storage_dtype()), + ByteBuffer.wrap(proto.extension().metadata() != null ? proto.extension().metadata() : new byte[0]).asReadOnlyBuffer(), + false); + } + if (proto.variant() != null) { + return new DType.Variant(proto.variant().nullable()); + } + throw new VortexException("unsupported proto DType"); + } +} diff --git a/reader/src/main/java/io/github/dfa1/vortex/reader/decode/ZigZagEncodingDecoder.java b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/ZigZagEncodingDecoder.java new file mode 100644 index 00000000..18f6f678 --- /dev/null +++ b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/ZigZagEncodingDecoder.java @@ -0,0 +1,97 @@ +package io.github.dfa1.vortex.reader.decode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.PType; +import io.github.dfa1.vortex.core.VortexException; +import io.github.dfa1.vortex.core.array.Array; +import io.github.dfa1.vortex.core.array.ByteArray; +import io.github.dfa1.vortex.core.array.IntArray; +import io.github.dfa1.vortex.core.array.LongArray; +import io.github.dfa1.vortex.core.array.ShortArray; +import io.github.dfa1.vortex.encoding.DecodeContext; +import io.github.dfa1.vortex.encoding.EncodingDecoder; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.PTypeIO; +import io.github.dfa1.vortex.encoding.SegmentBroadcast; + +import java.lang.foreign.MemorySegment; +import java.lang.foreign.ValueLayout; + +/// Read-only decoder for {@code vortex.zigzag} — zigzag-decoded signed integers. +public final class ZigZagEncodingDecoder implements EncodingDecoder { + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public ZigZagEncodingDecoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_ZIGZAG; + } + + @Override + public boolean accepts(DType dtype) { + if (!(dtype instanceof DType.Primitive p)) { + return false; + } + PType pt = p.ptype(); + return pt == PType.I8 || pt == PType.I16 || pt == PType.I32 || pt == PType.I64; + } + + @Override + public Array decode(DecodeContext ctx) { + if (!(ctx.dtype() instanceof DType.Primitive p)) { + throw new VortexException(EncodingId.VORTEX_ZIGZAG, "expected primitive dtype, got " + ctx.dtype()); + } + PType signed = p.ptype(); + PType unsigned = toUnsigned(signed); + long n = ctx.rowCount(); + + MemorySegment src = ctx.decodeChildSegment(0, new DType.Primitive(unsigned, false), n); + int elemBytes = signed.byteSize(); + long srcCap = SegmentBroadcast.capacity(src, elemBytes); + MemorySegment dst = ctx.arena().allocate(n * elemBytes); + + return switch (signed) { + case I8 -> { + for (long i = 0; i < n; i++) { + int u = Byte.toUnsignedInt(src.get(ValueLayout.JAVA_BYTE, i % srcCap)); + dst.set(ValueLayout.JAVA_BYTE, i, (byte) ((u >>> 1) ^ -(u & 1))); + } + yield new ByteArray(ctx.dtype(), n, dst); + } + case I16 -> { + for (long i = 0; i < n; i++) { + int u = Short.toUnsignedInt(src.get(PTypeIO.LE_SHORT, (i % srcCap) * 2)); + dst.set(PTypeIO.LE_SHORT, i * 2, (short) ((u >>> 1) ^ -(u & 1))); + } + yield new ShortArray(ctx.dtype(), n, dst); + } + case I32 -> { + for (long i = 0; i < n; i++) { + int u = src.get(PTypeIO.LE_INT, (i % srcCap) * 4); + dst.set(PTypeIO.LE_INT, i * 4, (u >>> 1) ^ -(u & 1)); + } + yield new IntArray(ctx.dtype(), n, dst); + } + case I64 -> { + for (long i = 0; i < n; i++) { + long u = src.get(PTypeIO.LE_LONG, (i % srcCap) * 8); + dst.set(PTypeIO.LE_LONG, i * 8, (u >>> 1) ^ -(u & 1)); + } + yield new LongArray(ctx.dtype(), n, dst); + } + default -> throw new VortexException(EncodingId.VORTEX_ZIGZAG, "unreachable"); + }; + } + + private static PType toUnsigned(PType signed) { + return switch (signed) { + case I8 -> PType.U8; + case I16 -> PType.U16; + case I32 -> PType.U32; + case I64 -> PType.U64; + default -> throw new VortexException(EncodingId.VORTEX_ZIGZAG, "not a signed integer: " + signed); + }; + } +} diff --git a/reader/src/main/java/io/github/dfa1/vortex/reader/decode/ZstdEncodingDecoder.java b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/ZstdEncodingDecoder.java new file mode 100644 index 00000000..806a1944 --- /dev/null +++ b/reader/src/main/java/io/github/dfa1/vortex/reader/decode/ZstdEncodingDecoder.java @@ -0,0 +1,258 @@ +package io.github.dfa1.vortex.reader.decode; + +import com.github.luben.zstd.ZstdDecompressCtx; +import io.airlift.compress.v3.zstd.ZstdDecompressor; +import io.airlift.compress.v3.zstd.ZstdJavaDecompressor; +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.PType; +import io.github.dfa1.vortex.core.VortexException; +import io.github.dfa1.vortex.core.array.Array; +import io.github.dfa1.vortex.core.array.BoolArray; +import io.github.dfa1.vortex.core.array.ByteArray; +import io.github.dfa1.vortex.core.array.DoubleArray; +import io.github.dfa1.vortex.core.array.Float16Array; +import io.github.dfa1.vortex.core.array.FloatArray; +import io.github.dfa1.vortex.core.array.IntArray; +import io.github.dfa1.vortex.core.array.LongArray; +import io.github.dfa1.vortex.core.array.MaskedArray; +import io.github.dfa1.vortex.core.array.ShortArray; +import io.github.dfa1.vortex.core.array.VarBinArray; +import io.github.dfa1.vortex.encoding.DecodeContext; +import io.github.dfa1.vortex.encoding.EncodingDecoder; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.PTypeIO; +import io.github.dfa1.vortex.proto.ZstdMetadata; + +import java.io.IOException; +import java.lang.foreign.MemorySegment; +import java.lang.foreign.ValueLayout; +import java.nio.ByteBuffer; + +/// Read-only decoder for {@code vortex.zstd}. +public final class ZstdEncodingDecoder implements EncodingDecoder { + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public ZstdEncodingDecoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_ZSTD; + } + + @Override + public boolean accepts(DType dtype) { + return dtype instanceof DType.Primitive || dtype instanceof DType.Utf8 || dtype instanceof DType.Binary; + } + + @Override + public Array decode(DecodeContext ctx) { + ByteBuffer rawMeta = ctx.metadata(); + if (rawMeta == null) { + throw new VortexException(EncodingId.VORTEX_ZSTD, "missing metadata"); + } + ZstdMetadata meta; + try { + MemorySegment metaSeg = MemorySegment.ofBuffer(rawMeta.duplicate()); + meta = ZstdMetadata.decode(metaSeg, 0, metaSeg.byteSize()); + } catch (IOException e) { + throw new VortexException(EncodingId.VORTEX_ZSTD, "invalid metadata", e); + } + boolean hasDictionary = meta.dictionary_size() != 0; + + BoolArray validity = null; + if (ctx.node().children().length > 0) { + Array validityArray = ctx.decodeChild(0, new DType.Bool(false), ctx.rowCount()); + if (!(validityArray instanceof BoolArray ba)) { + throw new VortexException(EncodingId.VORTEX_ZSTD, + "validity child decoded to unexpected type: " + validityArray.getClass().getSimpleName()); + } + validity = ba; + } + + int frameCount = meta.frames().size(); + long totalUncompressed = 0; + for (int i = 0; i < frameCount; i++) { + totalUncompressed += meta.frames().get(i).uncompressed_size(); + } + + MemorySegment decompressed = hasDictionary + ? decompressFramesWithDict(ctx, meta, frameCount, totalUncompressed) + : decompressFrames(ctx, meta, frameCount, totalUncompressed); + + if (validity == null) { + return buildArray(ctx.dtype(), ctx.rowCount(), decompressed, ctx); + } else { + return buildNullableArray(ctx.dtype(), ctx.rowCount(), decompressed, validity, ctx); + } + } + + private static Array buildNullableArray( + DType dtype, long rowCount, MemorySegment validValues, BoolArray validity, DecodeContext ctx + ) { + Array child; + if (dtype instanceof DType.Primitive dt) { + child = buildScatteredPrimitive(dt, rowCount, validValues, validity, ctx); + } else if (dtype instanceof DType.Utf8 || dtype instanceof DType.Binary) { + child = buildScatteredVarBin(dtype, rowCount, validValues, validity, ctx); + } else { + throw new VortexException(EncodingId.VORTEX_ZSTD, "unsupported nullable dtype: " + dtype); + } + return new MaskedArray(child, validity); + } + + private static Array buildScatteredPrimitive( + DType.Primitive dt, long rowCount, MemorySegment validValues, BoolArray validity, DecodeContext ctx + ) { + int byteSize = dt.ptype().byteSize(); + MemorySegment out = ctx.arena().allocate(rowCount * byteSize); + long readPos = 0; + for (long i = 0; i < rowCount; i++) { + if (validity.getBoolean(i)) { + MemorySegment.copy(validValues, readPos, out, i * byteSize, byteSize); + readPos += byteSize; + } + } + DType.Primitive nonNull = new DType.Primitive(dt.ptype(), false); + return buildPrimitive(nonNull, rowCount, out); + } + + private static VarBinArray buildScatteredVarBin( + DType dtype, long rowCount, MemorySegment validValues, BoolArray validity, DecodeContext ctx + ) { + long totalDataBytes = 0; + long scanPos = 0; + for (long i = 0; i < rowCount; i++) { + if (validity.getBoolean(i)) { + int len = validValues.get(PTypeIO.LE_INT, scanPos); + scanPos += 4L + len; + totalDataBytes += len; + } + } + + MemorySegment values = ctx.arena().allocate(totalDataBytes > 0 ? totalDataBytes : 1); + MemorySegment offsets = ctx.arena().allocate((rowCount + 1) * 4L, 4); + offsets.setAtIndex(PTypeIO.LE_INT, 0, 0); + + long readPos = 0; + long dataPos = 0; + for (long i = 0; i < rowCount; i++) { + if (validity.getBoolean(i)) { + int len = validValues.get(PTypeIO.LE_INT, readPos); + readPos += 4; + MemorySegment.copy(validValues, readPos, values, dataPos, len); + readPos += len; + dataPos += len; + } + offsets.setAtIndex(PTypeIO.LE_INT, i + 1, (int) dataPos); + } + + return new VarBinArray(dtype.withNullable(false), rowCount, values, offsets, PType.I32); + } + + private static MemorySegment decompressFramesWithDict( + DecodeContext ctx, + ZstdMetadata meta, + int frameCount, + long totalUncompressed + ) { + MemorySegment out = ctx.arena().allocate(totalUncompressed); + byte[] dictBytes = ctx.buffer(0).toArray(ValueLayout.JAVA_BYTE); + try (ZstdDecompressCtx zctx = new ZstdDecompressCtx()) { + zctx.loadDict(dictBytes); + long outOffset = 0; + for (int i = 0; i < frameCount; i++) { + byte[] compressed = ctx.buffer(i + 1).toArray(ValueLayout.JAVA_BYTE); + int uncompSize = (int) meta.frames().get(i).uncompressed_size(); + byte[] temp = new byte[uncompSize]; + int written = zctx.decompressByteArray(temp, 0, uncompSize, compressed, 0, compressed.length); + if (written != uncompSize) { + throw new VortexException(EncodingId.VORTEX_ZSTD, + "frame " + i + ": expected " + uncompSize + " bytes, got " + written); + } + MemorySegment.copy(MemorySegment.ofArray(temp), 0, out, outOffset, uncompSize); + outOffset += uncompSize; + } + } catch (VortexException e) { + throw e; + } catch (Exception e) { + throw new VortexException(EncodingId.VORTEX_ZSTD, "dict decompression failed", e); + } + return out; + } + + private static MemorySegment decompressFrames( + DecodeContext ctx, + ZstdMetadata meta, + int frameCount, + long totalUncompressed + ) { + MemorySegment out = ctx.arena().allocate(totalUncompressed); + ZstdDecompressor decompressor = new ZstdJavaDecompressor(); + long outOffset = 0; + for (int i = 0; i < frameCount; i++) { + MemorySegment frameSeg = ctx.buffer(i); + byte[] compressed = frameSeg.toArray(ValueLayout.JAVA_BYTE); + int uncompSize = (int) meta.frames().get(i).uncompressed_size(); + byte[] temp = new byte[uncompSize]; + int written = decompressor.decompress(compressed, 0, compressed.length, temp, 0, uncompSize); + if (written != uncompSize) { + throw new VortexException(EncodingId.VORTEX_ZSTD, + "frame " + i + ": expected " + uncompSize + " bytes, got " + written); + } + MemorySegment.copy(MemorySegment.ofArray(temp), 0, out, outOffset, uncompSize); + outOffset += uncompSize; + } + return out; + } + + private static Array buildArray(DType dtype, long n, MemorySegment decompressed, DecodeContext ctx) { + if (dtype instanceof DType.Primitive dt) { + return buildPrimitive(dt, n, decompressed); + } + if (dtype instanceof DType.Utf8 || dtype instanceof DType.Binary) { + return buildVarBin(dtype, n, decompressed, ctx); + } + throw new VortexException(EncodingId.VORTEX_ZSTD, "unsupported dtype: " + dtype); + } + + private static Array buildPrimitive(DType.Primitive dt, long n, MemorySegment decompressed) { + PType ptype = dt.ptype(); + return switch (ptype) { + case I64, U64 -> new LongArray(dt, n, decompressed); + case I32, U32 -> new IntArray(dt, n, decompressed); + case F64 -> new DoubleArray(dt, n, decompressed); + case F32 -> new FloatArray(dt, n, decompressed); + case I16, U16 -> new ShortArray(dt, n, decompressed); + case I8, U8 -> new ByteArray(dt, n, decompressed); + case F16 -> new Float16Array(dt, n, decompressed); + }; + } + + private static VarBinArray buildVarBin(DType dtype, long n, MemorySegment decompressed, DecodeContext ctx) { + long totalDataBytes = 0; + long pos = 0; + for (long i = 0; i < n; i++) { + int len = decompressed.get(PTypeIO.LE_INT, pos); + pos += 4 + len; + totalDataBytes += len; + } + + MemorySegment values = ctx.arena().allocate(totalDataBytes); + MemorySegment offsets = ctx.arena().allocate((n + 1) * 4L, 4); + offsets.setAtIndex(PTypeIO.LE_INT, 0, 0); + + pos = 0; + long dataPos = 0; + for (long i = 0; i < n; i++) { + int len = decompressed.get(PTypeIO.LE_INT, pos); + pos += 4; + MemorySegment.copy(decompressed, pos, values, dataPos, len); + pos += len; + dataPos += len; + offsets.setAtIndex(PTypeIO.LE_INT, i + 1, (int) dataPos); + } + + return new VarBinArray(dtype, n, values, offsets, PType.I32); + } +} diff --git a/reader/src/main/resources/META-INF/services/io.github.dfa1.vortex.encoding.EncodingDecoder b/reader/src/main/resources/META-INF/services/io.github.dfa1.vortex.encoding.EncodingDecoder new file mode 100644 index 00000000..781a8a16 --- /dev/null +++ b/reader/src/main/resources/META-INF/services/io.github.dfa1.vortex.encoding.EncodingDecoder @@ -0,0 +1,33 @@ +io.github.dfa1.vortex.reader.decode.AlpEncodingDecoder +io.github.dfa1.vortex.reader.decode.AlpRdEncodingDecoder +io.github.dfa1.vortex.reader.decode.BitpackedEncodingDecoder +io.github.dfa1.vortex.reader.decode.BoolEncodingDecoder +io.github.dfa1.vortex.reader.decode.ByteBoolEncodingDecoder +io.github.dfa1.vortex.reader.decode.ChunkedEncodingDecoder +io.github.dfa1.vortex.reader.decode.ConstantEncodingDecoder +io.github.dfa1.vortex.reader.decode.DateTimePartsEncodingDecoder +io.github.dfa1.vortex.reader.decode.DecimalBytePartsEncodingDecoder +io.github.dfa1.vortex.reader.decode.DecimalEncodingDecoder +io.github.dfa1.vortex.reader.decode.DeltaEncodingDecoder +io.github.dfa1.vortex.reader.decode.DictEncodingDecoder +io.github.dfa1.vortex.reader.decode.ExtEncodingDecoder +io.github.dfa1.vortex.reader.decode.FixedSizeListEncodingDecoder +io.github.dfa1.vortex.reader.decode.FrameOfReferenceEncodingDecoder +io.github.dfa1.vortex.reader.decode.FsstEncodingDecoder +io.github.dfa1.vortex.reader.decode.ListEncodingDecoder +io.github.dfa1.vortex.reader.decode.ListViewEncodingDecoder +io.github.dfa1.vortex.reader.decode.MaskedEncodingDecoder +io.github.dfa1.vortex.reader.decode.NullEncodingDecoder +io.github.dfa1.vortex.reader.decode.PatchedEncodingDecoder +io.github.dfa1.vortex.reader.decode.PcoEncodingDecoder +io.github.dfa1.vortex.reader.decode.PrimitiveEncodingDecoder +io.github.dfa1.vortex.reader.decode.RleEncodingDecoder +io.github.dfa1.vortex.reader.decode.RunEndEncodingDecoder +io.github.dfa1.vortex.reader.decode.SequenceEncodingDecoder +io.github.dfa1.vortex.reader.decode.SparseEncodingDecoder +io.github.dfa1.vortex.reader.decode.StructEncodingDecoder +io.github.dfa1.vortex.reader.decode.VarBinEncodingDecoder +io.github.dfa1.vortex.reader.decode.VariantEncodingDecoder +io.github.dfa1.vortex.reader.decode.VarBinViewEncodingDecoder +io.github.dfa1.vortex.reader.decode.ZigZagEncodingDecoder +io.github.dfa1.vortex.reader.decode.ZstdEncodingDecoder diff --git a/reader/src/test/java/io/github/dfa1/vortex/reader/decode/ByteBoolEncodingDecoderTest.java b/reader/src/test/java/io/github/dfa1/vortex/reader/decode/ByteBoolEncodingDecoderTest.java new file mode 100644 index 00000000..f9219d5a --- /dev/null +++ b/reader/src/test/java/io/github/dfa1/vortex/reader/decode/ByteBoolEncodingDecoderTest.java @@ -0,0 +1,57 @@ +package io.github.dfa1.vortex.reader.decode; + +import io.github.dfa1.vortex.core.array.BoolArray; +import io.github.dfa1.vortex.encoding.ArrayNode; +import io.github.dfa1.vortex.encoding.DTypes; +import io.github.dfa1.vortex.encoding.DecodeContext; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.Registry; +import io.github.dfa1.vortex.encoding.TestRegistry; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; + +import java.lang.foreign.Arena; +import java.lang.foreign.MemorySegment; +import java.util.stream.Stream; + +import static org.assertj.core.api.Assertions.assertThat; + +class ByteBoolEncodingDecoderTest { + + static Stream cases() { + return Stream.of( + Arguments.of("all false", new byte[]{0, 0, 0}, new boolean[]{false, false, false}), + Arguments.of("all true", new byte[]{1, 42, (byte) 0xFF}, new boolean[]{true, true, true}), + Arguments.of("mixed", new byte[]{0, 1, 0, 1}, new boolean[]{false, true, false, true}), + Arguments.of("empty", new byte[]{}, new boolean[]{}) + ); + } + + private static DecodeContext buildCtx(byte[] byteValues) { + MemorySegment buf = MemorySegment.ofArray(byteValues); + ArrayNode node = ArrayNode.of(EncodingId.VORTEX_BYTEBOOL, null, new ArrayNode[0], new int[]{0}, null); + Registry registry = TestRegistry.ofDecoders(new ByteBoolEncodingDecoder()); + return new DecodeContext(node, DTypes.BOOL, byteValues.length, new MemorySegment[]{buf}, registry, + Arena.ofAuto()); + } + + @ParameterizedTest(name = "{0}") + @MethodSource("cases") + void decode_byteBool_packsToBitArray(String name, byte[] input, boolean[] expected) { + // Given + DecodeContext ctx = buildCtx(input); + var sut = new ByteBoolEncodingDecoder(); + + // When + var result = sut.decode(ctx); + + // Then + assertThat(result).isInstanceOf(BoolArray.class); + assertThat(result.length()).isEqualTo(expected.length); + BoolArray boolArr = (BoolArray) result; + for (int i = 0; i < expected.length; i++) { + assertThat(boolArr.getBoolean(i)).as("index %d", i).isEqualTo(expected[i]); + } + } +} diff --git a/reader/src/test/java/io/github/dfa1/vortex/reader/decode/NullEncodingDecoderTest.java b/reader/src/test/java/io/github/dfa1/vortex/reader/decode/NullEncodingDecoderTest.java new file mode 100644 index 00000000..ce653d15 --- /dev/null +++ b/reader/src/test/java/io/github/dfa1/vortex/reader/decode/NullEncodingDecoderTest.java @@ -0,0 +1,35 @@ +package io.github.dfa1.vortex.reader.decode; + +import io.github.dfa1.vortex.core.array.NullArray; +import io.github.dfa1.vortex.encoding.ArrayNode; +import io.github.dfa1.vortex.encoding.DTypes; +import io.github.dfa1.vortex.encoding.DecodeContext; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.Registry; +import org.junit.jupiter.api.Test; + +import java.lang.foreign.Arena; +import java.lang.foreign.MemorySegment; + +import static org.assertj.core.api.Assertions.assertThat; + +class NullEncodingDecoderTest { + + @Test + void decode_nullArray_returnsNullArrayWithCorrectLength() { + // Given + long rowCount = 42L; + ArrayNode node = ArrayNode.of(EncodingId.VORTEX_NULL, null, new ArrayNode[0], new int[0], null); + DecodeContext ctx = new DecodeContext(node, DTypes.NULL, rowCount, new MemorySegment[0], + Registry.empty(), Arena.ofAuto()); + var sut = new NullEncodingDecoder(); + + // When + var result = sut.decode(ctx); + + // Then + assertThat(result).isInstanceOf(NullArray.class); + assertThat(result.length()).isEqualTo(rowCount); + assertThat(result.dtype()).isEqualTo(DTypes.NULL); + } +} diff --git a/reader/src/test/java/io/github/dfa1/vortex/reader/decode/PatchedEncodingDecoderTest.java b/reader/src/test/java/io/github/dfa1/vortex/reader/decode/PatchedEncodingDecoderTest.java new file mode 100644 index 00000000..50f16685 --- /dev/null +++ b/reader/src/test/java/io/github/dfa1/vortex/reader/decode/PatchedEncodingDecoderTest.java @@ -0,0 +1,167 @@ +package io.github.dfa1.vortex.reader.decode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.PType; +import io.github.dfa1.vortex.core.array.Array; +import io.github.dfa1.vortex.core.array.ArraySegments; +import io.github.dfa1.vortex.core.array.IntArray; +import io.github.dfa1.vortex.core.array.LongArray; +import io.github.dfa1.vortex.encoding.ArrayNode; +import io.github.dfa1.vortex.encoding.DecodeContext; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.PTypeIO; +import io.github.dfa1.vortex.encoding.Registry; +import io.github.dfa1.vortex.encoding.TestRegistry; +import io.github.dfa1.vortex.proto.PatchedMetadata; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; + +import java.lang.foreign.Arena; +import java.lang.foreign.MemorySegment; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +class PatchedEncodingDecoderTest { + + private static final PatchedEncodingDecoder SUT = new PatchedEncodingDecoder(); + private static final Registry REGISTRY = TestRegistry.ofDecoders(SUT, new PrimitiveEncodingDecoder()); + + private static ByteBuffer patchedMeta(int nPatches, int nLanes, int offset) { + return ByteBuffer.wrap(new PatchedMetadata(nPatches, nLanes, offset).encode()); + } + + private static MemorySegment i32Segment(int... values) { + MemorySegment seg = MemorySegment.ofArray(new byte[values.length * 4]); + ByteBuffer bb = seg.asByteBuffer().order(ByteOrder.LITTLE_ENDIAN); + for (int v : values) { + bb.putInt(v); + } + return seg; + } + + private static MemorySegment u32Segment(int... values) { + return i32Segment(values); + } + + private static MemorySegment u16Segment(short... values) { + MemorySegment seg = MemorySegment.ofArray(new byte[values.length * 2]); + ByteBuffer bb = seg.asByteBuffer().order(ByteOrder.LITTLE_ENDIAN); + for (short v : values) { + bb.putShort(v); + } + return seg; + } + + private static MemorySegment i64Segment(long... values) { + MemorySegment seg = MemorySegment.ofArray(new byte[values.length * 8]); + ByteBuffer bb = seg.asByteBuffer().order(ByteOrder.LITTLE_ENDIAN); + for (long v : values) { + bb.putLong(v); + } + return seg; + } + + private static Array decode(int n, int[] innerI32, int[] laneOffsets, short[] patchIndices, int[] patchValues) { + return decode(new DType.Primitive(PType.I32, false), n, + i32Segment(innerI32), u32Segment(laneOffsets), + u16Segment(patchIndices), i32Segment(patchValues), + laneOffsets.length - 1); + } + + private static Array decode(DType dtype, int n, + MemorySegment inner, MemorySegment laneOffsets, + MemorySegment patchIndices, MemorySegment patchValues, + int nLanes) { + int nPatches = (int) (patchIndices.byteSize() / 2); + ByteBuffer meta = patchedMeta(nPatches, nLanes, 0); + + MemorySegment[] segments = {inner, laneOffsets, patchIndices, patchValues}; + + ArrayNode innerNode = ArrayNode.of(EncodingId.VORTEX_PRIMITIVE, null, new ArrayNode[0], new int[]{0}, null); + ArrayNode laneNode = ArrayNode.of(EncodingId.VORTEX_PRIMITIVE, null, new ArrayNode[0], new int[]{1}, null); + ArrayNode idxNode = ArrayNode.of(EncodingId.VORTEX_PRIMITIVE, null, new ArrayNode[0], new int[]{2}, null); + ArrayNode valNode = ArrayNode.of(EncodingId.VORTEX_PRIMITIVE, null, new ArrayNode[0], new int[]{3}, null); + + ArrayNode patchedNode = ArrayNode.of(EncodingId.VORTEX_PATCHED, meta, + new ArrayNode[]{innerNode, laneNode, idxNode, valNode}, new int[]{}, null); + + DecodeContext ctx = new DecodeContext(patchedNode, dtype, n, segments, REGISTRY, Arena.ofAuto()); + return SUT.decode(ctx); + } + + @Test + void decode_noPatches_returnsInnerUnchanged() { + int n = 4; + int[] inner = {10, 20, 30, 40}; + Array sut = decode(n, inner, new int[]{0, 0}, new short[]{}, new int[]{}); + + assertThat(sut).isInstanceOf(IntArray.class); + MemorySegment seg = ArraySegments.of(sut); + for (int i = 0; i < n; i++) { + assertThat(seg.getAtIndex(PTypeIO.LE_INT, i)).as("index %d", i).isEqualTo(inner[i]); + } + } + + @Test + void decode_singlePatch_overwrites() { + Array sut = decode(4, new int[]{10, 20, 30, 40}, new int[]{0, 1}, new short[]{2}, new int[]{99}); + + MemorySegment seg = ArraySegments.of(sut); + assertThat(seg.getAtIndex(PTypeIO.LE_INT, 0)).isEqualTo(10); + assertThat(seg.getAtIndex(PTypeIO.LE_INT, 1)).isEqualTo(20); + assertThat(seg.getAtIndex(PTypeIO.LE_INT, 2)).isEqualTo(99); + assertThat(seg.getAtIndex(PTypeIO.LE_INT, 3)).isEqualTo(40); + } + + @Test + void decode_multiplePatches_allApplied() { + Array sut = decode(4, new int[]{0, 0, 0, 0}, new int[]{0, 2}, new short[]{0, 3}, new int[]{1, 7}); + + MemorySegment seg = ArraySegments.of(sut); + assertThat(seg.getAtIndex(PTypeIO.LE_INT, 0)).isEqualTo(1); + assertThat(seg.getAtIndex(PTypeIO.LE_INT, 1)).isEqualTo(0); + assertThat(seg.getAtIndex(PTypeIO.LE_INT, 2)).isEqualTo(0); + assertThat(seg.getAtIndex(PTypeIO.LE_INT, 3)).isEqualTo(7); + } + + @ParameterizedTest + @ValueSource(ints = {1, 2, 1020, 1023, 1024, 1025, 2048}) + void decode_variousLengths_noPatches(int n) { + int[] inner = new int[n]; + Array sut = decode(n, inner, new int[]{0, 0}, new short[]{}, new int[]{}); + + MemorySegment seg = ArraySegments.of(sut); + for (int i = 0; i < n; i++) { + assertThat(seg.getAtIndex(PTypeIO.LE_INT, i)).as("index %d", i).isZero(); + } + } + + @Test + void decode_i64_singlePatch() { + DType dtype = new DType.Primitive(PType.I64, false); + Array sut = decode(dtype, 3, i64Segment(100L, 200L, 300L), u32Segment(0, 1), + u16Segment((short) 1), i64Segment(999L), 1); + + assertThat(sut).isInstanceOf(LongArray.class); + MemorySegment seg = ArraySegments.of(sut); + assertThat(seg.getAtIndex(PTypeIO.LE_LONG, 0)).isEqualTo(100L); + assertThat(seg.getAtIndex(PTypeIO.LE_LONG, 1)).isEqualTo(999L); + assertThat(seg.getAtIndex(PTypeIO.LE_LONG, 2)).isEqualTo(300L); + } + + @Test + void decode_missingMetadata_throws() { + ArrayNode innerNode = ArrayNode.of(EncodingId.VORTEX_PRIMITIVE, null, new ArrayNode[0], new int[]{0}, null); + ArrayNode patchedNode = ArrayNode.of(EncodingId.VORTEX_PATCHED, null, + new ArrayNode[]{innerNode, innerNode, innerNode, innerNode}, new int[]{}, null); + MemorySegment seg = i32Segment(1, 2, 3); + DecodeContext ctx = new DecodeContext(patchedNode, new DType.Primitive(PType.I32, false), 3, + new MemorySegment[]{seg}, Registry.empty(), Arena.ofAuto()); + + assertThatThrownBy(() -> SUT.decode(ctx)).hasMessageContaining("missing metadata"); + } +} diff --git a/core/src/test/java/io/github/dfa1/vortex/encoding/PcoEncodingTest.java b/reader/src/test/java/io/github/dfa1/vortex/reader/decode/PcoEncodingDecoderTest.java similarity index 51% rename from core/src/test/java/io/github/dfa1/vortex/encoding/PcoEncodingTest.java rename to reader/src/test/java/io/github/dfa1/vortex/reader/decode/PcoEncodingDecoderTest.java index dc462dcc..39ef7c93 100644 --- a/core/src/test/java/io/github/dfa1/vortex/encoding/PcoEncodingTest.java +++ b/reader/src/test/java/io/github/dfa1/vortex/reader/decode/PcoEncodingDecoderTest.java @@ -1,13 +1,19 @@ -package io.github.dfa1.vortex.encoding; +package io.github.dfa1.vortex.reader.decode; -import io.github.dfa1.vortex.proto.PcoChunkInfo; -import io.github.dfa1.vortex.proto.PcoMetadata; -import io.github.dfa1.vortex.proto.PcoPageInfo; import io.github.dfa1.vortex.core.DType; import io.github.dfa1.vortex.core.PType; import io.github.dfa1.vortex.core.VortexException; import io.github.dfa1.vortex.core.array.LongArray; import io.github.dfa1.vortex.core.array.MaskedArray; +import io.github.dfa1.vortex.encoding.ArrayNode; +import io.github.dfa1.vortex.encoding.DecodeContext; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.PcoEncoding; +import io.github.dfa1.vortex.encoding.Registry; +import io.github.dfa1.vortex.encoding.TestRegistry; +import io.github.dfa1.vortex.proto.PcoChunkInfo; +import io.github.dfa1.vortex.proto.PcoMetadata; +import io.github.dfa1.vortex.proto.PcoPageInfo; import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; @@ -25,7 +31,9 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; -class PcoEncodingTest { +class PcoEncodingDecoderTest { + + private static final PcoEncodingDecoder SUT = new PcoEncodingDecoder(); private static ByteBuffer validMetaBuffer() { PcoMetadata meta = new PcoMetadata(new byte[]{PcoEncoding.PCO_FORMAT_MAJOR, PcoEncoding.PCO_FORMAT_MINOR}, java.util.List.of()); @@ -38,8 +46,6 @@ private static DecodeContext ctxWith(ByteBuffer meta, DType dtype, long rowCount return new DecodeContext(node, dtype, rowCount, buffers, Registry.empty(), Arena.ofAuto()); } - /// Build a nullable DecodeContext: validity buffer at index 0, pco buffers at indices 1..N. - /// Validity is a bit-packed Bool array (LSB-first, 1=valid). private static DecodeContext ctxWithValidity(ByteBuffer meta, DType dtype, long rowCount, MemorySegment validityBuf, MemorySegment[] pcoBuffers) { MemorySegment[] allBuffers = new MemorySegment[1 + pcoBuffers.length]; @@ -56,7 +62,7 @@ private static DecodeContext ctxWithValidity(ByteBuffer meta, DType dtype, long ArrayNode pcoNode = ArrayNode.of(EncodingId.VORTEX_PCO, meta, new ArrayNode[]{validityNode}, pcoBufferIndices, null); - Registry registry = TestRegistry.of(new BoolEncoding()); + Registry registry = TestRegistry.ofDecoders(new BoolEncodingDecoder()); return new DecodeContext(pcoNode, dtype, rowCount, allBuffers, registry, Arena.ofAuto()); } @@ -76,7 +82,6 @@ private static MemorySegment segmentOf(byte... bytes) { return seg; } - /// Build a PcoMetadata proto with one chunk containing one page of {@code nValues} values. private static ByteBuffer metaWithOneChunk(int nValues) { PcoMetadata meta = new PcoMetadata( new byte[]{PcoEncoding.PCO_FORMAT_MAJOR, PcoEncoding.PCO_FORMAT_MINOR}, @@ -84,22 +89,10 @@ private static ByteBuffer metaWithOneChunk(int nValues) { return ByteBuffer.wrap(meta.encode()); } - /// Chunk-meta bytes for Classic mode, Consecutive delta at {@code order}, ansSizeLog=0, nBins=0. - /// - /// Bit layout (LSB-first per byte): - /// byte0: mode_nibble=0, delta_nibble=1 - /// byte1: order (3b), secondary_uses_delta=0 (1b), ansSizeLog=0 (4b) - /// byte2–3: nBins=0 (15b), align padding private static MemorySegment chunkMetaConsecutive(int order) { - return segmentOf( - (byte) 0x10, // mode=0, delta_variant=1 - (byte) order, // order[2:0], secondary=0, ansSizeLog=0 (order ≤ 7) - (byte) 0x00, // nBins bits16-23 = 0 - (byte) 0x00 // nBins bits24-30 = 0, padding - ); + return segmentOf((byte) 0x10, (byte) order, (byte) 0x00, (byte) 0x00); } - /// Page bytes: {@code order} LE-U64 moments, then 4 zero ANS-state slots (0 bits each). private static MemorySegment pageWithMoments(long... moments) { byte[] buf = new byte[moments.length * Long.BYTES]; java.nio.ByteBuffer bb = java.nio.ByteBuffer.wrap(buf).order(java.nio.ByteOrder.LITTLE_ENDIAN); @@ -109,43 +102,32 @@ private static MemorySegment pageWithMoments(long... moments) { return segmentOf(buf); } - /// Build chunk meta for Classic mode + Conv1 delta by packing bits LSB-first. - /// - /// Layout: mode(4b)=0, delta(4b)=3, quantization(5b), bias_latent(64b), - /// order-1(5b), weights[order*32b], ansSizeLog(4b)=0, nBins(15b)=0, align. - /// bias_latent = bias ^ Long.MIN_VALUE; each weight_latent = weight ^ 0x80000000L. private static MemorySegment chunkMetaConv1(int quantization, long biasLatent, int order, long[] weightLatents) { java.util.BitSet bits = new java.util.BitSet(); int pos = 0; - // mode nibble = 0 (Classic) pos += 4; - // delta nibble = 3 (Conv1): bits = 0b0011 bits.set(pos); bits.set(pos + 1); pos += 4; - // quantization (5 bits) for (int i = 0; i < 5; i++) { if (((quantization >> i) & 1) != 0) { bits.set(pos); } pos++; } - // bias latent (64 bits, LSB first) for (int i = 0; i < 64; i++) { if (((biasLatent >> i) & 1L) != 0L) { bits.set(pos); } pos++; } - // order-1 (5 bits) for (int i = 0; i < 5; i++) { if ((((order - 1) >> i) & 1) != 0) { bits.set(pos); } pos++; } - // weight latents (order × 32 bits) for (long wl : weightLatents) { for (int i = 0; i < 32; i++) { if (((wl >> i) & 1L) != 0L) { @@ -154,9 +136,7 @@ private static MemorySegment chunkMetaConv1(int quantization, long biasLatent, pos++; } } - // ansSizeLog (4 bits) = 0 pos += 4; - // nBins (15 bits) = 0 pos += 15; int byteLen = (pos + 7) / 8; byte[] buf = new byte[byteLen]; @@ -168,20 +148,10 @@ private static MemorySegment chunkMetaConv1(int quantization, long biasLatent, return segmentOf(buf); } - /// Chunk-meta bytes for Classic mode + Lookback delta with windowNLog=1 (windowN=2), stateNLog=0 (stateN=1), - /// deltaAnsSizeLog=0, primaryAnsSizeLog=0, no bins. - /// - /// Bit layout: - /// byte0: mode=0[3:0], delta=2[7:4] → 0x20 - /// bytes 1-6: windowNLog-1(5b)=0, stateNLog(4b)=0, secondary(1b)=0, - /// deltaAnsSizeLog(4b)=0, nDeltaBins(15b)=0, - /// primaryAnsSizeLog(4b)=0, nBins(15b)=0, align → all 0x00 private static MemorySegment chunkMetaLookback() { return segmentOf((byte) 0x20, (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00); } - /// Page bytes for Lookback with stateN=1, U64, deltaAnsSizeLog=0, primaryAnsSizeLog=0. - /// Format: 8 bytes (one 64-bit initial state). No ANS state bits (sizeLog=0). No decoded bits. private static MemorySegment lookbackPage(long initialState) { byte[] buf = new byte[Long.BYTES]; java.nio.ByteBuffer.wrap(buf).order(java.nio.ByteOrder.LITTLE_ENDIAN).putLong(initialState); @@ -189,28 +159,10 @@ private static MemorySegment lookbackPage(long initialState) { } @Nested - class EncodingIdTest { - + class EncodingIdNested { @Test void encodingId_isVortexPco() { - // Given / When / Then - assertThat(new PcoEncoding().encodingId()).isEqualTo(EncodingId.VORTEX_PCO); - } - } - - @Nested - class Encode { - - @Test - void encode_throwsVortexException() { - // Given - var sut = new PcoEncoding(); - DType dtype = new DType.Primitive(PType.I64, false); - - // When / Then - assertThatThrownBy(() -> sut.encode(dtype, new long[]{1L, 2L, 3L}, EncodeTestHelper.testCtx())) - .isInstanceOf(VortexException.class) - .hasMessageContaining("not implemented"); + assertThat(SUT.encodingId()).isEqualTo(EncodingId.VORTEX_PCO); } } @@ -219,51 +171,35 @@ class Decode { @Test void decode_nullMetadata_throwsMissingMeta() { - // Given - var sut = new PcoEncoding(); DecodeContext ctx = ctxWith(null, new DType.Primitive(PType.I64, false), 0, new MemorySegment[0]); - - // When / Then - assertThatThrownBy(() -> sut.decode(ctx)) + assertThatThrownBy(() -> SUT.decode(ctx)) .isInstanceOf(VortexException.class) .hasMessageContaining("missing PcoMetadata"); } @Test void decode_invalidHeaderVersion_throwsUnsupported() { - // Given - var sut = new PcoEncoding(); PcoMetadata meta = new PcoMetadata(new byte[]{0x03, 0x00}, java.util.List.of()); DecodeContext ctx = ctxWith(ByteBuffer.wrap(meta.encode()), new DType.Primitive(PType.I64, false), 0, new MemorySegment[0]); - - // When / Then - assertThatThrownBy(() -> sut.decode(ctx)) + assertThatThrownBy(() -> SUT.decode(ctx)) .isInstanceOf(VortexException.class) .hasMessageContaining("unsupported pco format version 03.00"); } @Test void decode_nonPrimitiveDtype_throws() { - // Given - var sut = new PcoEncoding(); DecodeContext ctx = ctxWith(validMetaBuffer(), new DType.Utf8(false), 0, new MemorySegment[0]); - - // When / Then - assertThatThrownBy(() -> sut.decode(ctx)) + assertThatThrownBy(() -> SUT.decode(ctx)) .isInstanceOf(VortexException.class) .hasMessageContaining("Primitive dtype"); } @Test void decode_unsupportedPtype_throws() { - // Given — F16 not supported by pco - var sut = new PcoEncoding(); DecodeContext ctx = ctxWith(validMetaBuffer(), new DType.Primitive(PType.F16, false), 0, new MemorySegment[0]); - - // When / Then - assertThatThrownBy(() -> sut.decode(ctx)) + assertThatThrownBy(() -> SUT.decode(ctx)) .isInstanceOf(VortexException.class) .hasMessageContaining("unsupported ptype"); } @@ -271,53 +207,25 @@ void decode_unsupportedPtype_throws() { @ParameterizedTest @EnumSource(value = PType.class, names = {"I16", "U16", "I32", "U32", "F32", "I64", "U64", "F64"}) void decode_zeroChunks_returnsEmptyArray(PType ptype) { - // Given — valid metadata with 0 chunks, 0 rows, any supported ptype - var sut = new PcoEncoding(); - DecodeContext ctx = ctxWith(validMetaBuffer(), new DType.Primitive(ptype, false), 0, - new MemorySegment[0]); - - // When - var result = sut.decode(ctx); - - // Then + DecodeContext ctx = ctxWith(validMetaBuffer(), new DType.Primitive(ptype, false), 0, new MemorySegment[0]); + var result = SUT.decode(ctx); assertThat(result.length()).isZero(); } @Test void decode_consecutiveDelta_order1_singleValue_decodes() { - // Given — U64 sequence [42] encoded with Classic mode, Consecutive delta order=1. - // With pageN=1 and order=1, decodedN=0: the single output value is the moment itself. - var sut = new PcoEncoding(); - DecodeContext ctx = ctxWith( - metaWithOneChunk(1), - new DType.Primitive(PType.U64, false), - 1, + DecodeContext ctx = ctxWith(metaWithOneChunk(1), new DType.Primitive(PType.U64, false), 1, new MemorySegment[]{chunkMetaConsecutive(1), pageWithMoments(42L)}); - - // When - var result = sut.decode(ctx); - - // Then + var result = SUT.decode(ctx); assertThat(result.length()).isEqualTo(1); assertThat(((LongArray) result).getLong(0)).isEqualTo(42L); } @Test void decode_consecutiveDelta_order2_twoValues_decodes() { - // Given — U64 sequence [10, 17] encoded with Consecutive delta order=2. - // With pageN=2, order=2: decodedN=0; moments=[m0=10, m1=delta1=7]. - // Expected reconstruction: [m0, m0+m1] = [10, 17]. - var sut = new PcoEncoding(); - DecodeContext ctx = ctxWith( - metaWithOneChunk(2), - new DType.Primitive(PType.U64, false), - 2, + DecodeContext ctx = ctxWith(metaWithOneChunk(2), new DType.Primitive(PType.U64, false), 2, new MemorySegment[]{chunkMetaConsecutive(2), pageWithMoments(10L, 7L)}); - - // When - var result = sut.decode(ctx); - - // Then + var result = SUT.decode(ctx); assertThat(result.length()).isEqualTo(2); assertThat(((LongArray) result).getLong(0)).isEqualTo(10L); assertThat(((LongArray) result).getLong(1)).isEqualTo(17L); @@ -325,22 +233,12 @@ void decode_consecutiveDelta_order2_twoValues_decodes() { @Test void decode_multiPage_singleChunk_decodes() { - // Given — 1 chunk, 2 pages each containing 1 value (Consecutive order=1). - // buffers: [chunkMeta, page0, page1]; page0 moment=10→value 10, page1 moment=20→value 20. - var sut = new PcoEncoding(); PcoMetadata meta = new PcoMetadata( new byte[]{PcoEncoding.PCO_FORMAT_MAJOR, PcoEncoding.PCO_FORMAT_MINOR}, java.util.List.of(new PcoChunkInfo(java.util.List.of(new PcoPageInfo(1), new PcoPageInfo(1))))); - DecodeContext ctx = ctxWith( - ByteBuffer.wrap(meta.encode()), - new DType.Primitive(PType.U64, false), - 2, + DecodeContext ctx = ctxWith(ByteBuffer.wrap(meta.encode()), new DType.Primitive(PType.U64, false), 2, new MemorySegment[]{chunkMetaConsecutive(1), pageWithMoments(10L), pageWithMoments(20L)}); - - // When - var result = sut.decode(ctx); - - // Then + var result = SUT.decode(ctx); assertThat(result.length()).isEqualTo(2); assertThat(((LongArray) result).getLong(0)).isEqualTo(10L); assertThat(((LongArray) result).getLong(1)).isEqualTo(20L); @@ -348,26 +246,15 @@ void decode_multiPage_singleChunk_decodes() { @Test void decode_multiChunk_decodes() { - // Given — 2 chunks each with 1 page containing 1 value (Consecutive order=1). - // buffers: [chunkMeta0, page0, chunkMeta1, page1]; values=[100, 200]. - var sut = new PcoEncoding(); PcoMetadata meta = new PcoMetadata( new byte[]{PcoEncoding.PCO_FORMAT_MAJOR, PcoEncoding.PCO_FORMAT_MINOR}, java.util.List.of( new PcoChunkInfo(java.util.List.of(new PcoPageInfo(1))), new PcoChunkInfo(java.util.List.of(new PcoPageInfo(1))))); - DecodeContext ctx = ctxWith( - ByteBuffer.wrap(meta.encode()), - new DType.Primitive(PType.U64, false), - 2, - new MemorySegment[]{ - chunkMetaConsecutive(1), pageWithMoments(100L), + DecodeContext ctx = ctxWith(ByteBuffer.wrap(meta.encode()), new DType.Primitive(PType.U64, false), 2, + new MemorySegment[]{chunkMetaConsecutive(1), pageWithMoments(100L), chunkMetaConsecutive(1), pageWithMoments(200L)}); - - // When - var result = sut.decode(ctx); - - // Then + var result = SUT.decode(ctx); assertThat(result.length()).isEqualTo(2); assertThat(((LongArray) result).getLong(0)).isEqualTo(100L); assertThat(((LongArray) result).getLong(1)).isEqualTo(200L); @@ -379,27 +266,16 @@ class DecodeNullable { @Test void decode_nullable_someNulls_scattersCorrectly() { - // Given — U64 sequence: 3 total rows, validity=[true,false,true], valid values=[100,200]. - // Validity bits LSB-first: bit0=1, bit1=0, bit2=1 → byte 0x05. - // Pco encodes only valid values: 1 chunk, 2 pages of nValues=1 (Consecutive order=1). - var sut = new PcoEncoding(); PcoMetadata meta = new PcoMetadata( new byte[]{PcoEncoding.PCO_FORMAT_MAJOR, PcoEncoding.PCO_FORMAT_MINOR}, java.util.List.of(new PcoChunkInfo(java.util.List.of(new PcoPageInfo(1), new PcoPageInfo(1))))); - MemorySegment validityBuf = segmentOf((byte) 0x05); // bits: 1,0,1 + MemorySegment validityBuf = segmentOf((byte) 0x05); DecodeContext ctx = ctxWithValidity( - ByteBuffer.wrap(meta.encode()), - new DType.Primitive(PType.U64, true), - 3, - validityBuf, + ByteBuffer.wrap(meta.encode()), new DType.Primitive(PType.U64, true), 3, validityBuf, new MemorySegment[]{chunkMetaConsecutive(1), pageWithMoments(100L), pageWithMoments(200L)}); + var result = SUT.decode(ctx); - // When - var result = sut.decode(ctx); - - // Then — MaskedArray with 3 slots; positions 0 and 2 valid, position 1 null assertThat(result).isInstanceOf(MaskedArray.class); - assertThat(result.length()).isEqualTo(3); MaskedArray masked = (MaskedArray) result; assertThat(masked.isValid(0)).isTrue(); assertThat(masked.isValid(1)).isFalse(); @@ -410,23 +286,12 @@ void decode_nullable_someNulls_scattersCorrectly() { @Test void decode_nullable_allNull_returnsAllZeroed() { - // Given — 2 total rows, validity=[false,false], validCount=0. Pco has 0 chunks. - // Validity bits LSB-first: 0x00. - var sut = new PcoEncoding(); MemorySegment validityBuf = segmentOf((byte) 0x00); - DecodeContext ctx = ctxWithValidity( - validMetaBuffer(), - new DType.Primitive(PType.U64, true), - 2, - validityBuf, - new MemorySegment[0]); + DecodeContext ctx = ctxWithValidity(validMetaBuffer(), new DType.Primitive(PType.U64, true), 2, + validityBuf, new MemorySegment[0]); + var result = SUT.decode(ctx); - // When - var result = sut.decode(ctx); - - // Then — MaskedArray, length 2, both null, values zeroed assertThat(result).isInstanceOf(MaskedArray.class); - assertThat(result.length()).isEqualTo(2); MaskedArray masked = (MaskedArray) result; assertThat(masked.isValid(0)).isFalse(); assertThat(masked.isValid(1)).isFalse(); @@ -436,23 +301,12 @@ void decode_nullable_allNull_returnsAllZeroed() { @Test void decode_nullable_allValid_returnsMaskedWithAllValues() { - // Given — 2 total rows, validity=[true,true], valid values=[10,20]. - // Validity bits: 0x03. - var sut = new PcoEncoding(); - MemorySegment validityBuf = segmentOf((byte) 0x03); // bits: 1,1 - DecodeContext ctx = ctxWithValidity( - metaWithOneChunk(2), - new DType.Primitive(PType.U64, true), - 2, - validityBuf, - new MemorySegment[]{chunkMetaConsecutive(2), pageWithMoments(10L, 10L)}); - - // When - var result = sut.decode(ctx); + MemorySegment validityBuf = segmentOf((byte) 0x03); + DecodeContext ctx = ctxWithValidity(metaWithOneChunk(2), new DType.Primitive(PType.U64, true), 2, + validityBuf, new MemorySegment[]{chunkMetaConsecutive(2), pageWithMoments(10L, 10L)}); + var result = SUT.decode(ctx); - // Then — MaskedArray, all valid, values [10, 20] assertThat(result).isInstanceOf(MaskedArray.class); - assertThat(result.length()).isEqualTo(2); MaskedArray masked = (MaskedArray) result; assertThat(masked.isValid(0)).isTrue(); assertThat(masked.isValid(1)).isTrue(); @@ -466,27 +320,14 @@ class DecodeConv1 { @Test void decode_conv1_order1_zeroPrediction_statePassedThrough() { - // Given — I32, pageN=2, order=1, bias=0, weight=0 → prediction always 0. - // State raw = value ^ 0x80000000: for value=5, state_raw=0x80000005. - // Residual from degenerate tANS=0 → decoded = 0 ^ mid_i32(0x80000000) = 0x80000000 - // → fromLatentOrdered(0x80000000, I32) = 0x80000000 ^ 0x80000000 = 0. - // Expected output: [5, 0]. - var sut = new PcoEncoding(); - long biasLatent = Long.MIN_VALUE; // encodes bias=0: raw ^ MIN_VALUE = 0 - long weightLatent = 0x80000000L; // encodes weight=0: (int)(raw ^ 0x80000000) = 0 + long biasLatent = Long.MIN_VALUE; + long weightLatent = 0x80000000L; MemorySegment chunkMeta = chunkMetaConv1(0, biasLatent, 1, new long[]{weightLatent}); - // Page: 1 × 32-bit state = 0x80000005 (encodes value 5), no residual bits. MemorySegment page = segmentOf((byte) 0x05, (byte) 0x00, (byte) 0x00, (byte) 0x80); - DecodeContext ctx = ctxWith( - metaWithOneChunk(2), - new DType.Primitive(PType.I32, false), - 2, + DecodeContext ctx = ctxWith(metaWithOneChunk(2), new DType.Primitive(PType.I32, false), 2, new MemorySegment[]{chunkMeta, page}); + var result = SUT.decode(ctx); - // When - var result = sut.decode(ctx); - - // Then assertThat(result.length()).isEqualTo(2); assertThat(((io.github.dfa1.vortex.core.array.IntArray) result).getInt(0)).isEqualTo(5); assertThat(((io.github.dfa1.vortex.core.array.IntArray) result).getInt(1)).isZero(); @@ -498,59 +339,29 @@ class DecodeLookback { @Test void decode_lookback_corruptIndexZero_throwsVortexException() { - // Given — Classic+Lookback, windowN=2, stateN=1, degenerate ANS (0 bins). - // Degenerate tANS always outputs lower=0; lb=0 is out of [1, windowN=2]. - // pageN=2: stateN=1 initial value + 1 decoded value with lb=0. - var sut = new PcoEncoding(); - DecodeContext ctx = ctxWith( - metaWithOneChunk(2), - new DType.Primitive(PType.U64, false), - 2, + DecodeContext ctx = ctxWith(metaWithOneChunk(2), new DType.Primitive(PType.U64, false), 2, new MemorySegment[]{chunkMetaLookback(), lookbackPage(0L)}); - - // When / Then - assertThatThrownBy(() -> sut.decode(ctx)) + assertThatThrownBy(() -> SUT.decode(ctx)) .isInstanceOf(VortexException.class) .hasMessageContaining("corrupt lookback index 0"); } @Test void decode_lookback_stateNExceedsPageN_throwsVortexException() { - // Given — stateN=2 (stateNLog=1) but pageN=1 → decodeN = 1-2 = -1 → corrupt. - // chunkMeta bit layout (all LE, LSB-first): - // byte0: mode=0[3:0], delta=2[7:4] → 0x20 - // byte1: windowNLog-1(5b)=0, stateNLog[0](1b)=1 → 0x20 - // byte2: stateNLog[1..3](3b)=0, secondary(1b)=0, deltaAnsSizeLog(4b)=0 → 0x00 - // bytes3-6: nDeltaBins(15b)=0, primaryAnsSizeLog(4b)=0, nBins(15b)=0 → 0x00 - var sut = new PcoEncoding(); MemorySegment chunkMeta = segmentOf( (byte) 0x20, (byte) 0x20, (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00); - DecodeContext ctx = ctxWith( - metaWithOneChunk(1), - new DType.Primitive(PType.U64, false), - 1, + DecodeContext ctx = ctxWith(metaWithOneChunk(1), new DType.Primitive(PType.U64, false), 1, new MemorySegment[]{chunkMeta, segmentOf((byte) 0x00)}); - - // When / Then - assertThatThrownBy(() -> sut.decode(ctx)) + assertThatThrownBy(() -> SUT.decode(ctx)) .isInstanceOf(VortexException.class) .hasMessageContaining("stateN"); } @Test void decode_lookback_singleInitialValue_returnsIt() { - // Given — pageN=1, stateN=1, decodeN=0: only the initial state value; no decoded values. - var sut = new PcoEncoding(); - DecodeContext ctx = ctxWith( - metaWithOneChunk(1), - new DType.Primitive(PType.U64, false), - 1, + DecodeContext ctx = ctxWith(metaWithOneChunk(1), new DType.Primitive(PType.U64, false), 1, new MemorySegment[]{chunkMetaLookback(), lookbackPage(42L)}); - - // When - var result = sut.decode(ctx); - - // Then + var result = SUT.decode(ctx); assertThat(result.length()).isEqualTo(1); assertThat(((LongArray) result).getLong(0)).isEqualTo(42L); } @@ -558,21 +369,12 @@ void decode_lookback_singleInitialValue_returnsIt() { @Nested class DecodeLookbackDecodeN { - @Test void lookback_decodeNExceedsMax_throwsVortexException() { - // Given — stateN=1, pageN=(1<<23)+2 → decodeN=(1<<23)+1 > cap 1<<23. - // Check fires before arena.allocate; page only needs 8 bytes (1×64-bit initial state). - var sut = new PcoEncoding(); int pageN = (1 << 23) + 2; - DecodeContext ctx = ctxWith( - metaWithOneChunk(pageN), - new DType.Primitive(PType.U64, false), - pageN, + DecodeContext ctx = ctxWith(metaWithOneChunk(pageN), new DType.Primitive(PType.U64, false), pageN, new MemorySegment[]{chunkMetaLookback(), segmentOf(new byte[8])}); - - // When / Then - assertThatThrownBy(() -> sut.decode(ctx)) + assertThatThrownBy(() -> SUT.decode(ctx)) .isInstanceOf(VortexException.class) .hasMessageContaining("decodeN"); } @@ -580,28 +382,14 @@ void lookback_decodeNExceedsMax_throwsVortexException() { @Nested class DecodeLookbackStateNWindow { - @Test void lookback_stateNExceedsWindowN_throwsVortexException() { - // Given — windowNLog=1 (windowN=2), stateNLog=2 (stateN=4) → stateN > windowN. - // pageN=4 (≥ stateN=4 passes the pageN sut.decode(ctx)) + assertThatThrownBy(() -> SUT.decode(ctx)) .isInstanceOf(VortexException.class) .hasMessageContaining("stateN"); } @@ -609,25 +397,13 @@ void lookback_stateNExceedsWindowN_throwsVortexException() { @Nested class DecodeLookbackWindowNLog { - @Test void lookback_windowNLogExceedsMax_throwsVortexException() { - // Given — mode=0 (Classic), delta=2 (Lookback), windowNLog=25 > max 24. - // Bit layout (LSB-first after byte0): - // byte0: mode=0[3:0], delta=2[7:4] → 0x20 - // byte1: windowNLog-1(5b)=24=0b11000 → 0x18; stateNLog bits start at bit5 - // bytes2-6: all 0x00 - var sut = new PcoEncoding(); MemorySegment chunkMeta = segmentOf( (byte) 0x20, (byte) 0x18, (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00); - DecodeContext ctx = ctxWith( - metaWithOneChunk(1), - new DType.Primitive(PType.U64, false), - 1, + DecodeContext ctx = ctxWith(metaWithOneChunk(1), new DType.Primitive(PType.U64, false), 1, new MemorySegment[]{chunkMeta, segmentOf((byte) 0x00)}); - - // When / Then - assertThatThrownBy(() -> sut.decode(ctx)) + assertThatThrownBy(() -> SUT.decode(ctx)) .isInstanceOf(VortexException.class) .hasMessageContaining("windowNLog"); } @@ -635,30 +411,17 @@ void lookback_windowNLogExceedsMax_throwsVortexException() { @Nested class DecodeDict { - @Test void dict_nUniqueExceedsMax_throwsVortexException() { - // Given — mode=4 (Dict), nUnique=65537 > max 65536. - // Bit layout (LSB-first): mode[3:0]=4, nUnique[28:4]=65537 - // combined = 4 | (65537 << 4) = 0x100014 - // bytes: 0x14, 0x00, 0x10, 0x00, 0x00 - var sut = new PcoEncoding(); - MemorySegment chunkMeta = segmentOf( - (byte) 0x14, (byte) 0x00, (byte) 0x10, (byte) 0x00, (byte) 0x00); - DecodeContext ctx = ctxWith( - metaWithOneChunk(1), - new DType.Primitive(PType.U64, false), - 1, + MemorySegment chunkMeta = segmentOf((byte) 0x14, (byte) 0x00, (byte) 0x10, (byte) 0x00, (byte) 0x00); + DecodeContext ctx = ctxWith(metaWithOneChunk(1), new DType.Primitive(PType.U64, false), 1, new MemorySegment[]{chunkMeta, segmentOf((byte) 0x00)}); - - // When / Then - assertThatThrownBy(() -> sut.decode(ctx)) + assertThatThrownBy(() -> SUT.decode(ctx)) .isInstanceOf(VortexException.class) .hasMessageContaining("nUnique"); } } - /// Adversarial coverage: malformed inputs must throw VortexException — never AIOOBE, NPE, or OOM. @Nested class Adversarial { @@ -680,115 +443,64 @@ static Stream pageBytesProvider() { }).limit(50); } - /// Random chunk-meta bytes — any exception must be a VortexException, not a JVM crash exception. @ParameterizedTest @MethodSource("chunkMetaBytesProvider") void randomChunkMetaBytes_neverThrowsJvmException(byte[] chunkMetaBytes) { - // Given — valid pco header + 1 chunk with 1 page of 1 value; garbage chunk-meta bytes. - var sut = new PcoEncoding(); - DecodeContext ctx = ctxWith( - metaWithOneChunk(1), - new DType.Primitive(PType.U64, false), - 1, + DecodeContext ctx = ctxWith(metaWithOneChunk(1), new DType.Primitive(PType.U64, false), 1, new MemorySegment[]{segmentOf(chunkMetaBytes), segmentOf((byte) 0x00)}); - - // When / Then — either succeeds or throws VortexException; never AIOOBE/NPE/OOM try { - sut.decode(ctx); + SUT.decode(ctx); } catch (VortexException ignored) { - // expected — malformed input } } - /// Random page bytes after a valid Classic-mode chunk meta — must not crash the JVM. @ParameterizedTest @MethodSource("pageBytesProvider") void randomPageBytes_classicMode_neverThrowsJvmException(byte[] pageBytes) { - // Given — Classic mode, delta=NoOp, ansSizeLog=0, nBins=0 chunk meta. - var sut = new PcoEncoding(); - // byte0: mode=0 (bits3:0), deltaVariant=0 (bits7:4) → 0x00 - // byte1: ansSizeLog=0 (bits3:0), nBins low bits = 0 - // bytes 2-3: nBins high bits = 0 - DecodeContext ctx = ctxWith( - metaWithOneChunk(1), - new DType.Primitive(PType.U64, false), - 1, + DecodeContext ctx = ctxWith(metaWithOneChunk(1), new DType.Primitive(PType.U64, false), 1, new MemorySegment[]{segmentOf((byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00), segmentOf(pageBytes)}); - - // When / Then try { - sut.decode(ctx); + SUT.decode(ctx); } catch (VortexException ignored) { - // expected — malformed page data } } - /// Invalid mode nibbles (5–15) must produce a VortexException naming the mode number. @ParameterizedTest @ValueSource(ints = {5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}) void invalidModeNibble_throwsVortexException(int modeNibble) { - // Given — chunk meta with unsupported mode nibble in bits[3:0]. - var sut = new PcoEncoding(); - // bits[3:0] = modeNibble, delta nibble doesn't matter (won't be reached) byte modeByte = (byte) (modeNibble & 0x0F); - DecodeContext ctx = ctxWith( - metaWithOneChunk(1), - new DType.Primitive(PType.U64, false), - 1, - new MemorySegment[]{ - segmentOf(modeByte, (byte) 0x00, (byte) 0x00, (byte) 0x00), + DecodeContext ctx = ctxWith(metaWithOneChunk(1), new DType.Primitive(PType.U64, false), 1, + new MemorySegment[]{segmentOf(modeByte, (byte) 0x00, (byte) 0x00, (byte) 0x00), segmentOf((byte) 0x00)}); - - // When / Then - assertThatThrownBy(() -> sut.decode(ctx)) + assertThatThrownBy(() -> SUT.decode(ctx)) .isInstanceOf(VortexException.class) .hasMessageContaining("pco mode " + modeNibble); } - /// Invalid delta variants (4–15) must produce a VortexException naming the variant number. @ParameterizedTest @ValueSource(ints = {4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}) void invalidDeltaVariant_throwsVortexException(int deltaVariant) { - // Given — Classic mode (nibble=0) + invalid delta nibble in bits[7:4]. - var sut = new PcoEncoding(); - // byte0: bits[3:0]=mode=0, bits[7:4]=deltaVariant byte modeDeltaByte = (byte) ((deltaVariant & 0x0F) << 4); - DecodeContext ctx = ctxWith( - metaWithOneChunk(1), - new DType.Primitive(PType.U64, false), - 1, - new MemorySegment[]{ - segmentOf(modeDeltaByte, (byte) 0x00, (byte) 0x00, (byte) 0x00), + DecodeContext ctx = ctxWith(metaWithOneChunk(1), new DType.Primitive(PType.U64, false), 1, + new MemorySegment[]{segmentOf(modeDeltaByte, (byte) 0x00, (byte) 0x00, (byte) 0x00), segmentOf((byte) 0x00)}); - - // When / Then - assertThatThrownBy(() -> sut.decode(ctx)) + assertThatThrownBy(() -> SUT.decode(ctx)) .isInstanceOf(VortexException.class) .hasMessageContaining("delta variant " + deltaVariant); } - /// Conv1 delta with 64-bit dtype must throw VortexException; pcodec only supports 16/32-bit Conv1. @ParameterizedTest @EnumSource(value = PType.class, names = {"I64", "U64", "F64"}) void conv1Delta_with64BitDtype_throwsVortexException(PType ptype) { - // Given — Conv1 delta variant (nibble=3 in bits[7:4]), Classic mode (nibble=0 in bits[3:0]). - // byte0: bits[3:0]=0 (Classic), bits[7:4]=3 (Conv1) → 0x30 - // Remaining bytes: conv1 bit fields (don't matter — error fires before parsing them). - var sut = new PcoEncoding(); - DecodeContext ctx = ctxWith( - metaWithOneChunk(1), - new DType.Primitive(ptype, false), - 1, + DecodeContext ctx = ctxWith(metaWithOneChunk(1), new DType.Primitive(ptype, false), 1, new MemorySegment[]{ segmentOf((byte) 0x30, (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00), segmentOf((byte) 0x00)}); - - // When / Then - assertThatThrownBy(() -> sut.decode(ctx)) + assertThatThrownBy(() -> SUT.decode(ctx)) .isInstanceOf(VortexException.class) .hasMessageContaining("Conv1"); } diff --git a/reader/src/test/java/io/github/dfa1/vortex/reader/decode/VariantEncodingDecoderTest.java b/reader/src/test/java/io/github/dfa1/vortex/reader/decode/VariantEncodingDecoderTest.java new file mode 100644 index 00000000..ed59d7a0 --- /dev/null +++ b/reader/src/test/java/io/github/dfa1/vortex/reader/decode/VariantEncodingDecoderTest.java @@ -0,0 +1,143 @@ +package io.github.dfa1.vortex.reader.decode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.PType; +import io.github.dfa1.vortex.core.array.Array; +import io.github.dfa1.vortex.core.array.NullArray; +import io.github.dfa1.vortex.core.array.VariantArray; +import io.github.dfa1.vortex.encoding.ArrayNode; +import io.github.dfa1.vortex.encoding.DecodeContext; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.Registry; +import io.github.dfa1.vortex.encoding.TestRegistry; +import io.github.dfa1.vortex.proto.Primitive; +import io.github.dfa1.vortex.proto.VariantMetadata; +import org.junit.jupiter.api.Test; + +import java.lang.foreign.Arena; +import java.lang.foreign.MemorySegment; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +class VariantEncodingDecoderTest { + + private static final DType VARIANT_DTYPE = new DType.Variant(false); + private static final int N = 3; + + private static final VariantEncodingDecoder SUT = new VariantEncodingDecoder(); + + private static ByteBuffer variantMetaWithShredded(io.github.dfa1.vortex.proto.DType shredded) { + return ByteBuffer.wrap(new VariantMetadata(shredded).encode()); + } + + private static ArrayNode primitiveChildNode(int segIdx) { + return ArrayNode.of(EncodingId.VORTEX_PRIMITIVE, null, new ArrayNode[0], new int[]{segIdx}, null); + } + + private static ArrayNode nullChildNode() { + return ArrayNode.of(EncodingId.VORTEX_NULL, null, new ArrayNode[0], new int[]{}, null); + } + + private static MemorySegment i32Segment(int... values) { + MemorySegment seg = MemorySegment.ofArray(new byte[values.length * 4]); + ByteBuffer bb = seg.asByteBuffer().order(ByteOrder.LITTLE_ENDIAN); + for (int v : values) { + bb.putInt(v); + } + return seg; + } + + @org.junit.jupiter.api.Test + void decode_withoutShredded_returnsCoreStorageOnly() { + ArrayNode coreNode = nullChildNode(); + ArrayNode variantNode = ArrayNode.of(EncodingId.VORTEX_VARIANT, null, + new ArrayNode[]{coreNode}, new int[]{}, null); + + Registry registry = TestRegistry.ofDecoders(SUT, new NullEncodingDecoder()); + DecodeContext ctx = new DecodeContext(variantNode, VARIANT_DTYPE, N, + new MemorySegment[0], registry, Arena.ofAuto()); + + Array result = SUT.decode(ctx); + + assertThat(result).isInstanceOf(VariantArray.class); + VariantArray va = (VariantArray) result; + assertThat(va.dtype()).isEqualTo(VARIANT_DTYPE); + assertThat(va.length()).isEqualTo(N); + assertThat(va.coreStorage()).isInstanceOf(NullArray.class); + assertThat(va.shredded()).isNull(); + } + + @Test + void decode_withShredded_decodesSecondChild() { + io.github.dfa1.vortex.proto.DType shreddedProto = io.github.dfa1.vortex.proto.DType.ofPrimitive( + new Primitive(io.github.dfa1.vortex.proto.PType.I32, false)); + ByteBuffer meta = variantMetaWithShredded(shreddedProto); + + ArrayNode coreNode = nullChildNode(); + ArrayNode shreddedNode = primitiveChildNode(0); + ArrayNode variantNode = ArrayNode.of(EncodingId.VORTEX_VARIANT, meta, + new ArrayNode[]{coreNode, shreddedNode}, new int[]{}, null); + + MemorySegment[] segments = {i32Segment(1, 2, 3)}; + Registry registry = TestRegistry.ofDecoders(SUT, new NullEncodingDecoder(), new PrimitiveEncodingDecoder()); + DecodeContext ctx = new DecodeContext(variantNode, VARIANT_DTYPE, N, + segments, registry, Arena.ofAuto()); + + Array result = SUT.decode(ctx); + + assertThat(result).isInstanceOf(VariantArray.class); + VariantArray va = (VariantArray) result; + assertThat(va.shredded()).isNotNull(); + assertThat(va.shredded().dtype()).isEqualTo(new DType.Primitive(PType.I32, false)); + assertThat(va.shredded().length()).isEqualTo(N); + } + + @Test + void decode_emptyMetadata_noShredded() { + ArrayNode coreNode = nullChildNode(); + ArrayNode variantNode = ArrayNode.of(EncodingId.VORTEX_VARIANT, ByteBuffer.allocate(0), + new ArrayNode[]{coreNode}, new int[]{}, null); + + Registry registry = TestRegistry.ofDecoders(SUT, new NullEncodingDecoder()); + DecodeContext ctx = new DecodeContext(variantNode, VARIANT_DTYPE, N, + new MemorySegment[0], registry, Arena.ofAuto()); + + Array result = SUT.decode(ctx); + + VariantArray va = (VariantArray) result; + assertThat(va.shredded()).isNull(); + } + + @Test + void decode_nullableDtype_preservedOnResult() { + DType nullableVariant = new DType.Variant(true); + ArrayNode coreNode = nullChildNode(); + ArrayNode variantNode = ArrayNode.of(EncodingId.VORTEX_VARIANT, null, + new ArrayNode[]{coreNode}, new int[]{}, null); + + Registry registry = TestRegistry.ofDecoders(SUT, new NullEncodingDecoder()); + DecodeContext ctx = new DecodeContext(variantNode, nullableVariant, N, + new MemorySegment[0], registry, Arena.ofAuto()); + + VariantArray va = (VariantArray) SUT.decode(ctx); + + assertThat(va.dtype()).isEqualTo(nullableVariant); + assertThat(va.dtype().nullable()).isTrue(); + } + + @Test + void decode_wrongChildCount_throws() { + ArrayNode variantNode = ArrayNode.of(EncodingId.VORTEX_VARIANT, null, + new ArrayNode[0], new int[]{}, null); + + Registry registry = TestRegistry.ofDecoders(SUT); + DecodeContext ctx = new DecodeContext(variantNode, VARIANT_DTYPE, N, + new MemorySegment[0], registry, Arena.ofAuto()); + + assertThatThrownBy(() -> SUT.decode(ctx)) + .hasMessageContaining("expected 1 or 2 children"); + } +} diff --git a/writer/pom.xml b/writer/pom.xml index a96b858f..b8e44d85 100644 --- a/writer/pom.xml +++ b/writer/pom.xml @@ -25,6 +25,13 @@ flatbuffers-java + + + io.github.dfa1.vortex + vortex-core + test-jar + test + io.github.dfa1.vortex vortex-reader diff --git a/writer/src/main/java/io/github/dfa1/vortex/writer/encode/AlpEncodingEncoder.java b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/AlpEncodingEncoder.java new file mode 100644 index 00000000..584bd252 --- /dev/null +++ b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/AlpEncodingEncoder.java @@ -0,0 +1,327 @@ +package io.github.dfa1.vortex.writer.encode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.PType; +import io.github.dfa1.vortex.encoding.AlpEncoding; +import io.github.dfa1.vortex.encoding.CascadeStep; +import io.github.dfa1.vortex.encoding.ChildSlot; +import io.github.dfa1.vortex.encoding.EncodeContext; +import io.github.dfa1.vortex.encoding.EncodeNode; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodingEncoder; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.PTypeIO; +import io.github.dfa1.vortex.proto.ALPMetadata; +import io.github.dfa1.vortex.proto.PatchesMetadata; +import io.github.dfa1.vortex.proto.ScalarValue; + +import java.lang.foreign.MemorySegment; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; + +/// Write-only encoder for {@code vortex.alp}. +public final class AlpEncodingEncoder implements EncodingEncoder { + + private static final int MAX_EXPONENT_F64 = 18; + private static final int MAX_EXPONENT_F32 = 10; + private static final int SAMPLE_SIZE = 512; + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public AlpEncodingEncoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_ALP; + } + + @Override + public boolean accepts(DType dtype) { + if (!(dtype instanceof DType.Primitive p)) { + return false; + } + return p.ptype() == PType.F64 || p.ptype() == PType.F32; + } + + @Override + public EncodeResult encode(DType dtype, Object data, EncodeContext ctx) { + PType ptype = ((DType.Primitive) dtype).ptype(); + return switch (ptype) { + case F64 -> encodeF64((double[]) data, ctx); + case F32 -> encodeF32((float[]) data, ctx); + default -> throw new UnsupportedOperationException("ALP encode not supported for " + ptype); + }; + } + + @Override + public CascadeStep encodeCascade(DType dtype, Object data, EncodeContext ctx) { + PType ptype = ((DType.Primitive) dtype).ptype(); + if (ptype == PType.F64) { + return encodeCascadeF64((double[]) data, ctx); + } + return CascadeStep.terminal(encode(dtype, data, ctx)); + } + + private static int[] findExponentsF64(double[] values) { + int sampleLen = Math.min(SAMPLE_SIZE, values.length); + int bestExpE = 0, bestExpF = 0, bestExceptions = sampleLen + 1; + + outer: + for (int expE = 0; expE <= MAX_EXPONENT_F64; expE++) { + for (int expF = 0; expF <= MAX_EXPONENT_F64; expF++) { + double ef = AlpEncoding.F10_F64[expE]; + double iff = AlpEncoding.IF10_F64[expF]; + double df = AlpEncoding.F10_F64[expF]; + double de = AlpEncoding.IF10_F64[expE]; + int exceptions = 0; + for (int i = 0; i < sampleLen; i++) { + double enc = values[i] * ef * iff; + if (!Double.isFinite(enc) || (double) Math.round(enc) * df * de != values[i]) { + exceptions++; + } + } + if (exceptions < bestExceptions) { + bestExceptions = exceptions; + bestExpE = expE; + bestExpF = expF; + if (bestExceptions == 0) { + break outer; + } + } + } + } + return new int[]{bestExpE, bestExpF}; + } + + private static AlpF64Data computeF64(double[] values) { + int n = values.length; + int[] exps = findExponentsF64(values); + int expE = exps[0], expF = exps[1]; + double ef = AlpEncoding.F10_F64[expE]; + double iff = AlpEncoding.IF10_F64[expF]; + double df = AlpEncoding.F10_F64[expF]; + double de = AlpEncoding.IF10_F64[expE]; + + long[] encodedArr = new long[n]; + var patchIndices = new ArrayList(); + var patchValues = new ArrayList(); + + double min = Double.MAX_VALUE, max = -Double.MAX_VALUE; + for (int i = 0; i < n; i++) { + double v = values[i]; + double enc = v * ef * iff; + long encoded; + if (Double.isFinite(enc) && (double) (encoded = Math.round(enc)) * df * de == v) { + encodedArr[i] = encoded; + } else { + encodedArr[i] = 0L; + patchIndices.add(i); + patchValues.add(v); + } + if (v < min) { + min = v; + } + if (v > max) { + max = v; + } + } + + byte[] statsMin = n > 0 ? scalarF64(min) : null; + byte[] statsMax = n > 0 ? scalarF64(max) : null; + return new AlpF64Data(expE, expF, encodedArr, patchIndices, patchValues, statsMin, statsMax); + } + + private static EncodeResult encodeF64(double[] values, EncodeContext ctx) { + AlpF64Data d = computeF64(values); + int n = values.length; + + MemorySegment encodedBuf = ctx.arena().allocate((long) n * 8, 8); + for (int i = 0; i < n; i++) { + encodedBuf.setAtIndex(PTypeIO.LE_LONG, i, d.encodedArr()[i]); + } + + EncodeNode encodedNode = EncodeNode.leaf(EncodingId.VORTEX_PRIMITIVE, 0); + + if (d.patchIndices().isEmpty()) { + byte[] metaBytes = new ALPMetadata(d.expE(), d.expF(), null).encode(); + EncodeNode root = new EncodeNode(EncodingId.VORTEX_ALP, + ByteBuffer.wrap(metaBytes), new EncodeNode[]{encodedNode}, new int[0]); + return new EncodeResult(root, List.of(encodedBuf), d.statsMin(), d.statsMax()); + } + + int numPatches = d.patchIndices().size(); + MemorySegment idxBuf = ctx.arena().allocate((long) numPatches * 4, 4); + MemorySegment valBuf = ctx.arena().allocate((long) numPatches * 8, 8); + for (int i = 0; i < numPatches; i++) { + idxBuf.setAtIndex(PTypeIO.LE_INT, i, d.patchIndices().get(i)); + valBuf.setAtIndex(PTypeIO.LE_DOUBLE, i, d.patchValues().get(i)); + } + + PatchesMetadata patches = buildPatchesMeta(numPatches); + byte[] metaBytes = new ALPMetadata(d.expE(), d.expF(), patches).encode(); + + EncodeNode idxNode = EncodeNode.leaf(EncodingId.VORTEX_PRIMITIVE, 1); + EncodeNode valNode = EncodeNode.leaf(EncodingId.VORTEX_PRIMITIVE, 2); + EncodeNode root = new EncodeNode(EncodingId.VORTEX_ALP, + ByteBuffer.wrap(metaBytes), + new EncodeNode[]{encodedNode, idxNode, valNode}, + new int[0]); + return new EncodeResult(root, List.of(encodedBuf, idxBuf, valBuf), d.statsMin(), d.statsMax()); + } + + private static CascadeStep encodeCascadeF64(double[] values, EncodeContext ctx) { + AlpF64Data d = computeF64(values); + if (d.patchIndices().isEmpty()) { + byte[] metaBytes = new ALPMetadata(d.expE(), d.expF(), null).encode(); + EncodeNode partialRoot = new EncodeNode(EncodingId.VORTEX_ALP, + ByteBuffer.wrap(metaBytes), new EncodeNode[1], new int[0]); + ChildSlot slot = new ChildSlot(AlpEncoding.I64_DTYPE, d.encodedArr(), 0); + return new CascadeStep(partialRoot, List.of(), List.of(slot), d.statsMin(), d.statsMax(), true); + } + + int numPatches = d.patchIndices().size(); + MemorySegment idxBuf = ctx.arena().allocate((long) numPatches * 4, 4); + MemorySegment valBuf = ctx.arena().allocate((long) numPatches * 8, 8); + for (int i = 0; i < numPatches; i++) { + idxBuf.setAtIndex(PTypeIO.LE_INT, i, d.patchIndices().get(i)); + valBuf.setAtIndex(PTypeIO.LE_DOUBLE, i, d.patchValues().get(i)); + } + + PatchesMetadata patches = buildPatchesMeta(numPatches); + byte[] metaBytes = new ALPMetadata(d.expE(), d.expF(), patches).encode(); + + EncodeNode idxNode = EncodeNode.leaf(EncodingId.VORTEX_PRIMITIVE, 0); + EncodeNode valNode = EncodeNode.leaf(EncodingId.VORTEX_PRIMITIVE, 1); + EncodeNode partialRoot = new EncodeNode(EncodingId.VORTEX_ALP, + ByteBuffer.wrap(metaBytes), new EncodeNode[]{null, idxNode, valNode}, new int[0]); + ChildSlot slot = new ChildSlot(AlpEncoding.I64_DTYPE, d.encodedArr(), 0); + return new CascadeStep(partialRoot, List.of(idxBuf, valBuf), List.of(slot), d.statsMin(), d.statsMax(), true); + } + + private static int[] findExponentsF32(float[] values) { + int sampleLen = Math.min(SAMPLE_SIZE, values.length); + int bestExpE = 0, bestExpF = 0, bestExceptions = sampleLen + 1; + + outer: + for (int expE = 0; expE <= MAX_EXPONENT_F32; expE++) { + for (int expF = 0; expF <= MAX_EXPONENT_F32; expF++) { + float ef = AlpEncoding.F10_F32[expE]; + float iff = AlpEncoding.IF10_F32[expF]; + float df = AlpEncoding.F10_F32[expF]; + float de = AlpEncoding.IF10_F32[expE]; + int exceptions = 0; + for (int i = 0; i < sampleLen; i++) { + float enc = values[i] * ef * iff; + if (!Float.isFinite(enc) || (float) Math.round(enc) * df * de != values[i]) { + exceptions++; + } + } + if (exceptions < bestExceptions) { + bestExceptions = exceptions; + bestExpE = expE; + bestExpF = expF; + if (bestExceptions == 0) { + break outer; + } + } + } + } + return new int[]{bestExpE, bestExpF}; + } + + private static EncodeResult encodeF32(float[] values, EncodeContext ctx) { + int n = values.length; + int[] exps = findExponentsF32(values); + int expE = exps[0], expF = exps[1]; + float ef = AlpEncoding.F10_F32[expE]; + float iff = AlpEncoding.IF10_F32[expF]; + float df = AlpEncoding.F10_F32[expF]; + float de = AlpEncoding.IF10_F32[expE]; + + int[] encodedArr = new int[n]; + var patchIndices = new ArrayList(); + var patchValues = new ArrayList(); + + float min = Float.MAX_VALUE, max = -Float.MAX_VALUE; + for (int i = 0; i < n; i++) { + float v = values[i]; + float enc = v * ef * iff; + int encoded; + if (Float.isFinite(enc) && (float) (encoded = Math.round(enc)) * df * de == v) { + encodedArr[i] = encoded; + } else { + encodedArr[i] = 0; + patchIndices.add(i); + patchValues.add(v); + } + if (v < min) { + min = v; + } + if (v > max) { + max = v; + } + } + + byte[] statsMin = n > 0 ? scalarF32(min) : null; + byte[] statsMax = n > 0 ? scalarF32(max) : null; + + MemorySegment encodedBuf = ctx.arena().allocate((long) n * 4, 4); + for (int i = 0; i < n; i++) { + encodedBuf.setAtIndex(PTypeIO.LE_INT, i, encodedArr[i]); + } + + EncodeNode encodedNode = EncodeNode.leaf(EncodingId.VORTEX_PRIMITIVE, 0); + + if (patchIndices.isEmpty()) { + byte[] metaBytes = new ALPMetadata(expE, expF, null).encode(); + EncodeNode root = new EncodeNode(EncodingId.VORTEX_ALP, + ByteBuffer.wrap(metaBytes), new EncodeNode[]{encodedNode}, new int[0]); + return new EncodeResult(root, List.of(encodedBuf), statsMin, statsMax); + } + + int numPatches = patchIndices.size(); + MemorySegment idxBuf = ctx.arena().allocate((long) numPatches * 4, 4); + MemorySegment valBuf = ctx.arena().allocate((long) numPatches * 4, 4); + for (int i = 0; i < numPatches; i++) { + idxBuf.setAtIndex(PTypeIO.LE_INT, i, patchIndices.get(i)); + valBuf.setAtIndex(PTypeIO.LE_FLOAT, i, patchValues.get(i)); + } + + PatchesMetadata patches = new PatchesMetadata( + numPatches, + 0L, + io.github.dfa1.vortex.proto.PType.fromValue(PType.U32.ordinal()), + null, null, null); + byte[] metaBytes = new ALPMetadata(expE, expF, patches).encode(); + + EncodeNode idxNode = EncodeNode.leaf(EncodingId.VORTEX_PRIMITIVE, 1); + EncodeNode valNode = EncodeNode.leaf(EncodingId.VORTEX_PRIMITIVE, 2); + EncodeNode root = new EncodeNode(EncodingId.VORTEX_ALP, + ByteBuffer.wrap(metaBytes), + new EncodeNode[]{encodedNode, idxNode, valNode}, + new int[0]); + return new EncodeResult(root, List.of(encodedBuf, idxBuf, valBuf), statsMin, statsMax); + } + + private static PatchesMetadata buildPatchesMeta(int numPatches) { + return new PatchesMetadata( + numPatches, + 0L, + io.github.dfa1.vortex.proto.PType.fromValue(PType.U32.ordinal()), + null, null, null); + } + + private static byte[] scalarF64(double v) { + return ScalarValue.ofF64Value(v).encode(); + } + + private static byte[] scalarF32(float v) { + return ScalarValue.ofF32Value(v).encode(); + } + + private record AlpF64Data(int expE, int expF, long[] encodedArr, + List patchIndices, List patchValues, + byte[] statsMin, byte[] statsMax) { + } +} diff --git a/writer/src/main/java/io/github/dfa1/vortex/writer/encode/AlpRdEncodingEncoder.java b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/AlpRdEncodingEncoder.java new file mode 100644 index 00000000..7af71f58 --- /dev/null +++ b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/AlpRdEncodingEncoder.java @@ -0,0 +1,327 @@ +package io.github.dfa1.vortex.writer.encode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.PType; +import io.github.dfa1.vortex.encoding.AlpRdEncoding; +import io.github.dfa1.vortex.encoding.EncodeContext; +import io.github.dfa1.vortex.encoding.EncodeNode; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.Encoding; +import io.github.dfa1.vortex.encoding.EncodingEncoder; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.proto.ALPRDMetadata; +import io.github.dfa1.vortex.proto.PatchesMetadata; + +import java.lang.foreign.MemorySegment; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/// Write-only encoder for {@code vortex.alprd}. +public final class AlpRdEncodingEncoder implements EncodingEncoder { + + private static final int SAMPLE_SIZE = 512; + private static final int MAX_CUT = 16; + private static final int MAX_DICT_SIZE = 8; + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public AlpRdEncodingEncoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_ALPRD; + } + + @Override + public boolean accepts(DType dtype) { + if (!(dtype instanceof DType.Primitive p)) { + return false; + } + return p.ptype() == PType.F32 || p.ptype() == PType.F64; + } + + @Override + public EncodeResult encode(DType dtype, Object data, EncodeContext ctx) { + PType ptype = ((DType.Primitive) dtype).ptype(); + return switch (ptype) { + case F64 -> encodeF64((double[]) data, ctx); + case F32 -> encodeF32((float[]) data, ctx); + default -> throw new UnsupportedOperationException("ALP-RD encode not supported for " + ptype); + }; + } + + private static EncodeResult encodeF64(double[] values, EncodeContext ctx) { + int n = values.length; + if (n == 0) { + return emptyResult(AlpRdEncoding.U64_DTYPE, ctx); + } + + int sampleLen = Math.min(SAMPLE_SIZE, n); + Dictionary64 best = findBestDictionaryF64(values, sampleLen); + + Map lookup = buildLookup(best.dict); + long rightMask = -1L >>> (64 - best.rightBitWidth); + + short[] leftCodes = new short[n]; + long[] rightParts = new long[n]; + List excPos = new ArrayList<>(); + List excVals = new ArrayList<>(); + + for (int i = 0; i < n; i++) { + long bits = Double.doubleToRawLongBits(values[i]); + short leftU16 = (short) (bits >>> best.rightBitWidth); + rightParts[i] = bits & rightMask; + Short code = lookup.get(leftU16); + if (code != null) { + leftCodes[i] = code; + } else { + leftCodes[i] = 0; + excPos.add((long) i); + excVals.add(leftU16); + } + } + + return buildEncodeResult( + best.dict, best.rightBitWidth, leftCodes, rightParts, + AlpRdEncoding.U64_DTYPE, excPos, excVals, ctx); + } + + private static Dictionary64 findBestDictionaryF64(double[] values, int sampleLen) { + double bestEstSize = Double.MAX_VALUE; + int bestRightBw = 48; + short[] bestDict = new short[]{0}; + + for (int p = 1; p <= MAX_CUT; p++) { + int rightBw = 64 - p; + Map counts = new HashMap<>(); + for (int i = 0; i < sampleLen; i++) { + long bits = Double.doubleToRawLongBits(values[i]); + short leftU16 = (short) (bits >>> rightBw); + counts.merge(leftU16, 1, Integer::sum); + } + short[] dict = topKByCount(counts); + int excCount = countExceptionsF64(values, sampleLen, dict, rightBw); + int maxCode = dict.length - 1; + int leftBw = maxCode == 0 ? 1 : (Integer.SIZE - Integer.numberOfLeadingZeros(maxCode)); + double estSize = rightBw + leftBw + (double) (excCount * 32) / sampleLen; + if (estSize < bestEstSize) { + bestEstSize = estSize; + bestRightBw = rightBw; + bestDict = dict; + } + } + return new Dictionary64(bestDict, bestRightBw); + } + + private static int countExceptionsF64(double[] values, int sampleLen, short[] dict, int rightBw) { + Map dictSet = new HashMap<>(); + for (short d : dict) { + dictSet.put(d, Boolean.TRUE); + } + int count = 0; + for (int i = 0; i < sampleLen; i++) { + long bits = Double.doubleToRawLongBits(values[i]); + short leftU16 = (short) (bits >>> rightBw); + if (!dictSet.containsKey(leftU16)) { + count++; + } + } + return count; + } + + private static EncodeResult encodeF32(float[] values, EncodeContext ctx) { + int n = values.length; + if (n == 0) { + return emptyResult(AlpRdEncoding.U32_DTYPE, ctx); + } + + int sampleLen = Math.min(SAMPLE_SIZE, n); + Dictionary32 best = findBestDictionaryF32(values, sampleLen); + + Map lookup = buildLookup(best.dict); + int rightMask = -1 >>> (32 - best.rightBitWidth); + + short[] leftCodes = new short[n]; + int[] rightParts = new int[n]; + List excPos = new ArrayList<>(); + List excVals = new ArrayList<>(); + + for (int i = 0; i < n; i++) { + int bits = Float.floatToRawIntBits(values[i]); + short leftU16 = (short) (bits >>> best.rightBitWidth); + rightParts[i] = bits & rightMask; + Short code = lookup.get(leftU16); + if (code != null) { + leftCodes[i] = code; + } else { + leftCodes[i] = 0; + excPos.add((long) i); + excVals.add(leftU16); + } + } + + return buildEncodeResult( + best.dict, best.rightBitWidth, leftCodes, rightParts, + AlpRdEncoding.U32_DTYPE, excPos, excVals, ctx); + } + + private static Dictionary32 findBestDictionaryF32(float[] values, int sampleLen) { + double bestEstSize = Double.MAX_VALUE; + int bestRightBw = 16; + short[] bestDict = new short[]{0}; + + for (int p = 1; p <= MAX_CUT; p++) { + int rightBw = 32 - p; + Map counts = new HashMap<>(); + for (int i = 0; i < sampleLen; i++) { + int bits = Float.floatToRawIntBits(values[i]); + short leftU16 = (short) (bits >>> rightBw); + counts.merge(leftU16, 1, Integer::sum); + } + short[] dict = topKByCount(counts); + int excCount = countExceptionsF32(values, sampleLen, dict, rightBw); + int maxCode = dict.length - 1; + int leftBw = maxCode == 0 ? 1 : (Integer.SIZE - Integer.numberOfLeadingZeros(maxCode)); + double estSize = rightBw + leftBw + (double) (excCount * 32) / sampleLen; + if (estSize < bestEstSize) { + bestEstSize = estSize; + bestRightBw = rightBw; + bestDict = dict; + } + } + return new Dictionary32(bestDict, bestRightBw); + } + + private static int countExceptionsF32(float[] values, int sampleLen, short[] dict, int rightBw) { + Map dictSet = new HashMap<>(); + for (short d : dict) { + dictSet.put(d, Boolean.TRUE); + } + int count = 0; + for (int i = 0; i < sampleLen; i++) { + int bits = Float.floatToRawIntBits(values[i]); + short leftU16 = (short) (bits >>> rightBw); + if (!dictSet.containsKey(leftU16)) { + count++; + } + } + return count; + } + + private static short[] topKByCount(Map counts) { + List> sorted = new ArrayList<>(counts.entrySet()); + sorted.sort((a, b) -> b.getValue() - a.getValue()); + int dictSize = Math.min(sorted.size(), MAX_DICT_SIZE); + short[] dict = new short[dictSize]; + for (int i = 0; i < dictSize; i++) { + dict[i] = sorted.get(i).getKey(); + } + return dict; + } + + private static Map buildLookup(short[] dict) { + Map lookup = new HashMap<>(); + for (short i = 0; i < dict.length; i++) { + lookup.put(dict[i], i); + } + return lookup; + } + + private static EncodeResult buildEncodeResult( + short[] dict, int rightBitWidth, + short[] leftCodes, Object rightPartsData, DType rightDtype, + List excPos, List excVals, EncodeContext ctx) { + + Encoding bp = ctx.lookupEncoding(EncodingId.FASTLANES_BITPACKED); + EncodeResult leftResult = bp.encode(AlpRdEncoding.U16_DTYPE, leftCodes, ctx); + EncodeResult rightResult = bp.encode(rightDtype, rightPartsData, ctx); + + List allBuffers = new ArrayList<>(leftResult.buffers()); + int leftBufCount = allBuffers.size(); + allBuffers.addAll(rightResult.buffers()); + + EncodeNode leftNode = EncodeNode.remapBufferIndices(leftResult.rootNode(), 0); + EncodeNode rightNode = EncodeNode.remapBufferIndices(rightResult.rootNode(), leftBufCount); + + List dictList = new ArrayList<>(dict.length); + for (short d : dict) { + dictList.add(d & 0xFFFF); + } + + EncodeNode[] children; + PatchesMetadata patchesMeta = null; + if (excPos.isEmpty()) { + children = new EncodeNode[]{leftNode, rightNode}; + } else { + long[] excPosArr = excPos.stream().mapToLong(Long::longValue).toArray(); + short[] excValsArr = new short[excVals.size()]; + for (int i = 0; i < excVals.size(); i++) { + excValsArr[i] = excVals.get(i); + } + + EncodeResult idxResult = bp.encode(AlpRdEncoding.U64_DTYPE, excPosArr, ctx); + EncodeResult valResult = bp.encode(AlpRdEncoding.U16_DTYPE, excValsArr, ctx); + + int idxOffset = allBuffers.size(); + allBuffers.addAll(idxResult.buffers()); + int idxBufCount = idxResult.buffers().size(); + allBuffers.addAll(valResult.buffers()); + + EncodeNode idxNode = EncodeNode.remapBufferIndices(idxResult.rootNode(), idxOffset); + EncodeNode valNode = EncodeNode.remapBufferIndices(valResult.rootNode(), idxOffset + idxBufCount); + + patchesMeta = new PatchesMetadata( + (long) excPos.size(), + 0L, + io.github.dfa1.vortex.proto.PType.fromValue(PType.U64.ordinal()), + null, null, null); + children = new EncodeNode[]{leftNode, rightNode, idxNode, valNode}; + } + + byte[] metaBytes = new ALPRDMetadata( + rightBitWidth, + dict.length, + dictList, + io.github.dfa1.vortex.proto.PType.fromValue(PType.U16.ordinal()), + patchesMeta + ).encode(); + EncodeNode root = new EncodeNode( + EncodingId.VORTEX_ALPRD, ByteBuffer.wrap(metaBytes), children, new int[]{}); + return new EncodeResult(root, List.copyOf(allBuffers), null, null); + } + + private static EncodeResult emptyResult(DType rightDtype, EncodeContext ctx) { + Encoding bp = ctx.lookupEncoding(EncodingId.FASTLANES_BITPACKED); + EncodeResult leftResult = bp.encode(AlpRdEncoding.U16_DTYPE, new short[0], ctx); + EncodeResult rightResult = bp.encode(rightDtype, + rightDtype.equals(AlpRdEncoding.U32_DTYPE) ? new int[0] : new long[0], ctx); + + List allBuffers = new ArrayList<>(leftResult.buffers()); + int leftBufCount = allBuffers.size(); + allBuffers.addAll(rightResult.buffers()); + + EncodeNode leftNode = EncodeNode.remapBufferIndices(leftResult.rootNode(), 0); + EncodeNode rightNode = EncodeNode.remapBufferIndices(rightResult.rootNode(), leftBufCount); + + byte[] metaBytes = new ALPRDMetadata( + 48, + 0, + List.of(), + io.github.dfa1.vortex.proto.PType.fromValue(PType.U16.ordinal()), + null).encode(); + + EncodeNode root = new EncodeNode( + EncodingId.VORTEX_ALPRD, ByteBuffer.wrap(metaBytes), + new EncodeNode[]{leftNode, rightNode}, new int[]{}); + return new EncodeResult(root, List.copyOf(allBuffers), null, null); + } + + private record Dictionary64(short[] dict, int rightBitWidth) { + } + + private record Dictionary32(short[] dict, int rightBitWidth) { + } +} diff --git a/writer/src/main/java/io/github/dfa1/vortex/writer/encode/BitpackedEncodingEncoder.java b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/BitpackedEncodingEncoder.java new file mode 100644 index 00000000..7742f6c3 --- /dev/null +++ b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/BitpackedEncodingEncoder.java @@ -0,0 +1,227 @@ +package io.github.dfa1.vortex.writer.encode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.PType; +import io.github.dfa1.vortex.core.VortexException; +import io.github.dfa1.vortex.encoding.BitpackedEncoding; +import io.github.dfa1.vortex.encoding.EncodeContext; +import io.github.dfa1.vortex.encoding.EncodeNode; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodingEncoder; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.PTypeIO; +import io.github.dfa1.vortex.proto.BitPackedMetadata; +import io.github.dfa1.vortex.proto.ScalarValue; + +import java.lang.foreign.Arena; +import java.lang.foreign.MemorySegment; +import java.lang.foreign.ValueLayout; +import java.nio.ByteBuffer; +import java.util.List; + +/// Write-only encoder for {@code fastlanes.bitpacked}. +public final class BitpackedEncodingEncoder implements EncodingEncoder { + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public BitpackedEncodingEncoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.FASTLANES_BITPACKED; + } + + @Override + public boolean accepts(DType dtype) { + if (!(dtype instanceof DType.Primitive p)) { + return false; + } + return switch (p.ptype()) { + case I8, I16, I32, I64, U8, U16, U32, U64 -> true; + default -> false; + }; + } + + @Override + public EncodeResult encode(DType dtype, Object data, EncodeContext ctx) { + PType ptype = ((DType.Primitive) dtype).ptype(); + long[] longs = toLongs(data, ptype); + int n = longs.length; + int typeBits = ptype.byteSize() * 8; + long typeMask = typeMask(typeBits); + boolean unsign = isUnsigned(ptype); + + long signedMin = 0L; + long signedMax = 0L; + long maxUnsigned = 0L; + int bitWidth = 0; + + if (n > 0) { + signedMin = longs[0]; + signedMax = longs[0]; + for (long v : longs) { + if (unsign ? Long.compareUnsigned(v, signedMin) < 0 : v < signedMin) { + signedMin = v; + } + if (unsign ? Long.compareUnsigned(v, signedMax) > 0 : v > signedMax) { + signedMax = v; + } + long uv = v & typeMask; + if (Long.compareUnsigned(uv, maxUnsigned) > 0) { + maxUnsigned = uv; + } + } + bitWidth = maxUnsigned == 0L ? 0 : (Long.SIZE - Long.numberOfLeadingZeros(maxUnsigned)); + } + + MemorySegment packed = packFastLanes(longs, n, bitWidth, typeBits, ctx.arena()); + + byte[] metaBytes = new BitPackedMetadata(bitWidth, 0, null).encode(); + + byte[] statsMin = n > 0 ? statsBytes(ptype, signedMin) : null; + byte[] statsMax = n > 0 ? statsBytes(ptype, signedMax) : null; + + EncodeNode root = new EncodeNode(EncodingId.FASTLANES_BITPACKED, ByteBuffer.wrap(metaBytes), + new EncodeNode[0], new int[]{0}); + return new EncodeResult(root, List.of(packed), statsMin, statsMax); + } + + private static MemorySegment packFastLanes(long[] values, int n, int bitWidth, int typeBits, Arena arena) { + if (bitWidth == 0 || n == 0) { + return MemorySegment.ofArray(new byte[0]); + } + int lanes = 1024 / typeBits; + int wordBytes = typeBits / 8; + int blockCount = (n + 1023) / 1024; + long typeMask = typeMask(typeBits); + MemorySegment seg = arena.allocate((long) blockCount * 128 * bitWidth); + + for (int block = 0; block < blockCount; block++) { + int blockByteOff = block * 128 * bitWidth; + int blockStart = block * 1024; + + for (int row = 0; row < typeBits; row++) { + int currWord = (row * bitWidth) / typeBits; + int nextWord = ((row + 1) * bitWidth) / typeBits; + int shift = (row * bitWidth) % typeBits; + int remainingBits = (nextWord > currWord) ? ((row + 1) * bitWidth) % typeBits : 0; + int currentBits = bitWidth - remainingBits; + + for (int lane = 0; lane < lanes; lane++) { + int o = row / 8; + int s = row % 8; + int logicalIdx = blockStart + BitpackedEncoding.FL_ORDER[o] * 16 + s * 128 + lane; + long value = (logicalIdx < n) ? (values[logicalIdx] & typeMask) : 0L; + + int wordOff = blockByteOff + (lanes * currWord + lane) * wordBytes; + long existing = readWordFromSeg(seg, wordOff, typeBits); + existing |= (value << shift) & typeMask; + writeWordToSeg(seg, wordOff, existing, typeBits); + + if (remainingBits > 0) { + int hiWordOff = blockByteOff + (lanes * nextWord + lane) * wordBytes; + long existingHi = readWordFromSeg(seg, hiWordOff, typeBits); + existingHi |= (value >>> currentBits) & typeMask; + writeWordToSeg(seg, hiWordOff, existingHi, typeBits); + } + } + } + } + return seg; + } + + private static long[] toLongs(Object data, PType ptype) { + return switch (ptype) { + case I8 -> { + byte[] arr = (byte[]) data; + long[] r = new long[arr.length]; + for (int i = 0; i < arr.length; i++) { + r[i] = arr[i]; + } + yield r; + } + case U8 -> { + byte[] arr = (byte[]) data; + long[] r = new long[arr.length]; + for (int i = 0; i < arr.length; i++) { + r[i] = Byte.toUnsignedLong(arr[i]); + } + yield r; + } + case I16 -> { + short[] arr = (short[]) data; + long[] r = new long[arr.length]; + for (int i = 0; i < arr.length; i++) { + r[i] = arr[i]; + } + yield r; + } + case U16 -> { + short[] arr = (short[]) data; + long[] r = new long[arr.length]; + for (int i = 0; i < arr.length; i++) { + r[i] = Short.toUnsignedLong(arr[i]); + } + yield r; + } + case I32 -> { + int[] arr = (int[]) data; + long[] r = new long[arr.length]; + for (int i = 0; i < arr.length; i++) { + r[i] = arr[i]; + } + yield r; + } + case U32 -> { + int[] arr = (int[]) data; + long[] r = new long[arr.length]; + for (int i = 0; i < arr.length; i++) { + r[i] = Integer.toUnsignedLong(arr[i]); + } + yield r; + } + case I64, U64 -> (long[]) data; + default -> throw new VortexException(EncodingId.FASTLANES_BITPACKED, "unsupported ptype: " + ptype); + }; + } + + private static long typeMask(int typeBits) { + return typeBits == 64 ? -1L : (1L << typeBits) - 1L; + } + + private static boolean isUnsigned(PType ptype) { + return switch (ptype) { + case U8, U16, U32, U64 -> true; + default -> false; + }; + } + + private static byte[] statsBytes(PType ptype, long value) { + if (isUnsigned(ptype)) { + return ScalarValue.ofUint64Value(value).encode(); + } + return ScalarValue.ofInt64Value(value).encode(); + } + + private static long readWordFromSeg(MemorySegment seg, int off, int typeBits) { + return switch (typeBits) { + case 8 -> Byte.toUnsignedLong(seg.get(ValueLayout.JAVA_BYTE, off)); + case 16 -> Short.toUnsignedLong(seg.get(PTypeIO.LE_SHORT, off)); + case 32 -> Integer.toUnsignedLong(seg.get(PTypeIO.LE_INT, off)); + case 64 -> seg.get(PTypeIO.LE_LONG, off); + default -> + throw new VortexException(EncodingId.FASTLANES_BITPACKED, "unsupported typeBits: " + typeBits); + }; + } + + private static void writeWordToSeg(MemorySegment seg, int off, long value, int typeBits) { + switch (typeBits) { + case 8 -> seg.set(ValueLayout.JAVA_BYTE, off, (byte) value); + case 16 -> seg.set(PTypeIO.LE_SHORT, off, (short) value); + case 32 -> seg.set(PTypeIO.LE_INT, off, (int) value); + case 64 -> seg.set(PTypeIO.LE_LONG, off, value); + default -> + throw new VortexException(EncodingId.FASTLANES_BITPACKED, "unsupported typeBits: " + typeBits); + } + } +} diff --git a/writer/src/main/java/io/github/dfa1/vortex/writer/encode/BoolEncodingEncoder.java b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/BoolEncodingEncoder.java new file mode 100644 index 00000000..cfd0501d --- /dev/null +++ b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/BoolEncodingEncoder.java @@ -0,0 +1,75 @@ +package io.github.dfa1.vortex.writer.encode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.encoding.EncodeContext; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodingEncoder; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.proto.ScalarValue; + +import java.lang.foreign.Arena; +import java.lang.foreign.MemorySegment; +import java.lang.foreign.ValueLayout; + +/// Write-only encoder for {@code vortex.bool} (bit-packed boolean arrays, LSB first). +/// +///

ADR 0001 Phase 3: first encoding lifted into a standalone {@link EncodingEncoder} +/// implementation in the {@code writer} module. The corresponding read-side decode path +/// lives on {@link io.github.dfa1.vortex.reader.decode.BoolEncodingDecoder} in +/// {@code reader}. +public final class BoolEncodingEncoder implements EncodingEncoder { + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public BoolEncodingEncoder() { + } + + private static MemorySegment encodeBool(boolean[] data, Arena arena) { + long packedBytes = (data.length + 7L) / 8; + if (packedBytes == 0) { + return MemorySegment.NULL; + } + MemorySegment seg = arena.allocate(packedBytes); + for (int i = 0; i < data.length; i++) { + if (data[i]) { + long byteIdx = i / 8; + byte cur = seg.get(ValueLayout.JAVA_BYTE, byteIdx); + seg.set(ValueLayout.JAVA_BYTE, byteIdx, (byte) (cur | (1 << (i % 8)))); + } + } + return seg; + } + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_BOOL; + } + + @Override + public boolean accepts(DType dtype) { + return dtype instanceof DType.Bool; + } + + @Override + public EncodeResult encode(DType dtype, Object data, EncodeContext ctx) { + boolean[] bools = (boolean[]) data; + boolean hasTrue = false; + boolean hasFalse = false; + for (boolean b : bools) { + if (b) { + hasTrue = true; + } else { + hasFalse = true; + } + if (hasTrue && hasFalse) { + break; + } + } + byte[] statsMin = bools.length > 0 + ? ScalarValue.ofBoolValue(!hasFalse).encode() + : null; + byte[] statsMax = bools.length > 0 + ? ScalarValue.ofBoolValue(hasTrue).encode() + : null; + return EncodeResult.simple(encodingId(), encodeBool(bools, ctx.arena()), statsMin, statsMax); + } +} diff --git a/writer/src/main/java/io/github/dfa1/vortex/writer/encode/ByteBoolEncodingEncoder.java b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/ByteBoolEncodingEncoder.java new file mode 100644 index 00000000..f496dddc --- /dev/null +++ b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/ByteBoolEncodingEncoder.java @@ -0,0 +1,38 @@ +package io.github.dfa1.vortex.writer.encode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.encoding.EncodeContext; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodingEncoder; +import io.github.dfa1.vortex.encoding.EncodingId; + +import java.lang.foreign.MemorySegment; +import java.lang.foreign.ValueLayout; + +/// Write-only encoder for {@code vortex.bytebool} — one byte per boolean element. +public final class ByteBoolEncodingEncoder implements EncodingEncoder { + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public ByteBoolEncodingEncoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_BYTEBOOL; + } + + @Override + public boolean accepts(DType dtype) { + return dtype instanceof DType.Bool; + } + + @Override + public EncodeResult encode(DType dtype, Object data, EncodeContext ctx) { + boolean[] bools = (boolean[]) data; + MemorySegment seg = ctx.arena().allocate(bools.length); + for (int i = 0; i < bools.length; i++) { + seg.set(ValueLayout.JAVA_BYTE, i, bools[i] ? (byte) 1 : (byte) 0); + } + return EncodeResult.simple(encodingId(), seg); + } +} diff --git a/writer/src/main/java/io/github/dfa1/vortex/writer/encode/ChunkedEncodingEncoder.java b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/ChunkedEncodingEncoder.java new file mode 100644 index 00000000..c8c20865 --- /dev/null +++ b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/ChunkedEncodingEncoder.java @@ -0,0 +1,93 @@ +package io.github.dfa1.vortex.writer.encode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.PType; +import io.github.dfa1.vortex.core.VortexException; +import io.github.dfa1.vortex.encoding.BoolEncoding; +import io.github.dfa1.vortex.encoding.ByteBoolEncoding; +import io.github.dfa1.vortex.encoding.ChunkedData; +import io.github.dfa1.vortex.encoding.EncodeContext; +import io.github.dfa1.vortex.encoding.EncodeNode; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.Encoding; +import io.github.dfa1.vortex.encoding.EncodingEncoder; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.NullEncoding; +import io.github.dfa1.vortex.encoding.PrimitiveEncoding; +import io.github.dfa1.vortex.encoding.StructEncoding; +import io.github.dfa1.vortex.encoding.VarBinEncoding; + +import java.lang.foreign.MemorySegment; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; + +/// Write-only encoder for {@code vortex.chunked}. +public final class ChunkedEncodingEncoder implements EncodingEncoder { + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public ChunkedEncodingEncoder() { + } + + private static final List FALLBACK = List.of( + new PrimitiveEncoding(), new VarBinEncoding(), new BoolEncoding(), + new NullEncoding(), new ByteBoolEncoding(), new StructEncoding()); + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_CHUNKED; + } + + @Override + public boolean accepts(DType dtype) { + return dtype instanceof DType.Primitive || dtype instanceof DType.Struct; + } + + @Override + public EncodeResult encode(DType dtype, Object data, EncodeContext ctx) { + ChunkedData cd = (ChunkedData) data; + List chunks = cd.chunks(); + long[] chunkLengths = cd.chunkLengths(); + int nchunks = chunks.size(); + if (nchunks == 0) { + throw new VortexException(EncodingId.VORTEX_CHUNKED, "at least one chunk required"); + } + + long[] offsets = new long[nchunks + 1]; + offsets[0] = 0; + for (int i = 0; i < nchunks; i++) { + offsets[i + 1] = offsets[i] + chunkLengths[i]; + } + + DType u64 = new DType.Primitive(PType.U64, false); + EncodeResult offsetsResult = ctx.lookupEncoding(EncodingId.VORTEX_PRIMITIVE).encode(u64, offsets, ctx); + + List allBuffers = new ArrayList<>(offsetsResult.buffers()); + EncodeNode[] children = new EncodeNode[nchunks + 1]; + children[0] = offsetsResult.rootNode(); + + Encoding inner = findEncoding(dtype); + for (int i = 0; i < nchunks; i++) { + EncodeResult chunkResult = inner.encode(dtype, chunks.get(i), ctx); + int bufOffset = allBuffers.size(); + children[i + 1] = EncodeNode.remapBufferIndices(chunkResult.rootNode(), bufOffset); + allBuffers.addAll(chunkResult.buffers()); + } + + EncodeNode root = new EncodeNode( + EncodingId.VORTEX_CHUNKED, + ByteBuffer.wrap(new byte[0]), + children, + new int[]{}); + return new EncodeResult(root, List.copyOf(allBuffers), null, null); + } + + private static Encoding findEncoding(DType dtype) { + for (Encoding enc : FALLBACK) { + if (enc.accepts(dtype)) { + return enc; + } + } + throw new UnsupportedOperationException("no fallback encoding for dtype: " + dtype); + } +} diff --git a/writer/src/main/java/io/github/dfa1/vortex/writer/encode/ConstantEncodingEncoder.java b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/ConstantEncodingEncoder.java new file mode 100644 index 00000000..95864d27 --- /dev/null +++ b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/ConstantEncodingEncoder.java @@ -0,0 +1,103 @@ +package io.github.dfa1.vortex.writer.encode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.PType; +import io.github.dfa1.vortex.core.VortexException; +import io.github.dfa1.vortex.encoding.CascadeStep; +import io.github.dfa1.vortex.encoding.EncodeContext; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodingEncoder; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.proto.ScalarValue; + +import java.lang.foreign.MemorySegment; + +/// Write-only encoder for {@code vortex.constant}. +public final class ConstantEncodingEncoder implements EncodingEncoder { + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public ConstantEncodingEncoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_CONSTANT; + } + + @Override + public boolean accepts(DType dtype) { + return dtype instanceof DType.Primitive; + } + + @Override + public EncodeResult encode(DType dtype, Object data, EncodeContext ctx) { + if (!(dtype instanceof DType.Primitive p)) { + throw new VortexException(EncodingId.VORTEX_CONSTANT, "encode only supports Primitive dtype, got " + dtype); + } + PType ptype = p.ptype(); + if (!isConstant(data, ptype)) { + throw new VortexException(EncodingId.VORTEX_CONSTANT, "not a constant array"); + } + long firstRaw = readFirstRaw(data, ptype); + ScalarValue scalar = buildScalar(ptype, firstRaw); + return EncodeResult.simple(EncodingId.VORTEX_CONSTANT, MemorySegment.ofArray(scalar.encode())); + } + + @Override + public CascadeStep encodeCascade(DType dtype, Object data, EncodeContext encodeCtx) { + if (!isConstant(data, ((DType.Primitive) dtype).ptype())) { + return CascadeStep.notApplicable(); + } + return CascadeStep.terminal(encode(dtype, data, encodeCtx)); + } + + private static long readFirstRaw(Object data, PType ptype) { + return switch (ptype) { + case I8, U8 -> ((byte[]) data).length > 0 ? ((byte[]) data)[0] : 0L; + case I16, U16 -> ((short[]) data).length > 0 ? ((short[]) data)[0] : 0L; + case I32, U32 -> ((int[]) data).length > 0 ? ((int[]) data)[0] : 0L; + case I64, U64 -> ((long[]) data).length > 0 ? ((long[]) data)[0] : 0L; + case F32 -> ((float[]) data).length > 0 ? Float.floatToRawIntBits(((float[]) data)[0]) : 0L; + case F64 -> ((double[]) data).length > 0 ? Double.doubleToRawLongBits(((double[]) data)[0]) : 0L; + default -> throw new VortexException(EncodingId.VORTEX_CONSTANT, "unsupported ptype: " + ptype); + }; + } + + private static boolean isConstant(Object data, PType ptype) { + long firstRaw = readFirstRaw(data, ptype); + int len = switch (ptype) { + case I8, U8 -> ((byte[]) data).length; + case I16, U16 -> ((short[]) data).length; + case I32, U32 -> ((int[]) data).length; + case I64, U64 -> ((long[]) data).length; + case F32 -> ((float[]) data).length; + case F64 -> ((double[]) data).length; + default -> throw new VortexException(EncodingId.VORTEX_CONSTANT, "unsupported ptype: " + ptype); + }; + for (int i = 1; i < len; i++) { + long raw = switch (ptype) { + case I8, U8 -> ((byte[]) data)[i]; + case I16, U16 -> ((short[]) data)[i]; + case I32, U32 -> ((int[]) data)[i]; + case I64, U64 -> ((long[]) data)[i]; + case F32 -> Float.floatToRawIntBits(((float[]) data)[i]); + case F64 -> Double.doubleToRawLongBits(((double[]) data)[i]); + default -> throw new VortexException(EncodingId.VORTEX_CONSTANT, "unsupported ptype: " + ptype); + }; + if (raw != firstRaw) { + return false; + } + } + return true; + } + + private static ScalarValue buildScalar(PType ptype, long rawBits) { + return switch (ptype) { + case U8, U16, U32, U64 -> ScalarValue.ofUint64Value(rawBits); + case I8, I16, I32, I64 -> ScalarValue.ofInt64Value(rawBits); + case F32 -> ScalarValue.ofF32Value(Float.intBitsToFloat((int) rawBits)); + case F64 -> ScalarValue.ofF64Value(Double.longBitsToDouble(rawBits)); + default -> throw new VortexException(EncodingId.VORTEX_CONSTANT, "unsupported ptype: " + ptype); + }; + } +} diff --git a/writer/src/main/java/io/github/dfa1/vortex/writer/encode/DateTimePartsEncodingEncoder.java b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/DateTimePartsEncodingEncoder.java new file mode 100644 index 00000000..690a4f6a --- /dev/null +++ b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/DateTimePartsEncodingEncoder.java @@ -0,0 +1,158 @@ +package io.github.dfa1.vortex.writer.encode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.PType; +import io.github.dfa1.vortex.core.VortexException; +import io.github.dfa1.vortex.encoding.CascadeStep; +import io.github.dfa1.vortex.encoding.ChildSlot; +import io.github.dfa1.vortex.encoding.DateTimePartsData; +import io.github.dfa1.vortex.encoding.EncodeContext; +import io.github.dfa1.vortex.encoding.EncodeNode; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.Encoding; +import io.github.dfa1.vortex.encoding.EncodingEncoder; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.TimeUnit; +import io.github.dfa1.vortex.proto.DateTimePartsMetadata; + +import java.lang.foreign.MemorySegment; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; + +/// Write-only encoder for {@code vortex.datetimeparts}. +public final class DateTimePartsEncodingEncoder implements EncodingEncoder { + + private static final long SECONDS_PER_DAY = 86_400L; + private static final DType I64 = new DType.Primitive(PType.I64, false); + private static final DType I64_NULLABLE = new DType.Primitive(PType.I64, true); + private static final io.github.dfa1.vortex.proto.PType I64_PROTO = + io.github.dfa1.vortex.proto.PType.fromValue(PType.I64.ordinal()); + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public DateTimePartsEncodingEncoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_DATETIMEPARTS; + } + + @Override + public boolean accepts(DType dtype) { + return dtype instanceof DType.Extension; + } + + @Override + public EncodeResult encode(DType dtype, Object data, EncodeContext ctx) { + DType.Extension ext = (DType.Extension) dtype; + DateTimePartsData d = (DateTimePartsData) data; + + ByteBuffer extMeta = ext.metadata(); + if (extMeta == null || extMeta.remaining() < 3) { + throw new VortexException(EncodingId.VORTEX_DATETIMEPARTS, + "extension metadata missing or too short"); + } + byte[] extBytes = new byte[extMeta.remaining()]; + extMeta.duplicate().get(extBytes); + TimeUnit unit = TimeUnit.fromTag(extBytes[0]); + + long divisor = unit.divisor(); + long ticksPerDay = SECONDS_PER_DAY * divisor; + int n = d.timestamps().length; + + long[] days = new long[n]; + long[] seconds = new long[n]; + long[] subseconds = new long[n]; + + for (int i = 0; i < n; i++) { + long ts = d.timestamps()[i]; + long dval = ts / ticksPerDay; + long rem = ts % ticksPerDay; + if (rem < 0) { + rem += ticksPerDay; + dval--; + } + days[i] = dval; + seconds[i] = rem / divisor; + subseconds[i] = rem % divisor; + } + + DType daysDtype = d.nullable() ? I64_NULLABLE : I64; + + Encoding primEnc = ctx.lookupEncoding(EncodingId.VORTEX_PRIMITIVE); + EncodeResult daysResult = primEnc.encode(daysDtype, days, ctx); + EncodeResult secondsResult = primEnc.encode(I64, seconds, ctx); + EncodeResult subsecondsResult = primEnc.encode(I64, subseconds, ctx); + + List allBuffers = new ArrayList<>(); + allBuffers.addAll(daysResult.buffers()); + allBuffers.addAll(secondsResult.buffers()); + allBuffers.addAll(subsecondsResult.buffers()); + + int off1 = daysResult.buffers().size(); + int off2 = off1 + secondsResult.buffers().size(); + + EncodeNode daysNode = EncodeNode.remapBufferIndices(daysResult.rootNode(), 0); + EncodeNode secondsNode = EncodeNode.remapBufferIndices(secondsResult.rootNode(), off1); + EncodeNode subsecondsNode = EncodeNode.remapBufferIndices(subsecondsResult.rootNode(), off2); + + byte[] metaBytes = new DateTimePartsMetadata(I64_PROTO, I64_PROTO, I64_PROTO).encode(); + + EncodeNode root = new EncodeNode( + EncodingId.VORTEX_DATETIMEPARTS, + ByteBuffer.wrap(metaBytes), + new EncodeNode[]{daysNode, secondsNode, subsecondsNode}, + new int[]{}); + return new EncodeResult(root, List.copyOf(allBuffers), null, null); + } + + @Override + public CascadeStep encodeCascade(DType dtype, Object data, EncodeContext encodeCtx) { + if (!(data instanceof DateTimePartsData d)) { + return CascadeStep.notApplicable(); + } + DType.Extension ext = (DType.Extension) dtype; + ByteBuffer extMeta = ext.metadata(); + byte[] extBytes = new byte[extMeta.remaining()]; + extMeta.duplicate().get(extBytes); + TimeUnit unit = TimeUnit.fromTag(extBytes[0]); + + long divisor = unit.divisor(); + long ticksPerDay = SECONDS_PER_DAY * divisor; + int n = d.timestamps().length; + + long[] days = new long[n]; + long[] seconds = new long[n]; + long[] subseconds = new long[n]; + + for (int i = 0; i < n; i++) { + long ts = d.timestamps()[i]; + long dval = ts / ticksPerDay; + long rem = ts % ticksPerDay; + if (rem < 0) { + rem += ticksPerDay; + dval--; + } + days[i] = dval; + seconds[i] = rem / divisor; + subseconds[i] = rem % divisor; + } + + byte[] metaBytes = new DateTimePartsMetadata(I64_PROTO, I64_PROTO, I64_PROTO).encode(); + + EncodeNode partialRoot = new EncodeNode( + EncodingId.VORTEX_DATETIMEPARTS, + ByteBuffer.wrap(metaBytes), + new EncodeNode[3], + new int[0]); + + DType daysDtype = d.nullable() ? I64_NULLABLE : I64; + List children = List.of( + new ChildSlot(daysDtype, days, 0), + new ChildSlot(I64, seconds, 1), + new ChildSlot(I64, subseconds, 2)); + + return new CascadeStep(partialRoot, List.of(), children, null, null, true); + } +} diff --git a/writer/src/main/java/io/github/dfa1/vortex/writer/encode/DecimalBytePartsEncodingEncoder.java b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/DecimalBytePartsEncodingEncoder.java new file mode 100644 index 00000000..a5b97cb7 --- /dev/null +++ b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/DecimalBytePartsEncodingEncoder.java @@ -0,0 +1,49 @@ +package io.github.dfa1.vortex.writer.encode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.PType; +import io.github.dfa1.vortex.encoding.EncodeContext; +import io.github.dfa1.vortex.encoding.EncodeNode; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodingEncoder; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.proto.DecimalBytePartsMetadata; + +import java.nio.ByteBuffer; +import java.util.List; + +/// Write-only encoder for {@code vortex.decimal_byte_parts}. +public final class DecimalBytePartsEncodingEncoder implements EncodingEncoder { + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public DecimalBytePartsEncodingEncoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_DECIMAL_BYTE_PARTS; + } + + @Override + public boolean accepts(DType dtype) { + return dtype instanceof DType.Decimal; + } + + @Override + public EncodeResult encode(DType dtype, Object data, EncodeContext ctx) { + DType.Decimal d = (DType.Decimal) dtype; + long[] longs = (long[]) data; + DType mspDtype = new DType.Primitive(PType.I64, d.nullable()); + EncodeResult mspResult = ctx.lookupEncoding(EncodingId.VORTEX_PRIMITIVE).encode(mspDtype, longs, ctx); + + DecimalBytePartsMetadata proto = new DecimalBytePartsMetadata( + io.github.dfa1.vortex.proto.PType.fromValue(PType.I64.ordinal()), + 0); + ByteBuffer metaBuf = ByteBuffer.wrap(proto.encode()); + + EncodeNode mspNode = EncodeNode.remapBufferIndices(mspResult.rootNode(), 0); + EncodeNode root = new EncodeNode( + EncodingId.VORTEX_DECIMAL_BYTE_PARTS, metaBuf, new EncodeNode[]{mspNode}, new int[]{}); + return new EncodeResult(root, List.copyOf(mspResult.buffers()), null, null); + } +} diff --git a/writer/src/main/java/io/github/dfa1/vortex/writer/encode/DecimalEncodingEncoder.java b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/DecimalEncodingEncoder.java new file mode 100644 index 00000000..2e34e3f9 --- /dev/null +++ b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/DecimalEncodingEncoder.java @@ -0,0 +1,79 @@ +package io.github.dfa1.vortex.writer.encode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.VortexException; +import io.github.dfa1.vortex.encoding.EncodeContext; +import io.github.dfa1.vortex.encoding.EncodeNode; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodingEncoder; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.proto.DecimalMetadata; + +import java.lang.foreign.MemorySegment; +import java.nio.ByteBuffer; +import java.util.List; + +/// Write-only encoder for {@code vortex.decimal}. +public final class DecimalEncodingEncoder implements EncodingEncoder { + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public DecimalEncodingEncoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_DECIMAL; + } + + @Override + public boolean accepts(DType dtype) { + return dtype instanceof DType.Decimal; + } + + @Override + public EncodeResult encode(DType dtype, Object data, EncodeContext ctx) { + DType.Decimal d = (DType.Decimal) dtype; + MemorySegment seg = (MemorySegment) data; + int valuesType = valuesType(d.precision()); + int bw = byteWidth(valuesType); + if (seg.byteSize() % bw != 0) { + throw new VortexException(EncodingId.VORTEX_DECIMAL, + "buffer size %d not multiple of byteWidth %d".formatted(seg.byteSize(), bw)); + } + ByteBuffer metaBuf = ByteBuffer.wrap(new DecimalMetadata(valuesType).encode()); + EncodeNode node = new EncodeNode(EncodingId.VORTEX_DECIMAL, metaBuf, new EncodeNode[0], new int[]{0}); + return new EncodeResult(node, List.of(seg), null, null); + } + + private static int valuesType(byte precision) { + if (precision <= 2) { + return 0; + } + if (precision <= 4) { + return 1; + } + if (precision <= 9) { + return 2; + } + if (precision <= 18) { + return 3; + } + if (precision <= 38) { + return 4; + } + return 5; + } + + private static int byteWidth(int valuesType) { + return switch (valuesType) { + case 0 -> 1; + case 1 -> 2; + case 2 -> 4; + case 3 -> 8; + case 4 -> 16; + case 5 -> 32; + default -> throw new VortexException(EncodingId.VORTEX_DECIMAL, + "unknown valuesType: " + valuesType); + }; + } +} diff --git a/writer/src/main/java/io/github/dfa1/vortex/writer/encode/DeltaEncodingEncoder.java b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/DeltaEncodingEncoder.java new file mode 100644 index 00000000..a940920b --- /dev/null +++ b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/DeltaEncodingEncoder.java @@ -0,0 +1,192 @@ +package io.github.dfa1.vortex.writer.encode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.PType; +import io.github.dfa1.vortex.core.VortexException; +import io.github.dfa1.vortex.encoding.DeltaEncoding; +import io.github.dfa1.vortex.encoding.EncodeContext; +import io.github.dfa1.vortex.encoding.EncodeNode; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodingEncoder; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.proto.DeltaMetadata; +import io.github.dfa1.vortex.proto.ScalarValue; + +import java.lang.foreign.MemorySegment; +import java.nio.ByteBuffer; +import java.util.List; + +/// Write-only encoder for {@code fastlanes.delta}. +public final class DeltaEncodingEncoder implements EncodingEncoder { + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public DeltaEncodingEncoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.FASTLANES_DELTA; + } + + @Override + public boolean accepts(DType dtype) { + if (!(dtype instanceof DType.Primitive p)) { + return false; + } + return switch (p.ptype()) { + case I8, I16, I32, I64, U8, U16, U32, U64 -> true; + default -> false; + }; + } + + @Override + public EncodeResult encode(DType dtype, Object data, EncodeContext ctx) { + PType ptype = ((DType.Primitive) dtype).ptype(); + long[] longs = toLongs(data, ptype); + int n = longs.length; + int typeBits = DeltaEncoding.typeBits(ptype); + int lanes = DeltaEncoding.lanes(ptype); + long mask = DeltaEncoding.typeMask(ptype); + boolean unsign = isUnsigned(ptype); + + long minVal = 0L, maxVal = 0L; + if (n > 0) { + minVal = longs[0]; + maxVal = longs[0]; + for (int i = 1; i < n; i++) { + long v = longs[i]; + if (unsign ? Long.compareUnsigned(v, minVal) < 0 : v < minVal) { + minVal = v; + } + if (unsign ? Long.compareUnsigned(v, maxVal) > 0 : v > maxVal) { + maxVal = v; + } + } + } + + int numChunks = n == 0 ? 0 : (n + DeltaEncoding.FL_CHUNK_SIZE - 1) / DeltaEncoding.FL_CHUNK_SIZE; + long paddedLen = (long) numChunks * DeltaEncoding.FL_CHUNK_SIZE; + int basesLen = numChunks * lanes; + + long[] basesAll = new long[basesLen]; + long[] deltasAll = new long[(int) paddedLen]; + long[] chunkBuf = new long[DeltaEncoding.FL_CHUNK_SIZE]; + long[] transposed = new long[DeltaEncoding.FL_CHUNK_SIZE]; + long[] chunkBases = new long[lanes]; + long[] chunkDelta = new long[DeltaEncoding.FL_CHUNK_SIZE]; + + for (int chunk = 0; chunk < numChunks; chunk++) { + int start = chunk * DeltaEncoding.FL_CHUNK_SIZE; + int end = Math.min(start + DeltaEncoding.FL_CHUNK_SIZE, n); + for (int i = start; i < end; i++) { + chunkBuf[i - start] = longs[i] & mask; + } + for (int i = end - start; i < DeltaEncoding.FL_CHUNK_SIZE; i++) { + chunkBuf[i] = 0L; + } + for (int i = 0; i < DeltaEncoding.FL_CHUNK_SIZE; i++) { + transposed[i] = chunkBuf[DeltaEncoding.transposeIndex(i)]; + } + int basesOff = chunk * lanes; + System.arraycopy(transposed, 0, basesAll, basesOff, lanes); + System.arraycopy(basesAll, basesOff, chunkBases, 0, lanes); + deltaChunk(transposed, chunkBases, lanes, typeBits, mask, chunkDelta); + System.arraycopy(chunkDelta, 0, deltasAll, chunk * DeltaEncoding.FL_CHUNK_SIZE, DeltaEncoding.FL_CHUNK_SIZE); + } + + MemorySegment basesSeg = DeltaEncoding.fromLongs(basesAll, ptype, ctx.arena()); + MemorySegment deltasSeg = DeltaEncoding.fromLongs(deltasAll, ptype, ctx.arena()); + + byte[] metaBytes = new DeltaMetadata(paddedLen, 0).encode(); + + byte[] statsMin = n > 0 ? statsBytes(ptype, minVal) : null; + byte[] statsMax = n > 0 ? statsBytes(ptype, maxVal) : null; + + EncodeNode basesNode = EncodeNode.leaf(EncodingId.VORTEX_PRIMITIVE, 0); + EncodeNode deltasNode = EncodeNode.leaf(EncodingId.VORTEX_PRIMITIVE, 1); + EncodeNode root = new EncodeNode(EncodingId.FASTLANES_DELTA, ByteBuffer.wrap(metaBytes), + new EncodeNode[]{basesNode, deltasNode}, new int[0]); + return new EncodeResult(root, List.of(basesSeg, deltasSeg), statsMin, statsMax); + } + + private static void deltaChunk(long[] transposed, long[] bases, int lanes, int typeBits, long mask, long[] out) { + for (int lane = 0; lane < lanes; lane++) { + long prev = bases[lane] & mask; + for (int row = 0; row < typeBits; row++) { + int idx = DeltaEncoding.iterateIndex(row, lane); + long next = transposed[idx] & mask; + out[idx] = (next - prev) & mask; + prev = next; + } + } + } + + private static long[] toLongs(Object data, PType ptype) { + return switch (ptype) { + case I8 -> { + byte[] arr = (byte[]) data; + long[] r = new long[arr.length]; + for (int i = 0; i < arr.length; i++) { + r[i] = arr[i]; + } + yield r; + } + case U8 -> { + byte[] arr = (byte[]) data; + long[] r = new long[arr.length]; + for (int i = 0; i < arr.length; i++) { + r[i] = Byte.toUnsignedLong(arr[i]); + } + yield r; + } + case I16 -> { + short[] arr = (short[]) data; + long[] r = new long[arr.length]; + for (int i = 0; i < arr.length; i++) { + r[i] = arr[i]; + } + yield r; + } + case U16 -> { + short[] arr = (short[]) data; + long[] r = new long[arr.length]; + for (int i = 0; i < arr.length; i++) { + r[i] = Short.toUnsignedLong(arr[i]); + } + yield r; + } + case I32 -> { + int[] arr = (int[]) data; + long[] r = new long[arr.length]; + for (int i = 0; i < arr.length; i++) { + r[i] = arr[i]; + } + yield r; + } + case U32 -> { + int[] arr = (int[]) data; + long[] r = new long[arr.length]; + for (int i = 0; i < arr.length; i++) { + r[i] = Integer.toUnsignedLong(arr[i]); + } + yield r; + } + case I64, U64 -> (long[]) data; + default -> throw new VortexException(EncodingId.FASTLANES_DELTA, "unsupported ptype: " + ptype); + }; + } + + private static boolean isUnsigned(PType ptype) { + return switch (ptype) { + case U8, U16, U32, U64 -> true; + default -> false; + }; + } + + private static byte[] statsBytes(PType ptype, long value) { + if (isUnsigned(ptype)) { + return ScalarValue.ofUint64Value(value).encode(); + } + return ScalarValue.ofInt64Value(value).encode(); + } +} diff --git a/writer/src/main/java/io/github/dfa1/vortex/writer/encode/DictEncodingEncoder.java b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/DictEncodingEncoder.java new file mode 100644 index 00000000..5bd4dc9c --- /dev/null +++ b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/DictEncodingEncoder.java @@ -0,0 +1,317 @@ +package io.github.dfa1.vortex.writer.encode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.PType; +import io.github.dfa1.vortex.core.VortexException; +import io.github.dfa1.vortex.encoding.CascadeStep; +import io.github.dfa1.vortex.encoding.ChildSlot; +import io.github.dfa1.vortex.encoding.EncodeContext; +import io.github.dfa1.vortex.encoding.EncodeNode; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodingEncoder; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.PTypeIO; +import io.github.dfa1.vortex.proto.DictMetadata; +import io.github.dfa1.vortex.proto.ScalarValue; +import io.github.dfa1.vortex.proto.VarBinMetadata; + +import java.lang.foreign.Arena; +import java.lang.foreign.MemorySegment; +import java.lang.foreign.ValueLayout; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.util.LinkedHashMap; +import java.util.List; + +/// Write-only encoder for {@code vortex.dict}. +public final class DictEncodingEncoder implements EncodingEncoder { + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public DictEncodingEncoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_DICT; + } + + @Override + public boolean accepts(DType dtype) { + return dtype instanceof DType.Primitive || dtype instanceof DType.Utf8; + } + + @Override + public EncodeResult encode(DType dtype, Object data, EncodeContext ctx) { + if (dtype instanceof DType.Utf8) { + return encodeUtf8((String[]) data, ctx); + } + DictData d = buildDictData(dtype, data, ctx); + PType codePType = d.codePType(); + int codeBytes = codePType.byteSize(); + + MemorySegment codesBuf = ctx.arena().allocate((long) d.len() * codeBytes); + for (int i = 0; i < d.len(); i++) { + writeCodeToSeg(codesBuf, codePType, i, readCodeFromArr(d.codesArr(), codePType, i)); + } + + ByteBuffer meta = ByteBuffer.allocate(1).put(0, (byte) codePType.ordinal()); + EncodeNode valuesNode = EncodeNode.leaf(EncodingId.VORTEX_PRIMITIVE, 0); + EncodeNode codesNode = EncodeNode.leaf(EncodingId.VORTEX_PRIMITIVE, 1); + EncodeNode rootNode = new EncodeNode( + EncodingId.VORTEX_DICT, meta, + new EncodeNode[]{valuesNode, codesNode}, + new int[0]); + + return new EncodeResult(rootNode, List.of(d.valuesBuf(), codesBuf), null, null); + } + + @Override + public CascadeStep encodeCascade(DType dtype, Object data, EncodeContext ctx) { + if (dtype instanceof DType.Utf8) { + return CascadeStep.terminal(encodeUtf8((String[]) data, ctx)); + } + DictData d = buildDictData(dtype, data, ctx); + PType codePType = d.codePType(); + + ByteBuffer meta = ByteBuffer.allocate(1).put(0, (byte) codePType.ordinal()); + EncodeNode valuesNode = EncodeNode.leaf(EncodingId.VORTEX_PRIMITIVE, 0); + EncodeNode partialRoot = new EncodeNode( + EncodingId.VORTEX_DICT, meta, + new EncodeNode[]{valuesNode, null}, + new int[0]); + + DType codesDtype = new DType.Primitive(codePType, false); + ChildSlot slot = new ChildSlot(codesDtype, d.codesArr(), 1); + return new CascadeStep(partialRoot, List.of(d.valuesBuf()), List.of(slot), null, null, true); + } + + private static EncodeResult encodeUtf8(String[] strings, EncodeContext ctx) { + int n = strings.length; + + var valueMap = new LinkedHashMap(); + for (String s : strings) { + valueMap.computeIfAbsent(s, _ -> valueMap.size()); + } + + int dictSize = valueMap.size(); + PType codePType = codePType(dictSize); + int codeBytes = codePType.byteSize(); + + byte[][] dictByteArrays = new byte[dictSize][]; + int j = 0; + long totalDictBytes = 0; + for (String s : valueMap.keySet()) { + dictByteArrays[j] = s.getBytes(StandardCharsets.UTF_8); + totalDictBytes += dictByteArrays[j].length; + j++; + } + + Arena arena = ctx.arena(); + MemorySegment dictBytesBuf = arena.allocate(totalDictBytes > 0 ? totalDictBytes : 1); + MemorySegment dictOffsetsBuf = arena.allocate((long) (dictSize + 1) * Long.BYTES, Long.BYTES); + + long pos = 0; + dictOffsetsBuf.setAtIndex(PTypeIO.LE_LONG, 0, 0L); + for (int i = 0; i < dictSize; i++) { + MemorySegment.copy(MemorySegment.ofArray(dictByteArrays[i]), 0, dictBytesBuf, pos, dictByteArrays[i].length); + pos += dictByteArrays[i].length; + dictOffsetsBuf.setAtIndex(PTypeIO.LE_LONG, i + 1, pos); + } + + MemorySegment codesBuf = arena.allocate((long) n * codeBytes); + for (int i = 0; i < n; i++) { + writeCodeToSeg(codesBuf, codePType, i, valueMap.get(strings[i])); + } + + byte[] metaBytes = new DictMetadata( + dictSize, + io.github.dfa1.vortex.proto.PType.fromValue(codePType.ordinal()), + null, + null + ).encode(); + + byte[] varBinMetaBytes = new VarBinMetadata( + io.github.dfa1.vortex.proto.PType.fromValue(PType.I64.ordinal()) + ).encode(); + + EncodeNode offsetsNode = EncodeNode.leaf(EncodingId.VORTEX_PRIMITIVE, 1); + EncodeNode valuesNode = new EncodeNode(EncodingId.VORTEX_VARBIN, + ByteBuffer.wrap(varBinMetaBytes), + new EncodeNode[]{offsetsNode}, + new int[]{0}); + EncodeNode codesNode = EncodeNode.leaf(EncodingId.VORTEX_PRIMITIVE, 2); + EncodeNode root = new EncodeNode( + EncodingId.VORTEX_DICT, ByteBuffer.wrap(metaBytes), + new EncodeNode[]{codesNode, valuesNode}, + new int[0]); + + String minStr = valueMap.keySet().stream().min(String::compareTo).orElse(null); + String maxStr = valueMap.keySet().stream().max(String::compareTo).orElse(null); + byte[] statsMin = minStr != null ? ScalarValue.ofStringValue(minStr).encode() : null; + byte[] statsMax = maxStr != null ? ScalarValue.ofStringValue(maxStr).encode() : null; + return new EncodeResult(root, List.of(dictBytesBuf, dictOffsetsBuf, codesBuf), statsMin, statsMax); + } + + private static DictData buildDictData(DType dtype, Object data, EncodeContext ctx) { + PType ptype = ((DType.Primitive) dtype).ptype(); + var valueMap = new LinkedHashMap(); + int len = arrayLength(data, ptype); + for (int i = 0; i < len; i++) { + Object v = readElement(data, ptype, i); + valueMap.computeIfAbsent(v, _ -> valueMap.size()); + } + + int dictSize = valueMap.size(); + PType codePType = codePType(dictSize); + int codeBytes = codePType.byteSize(); + + Object uniqueArray = buildUniqueArray(ptype, valueMap.keySet(), dictSize); + MemorySegment valuesBuf = PTypeIO.copyArray(ptype, uniqueArray, dictSize); + + MemorySegment codesBuf = ctx.arena().allocate((long) len * codeBytes); + for (int i = 0; i < len; i++) { + Object v = readElement(data, ptype, i); + int code = valueMap.get(v); + writeCodeToSeg(codesBuf, codePType, i, code); + } + + Object codesArr = switch (codePType) { + case U8 -> { + byte[] a = new byte[len]; + for (int i = 0; i < len; i++) { + a[i] = codesBuf.get(ValueLayout.JAVA_BYTE, i); + } + yield a; + } + case U16 -> { + short[] a = new short[len]; + for (int i = 0; i < len; i++) { + a[i] = codesBuf.get(PTypeIO.LE_SHORT, (long) i * 2); + } + yield a; + } + default -> { + int[] a = new int[len]; + for (int i = 0; i < len; i++) { + a[i] = codesBuf.get(PTypeIO.LE_INT, (long) i * 4); + } + yield a; + } + }; + return new DictData(valuesBuf, codesArr, codePType, len); + } + + private static PType codePType(int dictSize) { + if (dictSize <= 256) { + return PType.U8; + } + if (dictSize <= 65536) { + return PType.U16; + } + return PType.U32; + } + + private static int arrayLength(Object data, PType ptype) { + return switch (ptype) { + case I8, U8 -> ((byte[]) data).length; + case I16, U16 -> ((short[]) data).length; + case I32, U32 -> ((int[]) data).length; + case I64, U64 -> ((long[]) data).length; + case F32 -> ((float[]) data).length; + case F64 -> ((double[]) data).length; + case F16 -> ((short[]) data).length; + }; + } + + private static Object readElement(Object data, PType ptype, int i) { + return switch (ptype) { + case I8, U8 -> ((byte[]) data)[i]; + case I16, U16, F16 -> ((short[]) data)[i]; + case I32, U32 -> ((int[]) data)[i]; + case I64, U64 -> ((long[]) data)[i]; + case F32 -> ((float[]) data)[i]; + case F64 -> ((double[]) data)[i]; + }; + } + + private static Object buildUniqueArray(PType ptype, Iterable uniques, int dictSize) { + return switch (ptype) { + case I8, U8 -> { + byte[] a = new byte[dictSize]; + int i = 0; + for (Object v : uniques) { + a[i++] = (Byte) v; + } + yield a; + } + case I16, U16 -> { + short[] a = new short[dictSize]; + int i = 0; + for (Object v : uniques) { + a[i++] = (Short) v; + } + yield a; + } + case I32, U32 -> { + int[] a = new int[dictSize]; + int i = 0; + for (Object v : uniques) { + a[i++] = (Integer) v; + } + yield a; + } + case I64, U64 -> { + long[] a = new long[dictSize]; + int i = 0; + for (Object v : uniques) { + a[i++] = (Long) v; + } + yield a; + } + case F32 -> { + float[] a = new float[dictSize]; + int i = 0; + for (Object v : uniques) { + a[i++] = (Float) v; + } + yield a; + } + case F64 -> { + double[] a = new double[dictSize]; + int i = 0; + for (Object v : uniques) { + a[i++] = (Double) v; + } + yield a; + } + case F16 -> { + short[] a = new short[dictSize]; + int i = 0; + for (Object v : uniques) { + a[i++] = (Short) v; + } + yield a; + } + }; + } + + private static void writeCodeToSeg(MemorySegment seg, PType codePType, int idx, int code) { + switch (codePType) { + case U8 -> seg.set(ValueLayout.JAVA_BYTE, idx, (byte) code); + case U16 -> seg.set(PTypeIO.LE_SHORT, (long) idx * 2, (short) code); + case U32 -> seg.set(PTypeIO.LE_INT, (long) idx * 4, code); + default -> throw new VortexException(EncodingId.VORTEX_DICT, "unexpected code type: " + codePType); + } + } + + private static int readCodeFromArr(Object arr, PType codePType, int i) { + return switch (codePType) { + case U8 -> Byte.toUnsignedInt(((byte[]) arr)[i]); + case U16 -> Short.toUnsignedInt(((short[]) arr)[i]); + default -> ((int[]) arr)[i]; + }; + } + + private record DictData(MemorySegment valuesBuf, Object codesArr, PType codePType, int len) { + } +} diff --git a/writer/src/main/java/io/github/dfa1/vortex/writer/encode/ExtEncodingEncoder.java b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/ExtEncodingEncoder.java new file mode 100644 index 00000000..bc6ae474 --- /dev/null +++ b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/ExtEncodingEncoder.java @@ -0,0 +1,79 @@ +package io.github.dfa1.vortex.writer.encode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.VortexException; +import io.github.dfa1.vortex.core.array.NullableData; +import io.github.dfa1.vortex.encoding.CascadeStep; +import io.github.dfa1.vortex.encoding.ChildSlot; +import io.github.dfa1.vortex.encoding.EncodeContext; +import io.github.dfa1.vortex.encoding.EncodeNode; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.Encoding; +import io.github.dfa1.vortex.encoding.EncodingEncoder; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.FixedSizeListEncoding; +import io.github.dfa1.vortex.encoding.MaskedEncoding; +import io.github.dfa1.vortex.encoding.PrimitiveEncoding; + +import java.util.List; + +/// Write-only encoder for {@code vortex.ext} — wraps a storage-array encode in an ext node. +public final class ExtEncodingEncoder implements EncodingEncoder { + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public ExtEncodingEncoder() { + } + + private static final List STORAGE_FALLBACK = List.of( + new PrimitiveEncoding(), + new FixedSizeListEncoding()); + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_EXT; + } + + @Override + public boolean accepts(DType dtype) { + return dtype instanceof DType.Extension; + } + + @Override + public EncodeResult encode(DType dtype, Object data, EncodeContext ctx) { + if (!(dtype instanceof DType.Extension ext)) { + throw new VortexException(EncodingId.VORTEX_EXT, "expected extension dtype, got " + dtype); + } + DType storage = ext.storageDType(); + EncodeResult childResult; + if (data instanceof NullableData) { + childResult = new MaskedEncoding().encode(storage, data, ctx); + } else { + Encoding storageEncoding = null; + for (Encoding enc : STORAGE_FALLBACK) { + if (enc.accepts(storage)) { + storageEncoding = enc; + break; + } + } + if (storageEncoding == null) { + throw new VortexException(EncodingId.VORTEX_EXT, "no storage encoding for " + storage); + } + childResult = storageEncoding.encode(storage, data, ctx); + } + EncodeNode root = new EncodeNode(EncodingId.VORTEX_EXT, null, new EncodeNode[]{childResult.rootNode()}, new int[0]); + return new EncodeResult(root, childResult.buffers(), childResult.statsMin(), childResult.statsMax()); + } + + @Override + public CascadeStep encodeCascade(DType dtype, Object data, EncodeContext ctx) { + if (!(dtype instanceof DType.Extension ext)) { + throw new VortexException(EncodingId.VORTEX_EXT, "expected extension dtype, got " + dtype); + } + if (data instanceof NullableData) { + return CascadeStep.terminal(encode(dtype, data, ctx)); + } + EncodeNode partialRoot = new EncodeNode(EncodingId.VORTEX_EXT, null, new EncodeNode[1], new int[0]); + ChildSlot slot = new ChildSlot(ext.storageDType(), data, 0); + return new CascadeStep(partialRoot, List.of(), List.of(slot), null, null, true); + } +} diff --git a/writer/src/main/java/io/github/dfa1/vortex/writer/encode/FixedSizeListEncodingEncoder.java b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/FixedSizeListEncodingEncoder.java new file mode 100644 index 00000000..485649ad --- /dev/null +++ b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/FixedSizeListEncodingEncoder.java @@ -0,0 +1,71 @@ +package io.github.dfa1.vortex.writer.encode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.encoding.BoolEncoding; +import io.github.dfa1.vortex.encoding.ByteBoolEncoding; +import io.github.dfa1.vortex.encoding.EncodeContext; +import io.github.dfa1.vortex.encoding.EncodeNode; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.Encoding; +import io.github.dfa1.vortex.encoding.EncodingEncoder; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.FixedSizeListData; +import io.github.dfa1.vortex.encoding.NullEncoding; +import io.github.dfa1.vortex.encoding.PrimitiveEncoding; +import io.github.dfa1.vortex.encoding.VarBinEncoding; + +import java.lang.foreign.MemorySegment; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; + +/// Write-only encoder for {@code vortex.fixed_size_list}. +public final class FixedSizeListEncodingEncoder implements EncodingEncoder { + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public FixedSizeListEncodingEncoder() { + } + + private static final List FALLBACK = List.of( + new PrimitiveEncoding(), new VarBinEncoding(), new BoolEncoding(), + new NullEncoding(), new ByteBoolEncoding()); + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_FIXED_SIZE_LIST; + } + + @Override + public boolean accepts(DType dtype) { + return dtype instanceof DType.FixedSizeList; + } + + @Override + public EncodeResult encode(DType dtype, Object data, EncodeContext ctx) { + DType.FixedSizeList fsl = (DType.FixedSizeList) dtype; + FixedSizeListData fsd = (FixedSizeListData) data; + DType elementType = fsl.elementType(); + Encoding inner = findEncoding(elementType); + + EncodeResult elemResult = inner.encode(elementType, fsd.elements(), ctx); + + List allBuffers = new ArrayList<>(elemResult.buffers()); + EncodeNode elemNode = EncodeNode.remapBufferIndices(elemResult.rootNode(), 0); + + EncodeNode root = new EncodeNode( + EncodingId.VORTEX_FIXED_SIZE_LIST, + ByteBuffer.wrap(new byte[0]), + new EncodeNode[]{elemNode}, + new int[]{}); + return new EncodeResult(root, List.copyOf(allBuffers), null, null); + } + + private static Encoding findEncoding(DType dtype) { + for (Encoding enc : FALLBACK) { + if (enc.accepts(dtype)) { + return enc; + } + } + throw new UnsupportedOperationException("no fallback encoding for dtype: " + dtype); + } +} diff --git a/writer/src/main/java/io/github/dfa1/vortex/writer/encode/FrameOfReferenceEncodingEncoder.java b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/FrameOfReferenceEncodingEncoder.java new file mode 100644 index 00000000..47db9bde --- /dev/null +++ b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/FrameOfReferenceEncodingEncoder.java @@ -0,0 +1,167 @@ +package io.github.dfa1.vortex.writer.encode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.PType; +import io.github.dfa1.vortex.core.VortexException; +import io.github.dfa1.vortex.encoding.CascadeStep; +import io.github.dfa1.vortex.encoding.ChildSlot; +import io.github.dfa1.vortex.encoding.EncodeContext; +import io.github.dfa1.vortex.encoding.EncodeNode; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodingEncoder; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.PTypeIO; +import io.github.dfa1.vortex.proto.ScalarValue; + +import java.lang.foreign.MemorySegment; +import java.nio.ByteBuffer; +import java.util.List; + +/// Write-only encoder for {@code fastlanes.for} (Frame of Reference). +public final class FrameOfReferenceEncodingEncoder implements EncodingEncoder { + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public FrameOfReferenceEncodingEncoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.FASTLANES_FOR; + } + + @Override + public boolean accepts(DType dtype) { + return dtype instanceof DType.Primitive p && !p.ptype().isFloating(); + } + + @Override + public EncodeResult encode(DType dtype, Object data, EncodeContext ctx) { + if (!(dtype instanceof DType.Primitive p)) { + throw new VortexException(EncodingId.FASTLANES_FOR, "expected primitive dtype, got " + dtype); + } + PType ptype = p.ptype(); + long[] longs = toLongs(data, ptype); + int n = longs.length; + + long ref = computeRef(longs, n); + MemorySegment residuals = toResidualBuffer(longs, ref, ptype, ctx); + ByteBuffer meta = buildForMeta(ref, ptype); + + EncodeNode child = EncodeNode.leaf(EncodingId.VORTEX_PRIMITIVE, 0); + EncodeNode root = new EncodeNode(EncodingId.FASTLANES_FOR, meta, new EncodeNode[]{child}, new int[0]); + return new EncodeResult(root, List.of(residuals), null, null); + } + + @Override + public CascadeStep encodeCascade(DType dtype, Object data, EncodeContext encodeCtx) { + if (!(dtype instanceof DType.Primitive p)) { + throw new VortexException(EncodingId.FASTLANES_FOR, "expected primitive dtype, got " + dtype); + } + PType ptype = p.ptype(); + long[] longs = toLongs(data, ptype); + int n = longs.length; + + long ref = computeRef(longs, n); + ByteBuffer meta = buildForMeta(ref, ptype); + + EncodeNode partialRoot = new EncodeNode(EncodingId.FASTLANES_FOR, meta, new EncodeNode[1], new int[0]); + ChildSlot slot = new ChildSlot(dtype, residualsAsNativeArray(longs, ref, ptype), 0); + return new CascadeStep(partialRoot, List.of(), List.of(slot), null, null, true); + } + + private static long computeRef(long[] longs, int n) { + long ref = n > 0 ? longs[0] : 0L; + for (long v : longs) { + if (v < ref) { + ref = v; + } + } + return ref; + } + + private static ByteBuffer buildForMeta(long ref, PType ptype) { + boolean unsigned = switch (ptype) { + case U8, U16, U32, U64 -> true; + default -> false; + }; + ScalarValue scalar = unsigned ? ScalarValue.ofUint64Value(ref) : ScalarValue.ofInt64Value(ref); + return ByteBuffer.wrap(scalar.encode()); + } + + private static Object residualsAsNativeArray(long[] longs, long ref, PType ptype) { + int n = longs.length; + return switch (ptype) { + case I8, U8 -> { + byte[] r = new byte[n]; + for (int i = 0; i < n; i++) { + r[i] = (byte) (longs[i] - ref); + } + yield r; + } + case I16, U16 -> { + short[] r = new short[n]; + for (int i = 0; i < n; i++) { + r[i] = (short) (longs[i] - ref); + } + yield r; + } + case I32, U32 -> { + int[] r = new int[n]; + for (int i = 0; i < n; i++) { + r[i] = (int) (longs[i] - ref); + } + yield r; + } + case I64, U64 -> { + long[] r = new long[n]; + for (int i = 0; i < n; i++) { + r[i] = longs[i] - ref; + } + yield r; + } + default -> throw new VortexException(EncodingId.FASTLANES_FOR, "unsupported ptype: " + ptype); + }; + } + + private static long[] toLongs(Object data, PType ptype) { + return switch (ptype) { + case I8, U8 -> { + byte[] arr = (byte[]) data; + long[] r = new long[arr.length]; + for (int i = 0; i < arr.length; i++) { + r[i] = ptype == PType.U8 ? Byte.toUnsignedLong(arr[i]) : arr[i]; + } + yield r; + } + case I16, U16 -> { + short[] arr = (short[]) data; + long[] r = new long[arr.length]; + for (int i = 0; i < arr.length; i++) { + r[i] = ptype == PType.U16 ? Short.toUnsignedLong(arr[i]) : arr[i]; + } + yield r; + } + case I32, U32 -> { + int[] arr = (int[]) data; + long[] r = new long[arr.length]; + for (int i = 0; i < arr.length; i++) { + r[i] = ptype == PType.U32 ? Integer.toUnsignedLong(arr[i]) : arr[i]; + } + yield r; + } + case I64, U64 -> (long[]) data; + default -> throw new VortexException(EncodingId.FASTLANES_FOR, "unsupported ptype: " + ptype); + }; + } + + private static MemorySegment toResidualBuffer(long[] longs, long ref, PType ptype, EncodeContext ctx) { + int n = longs.length; + int elemBytes = ptype.byteSize(); + MemorySegment seg = ctx.arena().allocate((long) n * elemBytes, elemBytes); + for (int i = 0; i < n; i++) { + long r = longs[i] - ref; + PTypeIO.set(seg, (long) i * elemBytes, ptype, r); + } + return seg; + } +} diff --git a/writer/src/main/java/io/github/dfa1/vortex/writer/encode/FsstEncodingEncoder.java b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/FsstEncodingEncoder.java new file mode 100644 index 00000000..5841237a --- /dev/null +++ b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/FsstEncodingEncoder.java @@ -0,0 +1,158 @@ +package io.github.dfa1.vortex.writer.encode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.PType; +import io.github.dfa1.vortex.encoding.EncodeContext; +import io.github.dfa1.vortex.encoding.EncodeNode; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodingEncoder; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.PTypeIO; +import io.github.dfa1.vortex.proto.FSSTMetadata; + +import java.lang.foreign.Arena; +import java.lang.foreign.MemorySegment; +import java.lang.foreign.ValueLayout; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.util.Arrays; +import java.util.List; + +/// Write-only encoder for {@code vortex.fsst}. +public final class FsstEncodingEncoder implements EncodingEncoder { + + private static final int MAX_SYMBOLS = 255; + private static final int BIGRAM_COUNT = 65536; + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public FsstEncodingEncoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_FSST; + } + + @Override + public boolean accepts(DType dtype) { + return dtype instanceof DType.Utf8 || dtype instanceof DType.Binary; + } + + @Override + public EncodeResult encode(DType dtype, Object data, EncodeContext ctx) { + String[] strings = (String[]) data; + int n = strings.length; + + byte[][] byteArrays = new byte[n][]; + for (int i = 0; i < n; i++) { + byteArrays[i] = strings[i].getBytes(StandardCharsets.UTF_8); + } + + int[] freq = new int[BIGRAM_COUNT]; + for (byte[] b : byteArrays) { + for (int i = 0; i + 1 < b.length; i++) { + freq[(Byte.toUnsignedInt(b[i]) << 8) | Byte.toUnsignedInt(b[i + 1])]++; + } + } + long[] ranked = new long[BIGRAM_COUNT]; + for (int i = 0; i < BIGRAM_COUNT; i++) { + ranked[i] = ((long) freq[i] << 16) | i; + } + Arrays.sort(ranked); + + int numSymbols = 0; + int[] codeForBigram = new int[BIGRAM_COUNT]; + Arrays.fill(codeForBigram, -1); + long[] symbolValues = new long[MAX_SYMBOLS]; + for (int rank = BIGRAM_COUNT - 1; rank >= 0 && numSymbols < MAX_SYMBOLS; rank--) { + int bg = (int) (ranked[rank] & 0xFFFF); + if (freq[bg] == 0) { + break; + } + codeForBigram[bg] = numSymbols; + int hi = bg >>> 8; + int lo = bg & 0xFF; + symbolValues[numSymbols] = hi | ((long) lo << 8); + numSymbols++; + } + + byte[][] compressed = new byte[n][]; + for (int i = 0; i < n; i++) { + compressed[i] = compressString(byteArrays[i], codeForBigram); + } + + Arena arena = ctx.arena(); + + MemorySegment symBuf = arena.allocate(Math.max(numSymbols * 8L, 1), 8); + for (int i = 0; i < numSymbols; i++) { + symBuf.setAtIndex(PTypeIO.LE_LONG, i, symbolValues[i]); + } + + MemorySegment symLenBuf = arena.allocate(Math.max(numSymbols, 1)); + for (int i = 0; i < numSymbols; i++) { + symLenBuf.set(ValueLayout.JAVA_BYTE, i, (byte) 2); + } + + int totalCompressed = 0; + for (byte[] c : compressed) { + totalCompressed += c.length; + } + MemorySegment compBuf = arena.allocate(Math.max(totalCompressed, 1)); + long pos = 0; + for (byte[] c : compressed) { + MemorySegment.copy(MemorySegment.ofArray(c), 0, compBuf, pos, c.length); + pos += c.length; + } + + MemorySegment uncompLenBuf = arena.allocate(Math.max(n * 4L, 1), 4); + for (int i = 0; i < n; i++) { + uncompLenBuf.setAtIndex(PTypeIO.LE_INT, i, byteArrays[i].length); + } + + MemorySegment codesOffBuf = arena.allocate((long) (n + 1) * 4, 4); + long off = 0; + codesOffBuf.setAtIndex(PTypeIO.LE_INT, 0, 0); + for (int i = 0; i < n; i++) { + off += compressed[i].length; + codesOffBuf.setAtIndex(PTypeIO.LE_INT, i + 1, (int) off); + } + + byte[] metaBytes = new FSSTMetadata( + io.github.dfa1.vortex.proto.PType.fromValue(PType.I32.ordinal()), + io.github.dfa1.vortex.proto.PType.fromValue(PType.I32.ordinal()) + ).encode(); + + EncodeNode uncompLensNode = EncodeNode.leaf(EncodingId.VORTEX_PRIMITIVE, 3); + EncodeNode codesOffNode = EncodeNode.leaf(EncodingId.VORTEX_PRIMITIVE, 4); + EncodeNode root = new EncodeNode( + EncodingId.VORTEX_FSST, + ByteBuffer.wrap(metaBytes), + new EncodeNode[]{uncompLensNode, codesOffNode}, + new int[]{0, 1, 2}); + + return new EncodeResult(root, + List.of(symBuf, symLenBuf, compBuf, uncompLenBuf, codesOffBuf), + null, null); + } + + private static byte[] compressString(byte[] input, int[] codeForBigram) { + byte[] out = new byte[input.length * 2]; + int outLen = 0; + int i = 0; + while (i < input.length) { + if (i + 1 < input.length) { + int bg = (Byte.toUnsignedInt(input[i]) << 8) | Byte.toUnsignedInt(input[i + 1]); + int code = codeForBigram[bg]; + if (code >= 0) { + out[outLen++] = (byte) code; + i += 2; + continue; + } + } + out[outLen++] = (byte) 0xFF; + out[outLen++] = input[i]; + i++; + } + return Arrays.copyOf(out, outLen); + } +} diff --git a/writer/src/main/java/io/github/dfa1/vortex/writer/encode/ListEncodingEncoder.java b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/ListEncodingEncoder.java new file mode 100644 index 00000000..7a554060 --- /dev/null +++ b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/ListEncodingEncoder.java @@ -0,0 +1,88 @@ +package io.github.dfa1.vortex.writer.encode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.PType; +import io.github.dfa1.vortex.encoding.BoolEncoding; +import io.github.dfa1.vortex.encoding.ByteBoolEncoding; +import io.github.dfa1.vortex.encoding.EncodeContext; +import io.github.dfa1.vortex.encoding.EncodeNode; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.Encoding; +import io.github.dfa1.vortex.encoding.EncodingEncoder; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.ListData; +import io.github.dfa1.vortex.encoding.NullEncoding; +import io.github.dfa1.vortex.encoding.PTypeIO; +import io.github.dfa1.vortex.encoding.PrimitiveEncoding; +import io.github.dfa1.vortex.encoding.VarBinEncoding; +import io.github.dfa1.vortex.proto.ListMetadata; + +import java.lang.foreign.MemorySegment; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; + +/// Write-only encoder for {@code vortex.list}. +public final class ListEncodingEncoder implements EncodingEncoder { + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public ListEncodingEncoder() { + } + + private static final List FALLBACK = List.of( + new PrimitiveEncoding(), new VarBinEncoding(), new BoolEncoding(), + new NullEncoding(), new ByteBoolEncoding()); + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_LIST; + } + + @Override + public boolean accepts(DType dtype) { + return dtype instanceof DType.List; + } + + @Override + public EncodeResult encode(DType dtype, Object data, EncodeContext ctx) { + DType.List listDtype = (DType.List) dtype; + ListData ld = (ListData) data; + DType elementType = listDtype.elementType(); + Encoding elemEncoding = findEncoding(elementType); + EncodeResult elemResult = elemEncoding.encode(elementType, ld.elements(), ctx); + + List allBuffers = new ArrayList<>(elemResult.buffers()); + int elemBufCount = allBuffers.size(); + EncodeNode elemNode = EncodeNode.remapBufferIndices(elemResult.rootNode(), 0); + + long nOffsets = ld.outerLen() + 1; + MemorySegment offsetsBuf = ctx.arena().allocate(nOffsets * Long.BYTES, Long.BYTES); + for (int i = 0; i < nOffsets; i++) { + offsetsBuf.setAtIndex(PTypeIO.LE_LONG, i, ld.offsets()[i]); + } + allBuffers.add(offsetsBuf); + EncodeNode offsetsNode = EncodeNode.leaf(EncodingId.VORTEX_PRIMITIVE, elemBufCount); + + long elementsLen = ld.offsets()[(int) ld.outerLen()]; + byte[] metaBytes = new ListMetadata( + elementsLen, + io.github.dfa1.vortex.proto.PType.fromValue(PType.I64.ordinal()) + ).encode(); + + EncodeNode root = new EncodeNode( + EncodingId.VORTEX_LIST, + ByteBuffer.wrap(metaBytes), + new EncodeNode[]{elemNode, offsetsNode}, + new int[]{}); + return new EncodeResult(root, List.copyOf(allBuffers), null, null); + } + + private static Encoding findEncoding(DType dtype) { + for (Encoding enc : FALLBACK) { + if (enc.accepts(dtype)) { + return enc; + } + } + throw new UnsupportedOperationException("no fallback encoding for dtype: " + dtype); + } +} diff --git a/writer/src/main/java/io/github/dfa1/vortex/writer/encode/ListViewEncodingEncoder.java b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/ListViewEncodingEncoder.java new file mode 100644 index 00000000..bc09fc9b --- /dev/null +++ b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/ListViewEncodingEncoder.java @@ -0,0 +1,97 @@ +package io.github.dfa1.vortex.writer.encode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.PType; +import io.github.dfa1.vortex.encoding.BoolEncoding; +import io.github.dfa1.vortex.encoding.ByteBoolEncoding; +import io.github.dfa1.vortex.encoding.EncodeContext; +import io.github.dfa1.vortex.encoding.EncodeNode; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.Encoding; +import io.github.dfa1.vortex.encoding.EncodingEncoder; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.ListViewData; +import io.github.dfa1.vortex.encoding.NullEncoding; +import io.github.dfa1.vortex.encoding.PTypeIO; +import io.github.dfa1.vortex.encoding.PrimitiveEncoding; +import io.github.dfa1.vortex.encoding.VarBinEncoding; +import io.github.dfa1.vortex.proto.ListViewMetadata; + +import java.lang.foreign.MemorySegment; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; + +/// Write-only encoder for {@code vortex.listview}. +public final class ListViewEncodingEncoder implements EncodingEncoder { + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public ListViewEncodingEncoder() { + } + + private static final List FALLBACK = List.of( + new PrimitiveEncoding(), new VarBinEncoding(), new BoolEncoding(), + new NullEncoding(), new ByteBoolEncoding()); + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_LISTVIEW; + } + + @Override + public boolean accepts(DType dtype) { + return dtype instanceof DType.List; + } + + @Override + public EncodeResult encode(DType dtype, Object data, EncodeContext ctx) { + DType.List listDtype = (DType.List) dtype; + ListViewData lvd = (ListViewData) data; + DType elementType = listDtype.elementType(); + Encoding elemEncoding = findEncoding(elementType); + EncodeResult elemResult = elemEncoding.encode(elementType, lvd.elements(), ctx); + + List allBuffers = new ArrayList<>(elemResult.buffers()); + int elemBufCount = allBuffers.size(); + EncodeNode elemNode = EncodeNode.remapBufferIndices(elemResult.rootNode(), 0); + + long n = lvd.outerLen(); + + MemorySegment offsetsBuf = ctx.arena().allocate(n * Integer.BYTES, Integer.BYTES); + for (int i = 0; i < n; i++) { + offsetsBuf.setAtIndex(PTypeIO.LE_INT, i, lvd.offsets()[i]); + } + allBuffers.add(offsetsBuf); + EncodeNode offsetsNode = EncodeNode.leaf(EncodingId.VORTEX_PRIMITIVE, elemBufCount); + + MemorySegment sizesBuf = ctx.arena().allocate(n * Integer.BYTES, Integer.BYTES); + for (int i = 0; i < n; i++) { + sizesBuf.setAtIndex(PTypeIO.LE_INT, i, lvd.sizes()[i]); + } + allBuffers.add(sizesBuf); + EncodeNode sizesNode = EncodeNode.leaf(EncodingId.VORTEX_PRIMITIVE, elemBufCount + 1); + + long elementsLen = java.lang.reflect.Array.getLength(lvd.elements()); + byte[] metaBytes = new ListViewMetadata( + elementsLen, + io.github.dfa1.vortex.proto.PType.fromValue(PType.I32.ordinal()), + io.github.dfa1.vortex.proto.PType.fromValue(PType.I32.ordinal()) + ).encode(); + + EncodeNode root = new EncodeNode( + EncodingId.VORTEX_LISTVIEW, + ByteBuffer.wrap(metaBytes), + new EncodeNode[]{elemNode, offsetsNode, sizesNode}, + new int[]{}); + return new EncodeResult(root, List.copyOf(allBuffers), null, null); + } + + private static Encoding findEncoding(DType dtype) { + for (Encoding enc : FALLBACK) { + if (enc.accepts(dtype)) { + return enc; + } + } + throw new UnsupportedOperationException("no fallback encoding for dtype: " + dtype); + } +} diff --git a/writer/src/main/java/io/github/dfa1/vortex/writer/encode/MaskedEncodingEncoder.java b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/MaskedEncodingEncoder.java new file mode 100644 index 00000000..cd0ed2ed --- /dev/null +++ b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/MaskedEncodingEncoder.java @@ -0,0 +1,79 @@ +package io.github.dfa1.vortex.writer.encode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.VortexException; +import io.github.dfa1.vortex.core.array.NullableData; +import io.github.dfa1.vortex.encoding.BoolEncoding; +import io.github.dfa1.vortex.encoding.EncodeContext; +import io.github.dfa1.vortex.encoding.EncodeNode; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.Encoding; +import io.github.dfa1.vortex.encoding.EncodingEncoder; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.FixedSizeListEncoding; +import io.github.dfa1.vortex.encoding.PrimitiveEncoding; +import io.github.dfa1.vortex.encoding.VarBinEncoding; + +import java.lang.foreign.MemorySegment; +import java.util.ArrayList; +import java.util.List; + +/// Write-only encoder for {@code vortex.masked}. Wraps the payload encode in a values + validity +/// pair driven by a {@link NullableData} carrier. +public final class MaskedEncodingEncoder implements EncodingEncoder { + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public MaskedEncodingEncoder() { + } + + private static final List INNER_FALLBACK = List.of( + new PrimitiveEncoding(), + new VarBinEncoding(), + new FixedSizeListEncoding()); + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_MASKED; + } + + @Override + public boolean accepts(DType dtype) { + return false; + } + + @Override + public EncodeResult encode(DType dtype, Object data, EncodeContext ctx) { + if (!(data instanceof NullableData nd)) { + throw new VortexException(EncodingId.VORTEX_MASKED, + "expected NullableData, got " + (data == null ? "null" : data.getClass().getName())); + } + DType nonNullable = dtype.withNullable(false); + Encoding inner = pickInner(nonNullable); + EncodeResult valuesResult = inner.encode(nonNullable, nd.values(), ctx); + EncodeResult validityResult = new BoolEncoding().encode(new DType.Bool(false), nd.validity(), ctx); + + int valuesBufCount = valuesResult.buffers().size(); + EncodeNode validityNode = EncodeNode.remapBufferIndices(validityResult.rootNode(), valuesBufCount); + + List buffers = new ArrayList<>(valuesBufCount + validityResult.buffers().size()); + buffers.addAll(valuesResult.buffers()); + buffers.addAll(validityResult.buffers()); + + EncodeNode root = new EncodeNode( + EncodingId.VORTEX_MASKED, + null, + new EncodeNode[]{valuesResult.rootNode(), validityNode}, + new int[0]); + return new EncodeResult(root, buffers, valuesResult.statsMin(), valuesResult.statsMax()); + } + + private static Encoding pickInner(DType nonNullable) { + for (Encoding e : INNER_FALLBACK) { + if (e.accepts(nonNullable)) { + return e; + } + } + throw new VortexException(EncodingId.VORTEX_MASKED, + "no inner encoding for " + nonNullable); + } +} diff --git a/writer/src/main/java/io/github/dfa1/vortex/writer/encode/NullEncodingEncoder.java b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/NullEncodingEncoder.java new file mode 100644 index 00000000..d8f427f2 --- /dev/null +++ b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/NullEncodingEncoder.java @@ -0,0 +1,34 @@ +package io.github.dfa1.vortex.writer.encode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.encoding.EncodeContext; +import io.github.dfa1.vortex.encoding.EncodeNode; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodingEncoder; +import io.github.dfa1.vortex.encoding.EncodingId; + +import java.util.List; + +/// Write-only encoder for {@code vortex.null} (all-null arrays). +public final class NullEncodingEncoder implements EncodingEncoder { + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public NullEncodingEncoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_NULL; + } + + @Override + public boolean accepts(DType dtype) { + return dtype instanceof DType.Null; + } + + @Override + public EncodeResult encode(DType dtype, Object data, EncodeContext ctx) { + EncodeNode root = new EncodeNode(EncodingId.VORTEX_NULL, null, new EncodeNode[0], new int[0]); + return new EncodeResult(root, List.of(), null, null); + } +} diff --git a/writer/src/main/java/io/github/dfa1/vortex/writer/encode/PatchedEncodingEncoder.java b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/PatchedEncodingEncoder.java new file mode 100644 index 00000000..0e968582 --- /dev/null +++ b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/PatchedEncodingEncoder.java @@ -0,0 +1,31 @@ +package io.github.dfa1.vortex.writer.encode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.VortexException; +import io.github.dfa1.vortex.encoding.EncodeContext; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodingEncoder; +import io.github.dfa1.vortex.encoding.EncodingId; + +/// Write-only encoder for {@code vortex.patched} — currently throws (not implemented). +public final class PatchedEncodingEncoder implements EncodingEncoder { + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public PatchedEncodingEncoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_PATCHED; + } + + @Override + public boolean accepts(DType dtype) { + return false; + } + + @Override + public EncodeResult encode(DType dtype, Object data, EncodeContext ctx) { + throw new VortexException(EncodingId.VORTEX_PATCHED, "encode not yet implemented"); + } +} diff --git a/writer/src/main/java/io/github/dfa1/vortex/writer/encode/PcoEncodingEncoder.java b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/PcoEncodingEncoder.java new file mode 100644 index 00000000..78b37af9 --- /dev/null +++ b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/PcoEncodingEncoder.java @@ -0,0 +1,32 @@ +package io.github.dfa1.vortex.writer.encode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.VortexException; +import io.github.dfa1.vortex.encoding.EncodeContext; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodingEncoder; +import io.github.dfa1.vortex.encoding.EncodingId; + +/// Write-side stub for {@code vortex.pco} — encode is not yet implemented. +public final class PcoEncodingEncoder implements EncodingEncoder { + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public PcoEncodingEncoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_PCO; + } + + @Override + public boolean accepts(DType dtype) { + return false; + } + + @Override + public EncodeResult encode(DType dtype, Object data, EncodeContext ctx) { + throw new VortexException(EncodingId.VORTEX_PCO, + "encode not implemented — pco encode port pending"); + } +} diff --git a/writer/src/main/java/io/github/dfa1/vortex/writer/encode/PrimitiveEncodingEncoder.java b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/PrimitiveEncodingEncoder.java new file mode 100644 index 00000000..21cb1f5b --- /dev/null +++ b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/PrimitiveEncodingEncoder.java @@ -0,0 +1,292 @@ +package io.github.dfa1.vortex.writer.encode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.PType; +import io.github.dfa1.vortex.encoding.EncodeContext; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodingEncoder; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.PTypeIO; +import io.github.dfa1.vortex.proto.ScalarValue; + +import java.lang.foreign.Arena; +import java.lang.foreign.MemorySegment; + +/// Write-only encoder for {@code vortex.primitive} — raw little-endian primitive arrays. +public final class PrimitiveEncodingEncoder implements EncodingEncoder { + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public PrimitiveEncodingEncoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_PRIMITIVE; + } + + @Override + public boolean accepts(DType dtype) { + return dtype instanceof DType.Primitive; + } + + @Override + public EncodeResult encode(DType dtype, Object data, EncodeContext ctx) { + PType ptype = ((DType.Primitive) dtype).ptype(); + MemorySegment seg = encodePrimitive(ptype, data, ctx.arena()); + byte[] min = null; + byte[] max = null; + byte[][] stats = computeStats(ptype, data); + if (stats != null) { + min = stats[0]; + max = stats[1]; + } + return EncodeResult.simple(EncodingId.VORTEX_PRIMITIVE, seg, min, max); + } + + private static MemorySegment encodePrimitive(PType ptype, Object data, Arena arena) { + return switch (ptype) { + case I8, U8 -> MemorySegment.ofArray((byte[]) data); + case I16, U16, F16 -> { + short[] arr = (short[]) data; + MemorySegment seg = arena.allocate((long) arr.length * 2, 2); + for (int i = 0; i < arr.length; i++) { + seg.setAtIndex(PTypeIO.LE_SHORT, i, arr[i]); + } + yield seg; + } + case I32, U32 -> { + int[] arr = (int[]) data; + MemorySegment seg = arena.allocate((long) arr.length * 4, 4); + for (int i = 0; i < arr.length; i++) { + seg.setAtIndex(PTypeIO.LE_INT, i, arr[i]); + } + yield seg; + } + case I64, U64 -> { + long[] arr = (long[]) data; + MemorySegment seg = arena.allocate((long) arr.length * 8, 8); + for (int i = 0; i < arr.length; i++) { + seg.setAtIndex(PTypeIO.LE_LONG, i, arr[i]); + } + yield seg; + } + case F32 -> { + float[] arr = (float[]) data; + MemorySegment seg = arena.allocate((long) arr.length * 4, 4); + for (int i = 0; i < arr.length; i++) { + seg.setAtIndex(PTypeIO.LE_FLOAT, i, arr[i]); + } + yield seg; + } + case F64 -> { + double[] arr = (double[]) data; + MemorySegment seg = arena.allocate((long) arr.length * 8, 8); + for (int i = 0; i < arr.length; i++) { + seg.setAtIndex(PTypeIO.LE_DOUBLE, i, arr[i]); + } + yield seg; + } + }; + } + + private static byte[][] computeStats(PType ptype, Object data) { + return switch (ptype) { + case I8 -> { + byte[] arr = (byte[]) data; + if (arr.length == 0) { + yield null; + } + long min = arr[0], max = arr[0]; + for (byte v : arr) { + if (v < min) { + min = v; + } + if (v > max) { + max = v; + } + } + yield new byte[][]{scalarI64(min), scalarI64(max)}; + } + case I16 -> { + short[] arr = (short[]) data; + if (arr.length == 0) { + yield null; + } + long min = arr[0], max = arr[0]; + for (short v : arr) { + if (v < min) { + min = v; + } + if (v > max) { + max = v; + } + } + yield new byte[][]{scalarI64(min), scalarI64(max)}; + } + case I32 -> { + int[] arr = (int[]) data; + if (arr.length == 0) { + yield null; + } + long min = arr[0], max = arr[0]; + for (int v : arr) { + if (v < min) { + min = v; + } + if (v > max) { + max = v; + } + } + yield new byte[][]{scalarI64(min), scalarI64(max)}; + } + case I64 -> { + long[] arr = (long[]) data; + if (arr.length == 0) { + yield null; + } + long min = arr[0], max = arr[0]; + for (long v : arr) { + if (v < min) { + min = v; + } + if (v > max) { + max = v; + } + } + yield new byte[][]{scalarI64(min), scalarI64(max)}; + } + case U8 -> { + byte[] arr = (byte[]) data; + if (arr.length == 0) { + yield null; + } + long min = Byte.toUnsignedInt(arr[0]), max = Byte.toUnsignedInt(arr[0]); + for (byte v : arr) { + long uv = Byte.toUnsignedInt(v); + if (uv < min) { + min = uv; + } + if (uv > max) { + max = uv; + } + } + yield new byte[][]{scalarU64(min), scalarU64(max)}; + } + case U16 -> { + short[] arr = (short[]) data; + if (arr.length == 0) { + yield null; + } + long min = Short.toUnsignedInt(arr[0]), max = Short.toUnsignedInt(arr[0]); + for (short v : arr) { + long uv = Short.toUnsignedInt(v); + if (uv < min) { + min = uv; + } + if (uv > max) { + max = uv; + } + } + yield new byte[][]{scalarU64(min), scalarU64(max)}; + } + case U32 -> { + int[] arr = (int[]) data; + if (arr.length == 0) { + yield null; + } + long min = Integer.toUnsignedLong(arr[0]), max = Integer.toUnsignedLong(arr[0]); + for (int v : arr) { + long uv = Integer.toUnsignedLong(v); + if (uv < min) { + min = uv; + } + if (uv > max) { + max = uv; + } + } + yield new byte[][]{scalarU64(min), scalarU64(max)}; + } + case U64 -> { + long[] arr = (long[]) data; + if (arr.length == 0) { + yield null; + } + long min = arr[0], max = arr[0]; + for (long v : arr) { + if (Long.compareUnsigned(v, min) < 0) { + min = v; + } + if (Long.compareUnsigned(v, max) > 0) { + max = v; + } + } + yield new byte[][]{scalarU64(min), scalarU64(max)}; + } + case F32 -> { + float[] arr = (float[]) data; + if (arr.length == 0) { + yield null; + } + float min = arr[0], max = arr[0]; + for (float v : arr) { + if (v < min) { + min = v; + } + if (v > max) { + max = v; + } + } + yield new byte[][]{scalarF32(min), scalarF32(max)}; + } + case F64 -> { + double[] arr = (double[]) data; + if (arr.length == 0) { + yield null; + } + double min = arr[0], max = arr[0]; + for (double v : arr) { + if (v < min) { + min = v; + } + if (v > max) { + max = v; + } + } + yield new byte[][]{scalarF64(min), scalarF64(max)}; + } + case F16 -> { + short[] arr = (short[]) data; + if (arr.length == 0) { + yield null; + } + float min = Float.float16ToFloat(arr[0]), max = Float.float16ToFloat(arr[0]); + for (short v : arr) { + float fv = Float.float16ToFloat(v); + if (fv < min) { + min = fv; + } + if (fv > max) { + max = fv; + } + } + yield new byte[][]{scalarF32(min), scalarF32(max)}; + } + }; + } + + private static byte[] scalarI64(long v) { + return ScalarValue.ofInt64Value(v).encode(); + } + + private static byte[] scalarU64(long v) { + return ScalarValue.ofUint64Value(v).encode(); + } + + private static byte[] scalarF32(float v) { + return ScalarValue.ofF32Value(v).encode(); + } + + private static byte[] scalarF64(double v) { + return ScalarValue.ofF64Value(v).encode(); + } +} diff --git a/writer/src/main/java/io/github/dfa1/vortex/writer/encode/RleEncodingEncoder.java b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/RleEncodingEncoder.java new file mode 100644 index 00000000..512757cd --- /dev/null +++ b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/RleEncodingEncoder.java @@ -0,0 +1,255 @@ +package io.github.dfa1.vortex.writer.encode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.PType; +import io.github.dfa1.vortex.core.VortexException; +import io.github.dfa1.vortex.encoding.EncodeContext; +import io.github.dfa1.vortex.encoding.EncodeNode; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodingEncoder; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.PTypeIO; +import io.github.dfa1.vortex.encoding.RleEncoding; +import io.github.dfa1.vortex.proto.RLEMetadata; + +import java.lang.foreign.MemorySegment; +import java.lang.foreign.SegmentAllocator; +import java.nio.ByteBuffer; +import java.util.List; + +/// Write-only encoder for {@code fastlanes.rle}. +public final class RleEncodingEncoder implements EncodingEncoder { + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public RleEncodingEncoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.FASTLANES_RLE; + } + + @Override + public boolean accepts(DType dtype) { + return dtype instanceof DType.Primitive p && !p.ptype().isFloating(); + } + + @Override + public EncodeResult encode(DType dtype, Object data, EncodeContext ctx) { + if (!(dtype instanceof DType.Primitive p)) { + throw new VortexException(EncodingId.FASTLANES_RLE, "encode only supports Primitive dtype, got " + dtype); + } + PType ptype = p.ptype(); + long[] longs = toLongs(data, ptype); + int n = longs.length; + + if (n == 0) { + return encodeEmpty(ctx); + } + + int numChunks = (n + RleEncoding.FL_CHUNK_SIZE - 1) / RleEncoding.FL_CHUNK_SIZE; + int paddedLen = numChunks * RleEncoding.FL_CHUNK_SIZE; + + long[] globalValues = new long[paddedLen]; + short[] globalIndices = new short[paddedLen]; + long[] valuesIdxOffsets = new long[numChunks]; + + long[] chunkInput = new long[RleEncoding.FL_CHUNK_SIZE]; + long[] chunkValues = new long[RleEncoding.FL_CHUNK_SIZE]; + short[] chunkIndices = new short[RleEncoding.FL_CHUNK_SIZE]; + + int globalValuesCount = 0; + + for (int chunk = 0; chunk < numChunks; chunk++) { + int chunkStart = chunk * RleEncoding.FL_CHUNK_SIZE; + int chunkEnd = Math.min(chunkStart + RleEncoding.FL_CHUNK_SIZE, n); + int chunkLen = chunkEnd - chunkStart; + + System.arraycopy(longs, chunkStart, chunkInput, 0, chunkLen); + long lastVal = longs[chunkEnd - 1]; + for (int i = chunkLen; i < RleEncoding.FL_CHUNK_SIZE; i++) { + chunkInput[i] = lastVal; + } + + int numChunkValues = rleEncode(chunkInput, chunkValues, chunkIndices); + + valuesIdxOffsets[chunk] = globalValuesCount; + System.arraycopy(chunkValues, 0, globalValues, globalValuesCount, numChunkValues); + globalValuesCount += numChunkValues; + + System.arraycopy(chunkIndices, 0, globalIndices, chunkStart, RleEncoding.FL_CHUNK_SIZE); + } + + MemorySegment valuesSeg = fromLongs(globalValues, globalValuesCount, ptype, ctx.arena()); + MemorySegment indicesSeg = toIndicesSeg(globalIndices, paddedLen, ctx.arena()); + MemorySegment offsetsSeg = fromLongsU64(valuesIdxOffsets, numChunks, ctx.arena()); + + PType indicesPtype = PType.U16; + PType offsetsPtype = PType.U64; + + byte[] metaBytes = new RLEMetadata( + globalValuesCount, + paddedLen, + io.github.dfa1.vortex.proto.PType.fromValue(indicesPtype.ordinal()), + numChunks, + io.github.dfa1.vortex.proto.PType.fromValue(offsetsPtype.ordinal()), + 0L + ).encode(); + + EncodeNode valuesNode = EncodeNode.leaf(EncodingId.VORTEX_PRIMITIVE, 0); + EncodeNode indicesNode = EncodeNode.leaf(EncodingId.VORTEX_PRIMITIVE, 1); + EncodeNode offsetsNode = EncodeNode.leaf(EncodingId.VORTEX_PRIMITIVE, 2); + EncodeNode root = new EncodeNode( + EncodingId.FASTLANES_RLE, + ByteBuffer.wrap(metaBytes), + new EncodeNode[]{valuesNode, indicesNode, offsetsNode}, + new int[0]); + return new EncodeResult(root, List.of(valuesSeg, indicesSeg, offsetsSeg), null, null); + } + + private static int rleEncode(long[] input, long[] chunkValues, short[] chunkIndices) { + short posVal = 0; + int valIdx = 1; + long prev = input[0]; + chunkValues[0] = prev; + chunkIndices[0] = 0; + + for (int i = 1; i < RleEncoding.FL_CHUNK_SIZE; i++) { + long cur = input[i]; + if (cur != prev) { + chunkValues[valIdx] = cur; + valIdx++; + posVal++; + prev = cur; + } + chunkIndices[i] = posVal; + } + return valIdx; + } + + private static EncodeResult encodeEmpty(EncodeContext ctx) { + MemorySegment empty = ctx.arena().allocate(0); + PType indicesPtype = PType.U16; + PType offsetsPtype = PType.U64; + byte[] metaBytes = new RLEMetadata( + 0L, + 0L, + io.github.dfa1.vortex.proto.PType.fromValue(indicesPtype.ordinal()), + 0L, + io.github.dfa1.vortex.proto.PType.fromValue(offsetsPtype.ordinal()), + 0L + ).encode(); + EncodeNode valuesNode = EncodeNode.leaf(EncodingId.VORTEX_PRIMITIVE, 0); + EncodeNode indicesNode = EncodeNode.leaf(EncodingId.VORTEX_PRIMITIVE, 1); + EncodeNode offsetsNode = EncodeNode.leaf(EncodingId.VORTEX_PRIMITIVE, 2); + EncodeNode root = new EncodeNode( + EncodingId.FASTLANES_RLE, + ByteBuffer.wrap(metaBytes), + new EncodeNode[]{valuesNode, indicesNode, offsetsNode}, + new int[0]); + return new EncodeResult(root, List.of(empty, empty, empty), null, null); + } + + private static long[] toLongs(Object data, PType ptype) { + return switch (ptype) { + case I8 -> { + byte[] arr = (byte[]) data; + long[] r = new long[arr.length]; + for (int i = 0; i < arr.length; i++) { + r[i] = arr[i]; + } + yield r; + } + case U8 -> { + byte[] arr = (byte[]) data; + long[] r = new long[arr.length]; + for (int i = 0; i < arr.length; i++) { + r[i] = Byte.toUnsignedLong(arr[i]); + } + yield r; + } + case I16 -> { + short[] arr = (short[]) data; + long[] r = new long[arr.length]; + for (int i = 0; i < arr.length; i++) { + r[i] = arr[i]; + } + yield r; + } + case U16 -> { + short[] arr = (short[]) data; + long[] r = new long[arr.length]; + for (int i = 0; i < arr.length; i++) { + r[i] = Short.toUnsignedLong(arr[i]); + } + yield r; + } + case I32 -> { + int[] arr = (int[]) data; + long[] r = new long[arr.length]; + for (int i = 0; i < arr.length; i++) { + r[i] = arr[i]; + } + yield r; + } + case U32 -> { + int[] arr = (int[]) data; + long[] r = new long[arr.length]; + for (int i = 0; i < arr.length; i++) { + r[i] = Integer.toUnsignedLong(arr[i]); + } + yield r; + } + case I64, U64 -> (long[]) data; + case F32 -> { + float[] arr = (float[]) data; + long[] r = new long[arr.length]; + for (int i = 0; i < arr.length; i++) { + r[i] = Float.floatToRawIntBits(arr[i]); + } + yield r; + } + case F64 -> { + double[] arr = (double[]) data; + long[] r = new long[arr.length]; + for (int i = 0; i < arr.length; i++) { + r[i] = Double.doubleToRawLongBits(arr[i]); + } + yield r; + } + case F16 -> { + short[] arr = (short[]) data; + long[] r = new long[arr.length]; + for (int i = 0; i < arr.length; i++) { + r[i] = Short.toUnsignedLong(arr[i]); + } + yield r; + } + }; + } + + private static MemorySegment fromLongs(long[] values, int count, PType ptype, SegmentAllocator arena) { + int elemSize = ptype.byteSize(); + MemorySegment seg = arena.allocate((long) count * elemSize); + for (int i = 0; i < count; i++) { + PTypeIO.set(seg, (long) i * elemSize, ptype, values[i]); + } + return seg; + } + + private static MemorySegment fromLongsU64(long[] values, int count, SegmentAllocator arena) { + MemorySegment seg = arena.allocate((long) count * 8); + for (int i = 0; i < count; i++) { + seg.setAtIndex(PTypeIO.LE_LONG, i, values[i]); + } + return seg; + } + + private static MemorySegment toIndicesSeg(short[] indices, int count, SegmentAllocator arena) { + MemorySegment seg = arena.allocate((long) count * 2); + for (int i = 0; i < count; i++) { + seg.setAtIndex(PTypeIO.LE_SHORT, i, indices[i]); + } + return seg; + } +} diff --git a/writer/src/main/java/io/github/dfa1/vortex/writer/encode/RunEndEncodingEncoder.java b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/RunEndEncodingEncoder.java new file mode 100644 index 00000000..b8829b55 --- /dev/null +++ b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/RunEndEncodingEncoder.java @@ -0,0 +1,108 @@ +package io.github.dfa1.vortex.writer.encode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.PType; +import io.github.dfa1.vortex.core.VortexException; +import io.github.dfa1.vortex.encoding.EncodeContext; +import io.github.dfa1.vortex.encoding.EncodeNode; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodingEncoder; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.PTypeIO; +import io.github.dfa1.vortex.proto.RunEndMetadata; + +import java.lang.foreign.MemorySegment; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; + +/// Write-only encoder for {@code vortex.runend}. +public final class RunEndEncodingEncoder implements EncodingEncoder { + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public RunEndEncodingEncoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_RUNEND; + } + + @Override + public boolean accepts(DType dtype) { + return dtype instanceof DType.Primitive p && !p.ptype().isFloating(); + } + + @Override + public EncodeResult encode(DType dtype, Object data, EncodeContext ctx) { + if (!(dtype instanceof DType.Primitive p)) { + throw new VortexException(EncodingId.VORTEX_RUNEND, "encode only supports Primitive dtype, got " + dtype); + } + PType ptype = p.ptype(); + int n = arrayLength(data, ptype); + + List ends = new ArrayList<>(); + List values = new ArrayList<>(); + if (n > 0) { + long runVal = readLong(data, ptype, 0); + for (int i = 1; i < n; i++) { + long cur = readLong(data, ptype, i); + if (cur != runVal) { + ends.add(i); + values.add(runVal); + runVal = cur; + } + } + ends.add(n); + values.add(runVal); + } + + int numRuns = ends.size(); + + MemorySegment endsBuf = ctx.arena().allocate((long) numRuns * 4, 4); + for (int i = 0; i < numRuns; i++) { + endsBuf.setAtIndex(PTypeIO.LE_INT, i, ends.get(i)); + } + + int elemBytes = ptype.byteSize(); + MemorySegment valuesBuf = ctx.arena().allocate((long) numRuns * elemBytes, elemBytes); + for (int i = 0; i < numRuns; i++) { + PTypeIO.set(valuesBuf, (long) i * elemBytes, ptype, values.get(i)); + } + + byte[] metaBytes = new RunEndMetadata( + io.github.dfa1.vortex.proto.PType.fromValue(PType.U32.ordinal()), + numRuns, + 0L + ).encode(); + + EncodeNode endsNode = EncodeNode.leaf(EncodingId.VORTEX_PRIMITIVE, 0); + EncodeNode valuesNode = EncodeNode.leaf(EncodingId.VORTEX_PRIMITIVE, 1); + EncodeNode root = new EncodeNode(EncodingId.VORTEX_RUNEND, ByteBuffer.wrap(metaBytes), + new EncodeNode[]{endsNode, valuesNode}, new int[0]); + return new EncodeResult(root, List.of(endsBuf, valuesBuf), null, null); + } + + private static int arrayLength(Object data, PType ptype) { + return switch (ptype) { + case I8, U8 -> ((byte[]) data).length; + case I16, U16 -> ((short[]) data).length; + case I32, U32 -> ((int[]) data).length; + case I64, U64 -> ((long[]) data).length; + default -> throw new VortexException(EncodingId.VORTEX_RUNEND, "unsupported ptype: " + ptype); + }; + } + + private static long readLong(Object data, PType ptype, int i) { + return switch (ptype) { + case I8 -> ((byte[]) data)[i]; + case U8 -> Byte.toUnsignedLong(((byte[]) data)[i]); + case I16 -> ((short[]) data)[i]; + case U16 -> Short.toUnsignedLong(((short[]) data)[i]); + case I32 -> ((int[]) data)[i]; + case U32 -> Integer.toUnsignedLong(((int[]) data)[i]); + case I64, U64 -> ((long[]) data)[i]; + default -> throw new VortexException(EncodingId.VORTEX_RUNEND, "unsupported ptype: " + ptype); + }; + } +} diff --git a/writer/src/main/java/io/github/dfa1/vortex/writer/encode/SequenceEncodingEncoder.java b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/SequenceEncodingEncoder.java new file mode 100644 index 00000000..00e19ea3 --- /dev/null +++ b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/SequenceEncodingEncoder.java @@ -0,0 +1,138 @@ +package io.github.dfa1.vortex.writer.encode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.PType; +import io.github.dfa1.vortex.core.VortexException; +import io.github.dfa1.vortex.encoding.EncodeContext; +import io.github.dfa1.vortex.encoding.EncodeNode; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodingEncoder; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.proto.ScalarValue; +import io.github.dfa1.vortex.proto.SequenceMetadata; + +import java.nio.ByteBuffer; +import java.util.List; + +/// Write-only encoder for {@code vortex.sequence} — arithmetic sequences as (base, multiplier). +public final class SequenceEncodingEncoder implements EncodingEncoder { + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public SequenceEncodingEncoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_SEQUENCE; + } + + @Override + public boolean accepts(DType dtype) { + return dtype instanceof DType.Primitive; + } + + @Override + public EncodeResult encode(DType dtype, Object data, EncodeContext ctx) { + if (!(dtype instanceof DType.Primitive p)) { + throw new VortexException(EncodingId.VORTEX_SEQUENCE, "encode only supports Primitive dtype, got " + dtype); + } + PType pt = p.ptype(); + return switch (pt) { + case I8, I16, I32, I64, U8, U16, U32, U64 -> encodeInteger(pt, data); + case F32 -> encodeF32((float[]) data); + case F64 -> encodeF64((double[]) data); + case F16 -> encodeF16((short[]) data); + }; + } + + private static EncodeResult encodeInteger(PType pt, Object data) { + int n = intArrayLength(pt, data); + long base = 0; + long multiplier = 0; + if (n > 0) { + base = readLong(pt, data, 0); + multiplier = n > 1 ? readLong(pt, data, 1) - base : 0; + for (int i = 2; i < n; i++) { + long expected = base + (long) i * multiplier; + if (readLong(pt, data, i) != expected) { + throw new VortexException(EncodingId.VORTEX_SEQUENCE, "not an arithmetic sequence at index " + i); + } + } + } + ScalarValue baseScalar = buildIntScalar(pt, base); + ScalarValue mulScalar = buildIntScalar(pt, multiplier); + return buildResult(baseScalar, mulScalar); + } + + private static EncodeResult encodeF32(float[] data) { + float base = data.length > 0 ? data[0] : 0f; + float mul = data.length > 1 ? data[1] - base : 0f; + for (int i = 2; i < data.length; i++) { + if (data[i] != base + (float) i * mul) { + throw new VortexException(EncodingId.VORTEX_SEQUENCE, "not an arithmetic sequence at index " + i); + } + } + return buildResult(ScalarValue.ofF32Value(base), ScalarValue.ofF32Value(mul)); + } + + private static EncodeResult encodeF64(double[] data) { + double base = data.length > 0 ? data[0] : 0.0; + double mul = data.length > 1 ? data[1] - base : 0.0; + for (int i = 2; i < data.length; i++) { + if (data[i] != base + (double) i * mul) { + throw new VortexException(EncodingId.VORTEX_SEQUENCE, "not an arithmetic sequence at index " + i); + } + } + return buildResult(ScalarValue.ofF64Value(base), ScalarValue.ofF64Value(mul)); + } + + private static EncodeResult encodeF16(short[] data) { + short baseShort = data.length > 0 ? data[0] : 0; + float baseF = Float.float16ToFloat(baseShort); + float mulF = data.length > 1 ? Float.float16ToFloat(data[1]) - baseF : 0f; + short mulShort = Float.floatToFloat16(mulF); + for (int i = 2; i < data.length; i++) { + short expected = Float.floatToFloat16(baseF + (float) i * mulF); + if (data[i] != expected) { + throw new VortexException(EncodingId.VORTEX_SEQUENCE, "not an arithmetic sequence at index " + i); + } + } + return buildResult( + ScalarValue.ofF16Value(Short.toUnsignedLong(baseShort)), + ScalarValue.ofF16Value(Short.toUnsignedLong(mulShort))); + } + + private static EncodeResult buildResult(ScalarValue base, ScalarValue mul) { + SequenceMetadata meta = new SequenceMetadata(base, mul); + ByteBuffer metaBuf = ByteBuffer.wrap(meta.encode()); + EncodeNode node = new EncodeNode(EncodingId.VORTEX_SEQUENCE, metaBuf, new EncodeNode[0], new int[]{}); + return new EncodeResult(node, List.of(), null, null); + } + + private static ScalarValue buildIntScalar(PType pt, long value) { + return switch (pt) { + case U8, U16, U32, U64 -> ScalarValue.ofUint64Value(value); + default -> ScalarValue.ofInt64Value(value); + }; + } + + private static int intArrayLength(PType pt, Object data) { + return switch (pt) { + case I8, U8 -> ((byte[]) data).length; + case I16, U16 -> ((short[]) data).length; + case I32, U32 -> ((int[]) data).length; + case I64, U64 -> ((long[]) data).length; + default -> throw new VortexException(EncodingId.VORTEX_SEQUENCE, "unsupported ptype: " + pt); + }; + } + + private static long readLong(PType pt, Object data, int i) { + return switch (pt) { + case I8, U8 -> ((byte[]) data)[i]; + case I16, U16 -> ((short[]) data)[i]; + case I32, U32 -> ((int[]) data)[i]; + case I64, U64 -> ((long[]) data)[i]; + default -> throw new VortexException(EncodingId.VORTEX_SEQUENCE, "unsupported ptype: " + pt); + }; + } +} diff --git a/writer/src/main/java/io/github/dfa1/vortex/writer/encode/SparseEncodingEncoder.java b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/SparseEncodingEncoder.java new file mode 100644 index 00000000..2c5e142e --- /dev/null +++ b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/SparseEncodingEncoder.java @@ -0,0 +1,149 @@ +package io.github.dfa1.vortex.writer.encode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.PType; +import io.github.dfa1.vortex.core.VortexException; +import io.github.dfa1.vortex.encoding.EncodeContext; +import io.github.dfa1.vortex.encoding.EncodeNode; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodingEncoder; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.PTypeIO; +import io.github.dfa1.vortex.proto.PatchesMetadata; +import io.github.dfa1.vortex.proto.ScalarValue; +import io.github.dfa1.vortex.proto.SparseMetadata; + +import java.lang.foreign.MemorySegment; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; + +/// Write-only encoder for {@code vortex.sparse}. +public final class SparseEncodingEncoder implements EncodingEncoder { + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public SparseEncodingEncoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_SPARSE; + } + + @Override + public boolean accepts(DType dtype) { + return dtype instanceof DType.Primitive; + } + + @Override + public EncodeResult encode(DType dtype, Object data, EncodeContext ctx) { + if (!(dtype instanceof DType.Primitive p)) { + throw new VortexException(EncodingId.VORTEX_SPARSE, + "encode only supports Primitive dtype, got " + dtype); + } + PType ptype = p.ptype(); + int n = arrayLength(data, ptype); + + List patchIdx = new ArrayList<>(); + List patchBits = new ArrayList<>(); + for (int i = 0; i < n; i++) { + long bits = readBits(data, ptype, i); + if (bits != 0L) { + patchIdx.add(i); + patchBits.add(bits); + } + } + + int numPatches = patchIdx.size(); + PType idxPtype = chooseIdxPtype(n); + + ScalarValue fillScalar = zeroScalar(ptype); + byte[] fillBytes = fillScalar.encode(); + MemorySegment fillBuf = ctx.arena().allocate(fillBytes.length); + MemorySegment.copy(MemorySegment.ofArray(fillBytes), 0, fillBuf, 0, fillBytes.length); + + MemorySegment idxBuf = buildIdxBuf(patchIdx, idxPtype, numPatches, ctx); + MemorySegment valBuf = buildValBuf(patchBits, ptype, numPatches, ctx); + + PatchesMetadata patchesMeta = new PatchesMetadata( + numPatches, + 0L, + io.github.dfa1.vortex.proto.PType.fromValue(idxPtype.ordinal()), + null, + null, + null + ); + byte[] metaBytes = new SparseMetadata(patchesMeta).encode(); + + EncodeNode idxNode = EncodeNode.leaf(EncodingId.VORTEX_PRIMITIVE, 1); + EncodeNode valNode = EncodeNode.leaf(EncodingId.VORTEX_PRIMITIVE, 2); + EncodeNode root = new EncodeNode(EncodingId.VORTEX_SPARSE, ByteBuffer.wrap(metaBytes), + new EncodeNode[]{idxNode, valNode}, new int[]{0}); + return new EncodeResult(root, List.of(fillBuf, idxBuf, valBuf), null, null); + } + + private static int arrayLength(Object data, PType ptype) { + return switch (ptype) { + case I8, U8 -> ((byte[]) data).length; + case I16, U16 -> ((short[]) data).length; + case I32, U32 -> ((int[]) data).length; + case I64, U64 -> ((long[]) data).length; + case F32 -> ((float[]) data).length; + case F64 -> ((double[]) data).length; + default -> throw new VortexException(EncodingId.VORTEX_SPARSE, "unsupported ptype: " + ptype); + }; + } + + private static long readBits(Object data, PType ptype, int i) { + return switch (ptype) { + case I8 -> ((byte[]) data)[i]; + case U8 -> Byte.toUnsignedLong(((byte[]) data)[i]); + case I16 -> ((short[]) data)[i]; + case U16 -> Short.toUnsignedLong(((short[]) data)[i]); + case I32 -> ((int[]) data)[i]; + case U32 -> Integer.toUnsignedLong(((int[]) data)[i]); + case I64, U64 -> ((long[]) data)[i]; + case F32 -> Float.floatToRawIntBits(((float[]) data)[i]); + case F64 -> Double.doubleToRawLongBits(((double[]) data)[i]); + default -> throw new VortexException(EncodingId.VORTEX_SPARSE, "unsupported ptype: " + ptype); + }; + } + + private static PType chooseIdxPtype(int n) { + if (n <= 0xFF) { + return PType.U8; + } else if (n <= 0xFFFF) { + return PType.U16; + } else { + return PType.U32; + } + } + + private static ScalarValue zeroScalar(PType ptype) { + return switch (ptype) { + case I8, I16, I32, I64 -> ScalarValue.ofInt64Value(0L); + case U8, U16, U32, U64 -> ScalarValue.ofUint64Value(0L); + case F32 -> ScalarValue.ofF32Value(0.0f); + case F64 -> ScalarValue.ofF64Value(0.0); + default -> throw new VortexException(EncodingId.VORTEX_SPARSE, "unsupported ptype: " + ptype); + }; + } + + private static MemorySegment buildIdxBuf(List patchIdx, PType idxPtype, int numPatches, EncodeContext ctx) { + int elemBytes = idxPtype.byteSize(); + MemorySegment seg = ctx.arena().allocate(Math.max(1L, (long) numPatches * elemBytes), elemBytes); + for (int i = 0; i < numPatches; i++) { + PTypeIO.set(seg, (long) i * elemBytes, idxPtype, patchIdx.get(i)); + } + return seg; + } + + private static MemorySegment buildValBuf(List patchBits, PType ptype, int numPatches, EncodeContext ctx) { + int elemBytes = ptype.byteSize(); + MemorySegment seg = ctx.arena().allocate(Math.max(1L, (long) numPatches * elemBytes), elemBytes); + for (int i = 0; i < numPatches; i++) { + PTypeIO.set(seg, (long) i * elemBytes, ptype, patchBits.get(i)); + } + return seg; + } +} diff --git a/writer/src/main/java/io/github/dfa1/vortex/writer/encode/StructEncodingEncoder.java b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/StructEncodingEncoder.java new file mode 100644 index 00000000..d0a1cfd7 --- /dev/null +++ b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/StructEncodingEncoder.java @@ -0,0 +1,74 @@ +package io.github.dfa1.vortex.writer.encode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.VortexException; +import io.github.dfa1.vortex.encoding.BoolEncoding; +import io.github.dfa1.vortex.encoding.ByteBoolEncoding; +import io.github.dfa1.vortex.encoding.EncodeContext; +import io.github.dfa1.vortex.encoding.EncodeNode; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.Encoding; +import io.github.dfa1.vortex.encoding.EncodingEncoder; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.NullEncoding; +import io.github.dfa1.vortex.encoding.PrimitiveEncoding; +import io.github.dfa1.vortex.encoding.StructData; +import io.github.dfa1.vortex.encoding.VarBinEncoding; + +import java.lang.foreign.MemorySegment; +import java.util.ArrayList; +import java.util.List; + +/// Write-only encoder for {@code vortex.struct}. +public final class StructEncodingEncoder implements EncodingEncoder { + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public StructEncodingEncoder() { + } + + private static final List FALLBACK = List.of( + new PrimitiveEncoding(), new VarBinEncoding(), new BoolEncoding(), + new NullEncoding(), new ByteBoolEncoding()); + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_STRUCT; + } + + @Override + public boolean accepts(DType dtype) { + return dtype instanceof DType.Struct; + } + + @Override + public EncodeResult encode(DType dtype, Object data, EncodeContext ctx) { + DType.Struct sdtype = (DType.Struct) dtype; + StructData sd = (StructData) data; + List fields = sd.fieldArrays(); + List fieldTypes = sdtype.fieldTypes(); + if (fields.size() != fieldTypes.size()) { + throw new VortexException(EncodingId.VORTEX_STRUCT, + "fieldArrays length %d != fieldTypes length %d".formatted(fields.size(), fieldTypes.size())); + } + List allBuffers = new ArrayList<>(); + EncodeNode[] children = new EncodeNode[fields.size()]; + for (int i = 0; i < fields.size(); i++) { + DType fieldDtype = fieldTypes.get(i); + EncodeResult fieldResult = findEncoding(fieldDtype).encode(fieldDtype, fields.get(i), ctx); + int bufOffset = allBuffers.size(); + children[i] = EncodeNode.remapBufferIndices(fieldResult.rootNode(), bufOffset); + allBuffers.addAll(fieldResult.buffers()); + } + EncodeNode root = new EncodeNode(EncodingId.VORTEX_STRUCT, null, children, new int[0]); + return new EncodeResult(root, List.copyOf(allBuffers), null, null); + } + + private static Encoding findEncoding(DType dtype) { + for (Encoding enc : FALLBACK) { + if (enc.accepts(dtype)) { + return enc; + } + } + throw new UnsupportedOperationException("no fallback encoding for dtype: " + dtype); + } +} diff --git a/writer/src/main/java/io/github/dfa1/vortex/writer/encode/VarBinEncodingEncoder.java b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/VarBinEncodingEncoder.java new file mode 100644 index 00000000..b1731029 --- /dev/null +++ b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/VarBinEncodingEncoder.java @@ -0,0 +1,84 @@ +package io.github.dfa1.vortex.writer.encode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.PType; +import io.github.dfa1.vortex.encoding.EncodeContext; +import io.github.dfa1.vortex.encoding.EncodeNode; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodingEncoder; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.PTypeIO; +import io.github.dfa1.vortex.proto.ScalarValue; +import io.github.dfa1.vortex.proto.VarBinMetadata; + +import java.lang.foreign.Arena; +import java.lang.foreign.MemorySegment; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.util.List; + +/// Write-only encoder for {@code vortex.varbin}. +public final class VarBinEncodingEncoder implements EncodingEncoder { + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public VarBinEncodingEncoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_VARBIN; + } + + @Override + public boolean accepts(DType dtype) { + return dtype instanceof DType.Utf8 || dtype instanceof DType.Binary; + } + + @Override + public EncodeResult encode(DType dtype, Object data, EncodeContext ctx) { + String[] strings = (String[]) data; + int n = strings.length; + + byte[][] byteArrays = new byte[n][]; + int totalBytes = 0; + for (int i = 0; i < n; i++) { + byteArrays[i] = strings[i].getBytes(StandardCharsets.UTF_8); + totalBytes += byteArrays[i].length; + } + + Arena arena = ctx.arena(); + MemorySegment bytesBuf = arena.allocate(totalBytes > 0 ? totalBytes : 1); + MemorySegment offsetsBuf = arena.allocate((long) (n + 1) * Long.BYTES, Long.BYTES); + + long pos = 0; + offsetsBuf.setAtIndex(PTypeIO.LE_LONG, 0, 0L); + for (int i = 0; i < n; i++) { + MemorySegment.copy(MemorySegment.ofArray(byteArrays[i]), 0, bytesBuf, pos, byteArrays[i].length); + pos += byteArrays[i].length; + offsetsBuf.setAtIndex(PTypeIO.LE_LONG, i + 1, pos); + } + + byte[] metaBytes = new VarBinMetadata(io.github.dfa1.vortex.proto.PType.fromValue(PType.I64.ordinal())).encode(); + + String minStr = null; + String maxStr = null; + for (String s : strings) { + if (s == null) { + continue; + } + if (minStr == null || s.compareTo(minStr) < 0) { + minStr = s; + } + if (maxStr == null || s.compareTo(maxStr) > 0) { + maxStr = s; + } + } + byte[] statsMin = minStr != null ? ScalarValue.ofStringValue(minStr).encode() : null; + byte[] statsMax = maxStr != null ? ScalarValue.ofStringValue(maxStr).encode() : null; + + EncodeNode offsetsNode = EncodeNode.leaf(EncodingId.VORTEX_PRIMITIVE, 1); + EncodeNode root = new EncodeNode(EncodingId.VORTEX_VARBIN, ByteBuffer.wrap(metaBytes), + new EncodeNode[]{offsetsNode}, new int[]{0}); + return new EncodeResult(root, List.of(bytesBuf, offsetsBuf), statsMin, statsMax); + } +} diff --git a/writer/src/main/java/io/github/dfa1/vortex/writer/encode/VarBinViewEncodingEncoder.java b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/VarBinViewEncodingEncoder.java new file mode 100644 index 00000000..99971563 --- /dev/null +++ b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/VarBinViewEncodingEncoder.java @@ -0,0 +1,84 @@ +package io.github.dfa1.vortex.writer.encode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.encoding.EncodeContext; +import io.github.dfa1.vortex.encoding.EncodeNode; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodingEncoder; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.PTypeIO; + +import java.lang.foreign.Arena; +import java.lang.foreign.MemorySegment; +import java.nio.charset.StandardCharsets; +import java.util.List; + +/// Write-only encoder for {@code vortex.varbinview}. +public final class VarBinViewEncodingEncoder implements EncodingEncoder { + + private static final int MAX_INLINED_SIZE = 12; + private static final int VIEW_SIZE = 16; + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public VarBinViewEncodingEncoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_VARBINVIEW; + } + + @Override + public boolean accepts(DType dtype) { + return dtype instanceof DType.Utf8 || dtype instanceof DType.Binary; + } + + @Override + public EncodeResult encode(DType dtype, Object data, EncodeContext ctx) { + String[] strings = (String[]) data; + int n = strings.length; + + byte[][] bytes = new byte[n][]; + int totalDataBytes = 0; + for (int i = 0; i < n; i++) { + bytes[i] = strings[i].getBytes(StandardCharsets.UTF_8); + if (bytes[i].length > MAX_INLINED_SIZE) { + totalDataBytes += bytes[i].length; + } + } + + Arena arena = ctx.arena(); + boolean hasDataBuf = totalDataBytes > 0; + MemorySegment dataBuf = arena.allocate(hasDataBuf ? totalDataBytes : 1); + MemorySegment viewsBuf = arena.allocate(n > 0 ? (long) n * VIEW_SIZE : 1); + + int dataOffset = 0; + for (int i = 0; i < n; i++) { + byte[] b = bytes[i]; + long viewOff = (long) i * VIEW_SIZE; + viewsBuf.set(PTypeIO.LE_INT, viewOff, b.length); + if (b.length <= MAX_INLINED_SIZE) { + MemorySegment.copy(MemorySegment.ofArray(b), 0, viewsBuf, viewOff + 4, b.length); + } else { + MemorySegment.copy(MemorySegment.ofArray(b), 0, viewsBuf, viewOff + 4, 4); + viewsBuf.set(PTypeIO.LE_INT, viewOff + 8, 0); + viewsBuf.set(PTypeIO.LE_INT, viewOff + 12, dataOffset); + MemorySegment.copy(MemorySegment.ofArray(b), 0, dataBuf, dataOffset, b.length); + dataOffset += b.length; + } + } + + int[] bufIndices; + List buffers; + if (hasDataBuf) { + bufIndices = new int[]{0, 1}; + buffers = List.of(dataBuf, viewsBuf); + } else { + bufIndices = new int[]{0}; + buffers = List.of(viewsBuf); + } + + EncodeNode root = new EncodeNode(EncodingId.VORTEX_VARBINVIEW, null, new EncodeNode[0], bufIndices); + return new EncodeResult(root, buffers, null, null); + } +} diff --git a/writer/src/main/java/io/github/dfa1/vortex/writer/encode/VariantEncodingEncoder.java b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/VariantEncodingEncoder.java new file mode 100644 index 00000000..368d6501 --- /dev/null +++ b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/VariantEncodingEncoder.java @@ -0,0 +1,31 @@ +package io.github.dfa1.vortex.writer.encode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.VortexException; +import io.github.dfa1.vortex.encoding.EncodeContext; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodingEncoder; +import io.github.dfa1.vortex.encoding.EncodingId; + +/// Write-only encoder for {@code vortex.variant} — currently throws (not implemented). +public final class VariantEncodingEncoder implements EncodingEncoder { + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public VariantEncodingEncoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_VARIANT; + } + + @Override + public boolean accepts(DType dtype) { + return false; + } + + @Override + public EncodeResult encode(DType dtype, Object data, EncodeContext ctx) { + throw new VortexException(EncodingId.VORTEX_VARIANT, "encode not yet implemented"); + } +} diff --git a/writer/src/main/java/io/github/dfa1/vortex/writer/encode/ZigZagEncodingEncoder.java b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/ZigZagEncodingEncoder.java new file mode 100644 index 00000000..a1d69f7b --- /dev/null +++ b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/ZigZagEncodingEncoder.java @@ -0,0 +1,84 @@ +package io.github.dfa1.vortex.writer.encode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.PType; +import io.github.dfa1.vortex.core.VortexException; +import io.github.dfa1.vortex.encoding.EncodeContext; +import io.github.dfa1.vortex.encoding.EncodeNode; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodingEncoder; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.PTypeIO; + +import java.lang.foreign.MemorySegment; +import java.lang.foreign.ValueLayout; +import java.util.List; + +/// Write-only encoder for {@code vortex.zigzag} — signed integers as zigzag-encoded unsigned values. +public final class ZigZagEncodingEncoder implements EncodingEncoder { + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public ZigZagEncodingEncoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_ZIGZAG; + } + + @Override + public boolean accepts(DType dtype) { + if (!(dtype instanceof DType.Primitive p)) { + return false; + } + PType pt = p.ptype(); + return pt == PType.I8 || pt == PType.I16 || pt == PType.I32 || pt == PType.I64; + } + + @Override + public EncodeResult encode(DType dtype, Object data, EncodeContext ctx) { + PType signed = ((DType.Primitive) dtype).ptype(); + MemorySegment seg = switch (signed) { + case I8 -> { + byte[] arr = (byte[]) data; + MemorySegment s = ctx.arena().allocate(arr.length); + for (int i = 0; i < arr.length; i++) { + byte v = arr[i]; + s.set(ValueLayout.JAVA_BYTE, i, (byte) ((v << 1) ^ (v >> 7))); + } + yield s; + } + case I16 -> { + short[] arr = (short[]) data; + MemorySegment s = ctx.arena().allocate((long) arr.length * 2, 2); + for (int i = 0; i < arr.length; i++) { + short v = arr[i]; + s.setAtIndex(PTypeIO.LE_SHORT, i, (short) ((v << 1) ^ (v >> 15))); + } + yield s; + } + case I32 -> { + int[] arr = (int[]) data; + MemorySegment s = ctx.arena().allocate((long) arr.length * 4, 4); + for (int i = 0; i < arr.length; i++) { + int v = arr[i]; + s.setAtIndex(PTypeIO.LE_INT, i, (v << 1) ^ (v >> 31)); + } + yield s; + } + case I64 -> { + long[] arr = (long[]) data; + MemorySegment s = ctx.arena().allocate((long) arr.length * 8, 8); + for (int i = 0; i < arr.length; i++) { + long v = arr[i]; + s.setAtIndex(PTypeIO.LE_LONG, i, (v << 1) ^ (v >> 63)); + } + yield s; + } + default -> throw new VortexException(EncodingId.VORTEX_ZIGZAG, "unsupported ptype: " + signed); + }; + EncodeNode child = EncodeNode.leaf(EncodingId.VORTEX_PRIMITIVE, 0); + EncodeNode root = new EncodeNode(EncodingId.VORTEX_ZIGZAG, null, new EncodeNode[]{child}, new int[0]); + return new EncodeResult(root, List.of(seg), null, null); + } +} diff --git a/writer/src/main/java/io/github/dfa1/vortex/writer/encode/ZstdEncodingEncoder.java b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/ZstdEncodingEncoder.java new file mode 100644 index 00000000..7b3b75a1 --- /dev/null +++ b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/ZstdEncodingEncoder.java @@ -0,0 +1,159 @@ +package io.github.dfa1.vortex.writer.encode; + +import io.airlift.compress.v3.zstd.ZstdCompressor; +import io.airlift.compress.v3.zstd.ZstdJavaCompressor; +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.PType; +import io.github.dfa1.vortex.core.VortexException; +import io.github.dfa1.vortex.encoding.EncodeContext; +import io.github.dfa1.vortex.encoding.EncodeNode; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodingEncoder; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.PTypeIO; +import io.github.dfa1.vortex.proto.ZstdFrameMetadata; +import io.github.dfa1.vortex.proto.ZstdMetadata; + +import java.lang.foreign.Arena; +import java.lang.foreign.MemorySegment; +import java.lang.foreign.ValueLayout; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.util.Arrays; +import java.util.List; + +/// Write-only encoder for {@code vortex.zstd}. +public final class ZstdEncodingEncoder implements EncodingEncoder { + + /// Public no-arg constructor required by {@link java.util.ServiceLoader}. + public ZstdEncodingEncoder() { + } + + @Override + public EncodingId encodingId() { + return EncodingId.VORTEX_ZSTD; + } + + @Override + public boolean accepts(DType dtype) { + return dtype instanceof DType.Primitive || dtype instanceof DType.Utf8 || dtype instanceof DType.Binary; + } + + @Override + public EncodeResult encode(DType dtype, Object data, EncodeContext ctx) { + if (dtype instanceof DType.Primitive dt) { + return encodePrimitive(dt, data); + } + if (dtype instanceof DType.Utf8 || dtype instanceof DType.Binary) { + return encodeVarBin((String[]) data); + } + throw new VortexException(EncodingId.VORTEX_ZSTD, "unsupported dtype: " + dtype); + } + + private static EncodeResult encodePrimitive(DType.Primitive dt, Object data) { + MemorySegment raw = primitiveToLeBytes(dt.ptype(), data, Arena.ofAuto()); + long n = primitiveLength(dt.ptype(), data); + byte[] rawBytes = raw.toArray(ValueLayout.JAVA_BYTE); + return buildResult(rawBytes, n); + } + + private static EncodeResult encodeVarBin(String[] strings) { + byte[] raw = buildLengthPrefixed(strings); + return buildResult(raw, strings.length); + } + + private static EncodeResult buildResult(byte[] raw, long n) { + byte[] compressed = compress(raw); + byte[] meta = new ZstdMetadata( + 0, + java.util.List.of(new ZstdFrameMetadata(raw.length, n)) + ).encode(); + EncodeNode root = new EncodeNode(EncodingId.VORTEX_ZSTD, ByteBuffer.wrap(meta), + new EncodeNode[0], new int[]{0}); + return new EncodeResult(root, List.of(MemorySegment.ofArray(compressed)), null, null); + } + + private static byte[] compress(byte[] input) { + ZstdCompressor compressor = new ZstdJavaCompressor(); + byte[] out = new byte[compressor.maxCompressedLength(input.length)]; + int len = compressor.compress(input, 0, input.length, out, 0, out.length); + return Arrays.copyOf(out, len); + } + + private static MemorySegment primitiveToLeBytes(PType ptype, Object data, Arena arena) { + return switch (ptype) { + case I8, U8 -> MemorySegment.ofArray((byte[]) data); + case I16, U16, F16 -> { + short[] arr = (short[]) data; + MemorySegment seg = arena.allocate((long) arr.length * 2, 2); + for (int i = 0; i < arr.length; i++) { + seg.setAtIndex(PTypeIO.LE_SHORT, i, arr[i]); + } + yield seg; + } + case I32, U32 -> { + int[] arr = (int[]) data; + MemorySegment seg = arena.allocate((long) arr.length * 4, 4); + for (int i = 0; i < arr.length; i++) { + seg.setAtIndex(PTypeIO.LE_INT, i, arr[i]); + } + yield seg; + } + case I64, U64 -> { + long[] arr = (long[]) data; + MemorySegment seg = arena.allocate((long) arr.length * 8, 8); + for (int i = 0; i < arr.length; i++) { + seg.setAtIndex(PTypeIO.LE_LONG, i, arr[i]); + } + yield seg; + } + case F32 -> { + float[] arr = (float[]) data; + MemorySegment seg = arena.allocate((long) arr.length * 4, 4); + for (int i = 0; i < arr.length; i++) { + seg.setAtIndex(PTypeIO.LE_FLOAT, i, arr[i]); + } + yield seg; + } + case F64 -> { + double[] arr = (double[]) data; + MemorySegment seg = arena.allocate((long) arr.length * 8, 8); + for (int i = 0; i < arr.length; i++) { + seg.setAtIndex(PTypeIO.LE_DOUBLE, i, arr[i]); + } + yield seg; + } + }; + } + + private static long primitiveLength(PType ptype, Object data) { + return switch (ptype) { + case I8, U8 -> ((byte[]) data).length; + case I16, U16, F16 -> ((short[]) data).length; + case I32, U32 -> ((int[]) data).length; + case F32 -> ((float[]) data).length; + case I64, U64 -> ((long[]) data).length; + case F64 -> ((double[]) data).length; + }; + } + + private static byte[] buildLengthPrefixed(String[] strings) { + int total = 0; + byte[][] encoded = new byte[strings.length][]; + for (int i = 0; i < strings.length; i++) { + encoded[i] = strings[i].getBytes(StandardCharsets.UTF_8); + total += 4 + encoded[i].length; + } + try (Arena scratch = Arena.ofConfined()) { + MemorySegment seg = scratch.allocate(total > 0 ? total : 1); + long pos = 0; + for (byte[] bytes : encoded) { + seg.set(PTypeIO.LE_INT, pos, bytes.length); + pos += 4; + MemorySegment.copy(MemorySegment.ofArray(bytes), 0, seg, pos, bytes.length); + pos += bytes.length; + } + return seg.asSlice(0, total).toArray(ValueLayout.JAVA_BYTE); + } + } +} diff --git a/writer/src/main/resources/META-INF/services/io.github.dfa1.vortex.encoding.EncodingEncoder b/writer/src/main/resources/META-INF/services/io.github.dfa1.vortex.encoding.EncodingEncoder new file mode 100644 index 00000000..19a1dab3 --- /dev/null +++ b/writer/src/main/resources/META-INF/services/io.github.dfa1.vortex.encoding.EncodingEncoder @@ -0,0 +1,33 @@ +io.github.dfa1.vortex.writer.encode.AlpEncodingEncoder +io.github.dfa1.vortex.writer.encode.AlpRdEncodingEncoder +io.github.dfa1.vortex.writer.encode.BitpackedEncodingEncoder +io.github.dfa1.vortex.writer.encode.BoolEncodingEncoder +io.github.dfa1.vortex.writer.encode.ByteBoolEncodingEncoder +io.github.dfa1.vortex.writer.encode.ChunkedEncodingEncoder +io.github.dfa1.vortex.writer.encode.ConstantEncodingEncoder +io.github.dfa1.vortex.writer.encode.DateTimePartsEncodingEncoder +io.github.dfa1.vortex.writer.encode.DecimalBytePartsEncodingEncoder +io.github.dfa1.vortex.writer.encode.DecimalEncodingEncoder +io.github.dfa1.vortex.writer.encode.DeltaEncodingEncoder +io.github.dfa1.vortex.writer.encode.DictEncodingEncoder +io.github.dfa1.vortex.writer.encode.ExtEncodingEncoder +io.github.dfa1.vortex.writer.encode.FixedSizeListEncodingEncoder +io.github.dfa1.vortex.writer.encode.FrameOfReferenceEncodingEncoder +io.github.dfa1.vortex.writer.encode.FsstEncodingEncoder +io.github.dfa1.vortex.writer.encode.ListEncodingEncoder +io.github.dfa1.vortex.writer.encode.ListViewEncodingEncoder +io.github.dfa1.vortex.writer.encode.MaskedEncodingEncoder +io.github.dfa1.vortex.writer.encode.NullEncodingEncoder +io.github.dfa1.vortex.writer.encode.PatchedEncodingEncoder +io.github.dfa1.vortex.writer.encode.PcoEncodingEncoder +io.github.dfa1.vortex.writer.encode.PrimitiveEncodingEncoder +io.github.dfa1.vortex.writer.encode.RleEncodingEncoder +io.github.dfa1.vortex.writer.encode.RunEndEncodingEncoder +io.github.dfa1.vortex.writer.encode.SequenceEncodingEncoder +io.github.dfa1.vortex.writer.encode.SparseEncodingEncoder +io.github.dfa1.vortex.writer.encode.StructEncodingEncoder +io.github.dfa1.vortex.writer.encode.VarBinEncodingEncoder +io.github.dfa1.vortex.writer.encode.VariantEncodingEncoder +io.github.dfa1.vortex.writer.encode.VarBinViewEncodingEncoder +io.github.dfa1.vortex.writer.encode.ZigZagEncodingEncoder +io.github.dfa1.vortex.writer.encode.ZstdEncodingEncoder diff --git a/core/src/test/java/io/github/dfa1/vortex/encoding/AlpEncodingTest.java b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/AlpEncodingEncoderTest.java similarity index 52% rename from core/src/test/java/io/github/dfa1/vortex/encoding/AlpEncodingTest.java rename to writer/src/test/java/io/github/dfa1/vortex/writer/encode/AlpEncodingEncoderTest.java index 5878e433..e95d5942 100644 --- a/core/src/test/java/io/github/dfa1/vortex/encoding/AlpEncodingTest.java +++ b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/AlpEncodingEncoderTest.java @@ -1,10 +1,21 @@ -package io.github.dfa1.vortex.encoding; +package io.github.dfa1.vortex.writer.encode; import io.github.dfa1.vortex.core.ArrayStats; import io.github.dfa1.vortex.core.array.Array; import io.github.dfa1.vortex.core.array.ArraySegments; +import io.github.dfa1.vortex.encoding.ArrayNode; +import io.github.dfa1.vortex.encoding.DTypes; +import io.github.dfa1.vortex.encoding.DecodeContext; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodeTestHelper; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.PTypeIO; +import io.github.dfa1.vortex.encoding.Registry; +import io.github.dfa1.vortex.encoding.TestRegistry; import io.github.dfa1.vortex.proto.ALPMetadata; import io.github.dfa1.vortex.proto.PatchesMetadata; +import io.github.dfa1.vortex.reader.decode.AlpEncodingDecoder; +import io.github.dfa1.vortex.reader.decode.PrimitiveEncodingDecoder; import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; @@ -17,24 +28,22 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.within; -class AlpEncodingTest { +class AlpEncodingEncoderTest { + private static final AlpEncodingEncoder ENCODER = new AlpEncodingEncoder(); + private static final AlpEncodingDecoder DECODER = new AlpEncodingDecoder(); + private static final Registry REGISTRY = TestRegistry.ofDecoders(DECODER, new PrimitiveEncodingDecoder()); @Nested class Decode { private static DecodeContext buildAlpCtxF64( - int expE, int expF, - long[] encodedVals, - long[] patchIndices, - double[] patchValues + int expE, int expF, long[] encodedVals, + long[] patchIndices, double[] patchValues ) { PatchesMetadata pm = patchIndices != null - ? new PatchesMetadata( - (long) patchIndices.length, - 0L, - io.github.dfa1.vortex.proto.PType.U32, - null, null, null) + ? new PatchesMetadata((long) patchIndices.length, 0L, + io.github.dfa1.vortex.proto.PType.U32, null, null, null) : null; byte[] metaBytes = new ALPMetadata(expE, expF, pm).encode(); @@ -56,24 +65,16 @@ private static DecodeContext buildAlpCtxF64( for (long v : patchIndices) { ib.putInt((int) v); } - byte[] valBuf = new byte[patchValues.length * 8]; ByteBuffer vb = ByteBuffer.wrap(valBuf).order(ByteOrder.LITTLE_ENDIAN); for (double v : patchValues) { vb.putDouble(v); } - - ArrayNode idxNode = ArrayNode.of(EncodingId.VORTEX_PRIMITIVE, null, - new ArrayNode[0], new int[]{1}, ArrayStats.empty()); - ArrayNode valNode = ArrayNode.of(EncodingId.VORTEX_PRIMITIVE, null, - new ArrayNode[0], new int[]{2}, ArrayStats.empty()); - + ArrayNode idxNode = ArrayNode.of(EncodingId.VORTEX_PRIMITIVE, null, new ArrayNode[0], new int[]{1}, ArrayStats.empty()); + ArrayNode valNode = ArrayNode.of(EncodingId.VORTEX_PRIMITIVE, null, new ArrayNode[0], new int[]{2}, ArrayStats.empty()); children = new ArrayNode[]{encNode, idxNode, valNode}; segments = new MemorySegment[]{ - MemorySegment.ofArray(encBuf), - MemorySegment.ofArray(idxBuf), - MemorySegment.ofArray(valBuf) - }; + MemorySegment.ofArray(encBuf), MemorySegment.ofArray(idxBuf), MemorySegment.ofArray(valBuf)}; } else { children = new ArrayNode[]{encNode}; segments = new MemorySegment[]{MemorySegment.ofArray(encBuf)}; @@ -82,91 +83,59 @@ private static DecodeContext buildAlpCtxF64( ArrayNode alpNode = ArrayNode.of(EncodingId.VORTEX_ALP, ByteBuffer.wrap(metaBytes), children, new int[0], ArrayStats.empty()); - Registry registry = TestRegistry.of(new AlpEncoding(), new PrimitiveEncoding()); - - return new DecodeContext(alpNode, DTypes.F64, encodedVals.length, segments, registry, java.lang.foreign.Arena.global()); + return new DecodeContext(alpNode, DTypes.F64, encodedVals.length, segments, REGISTRY, java.lang.foreign.Arena.global()); } - private static DecodeContext buildAlpCtxF32( - int expE, int expF, - int[] encodedVals - ) { + private static DecodeContext buildAlpCtxF32(int expE, int expF, int[] encodedVals) { byte[] metaBytes = new ALPMetadata(expE, expF, null).encode(); - byte[] encBuf = new byte[encodedVals.length * 4]; ByteBuffer bb = ByteBuffer.wrap(encBuf).order(ByteOrder.LITTLE_ENDIAN); for (int v : encodedVals) { bb.putInt(v); } - - ArrayNode encNode = ArrayNode.of(EncodingId.VORTEX_PRIMITIVE, null, - new ArrayNode[0], new int[]{0}, ArrayStats.empty()); - - ArrayNode alpNode = ArrayNode.of(EncodingId.VORTEX_ALP, - ByteBuffer.wrap(metaBytes), new ArrayNode[]{encNode}, new int[0], ArrayStats.empty()); - + ArrayNode encNode = ArrayNode.of(EncodingId.VORTEX_PRIMITIVE, null, new ArrayNode[0], new int[]{0}, ArrayStats.empty()); + ArrayNode alpNode = ArrayNode.of(EncodingId.VORTEX_ALP, ByteBuffer.wrap(metaBytes), + new ArrayNode[]{encNode}, new int[0], ArrayStats.empty()); MemorySegment[] segments = {MemorySegment.ofArray(encBuf)}; - - Registry registry = TestRegistry.of(new AlpEncoding(), new PrimitiveEncoding()); - - return new DecodeContext(alpNode, DTypes.F32, encodedVals.length, segments, registry, java.lang.foreign.Arena.global()); + return new DecodeContext(alpNode, DTypes.F32, encodedVals.length, segments, REGISTRY, java.lang.foreign.Arena.global()); } @Test void decode_f64_noPatches() { - // Given — encode 1.23 with exp_e=2, exp_f=1: encoded = round(1.23 * 100 / 10) = 12 - // decode: 12 * 10^1 * 10^-2 = 12 * 10 * 0.01 = 1.2 - // Use exp_e=0, exp_f=2: encoded = round(1.23 * 1 / 100) ... let's use known values - // encode: value * F10[e] * IF10[f] then round; decode: encoded * F10[f] * IF10[e] - // Use e=2, f=0: encoded = round(1.23 * 100 * 1.0) = 123; decode = 123 * 1.0 * 0.01 = 1.23 int expE = 2, expF = 0; long[] encoded = {123L, 456L, 789L}; double[] expected = {1.23, 4.56, 7.89}; DecodeContext ctx = buildAlpCtxF64(expE, expF, encoded, null, null); - AlpEncoding sut = new AlpEncoding(); + Array result = DECODER.decode(ctx); - // When - Array result = sut.decode(ctx); - - // Then assertThat(result.length()).isEqualTo(encoded.length); - var layout = PTypeIO.LE_DOUBLE; for (int i = 0; i < expected.length; i++) { - assertThat(ArraySegments.of(result).get(layout, (long) i * 8)) + assertThat(ArraySegments.of(result).get(PTypeIO.LE_DOUBLE, (long) i * 8)) .as("index %d", i).isCloseTo(expected[i], within(1e-9)); } } @Test void decode_f64_withPatches() { - // Given — 5 values, encoded = [100, 0, 200, 0, 300] with e=2,f=0 → [1.0, 0.0, 2.0, 0.0, 3.0] - // patches at [1, 3] with real values [Double.NaN, Double.POSITIVE_INFINITY] int expE = 2, expF = 0; long[] encoded = {100L, 0L, 200L, 0L, 300L}; long[] patchIndices = {1L, 3L}; double[] patchValues = {Double.NaN, Double.POSITIVE_INFINITY}; DecodeContext ctx = buildAlpCtxF64(expE, expF, encoded, patchIndices, patchValues); - AlpEncoding sut = new AlpEncoding(); - - // When - Array result = sut.decode(ctx); - - // Then - var layout = PTypeIO.LE_DOUBLE; - assertThat(ArraySegments.of(result).get(layout, 0L)).isCloseTo(1.0, within(1e-9)); - assertThat(ArraySegments.of(result).get(layout, 8L)).isNaN(); - assertThat(ArraySegments.of(result).get(layout, 16L)).isCloseTo(2.0, within(1e-9)); - assertThat(ArraySegments.of(result).get(layout, 24L)).isInfinite(); - assertThat(ArraySegments.of(result).get(layout, 32L)).isCloseTo(3.0, within(1e-9)); + Array result = DECODER.decode(ctx); + + assertThat(ArraySegments.of(result).get(PTypeIO.LE_DOUBLE, 0L)).isCloseTo(1.0, within(1e-9)); + assertThat(ArraySegments.of(result).get(PTypeIO.LE_DOUBLE, 8L)).isNaN(); + assertThat(ArraySegments.of(result).get(PTypeIO.LE_DOUBLE, 16L)).isCloseTo(2.0, within(1e-9)); + assertThat(ArraySegments.of(result).get(PTypeIO.LE_DOUBLE, 24L)).isInfinite(); + assertThat(ArraySegments.of(result).get(PTypeIO.LE_DOUBLE, 32L)).isCloseTo(3.0, within(1e-9)); } @ParameterizedTest @CsvSource({"0,0", "1,0", "2,1", "3,2", "4,3"}) void decode_f64_exponentCombinations(int expE, int expF) { - // Given — encode 42 with given exponents, then verify round-trip - // encode: encoded = round(42.0 * F10[e] * IF10[f]) double value = 42.0; double[] f10 = {1e0, 1e1, 1e2, 1e3, 1e4}; double[] if10 = {1e-0, 1e-1, 1e-2, 1e-3, 1e-4}; @@ -174,34 +143,23 @@ void decode_f64_exponentCombinations(int expE, int expF) { long[] encoded = {encVal}; DecodeContext ctx = buildAlpCtxF64(expE, expF, encoded, null, null); - AlpEncoding sut = new AlpEncoding(); - - // When - Array result = sut.decode(ctx); + Array result = DECODER.decode(ctx); - // Then - var layout = PTypeIO.LE_DOUBLE; - double decoded = ArraySegments.of(result).get(layout, 0L); + double decoded = ArraySegments.of(result).get(PTypeIO.LE_DOUBLE, 0L); assertThat(decoded).isCloseTo(value, within(1e-6)); } @Test void decode_f32_noPatches() { - // Given — e=1, f=0: decode = encoded * 1.0 * 0.1 int expE = 1, expF = 0; int[] encoded = {10, 25, 100}; float[] expected = {1.0f, 2.5f, 10.0f}; DecodeContext ctx = buildAlpCtxF32(expE, expF, encoded); - AlpEncoding sut = new AlpEncoding(); + Array result = DECODER.decode(ctx); - // When - Array result = sut.decode(ctx); - - // Then - var layout = PTypeIO.LE_FLOAT; for (int i = 0; i < expected.length; i++) { - assertThat(ArraySegments.of(result).get(layout, (long) i * 4)) + assertThat(ArraySegments.of(result).get(PTypeIO.LE_FLOAT, (long) i * 4)) .as("index %d", i).isCloseTo(expected[i], within(1e-6f)); } } @@ -212,80 +170,55 @@ class Encode { @Test void encode_f32_roundTrip_noPatches() { - // Given float[] values = {1.0f, 2.5f, 3.75f, 10.0f, 0.1f}; - AlpEncoding sut = new AlpEncoding(); - - Registry registry = TestRegistry.withPrimitive(sut); - // When - EncodeResult encoded = sut.encode(DTypes.F32, values, EncodeTestHelper.testCtx()); - DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, values.length, DTypes.F32, registry); - Array result = sut.decode(ctx); + EncodeResult encoded = ENCODER.encode(DTypes.F32, values, EncodeTestHelper.testCtx()); + DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, values.length, DTypes.F32, REGISTRY); + Array result = DECODER.decode(ctx); - // Then - var layout = PTypeIO.LE_FLOAT; for (int i = 0; i < values.length; i++) { - assertThat(ArraySegments.of(result).get(layout, (long) i * 4)) + assertThat(ArraySegments.of(result).get(PTypeIO.LE_FLOAT, (long) i * 4)) .as("index %d", i).isCloseTo(values[i], within(1e-6f)); } } @Test void encode_f32_roundTrip_withPatches() { - // Given — Float.NaN and Float.POSITIVE_INFINITY can't be ALP-encoded; must become patches float[] values = {1.0f, Float.NaN, 2.5f, Float.POSITIVE_INFINITY, 3.0f}; - AlpEncoding sut = new AlpEncoding(); - - Registry registry = TestRegistry.withPrimitive(sut); - - // When - EncodeResult encoded = sut.encode(DTypes.F32, values, EncodeTestHelper.testCtx()); - DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, values.length, DTypes.F32, registry); - Array result = sut.decode(ctx); - - // Then - var layout = PTypeIO.LE_FLOAT; - assertThat(ArraySegments.of(result).get(layout, 0L)).isCloseTo(1.0f, within(1e-6f)); - assertThat(ArraySegments.of(result).get(layout, 4L)).isNaN(); - assertThat(ArraySegments.of(result).get(layout, 8L)).isCloseTo(2.5f, within(1e-6f)); - assertThat(ArraySegments.of(result).get(layout, 12L)).isInfinite(); - assertThat(ArraySegments.of(result).get(layout, 16L)).isCloseTo(3.0f, within(1e-6f)); + + EncodeResult encoded = ENCODER.encode(DTypes.F32, values, EncodeTestHelper.testCtx()); + DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, values.length, DTypes.F32, REGISTRY); + Array result = DECODER.decode(ctx); + + assertThat(ArraySegments.of(result).get(PTypeIO.LE_FLOAT, 0L)).isCloseTo(1.0f, within(1e-6f)); + assertThat(ArraySegments.of(result).get(PTypeIO.LE_FLOAT, 4L)).isNaN(); + assertThat(ArraySegments.of(result).get(PTypeIO.LE_FLOAT, 8L)).isCloseTo(2.5f, within(1e-6f)); + assertThat(ArraySegments.of(result).get(PTypeIO.LE_FLOAT, 12L)).isInfinite(); + assertThat(ArraySegments.of(result).get(PTypeIO.LE_FLOAT, 16L)).isCloseTo(3.0f, within(1e-6f)); } @Test void encode_f64_roundTrip_noPatches() { - // Given double[] values = {1.23, 4.56, 7.89, 0.001, 100.0}; - AlpEncoding sut = new AlpEncoding(); - - Registry registry = TestRegistry.withPrimitive(sut); - // When - EncodeResult encoded = sut.encode(DTypes.F64, values, EncodeTestHelper.testCtx()); - DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, values.length, DTypes.F64, registry); - Array result = sut.decode(ctx); + EncodeResult encoded = ENCODER.encode(DTypes.F64, values, EncodeTestHelper.testCtx()); + DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, values.length, DTypes.F64, REGISTRY); + Array result = DECODER.decode(ctx); - // Then - var layout = PTypeIO.LE_DOUBLE; for (int i = 0; i < values.length; i++) { - assertThat(ArraySegments.of(result).get(layout, (long) i * 8)) + assertThat(ArraySegments.of(result).get(PTypeIO.LE_DOUBLE, (long) i * 8)) .as("index %d", i).isCloseTo(values[i], within(1e-9)); } } @Test void encode_f64_metadata_expE_isNonZero() throws Exception { - // Given — 2-decimal values force ALP to pick exp_e=2 (×100); if tag drifts, exp_e reads as 0 double[] values = {1.23, 4.56, 7.89}; - AlpEncoding sut = new AlpEncoding(); - // When - EncodeResult result = sut.encode(DTypes.F64, values, EncodeTestHelper.testCtx()); + EncodeResult result = ENCODER.encode(DTypes.F64, values, EncodeTestHelper.testCtx()); MemorySegment metaSeg = MemorySegment.ofBuffer(result.rootNode().metadata().duplicate()); ALPMetadata meta = ALPMetadata.decode(metaSeg, 0, metaSeg.byteSize()); - // Then assertThat(meta.exp_e()).isGreaterThan(0); } } diff --git a/writer/src/test/java/io/github/dfa1/vortex/writer/encode/AlpRdEncodingEncoderTest.java b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/AlpRdEncodingEncoderTest.java new file mode 100644 index 00000000..314829c7 --- /dev/null +++ b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/AlpRdEncodingEncoderTest.java @@ -0,0 +1,57 @@ +package io.github.dfa1.vortex.writer.encode; + +import io.github.dfa1.vortex.core.array.ArraySegments; +import io.github.dfa1.vortex.encoding.DTypes; +import io.github.dfa1.vortex.encoding.DecodeContext; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodeTestHelper; +import io.github.dfa1.vortex.encoding.PTypeIO; +import io.github.dfa1.vortex.encoding.Registry; +import io.github.dfa1.vortex.encoding.TestRegistry; +import io.github.dfa1.vortex.proto.ALPRDMetadata; +import io.github.dfa1.vortex.reader.decode.AlpRdEncodingDecoder; +import io.github.dfa1.vortex.reader.decode.BitpackedEncodingDecoder; +import io.github.dfa1.vortex.reader.decode.PrimitiveEncodingDecoder; +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.within; + +class AlpRdEncodingEncoderTest { + + @Test + void encode_f64_roundTrip() { + // Given + double[] values = {0.1, 0.2, 0.3, 0.4, 0.5}; + var encoder = new AlpRdEncodingEncoder(); + var decoder = new AlpRdEncodingDecoder(); + Registry registry = TestRegistry.ofDecoders(decoder, new BitpackedEncodingDecoder(), new PrimitiveEncodingDecoder()); + + // When + EncodeResult encoded = encoder.encode(DTypes.F64, values, EncodeTestHelper.testCtx()); + DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, values.length, DTypes.F64, registry); + var result = decoder.decode(ctx); + + // Then + for (int i = 0; i < values.length; i++) { + assertThat(ArraySegments.of(result).get(PTypeIO.LE_DOUBLE, (long) i * 8)) + .as("index %d", i).isCloseTo(values[i], within(1e-9)); + } + } + + @Test + void encode_f64_metadata_rightBitWidth_isNonZero() throws Exception { + // Given — ALPRD splits F64 into left+right parts; right_bit_width>0 means real encoding happened + // if tag drifts, right_bit_width reads as 0 (proto3 default) and right parts are all zero + double[] values = {0.1, 0.2, 0.3, 0.4, 0.5}; + var sut = new AlpRdEncodingEncoder(); + + // When + EncodeResult result = sut.encode(DTypes.F64, values, EncodeTestHelper.testCtx()); + var metaSeg = java.lang.foreign.MemorySegment.ofBuffer(result.rootNode().metadata().duplicate()); + ALPRDMetadata meta = ALPRDMetadata.decode(metaSeg, 0, metaSeg.byteSize()); + + // Then + assertThat(meta.right_bit_width()).isGreaterThan(0); + } +} diff --git a/writer/src/test/java/io/github/dfa1/vortex/writer/encode/BitpackedConstantPatchesBroadcastTest.java b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/BitpackedConstantPatchesBroadcastTest.java new file mode 100644 index 00000000..ad5b7a56 --- /dev/null +++ b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/BitpackedConstantPatchesBroadcastTest.java @@ -0,0 +1,83 @@ +package io.github.dfa1.vortex.writer.encode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.PType; +import io.github.dfa1.vortex.core.array.Array; +import io.github.dfa1.vortex.core.array.ArraySegments; +import io.github.dfa1.vortex.encoding.ArrayNode; +import io.github.dfa1.vortex.encoding.DecodeContext; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.PTypeIO; +import io.github.dfa1.vortex.encoding.Registry; +import io.github.dfa1.vortex.proto.BitPackedMetadata; +import io.github.dfa1.vortex.proto.PatchesMetadata; +import io.github.dfa1.vortex.proto.ScalarValue; +import io.github.dfa1.vortex.reader.decode.BitpackedEncodingDecoder; +import org.junit.jupiter.api.Test; + +import java.lang.foreign.Arena; +import java.lang.foreign.MemorySegment; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; + +import static org.assertj.core.api.Assertions.assertThat; + +/// Regression for the IOOB crash in `BitpackedEncoding.applyPatches` (and sibling +/// `SparseEncoding`, `AlpEncoding`, `PatchedEncoding`, etc.) when a patches child is +/// encoded with [io.github.dfa1.vortex.encoding.ConstantEncoding]. +class BitpackedConstantPatchesBroadcastTest { + + @Test + void bitpackedDecode_withConstantPatchesValues_broadcastsValueAcrossPatches() { + long n = 10; + long numPatches = 3; + long constantPatchValue = 42L; + + byte[] packed = new byte[128]; + + ScalarValue idxScalar = ScalarValue.ofUint64Value(2L); + byte[] idxScalarBytes = idxScalar.encode(); + + ScalarValue valScalar = ScalarValue.ofInt64Value(constantPatchValue); + byte[] valScalarBytes = valScalar.encode(); + + PatchesMetadata patches = new PatchesMetadata(numPatches, 0, + io.github.dfa1.vortex.proto.PType.U32, null, null, null); + BitPackedMetadata meta = new BitPackedMetadata(1, 0, patches); + ByteBuffer metaBuf = ByteBuffer.wrap(meta.encode()).order(ByteOrder.LITTLE_ENDIAN); + + try (Arena arena = Arena.ofConfined()) { + MemorySegment packedSeg = arena.allocate(packed.length, 8); + MemorySegment.copy(MemorySegment.ofArray(packed), 0, packedSeg, 0, packed.length); + MemorySegment idxBufSeg = arena.allocate(idxScalarBytes.length, 1); + MemorySegment.copy(MemorySegment.ofArray(idxScalarBytes), 0, idxBufSeg, 0, idxScalarBytes.length); + MemorySegment valBufSeg = arena.allocate(valScalarBytes.length, 1); + MemorySegment.copy(MemorySegment.ofArray(valScalarBytes), 0, valBufSeg, 0, valScalarBytes.length); + + ArrayNode idxChild = ArrayNode.of(EncodingId.VORTEX_CONSTANT, null, + new ArrayNode[0], new int[]{1}, null); + ArrayNode valChild = ArrayNode.of(EncodingId.VORTEX_CONSTANT, null, + new ArrayNode[0], new int[]{2}, null); + ArrayNode root = ArrayNode.of(EncodingId.FASTLANES_BITPACKED, metaBuf, + new ArrayNode[]{idxChild, valChild}, new int[]{0}, null); + + DType dtype = new DType.Primitive(PType.I64, false); + Registry registry = Registry.loadAll(); + DecodeContext ctx = new DecodeContext(root, dtype, n, + new MemorySegment[]{packedSeg, idxBufSeg, valBufSeg}, + registry, Arena.ofAuto()); + + Array result = new BitpackedEncodingDecoder().decode(ctx); + + assertThat(result.length()).isEqualTo(n); + MemorySegment data = ArraySegments.of(result); + assertThat(data.getAtIndex(PTypeIO.LE_LONG, 2)).isEqualTo(constantPatchValue); + for (long i = 0; i < n; i++) { + if (i == 2) { + continue; + } + assertThat(data.getAtIndex(PTypeIO.LE_LONG, i)).as("non-patched index %d", i).isZero(); + } + } + } +} diff --git a/writer/src/test/java/io/github/dfa1/vortex/writer/encode/BitpackedEncodingEncoderTest.java b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/BitpackedEncodingEncoderTest.java new file mode 100644 index 00000000..66194ee5 --- /dev/null +++ b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/BitpackedEncodingEncoderTest.java @@ -0,0 +1,87 @@ +package io.github.dfa1.vortex.writer.encode; + +import io.github.dfa1.vortex.core.array.Array; +import io.github.dfa1.vortex.core.array.ArraySegments; +import io.github.dfa1.vortex.encoding.DTypes; +import io.github.dfa1.vortex.encoding.DecodeContext; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodeTestHelper; +import io.github.dfa1.vortex.encoding.PTypeIO; +import io.github.dfa1.vortex.encoding.Registry; +import io.github.dfa1.vortex.encoding.TestRegistry; +import io.github.dfa1.vortex.proto.BitPackedMetadata; +import io.github.dfa1.vortex.reader.decode.BitpackedEncodingDecoder; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; + +import java.util.stream.Stream; + +import static org.assertj.core.api.Assertions.assertThat; + +/// Property: encode then decode is lossless for unsigned integer types. +class BitpackedEncodingEncoderTest { + + private static final BitpackedEncodingEncoder ENCODER = new BitpackedEncodingEncoder(); + private static final BitpackedEncodingDecoder DECODER = new BitpackedEncodingDecoder(); + private static final Registry REGISTRY = TestRegistry.ofDecoders(DECODER); + + static Stream u32Arrays() { + return Stream.of( + Arguments.of("empty", new int[]{}), + Arguments.of("single", new int[]{0}), + Arguments.of("all-zeros", new int[]{0, 0, 0, 0, 0}), + Arguments.of("small-values", new int[]{1, 2, 3, 4, 5, 6, 7}), + Arguments.of("mixed", new int[]{0, 7, 63, 255, 1023, 65535}), + Arguments.of("max-unsigned", new int[]{-1, -1, -1}) // 0xFFFFFFFF + ); + } + + static Stream u64Arrays() { + return Stream.of( + Arguments.of("empty", new long[]{}), + Arguments.of("single", new long[]{0L}), + Arguments.of("small-values", new long[]{1L, 2L, 3L, 4L, 5L}), + Arguments.of("large-values", new long[]{0L, 0xFFFFL, 0xFFFFFFL, 0xFFFFFFFFL}) + ); + } + + @ParameterizedTest(name = "{0}") + @MethodSource("u32Arrays") + void encodeDecode_u32_isLossless(String name, int[] data) { + EncodeResult encoded = ENCODER.encode(DTypes.U32, data, EncodeTestHelper.testCtx()); + DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, data.length, DTypes.U32, REGISTRY); + Array result = DECODER.decode(ctx); + + assertThat(result.length()).isEqualTo(data.length); + for (int i = 0; i < data.length; i++) { + assertThat(ArraySegments.of(result).get(PTypeIO.LE_INT, (long) i * 4)).as("index %d", i).isEqualTo(data[i]); + } + } + + @ParameterizedTest(name = "{0}") + @MethodSource("u64Arrays") + void encodeDecode_u64_isLossless(String name, long[] data) { + EncodeResult encoded = ENCODER.encode(DTypes.U64, data, EncodeTestHelper.testCtx()); + DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, data.length, DTypes.U64, REGISTRY); + Array result = DECODER.decode(ctx); + + assertThat(result.length()).isEqualTo(data.length); + for (int i = 0; i < data.length; i++) { + assertThat(ArraySegments.of(result).get(PTypeIO.LE_LONG, (long) i * 8)).as("index %d", i).isEqualTo(data[i]); + } + } + + @Test + void encode_i32_metadata_bitWidth_isNonZero() throws Exception { + // Given — max value 5 needs 3 bits; if tag drifts, bit_width reads as 0 (proto3 default) + int[] data = {1, 2, 3, 4, 5}; + + EncodeResult result = ENCODER.encode(DTypes.I32, data, EncodeTestHelper.testCtx()); + var metaSeg = java.lang.foreign.MemorySegment.ofBuffer(result.rootNode().metadata().duplicate()); + BitPackedMetadata meta = BitPackedMetadata.decode(metaSeg, 0, metaSeg.byteSize()); + + assertThat(meta.bit_width()).isGreaterThan(0); + } +} diff --git a/core/src/test/java/io/github/dfa1/vortex/encoding/BitpackedEncodingPatchesTest.java b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/BitpackedEncodingPatchesTest.java similarity index 68% rename from core/src/test/java/io/github/dfa1/vortex/encoding/BitpackedEncodingPatchesTest.java rename to writer/src/test/java/io/github/dfa1/vortex/writer/encode/BitpackedEncodingPatchesTest.java index 0ab812af..ebb0f288 100644 --- a/core/src/test/java/io/github/dfa1/vortex/encoding/BitpackedEncodingPatchesTest.java +++ b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/BitpackedEncodingPatchesTest.java @@ -1,10 +1,21 @@ -package io.github.dfa1.vortex.encoding; +package io.github.dfa1.vortex.writer.encode; -import io.github.dfa1.vortex.proto.BitPackedMetadata; -import io.github.dfa1.vortex.proto.PatchesMetadata; import io.github.dfa1.vortex.core.ArrayStats; import io.github.dfa1.vortex.core.array.Array; import io.github.dfa1.vortex.core.array.ArraySegments; +import io.github.dfa1.vortex.encoding.ArrayNode; +import io.github.dfa1.vortex.encoding.DTypes; +import io.github.dfa1.vortex.encoding.DecodeContext; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodeTestHelper; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.PTypeIO; +import io.github.dfa1.vortex.encoding.Registry; +import io.github.dfa1.vortex.encoding.TestRegistry; +import io.github.dfa1.vortex.proto.BitPackedMetadata; +import io.github.dfa1.vortex.proto.PatchesMetadata; +import io.github.dfa1.vortex.reader.decode.BitpackedEncodingDecoder; +import io.github.dfa1.vortex.reader.decode.PrimitiveEncodingDecoder; import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; @@ -17,21 +28,21 @@ class BitpackedEncodingPatchesTest { + private static final BitpackedEncodingEncoder ENCODER = new BitpackedEncodingEncoder(); + private static final BitpackedEncodingDecoder DECODER = new BitpackedEncodingDecoder(); + private static final Registry REGISTRY = TestRegistry.ofDecoders(DECODER, new PrimitiveEncodingDecoder()); + @Nested class Decode { @Test void decode_appliesPatches_overridingBitPackedValues() { - // Given — bit-pack [10, 20, 30, 40, 50] via the production encoder (bitWidth = 6), - // then attach synthetic patches metadata that rewrites indices [1, 3] with [777, 999]. int[] base = {10, 20, 30, 40, 50}; - BitpackedEncoding sut = new BitpackedEncoding(); - EncodeResult packed = sut.encode(DTypes.I32, base, EncodeTestHelper.testCtx()); + EncodeResult packed = ENCODER.encode(DTypes.I32, base, EncodeTestHelper.testCtx()); MemorySegment packedSeg = packed.buffers().getFirst(); byte[] packedBytes = packedSeg.toArray(java.lang.foreign.ValueLayout.JAVA_BYTE); - // Build new BitPackedMetadata that re-uses the packed bytes but advertises patches. PatchesMetadata patches = new PatchesMetadata(2L, 0L, io.github.dfa1.vortex.proto.PType.U32, null, null, null); byte[] metaBytes = new BitPackedMetadata(6, 0, patches).encode(); @@ -48,8 +59,7 @@ void decode_appliesPatches_overridingBitPackedValues() { ArrayNode bpNode = ArrayNode.of(EncodingId.FASTLANES_BITPACKED, ByteBuffer.wrap(metaBytes), new ArrayNode[]{idxNode, valNode}, - new int[]{0}, - ArrayStats.empty()); + new int[]{0}, ArrayStats.empty()); MemorySegment[] segments = { MemorySegment.ofArray(packedBytes), @@ -57,15 +67,11 @@ void decode_appliesPatches_overridingBitPackedValues() { MemorySegment.ofArray(valBuf) }; - Registry registry = TestRegistry.of(new BitpackedEncoding(), new PrimitiveEncoding()); - DecodeContext ctx = new DecodeContext( - bpNode, DTypes.I32, base.length, segments, registry, Arena.global()); + bpNode, DTypes.I32, base.length, segments, REGISTRY, Arena.global()); - // When - Array result = sut.decode(ctx); + Array result = DECODER.decode(ctx); - // Then assertThat(result.length()).isEqualTo(base.length); assertThat(ArraySegments.of(result).get(PTypeIO.LE_INT, 0L)).isEqualTo(10); assertThat(ArraySegments.of(result).get(PTypeIO.LE_INT, 4L)).isEqualTo(777); diff --git a/writer/src/test/java/io/github/dfa1/vortex/writer/encode/BoolEncodingEncoderTest.java b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/BoolEncodingEncoderTest.java new file mode 100644 index 00000000..1bad615d --- /dev/null +++ b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/BoolEncodingEncoderTest.java @@ -0,0 +1,72 @@ +package io.github.dfa1.vortex.writer.encode; + +import io.github.dfa1.vortex.core.array.Array; +import io.github.dfa1.vortex.core.array.ArraySegments; +import io.github.dfa1.vortex.core.array.BoolArray; +import io.github.dfa1.vortex.encoding.DTypes; +import io.github.dfa1.vortex.encoding.DecodeContext; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodeTestHelper; +import io.github.dfa1.vortex.encoding.Registry; +import io.github.dfa1.vortex.encoding.TestRegistry; +import io.github.dfa1.vortex.reader.decode.BoolEncodingDecoder; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import java.lang.foreign.ValueLayout; +import java.util.stream.Stream; + +import static org.assertj.core.api.Assertions.assertThat; + +/// Property: encode then decode is lossless for boolean arrays of all lengths (including non-multiples of 8). +class BoolEncodingEncoderTest { + + static Stream boolArrays() { + return Stream.of( + new boolean[]{}, + new boolean[]{false}, + new boolean[]{true}, + new boolean[]{false, true, false, true, false, true, false, true}, + new boolean[]{true, true, true, false, false, false, true, false, true}, + new boolean[]{false, false, false, false, false, false, false}, + new boolean[]{true, true, true, true, true, true, true, true, true} + ); + } + + @ParameterizedTest + @MethodSource("boolArrays") + void encodeDecode_isLossless(boolean[] data) { + // Given + var encoder = new BoolEncodingEncoder(); + var decoder = new BoolEncodingDecoder(); + Registry registry = TestRegistry.ofDecoders(decoder); + + // When + EncodeResult encoded = encoder.encode(DTypes.BOOL, data, EncodeTestHelper.testCtx()); + DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, data.length, DTypes.BOOL, registry); + Array result = decoder.decode(ctx); + + // Then + assertThat(result).isInstanceOf(BoolArray.class); + assertThat(result.length()).isEqualTo(data.length); + for (int i = 0; i < data.length; i++) { + byte byteVal = ArraySegments.of(result).get(ValueLayout.JAVA_BYTE, i / 8); + boolean decoded = ((byteVal >>> (i % 8)) & 1) == 1; + assertThat(decoded).as("index %d", i).isEqualTo(data[i]); + } + } + + @ParameterizedTest + @MethodSource("boolArrays") + void encodedSize_isPackedBits(boolean[] data) { + // Given + var sut = new BoolEncodingEncoder(); + + // When + EncodeResult encoded = sut.encode(DTypes.BOOL, data, EncodeTestHelper.testCtx()); + + // Then — bit-packed: ceiling(n/8) bytes, always ≤ n bytes raw + long totalBytes = encoded.buffers().stream().mapToLong(java.lang.foreign.MemorySegment::byteSize).sum(); + assertThat(totalBytes).isEqualTo((long) (data.length + 7) / 8); + } +} diff --git a/writer/src/test/java/io/github/dfa1/vortex/writer/encode/ByteBoolEncodingEncoderTest.java b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/ByteBoolEncodingEncoderTest.java new file mode 100644 index 00000000..af85901c --- /dev/null +++ b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/ByteBoolEncodingEncoderTest.java @@ -0,0 +1,52 @@ +package io.github.dfa1.vortex.writer.encode; + +import io.github.dfa1.vortex.core.array.Array; +import io.github.dfa1.vortex.core.array.BoolArray; +import io.github.dfa1.vortex.encoding.DTypes; +import io.github.dfa1.vortex.encoding.DecodeContext; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodeTestHelper; +import io.github.dfa1.vortex.encoding.Registry; +import io.github.dfa1.vortex.encoding.TestRegistry; +import io.github.dfa1.vortex.reader.decode.ByteBoolEncodingDecoder; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import java.util.stream.Stream; + +import static org.assertj.core.api.Assertions.assertThat; + +class ByteBoolEncodingEncoderTest { + + static Stream boolArrays() { + return Stream.of( + new boolean[]{}, + new boolean[]{false}, + new boolean[]{true}, + new boolean[]{true, false, true, false, true}, + new boolean[]{false, false, false, false}, + new boolean[]{true, true, true, true, true, true, true, true, true} + ); + } + + @ParameterizedTest + @MethodSource("boolArrays") + void encodeDecode_isLossless(boolean[] data) { + // Given + var encoder = new ByteBoolEncodingEncoder(); + var decoder = new ByteBoolEncodingDecoder(); + Registry registry = TestRegistry.ofDecoders(decoder); + + // When + EncodeResult encoded = encoder.encode(DTypes.BOOL, data, EncodeTestHelper.testCtx()); + DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, data.length, DTypes.BOOL, registry); + Array result = decoder.decode(ctx); + + // Then + assertThat(result.length()).isEqualTo(data.length); + BoolArray boolArr = (BoolArray) result; + for (int i = 0; i < data.length; i++) { + assertThat(boolArr.getBoolean(i)).as("index %d", i).isEqualTo(data[i]); + } + } +} diff --git a/core/src/test/java/io/github/dfa1/vortex/encoding/CascadingCompressorTest.java b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/CascadingCompressorTest.java similarity index 86% rename from core/src/test/java/io/github/dfa1/vortex/encoding/CascadingCompressorTest.java rename to writer/src/test/java/io/github/dfa1/vortex/writer/encode/CascadingCompressorTest.java index 5f09856c..91432eab 100644 --- a/core/src/test/java/io/github/dfa1/vortex/encoding/CascadingCompressorTest.java +++ b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/CascadingCompressorTest.java @@ -1,7 +1,22 @@ -package io.github.dfa1.vortex.encoding; +package io.github.dfa1.vortex.writer.encode; import io.github.dfa1.vortex.core.array.DoubleArray; import io.github.dfa1.vortex.core.array.LongArray; +import io.github.dfa1.vortex.encoding.AlpEncoding; +import io.github.dfa1.vortex.encoding.BitpackedEncoding; +import io.github.dfa1.vortex.encoding.CascadingCompressor; +import io.github.dfa1.vortex.encoding.DTypes; +import io.github.dfa1.vortex.encoding.DecodeContext; +import io.github.dfa1.vortex.encoding.DictEncoding; +import io.github.dfa1.vortex.encoding.EncodeContext; +import io.github.dfa1.vortex.encoding.EncodeNode; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodeTestHelper; +import io.github.dfa1.vortex.encoding.Encoding; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.FrameOfReferenceEncoding; +import io.github.dfa1.vortex.encoding.PrimitiveEncoding; +import io.github.dfa1.vortex.encoding.Registry; import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; @@ -27,17 +42,14 @@ class Compress { @Test void depth0_picksTerminalOnly_forF64() { - // Given double[] values = new double[4096]; for (int i = 0; i < values.length; i++) { values[i] = i * 1.5; } CascadingCompressor sut = new CascadingCompressor(ALL_CODECS); - // When EncodeResult result = sut.encode(DTypes.F64, values, ctx(0)); - // Then - result should be a valid non-null encode result assertThat(result).isNotNull(); assertThat(result.rootNode()).isNotNull(); assertThat(result.buffers()).isNotEmpty(); @@ -45,38 +57,31 @@ void depth0_picksTerminalOnly_forF64() { @Test void depth1_alpPlusBitpacked_producesSmallResult_forF64() { - // Given: OHLC-style prices — ALP-encodable, small range → bitpackable residuals double[] values = new double[4096]; for (int i = 0; i < values.length; i++) { values[i] = 100.0 + (i % 50) * 0.01; } CascadingCompressor sut = new CascadingCompressor(ALL_CODECS); - // When EncodeResult result = sut.encode(DTypes.F64, values, ctx(1)); - // Then - cascaded result should be smaller than raw primitive (4096 * 8 = 32768 bytes) long totalBytes = result.buffers().stream().mapToLong(MemorySegment::byteSize).sum(); assertThat(totalBytes).isLessThan(4096L * 8); } @Test void excludedEncodings_areSkipped() { - // Given: exclude AlpEncoding — should not be selected for DTypes.F64 double[] values = new double[512]; for (int i = 0; i < values.length; i++) { values[i] = i * 1.5; } - // Depth=1 but ALP excluded via context EncodeContext encodeCtx = new EncodeContext( Arena.ofAuto(), Registry.of(ALL_CODECS), 1, Set.of(EncodingId.VORTEX_ALP), 42L, 64, 0.1); CascadingCompressor sut = new CascadingCompressor(ALL_CODECS); - // When EncodeResult result = sut.encode(DTypes.F64, values, encodeCtx); - // Then - ALP node should not appear in the tree assertThat(containsEncoding(result.rootNode(), EncodingId.VORTEX_ALP)).isFalse(); } @@ -98,20 +103,17 @@ class RoundTrip { @Test void alpBitpacked_f64() { - // Given double[] values = new double[1024]; for (int i = 0; i < values.length; i++) { values[i] = 150.0 + (i % 100) * 0.01; } CascadingCompressor sut = new CascadingCompressor(ALL_CODECS); - // When Registry registry = Registry.of(ALL_CODECS); EncodeResult result = sut.encode(DTypes.F64, values, EncodeContext.ofDepth(1, Arena.ofAuto(), registry)); DecodeContext decodeCtx = EncodeTestHelper.toDecodeContext(result, values.length, DTypes.F64, registry); DoubleArray decoded = (DoubleArray) registry.decode(decodeCtx); - // Then for (int i = 0; i < values.length; i++) { assertThat(decoded.getDouble(i)).isEqualTo(values[i]); } @@ -119,20 +121,17 @@ void alpBitpacked_f64() { @Test void forBitpacked_i64() { - // Given: integers in narrow range → FoR reduces to small residuals → bitpackable long[] values = new long[1024]; for (int i = 0; i < values.length; i++) { values[i] = 1_000_000L + (i % 200); } CascadingCompressor sut = new CascadingCompressor(ALL_CODECS); - // When Registry registry = Registry.of(ALL_CODECS); EncodeResult result = sut.encode(DTypes.I64, values, EncodeContext.ofDepth(1, Arena.ofAuto(), registry)); DecodeContext decodeCtx = EncodeTestHelper.toDecodeContext(result, values.length, DTypes.I64, registry); LongArray decoded = (LongArray) registry.decode(decodeCtx); - // Then for (int i = 0; i < values.length; i++) { assertThat(decoded.getLong(i)).isEqualTo(values[i]); } @@ -140,21 +139,18 @@ void forBitpacked_i64() { @Test void dictBitpacked_i32() { - // Given: low-cardinality int column int[] values = new int[2048]; for (int i = 0; i < values.length; i++) { values[i] = i % 10; } CascadingCompressor sut = new CascadingCompressor(ALL_CODECS); - // When Registry registry = Registry.of(ALL_CODECS); EncodeResult result = sut.encode(DTypes.I32, values, EncodeContext.ofDepth(1, Arena.ofAuto(), registry)); DecodeContext decodeCtx = EncodeTestHelper.toDecodeContext(result, values.length, DTypes.I32, registry); io.github.dfa1.vortex.core.array.IntArray decoded = (io.github.dfa1.vortex.core.array.IntArray) registry.decode(decodeCtx); - // Then for (int i = 0; i < values.length; i++) { assertThat(decoded.getInt(i)).isEqualTo(values[i]); } diff --git a/core/src/test/java/io/github/dfa1/vortex/encoding/ChunkedEncodingTest.java b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/ChunkedEncodingEncoderTest.java similarity index 64% rename from core/src/test/java/io/github/dfa1/vortex/encoding/ChunkedEncodingTest.java rename to writer/src/test/java/io/github/dfa1/vortex/writer/encode/ChunkedEncodingEncoderTest.java index 59030e0f..e39e470b 100644 --- a/core/src/test/java/io/github/dfa1/vortex/encoding/ChunkedEncodingTest.java +++ b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/ChunkedEncodingEncoderTest.java @@ -1,10 +1,21 @@ -package io.github.dfa1.vortex.encoding; +package io.github.dfa1.vortex.writer.encode; import io.github.dfa1.vortex.core.DType; import io.github.dfa1.vortex.core.PType; import io.github.dfa1.vortex.core.VortexException; import io.github.dfa1.vortex.core.array.Array; import io.github.dfa1.vortex.core.array.ArraySegments; +import io.github.dfa1.vortex.encoding.ArrayNode; +import io.github.dfa1.vortex.encoding.ChunkedData; +import io.github.dfa1.vortex.encoding.DecodeContext; +import io.github.dfa1.vortex.encoding.EncodeNode; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodeTestHelper; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.Registry; +import io.github.dfa1.vortex.encoding.TestRegistry; +import io.github.dfa1.vortex.reader.decode.ChunkedEncodingDecoder; +import io.github.dfa1.vortex.reader.decode.PrimitiveEncodingDecoder; import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; @@ -17,11 +28,16 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; -class ChunkedEncodingTest { +class ChunkedEncodingEncoderTest { private static final ValueLayout.OfLong LE_LONG = ValueLayout.JAVA_LONG_UNALIGNED.withOrder(ByteOrder.LITTLE_ENDIAN); + private static final ChunkedEncodingEncoder ENCODER = new ChunkedEncodingEncoder(); + private static final ChunkedEncodingDecoder DECODER = new ChunkedEncodingDecoder(); + private static final PrimitiveEncodingEncoder PRIM_ENCODER = new PrimitiveEncodingEncoder(); + private static final Registry REGISTRY = TestRegistry.ofDecoders(DECODER, new PrimitiveEncodingDecoder()); + private static ArrayNode toArrayNode(EncodeNode enc) { ArrayNode[] children = new ArrayNode[enc.children().length]; for (int i = 0; i < children.length; i++) { @@ -39,23 +55,15 @@ class Encode { @Test void roundTrip_twoChunks_i64_preservesValues() { - // Given long[] chunk0 = {10L, 20L, 30L}; long[] chunk1 = {40L, 50L}; DType i64 = new DType.Primitive(PType.I64, false); - var sut = new ChunkedEncoding(); - Registry registry = Registry.builder() - .register(sut) - .register(new PrimitiveEncoding()) - .build(); ChunkedData data = new ChunkedData(List.of(chunk0, chunk1), new long[]{3, 2}); - // When - EncodeResult encoded = sut.encode(i64, data, EncodeTestHelper.testCtx()); - DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, 5L, i64, registry); - Array result = sut.decode(ctx); + EncodeResult encoded = ENCODER.encode(i64, data, EncodeTestHelper.testCtx()); + DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, 5L, i64, REGISTRY); + Array result = DECODER.decode(ctx); - // Then assertThat(result.length()).isEqualTo(5); assertThat(ArraySegments.of(result).get(LE_LONG, 0L)).isEqualTo(10L); assertThat(ArraySegments.of(result).get(LE_LONG, 8L)).isEqualTo(20L); @@ -66,28 +74,19 @@ void roundTrip_twoChunks_i64_preservesValues() { @Test void encodeNode_hasNoDirectBuffers_offsetsAsFirstChild() { - // Given long[] chunk0 = {1L, 2L}; DType i64 = new DType.Primitive(PType.I64, false); - var sut = new ChunkedEncoding(); ChunkedData data = new ChunkedData(List.of(chunk0), new long[]{2}); - // When - EncodeResult result = sut.encode(i64, data, EncodeTestHelper.testCtx()); + EncodeResult result = ENCODER.encode(i64, data, EncodeTestHelper.testCtx()); - // Then assertThat(result.rootNode().bufferIndices()).isEmpty(); - assertThat(result.rootNode().children()).hasSize(2); // offsets + 1 chunk - assertThat(result.buffers()).hasSize(2); // offsets buf + chunk buf + assertThat(result.rootNode().children()).hasSize(2); + assertThat(result.buffers()).hasSize(2); } @Test void mismatchedLengths_throws() { - // Given - var sut = new ChunkedEncoding(); - DType i64 = new DType.Primitive(PType.I64, false); - - // When / Then assertThatThrownBy(() -> new ChunkedData(List.of(new long[]{1L}), new long[]{1, 2})) .isInstanceOf(IllegalArgumentException.class); } @@ -98,37 +97,21 @@ class Decode { @Test void roundTrip_twoChunks_concatenatesValues() { - // Given long[] chunk0 = {10L, 20L, 30L}; long[] chunk1 = {40L, 50L}; DType i64 = new DType.Primitive(PType.I64, false); DType u64 = new DType.Primitive(PType.U64, false); - var sut = new ChunkedEncoding(); - Registry registry = Registry.builder() - .register(sut) - .register(new PrimitiveEncoding()) - .build(); - - // Build chunk_offsets segment: [0, 3, 5] as U64 LE - EncodeResult offsetsResult = new PrimitiveEncoding().encode(u64, new long[]{0L, 3L, 5L}, EncodeTestHelper.testCtx()); - // Build chunk0 segment - EncodeResult chunk0Result = new PrimitiveEncoding().encode(i64, chunk0, EncodeTestHelper.testCtx()); - // Build chunk1 segment - EncodeResult chunk1Result = new PrimitiveEncoding().encode(i64, chunk1, EncodeTestHelper.testCtx()); + EncodeResult offsetsResult = PRIM_ENCODER.encode(u64, new long[]{0L, 3L, 5L}, EncodeTestHelper.testCtx()); + EncodeResult chunk0Result = PRIM_ENCODER.encode(i64, chunk0, EncodeTestHelper.testCtx()); + EncodeResult chunk1Result = PRIM_ENCODER.encode(i64, chunk1, EncodeTestHelper.testCtx()); - // Collect all buffers: [offsets_buf, chunk0_buf, chunk1_buf] MemorySegment[] allBufs = { offsetsResult.buffers().getFirst(), chunk0Result.buffers().getFirst(), chunk1Result.buffers().getFirst() }; - // Build ArrayNode tree: - // root: ChunkedEncoding, children=[offsetsNode, chunk0Node, chunk1Node], buffers=[] - // offsetsNode: PrimitiveEncoding, bufferIndices=[0] - // chunk0Node: PrimitiveEncoding, bufferIndices=[1] - // chunk1Node: PrimitiveEncoding, bufferIndices=[2] ArrayNode offsetsNode = toArrayNode(offsetsResult.rootNode()); ArrayNode chunk0Node = toArrayNode(remapped(chunk0Result.rootNode(), 1)); ArrayNode chunk1Node = toArrayNode(remapped(chunk1Result.rootNode(), 2)); @@ -137,12 +120,9 @@ void roundTrip_twoChunks_concatenatesValues() { new ArrayNode[]{offsetsNode, chunk0Node, chunk1Node}, new int[]{}, null); - DecodeContext ctx = new DecodeContext(root, i64, 5L, allBufs, registry, Arena.ofAuto()); + DecodeContext ctx = new DecodeContext(root, i64, 5L, allBufs, REGISTRY, Arena.ofAuto()); + Array result = DECODER.decode(ctx); - // When - Array result = sut.decode(ctx); - - // Then assertThat(result.length()).isEqualTo(5); assertThat(ArraySegments.of(result).get(LE_LONG, 0L)).isEqualTo(10L); assertThat(ArraySegments.of(result).get(LE_LONG, 8L)).isEqualTo(20L); @@ -153,18 +133,12 @@ void roundTrip_twoChunks_concatenatesValues() { @Test void singleChunk_returnsSameValues() { - // Given long[] data = {1L, 2L, 3L}; DType i64 = new DType.Primitive(PType.I64, false); DType u64 = new DType.Primitive(PType.U64, false); - Registry registry = Registry.builder() - .register(new ChunkedEncoding()) - .register(new PrimitiveEncoding()) - .build(); - - EncodeResult offsetsResult = new PrimitiveEncoding().encode(u64, new long[]{0L, 3L}, EncodeTestHelper.testCtx()); - EncodeResult chunkResult = new PrimitiveEncoding().encode(i64, data, EncodeTestHelper.testCtx()); + EncodeResult offsetsResult = PRIM_ENCODER.encode(u64, new long[]{0L, 3L}, EncodeTestHelper.testCtx()); + EncodeResult chunkResult = PRIM_ENCODER.encode(i64, data, EncodeTestHelper.testCtx()); MemorySegment[] allBufs = { offsetsResult.buffers().getFirst(), @@ -176,12 +150,9 @@ void singleChunk_returnsSameValues() { new ArrayNode[]{toArrayNode(offsetsResult.rootNode()), toArrayNode(remapped(chunkResult.rootNode(), 1))}, new int[]{}, null); - DecodeContext ctx = new DecodeContext(root, i64, 3L, allBufs, registry, Arena.ofAuto()); - - // When - Array result = new ChunkedEncoding().decode(ctx); + DecodeContext ctx = new DecodeContext(root, i64, 3L, allBufs, REGISTRY, Arena.ofAuto()); + Array result = DECODER.decode(ctx); - // Then assertThat(result.length()).isEqualTo(3); for (int i = 0; i < 3; i++) { assertThat(ArraySegments.of(result).get(LE_LONG, (long) i * 8)).isEqualTo(data[i]); @@ -190,16 +161,11 @@ void singleChunk_returnsSameValues() { @Test void noChildren_throws() { - // Given DType i64 = new DType.Primitive(PType.I64, false); - Registry registry = Registry.builder() - .register(new ChunkedEncoding()) - .build(); ArrayNode root = ArrayNode.of(EncodingId.VORTEX_CHUNKED, null, new ArrayNode[]{}, new int[]{}, null); - DecodeContext ctx = new DecodeContext(root, i64, 0L, new MemorySegment[]{}, registry, Arena.ofAuto()); + DecodeContext ctx = new DecodeContext(root, i64, 0L, new MemorySegment[]{}, REGISTRY, Arena.ofAuto()); - // When / Then - assertThatThrownBy(() -> new ChunkedEncoding().decode(ctx)) + assertThatThrownBy(() -> DECODER.decode(ctx)) .isInstanceOf(VortexException.class) .hasMessageContaining("at least one child"); } diff --git a/core/src/test/java/io/github/dfa1/vortex/encoding/ConstantEncodingTest.java b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/ConstantEncodingEncoderTest.java similarity index 57% rename from core/src/test/java/io/github/dfa1/vortex/encoding/ConstantEncodingTest.java rename to writer/src/test/java/io/github/dfa1/vortex/writer/encode/ConstantEncodingEncoderTest.java index c0342f93..80db0bc4 100644 --- a/core/src/test/java/io/github/dfa1/vortex/encoding/ConstantEncodingTest.java +++ b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/ConstantEncodingEncoderTest.java @@ -1,4 +1,4 @@ -package io.github.dfa1.vortex.encoding; +package io.github.dfa1.vortex.writer.encode; import io.github.dfa1.vortex.core.array.Array; import io.github.dfa1.vortex.core.array.ArraySegments; @@ -8,6 +8,14 @@ import io.github.dfa1.vortex.core.array.IntArray; import io.github.dfa1.vortex.core.array.LongArray; import io.github.dfa1.vortex.core.array.ShortArray; +import io.github.dfa1.vortex.encoding.DTypes; +import io.github.dfa1.vortex.encoding.DecodeContext; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodeTestHelper; +import io.github.dfa1.vortex.encoding.PTypeIO; +import io.github.dfa1.vortex.encoding.Registry; +import io.github.dfa1.vortex.encoding.TestRegistry; +import io.github.dfa1.vortex.reader.decode.ConstantEncodingDecoder; import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; @@ -21,7 +29,11 @@ /// Property: encode then decode is lossless for constant (all-equal) arrays. /// Property: decode allocates O(1) memory regardless of rowCount. -class ConstantEncodingTest { +class ConstantEncodingEncoderTest { + + private static final ConstantEncodingEncoder ENCODER = new ConstantEncodingEncoder(); + private static final ConstantEncodingDecoder DECODER = new ConstantEncodingDecoder(); + private static final Registry REGISTRY = TestRegistry.ofDecoders(DECODER); @Nested class Decode { @@ -31,18 +43,14 @@ void decode_largeRowCount_bufferStaysConstantSize() { // Given — 10M rows would allocate 80 MB if the decoder materializes every element; // the correct impl stores exactly one element regardless of logical length. long rowCount = 10_000_000L; - var sut = new ConstantEncoding(); - Registry registry = TestRegistry.of(sut); // When - EncodeResult encoded = sut.encode(DTypes.I64, new long[]{42L}, EncodeTestHelper.testCtx()); - DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, rowCount, DTypes.I64, registry); - Array result = sut.decode(ctx); + EncodeResult encoded = ENCODER.encode(DTypes.I64, new long[]{42L}, EncodeTestHelper.testCtx()); + DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, rowCount, DTypes.I64, REGISTRY); + Array result = DECODER.decode(ctx); // Then assertThat(result.length()).isEqualTo(rowCount); - // Constant encoding must not materialize the full array: the backing buffer must - // hold exactly one element. Before fix: buffer is rowCount * 8 bytes. assertThat(ArraySegments.of(result).byteSize()) .as("constant encoding must not allocate O(rowCount) memory") .isEqualTo(Long.BYTES); @@ -75,65 +83,43 @@ static Stream i64ConstantArrays() { @ParameterizedTest @MethodSource("i32ConstantArrays") void encodeDecode_i32_isLossless(int[] data) { - // Given - var sut = new ConstantEncoding(); - Registry registry = TestRegistry.of(sut); - var le = PTypeIO.LE_INT; - - // When - EncodeResult encoded = sut.encode(DTypes.I32, data, EncodeTestHelper.testCtx()); - DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, data.length, DTypes.I32, registry); - Array result = sut.decode(ctx); + EncodeResult encoded = ENCODER.encode(DTypes.I32, data, EncodeTestHelper.testCtx()); + DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, data.length, DTypes.I32, REGISTRY); + Array result = DECODER.decode(ctx); - // Then — buffer holds one element; logical length is n assertThat(result.length()).isEqualTo(data.length); assertThat(ArraySegments.of(result).byteSize()).isEqualTo(Integer.BYTES); - assertThat(ArraySegments.of(result).get(le, 0L)).isEqualTo(data[0]); + assertThat(ArraySegments.of(result).get(PTypeIO.LE_INT, 0L)).isEqualTo(data[0]); } @ParameterizedTest @MethodSource("i64ConstantArrays") void encodeDecode_i64_isLossless(long[] data) { - // Given - var sut = new ConstantEncoding(); - Registry registry = TestRegistry.of(sut); - var le = PTypeIO.LE_LONG; - - // When - EncodeResult encoded = sut.encode(DTypes.I64, data, EncodeTestHelper.testCtx()); - DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, data.length, DTypes.I64, registry); - Array result = sut.decode(ctx); + EncodeResult encoded = ENCODER.encode(DTypes.I64, data, EncodeTestHelper.testCtx()); + DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, data.length, DTypes.I64, REGISTRY); + Array result = DECODER.decode(ctx); - // Then — buffer holds one element; logical length is n assertThat(result.length()).isEqualTo(data.length); assertThat(ArraySegments.of(result).byteSize()).isEqualTo(Long.BYTES); - assertThat(ArraySegments.of(result).get(le, 0L)).isEqualTo(data[0]); + assertThat(ArraySegments.of(result).get(PTypeIO.LE_LONG, 0L)).isEqualTo(data[0]); } } /// ConstantEncoding stores 1 element in the buffer but reports length=N. /// Primitive Array accessors must broadcast that single element across every - /// logical index, not OOB. Regression-guard for commit ed658b7 (added the - /// broadcast semantic) and 051a794 (fast-path branch-split — preserve broadcast - /// only on the slow path where `elementCount != length`). + /// logical index, not OOB. @Nested class Broadcast { @ParameterizedTest @ValueSource(longs = {1, 2, 10, 1_000, 131_072, 1_000_000L}) void i64_getLong_returnsConstantAtEveryIndex(long rowCount) { - // Given long constant = 0xDEADBEEFCAFEBABEL; - var sut = new ConstantEncoding(); - Registry registry = TestRegistry.of(sut); - EncodeResult encoded = sut.encode(DTypes.I64, new long[]{constant}, EncodeTestHelper.testCtx()); - DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, rowCount, DTypes.I64, registry); + EncodeResult encoded = ENCODER.encode(DTypes.I64, new long[]{constant}, EncodeTestHelper.testCtx()); + DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, rowCount, DTypes.I64, REGISTRY); - // When - LongArray result = (LongArray) sut.decode(ctx); + LongArray result = (LongArray) DECODER.decode(ctx); - // Then — getLong at first, last, and arbitrary midpoints all return the constant. - // Catches: missing modulo (OOB or wrong value) and accidental skip of broadcast branch. assertThat(result.length()).isEqualTo(rowCount); assertThat(result.getLong(0)).isEqualTo(constant); assertThat(result.getLong(rowCount - 1)).isEqualTo(constant); @@ -144,39 +130,26 @@ void i64_getLong_returnsConstantAtEveryIndex(long rowCount) { @Test void i64_fold_broadcastsAcrossAllRows() { - // Given — fold is the hot path for column aggregates. Must use the broadcast - // branch when elementCount != length, otherwise fold returns wrong sum. long rowCount = 1_000_000L; long constant = 7L; - var sut = new ConstantEncoding(); - Registry registry = TestRegistry.of(sut); - EncodeResult encoded = sut.encode(DTypes.I64, new long[]{constant}, EncodeTestHelper.testCtx()); - DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, rowCount, DTypes.I64, registry); + EncodeResult encoded = ENCODER.encode(DTypes.I64, new long[]{constant}, EncodeTestHelper.testCtx()); + DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, rowCount, DTypes.I64, REGISTRY); - // When - LongArray result = (LongArray) sut.decode(ctx); + LongArray result = (LongArray) DECODER.decode(ctx); long sum = result.fold(0L, Long::sum); - // Then — every row contributes the constant; total is rowCount * constant. - // Pre-fix (no modulo) the fold would OOB on row 1; post-bug-fix without branch-split - // the result is correct but ~25% slower. assertThat(sum).isEqualTo(rowCount * constant); } @Test void i32_getInt_broadcastsAcrossEveryIndex() { - // Given long rowCount = 10_000L; int constant = -123_456; - var sut = new ConstantEncoding(); - Registry registry = TestRegistry.of(sut); - EncodeResult encoded = sut.encode(DTypes.I32, new int[]{constant}, EncodeTestHelper.testCtx()); - DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, rowCount, DTypes.I32, registry); + EncodeResult encoded = ENCODER.encode(DTypes.I32, new int[]{constant}, EncodeTestHelper.testCtx()); + DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, rowCount, DTypes.I32, REGISTRY); - // When - IntArray result = (IntArray) sut.decode(ctx); + IntArray result = (IntArray) DECODER.decode(ctx); - // Then assertThat(result.length()).isEqualTo(rowCount); assertThat(result.getInt(0)).isEqualTo(constant); assertThat(result.getInt(rowCount - 1)).isEqualTo(constant); @@ -185,40 +158,29 @@ void i32_getInt_broadcastsAcrossEveryIndex() { @Test void f64_getDouble_broadcastsAcrossEveryIndex() { - // Given long rowCount = 10_000L; double constant = 3.141592653589793; - var sut = new ConstantEncoding(); - Registry registry = TestRegistry.of(sut); - EncodeResult encoded = sut.encode(DTypes.F64, new double[]{constant}, EncodeTestHelper.testCtx()); - DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, rowCount, DTypes.F64, registry); + EncodeResult encoded = ENCODER.encode(DTypes.F64, new double[]{constant}, EncodeTestHelper.testCtx()); + DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, rowCount, DTypes.F64, REGISTRY); - // When - DoubleArray result = (DoubleArray) sut.decode(ctx); + DoubleArray result = (DoubleArray) DECODER.decode(ctx); - // Then assertThat(result.length()).isEqualTo(rowCount); assertThat(result.getDouble(0)).isEqualTo(constant); assertThat(result.getDouble(rowCount - 1)).isEqualTo(constant); - // Iterative double sum drifts (~1e-10 per 10K rows) — use tolerance, not strict equality. assertThat(result.fold(0.0, Double::sum)) .isCloseTo(rowCount * constant, org.assertj.core.data.Offset.offset(1e-6)); } @Test void f32_getFloat_broadcastsAcrossEveryIndex() { - // Given long rowCount = 10_000L; float constant = 2.71828f; - var sut = new ConstantEncoding(); - Registry registry = TestRegistry.of(sut); - EncodeResult encoded = sut.encode(DTypes.F32, new float[]{constant}, EncodeTestHelper.testCtx()); - DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, rowCount, DTypes.F32, registry); + EncodeResult encoded = ENCODER.encode(DTypes.F32, new float[]{constant}, EncodeTestHelper.testCtx()); + DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, rowCount, DTypes.F32, REGISTRY); - // When - FloatArray result = (FloatArray) sut.decode(ctx); + FloatArray result = (FloatArray) DECODER.decode(ctx); - // Then assertThat(result.length()).isEqualTo(rowCount); assertThat(result.getFloat(0)).isEqualTo(constant); assertThat(result.getFloat(rowCount - 1)).isEqualTo(constant); @@ -226,18 +188,13 @@ void f32_getFloat_broadcastsAcrossEveryIndex() { @Test void i16_getShort_broadcastsAcrossEveryIndex() { - // Given long rowCount = 10_000L; short constant = (short) -12345; - var sut = new ConstantEncoding(); - Registry registry = TestRegistry.of(sut); - EncodeResult encoded = sut.encode(DTypes.I16, new short[]{constant}, EncodeTestHelper.testCtx()); - DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, rowCount, DTypes.I16, registry); + EncodeResult encoded = ENCODER.encode(DTypes.I16, new short[]{constant}, EncodeTestHelper.testCtx()); + DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, rowCount, DTypes.I16, REGISTRY); - // When - ShortArray result = (ShortArray) sut.decode(ctx); + ShortArray result = (ShortArray) DECODER.decode(ctx); - // Then assertThat(result.length()).isEqualTo(rowCount); assertThat(result.getShort(0)).isEqualTo(constant); assertThat(result.getShort(rowCount - 1)).isEqualTo(constant); @@ -245,18 +202,13 @@ void i16_getShort_broadcastsAcrossEveryIndex() { @Test void i8_getByte_broadcastsAcrossEveryIndex() { - // Given long rowCount = 10_000L; byte constant = (byte) -42; - var sut = new ConstantEncoding(); - Registry registry = TestRegistry.of(sut); - EncodeResult encoded = sut.encode(DTypes.I8, new byte[]{constant}, EncodeTestHelper.testCtx()); - DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, rowCount, DTypes.I8, registry); + EncodeResult encoded = ENCODER.encode(DTypes.I8, new byte[]{constant}, EncodeTestHelper.testCtx()); + DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, rowCount, DTypes.I8, REGISTRY); - // When - ByteArray result = (ByteArray) sut.decode(ctx); + ByteArray result = (ByteArray) DECODER.decode(ctx); - // Then assertThat(result.length()).isEqualTo(rowCount); assertThat(result.getByte(0)).isEqualTo(constant); assertThat(result.getByte(rowCount - 1)).isEqualTo(constant); diff --git a/core/src/test/java/io/github/dfa1/vortex/encoding/DateTimePartsEncodingTest.java b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/DateTimePartsEncodingEncoderTest.java similarity index 66% rename from core/src/test/java/io/github/dfa1/vortex/encoding/DateTimePartsEncodingTest.java rename to writer/src/test/java/io/github/dfa1/vortex/writer/encode/DateTimePartsEncodingEncoderTest.java index 68e8851b..6a6d3b9d 100644 --- a/core/src/test/java/io/github/dfa1/vortex/encoding/DateTimePartsEncodingTest.java +++ b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/DateTimePartsEncodingEncoderTest.java @@ -1,11 +1,24 @@ -package io.github.dfa1.vortex.encoding; +package io.github.dfa1.vortex.writer.encode; -import io.github.dfa1.vortex.proto.DateTimePartsMetadata; import io.github.dfa1.vortex.core.ArrayStats; import io.github.dfa1.vortex.core.DType; import io.github.dfa1.vortex.core.PType; import io.github.dfa1.vortex.core.array.GenericArray; import io.github.dfa1.vortex.core.array.LongArray; +import io.github.dfa1.vortex.encoding.ArrayNode; +import io.github.dfa1.vortex.encoding.DTypes; +import io.github.dfa1.vortex.encoding.DateTimePartsData; +import io.github.dfa1.vortex.encoding.DecodeContext; +import io.github.dfa1.vortex.encoding.EncodeNode; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodeTestHelper; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.Registry; +import io.github.dfa1.vortex.encoding.TestRegistry; +import io.github.dfa1.vortex.encoding.TimeUnit; +import io.github.dfa1.vortex.proto.DateTimePartsMetadata; +import io.github.dfa1.vortex.reader.decode.DateTimePartsEncodingDecoder; +import io.github.dfa1.vortex.reader.decode.PrimitiveEncodingDecoder; import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; @@ -19,16 +32,19 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; -class DateTimePartsEncodingTest { +class DateTimePartsEncodingEncoderTest { + + private static final DateTimePartsEncodingEncoder ENCODER = new DateTimePartsEncodingEncoder(); + private static final DateTimePartsEncodingDecoder DECODER = new DateTimePartsEncodingDecoder(); + private static final Registry REGISTRY = TestRegistry.ofDecoders(DECODER, new PrimitiveEncodingDecoder()); private static final DType EXT_TIMESTAMP_MS = timestampDType(TimeUnit.Milliseconds); private static final DType EXT_TIMESTAMP_NS = timestampDType(TimeUnit.Nanoseconds); private static DType timestampDType(TimeUnit unit) { - // Rust hand-rolled: byte[0]=unit tag, bytes[1-2]=tz_len u16 LE (0 = no tz) ByteBuffer meta = ByteBuffer.allocate(3).order(ByteOrder.LITTLE_ENDIAN); meta.put((byte) unit.ordinal()); - meta.putShort((short) 0); // no timezone + meta.putShort((short) 0); meta.flip(); return new DType.Extension("vortex.timestamp", new DType.Primitive(PType.I64, false), meta, false); @@ -42,45 +58,26 @@ private static ArrayNode toArrayNode(EncodeNode node) { return ArrayNode.of(node.encodingId(), node.metadata(), children, node.bufferIndices(), ArrayStats.empty()); } - private static Registry registry() { - return Registry.builder() - .register(new DateTimePartsEncoding()) - .register(new PrimitiveEncoding()) - .build(); - } - @Nested class Encode { @Test void accepts_extensionDtype_true() { - // Given - DateTimePartsEncoding sut = new DateTimePartsEncoding(); - - // When / Then - assertThat(sut.accepts(EXT_TIMESTAMP_MS)).isTrue(); + assertThat(ENCODER.accepts(EXT_TIMESTAMP_MS)).isTrue(); } @Test void accepts_primitiveDtype_false() { - // Given - DateTimePartsEncoding sut = new DateTimePartsEncoding(); - - // When / Then - assertThat(sut.accepts(DTypes.I64)).isFalse(); + assertThat(ENCODER.accepts(DTypes.I64)).isFalse(); } @Test void encode_producesThreeChildren_noBuffersAtRoot() { - // Given long[] timestamps = {0L, 86_400_000L}; DateTimePartsData data = new DateTimePartsData(timestamps, false); - DateTimePartsEncoding sut = new DateTimePartsEncoding(); - // When - EncodeResult result = sut.encode(EXT_TIMESTAMP_MS, data, EncodeTestHelper.testCtx()); + EncodeResult result = ENCODER.encode(EXT_TIMESTAMP_MS, data, EncodeTestHelper.testCtx()); - // Then assertThat(result.rootNode().encodingId()).isEqualTo(EncodingId.VORTEX_DATETIMEPARTS); assertThat(result.rootNode().bufferIndices()).isEmpty(); assertThat(result.rootNode().children()).hasSize(3); @@ -88,14 +85,11 @@ void encode_producesThreeChildren_noBuffersAtRoot() { @Test void encode_missingMetadata_throws() { - // Given DType noMeta = new DType.Extension("vortex.timestamp", new DType.Primitive(PType.I64, false), null, false); - DateTimePartsEncoding sut = new DateTimePartsEncoding(); DateTimePartsData data = new DateTimePartsData(new long[]{0L}, false); - // When / Then - assertThatThrownBy(() -> sut.encode(noMeta, data, EncodeTestHelper.testCtx())) + assertThatThrownBy(() -> ENCODER.encode(noMeta, data, EncodeTestHelper.testCtx())) .hasMessageContaining("extension metadata missing"); } } @@ -105,22 +99,17 @@ class Decode { @Test void roundTrip_milliseconds_preservesDaysSecondsSubseconds() { - // Given - // 1970-01-02 01:02:03.456 UTC in millis long msPerDay = 86_400_000L; long ts = msPerDay + (3723L * 1000L) + 456L; long[] timestamps = {ts}; DateTimePartsData data = new DateTimePartsData(timestamps, false); - DateTimePartsEncoding sut = new DateTimePartsEncoding(); - // When - EncodeResult result = sut.encode(EXT_TIMESTAMP_MS, data, EncodeTestHelper.testCtx()); + EncodeResult result = ENCODER.encode(EXT_TIMESTAMP_MS, data, EncodeTestHelper.testCtx()); MemorySegment[] bufs = result.buffers().toArray(MemorySegment[]::new); DecodeContext ctx = new DecodeContext( - toArrayNode(result.rootNode()), EXT_TIMESTAMP_MS, 1, bufs, registry(), Arena.global()); - GenericArray decoded = (GenericArray) sut.decode(ctx); + toArrayNode(result.rootNode()), EXT_TIMESTAMP_MS, 1, bufs, REGISTRY, Arena.global()); + GenericArray decoded = (GenericArray) DECODER.decode(ctx); - // Then assertThat(decoded.length()).isEqualTo(1); LongArray days = (LongArray) decoded.child(0); LongArray seconds = (LongArray) decoded.child(1); @@ -132,22 +121,17 @@ void roundTrip_milliseconds_preservesDaysSecondsSubseconds() { @Test void roundTrip_nanoseconds_preservesSubsecondPrecision() { - // Given - // 1970-01-02 01:02:03.456789123 UTC in nanos long nsPerDay = 86_400_000_000_000L; long ts = nsPerDay + (3723L * 1_000_000_000L) + 456_789_123L; long[] timestamps = {ts}; DateTimePartsData data = new DateTimePartsData(timestamps, false); - DateTimePartsEncoding sut = new DateTimePartsEncoding(); - // When - EncodeResult result = sut.encode(EXT_TIMESTAMP_NS, data, EncodeTestHelper.testCtx()); + EncodeResult result = ENCODER.encode(EXT_TIMESTAMP_NS, data, EncodeTestHelper.testCtx()); MemorySegment[] bufs = result.buffers().toArray(MemorySegment[]::new); DecodeContext ctx = new DecodeContext( - toArrayNode(result.rootNode()), EXT_TIMESTAMP_NS, 1, bufs, registry(), Arena.global()); - GenericArray decoded = (GenericArray) sut.decode(ctx); + toArrayNode(result.rootNode()), EXT_TIMESTAMP_NS, 1, bufs, REGISTRY, Arena.global()); + GenericArray decoded = (GenericArray) DECODER.decode(ctx); - // Then LongArray days = (LongArray) decoded.child(0); LongArray seconds = (LongArray) decoded.child(1); LongArray subseconds = (LongArray) decoded.child(2); @@ -158,19 +142,15 @@ void roundTrip_nanoseconds_preservesSubsecondPrecision() { @Test void roundTrip_epoch_allZero() { - // Given long[] timestamps = {0L}; DateTimePartsData data = new DateTimePartsData(timestamps, false); - DateTimePartsEncoding sut = new DateTimePartsEncoding(); - // When - EncodeResult result = sut.encode(EXT_TIMESTAMP_MS, data, EncodeTestHelper.testCtx()); + EncodeResult result = ENCODER.encode(EXT_TIMESTAMP_MS, data, EncodeTestHelper.testCtx()); MemorySegment[] bufs = result.buffers().toArray(MemorySegment[]::new); DecodeContext ctx = new DecodeContext( - toArrayNode(result.rootNode()), EXT_TIMESTAMP_MS, 1, bufs, registry(), Arena.global()); - GenericArray decoded = (GenericArray) sut.decode(ctx); + toArrayNode(result.rootNode()), EXT_TIMESTAMP_MS, 1, bufs, REGISTRY, Arena.global()); + GenericArray decoded = (GenericArray) DECODER.decode(ctx); - // Then LongArray days = (LongArray) decoded.child(0); LongArray seconds = (LongArray) decoded.child(1); LongArray subseconds = (LongArray) decoded.child(2); @@ -181,20 +161,16 @@ void roundTrip_epoch_allZero() { @Test void roundTrip_multipleTimestamps_allRowsPreserved() { - // Given long msPerDay = 86_400_000L; long[] timestamps = {0L, msPerDay, msPerDay + 1000L, msPerDay + 1001L}; DateTimePartsData data = new DateTimePartsData(timestamps, false); - DateTimePartsEncoding sut = new DateTimePartsEncoding(); - // When - EncodeResult result = sut.encode(EXT_TIMESTAMP_MS, data, EncodeTestHelper.testCtx()); + EncodeResult result = ENCODER.encode(EXT_TIMESTAMP_MS, data, EncodeTestHelper.testCtx()); MemorySegment[] bufs = result.buffers().toArray(MemorySegment[]::new); DecodeContext ctx = new DecodeContext( - toArrayNode(result.rootNode()), EXT_TIMESTAMP_MS, 4, bufs, registry(), Arena.global()); - GenericArray decoded = (GenericArray) sut.decode(ctx); + toArrayNode(result.rootNode()), EXT_TIMESTAMP_MS, 4, bufs, REGISTRY, Arena.global()); + GenericArray decoded = (GenericArray) DECODER.decode(ctx); - // Then assertThat(decoded.length()).isEqualTo(4); LongArray days = (LongArray) decoded.child(0); assertThat(days.getLong(0)).isEqualTo(0L); @@ -209,20 +185,16 @@ void roundTrip_multipleTimestamps_allRowsPreserved() { @ParameterizedTest @EnumSource(value = TimeUnit.class, names = {"Nanoseconds", "Microseconds", "Milliseconds", "Seconds"}) void roundTrip_allUnits_epochIsZero(TimeUnit unit) { - // Given DType dtype = timestampDType(unit); long[] timestamps = {0L}; DateTimePartsData data = new DateTimePartsData(timestamps, false); - DateTimePartsEncoding sut = new DateTimePartsEncoding(); - // When - EncodeResult result = sut.encode(dtype, data, EncodeTestHelper.testCtx()); + EncodeResult result = ENCODER.encode(dtype, data, EncodeTestHelper.testCtx()); MemorySegment[] bufs = result.buffers().toArray(MemorySegment[]::new); DecodeContext ctx = new DecodeContext( - toArrayNode(result.rootNode()), dtype, 1, bufs, registry(), Arena.global()); - GenericArray decoded = (GenericArray) sut.decode(ctx); + toArrayNode(result.rootNode()), dtype, 1, bufs, REGISTRY, Arena.global()); + GenericArray decoded = (GenericArray) DECODER.decode(ctx); - // Then LongArray days = (LongArray) decoded.child(0); LongArray seconds = (LongArray) decoded.child(1); LongArray subseconds = (LongArray) decoded.child(2); @@ -233,21 +205,16 @@ void roundTrip_allUnits_epochIsZero(TimeUnit unit) { @Test void encode_metadata_ptypes_areI64() throws Exception { - // Given — DateTimeParts always encodes days/seconds/subseconds as I64 (ordinal=7) - // if any tag drifts, the corresponding ptype reads as 0 (U8) which is proto3 default long[] timestamps = {0L, 86_400_000L}; DateTimePartsData data = new DateTimePartsData(timestamps, false); - DateTimePartsEncoding sut = new DateTimePartsEncoding(); - // When - EncodeResult result = sut.encode(EXT_TIMESTAMP_MS, data, EncodeTestHelper.testCtx()); - DateTimePartsMetadata meta = - DateTimePartsMetadata.decode(java.lang.foreign.MemorySegment.ofBuffer(result.rootNode().metadata().duplicate()), 0, java.lang.foreign.MemorySegment.ofBuffer(result.rootNode().metadata().duplicate()).byteSize()); + EncodeResult result = ENCODER.encode(EXT_TIMESTAMP_MS, data, EncodeTestHelper.testCtx()); + var metaSeg = java.lang.foreign.MemorySegment.ofBuffer(result.rootNode().metadata().duplicate()); + DateTimePartsMetadata meta = DateTimePartsMetadata.decode(metaSeg, 0, metaSeg.byteSize()); - // Then - assertThat(meta.days_ptype().value()).isEqualTo(7); // I64 - assertThat(meta.seconds_ptype().value()).isEqualTo(7); // I64 - assertThat(meta.subseconds_ptype().value()).isEqualTo(7); // I64 + assertThat(meta.days_ptype().value()).isEqualTo(7); + assertThat(meta.seconds_ptype().value()).isEqualTo(7); + assertThat(meta.subseconds_ptype().value()).isEqualTo(7); } } } diff --git a/writer/src/test/java/io/github/dfa1/vortex/writer/encode/DecimalBytePartsEncodingEncoderTest.java b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/DecimalBytePartsEncodingEncoderTest.java new file mode 100644 index 00000000..4df6c1b1 --- /dev/null +++ b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/DecimalBytePartsEncodingEncoderTest.java @@ -0,0 +1,79 @@ +package io.github.dfa1.vortex.writer.encode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.array.Array; +import io.github.dfa1.vortex.core.array.ArraySegments; +import io.github.dfa1.vortex.core.array.GenericArray; +import io.github.dfa1.vortex.encoding.DecodeContext; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodeTestHelper; +import io.github.dfa1.vortex.encoding.PTypeIO; +import io.github.dfa1.vortex.encoding.Registry; +import io.github.dfa1.vortex.encoding.TestRegistry; +import io.github.dfa1.vortex.proto.DecimalBytePartsMetadata; +import io.github.dfa1.vortex.reader.decode.DecimalBytePartsEncodingDecoder; +import io.github.dfa1.vortex.reader.decode.PrimitiveEncodingDecoder; +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; + +class DecimalBytePartsEncodingEncoderTest { + + @Test + void roundTrip_longArray_preservesMspValues() { + // Given + long[] values = {1L, -2L, 3L}; + DType dtype = new DType.Decimal((byte) 18, (byte) 0, false); + var encoder = new DecimalBytePartsEncodingEncoder(); + var decoder = new DecimalBytePartsEncodingDecoder(); + Registry registry = TestRegistry.ofDecoders(decoder, new PrimitiveEncodingDecoder()); + + // When + EncodeResult encoded = encoder.encode(dtype, values, EncodeTestHelper.testCtx()); + DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, values.length, dtype, registry); + GenericArray result = (GenericArray) decoder.decode(ctx); + + // Then + assertThat(result.length()).isEqualTo(values.length); + Array msp = result.child(0); + assertThat(msp.length()).isEqualTo(values.length); + for (int i = 0; i < values.length; i++) { + assertThat(ArraySegments.of(msp).get(PTypeIO.LE_LONG, (long) i * 8)).isEqualTo(values[i]); + } + } + + @Test + void encodeNode_hasNoBuffers_andOneMspChild() { + // Given + long[] values = {10L, 20L}; + DType dtype = new DType.Decimal((byte) 18, (byte) 0, false); + var sut = new DecimalBytePartsEncodingEncoder(); + + // When + EncodeResult result = sut.encode(dtype, values, EncodeTestHelper.testCtx()); + + // Then + assertThat(result.rootNode().bufferIndices()).isEmpty(); + assertThat(result.rootNode().children()).hasSize(1); + assertThat(result.buffers()).hasSize(1); + } + + @Test + void metadata_zerothChildPtype_isI64_lowerPartCountIsZero() throws Exception { + // Given + long[] values = {42L}; + DType dtype = new DType.Decimal((byte) 18, (byte) 0, false); + var sut = new DecimalBytePartsEncodingEncoder(); + + // When + EncodeResult result = sut.encode(dtype, values, EncodeTestHelper.testCtx()); + + // Then + byte[] metaBytes = new byte[result.rootNode().metadata().remaining()]; + result.rootNode().metadata().duplicate().get(metaBytes); + DecimalBytePartsMetadata meta = + DecimalBytePartsMetadata.decode(java.lang.foreign.MemorySegment.ofArray(metaBytes), 0, metaBytes.length); + assertThat(meta.zeroth_child_ptype().value()).isEqualTo(7); // I64 ordinal + assertThat(meta.lower_part_count()).isEqualTo(0); + } +} diff --git a/writer/src/test/java/io/github/dfa1/vortex/writer/encode/DecimalEncodingEncoderTest.java b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/DecimalEncodingEncoderTest.java new file mode 100644 index 00000000..d09f370a --- /dev/null +++ b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/DecimalEncodingEncoderTest.java @@ -0,0 +1,98 @@ +package io.github.dfa1.vortex.writer.encode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.PType; +import io.github.dfa1.vortex.core.VortexException; +import io.github.dfa1.vortex.core.array.Array; +import io.github.dfa1.vortex.core.array.ArraySegments; +import io.github.dfa1.vortex.encoding.DecodeContext; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodeTestHelper; +import io.github.dfa1.vortex.encoding.PTypeIO; +import io.github.dfa1.vortex.encoding.Registry; +import io.github.dfa1.vortex.encoding.TestRegistry; +import io.github.dfa1.vortex.encoding.TestSegments; +import io.github.dfa1.vortex.proto.DecimalMetadata; +import io.github.dfa1.vortex.reader.decode.DecimalEncodingDecoder; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.CsvSource; + +import java.lang.foreign.Arena; +import java.lang.foreign.MemorySegment; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +class DecimalEncodingEncoderTest { + + private static final DecimalEncodingEncoder ENCODER = new DecimalEncodingEncoder(); + private static final DecimalEncodingDecoder DECODER = new DecimalEncodingDecoder(); + private static final Registry REGISTRY = TestRegistry.ofDecoders(DECODER); + + @Test + void roundTrip_i64Precision_preservesBuffer() { + long[] values = {100L, -200L, 300L}; + MemorySegment input = TestSegments.leLongs(values); + DType dtype = new DType.Decimal((byte) 18, (byte) 2, false); + + EncodeResult encoded = ENCODER.encode(dtype, input, EncodeTestHelper.testCtx()); + DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, values.length, dtype, REGISTRY); + Array result = DECODER.decode(ctx); + + assertThat(result.length()).isEqualTo(values.length); + for (int i = 0; i < values.length; i++) { + assertThat(ArraySegments.of(result).get(PTypeIO.LE_LONG, (long) i * 8)).isEqualTo(values[i]); + } + } + + @Test + void accepts_decimalDtype_true_primitiveReturnsFalse() { + assertThat(ENCODER.accepts(new DType.Decimal((byte) 18, (byte) 2, false))).isTrue(); + assertThat(ENCODER.accepts(new DType.Primitive(PType.I64, false))).isFalse(); + } + + @ParameterizedTest(name = "precision={0} → valuesType={1}") + @CsvSource({ + "1, 0", + "2, 0", + "3, 1", + "4, 1", + "5, 2", + "9, 2", + "10, 3", + "18, 3", + "19, 4", + "38, 4", + "39, 5", + }) + void valuesType_matchesPrecisionBoundaries(int precision, int expectedValuesType) throws Exception { + int byteWidth = switch (expectedValuesType) { + case 0 -> 1; + case 1 -> 2; + case 2 -> 4; + case 3 -> 8; + case 4 -> 16; + default -> 32; + }; + MemorySegment input = Arena.ofAuto().allocate(byteWidth); + DType dtype = new DType.Decimal((byte) precision, (byte) 0, false); + + EncodeResult encoded = ENCODER.encode(dtype, input, EncodeTestHelper.testCtx()); + + byte[] metaBytes = new byte[encoded.rootNode().metadata().remaining()]; + encoded.rootNode().metadata().duplicate().get(metaBytes); + DecimalMetadata meta = DecimalMetadata.decode(java.lang.foreign.MemorySegment.ofArray(metaBytes), 0, metaBytes.length); + assertThat(meta.values_type()).isEqualTo(expectedValuesType); + } + + @Test + void invalidBufferSize_throws() { + MemorySegment input = Arena.ofAuto().allocate(7); + DType dtype = new DType.Decimal((byte) 18, (byte) 0, false); + + assertThatThrownBy(() -> ENCODER.encode(dtype, input, EncodeTestHelper.testCtx())) + .isInstanceOf(VortexException.class) + .hasMessageContaining("not multiple of byteWidth"); + } +} diff --git a/writer/src/test/java/io/github/dfa1/vortex/writer/encode/DeltaEncodingEncoderTest.java b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/DeltaEncodingEncoderTest.java new file mode 100644 index 00000000..da473633 --- /dev/null +++ b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/DeltaEncodingEncoderTest.java @@ -0,0 +1,109 @@ +package io.github.dfa1.vortex.writer.encode; + +import io.github.dfa1.vortex.core.array.Array; +import io.github.dfa1.vortex.core.array.ArraySegments; +import io.github.dfa1.vortex.encoding.DTypes; +import io.github.dfa1.vortex.encoding.DecodeContext; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodeTestHelper; +import io.github.dfa1.vortex.encoding.PTypeIO; +import io.github.dfa1.vortex.encoding.Registry; +import io.github.dfa1.vortex.encoding.TestRegistry; +import io.github.dfa1.vortex.proto.DeltaMetadata; +import io.github.dfa1.vortex.reader.decode.DeltaEncodingDecoder; +import io.github.dfa1.vortex.reader.decode.PrimitiveEncodingDecoder; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; + +import java.lang.foreign.MemorySegment; +import java.util.stream.Stream; + +import static org.assertj.core.api.Assertions.assertThat; + +class DeltaEncodingEncoderTest { + + private static final DeltaEncodingEncoder ENCODER = new DeltaEncodingEncoder(); + private static final DeltaEncodingDecoder DECODER = new DeltaEncodingDecoder(); + private static final Registry REGISTRY = TestRegistry.ofDecoders(DECODER, new PrimitiveEncodingDecoder()); + + static Stream i64Arrays() { + return Stream.of( + new long[]{0}, + new long[]{Long.MIN_VALUE}, + new long[]{0, 1, 2, 3, 4, 5, 6, 7}, + new long[]{100, 200, 300, 400, 500}, + new long[]{-100, -50, 0, 50, 100}, + new long[]{1000, 999, 998, 997, 996} + ); + } + + static Stream i32Arrays() { + return Stream.of( + new int[]{0}, + new int[]{Integer.MIN_VALUE}, + new int[]{0, 1, 2, 3, 4, 5, 6, 7}, + new int[]{10, 20, 30, 40, 50}, + new int[]{-5, -4, -3, -2, -1, 0} + ); + } + + static Stream monotoneI64Arrays() { + return Stream.of( + Arguments.of("ascending-1", new long[]{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}), + Arguments.of("ascending-100", new long[]{0, 100, 200, 300, 400, 500, 600, 700, 800, 900}), + Arguments.of("descending", new long[]{1000, 999, 998, 997, 996, 995, 994, 993, 992, 991}) + ); + } + + @ParameterizedTest + @MethodSource("i64Arrays") + void encodeDecode_i64_isLossless(long[] data) { + EncodeResult encoded = ENCODER.encode(DTypes.I64, data, EncodeTestHelper.testCtx()); + DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, data.length, DTypes.I64, REGISTRY); + Array result = DECODER.decode(ctx); + + assertThat(result.length()).isEqualTo(data.length); + for (int i = 0; i < data.length; i++) { + assertThat(ArraySegments.of(result).get(PTypeIO.LE_LONG, (long) i * 8)).as("index %d", i).isEqualTo(data[i]); + } + } + + @ParameterizedTest + @MethodSource("i32Arrays") + void encodeDecode_i32_isLossless(int[] data) { + EncodeResult encoded = ENCODER.encode(DTypes.I32, data, EncodeTestHelper.testCtx()); + DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, data.length, DTypes.I32, REGISTRY); + Array result = DECODER.decode(ctx); + + assertThat(result.length()).isEqualTo(data.length); + for (int i = 0; i < data.length; i++) { + assertThat(ArraySegments.of(result).get(PTypeIO.LE_INT, (long) i * 4)).as("index %d", i).isEqualTo(data[i]); + } + } + + @ParameterizedTest(name = "{0}") + @MethodSource("monotoneI64Arrays") + void encodeDecode_monotoneI64_isLossless(String name, long[] data) { + EncodeResult encoded = ENCODER.encode(DTypes.I64, data, EncodeTestHelper.testCtx()); + DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, data.length, DTypes.I64, REGISTRY); + Array result = DECODER.decode(ctx); + + assertThat(result.length()).isEqualTo(data.length); + for (int i = 0; i < data.length; i++) { + assertThat(ArraySegments.of(result).get(PTypeIO.LE_LONG, (long) i * 8)).as("index %d", i).isEqualTo(data[i]); + } + } + + @Test + void encode_i64_metadata_deltasLen_isNonZero() throws Exception { + long[] data = {10L, 20L, 30L, 40L, 50L}; + + EncodeResult result = ENCODER.encode(DTypes.I64, data, EncodeTestHelper.testCtx()); + MemorySegment metaSeg = MemorySegment.ofBuffer(result.rootNode().metadata().duplicate()); + DeltaMetadata meta = DeltaMetadata.decode(metaSeg, 0, metaSeg.byteSize()); + + assertThat(meta.deltas_len()).isGreaterThan(0); + } +} diff --git a/writer/src/test/java/io/github/dfa1/vortex/writer/encode/DictEncodingEncoderTest.java b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/DictEncodingEncoderTest.java new file mode 100644 index 00000000..b8d0fc4f --- /dev/null +++ b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/DictEncodingEncoderTest.java @@ -0,0 +1,161 @@ +package io.github.dfa1.vortex.writer.encode; + +import io.github.dfa1.vortex.core.array.Array; +import io.github.dfa1.vortex.core.array.ArraySegments; +import io.github.dfa1.vortex.core.array.VarBinArray; +import io.github.dfa1.vortex.encoding.DTypes; +import io.github.dfa1.vortex.encoding.DecodeContext; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodeTestHelper; +import io.github.dfa1.vortex.encoding.PTypeIO; +import io.github.dfa1.vortex.encoding.Registry; +import io.github.dfa1.vortex.encoding.TestRegistry; +import io.github.dfa1.vortex.proto.DictMetadata; +import io.github.dfa1.vortex.reader.decode.DictEncodingDecoder; +import io.github.dfa1.vortex.reader.decode.PrimitiveEncodingDecoder; +import io.github.dfa1.vortex.reader.decode.VarBinEncodingDecoder; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; + +import java.lang.foreign.MemorySegment; +import java.util.stream.Stream; + +import static org.assertj.core.api.Assertions.assertThat; + +class DictEncodingEncoderTest { + + private static final DictEncodingEncoder ENCODER = new DictEncodingEncoder(); + private static final DictEncodingDecoder DECODER = new DictEncodingDecoder(); + private static final Registry PRIM_REG = TestRegistry.ofDecoders(DECODER, new PrimitiveEncodingDecoder()); + private static final Registry UTF8_REG = TestRegistry.ofDecoders(DECODER, new PrimitiveEncodingDecoder(), new VarBinEncodingDecoder()); + + static Stream i32Arrays() { + return Stream.of( + new int[]{0}, + new int[]{1, 2, 3}, + new int[]{0, 1, 2, 0, 1, 2, 0, 1, 2}, + new int[]{42, 42, 42, 42, 42}, + new int[]{Integer.MIN_VALUE, Integer.MAX_VALUE, 0, Integer.MIN_VALUE, Integer.MAX_VALUE} + ); + } + + static Stream i64Arrays() { + return Stream.of( + new long[]{0L}, + new long[]{Long.MIN_VALUE, Long.MAX_VALUE, 0L, Long.MIN_VALUE}, + new long[]{1L, 2L, 1L, 2L, 1L, 2L, 1L, 2L} + ); + } + + static Stream repetitiveI32Arrays() { + return Stream.of( + Arguments.of("binary-100", repeat(new int[]{0, 1}, 50)), + Arguments.of("single-value-50", repeat(new int[]{42}, 50)), + Arguments.of("three-values-60", repeat(new int[]{10, 20, 30}, 20)) + ); + } + + static Stream utf8Arrays() { + return Stream.of( + Arguments.of("single", new String[]{"hello"}), + Arguments.of("all-unique", new String[]{"a", "b", "c"}), + Arguments.of("repeated", new String[]{"AAPL", "GOOG", "AAPL", "MSFT", "GOOG", "AAPL"}), + Arguments.of("unicode", new String[]{"café", "naïve", "café", "résumé", "naïve"}), + Arguments.of("empty-string", new String[]{"", "x", "", "y", ""}) + ); + } + + private static int[] repeat(int[] pattern, int times) { + int[] result = new int[pattern.length * times]; + for (int i = 0; i < times; i++) { + System.arraycopy(pattern, 0, result, i * pattern.length, pattern.length); + } + return result; + } + + @SuppressWarnings("SameParameterValue") + private static String[] repeat(String[] pattern, int times) { + String[] result = new String[pattern.length * times]; + for (int i = 0; i < times; i++) { + System.arraycopy(pattern, 0, result, i * pattern.length, pattern.length); + } + return result; + } + + @ParameterizedTest + @MethodSource("i32Arrays") + void encodeDecode_i32_isLossless(int[] data) { + EncodeResult encoded = ENCODER.encode(DTypes.I32, data, EncodeTestHelper.testCtx()); + DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, data.length, DTypes.I32, PRIM_REG); + Array result = DECODER.decode(ctx); + + assertThat(result.length()).isEqualTo(data.length); + for (int i = 0; i < data.length; i++) { + assertThat(ArraySegments.of(result).get(PTypeIO.LE_INT, (long) i * 4)).as("index %d", i).isEqualTo(data[i]); + } + } + + @ParameterizedTest + @MethodSource("i64Arrays") + void encodeDecode_i64_isLossless(long[] data) { + EncodeResult encoded = ENCODER.encode(DTypes.I64, data, EncodeTestHelper.testCtx()); + DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, data.length, DTypes.I64, PRIM_REG); + Array result = DECODER.decode(ctx); + + assertThat(result.length()).isEqualTo(data.length); + for (int i = 0; i < data.length; i++) { + assertThat(ArraySegments.of(result).get(PTypeIO.LE_LONG, (long) i * 8)).as("index %d", i).isEqualTo(data[i]); + } + } + + @ParameterizedTest(name = "{0}") + @MethodSource("repetitiveI32Arrays") + void encodedSize_lowCardinality_compressesWellVsRaw(String name, int[] data) { + EncodeResult encoded = ENCODER.encode(DTypes.I32, data, EncodeTestHelper.testCtx()); + + long encodedBytes = encoded.buffers().stream().mapToLong(MemorySegment::byteSize).sum(); + long rawBytes = (long) data.length * 4; + assertThat(encodedBytes).isLessThan(rawBytes); + } + + @ParameterizedTest(name = "{0}") + @MethodSource("utf8Arrays") + void encodeDecode_utf8_isLossless(String name, String[] data) { + EncodeResult encoded = ENCODER.encode(DTypes.UTF8, data, EncodeTestHelper.testCtx()); + DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, data.length, DTypes.UTF8, UTF8_REG); + Array result = DECODER.decode(ctx); + + assertThat(result).isInstanceOf(VarBinArray.class); + VarBinArray arr = (VarBinArray) result; + assertThat(arr.length()).isEqualTo(data.length); + for (int i = 0; i < data.length; i++) { + String actual = arr.getString(i); + assertThat(actual).as("index %d", i).isEqualTo(data[i]); + } + } + + @Test + void encodedSize_lowCardinalityUtf8_compressesWellVsRaw() { + String[] symbols = {"AAPL", "GOOG", "MSFT"}; + String[] data = repeat(symbols, 1000); + + EncodeResult encoded = ENCODER.encode(DTypes.UTF8, data, EncodeTestHelper.testCtx()); + + long encodedBytes = encoded.buffers().stream().mapToLong(MemorySegment::byteSize).sum(); + long rawBytes = 3000L * 5; + assertThat(encodedBytes).isLessThan(rawBytes); + } + + @Test + void encode_utf8_metadata_valuesLen_matchesUniqueCount() throws Exception { + String[] data = {"apple", "banana", "apple", "banana", "apple"}; + + EncodeResult result = ENCODER.encode(DTypes.UTF8, data, EncodeTestHelper.testCtx()); + var metaSeg = java.lang.foreign.MemorySegment.ofBuffer(result.rootNode().metadata().duplicate()); + DictMetadata meta = DictMetadata.decode(metaSeg, 0, metaSeg.byteSize()); + + assertThat(meta.values_len()).isEqualTo(2); + } +} diff --git a/core/src/test/java/io/github/dfa1/vortex/encoding/ExtEncodingTest.java b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/ExtEncodingEncoderTest.java similarity index 71% rename from core/src/test/java/io/github/dfa1/vortex/encoding/ExtEncodingTest.java rename to writer/src/test/java/io/github/dfa1/vortex/writer/encode/ExtEncodingEncoderTest.java index 5db5bed7..588d478f 100644 --- a/core/src/test/java/io/github/dfa1/vortex/encoding/ExtEncodingTest.java +++ b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/ExtEncodingEncoderTest.java @@ -1,8 +1,21 @@ -package io.github.dfa1.vortex.encoding; +package io.github.dfa1.vortex.writer.encode; import io.github.dfa1.vortex.core.DType; import io.github.dfa1.vortex.core.PType; import io.github.dfa1.vortex.core.array.LongArray; +import io.github.dfa1.vortex.encoding.ArrayNode; +import io.github.dfa1.vortex.encoding.CascadeStep; +import io.github.dfa1.vortex.encoding.ChildSlot; +import io.github.dfa1.vortex.encoding.DecodeContext; +import io.github.dfa1.vortex.encoding.EncodeNode; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodeTestHelper; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.Registry; +import io.github.dfa1.vortex.encoding.TestRegistry; +import io.github.dfa1.vortex.encoding.TestSegments; +import io.github.dfa1.vortex.reader.decode.ExtEncodingDecoder; +import io.github.dfa1.vortex.reader.decode.PrimitiveEncodingDecoder; import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; @@ -11,59 +24,49 @@ import static org.assertj.core.api.Assertions.assertThat; -class ExtEncodingTest { +class ExtEncodingEncoderTest { + + private static final ExtEncodingEncoder ENCODER = new ExtEncodingEncoder(); + private static final ExtEncodingDecoder DECODER = new ExtEncodingDecoder(); + private static final Registry REGISTRY = TestRegistry.ofDecoders(new PrimitiveEncodingDecoder(), DECODER); @Nested class Encode { @Test void accepts_extensionDtype_returnsTrue() { - // Given DType extDType = new DType.Extension("vortex.timestamp", new DType.Primitive(PType.I64, false), null, false); - var sut = new ExtEncoding(); - - // When / Then - assertThat(sut.accepts(extDType)).isTrue(); + assertThat(ENCODER.accepts(extDType)).isTrue(); } @Test void accepts_primitiveDtype_returnsFalse() { - // Given - var sut = new ExtEncoding(); - - // When / Then - assertThat(sut.accepts(new DType.Primitive(PType.I64, false))).isFalse(); + assertThat(ENCODER.accepts(new DType.Primitive(PType.I64, false))).isFalse(); } @Test void encode_extensionWrappingI64_roundTrips() { - // Given long[] data = {100L, 200L, 300L, 400L}; DType storageDType = new DType.Primitive(PType.I64, false); DType extDType = new DType.Extension("vortex.timestamp", storageDType, null, false); - var sut = new ExtEncoding(); - // When - EncodeResult result = sut.encode(extDType, data, EncodeTestHelper.testCtx()); + EncodeResult result = ENCODER.encode(extDType, data, EncodeTestHelper.testCtx()); - // Then — root is ext, child is primitive assertThat(result.rootNode().encodingId()).isEqualTo(EncodingId.VORTEX_EXT); assertThat(result.rootNode().children()).hasSize(1); assertThat(result.rootNode().children()[0].encodingId()).isEqualTo(EncodingId.VORTEX_PRIMITIVE); - // Decode back - Registry registry = TestRegistry.of(new PrimitiveEncoding(), new ExtEncoding()); ArrayNode rootNode = encodeNodeToArrayNode(result.rootNode()); DecodeContext ctx = new DecodeContext( rootNode, extDType, data.length, result.buffers().toArray(MemorySegment[]::new), - registry, Arena.ofAuto()); - var decoded = sut.decode(ctx); + REGISTRY, Arena.ofAuto()); + var decoded = DECODER.decode(ctx); assertThat(decoded).isInstanceOf(LongArray.class); + LongArray longArray = (LongArray) decoded; for (int i = 0; i < data.length; i++) { - LongArray longArray = (LongArray) decoded; assertThat(longArray.getLong(i)).isEqualTo(data[i]); } } @@ -82,16 +85,12 @@ class Cascade { @Test void encodeCascade_exposesStorageAsOpenChild() { - // Given — extension wraps an I64 storage column long[] data = {100L, 200L, 300L, 400L}; DType storageDType = new DType.Primitive(PType.I64, false); DType extDType = new DType.Extension("vortex.timestamp", storageDType, null, false); - var sut = new ExtEncoding(); - // When - CascadeStep step = sut.encodeCascade(extDType, data, EncodeTestHelper.testCtx()); + CascadeStep step = ENCODER.encodeCascade(extDType, data, EncodeTestHelper.testCtx()); - // Then — non-terminal step with one open slot for the storage child assertThat(step.applicable()).isTrue(); assertThat(step.isTerminal()).isFalse(); assertThat(step.openChildren()).hasSize(1); @@ -105,12 +104,8 @@ void encodeCascade_exposesStorageAsOpenChild() { @Test void encodeCascade_rejectsNonExtensionDtype() { - // Given - var sut = new ExtEncoding(); - - // When / Then org.assertj.core.api.Assertions.assertThatThrownBy(() -> - sut.encodeCascade(new DType.Primitive(PType.I64, false), new long[]{1L}, + ENCODER.encodeCascade(new DType.Primitive(PType.I64, false), new long[]{1L}, EncodeTestHelper.testCtx())) .isInstanceOf(io.github.dfa1.vortex.core.VortexException.class) .hasMessageContaining("expected extension dtype"); @@ -122,33 +117,24 @@ class Decode { @Test void decode_extensionWrappingI64_returnsStorageArray() { - // Given long[] values = {10L, 20L, 30L, 40L}; MemorySegment buf = TestSegments.leLongs(values); DType storageDType = new DType.Primitive(PType.I64, false); DType extDType = new DType.Extension("vortex.timestamp", storageDType, null, false); - // child node: vortex.primitive with buffer index 0 ArrayNode primitiveNode = ArrayNode.of(EncodingId.VORTEX_PRIMITIVE, null, new ArrayNode[0], new int[]{0}, null); - // parent node: vortex.ext, no buffers, one child ArrayNode extNode = ArrayNode.of(EncodingId.VORTEX_EXT, null, new ArrayNode[]{primitiveNode}, new int[0], null); - Registry registry = TestRegistry.of(new PrimitiveEncoding(), new ExtEncoding()); - DecodeContext ctx = new DecodeContext( - extNode, extDType, values.length, new MemorySegment[]{buf}, registry, Arena.ofAuto()); - - var sut = new ExtEncoding(); + extNode, extDType, values.length, new MemorySegment[]{buf}, REGISTRY, Arena.ofAuto()); - // When - var result = sut.decode(ctx); + var result = DECODER.decode(ctx); - // Then assertThat(result).isInstanceOf(LongArray.class); assertThat(result.length()).isEqualTo(values.length); + LongArray longArray = (LongArray) result; for (int i = 0; i < values.length; i++) { - LongArray longArray = (LongArray) result; assertThat(longArray.getLong(i)).isEqualTo(values[i]); } } diff --git a/core/src/test/java/io/github/dfa1/vortex/encoding/FixedSizeListEncodingTest.java b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/FixedSizeListEncodingEncoderTest.java similarity index 67% rename from core/src/test/java/io/github/dfa1/vortex/encoding/FixedSizeListEncodingTest.java rename to writer/src/test/java/io/github/dfa1/vortex/writer/encode/FixedSizeListEncodingEncoderTest.java index 4745c9a2..9d05c56f 100644 --- a/core/src/test/java/io/github/dfa1/vortex/encoding/FixedSizeListEncodingTest.java +++ b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/FixedSizeListEncodingEncoderTest.java @@ -1,9 +1,21 @@ -package io.github.dfa1.vortex.encoding; +package io.github.dfa1.vortex.writer.encode; import io.github.dfa1.vortex.core.ArrayStats; import io.github.dfa1.vortex.core.DType; import io.github.dfa1.vortex.core.array.FixedSizeListArray; import io.github.dfa1.vortex.core.array.IntArray; +import io.github.dfa1.vortex.encoding.ArrayNode; +import io.github.dfa1.vortex.encoding.DTypes; +import io.github.dfa1.vortex.encoding.DecodeContext; +import io.github.dfa1.vortex.encoding.EncodeNode; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodeTestHelper; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.FixedSizeListData; +import io.github.dfa1.vortex.encoding.Registry; +import io.github.dfa1.vortex.encoding.TestRegistry; +import io.github.dfa1.vortex.reader.decode.FixedSizeListEncodingDecoder; +import io.github.dfa1.vortex.reader.decode.PrimitiveEncodingDecoder; import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; @@ -13,8 +25,11 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; -class FixedSizeListEncodingTest { +class FixedSizeListEncodingEncoderTest { + private static final FixedSizeListEncodingEncoder ENCODER = new FixedSizeListEncodingEncoder(); + private static final FixedSizeListEncodingDecoder DECODER = new FixedSizeListEncodingDecoder(); + private static final Registry REGISTRY = TestRegistry.ofDecoders(DECODER, new PrimitiveEncodingDecoder()); private static ArrayNode toArrayNode(EncodeNode node) { ArrayNode[] children = new ArrayNode[node.children().length]; @@ -24,47 +39,28 @@ private static ArrayNode toArrayNode(EncodeNode node) { return ArrayNode.of(node.encodingId(), node.metadata(), children, node.bufferIndices(), ArrayStats.empty()); } - private static Registry registry() { - return Registry.builder() - .register(new FixedSizeListEncoding()) - .register(new PrimitiveEncoding()) - .build(); - } - @Nested class Encode { @Test void accepts_fixedSizeListDtype_true() { - // Given - FixedSizeListEncoding sut = new FixedSizeListEncoding(); DType.FixedSizeList dtype = new DType.FixedSizeList(DTypes.I32, 3, false); - - // When / Then - assertThat(sut.accepts(dtype)).isTrue(); + assertThat(ENCODER.accepts(dtype)).isTrue(); } @Test void accepts_primitiveDtype_false() { - // Given - FixedSizeListEncoding sut = new FixedSizeListEncoding(); - - // When / Then - assertThat(sut.accepts(DTypes.I32)).isFalse(); + assertThat(ENCODER.accepts(DTypes.I32)).isFalse(); } @Test void encode_producesOneChild_noBuffers() { - // Given DType.FixedSizeList dtype = new DType.FixedSizeList(DTypes.I32, 2, false); int[] elements = {1, 2, 3, 4}; FixedSizeListData data = new FixedSizeListData(elements, 2); - FixedSizeListEncoding sut = new FixedSizeListEncoding(); - // When - EncodeResult result = sut.encode(dtype, data, EncodeTestHelper.testCtx()); + EncodeResult result = ENCODER.encode(dtype, data, EncodeTestHelper.testCtx()); - // Then assertThat(result.rootNode().encodingId()).isEqualTo(EncodingId.VORTEX_FIXED_SIZE_LIST); assertThat(result.rootNode().bufferIndices()).isEmpty(); assertThat(result.rootNode().children()).hasSize(1); @@ -76,20 +72,16 @@ class Decode { @Test void roundTrip_i32Elements_preservesValues() { - // Given DType.FixedSizeList dtype = new DType.FixedSizeList(DTypes.I32, 3, false); int[] elements = {10, 20, 30, 40, 50, 60}; FixedSizeListData data = new FixedSizeListData(elements, 2); - FixedSizeListEncoding sut = new FixedSizeListEncoding(); - // When - EncodeResult result = sut.encode(dtype, data, EncodeTestHelper.testCtx()); + EncodeResult result = ENCODER.encode(dtype, data, EncodeTestHelper.testCtx()); MemorySegment[] bufs = result.buffers().toArray(MemorySegment[]::new); DecodeContext ctx = new DecodeContext( - toArrayNode(result.rootNode()), dtype, 2, bufs, registry(), Arena.global()); - FixedSizeListArray decoded = (FixedSizeListArray) sut.decode(ctx); + toArrayNode(result.rootNode()), dtype, 2, bufs, REGISTRY, Arena.global()); + FixedSizeListArray decoded = (FixedSizeListArray) DECODER.decode(ctx); - // Then assertThat(decoded.length()).isEqualTo(2); assertThat(decoded.fixedSize()).isEqualTo(3); IntArray elems = (IntArray) decoded.elements(); @@ -101,20 +93,16 @@ void roundTrip_i32Elements_preservesValues() { @Test void roundTrip_fixedSizeOne_preservesValues() { - // Given DType.FixedSizeList dtype = new DType.FixedSizeList(DTypes.I32, 1, false); int[] elements = {7, 8, 9}; FixedSizeListData data = new FixedSizeListData(elements, 3); - FixedSizeListEncoding sut = new FixedSizeListEncoding(); - // When - EncodeResult result = sut.encode(dtype, data, EncodeTestHelper.testCtx()); + EncodeResult result = ENCODER.encode(dtype, data, EncodeTestHelper.testCtx()); MemorySegment[] bufs = result.buffers().toArray(MemorySegment[]::new); DecodeContext ctx = new DecodeContext( - toArrayNode(result.rootNode()), dtype, 3, bufs, registry(), Arena.global()); - FixedSizeListArray decoded = (FixedSizeListArray) sut.decode(ctx); + toArrayNode(result.rootNode()), dtype, 3, bufs, REGISTRY, Arena.global()); + FixedSizeListArray decoded = (FixedSizeListArray) DECODER.decode(ctx); - // Then assertThat(decoded.length()).isEqualTo(3); assertThat(decoded.fixedSize()).isEqualTo(1); IntArray elems = (IntArray) decoded.elements(); @@ -125,14 +113,11 @@ void roundTrip_fixedSizeOne_preservesValues() { @Test void decode_wrongDtype_throws() { - // Given - FixedSizeListEncoding sut = new FixedSizeListEncoding(); ArrayNode node = ArrayNode.of(EncodingId.VORTEX_FIXED_SIZE_LIST, null, new ArrayNode[0], new int[0], ArrayStats.empty()); - DecodeContext ctx = new DecodeContext(node, DTypes.I32, 0, new MemorySegment[0], registry(), Arena.global()); + DecodeContext ctx = new DecodeContext(node, DTypes.I32, 0, new MemorySegment[0], REGISTRY, Arena.global()); - // When / Then - assertThatThrownBy(() -> sut.decode(ctx)) + assertThatThrownBy(() -> DECODER.decode(ctx)) .hasMessageContaining("DType.FixedSizeList"); } } diff --git a/core/src/test/java/io/github/dfa1/vortex/encoding/FrameOfReferenceEncodingTest.java b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/FrameOfReferenceEncodingEncoderTest.java similarity index 61% rename from core/src/test/java/io/github/dfa1/vortex/encoding/FrameOfReferenceEncodingTest.java rename to writer/src/test/java/io/github/dfa1/vortex/writer/encode/FrameOfReferenceEncodingEncoderTest.java index 6ea12e5e..47269c4c 100644 --- a/core/src/test/java/io/github/dfa1/vortex/encoding/FrameOfReferenceEncodingTest.java +++ b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/FrameOfReferenceEncodingEncoderTest.java @@ -1,4 +1,4 @@ -package io.github.dfa1.vortex.encoding; +package io.github.dfa1.vortex.writer.encode; import io.github.dfa1.vortex.core.ArrayStats; import io.github.dfa1.vortex.core.DType; @@ -6,7 +6,19 @@ import io.github.dfa1.vortex.core.array.Array; import io.github.dfa1.vortex.core.array.ArraySegments; import io.github.dfa1.vortex.core.array.MaskedArray; +import io.github.dfa1.vortex.encoding.ArrayNode; +import io.github.dfa1.vortex.encoding.DTypes; +import io.github.dfa1.vortex.encoding.DecodeContext; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodeTestHelper; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.PTypeIO; +import io.github.dfa1.vortex.encoding.Registry; +import io.github.dfa1.vortex.encoding.TestRegistry; import io.github.dfa1.vortex.proto.ScalarValue; +import io.github.dfa1.vortex.reader.decode.BoolEncodingDecoder; +import io.github.dfa1.vortex.reader.decode.FrameOfReferenceEncodingDecoder; +import io.github.dfa1.vortex.reader.decode.PrimitiveEncodingDecoder; import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; @@ -20,8 +32,11 @@ import static org.assertj.core.api.Assertions.assertThat; -class FrameOfReferenceEncodingTest { +class FrameOfReferenceEncodingEncoderTest { + private static final FrameOfReferenceEncodingEncoder ENCODER = new FrameOfReferenceEncodingEncoder(); + private static final FrameOfReferenceEncodingDecoder DECODER = new FrameOfReferenceEncodingDecoder(); + private static final Registry REGISTRY = TestRegistry.ofDecoders(DECODER, new PrimitiveEncodingDecoder()); @Nested class Decode { @@ -43,115 +58,74 @@ private static DecodeContext buildForContext( } ArrayNode childNode = ArrayNode.of( - EncodingId.VORTEX_PRIMITIVE, - null, - new ArrayNode[0], - new int[]{0}, - ArrayStats.empty() - ); - + EncodingId.VORTEX_PRIMITIVE, null, new ArrayNode[0], new int[]{0}, ArrayStats.empty()); ArrayNode forNode = ArrayNode.of( - EncodingId.FASTLANES_FOR, - ByteBuffer.wrap(metaBytes), - new ArrayNode[]{childNode}, - new int[0], - ArrayStats.empty() - ); + EncodingId.FASTLANES_FOR, ByteBuffer.wrap(metaBytes), + new ArrayNode[]{childNode}, new int[0], ArrayStats.empty()); MemorySegment[] segments = {MemorySegment.ofArray(childBytes)}; - - Registry registry = TestRegistry.of(new FrameOfReferenceEncoding(), new PrimitiveEncoding()); - - return new DecodeContext(forNode, dtype, residuals.length, segments, registry, java.lang.foreign.Arena.global()); + return new DecodeContext(forNode, dtype, residuals.length, segments, REGISTRY, java.lang.foreign.Arena.global()); } @Test void decode_i64_addsReferenceToResiduals() { - // Given long reference = 1000L; long[] residuals = {0, 1, 2, 3, 4}; long[] expected = {1000, 1001, 1002, 1003, 1004}; DecodeContext ctx = buildForContext(DTypes.I64, reference, residuals, PType.I64); - FrameOfReferenceEncoding sut = new FrameOfReferenceEncoding(); + Array result = DECODER.decode(ctx); - // When - Array result = sut.decode(ctx); - - // Then assertThat(result.length()).isEqualTo(residuals.length); - var layout = PTypeIO.LE_LONG; for (int i = 0; i < expected.length; i++) { - assertThat(ArraySegments.of(result).get(layout, (long) i * 8)) - .as("index %d", i) - .isEqualTo(expected[i]); + assertThat(ArraySegments.of(result).get(PTypeIO.LE_LONG, (long) i * 8)) + .as("index %d", i).isEqualTo(expected[i]); } } @Test void decode_i32_addsReferenceToResiduals() { - // Given long reference = -100L; long[] residuals = {0, 5, 10, 15}; int[] expected = {-100, -95, -90, -85}; DecodeContext ctx = buildForContext(DTypes.I32, reference, residuals, PType.I32); - FrameOfReferenceEncoding sut = new FrameOfReferenceEncoding(); - - // When - Array result = sut.decode(ctx); + Array result = DECODER.decode(ctx); - // Then assertThat(result.length()).isEqualTo(residuals.length); - var layout = PTypeIO.LE_INT; for (int i = 0; i < expected.length; i++) { - assertThat(ArraySegments.of(result).get(layout, (long) i * 4)) - .as("index %d", i) - .isEqualTo(expected[i]); + assertThat(ArraySegments.of(result).get(PTypeIO.LE_INT, (long) i * 4)) + .as("index %d", i).isEqualTo(expected[i]); } } @Test void decode_zeroReference_returnsChildUnchanged() { - // Given — reference == 0, should skip the add entirely long[] residuals = {7, 8, 9}; DecodeContext ctx = buildForContext(DTypes.I64, 0L, residuals, PType.I64); - FrameOfReferenceEncoding sut = new FrameOfReferenceEncoding(); - - // When - Array result = sut.decode(ctx); + Array result = DECODER.decode(ctx); - // Then — values unchanged - var layout = PTypeIO.LE_LONG; for (int i = 0; i < residuals.length; i++) { - assertThat(ArraySegments.of(result).get(layout, (long) i * 8)).isEqualTo(residuals[i]); + assertThat(ArraySegments.of(result).get(PTypeIO.LE_LONG, (long) i * 8)).isEqualTo(residuals[i]); } } @ParameterizedTest @ValueSource(longs = {Long.MIN_VALUE, Long.MAX_VALUE, -1L, 1L}) void decode_wrappingAdd_i64(long reference) { - // Given — wrapping arithmetic: MAX + 1 wraps to MIN long[] residuals = {1L}; DecodeContext ctx = buildForContext(DTypes.I64, reference, residuals, PType.I64); - FrameOfReferenceEncoding sut = new FrameOfReferenceEncoding(); + Array result = DECODER.decode(ctx); - // When - Array result = sut.decode(ctx); - - // Then - var layout = PTypeIO.LE_LONG; - long got = ArraySegments.of(result).get(layout, 0L); + long got = ArraySegments.of(result).get(PTypeIO.LE_LONG, 0L); assertThat(got).isEqualTo(residuals[0] + reference); } @Test void decode_nullableResiduals_returnsMaskedArrayWithCorrectValues() { - // Given — 4 I32 residuals; positions 1 and 3 are null (validity: 0b00000101 = 0x05) - // Residuals: [0, 0, 5, 0], reference: 100 → valid outputs: [100, ?, 105, ?] long reference = 100L; long[] residuals = {0, 0, 5, 0}; - MemorySegment validitySeg = MemorySegment.ofArray(new byte[]{0x05}); // bits 0,2 + MemorySegment validitySeg = MemorySegment.ofArray(new byte[]{0x05}); byte[] residualBytes = new byte[residuals.length * 4]; ByteBuffer bb = ByteBuffer.wrap(residualBytes).order(ByteOrder.LITTLE_ENDIAN); @@ -167,26 +141,23 @@ void decode_nullableResiduals_returnsMaskedArrayWithCorrectValues() { ArrayNode forNode = ArrayNode.of( EncodingId.FASTLANES_FOR, ByteBuffer.wrap(metaBytes), new ArrayNode[]{primNode}, new int[0], ArrayStats.empty()); - Registry registry = TestRegistry.of(new FrameOfReferenceEncoding(), new PrimitiveEncoding(), new BoolEncoding()); + Registry registry = TestRegistry.ofDecoders( + new FrameOfReferenceEncodingDecoder(), new PrimitiveEncodingDecoder(), new BoolEncodingDecoder()); MemorySegment[] segments = {MemorySegment.ofArray(residualBytes), validitySeg}; DecodeContext ctx = new DecodeContext( forNode, DTypes.I32, residuals.length, segments, registry, java.lang.foreign.Arena.global()); - FrameOfReferenceEncoding sut = new FrameOfReferenceEncoding(); - // When - Array result = sut.decode(ctx); + Array result = DECODER.decode(ctx); - // Then — MaskedArray; reference added to valid positions only assertThat(result).isInstanceOf(MaskedArray.class); MaskedArray masked = (MaskedArray) result; assertThat(masked.isValid(0)).isTrue(); assertThat(masked.isValid(1)).isFalse(); assertThat(masked.isValid(2)).isTrue(); assertThat(masked.isValid(3)).isFalse(); - var layout = PTypeIO.LE_INT; - assertThat(ArraySegments.of(masked.inner()).get(layout, 0L)).isEqualTo(100); - assertThat(ArraySegments.of(masked.inner()).get(layout, 8L)).isEqualTo(105); + assertThat(ArraySegments.of(masked.inner()).get(PTypeIO.LE_INT, 0L)).isEqualTo(100); + assertThat(ArraySegments.of(masked.inner()).get(PTypeIO.LE_INT, 8L)).isEqualTo(105); } } @@ -215,40 +186,26 @@ static Stream i32Arrays() { @ParameterizedTest @MethodSource("i64Arrays") void encodeDecode_i64_isLossless(long[] data) { - // Given - var sut = new FrameOfReferenceEncoding(); - Registry registry = TestRegistry.withPrimitive(sut); - var le = PTypeIO.LE_LONG; - - // When - EncodeResult encoded = sut.encode(DTypes.I64, data, EncodeTestHelper.testCtx()); - DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, data.length, DTypes.I64, registry); - Array result = sut.decode(ctx); + EncodeResult encoded = ENCODER.encode(DTypes.I64, data, EncodeTestHelper.testCtx()); + DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, data.length, DTypes.I64, REGISTRY); + Array result = DECODER.decode(ctx); - // Then assertThat(result.length()).isEqualTo(data.length); for (int i = 0; i < data.length; i++) { - assertThat(ArraySegments.of(result).get(le, (long) i * 8)).as("index %d", i).isEqualTo(data[i]); + assertThat(ArraySegments.of(result).get(PTypeIO.LE_LONG, (long) i * 8)).as("index %d", i).isEqualTo(data[i]); } } @ParameterizedTest @MethodSource("i32Arrays") void encodeDecode_i32_isLossless(int[] data) { - // Given - var sut = new FrameOfReferenceEncoding(); - Registry registry = TestRegistry.withPrimitive(sut); - var le = PTypeIO.LE_INT; - - // When - EncodeResult encoded = sut.encode(DTypes.I32, data, EncodeTestHelper.testCtx()); - DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, data.length, DTypes.I32, registry); - Array result = sut.decode(ctx); + EncodeResult encoded = ENCODER.encode(DTypes.I32, data, EncodeTestHelper.testCtx()); + DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, data.length, DTypes.I32, REGISTRY); + Array result = DECODER.decode(ctx); - // Then assertThat(result.length()).isEqualTo(data.length); for (int i = 0; i < data.length; i++) { - assertThat(ArraySegments.of(result).get(le, (long) i * 4)).as("index %d", i).isEqualTo(data[i]); + assertThat(ArraySegments.of(result).get(PTypeIO.LE_INT, (long) i * 4)).as("index %d", i).isEqualTo(data[i]); } } } diff --git a/core/src/test/java/io/github/dfa1/vortex/encoding/FsstEncodingTest.java b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/FsstEncodingEncoderTest.java similarity index 60% rename from core/src/test/java/io/github/dfa1/vortex/encoding/FsstEncodingTest.java rename to writer/src/test/java/io/github/dfa1/vortex/writer/encode/FsstEncodingEncoderTest.java index cdd8eb20..b5c356a6 100644 --- a/core/src/test/java/io/github/dfa1/vortex/encoding/FsstEncodingTest.java +++ b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/FsstEncodingEncoderTest.java @@ -1,9 +1,21 @@ -package io.github.dfa1.vortex.encoding; +package io.github.dfa1.vortex.writer.encode; -import io.github.dfa1.vortex.proto.FSSTMetadata; import io.github.dfa1.vortex.core.PType; import io.github.dfa1.vortex.core.VortexException; import io.github.dfa1.vortex.core.array.VarBinArray; +import io.github.dfa1.vortex.encoding.ArrayNode; +import io.github.dfa1.vortex.encoding.DTypes; +import io.github.dfa1.vortex.encoding.DecodeContext; +import io.github.dfa1.vortex.encoding.EncodeNode; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodeTestHelper; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.PTypeIO; +import io.github.dfa1.vortex.encoding.Registry; +import io.github.dfa1.vortex.encoding.TestRegistry; +import io.github.dfa1.vortex.proto.FSSTMetadata; +import io.github.dfa1.vortex.reader.decode.FsstEncodingDecoder; +import io.github.dfa1.vortex.reader.decode.PrimitiveEncodingDecoder; import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; @@ -20,7 +32,11 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; -class FsstEncodingTest { +class FsstEncodingEncoderTest { + + private static final FsstEncodingEncoder ENCODER = new FsstEncodingEncoder(); + private static final FsstEncodingDecoder DECODER = new FsstEncodingDecoder(); + private static final Registry REGISTRY = TestRegistry.ofDecoders(DECODER, new PrimitiveEncodingDecoder()); @Nested class Encode { @@ -56,50 +72,30 @@ private static String[] repeat(String s, int n) { @Test void accepts_utf8_true() { - // Given - var sut = new FsstEncoding(); - - // When / Then - assertThat(sut.accepts(DTypes.UTF8)).isTrue(); + assertThat(ENCODER.accepts(DTypes.UTF8)).isTrue(); } @Test void accepts_binary_true() { - // Given - var sut = new FsstEncoding(); - - // When / Then - assertThat(sut.accepts(DTypes.BINARY)).isTrue(); + assertThat(ENCODER.accepts(DTypes.BINARY)).isTrue(); } @Test void accepts_primitive_false() { - // Given - var sut = new FsstEncoding(); - - // When / Then - assertThat(sut.accepts(DTypes.I32)).isFalse(); + assertThat(ENCODER.accepts(DTypes.I32)).isFalse(); } @ParameterizedTest(name = "{0}") - @MethodSource("io.github.dfa1.vortex.encoding.FsstEncodingTest$Encode#stringArrays") + @MethodSource("stringArrays") void encode_thenDecode_roundtripsAllStrings(String name, String[] values) { - // Given - var sut = new FsstEncoding(); Arena arena = Arena.ofAuto(); - // When - EncodeResult result = sut.encode(DTypes.UTF8, values, EncodeTestHelper.testCtx()); + EncodeResult result = ENCODER.encode(DTypes.UTF8, values, EncodeTestHelper.testCtx()); MemorySegment[] bufs = result.buffers().toArray(MemorySegment[]::new); ArrayNode node = toArrayNode(result.rootNode()); - Registry registry = Registry.builder() - .register(new PrimitiveEncoding()) - .register(sut) - .build(); - DecodeContext ctx = new DecodeContext(node, DTypes.UTF8, values.length, bufs, registry, arena); - var decoded = (VarBinArray) sut.decode(ctx); - - // Then + DecodeContext ctx = new DecodeContext(node, DTypes.UTF8, values.length, bufs, REGISTRY, arena); + var decoded = (VarBinArray) DECODER.decode(ctx); + assertThat(decoded.length()).isEqualTo(values.length); for (int i = 0; i < values.length; i++) { assertThat(decoded.getString(i)).as("index %d", i).isEqualTo(values[i]); @@ -116,31 +112,26 @@ private static DecodeContext buildCtx( ) { Arena arena = Arena.ofAuto(); - // Buffer 0: symbol table (8 bytes per symbol, LE u64) MemorySegment symBuf = arena.allocate(Math.max(symbols.length * 8L, 1), 8); for (int i = 0; i < symbols.length; i++) { symBuf.setAtIndex(PTypeIO.LE_LONG, i, symbols[i]); } - // Buffer 1: symbol lengths (1 byte per symbol) MemorySegment symLenBuf = arena.allocate(Math.max(symLens.length, 1)); for (int i = 0; i < symLens.length; i++) { symLenBuf.set(ValueLayout.JAVA_BYTE, i, symLens[i]); } - // Buffer 2: compressed bytes MemorySegment compBuf = arena.allocate(Math.max(compressed.length, 1)); for (int i = 0; i < compressed.length; i++) { compBuf.set(ValueLayout.JAVA_BYTE, i, compressed[i]); } - // Buffer 3: uncompressed_lengths (I32) MemorySegment uncompLenBuf = arena.allocate((long) uncompLens.length * Integer.BYTES, Integer.BYTES); for (int i = 0; i < uncompLens.length; i++) { uncompLenBuf.setAtIndex(PTypeIO.LE_INT, i, uncompLens[i]); } - // Buffer 4: codes_offsets (I32, n+1 elements) MemorySegment codesOffBuf = arena.allocate((long) codesOffsets.length * Integer.BYTES, Integer.BYTES); for (int i = 0; i < codesOffsets.length; i++) { codesOffBuf.setAtIndex(PTypeIO.LE_INT, i, codesOffsets[i]); @@ -158,103 +149,47 @@ private static DecodeContext buildCtx( EncodingId.VORTEX_FSST, ByteBuffer.wrap(metaBytes), new ArrayNode[]{uncompLensNode, codesOffNode}, new int[]{0, 1, 2}, null); - return new DecodeContext(root, DTypes.UTF8, n, segs, buildRegistry(), arena); - } - - private static Registry buildRegistry() { - return Registry.builder().register(new PrimitiveEncoding()).build(); + return new DecodeContext(root, DTypes.UTF8, n, segs, REGISTRY, arena); } @Test void decode_singleByteSymbol_decompressesCorrectly() { - // Given: symbol 0 = 'A' (LE u64 = 0x41, length 1); string "AA" → codes [0, 0] - var sut = new FsstEncoding(); DecodeContext ctx = buildCtx( - new long[]{0x41L}, - new byte[]{1}, - new byte[]{0x00, 0x00}, - new int[]{2}, - new int[]{0, 2}, - 1 - ); - - // When - var result = sut.decode(ctx); - - // Then - assertThat(result).isInstanceOf(VarBinArray.class); + new long[]{0x41L}, new byte[]{1}, new byte[]{0x00, 0x00}, + new int[]{2}, new int[]{0, 2}, 1); + var result = DECODER.decode(ctx); VarBinArray vba = (VarBinArray) result; - assertThat(vba.length()).isEqualTo(1); assertThat(vba.getBytes(0)).isEqualTo("AA".getBytes(StandardCharsets.UTF_8)); } @Test void decode_escapeByte_decompressesCorrectly() { - // Given: no symbols; string "XY" → ESCAPE 'X' ESCAPE 'Y' - var sut = new FsstEncoding(); DecodeContext ctx = buildCtx( - new long[0], - new byte[0], - new byte[]{(byte) 0xFF, 0x58, (byte) 0xFF, 0x59}, - new int[]{2}, - new int[]{0, 4}, - 1 - ); - - // When - var result = sut.decode(ctx); - - // Then - assertThat(result).isInstanceOf(VarBinArray.class); + new long[0], new byte[0], new byte[]{(byte) 0xFF, 0x58, (byte) 0xFF, 0x59}, + new int[]{2}, new int[]{0, 4}, 1); + var result = DECODER.decode(ctx); VarBinArray vba = (VarBinArray) result; - assertThat(vba.length()).isEqualTo(1); assertThat(vba.getBytes(0)).isEqualTo("XY".getBytes(StandardCharsets.UTF_8)); } @Test void decode_multiByteSymbol_decompressesCorrectly() { - // Given: symbol 0 = "ab" (LE u64 = 0x6261, length 2); string "ab" → code [0] - var sut = new FsstEncoding(); DecodeContext ctx = buildCtx( - new long[]{0x6261L}, - new byte[]{2}, - new byte[]{0x00}, - new int[]{2}, - new int[]{0, 1}, - 1 - ); - - // When - var result = sut.decode(ctx); - - // Then - assertThat(result).isInstanceOf(VarBinArray.class); + new long[]{0x6261L}, new byte[]{2}, new byte[]{0x00}, + new int[]{2}, new int[]{0, 1}, 1); + var result = DECODER.decode(ctx); VarBinArray vba = (VarBinArray) result; - assertThat(vba.length()).isEqualTo(1); assertThat(vba.getBytes(0)).isEqualTo("ab".getBytes(StandardCharsets.UTF_8)); } @Test void decode_multipleStrings_decompressesAll() { - // Given: symbol 0 = 'H'; strings ["H", "HH", "!"] where "!" uses ESCAPE - // compressed: [0x00] | [0x00, 0x00] | [0xFF, 0x21] - var sut = new FsstEncoding(); DecodeContext ctx = buildCtx( - new long[]{0x48L}, - new byte[]{1}, + new long[]{0x48L}, new byte[]{1}, new byte[]{0x00, 0x00, 0x00, (byte) 0xFF, 0x21}, - new int[]{1, 2, 1}, - new int[]{0, 1, 3, 5}, - 3 - ); - - // When - var result = sut.decode(ctx); - - // Then - assertThat(result).isInstanceOf(VarBinArray.class); + new int[]{1, 2, 1}, new int[]{0, 1, 3, 5}, 3); + var result = DECODER.decode(ctx); VarBinArray vba = (VarBinArray) result; - assertThat(vba.length()).isEqualTo(3); assertThat(vba.getBytes(0)).isEqualTo("H".getBytes(StandardCharsets.UTF_8)); assertThat(vba.getBytes(1)).isEqualTo("HH".getBytes(StandardCharsets.UTF_8)); assertThat(vba.getBytes(2)).isEqualTo("!".getBytes(StandardCharsets.UTF_8)); @@ -262,15 +197,9 @@ void decode_multipleStrings_decompressesAll() { @Test void decode_missingMetadata_throwsVortexException() { - // Given - var sut = new FsstEncoding(); ArrayNode node = ArrayNode.of(EncodingId.VORTEX_FSST, null, new ArrayNode[0], new int[0], null); - DecodeContext ctx = new DecodeContext(node, DTypes.UTF8, 0, new MemorySegment[0], - buildRegistry(), Arena.ofAuto()); - - // When / Then - assertThatThrownBy(() -> sut.decode(ctx)) - .isInstanceOf(VortexException.class); + DecodeContext ctx = new DecodeContext(node, DTypes.UTF8, 0, new MemorySegment[0], REGISTRY, Arena.ofAuto()); + assertThatThrownBy(() -> DECODER.decode(ctx)).isInstanceOf(VortexException.class); } } @@ -279,19 +208,12 @@ class Metadata { @Test void encode_metadata_ptypes_areI32() throws Exception { - // Given — FsstEncoding always uses I32 (ordinal=6) for both ptype fields - // if either tag drifts, the ptype reads as 0 (U8) which is proto3 default String[] data = {"hello", "world", "hello", "fsst"}; - var sut = new FsstEncoding(); - - // When - EncodeResult result = sut.encode(DTypes.UTF8, data, EncodeTestHelper.testCtx()); - FSSTMetadata meta = - FSSTMetadata.decode(java.lang.foreign.MemorySegment.ofBuffer(result.rootNode().metadata().duplicate()), 0, java.lang.foreign.MemorySegment.ofBuffer(result.rootNode().metadata().duplicate()).byteSize()); - - // Then - assertThat(meta.uncompressed_lengths_ptype().value()).isEqualTo(6); // I32 - assertThat(meta.codes_offsets_ptype().value()).isEqualTo(6); // I32 + EncodeResult result = ENCODER.encode(DTypes.UTF8, data, EncodeTestHelper.testCtx()); + var metaSeg = java.lang.foreign.MemorySegment.ofBuffer(result.rootNode().metadata().duplicate()); + FSSTMetadata meta = FSSTMetadata.decode(metaSeg, 0, metaSeg.byteSize()); + assertThat(meta.uncompressed_lengths_ptype().value()).isEqualTo(6); + assertThat(meta.codes_offsets_ptype().value()).isEqualTo(6); } } } diff --git a/core/src/test/java/io/github/dfa1/vortex/encoding/ListEncodingTest.java b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/ListEncodingEncoderTest.java similarity index 63% rename from core/src/test/java/io/github/dfa1/vortex/encoding/ListEncodingTest.java rename to writer/src/test/java/io/github/dfa1/vortex/writer/encode/ListEncodingEncoderTest.java index 975b556b..09db6ab2 100644 --- a/core/src/test/java/io/github/dfa1/vortex/encoding/ListEncodingTest.java +++ b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/ListEncodingEncoderTest.java @@ -1,10 +1,22 @@ -package io.github.dfa1.vortex.encoding; +package io.github.dfa1.vortex.writer.encode; -import io.github.dfa1.vortex.proto.ListMetadata; import io.github.dfa1.vortex.core.ArrayStats; import io.github.dfa1.vortex.core.array.IntArray; import io.github.dfa1.vortex.core.array.ListArray; import io.github.dfa1.vortex.core.array.LongArray; +import io.github.dfa1.vortex.encoding.ArrayNode; +import io.github.dfa1.vortex.encoding.DTypes; +import io.github.dfa1.vortex.encoding.DecodeContext; +import io.github.dfa1.vortex.encoding.EncodeNode; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodeTestHelper; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.ListData; +import io.github.dfa1.vortex.encoding.Registry; +import io.github.dfa1.vortex.encoding.TestRegistry; +import io.github.dfa1.vortex.proto.ListMetadata; +import io.github.dfa1.vortex.reader.decode.ListEncodingDecoder; +import io.github.dfa1.vortex.reader.decode.PrimitiveEncodingDecoder; import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; @@ -14,8 +26,11 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; -class ListEncodingTest { +class ListEncodingEncoderTest { + private static final ListEncodingEncoder ENCODER = new ListEncodingEncoder(); + private static final ListEncodingDecoder DECODER = new ListEncodingDecoder(); + private static final Registry REGISTRY = TestRegistry.ofDecoders(DECODER, new PrimitiveEncodingDecoder()); private static ArrayNode toArrayNode(EncodeNode node) { ArrayNode[] children = new ArrayNode[node.children().length]; @@ -25,46 +40,27 @@ private static ArrayNode toArrayNode(EncodeNode node) { return ArrayNode.of(node.encodingId(), node.metadata(), children, node.bufferIndices(), ArrayStats.empty()); } - private static Registry registry() { - return Registry.builder() - .register(new ListEncoding()) - .register(new PrimitiveEncoding()) - .build(); - } - @Nested class Encode { @Test void accepts_listDtype_true() { - // Given - ListEncoding sut = new ListEncoding(); - - // When / Then - assertThat(sut.accepts(DTypes.LIST_I32)).isTrue(); + assertThat(ENCODER.accepts(DTypes.LIST_I32)).isTrue(); } @Test void accepts_primitiveDtype_false() { - // Given - ListEncoding sut = new ListEncoding(); - - // When / Then - assertThat(sut.accepts(DTypes.I32)).isFalse(); + assertThat(ENCODER.accepts(DTypes.I32)).isFalse(); } @Test void encode_producesTwoChildren_noBuffers() { - // Given int[] elements = {1, 2, 3, 4, 5}; long[] offsets = {0, 2, 5}; ListData data = new ListData(elements, offsets, 2); - ListEncoding sut = new ListEncoding(); - // When - EncodeResult result = sut.encode(DTypes.LIST_I32, data, EncodeTestHelper.testCtx()); + EncodeResult result = ENCODER.encode(DTypes.LIST_I32, data, EncodeTestHelper.testCtx()); - // Then assertThat(result.rootNode().encodingId()).isEqualTo(EncodingId.VORTEX_LIST); assertThat(result.rootNode().bufferIndices()).isEmpty(); assertThat(result.rootNode().children()).hasSize(2); @@ -76,20 +72,16 @@ class Decode { @Test void roundTrip_i32Elements_preservesValues() { - // Given int[] elements = {10, 20, 30, 40, 50}; long[] offsets = {0, 2, 3, 5}; ListData data = new ListData(elements, offsets, 3); - ListEncoding sut = new ListEncoding(); - // When - EncodeResult result = sut.encode(DTypes.LIST_I32, data, EncodeTestHelper.testCtx()); + EncodeResult result = ENCODER.encode(DTypes.LIST_I32, data, EncodeTestHelper.testCtx()); MemorySegment[] bufs = result.buffers().toArray(MemorySegment[]::new); DecodeContext ctx = new DecodeContext( - toArrayNode(result.rootNode()), DTypes.LIST_I32, 3, bufs, registry(), Arena.global()); - ListArray decoded = (ListArray) sut.decode(ctx); + toArrayNode(result.rootNode()), DTypes.LIST_I32, 3, bufs, REGISTRY, Arena.global()); + ListArray decoded = (ListArray) DECODER.decode(ctx); - // Then assertThat(decoded.length()).isEqualTo(3); IntArray decodedElems = (IntArray) decoded.elements(); assertThat(decodedElems.length()).isEqualTo(5); @@ -105,20 +97,16 @@ void roundTrip_i32Elements_preservesValues() { @Test void roundTrip_emptyLists_preservesOffsets() { - // Given int[] elements = {}; long[] offsets = {0, 0, 0}; ListData data = new ListData(elements, offsets, 2); - ListEncoding sut = new ListEncoding(); - // When - EncodeResult result = sut.encode(DTypes.LIST_I32, data, EncodeTestHelper.testCtx()); + EncodeResult result = ENCODER.encode(DTypes.LIST_I32, data, EncodeTestHelper.testCtx()); MemorySegment[] bufs = result.buffers().toArray(MemorySegment[]::new); DecodeContext ctx = new DecodeContext( - toArrayNode(result.rootNode()), DTypes.LIST_I32, 2, bufs, registry(), Arena.global()); - ListArray decoded = (ListArray) sut.decode(ctx); + toArrayNode(result.rootNode()), DTypes.LIST_I32, 2, bufs, REGISTRY, Arena.global()); + ListArray decoded = (ListArray) DECODER.decode(ctx); - // Then assertThat(decoded.length()).isEqualTo(2); assertThat(decoded.elements().length()).isEqualTo(0); assertThat(decoded.offsets().length()).isEqualTo(3); @@ -126,20 +114,16 @@ void roundTrip_emptyLists_preservesOffsets() { @Test void roundTrip_singleList_preservesValues() { - // Given int[] elements = {7, 8, 9}; long[] offsets = {0, 3}; ListData data = new ListData(elements, offsets, 1); - ListEncoding sut = new ListEncoding(); - // When - EncodeResult result = sut.encode(DTypes.LIST_I32, data, EncodeTestHelper.testCtx()); + EncodeResult result = ENCODER.encode(DTypes.LIST_I32, data, EncodeTestHelper.testCtx()); MemorySegment[] bufs = result.buffers().toArray(MemorySegment[]::new); DecodeContext ctx = new DecodeContext( - toArrayNode(result.rootNode()), DTypes.LIST_I32, 1, bufs, registry(), Arena.global()); - ListArray decoded = (ListArray) sut.decode(ctx); + toArrayNode(result.rootNode()), DTypes.LIST_I32, 1, bufs, REGISTRY, Arena.global()); + ListArray decoded = (ListArray) DECODER.decode(ctx); - // Then assertThat(decoded.length()).isEqualTo(1); IntArray decodedElems = (IntArray) decoded.elements(); assertThat(decodedElems.length()).isEqualTo(3); @@ -150,31 +134,23 @@ void roundTrip_singleList_preservesValues() { @Test void decode_wrongDtype_throws() { - // Given - ListEncoding sut = new ListEncoding(); ArrayNode node = ArrayNode.of(EncodingId.VORTEX_LIST, null, new ArrayNode[0], new int[0], ArrayStats.empty()); - DecodeContext ctx = new DecodeContext(node, DTypes.I32, 0, new MemorySegment[0], registry(), Arena.global()); + DecodeContext ctx = new DecodeContext(node, DTypes.I32, 0, new MemorySegment[0], REGISTRY, Arena.global()); - // When / Then - assertThatThrownBy(() -> sut.decode(ctx)) - .hasMessageContaining("DType.List"); + assertThatThrownBy(() -> DECODER.decode(ctx)).hasMessageContaining("DType.List"); } @Test void decode_wrongChildCount_throws() { - // Given - ListEncoding sut = new ListEncoding(); ArrayNode child = ArrayNode.of(EncodingId.VORTEX_PRIMITIVE, null, new ArrayNode[0], new int[0], ArrayStats.empty()); ArrayNode node = ArrayNode.of(EncodingId.VORTEX_LIST, java.nio.ByteBuffer.wrap(new byte[0]), new ArrayNode[]{child}, new int[0], ArrayStats.empty()); - DecodeContext ctx = new DecodeContext(node, DTypes.LIST_I32, 0, new MemorySegment[0], registry(), Arena.global()); + DecodeContext ctx = new DecodeContext(node, DTypes.LIST_I32, 0, new MemorySegment[0], REGISTRY, Arena.global()); - // When / Then - assertThatThrownBy(() -> sut.decode(ctx)) - .hasMessageContaining("expected 2 or 3 children"); + assertThatThrownBy(() -> DECODER.decode(ctx)).hasMessageContaining("expected 2 or 3 children"); } } @@ -183,19 +159,14 @@ class Metadata { @Test void encode_metadata_elementsLen_matchesElementCount() throws Exception { - // Given — 5 elements total across 2 outer lists - // if tag drifts, elements_len reads as 0 and decode allocates wrong-sized arrays int[] elements = {1, 2, 3, 4, 5}; long[] offsets = {0L, 2L, 5L}; ListData data = new ListData(elements, offsets, 2); - ListEncoding sut = new ListEncoding(); - // When - EncodeResult result = sut.encode(DTypes.LIST_I32, data, EncodeTestHelper.testCtx()); - ListMetadata meta = - ListMetadata.decode(java.lang.foreign.MemorySegment.ofBuffer(result.rootNode().metadata().duplicate()), 0, java.lang.foreign.MemorySegment.ofBuffer(result.rootNode().metadata().duplicate()).byteSize()); + EncodeResult result = ENCODER.encode(DTypes.LIST_I32, data, EncodeTestHelper.testCtx()); + var metaSeg = java.lang.foreign.MemorySegment.ofBuffer(result.rootNode().metadata().duplicate()); + ListMetadata meta = ListMetadata.decode(metaSeg, 0, metaSeg.byteSize()); - // Then assertThat(meta.elements_len()).isEqualTo(5); } } diff --git a/core/src/test/java/io/github/dfa1/vortex/encoding/ListViewEncodingTest.java b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/ListViewEncodingEncoderTest.java similarity index 68% rename from core/src/test/java/io/github/dfa1/vortex/encoding/ListViewEncodingTest.java rename to writer/src/test/java/io/github/dfa1/vortex/writer/encode/ListViewEncodingEncoderTest.java index 539ccfa5..4af35bc9 100644 --- a/core/src/test/java/io/github/dfa1/vortex/encoding/ListViewEncodingTest.java +++ b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/ListViewEncodingEncoderTest.java @@ -1,9 +1,22 @@ -package io.github.dfa1.vortex.encoding; +package io.github.dfa1.vortex.writer.encode; -import io.github.dfa1.vortex.proto.ListViewMetadata; import io.github.dfa1.vortex.core.ArrayStats; import io.github.dfa1.vortex.core.array.IntArray; import io.github.dfa1.vortex.core.array.ListViewArray; +import io.github.dfa1.vortex.encoding.ArrayNode; +import io.github.dfa1.vortex.encoding.DTypes; +import io.github.dfa1.vortex.encoding.DecodeContext; +import io.github.dfa1.vortex.encoding.EncodeNode; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodeTestHelper; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.ListViewData; +import io.github.dfa1.vortex.encoding.Registry; +import io.github.dfa1.vortex.encoding.TestDecodeContexts; +import io.github.dfa1.vortex.encoding.TestRegistry; +import io.github.dfa1.vortex.proto.ListViewMetadata; +import io.github.dfa1.vortex.reader.decode.ListViewEncodingDecoder; +import io.github.dfa1.vortex.reader.decode.PrimitiveEncodingDecoder; import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; @@ -13,8 +26,11 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; -class ListViewEncodingTest { +class ListViewEncodingEncoderTest { + private static final ListViewEncodingEncoder ENCODER = new ListViewEncodingEncoder(); + private static final ListViewEncodingDecoder DECODER = new ListViewEncodingDecoder(); + private static final Registry REGISTRY = TestRegistry.ofDecoders(DECODER, new PrimitiveEncodingDecoder()); private static ArrayNode toArrayNode(EncodeNode node) { ArrayNode[] children = new ArrayNode[node.children().length]; @@ -24,44 +40,28 @@ private static ArrayNode toArrayNode(EncodeNode node) { return ArrayNode.of(node.encodingId(), node.metadata(), children, node.bufferIndices(), ArrayStats.empty()); } - private static Registry registry() { - return TestRegistry.of(new ListViewEncoding(), new PrimitiveEncoding()); - } - @Nested class Encode { @Test void accepts_listDtype_true() { - // Given - ListViewEncoding sut = new ListViewEncoding(); - - // When / Then - assertThat(sut.accepts(DTypes.LIST_I32)).isTrue(); + assertThat(ENCODER.accepts(DTypes.LIST_I32)).isTrue(); } @Test void accepts_primitiveDtype_false() { - // Given - ListViewEncoding sut = new ListViewEncoding(); - - // When / Then - assertThat(sut.accepts(DTypes.I32)).isFalse(); + assertThat(ENCODER.accepts(DTypes.I32)).isFalse(); } @Test void encode_producesThreeChildren_noBuffers() { - // Given int[] elements = {1, 2, 3, 4, 5}; int[] offsets = {0, 2}; int[] sizes = {2, 3}; ListViewData data = new ListViewData(elements, offsets, sizes, 2); - ListViewEncoding sut = new ListViewEncoding(); - // When - EncodeResult result = sut.encode(DTypes.LIST_I32, data, EncodeTestHelper.testCtx()); + EncodeResult result = ENCODER.encode(DTypes.LIST_I32, data, EncodeTestHelper.testCtx()); - // Then assertThat(result.rootNode().encodingId()).isEqualTo(EncodingId.VORTEX_LISTVIEW); assertThat(result.rootNode().bufferIndices()).isEmpty(); assertThat(result.rootNode().children()).hasSize(3); @@ -73,21 +73,17 @@ class Decode { @Test void roundTrip_i32Elements_preservesValues() { - // Given int[] elements = {10, 20, 30, 40, 50}; int[] offsets = {0, 2, 3}; int[] sizes = {2, 1, 2}; ListViewData data = new ListViewData(elements, offsets, sizes, 3); - ListViewEncoding sut = new ListViewEncoding(); - // When - EncodeResult result = sut.encode(DTypes.LIST_I32, data, EncodeTestHelper.testCtx()); + EncodeResult result = ENCODER.encode(DTypes.LIST_I32, data, EncodeTestHelper.testCtx()); MemorySegment[] bufs = result.buffers().toArray(MemorySegment[]::new); DecodeContext ctx = new DecodeContext( - toArrayNode(result.rootNode()), DTypes.LIST_I32, 3, bufs, registry(), Arena.global()); - ListViewArray decoded = (ListViewArray) sut.decode(ctx); + toArrayNode(result.rootNode()), DTypes.LIST_I32, 3, bufs, REGISTRY, Arena.global()); + ListViewArray decoded = (ListViewArray) DECODER.decode(ctx); - // Then assertThat(decoded.length()).isEqualTo(3); IntArray decodedElems = (IntArray) decoded.elements(); assertThat(decodedElems.length()).isEqualTo(5); @@ -102,21 +98,17 @@ void roundTrip_i32Elements_preservesValues() { @Test void roundTrip_emptyLists_preservesZeroSizes() { - // Given int[] elements = {}; int[] offsets = {0, 0}; int[] sizes = {0, 0}; ListViewData data = new ListViewData(elements, offsets, sizes, 2); - ListViewEncoding sut = new ListViewEncoding(); - // When - EncodeResult result = sut.encode(DTypes.LIST_I32, data, EncodeTestHelper.testCtx()); + EncodeResult result = ENCODER.encode(DTypes.LIST_I32, data, EncodeTestHelper.testCtx()); MemorySegment[] bufs = result.buffers().toArray(MemorySegment[]::new); DecodeContext ctx = new DecodeContext( - toArrayNode(result.rootNode()), DTypes.LIST_I32, 2, bufs, registry(), Arena.global()); - ListViewArray decoded = (ListViewArray) sut.decode(ctx); + toArrayNode(result.rootNode()), DTypes.LIST_I32, 2, bufs, REGISTRY, Arena.global()); + ListViewArray decoded = (ListViewArray) DECODER.decode(ctx); - // Then assertThat(decoded.length()).isEqualTo(2); assertThat(decoded.elements().length()).isEqualTo(0); assertThat(decoded.offsets().length()).isEqualTo(2); @@ -125,21 +117,17 @@ void roundTrip_emptyLists_preservesZeroSizes() { @Test void roundTrip_singleList_preservesValues() { - // Given int[] elements = {7, 8, 9}; int[] offsets = {0}; int[] sizes = {3}; ListViewData data = new ListViewData(elements, offsets, sizes, 1); - ListViewEncoding sut = new ListViewEncoding(); - // When - EncodeResult result = sut.encode(DTypes.LIST_I32, data, EncodeTestHelper.testCtx()); + EncodeResult result = ENCODER.encode(DTypes.LIST_I32, data, EncodeTestHelper.testCtx()); MemorySegment[] bufs = result.buffers().toArray(MemorySegment[]::new); DecodeContext ctx = new DecodeContext( - toArrayNode(result.rootNode()), DTypes.LIST_I32, 1, bufs, registry(), Arena.global()); - ListViewArray decoded = (ListViewArray) sut.decode(ctx); + toArrayNode(result.rootNode()), DTypes.LIST_I32, 1, bufs, REGISTRY, Arena.global()); + ListViewArray decoded = (ListViewArray) DECODER.decode(ctx); - // Then assertThat(decoded.length()).isEqualTo(1); IntArray decodedElems = (IntArray) decoded.elements(); assertThat(decodedElems.length()).isEqualTo(3); @@ -150,31 +138,23 @@ void roundTrip_singleList_preservesValues() { @Test void decode_wrongDtype_throws() { - // Given - ListViewEncoding sut = new ListViewEncoding(); ArrayNode node = ArrayNode.of(EncodingId.VORTEX_LISTVIEW, null, new ArrayNode[0], new int[0], ArrayStats.empty()); - DecodeContext ctx = TestDecodeContexts.of(node, DTypes.I32).registry(registry()).build(); + DecodeContext ctx = TestDecodeContexts.of(node, DTypes.I32).registry(REGISTRY).build(); - // When / Then - assertThatThrownBy(() -> sut.decode(ctx)) - .hasMessageContaining("DType.List"); + assertThatThrownBy(() -> DECODER.decode(ctx)).hasMessageContaining("DType.List"); } @Test void decode_wrongChildCount_throws() { - // Given - ListViewEncoding sut = new ListViewEncoding(); ArrayNode child = ArrayNode.of(EncodingId.VORTEX_PRIMITIVE, null, new ArrayNode[0], new int[0], ArrayStats.empty()); ArrayNode node = ArrayNode.of(EncodingId.VORTEX_LISTVIEW, java.nio.ByteBuffer.wrap(new byte[0]), new ArrayNode[]{child}, new int[0], ArrayStats.empty()); - DecodeContext ctx = TestDecodeContexts.of(node, DTypes.LIST_I32).registry(registry()).build(); + DecodeContext ctx = TestDecodeContexts.of(node, DTypes.LIST_I32).registry(REGISTRY).build(); - // When / Then - assertThatThrownBy(() -> sut.decode(ctx)) - .hasMessageContaining("expected 3 or 4 children"); + assertThatThrownBy(() -> DECODER.decode(ctx)).hasMessageContaining("expected 3 or 4 children"); } } @@ -183,21 +163,16 @@ class Metadata { @Test void encode_metadata_elementsLen_matchesElementCount() throws Exception { - // Given — 5 elements across 2 outer lists - // if tag drifts, elements_len reads as 0 and decode allocates wrong-sized arrays int[] elements = {1, 2, 3, 4, 5}; int[] offsets = {0, 2}; int[] sizes = {2, 3}; ListViewData data = new ListViewData(elements, offsets, sizes, 2); - ListViewEncoding sut = new ListViewEncoding(); - // When - EncodeResult result = sut.encode(DTypes.LIST_I32, data, EncodeTestHelper.testCtx()); + EncodeResult result = ENCODER.encode(DTypes.LIST_I32, data, EncodeTestHelper.testCtx()); java.nio.ByteBuffer metaBuf = result.rootNode().metadata().duplicate(); java.lang.foreign.MemorySegment metaSeg = java.lang.foreign.MemorySegment.ofBuffer(metaBuf); ListViewMetadata meta = ListViewMetadata.decode(metaSeg, 0, metaSeg.byteSize()); - // Then assertThat(meta.elements_len()).isEqualTo(5); } } diff --git a/writer/src/test/java/io/github/dfa1/vortex/writer/encode/MaskedEncodingEncoderTest.java b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/MaskedEncodingEncoderTest.java new file mode 100644 index 00000000..9f9d3c9b --- /dev/null +++ b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/MaskedEncodingEncoderTest.java @@ -0,0 +1,161 @@ +package io.github.dfa1.vortex.writer.encode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.PType; +import io.github.dfa1.vortex.core.VortexException; +import io.github.dfa1.vortex.core.array.Array; +import io.github.dfa1.vortex.core.array.ArraySegments; +import io.github.dfa1.vortex.core.array.IntArray; +import io.github.dfa1.vortex.core.array.MaskedArray; +import io.github.dfa1.vortex.encoding.EncodeNode; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodeTestHelper; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.PTypeIO; +import io.github.dfa1.vortex.encoding.Registry; +import io.github.dfa1.vortex.encoding.TestRegistry; +import io.github.dfa1.vortex.reader.decode.BoolEncodingDecoder; +import io.github.dfa1.vortex.reader.decode.MaskedEncodingDecoder; +import io.github.dfa1.vortex.reader.decode.PrimitiveEncodingDecoder; +import org.junit.jupiter.api.Test; + +import java.lang.foreign.Arena; +import java.lang.foreign.MemorySegment; +import java.util.ArrayList; +import java.util.List; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +class MaskedEncodingEncoderTest { + + private static final MaskedEncodingDecoder DECODER = new MaskedEncodingDecoder(); + private static final PrimitiveEncodingEncoder PRIM_ENCODER = new PrimitiveEncodingEncoder(); + private static final BoolEncodingEncoder BOOL_ENCODER = new BoolEncodingEncoder(); + private static final Registry REGISTRY = TestRegistry.ofDecoders(DECODER, new PrimitiveEncodingDecoder(), new BoolEncodingDecoder()); + + private static EncodeResult maskedResult(int[] values, boolean[] validity) { + DType i32 = new DType.Primitive(PType.I32, false); + EncodeResult childResult = PRIM_ENCODER.encode(i32, values, EncodeTestHelper.testCtx()); + + List allBuffers = new ArrayList<>(childResult.buffers()); + EncodeNode[] children; + + if (validity == null) { + children = new EncodeNode[]{childResult.rootNode()}; + } else { + DType boolDtype = new DType.Bool(false); + EncodeResult validityResult = BOOL_ENCODER.encode(boolDtype, validity, EncodeTestHelper.testCtx()); + EncodeNode remapped = EncodeNode.remapBufferIndices( + validityResult.rootNode(), childResult.buffers().size()); + allBuffers.addAll(validityResult.buffers()); + children = new EncodeNode[]{childResult.rootNode(), remapped}; + } + + EncodeNode maskedNode = new EncodeNode( + EncodingId.VORTEX_MASKED, null, children, new int[]{}); + return new EncodeResult(maskedNode, allBuffers, null, null); + } + + @Test + void oneChild_noValidity_allPositionsValid() { + DType i32Nullable = new DType.Primitive(PType.I32, true); + EncodeResult ctx = maskedResult(new int[]{10, 20, 30}, null); + + Array result = DECODER.decode(EncodeTestHelper.toDecodeContext(ctx, 3L, i32Nullable, REGISTRY)); + + assertThat(result).isInstanceOf(MaskedArray.class); + MaskedArray masked = (MaskedArray) result; + assertThat(masked.length()).isEqualTo(3); + assertThat(masked.isValid(0)).isTrue(); + assertThat(masked.isValid(1)).isTrue(); + assertThat(masked.isValid(2)).isTrue(); + } + + @Test + void twoChildren_withValidity_masksNulls() { + DType i32Nullable = new DType.Primitive(PType.I32, true); + EncodeResult ctx = maskedResult(new int[]{1, 2, 3, 4, 5}, + new boolean[]{true, false, true, false, true}); + + Array result = DECODER.decode(EncodeTestHelper.toDecodeContext(ctx, 5L, i32Nullable, REGISTRY)); + + MaskedArray masked = (MaskedArray) result; + assertThat(masked.length()).isEqualTo(5); + assertThat(masked.isValid(0)).isTrue(); + assertThat(masked.isValid(1)).isFalse(); + assertThat(masked.isValid(2)).isTrue(); + assertThat(masked.isValid(3)).isFalse(); + assertThat(masked.isValid(4)).isTrue(); + } + + @Test + void dtype_isNullable() { + DType i32Nullable = new DType.Primitive(PType.I32, true); + EncodeResult ctx = maskedResult(new int[]{1, 2, 3}, null); + + Array result = DECODER.decode(EncodeTestHelper.toDecodeContext(ctx, 3L, i32Nullable, REGISTRY)); + + assertThat(result.dtype().nullable()).isTrue(); + } + + @Test + void inner_containsChildValues() { + DType i32Nullable = new DType.Primitive(PType.I32, true); + EncodeResult ctx = maskedResult(new int[]{7, 8, 9}, null); + + MaskedArray result = (MaskedArray) DECODER.decode(EncodeTestHelper.toDecodeContext(ctx, 3L, i32Nullable, REGISTRY)); + IntArray inner = (IntArray) result.inner(); + + assertThat(ArraySegments.of(inner).get(PTypeIO.LE_INT, 0L)).isEqualTo(7); + assertThat(ArraySegments.of(inner).get(PTypeIO.LE_INT, 4L)).isEqualTo(8); + assertThat(ArraySegments.of(inner).get(PTypeIO.LE_INT, 8L)).isEqualTo(9); + } + + @Test + void buffersPresentThrows() { + DType i32Nullable = new DType.Primitive(PType.I32, true); + + EncodeNode childNode = EncodeNode.leaf(EncodingId.VORTEX_PRIMITIVE, 0); + EncodeNode maskedNode = new EncodeNode( + EncodingId.VORTEX_MASKED, null, + new EncodeNode[]{childNode}, new int[]{1}); + MemorySegment dummyBuf = Arena.ofAuto().allocate(4); + EncodeResult result = new EncodeResult(maskedNode, List.of(dummyBuf, dummyBuf), null, null); + + assertThatThrownBy(() -> DECODER.decode(EncodeTestHelper.toDecodeContext(result, 1L, i32Nullable, REGISTRY))) + .isInstanceOf(VortexException.class) + .hasMessageContaining("expected 0 buffers"); + } + + @Test + void zeroChildrenThrows() { + DType i32Nullable = new DType.Primitive(PType.I32, true); + + EncodeNode maskedNode = new EncodeNode( + EncodingId.VORTEX_MASKED, null, new EncodeNode[]{}, new int[]{}); + EncodeResult result = new EncodeResult(maskedNode, List.of(), null, null); + + assertThatThrownBy(() -> DECODER.decode(EncodeTestHelper.toDecodeContext(result, 0L, i32Nullable, REGISTRY))) + .isInstanceOf(VortexException.class) + .hasMessageContaining("expected 1 or 2 children"); + } + + @Test + void threeChildrenThrows() { + DType i32Nullable = new DType.Primitive(PType.I32, true); + + DType i32 = new DType.Primitive(PType.I32, false); + EncodeResult childResult = PRIM_ENCODER.encode(i32, new int[]{1}, EncodeTestHelper.testCtx()); + EncodeNode childNode = childResult.rootNode(); + EncodeNode maskedNode = new EncodeNode( + EncodingId.VORTEX_MASKED, null, + new EncodeNode[]{childNode, childNode, childNode}, new int[]{}); + List bufs = new ArrayList<>(childResult.buffers()); + EncodeResult result = new EncodeResult(maskedNode, bufs, null, null); + + assertThatThrownBy(() -> DECODER.decode(EncodeTestHelper.toDecodeContext(result, 1L, i32Nullable, REGISTRY))) + .isInstanceOf(VortexException.class) + .hasMessageContaining("expected 1 or 2 children"); + } +} diff --git a/core/src/test/java/io/github/dfa1/vortex/encoding/NullEncodingTest.java b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/NullEncodingEncoderTest.java similarity index 51% rename from core/src/test/java/io/github/dfa1/vortex/encoding/NullEncodingTest.java rename to writer/src/test/java/io/github/dfa1/vortex/writer/encode/NullEncodingEncoderTest.java index 8646cb9a..03148744 100644 --- a/core/src/test/java/io/github/dfa1/vortex/encoding/NullEncodingTest.java +++ b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/NullEncodingEncoderTest.java @@ -1,6 +1,14 @@ -package io.github.dfa1.vortex.encoding; +package io.github.dfa1.vortex.writer.encode; import io.github.dfa1.vortex.core.array.NullArray; +import io.github.dfa1.vortex.encoding.ArrayNode; +import io.github.dfa1.vortex.encoding.DTypes; +import io.github.dfa1.vortex.encoding.DecodeContext; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodeTestHelper; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.Registry; +import io.github.dfa1.vortex.reader.decode.NullEncodingDecoder; import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; @@ -9,7 +17,7 @@ import static org.assertj.core.api.Assertions.assertThat; -class NullEncodingTest { +class NullEncodingEncoderTest { @Nested class Encode { @@ -17,7 +25,7 @@ class Encode { @Test void encode_producesEmptyNode() { // Given - var sut = new NullEncoding(); + var sut = new NullEncodingEncoder(); // When EncodeResult result = sut.encode(DTypes.NULL, null, EncodeTestHelper.testCtx()); @@ -32,44 +40,19 @@ void encode_producesEmptyNode() { void encode_thenDecode_roundTrips() { // Given long rowCount = 10L; - var sut = new NullEncoding(); + var encoder = new NullEncodingEncoder(); + var decoder = new NullEncodingDecoder(); // When - EncodeResult encoded = sut.encode(DTypes.NULL, null, EncodeTestHelper.testCtx()); + EncodeResult encoded = encoder.encode(DTypes.NULL, null, EncodeTestHelper.testCtx()); ArrayNode node = ArrayNode.of(encoded.rootNode().encodingId(), null, new ArrayNode[0], new int[0], null); DecodeContext ctx = new DecodeContext(node, DTypes.NULL, rowCount, new MemorySegment[0], Registry.empty(), Arena.ofAuto()); // Then - var decoded = sut.decode(ctx); + var decoded = decoder.decode(ctx); assertThat(decoded).isInstanceOf(NullArray.class); assertThat(decoded.length()).isEqualTo(rowCount); } } - - @Nested - class Decode { - - private static DecodeContext buildNullCtx(long rowCount) { - ArrayNode node = ArrayNode.of(EncodingId.VORTEX_NULL, null, new ArrayNode[0], new int[0], null); - Registry registry = Registry.builder().register(new NullEncoding()).build(); - return new DecodeContext(node, DTypes.NULL, rowCount, new MemorySegment[0], registry, Arena.ofAuto()); - } - - @Test - void decode_nullArray_returnsNullArrayWithCorrectLength() { - // Given - long rowCount = 42L; - DecodeContext ctx = buildNullCtx(rowCount); - var sut = new NullEncoding(); - - // When - var result = sut.decode(ctx); - - // Then - assertThat(result).isInstanceOf(NullArray.class); - assertThat(result.length()).isEqualTo(rowCount); - assertThat(result.dtype()).isEqualTo(DTypes.NULL); - } - } } diff --git a/writer/src/test/java/io/github/dfa1/vortex/writer/encode/PcoEncodingEncoderTest.java b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/PcoEncodingEncoderTest.java new file mode 100644 index 00000000..3779e503 --- /dev/null +++ b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/PcoEncodingEncoderTest.java @@ -0,0 +1,22 @@ +package io.github.dfa1.vortex.writer.encode; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.PType; +import io.github.dfa1.vortex.core.VortexException; +import io.github.dfa1.vortex.encoding.EncodeTestHelper; +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +class PcoEncodingEncoderTest { + + @Test + void encode_throwsVortexException() { + var sut = new PcoEncodingEncoder(); + DType dtype = new DType.Primitive(PType.I64, false); + + assertThatThrownBy(() -> sut.encode(dtype, new long[]{1L, 2L, 3L}, EncodeTestHelper.testCtx())) + .isInstanceOf(VortexException.class) + .hasMessageContaining("not implemented"); + } +} diff --git a/core/src/test/java/io/github/dfa1/vortex/encoding/PrimitiveEncodingTest.java b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/PrimitiveEncodingEncoderTest.java similarity index 65% rename from core/src/test/java/io/github/dfa1/vortex/encoding/PrimitiveEncodingTest.java rename to writer/src/test/java/io/github/dfa1/vortex/writer/encode/PrimitiveEncodingEncoderTest.java index 8cb6ea37..c30c92e4 100644 --- a/core/src/test/java/io/github/dfa1/vortex/encoding/PrimitiveEncodingTest.java +++ b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/PrimitiveEncodingEncoderTest.java @@ -1,4 +1,4 @@ -package io.github.dfa1.vortex.encoding; +package io.github.dfa1.vortex.writer.encode; import io.github.dfa1.vortex.core.ArrayStats; import io.github.dfa1.vortex.core.DType; @@ -7,6 +7,17 @@ import io.github.dfa1.vortex.core.array.ArraySegments; import io.github.dfa1.vortex.core.array.IntArray; import io.github.dfa1.vortex.core.array.MaskedArray; +import io.github.dfa1.vortex.encoding.ArrayNode; +import io.github.dfa1.vortex.encoding.DecodeContext; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodeTestHelper; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.PTypeIO; +import io.github.dfa1.vortex.encoding.Registry; +import io.github.dfa1.vortex.encoding.TestRegistry; +import io.github.dfa1.vortex.encoding.TestSegments; +import io.github.dfa1.vortex.reader.decode.BoolEncodingDecoder; +import io.github.dfa1.vortex.reader.decode.PrimitiveEncodingDecoder; import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; @@ -18,8 +29,11 @@ import static org.assertj.core.api.Assertions.assertThat; -/// Property: encode then decode is lossless for all primitive types and array sizes. -class PrimitiveEncodingTest { +class PrimitiveEncodingEncoderTest { + + private static final PrimitiveEncodingEncoder ENCODER = new PrimitiveEncodingEncoder(); + private static final PrimitiveEncodingDecoder DECODER = new PrimitiveEncodingDecoder(); + private static final Registry REGISTRY = TestRegistry.ofDecoders(DECODER); @Nested class Encode { @@ -57,78 +71,52 @@ static Stream doubleArrays() { @ParameterizedTest @MethodSource("longArrays") void encodeDecode_i64_isLossless(long[] data) { - // Given DType dtype = new DType.Primitive(PType.I64, false); - var sut = new PrimitiveEncoding(); - Registry registry = TestRegistry.of(sut); - - // When - EncodeResult encoded = sut.encode(dtype, data, EncodeTestHelper.testCtx()); - DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, data.length, dtype, registry); - Array result = sut.decode(ctx); + EncodeResult encoded = ENCODER.encode(dtype, data, EncodeTestHelper.testCtx()); + DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, data.length, dtype, REGISTRY); + Array result = DECODER.decode(ctx); - // Then — roundtrip lossless assertThat(result.length()).isEqualTo(data.length); - var le = PTypeIO.LE_LONG; for (int i = 0; i < data.length; i++) { - assertThat(ArraySegments.of(result).get(le, (long) i * 8)).isEqualTo(data[i]); + assertThat(ArraySegments.of(result).get(PTypeIO.LE_LONG, (long) i * 8)).isEqualTo(data[i]); } } @ParameterizedTest @MethodSource("intArrays") void encodeDecode_i32_isLossless(int[] data) { - // Given DType dtype = new DType.Primitive(PType.I32, false); - var sut = new PrimitiveEncoding(); - Registry registry = TestRegistry.of(sut); - - // When - EncodeResult encoded = sut.encode(dtype, data, EncodeTestHelper.testCtx()); - DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, data.length, dtype, registry); - Array result = sut.decode(ctx); + EncodeResult encoded = ENCODER.encode(dtype, data, EncodeTestHelper.testCtx()); + DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, data.length, dtype, REGISTRY); + Array result = DECODER.decode(ctx); - // Then — roundtrip lossless assertThat(result.length()).isEqualTo(data.length); - var le = PTypeIO.LE_INT; for (int i = 0; i < data.length; i++) { - assertThat(ArraySegments.of(result).get(le, (long) i * 4)).isEqualTo(data[i]); + assertThat(ArraySegments.of(result).get(PTypeIO.LE_INT, (long) i * 4)).isEqualTo(data[i]); } } @ParameterizedTest @MethodSource("doubleArrays") void encodeDecode_f64_isLossless(double[] data) { - // Given DType dtype = new DType.Primitive(PType.F64, false); - var sut = new PrimitiveEncoding(); - Registry registry = TestRegistry.of(sut); + EncodeResult encoded = ENCODER.encode(dtype, data, EncodeTestHelper.testCtx()); + DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, data.length, dtype, REGISTRY); + Array result = DECODER.decode(ctx); - // When - EncodeResult encoded = sut.encode(dtype, data, EncodeTestHelper.testCtx()); - DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, data.length, dtype, registry); - Array result = sut.decode(ctx); - - // Then — roundtrip lossless assertThat(result.length()).isEqualTo(data.length); - var le = PTypeIO.LE_DOUBLE; for (int i = 0; i < data.length; i++) { - assertThat(ArraySegments.of(result).get(le, (long) i * 8)).isEqualTo(data[i]); + assertThat(ArraySegments.of(result).get(PTypeIO.LE_DOUBLE, (long) i * 8)).isEqualTo(data[i]); } } @ParameterizedTest @MethodSource("longArrays") void encodedSize_equalsBytesInBuffer(long[] data) { - // Given DType dtype = new DType.Primitive(PType.I64, false); - var sut = new PrimitiveEncoding(); - - // When - EncodeResult encoded = sut.encode(dtype, data, EncodeTestHelper.testCtx()); + EncodeResult encoded = ENCODER.encode(dtype, data, EncodeTestHelper.testCtx()); - // Then — no compression: wire size = n * elemBytes - long totalBytes = encoded.buffers().stream().mapToLong(java.lang.foreign.MemorySegment::byteSize).sum(); + long totalBytes = encoded.buffers().stream().mapToLong(MemorySegment::byteSize).sum(); assertThat(totalBytes).isEqualTo((long) data.length * 8); } } @@ -138,17 +126,16 @@ class Decode { @Test void decode_withValidityChild_returnsMaskedArray() { - // Given — 4 I32 values; positions 1 and 3 are null (validity bitmap: 0b00000101 = 0x05) - int[] raw = {10, 0, 20, 0}; // garbage at null positions + int[] raw = {10, 0, 20, 0}; MemorySegment valuesSeg = TestSegments.leInts(raw); - MemorySegment validitySeg = MemorySegment.ofArray(new byte[]{0x05}); // bits 0,2 set + MemorySegment validitySeg = MemorySegment.ofArray(new byte[]{0x05}); ArrayNode validityNode = ArrayNode.of( EncodingId.VORTEX_BOOL, null, new ArrayNode[0], new int[]{1}, ArrayStats.empty()); ArrayNode primNode = ArrayNode.of( EncodingId.VORTEX_PRIMITIVE, null, new ArrayNode[]{validityNode}, new int[]{0}, ArrayStats.empty()); - Registry registry = TestRegistry.of(new PrimitiveEncoding(), new BoolEncoding()); + Registry registry = TestRegistry.ofDecoders(new PrimitiveEncodingDecoder(), new BoolEncodingDecoder()); DType dtype = new DType.Primitive(PType.I32, false); DecodeContext ctx = new DecodeContext( @@ -156,12 +143,8 @@ void decode_withValidityChild_returnsMaskedArray() { new MemorySegment[]{valuesSeg, validitySeg}, registry, Arena.global()); - PrimitiveEncoding sut = new PrimitiveEncoding(); + Array result = DECODER.decode(ctx); - // When - Array result = sut.decode(ctx); - - // Then — returns MaskedArray; only valid positions are usable assertThat(result).isInstanceOf(MaskedArray.class); MaskedArray masked = (MaskedArray) result; assertThat(masked.inner()).isInstanceOf(IntArray.class); @@ -176,27 +159,19 @@ void decode_withValidityChild_returnsMaskedArray() { @Test void decode_noValidityChild_returnsPlainArray() { - // Given — 3 I32 values; no validity child int[] raw = {1, 2, 3}; MemorySegment valuesSeg = TestSegments.leInts(raw); ArrayNode primNode = ArrayNode.of( EncodingId.VORTEX_PRIMITIVE, null, new ArrayNode[0], new int[]{0}, ArrayStats.empty()); - Registry registry = TestRegistry.of(new PrimitiveEncoding()); - DType dtype = new DType.Primitive(PType.I32, false); DecodeContext ctx = new DecodeContext( primNode, dtype, raw.length, new MemorySegment[]{valuesSeg}, - registry, Arena.global()); - - PrimitiveEncoding sut = new PrimitiveEncoding(); - - // When - Array result = sut.decode(ctx); + REGISTRY, Arena.global()); - // Then — plain array, not MaskedArray + Array result = DECODER.decode(ctx); assertThat(result).isInstanceOf(IntArray.class); } } diff --git a/core/src/test/java/io/github/dfa1/vortex/encoding/RandomAccessTest.java b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/RandomAccessTest.java similarity index 54% rename from core/src/test/java/io/github/dfa1/vortex/encoding/RandomAccessTest.java rename to writer/src/test/java/io/github/dfa1/vortex/writer/encode/RandomAccessTest.java index a2a6266c..30b77e6d 100644 --- a/core/src/test/java/io/github/dfa1/vortex/encoding/RandomAccessTest.java +++ b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/RandomAccessTest.java @@ -1,8 +1,18 @@ -package io.github.dfa1.vortex.encoding; +package io.github.dfa1.vortex.writer.encode; import io.github.dfa1.vortex.core.DType; import io.github.dfa1.vortex.core.array.Array; import io.github.dfa1.vortex.core.array.LongArray; +import io.github.dfa1.vortex.encoding.DTypes; +import io.github.dfa1.vortex.encoding.DecodeContext; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodeTestHelper; +import io.github.dfa1.vortex.encoding.EncodingEncoder; +import io.github.dfa1.vortex.encoding.Registry; +import io.github.dfa1.vortex.encoding.TestRegistry; +import io.github.dfa1.vortex.reader.decode.BitpackedEncodingDecoder; +import io.github.dfa1.vortex.reader.decode.FrameOfReferenceEncodingDecoder; +import io.github.dfa1.vortex.reader.decode.PrimitiveEncodingDecoder; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; @@ -13,17 +23,6 @@ import static org.assertj.core.api.Assertions.assertThat; /// Verifies that decoded arrays support correct random-order element access. -/// -/// The O(1) random-access claim in docs/explanation.md holds only if every -/// position decodes independently. This test catches encoder bugs where -/// value[N] accidentally depends on value[N-1] — e.g. an unapplied delta, -/// an off-by-one in residual accumulation, or a stateful read cursor that -/// advances forward and gives wrong values when accessed out of order. -/// -/// Data: {@code value[i] = i * 1_000_003 + 7} — every position unique, -/// prime multiplier prevents accidental aliasing. -/// Access orders: reverse (N-1…0) and seeded random — the combination -/// catches symmetric bugs that reverse alone might miss. class RandomAccessTest { private static final int N = 1024; @@ -31,37 +30,36 @@ class RandomAccessTest { static Stream encodings() { return Stream.of( - Arguments.of("Primitive", new PrimitiveEncoding(), DTypes.I64), - Arguments.of("BitpackedU64", new BitpackedEncoding(), DTypes.U64), - Arguments.of("FrameOfReference", new FrameOfReferenceEncoding(), DTypes.I64) + Arguments.of("Primitive", new PrimitiveEncodingEncoder(), new PrimitiveEncodingDecoder(), DTypes.I64), + Arguments.of("BitpackedU64", new BitpackedEncodingEncoder(), new BitpackedEncodingDecoder(), DTypes.U64), + Arguments.of("FrameOfReference", new FrameOfReferenceEncodingEncoder(), new FrameOfReferenceEncodingDecoder(), DTypes.I64) ); } @ParameterizedTest(name = "{0}") @MethodSource("encodings") - void randomOrderAccess_matchesForwardOrder(String name, Encoding sut, DType dtype) { - // Given — unique value per position, no two rows the same + void randomOrderAccess_matchesForwardOrder(String name, + EncodingEncoder encoder, io.github.dfa1.vortex.encoding.EncodingDecoder decoder, DType dtype) { long[] original = new long[N]; for (int i = 0; i < N; i++) { original[i] = (long) i * 1_000_003L + 7L; } - Registry registry = TestRegistry.withPrimitive(sut); - EncodeResult encoded = sut.encode(dtype, original, EncodeTestHelper.testCtx()); + Registry registry = decoder instanceof PrimitiveEncodingDecoder + ? TestRegistry.ofDecoders(decoder) + : TestRegistry.ofDecoders(decoder, new PrimitiveEncodingDecoder()); + EncodeResult encoded = encoder.encode(dtype, original, EncodeTestHelper.testCtx()); DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, N, dtype, registry); - // When - Array array = sut.decode(ctx); + Array array = decoder.decode(ctx); LongArray result = (LongArray) array; - // Then — reverse order for (int i = N - 1; i >= 0; i--) { assertThat(result.getLong(i)) .as("reverse access at index %d", i) .isEqualTo(original[i]); } - // Then — random order (seeded for reproducibility) Random rng = new Random(SEED); for (int check = 0; check < N; check++) { int i = rng.nextInt(N); diff --git a/core/src/test/java/io/github/dfa1/vortex/encoding/RleEncodingTest.java b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/RleEncodingEncoderTest.java similarity index 60% rename from core/src/test/java/io/github/dfa1/vortex/encoding/RleEncodingTest.java rename to writer/src/test/java/io/github/dfa1/vortex/writer/encode/RleEncodingEncoderTest.java index f7adef03..436d15c0 100644 --- a/core/src/test/java/io/github/dfa1/vortex/encoding/RleEncodingTest.java +++ b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/RleEncodingEncoderTest.java @@ -1,6 +1,5 @@ -package io.github.dfa1.vortex.encoding; +package io.github.dfa1.vortex.writer.encode; -import io.github.dfa1.vortex.proto.RLEMetadata; import io.github.dfa1.vortex.core.ArrayStats; import io.github.dfa1.vortex.core.DType; import io.github.dfa1.vortex.core.PType; @@ -8,6 +7,21 @@ import io.github.dfa1.vortex.core.array.ArraySegments; import io.github.dfa1.vortex.core.array.IntArray; import io.github.dfa1.vortex.core.array.MaskedArray; +import io.github.dfa1.vortex.encoding.ArrayNode; +import io.github.dfa1.vortex.encoding.DTypes; +import io.github.dfa1.vortex.encoding.DecodeContext; +import io.github.dfa1.vortex.encoding.EncodeNode; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodeTestHelper; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.KnownArrayNode; +import io.github.dfa1.vortex.encoding.PTypeIO; +import io.github.dfa1.vortex.encoding.Registry; +import io.github.dfa1.vortex.encoding.TestRegistry; +import io.github.dfa1.vortex.proto.RLEMetadata; +import io.github.dfa1.vortex.reader.decode.BoolEncodingDecoder; +import io.github.dfa1.vortex.reader.decode.PrimitiveEncodingDecoder; +import io.github.dfa1.vortex.reader.decode.RleEncodingDecoder; import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; @@ -20,14 +34,11 @@ import static org.assertj.core.api.Assertions.assertThat; -class RleEncodingTest { +class RleEncodingEncoderTest { - private static Registry registry() { - return Registry.builder() - .register(new RleEncoding()) - .register(new PrimitiveEncoding()) - .build(); - } + private static final RleEncodingEncoder ENCODER = new RleEncodingEncoder(); + private static final RleEncodingDecoder DECODER = new RleEncodingDecoder(); + private static final Registry REGISTRY = TestRegistry.ofDecoders(DECODER, new PrimitiveEncodingDecoder()); private static KnownArrayNode toArrayNode(EncodeNode enc) { ArrayNode[] children = new ArrayNode[enc.children().length]; @@ -42,53 +53,35 @@ class Encode { @Test void roundTrip_empty_i32() { - // Given - var sut = new RleEncoding(); DType dtype = new DType.Primitive(PType.I32, false); - - // When - EncodeResult encoded = sut.encode(dtype, new int[0], EncodeTestHelper.testCtx()); - DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, 0, dtype, registry()); - Array result = sut.decode(ctx); - - // Then + EncodeResult encoded = ENCODER.encode(dtype, new int[0], EncodeTestHelper.testCtx()); + DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, 0, dtype, REGISTRY); + Array result = DECODER.decode(ctx); assertThat(result.length()).isZero(); } @Test void roundTrip_singleElement_i32() { - // Given - var sut = new RleEncoding(); DType dtype = new DType.Primitive(PType.I32, false); int[] data = {42}; - - // When - EncodeResult encoded = sut.encode(dtype, data, EncodeTestHelper.testCtx()); - DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, data.length, dtype, registry()); - Array result = sut.decode(ctx); - - // Then + EncodeResult encoded = ENCODER.encode(dtype, data, EncodeTestHelper.testCtx()); + DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, data.length, dtype, REGISTRY); + Array result = DECODER.decode(ctx); assertThat(result.length()).isEqualTo(1); assertThat(ArraySegments.of(result).get(PTypeIO.LE_INT, 0)).isEqualTo(42); } @Test void roundTrip_constantArray_i32() { - // Given - var sut = new RleEncoding(); DType dtype = new DType.Primitive(PType.I32, false); int n = 2048; int[] data = new int[n]; for (int i = 0; i < n; i++) { data[i] = 99; } - - // When - EncodeResult encoded = sut.encode(dtype, data, EncodeTestHelper.testCtx()); - DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, n, dtype, registry()); - Array result = sut.decode(ctx); - - // Then + EncodeResult encoded = ENCODER.encode(dtype, data, EncodeTestHelper.testCtx()); + DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, n, dtype, REGISTRY); + Array result = DECODER.decode(ctx); assertThat(result.length()).isEqualTo(n); for (int i = 0; i < n; i++) { assertThat(ArraySegments.of(result).get(PTypeIO.LE_INT, (long) i * 4)).as("index %d", i).isEqualTo(99); @@ -97,41 +90,27 @@ void roundTrip_constantArray_i32() { @Test void roundTrip_classicRunLengthData_i32() { - // Given - var sut = new RleEncoding(); DType dtype = new DType.Primitive(PType.I32, false); int[] data = {1, 1, 1, 2, 2, 3}; - - // When - EncodeResult encoded = sut.encode(dtype, data, EncodeTestHelper.testCtx()); - DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, data.length, dtype, registry()); - Array result = sut.decode(ctx); - - // Then - assertThat(result.length()).isEqualTo(data.length); - int[] expected = {1, 1, 1, 2, 2, 3}; - for (int i = 0; i < expected.length; i++) { - assertThat(ArraySegments.of(result).get(PTypeIO.LE_INT, (long) i * 4)).as("index %d", i).isEqualTo(expected[i]); + EncodeResult encoded = ENCODER.encode(dtype, data, EncodeTestHelper.testCtx()); + DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, data.length, dtype, REGISTRY); + Array result = DECODER.decode(ctx); + for (int i = 0; i < data.length; i++) { + assertThat(ArraySegments.of(result).get(PTypeIO.LE_INT, (long) i * 4)).as("index %d", i).isEqualTo(data[i]); } } @Test void roundTrip_multipleChunks_i32() { - // Given: spans 3 chunks - var sut = new RleEncoding(); DType dtype = new DType.Primitive(PType.I32, false); int n = 3000; int[] data = new int[n]; for (int i = 0; i < n; i++) { data[i] = i / 100; } - - // When - EncodeResult encoded = sut.encode(dtype, data, EncodeTestHelper.testCtx()); - DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, n, dtype, registry()); - Array result = sut.decode(ctx); - - // Then + EncodeResult encoded = ENCODER.encode(dtype, data, EncodeTestHelper.testCtx()); + DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, n, dtype, REGISTRY); + Array result = DECODER.decode(ctx); assertThat(result.length()).isEqualTo(n); for (int i = 0; i < n; i++) { assertThat(ArraySegments.of(result).get(PTypeIO.LE_INT, (long) i * 4)).as("index %d", i).isEqualTo(i / 100); @@ -140,18 +119,11 @@ void roundTrip_multipleChunks_i32() { @Test void roundTrip_i64() { - // Given - var sut = new RleEncoding(); DType dtype = new DType.Primitive(PType.I64, false); long[] data = {100L, 100L, 200L, 300L, 300L, 300L}; - - // When - EncodeResult encoded = sut.encode(dtype, data, EncodeTestHelper.testCtx()); - DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, data.length, dtype, registry()); - Array result = sut.decode(ctx); - - // Then - assertThat(result.length()).isEqualTo(data.length); + EncodeResult encoded = ENCODER.encode(dtype, data, EncodeTestHelper.testCtx()); + DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, data.length, dtype, REGISTRY); + Array result = DECODER.decode(ctx); for (int i = 0; i < data.length; i++) { assertThat(ArraySegments.of(result).get(PTypeIO.LE_LONG, (long) i * 8)).as("index %d", i).isEqualTo(data[i]); } @@ -160,20 +132,14 @@ void roundTrip_i64() { @ParameterizedTest @ValueSource(ints = {1, 512, 1023, 1024, 1025, 2048, 2049}) void roundTrip_variousLengths_i32(int n) { - // Given - var sut = new RleEncoding(); DType dtype = new DType.Primitive(PType.I32, false); int[] data = new int[n]; for (int i = 0; i < n; i++) { data[i] = i / 50; } - - // When - EncodeResult encoded = sut.encode(dtype, data, EncodeTestHelper.testCtx()); - DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, n, dtype, registry()); - Array result = sut.decode(ctx); - - // Then + EncodeResult encoded = ENCODER.encode(dtype, data, EncodeTestHelper.testCtx()); + DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, n, dtype, REGISTRY); + Array result = DECODER.decode(ctx); assertThat(result.length()).isEqualTo(n); for (int i = 0; i < n; i++) { assertThat(ArraySegments.of(result).get(PTypeIO.LE_INT, (long) i * 4)).as("index %d", i).isEqualTo(i / 50); @@ -182,42 +148,27 @@ void roundTrip_variousLengths_i32(int n) { @Test void roundTrip_allDifferent_u16() { - // Given: worst case — every consecutive value is unique (no compression) - var sut = new RleEncoding(); DType dtype = new DType.Primitive(PType.U16, false); short[] data = new short[256]; for (int i = 0; i < 256; i++) { data[i] = (short) i; } - - // When - EncodeResult encoded = sut.encode(dtype, data, EncodeTestHelper.testCtx()); - DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, data.length, dtype, registry()); - Array result = sut.decode(ctx); - - // Then - assertThat(result.length()).isEqualTo(data.length); - var le = PTypeIO.LE_SHORT; + EncodeResult encoded = ENCODER.encode(dtype, data, EncodeTestHelper.testCtx()); + DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, data.length, dtype, REGISTRY); + Array result = DECODER.decode(ctx); for (int i = 0; i < data.length; i++) { - assertThat(Short.toUnsignedInt(ArraySegments.of(result).get(le, (long) i * 2))) + assertThat(Short.toUnsignedInt(ArraySegments.of(result).get(PTypeIO.LE_SHORT, (long) i * 2))) .as("index %d", i).isEqualTo(i); } } @Test void roundTrip_negativeValues_i32() { - // Given - var sut = new RleEncoding(); DType dtype = new DType.Primitive(PType.I32, false); int[] data = {-3, -3, -1, -1, 0, 0, 5}; - - // When - EncodeResult encoded = sut.encode(dtype, data, EncodeTestHelper.testCtx()); - DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, data.length, dtype, registry()); - Array result = sut.decode(ctx); - - // Then - assertThat(result.length()).isEqualTo(data.length); + EncodeResult encoded = ENCODER.encode(dtype, data, EncodeTestHelper.testCtx()); + DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, data.length, dtype, REGISTRY); + Array result = DECODER.decode(ctx); for (int i = 0; i < data.length; i++) { assertThat(ArraySegments.of(result).get(PTypeIO.LE_INT, (long) i * 4)).as("index %d", i).isEqualTo(data[i]); } @@ -229,67 +180,46 @@ class Decode { @Test void decode_exactlyOneChunk_correctLength() { - // Given - var sut = new RleEncoding(); DType dtype = new DType.Primitive(PType.I32, false); int[] data = new int[1024]; for (int i = 0; i < 1024; i++) { data[i] = i / 10; } - - // When - EncodeResult encoded = sut.encode(dtype, data, EncodeTestHelper.testCtx()); - DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, 1024, dtype, registry()); - Array result = sut.decode(ctx); - - // Then + EncodeResult encoded = ENCODER.encode(dtype, data, EncodeTestHelper.testCtx()); + DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, 1024, dtype, REGISTRY); + Array result = DECODER.decode(ctx); assertThat(result.length()).isEqualTo(1024); } @Test void decode_crossesChunkBoundary_correctValues() { - // Given: values span the chunk boundary at element 1024 - var sut = new RleEncoding(); DType dtype = new DType.Primitive(PType.I32, false); int n = 2048; int[] data = new int[n]; for (int i = 0; i < n; i++) { data[i] = i / 100; } - - // When - EncodeResult encoded = sut.encode(dtype, data, EncodeTestHelper.testCtx()); - DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, n, dtype, registry()); - Array result = sut.decode(ctx); - - // Then — verify values near the chunk boundary + EncodeResult encoded = ENCODER.encode(dtype, data, EncodeTestHelper.testCtx()); + DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, n, dtype, REGISTRY); + Array result = DECODER.decode(ctx); for (int i = 1000; i < 1048; i++) { - assertThat(ArraySegments.of(result).get(PTypeIO.LE_INT, (long) i * 4)) - .as("index %d", i).isEqualTo(i / 100); + assertThat(ArraySegments.of(result).get(PTypeIO.LE_INT, (long) i * 4)).as("index %d", i).isEqualTo(i / 100); } } @Test void decode_nullableIndices_returnsMaskedArrayWithCorrectValidity() { - // Given — encode [10, 10, 20, 20]; then inject a validity bitmap into the indices node - // so positions 1 and 3 are null. Valid outputs: [10, null, 20, null]. - var sut = new RleEncoding(); DType dtype = new DType.Primitive(PType.I32, false); int[] data = {10, 10, 20, 20}; - EncodeResult encoded = sut.encode(dtype, data, EncodeTestHelper.testCtx()); + EncodeResult encoded = ENCODER.encode(dtype, data, EncodeTestHelper.testCtx()); - // Original buffers: [values_buf, indices_buf, offsets_buf] List originalBufs = new ArrayList<>(encoded.buffers()); - // Validity bitmap: bits 0 and 2 set → 0b00000101 = 0x05 → positions 0,2 valid MemorySegment validityBuf = MemorySegment.ofArray(new byte[]{0x05}); - originalBufs.add(validityBuf); // index 3 + originalBufs.add(validityBuf); MemorySegment[] segments = originalBufs.toArray(new MemorySegment[0]); - // Rebuild the ArrayNode tree from the encode result KnownArrayNode origRoot = toArrayNode(encoded.rootNode()); - // RLE root children: [values(0), indices(1), offsets(2)] KnownArrayNode origIndices = (KnownArrayNode) origRoot.children()[1]; - // Wrap indices with a validity child pointing to buffer 3 ArrayNode validityNode = ArrayNode.of( EncodingId.VORTEX_BOOL, null, new ArrayNode[0], new int[]{3}, ArrayStats.empty()); ArrayNode nullableIndices = ArrayNode.of( @@ -300,17 +230,11 @@ void decode_nullableIndices_returnsMaskedArrayWithCorrectValidity() { new ArrayNode[]{origRoot.children()[0], nullableIndices, origRoot.children()[2]}, origRoot.bufferIndices(), ArrayStats.empty()); - Registry reg = Registry.builder() - .register(sut) - .register(new PrimitiveEncoding()) - .register(new BoolEncoding()) - .build(); + Registry reg = TestRegistry.ofDecoders(DECODER, new PrimitiveEncodingDecoder(), new BoolEncodingDecoder()); DecodeContext ctx = new DecodeContext(root, dtype, data.length, segments, reg, Arena.ofAuto()); - // When - Array result = sut.decode(ctx); + Array result = DECODER.decode(ctx); - // Then — MaskedArray; valid at positions 0 and 2 assertThat(result).isInstanceOf(MaskedArray.class); MaskedArray masked = (MaskedArray) result; assertThat(masked.isValid(0)).isTrue(); @@ -324,21 +248,15 @@ void decode_nullableIndices_returnsMaskedArrayWithCorrectValidity() { @Test void decode_partialLastChunk_correctLength() { - // Given: 1500 elements — two chunks (1024 full + 476 partial) - var sut = new RleEncoding(); DType dtype = new DType.Primitive(PType.I32, false); int n = 1500; int[] data = new int[n]; for (int i = 0; i < n; i++) { data[i] = i / 100; } - - // When - EncodeResult encoded = sut.encode(dtype, data, EncodeTestHelper.testCtx()); - DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, n, dtype, registry()); - Array result = sut.decode(ctx); - - // Then + EncodeResult encoded = ENCODER.encode(dtype, data, EncodeTestHelper.testCtx()); + DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, n, dtype, REGISTRY); + Array result = DECODER.decode(ctx); assertThat(result.length()).isEqualTo(n); for (int i = 0; i < n; i++) { assertThat(ArraySegments.of(result).get(PTypeIO.LE_INT, (long) i * 4)).as("index %d", i).isEqualTo(i / 100); @@ -347,19 +265,14 @@ void decode_partialLastChunk_correctLength() { @Test void encode_i32_metadata_valuesLen_matchesRunCount() throws Exception { - // Given — 2 distinct runs; if tag drifts, values_len reads as 0 (proto3 default) int[] data = {1, 1, 1, 2, 2, 2}; - RleEncoding sut = new RleEncoding(); - - // When - EncodeResult result = sut.encode(DTypes.I32, data, EncodeTestHelper.testCtx()); - RLEMetadata meta = - RLEMetadata.decode(java.lang.foreign.MemorySegment.ofBuffer(result.rootNode().metadata().duplicate()), 0, java.lang.foreign.MemorySegment.ofBuffer(result.rootNode().metadata().duplicate()).byteSize()); + EncodeResult result = ENCODER.encode(DTypes.I32, data, EncodeTestHelper.testCtx()); + var metaSeg = MemorySegment.ofBuffer(result.rootNode().metadata().duplicate()); + RLEMetadata meta = RLEMetadata.decode(metaSeg, 0, metaSeg.byteSize()); - // Then assertThat(meta.values_len()).isEqualTo(2); assertThat(meta.indices_len()).isGreaterThan(0); - assertThat(meta.indices_ptype().value()).isEqualTo(1); // U16 + assertThat(meta.indices_ptype().value()).isEqualTo(1); } } } diff --git a/core/src/test/java/io/github/dfa1/vortex/encoding/RunEndEncodingTest.java b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/RunEndEncodingEncoderTest.java similarity index 63% rename from core/src/test/java/io/github/dfa1/vortex/encoding/RunEndEncodingTest.java rename to writer/src/test/java/io/github/dfa1/vortex/writer/encode/RunEndEncodingEncoderTest.java index 1fb239ee..e099ab82 100644 --- a/core/src/test/java/io/github/dfa1/vortex/encoding/RunEndEncodingTest.java +++ b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/RunEndEncodingEncoderTest.java @@ -1,11 +1,22 @@ -package io.github.dfa1.vortex.encoding; +package io.github.dfa1.vortex.writer.encode; -import io.github.dfa1.vortex.proto.RunEndMetadata; import io.github.dfa1.vortex.core.ArrayStats; import io.github.dfa1.vortex.core.DType; import io.github.dfa1.vortex.core.PType; import io.github.dfa1.vortex.core.array.Array; import io.github.dfa1.vortex.core.array.ArraySegments; +import io.github.dfa1.vortex.encoding.ArrayNode; +import io.github.dfa1.vortex.encoding.DTypes; +import io.github.dfa1.vortex.encoding.DecodeContext; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodeTestHelper; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.PTypeIO; +import io.github.dfa1.vortex.encoding.Registry; +import io.github.dfa1.vortex.encoding.TestRegistry; +import io.github.dfa1.vortex.proto.RunEndMetadata; +import io.github.dfa1.vortex.reader.decode.PrimitiveEncodingDecoder; +import io.github.dfa1.vortex.reader.decode.RunEndEncodingDecoder; import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; @@ -18,7 +29,11 @@ import static org.assertj.core.api.Assertions.assertThat; -class RunEndEncodingTest { +class RunEndEncodingEncoderTest { + + private static final RunEndEncodingEncoder ENCODER = new RunEndEncodingEncoder(); + private static final RunEndEncodingDecoder DECODER = new RunEndEncodingDecoder(); + private static final Registry REGISTRY = TestRegistry.ofDecoders(DECODER, new PrimitiveEncodingDecoder()); @Nested class Decode { @@ -27,8 +42,8 @@ private static DecodeContext buildCtx( DType dtype, long rowCount, long[] ends, long[] values, PType endsPtype, long offset ) { - byte[] metaBytes = new RunEndMetadata(io.github.dfa1.vortex.proto.PType.fromValue(endsPtype.ordinal()), ends.length, offset) - .encode(); + byte[] metaBytes = new RunEndMetadata( + io.github.dfa1.vortex.proto.PType.fromValue(endsPtype.ordinal()), ends.length, offset).encode(); byte[] endsBuf = toLEBytes(ends, endsPtype); byte[] valBuf = toLEBytes(values, PType.I64); @@ -47,9 +62,7 @@ private static DecodeContext buildCtx( MemorySegment.ofArray(valBuf) }; - Registry registry = TestRegistry.of(new RunEndEncoding(), new PrimitiveEncoding()); - - return new DecodeContext(reNode, dtype, rowCount, segments, registry, java.lang.foreign.Arena.global()); + return new DecodeContext(reNode, dtype, rowCount, segments, REGISTRY, java.lang.foreign.Arena.global()); } private static byte[] toLEBytes(long[] values, PType ptype) { @@ -70,61 +83,41 @@ private static byte[] toLEBytes(long[] values, PType ptype) { @Test void decode_singleRun_fillsAllElements() { - // Given — 1 run: ends=[5], values=[42]; output = [42, 42, 42, 42, 42] long[] ends = {5L}; long[] values = {42L}; DecodeContext ctx = buildCtx(DTypes.I64, 5, ends, values, PType.U32, 0L); - RunEndEncoding sut = new RunEndEncoding(); - - // When - Array result = sut.decode(ctx); + Array result = DECODER.decode(ctx); - // Then assertThat(result.length()).isEqualTo(5L); - var layout = PTypeIO.LE_LONG; for (int i = 0; i < 5; i++) { - assertThat(ArraySegments.of(result).get(layout, (long) i * 8)).isEqualTo(42L); + assertThat(ArraySegments.of(result).get(PTypeIO.LE_LONG, (long) i * 8)).isEqualTo(42L); } } @Test void decode_multipleRuns_expandsCorrectly() { - // Given — runs: [0,2)=10, [2,5)=20, [5,7)=30 - // ends=[2,5,7], values=[10,20,30] long[] ends = {2L, 5L, 7L}; long[] values = {10L, 20L, 30L}; DecodeContext ctx = buildCtx(DTypes.I64, 7, ends, values, PType.U32, 0L); - RunEndEncoding sut = new RunEndEncoding(); + Array result = DECODER.decode(ctx); - // When - Array result = sut.decode(ctx); - - // Then long[] expected = {10, 10, 20, 20, 20, 30, 30}; - var layout = PTypeIO.LE_LONG; for (int i = 0; i < expected.length; i++) { - assertThat(ArraySegments.of(result).get(layout, (long) i * 8)) + assertThat(ArraySegments.of(result).get(PTypeIO.LE_LONG, (long) i * 8)) .as("index %d", i).isEqualTo(expected[i]); } } @Test void decode_withOffset_skipsLogicalElements() { - // Given — logical array: [0,3)=10, [3,6)=20; offset=2, rowCount=3 - // output elements [2..5): [10, 20, 20] long[] ends = {3L, 6L}; long[] values = {10L, 20L}; DecodeContext ctx = buildCtx(DTypes.I64, 3, ends, values, PType.U32, 2L); - RunEndEncoding sut = new RunEndEncoding(); - - // When - Array result = sut.decode(ctx); + Array result = DECODER.decode(ctx); - // Then long[] expected = {10L, 20L, 20L}; - var layout = PTypeIO.LE_LONG; for (int i = 0; i < expected.length; i++) { - assertThat(ArraySegments.of(result).get(layout, (long) i * 8)) + assertThat(ArraySegments.of(result).get(PTypeIO.LE_LONG, (long) i * 8)) .as("index %d", i).isEqualTo(expected[i]); } } @@ -146,37 +139,25 @@ static Stream i64Arrays() { @ParameterizedTest @MethodSource("i64Arrays") void encodeDecode_i64_isLossless(long[] data) { - // Given - var sut = new RunEndEncoding(); - Registry registry = TestRegistry.withPrimitive(sut); - var le = PTypeIO.LE_LONG; + EncodeResult encoded = ENCODER.encode(DTypes.I64, data, EncodeTestHelper.testCtx()); + DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, data.length, DTypes.I64, REGISTRY); + Array result = DECODER.decode(ctx); - // When - EncodeResult encoded = sut.encode(DTypes.I64, data, EncodeTestHelper.testCtx()); - DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, data.length, DTypes.I64, registry); - Array result = sut.decode(ctx); - - // Then assertThat(result.length()).isEqualTo(data.length); for (int i = 0; i < data.length; i++) { - assertThat(ArraySegments.of(result).get(le, (long) i * 8)).as("index %d", i).isEqualTo(data[i]); + assertThat(ArraySegments.of(result).get(PTypeIO.LE_LONG, (long) i * 8)).as("index %d", i).isEqualTo(data[i]); } } @Test void encode_i64_metadata_numRuns_andEndsPtype() throws Exception { - // Given — 3 runs; if tag drifts, num_runs reads as 0 or ends_ptype reads as 0 (U8) long[] data = {1L, 1L, 2L, 2L, 3L}; - RunEndEncoding sut = new RunEndEncoding(); - - // When - EncodeResult result = sut.encode(DTypes.I64, data, EncodeTestHelper.testCtx()); - RunEndMetadata meta = - RunEndMetadata.decode(java.lang.foreign.MemorySegment.ofBuffer(result.rootNode().metadata().duplicate()), 0, java.lang.foreign.MemorySegment.ofBuffer(result.rootNode().metadata().duplicate()).byteSize()); + EncodeResult result = ENCODER.encode(DTypes.I64, data, EncodeTestHelper.testCtx()); + var metaSeg = java.lang.foreign.MemorySegment.ofBuffer(result.rootNode().metadata().duplicate()); + RunEndMetadata meta = RunEndMetadata.decode(metaSeg, 0, metaSeg.byteSize()); - // Then assertThat(meta.num_runs()).isEqualTo(3); - assertThat(meta.ends_ptype().value()).isEqualTo(2); // U32 + assertThat(meta.ends_ptype().value()).isEqualTo(2); } } } diff --git a/core/src/test/java/io/github/dfa1/vortex/encoding/SequenceEncodingTest.java b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/SequenceEncodingEncoderTest.java similarity index 70% rename from core/src/test/java/io/github/dfa1/vortex/encoding/SequenceEncodingTest.java rename to writer/src/test/java/io/github/dfa1/vortex/writer/encode/SequenceEncodingEncoderTest.java index 7e5981b4..079dad1e 100644 --- a/core/src/test/java/io/github/dfa1/vortex/encoding/SequenceEncodingTest.java +++ b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/SequenceEncodingEncoderTest.java @@ -1,4 +1,4 @@ -package io.github.dfa1.vortex.encoding; +package io.github.dfa1.vortex.writer.encode; import io.github.dfa1.vortex.core.DType; import io.github.dfa1.vortex.core.VortexException; @@ -8,8 +8,16 @@ import io.github.dfa1.vortex.core.array.FloatArray; import io.github.dfa1.vortex.core.array.IntArray; import io.github.dfa1.vortex.core.array.LongArray; +import io.github.dfa1.vortex.encoding.ArrayNode; +import io.github.dfa1.vortex.encoding.DTypes; +import io.github.dfa1.vortex.encoding.DecodeContext; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodeTestHelper; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.Registry; import io.github.dfa1.vortex.proto.ScalarValue; import io.github.dfa1.vortex.proto.SequenceMetadata; +import io.github.dfa1.vortex.reader.decode.SequenceEncodingDecoder; import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; @@ -24,7 +32,10 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; -class SequenceEncodingTest { +class SequenceEncodingEncoderTest { + + private static final SequenceEncodingEncoder ENCODER = new SequenceEncodingEncoder(); + private static final SequenceEncodingDecoder DECODER = new SequenceEncodingDecoder(); @Nested class Encode { @@ -37,22 +48,16 @@ private static DecodeContext encodeResultToCtx(EncodeResult result, DType dtype, @Test void encodingId_isVortexSequence() { - // Given / When / Then - assertThat(new SequenceEncoding().encodingId()).isEqualTo(EncodingId.VORTEX_SEQUENCE); + assertThat(ENCODER.encodingId()).isEqualTo(EncodingId.VORTEX_SEQUENCE); } @Test void encode_i64_roundTrips() { - // Given - var sut = new SequenceEncoding(); long[] data = {10L, 12L, 14L, 16L}; - - // When - EncodeResult result = sut.encode(DTypes.I64, data, EncodeTestHelper.testCtx()); + EncodeResult result = ENCODER.encode(DTypes.I64, data, EncodeTestHelper.testCtx()); DecodeContext ctx = encodeResultToCtx(result, DTypes.I64, data.length); - LongArray decoded = (LongArray) sut.decode(ctx); + LongArray decoded = (LongArray) DECODER.decode(ctx); - // Then for (int i = 0; i < data.length; i++) { assertThat(decoded.getLong(i)).as("index %d", i).isEqualTo(data[i]); } @@ -60,16 +65,11 @@ void encode_i64_roundTrips() { @Test void encode_f64_roundTrips() { - // Given - var sut = new SequenceEncoding(); double[] data = {1.0, 1.5, 2.0, 2.5}; - - // When - EncodeResult result = sut.encode(DTypes.F64, data, EncodeTestHelper.testCtx()); + EncodeResult result = ENCODER.encode(DTypes.F64, data, EncodeTestHelper.testCtx()); DecodeContext ctx = encodeResultToCtx(result, DTypes.F64, data.length); - DoubleArray decoded = (DoubleArray) sut.decode(ctx); + DoubleArray decoded = (DoubleArray) DECODER.decode(ctx); - // Then for (int i = 0; i < data.length; i++) { assertThat(decoded.getDouble(i)).as("index %d", i).isEqualTo(data[i]); } @@ -77,16 +77,11 @@ void encode_f64_roundTrips() { @Test void encode_f16_roundTrips() { - // Given — 0.0, 1.0, 2.0 as F16 bit patterns - var sut = new SequenceEncoding(); short[] data = {Float.floatToFloat16(0.0f), Float.floatToFloat16(1.0f), Float.floatToFloat16(2.0f)}; - - // When - EncodeResult result = sut.encode(DTypes.F16, data, EncodeTestHelper.testCtx()); + EncodeResult result = ENCODER.encode(DTypes.F16, data, EncodeTestHelper.testCtx()); DecodeContext ctx = encodeResultToCtx(result, DTypes.F16, data.length); - Float16Array decoded = (Float16Array) sut.decode(ctx); + Float16Array decoded = (Float16Array) DECODER.decode(ctx); - // Then for (int i = 0; i < data.length; i++) { assertThat(decoded.getFloat(i)).as("index %d", i).isEqualTo(Float.float16ToFloat(data[i])); } @@ -94,22 +89,14 @@ void encode_f16_roundTrips() { @Test void encode_nonArithmeticSequence_throwsVortexException() { - // Given - var sut = new SequenceEncoding(); long[] data = {1L, 2L, 4L}; - - // When / Then - assertThatThrownBy(() -> sut.encode(DTypes.I64, data, EncodeTestHelper.testCtx())) + assertThatThrownBy(() -> ENCODER.encode(DTypes.I64, data, EncodeTestHelper.testCtx())) .isInstanceOf(VortexException.class); } @Test void encode_nonPrimitiveDtype_throwsVortexException() { - // Given - var sut = new SequenceEncoding(); - - // When / Then - assertThatThrownBy(() -> sut.encode(new DType.Utf8(false), new long[]{1L}, EncodeTestHelper.testCtx())) + assertThatThrownBy(() -> ENCODER.encode(new DType.Utf8(false), new long[]{1L}, EncodeTestHelper.testCtx())) .isInstanceOf(VortexException.class); } } @@ -136,10 +123,8 @@ static Stream i32Sequences() { } private static DecodeContext makeCtx(byte[] meta, DType dtype, long n) { - ArrayNode node = ArrayNode.of( - EncodingId.VORTEX_SEQUENCE, - ByteBuffer.wrap(meta), - new ArrayNode[0], new int[0], null); + ArrayNode node = ArrayNode.of(EncodingId.VORTEX_SEQUENCE, + ByteBuffer.wrap(meta), new ArrayNode[0], new int[0], null); return new DecodeContext(node, dtype, n, new MemorySegment[0], Registry.empty(), Arena.ofAuto()); } @@ -164,17 +149,12 @@ private static byte[] f16Meta(short baseShort, short mulShort) { @ParameterizedTest @MethodSource("i64Sequences") void decode_i64_generatesCorrectSequence(long base, long mul, long[] expected) { - // Given - var sut = new SequenceEncoding(); DecodeContext ctx = makeCtx(intMeta(base, mul), DTypes.I64, expected.length); + Array result = DECODER.decode(ctx); - // When - Array result = sut.decode(ctx); - - // Then assertThat(result.length()).isEqualTo(expected.length); + LongArray longArray = (LongArray) result; for (int i = 0; i < expected.length; i++) { - LongArray longArray = (LongArray) result; assertThat(longArray.getLong(i)).as("index %d", i).isEqualTo(expected[i]); } } @@ -182,31 +162,21 @@ void decode_i64_generatesCorrectSequence(long base, long mul, long[] expected) { @ParameterizedTest @MethodSource("i32Sequences") void decode_i32_generatesCorrectSequence(long base, long mul, int[] expected) { - // Given - var sut = new SequenceEncoding(); DecodeContext ctx = makeCtx(intMeta(base, mul), DTypes.I32, expected.length); + Array result = DECODER.decode(ctx); - // When - Array result = sut.decode(ctx); - - // Then assertThat(result.length()).isEqualTo(expected.length); + IntArray intArray = (IntArray) result; for (int i = 0; i < expected.length; i++) { - IntArray longArray = (IntArray) result; - assertThat(longArray.getInt(i)).as("index %d", i).isEqualTo(expected[i]); + assertThat(intArray.getInt(i)).as("index %d", i).isEqualTo(expected[i]); } } @Test void decode_f64_generatesCorrectSequence() { - // Given - var sut = new SequenceEncoding(); DecodeContext ctx = makeCtx(f64Meta(1.0, 0.5), DTypes.F64, 4); + Array result = DECODER.decode(ctx); - // When - Array result = sut.decode(ctx); - - // Then assertThat(result.length()).isEqualTo(4); DoubleArray doubleArray = (DoubleArray) result; assertThat(doubleArray.getDouble(0)).isEqualTo(1.0); @@ -217,14 +187,9 @@ void decode_f64_generatesCorrectSequence() { @Test void decode_f32_generatesCorrectSequence() { - // Given - var sut = new SequenceEncoding(); DecodeContext ctx = makeCtx(f32Meta(0.0f, 1.0f), DTypes.F32, 3); + Array result = DECODER.decode(ctx); - // When - Array result = sut.decode(ctx); - - // Then assertThat(result.length()).isEqualTo(3); FloatArray floatArray = (FloatArray) result; assertThat(floatArray.getFloat(0)).isEqualTo(0.0f); @@ -234,42 +199,27 @@ void decode_f32_generatesCorrectSequence() { @Test void decode_emptySequence_returnsZeroLengthArray() { - // Given - var sut = new SequenceEncoding(); DecodeContext ctx = makeCtx(intMeta(0, 1), DTypes.I64, 0); - - // When - Array result = sut.decode(ctx); - - // Then + Array result = DECODER.decode(ctx); assertThat(result.length()).isZero(); } @Test void decode_missingMetadata_throwsVortexException() { - // Given - var sut = new SequenceEncoding(); ArrayNode node = ArrayNode.of(EncodingId.VORTEX_SEQUENCE, null, new ArrayNode[0], new int[0], null); DecodeContext ctx = new DecodeContext(node, DTypes.I64, 3, new MemorySegment[0], Registry.empty(), Arena.ofAuto()); - // When / Then - assertThatThrownBy(() -> sut.decode(ctx)) - .isInstanceOf(VortexException.class); + assertThatThrownBy(() -> DECODER.decode(ctx)).isInstanceOf(VortexException.class); } @Test void decode_f16_generatesCorrectSequence() { - // Given — 0.0, 1.0, 2.0 as F16 bit patterns via the direct-metadata path - var sut = new SequenceEncoding(); short baseShort = Float.floatToFloat16(0.0f); short mulShort = Float.floatToFloat16(1.0f); byte[] meta = f16Meta(baseShort, mulShort); DecodeContext ctx = makeCtx(meta, DTypes.F16, 3); + Array result = DECODER.decode(ctx); - // When - Array result = sut.decode(ctx); - - // Then assertThat(result.length()).isEqualTo(3); Float16Array f16Array = (Float16Array) result; assertThat(f16Array.getFloat(0)).isEqualTo(0.0f); @@ -279,14 +229,9 @@ void decode_f16_generatesCorrectSequence() { @Test void decode_nonPrimitiveDtype_throwsVortexException() { - // Given - var sut = new SequenceEncoding(); DType utf8 = new DType.Utf8(false); DecodeContext ctx = makeCtx(intMeta(0, 1), utf8, 3); - - // When / Then - assertThatThrownBy(() -> sut.decode(ctx)) - .isInstanceOf(VortexException.class); + assertThatThrownBy(() -> DECODER.decode(ctx)).isInstanceOf(VortexException.class); } } @@ -295,17 +240,11 @@ class Metadata { @Test void encode_i64_metadata_base_andMultiplier_areSet() throws Exception { - // Given — arithmetic sequence {10, 12, 14, 16} → base=10, multiplier=2 - // if tag drifts, base/multiplier messages are missing (hasBase() == false) long[] data = {10L, 12L, 14L, 16L}; - SequenceEncoding sut = new SequenceEncoding(); - - // When - EncodeResult result = sut.encode(DTypes.I64, data, EncodeTestHelper.testCtx()); + EncodeResult result = ENCODER.encode(DTypes.I64, data, EncodeTestHelper.testCtx()); MemorySegment metaSeg = MemorySegment.ofBuffer(result.rootNode().metadata().duplicate()); SequenceMetadata meta = SequenceMetadata.decode(metaSeg, 0, metaSeg.byteSize()); - // Then assertThat(meta.base()).isNotNull(); assertThat(meta.multiplier()).isNotNull(); assertThat(meta.base().int64_value()).isEqualTo(10L); diff --git a/core/src/test/java/io/github/dfa1/vortex/encoding/SparseEncodingTest.java b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/SparseEncodingEncoderTest.java similarity index 67% rename from core/src/test/java/io/github/dfa1/vortex/encoding/SparseEncodingTest.java rename to writer/src/test/java/io/github/dfa1/vortex/writer/encode/SparseEncodingEncoderTest.java index dee032b9..87a16e30 100644 --- a/core/src/test/java/io/github/dfa1/vortex/encoding/SparseEncodingTest.java +++ b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/SparseEncodingEncoderTest.java @@ -1,8 +1,5 @@ -package io.github.dfa1.vortex.encoding; +package io.github.dfa1.vortex.writer.encode; -import io.github.dfa1.vortex.proto.PatchesMetadata; -import io.github.dfa1.vortex.proto.SparseMetadata; -import io.github.dfa1.vortex.proto.VarBinMetadata; import io.github.dfa1.vortex.core.ArrayStats; import io.github.dfa1.vortex.core.DType; import io.github.dfa1.vortex.core.PType; @@ -10,12 +7,29 @@ import io.github.dfa1.vortex.core.array.ArraySegments; import io.github.dfa1.vortex.core.array.BoolArray; import io.github.dfa1.vortex.core.array.VarBinArray; +import io.github.dfa1.vortex.encoding.ArrayNode; +import io.github.dfa1.vortex.encoding.DTypes; +import io.github.dfa1.vortex.encoding.DecodeContext; +import io.github.dfa1.vortex.encoding.EncodeNode; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodeTestHelper; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.PTypeIO; +import io.github.dfa1.vortex.encoding.Registry; +import io.github.dfa1.vortex.encoding.TestRegistry; +import io.github.dfa1.vortex.proto.NullValue; +import io.github.dfa1.vortex.proto.PatchesMetadata; +import io.github.dfa1.vortex.proto.ScalarValue; +import io.github.dfa1.vortex.proto.SparseMetadata; +import io.github.dfa1.vortex.proto.VarBinMetadata; +import io.github.dfa1.vortex.reader.decode.BoolEncodingDecoder; +import io.github.dfa1.vortex.reader.decode.PrimitiveEncodingDecoder; +import io.github.dfa1.vortex.reader.decode.SparseEncodingDecoder; +import io.github.dfa1.vortex.reader.decode.VarBinEncodingDecoder; import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; -import io.github.dfa1.vortex.proto.NullValue; -import io.github.dfa1.vortex.proto.ScalarValue; import java.lang.foreign.Arena; import java.lang.foreign.MemorySegment; @@ -26,8 +40,11 @@ import static org.assertj.core.api.Assertions.assertThat; -class SparseEncodingTest { +class SparseEncodingEncoderTest { + private static final SparseEncodingEncoder ENCODER = new SparseEncodingEncoder(); + private static final SparseEncodingDecoder DECODER = new SparseEncodingDecoder(); + private static final Registry REGISTRY = TestRegistry.ofDecoders(DECODER, new PrimitiveEncodingDecoder()); @Nested class Encode { @@ -45,76 +62,48 @@ private static Array decodeResult(EncodeResult encoded, DType dtype, int n) { ArrayNode sparseNode = ArrayNode.of(root.encodingId(), root.metadata(), new ArrayNode[]{idxNode, valNode}, root.bufferIndices(), ArrayStats.empty()); - Registry registry = TestRegistry.of(new SparseEncoding(), new PrimitiveEncoding()); - - DecodeContext ctx = new DecodeContext(sparseNode, dtype, n, segments, registry, Arena.global()); - return new SparseEncoding().decode(ctx); + DecodeContext ctx = new DecodeContext(sparseNode, dtype, n, segments, REGISTRY, Arena.global()); + return DECODER.decode(ctx); } @Test void encode_allZeros_noPatches() throws java.io.IOException { - // Given long[] data = {0L, 0L, 0L, 0L}; - SparseEncoding sut = new SparseEncoding(); - - // When - EncodeResult result = sut.encode(DTypes.I64, data, EncodeTestHelper.testCtx()); - - // Then - java.nio.ByteBuffer metaBuf = result.rootNode().metadata().duplicate(); - java.lang.foreign.MemorySegment metaSeg = java.lang.foreign.MemorySegment.ofBuffer(metaBuf); + EncodeResult result = ENCODER.encode(DTypes.I64, data, EncodeTestHelper.testCtx()); + var metaSeg = java.lang.foreign.MemorySegment.ofBuffer(result.rootNode().metadata().duplicate()); SparseMetadata meta = SparseMetadata.decode(metaSeg, 0, metaSeg.byteSize()); assertThat(meta.patches().len()).isZero(); } @Test void encode_withNonZero_createsPatches() throws java.io.IOException { - // Given — [0, 10, 0, 50, 0] long[] data = {0L, 10L, 0L, 50L, 0L}; - SparseEncoding sut = new SparseEncoding(); - - // When - EncodeResult result = sut.encode(DTypes.I64, data, EncodeTestHelper.testCtx()); - - // Then - java.nio.ByteBuffer metaBuf = result.rootNode().metadata().duplicate(); - java.lang.foreign.MemorySegment metaSeg = java.lang.foreign.MemorySegment.ofBuffer(metaBuf); + EncodeResult result = ENCODER.encode(DTypes.I64, data, EncodeTestHelper.testCtx()); + var metaSeg = java.lang.foreign.MemorySegment.ofBuffer(result.rootNode().metadata().duplicate()); SparseMetadata meta = SparseMetadata.decode(metaSeg, 0, metaSeg.byteSize()); assertThat(meta.patches().len()).isEqualTo(2); } @Test void encode_roundTrip_i64() { - // Given — sparse long array long[] data = {0L, 0L, 42L, 0L, 99L, 0L, 0L, 7L}; - SparseEncoding sut = new SparseEncoding(); - - // When - EncodeResult encoded = sut.encode(DTypes.I64, data, EncodeTestHelper.testCtx()); + EncodeResult encoded = ENCODER.encode(DTypes.I64, data, EncodeTestHelper.testCtx()); Array decoded = decodeResult(encoded, DTypes.I64, data.length); - // Then - var layout = PTypeIO.LE_LONG; for (int i = 0; i < data.length; i++) { - assertThat(ArraySegments.of(decoded).get(layout, (long) i * 8)) + assertThat(ArraySegments.of(decoded).get(PTypeIO.LE_LONG, (long) i * 8)) .as("index %d", i).isEqualTo(data[i]); } } @Test void encode_roundTrip_f64() { - // Given double[] data = {0.0, 3.14, 0.0, 0.0, 2.72}; - SparseEncoding sut = new SparseEncoding(); - - // When - EncodeResult encoded = sut.encode(DTypes.F64, data, EncodeTestHelper.testCtx()); + EncodeResult encoded = ENCODER.encode(DTypes.F64, data, EncodeTestHelper.testCtx()); Array decoded = decodeResult(encoded, DTypes.F64, data.length); - // Then - var layout = PTypeIO.LE_DOUBLE; for (int i = 0; i < data.length; i++) { - assertThat(ArraySegments.of(decoded).get(layout, (long) i * 8)) + assertThat(ArraySegments.of(decoded).get(PTypeIO.LE_DOUBLE, (long) i * 8)) .as("index %d", i).isEqualTo(data[i]); } } @@ -122,16 +111,9 @@ void encode_roundTrip_f64() { @ParameterizedTest @ValueSource(ints = {0, 1, 100}) void encode_empty_or_allZero_noPatches(int size) throws java.io.IOException { - // Given long[] data = new long[size]; - SparseEncoding sut = new SparseEncoding(); - - // When - EncodeResult result = sut.encode(DTypes.I64, data, EncodeTestHelper.testCtx()); - - // Then - java.nio.ByteBuffer metaBuf = result.rootNode().metadata().duplicate(); - java.lang.foreign.MemorySegment metaSeg = java.lang.foreign.MemorySegment.ofBuffer(metaBuf); + EncodeResult result = ENCODER.encode(DTypes.I64, data, EncodeTestHelper.testCtx()); + var metaSeg = java.lang.foreign.MemorySegment.ofBuffer(result.rootNode().metadata().duplicate()); SparseMetadata meta = SparseMetadata.decode(metaSeg, 0, metaSeg.byteSize()); assertThat(meta.patches().len()).isZero(); } @@ -152,14 +134,11 @@ private static DecodeContext buildSparseCtxWithOffset( long[] patchIndices, long[] patchValues, long offset ) { byte[] fillBytes = ScalarValue.ofInt64Value(fillLong).encode(); - byte[] metaBytes = buildSparseMetaBytes(patchIndices.length, offset, idxPtype); - byte[] idxBuf = toLEBytes(patchIndices, idxPtype); byte[] valBuf = toLEBytes(patchValues, PType.I64); - return buildCtx(dtype, rowCount, fillBytes, metaBytes, idxBuf, valBuf, - new DType.Primitive(idxPtype, false)); + return buildCtx(dtype, rowCount, fillBytes, metaBytes, idxBuf, valBuf); } private static DecodeContext buildSparseCtxF64( @@ -168,20 +147,14 @@ private static DecodeContext buildSparseCtxF64( ) { byte[] fillBytes = ScalarValue.ofF64Value(fillDouble).encode(); byte[] metaBytes = buildSparseMetaBytes(patchIndices.length, 0L, PType.U32); - byte[] idxBuf = toLEBytes(patchIndices, PType.U32); byte[] valBuf = f64LEBytes(patchValues); - return buildCtx(dtype, rowCount, fillBytes, metaBytes, idxBuf, valBuf, - new DType.Primitive(PType.U32, false)); + return buildCtx(dtype, rowCount, fillBytes, metaBytes, idxBuf, valBuf); } - private static DecodeContext buildCtx( - DType dtype, long rowCount, - byte[] fillBytes, byte[] metaBytes, - byte[] idxBuf, byte[] valBuf, - DType idxDtype - ) { + private static DecodeContext buildCtx(DType dtype, long rowCount, + byte[] fillBytes, byte[] metaBytes, byte[] idxBuf, byte[] valBuf) { ArrayNode idxNode = ArrayNode.of(EncodingId.VORTEX_PRIMITIVE, null, new ArrayNode[0], new int[]{1}, ArrayStats.empty()); ArrayNode valNode = ArrayNode.of(EncodingId.VORTEX_PRIMITIVE, null, @@ -189,24 +162,20 @@ private static DecodeContext buildCtx( ArrayNode sparseNode = ArrayNode.of(EncodingId.VORTEX_SPARSE, ByteBuffer.wrap(metaBytes), new ArrayNode[]{idxNode, valNode}, - new int[]{0}, - ArrayStats.empty()); + new int[]{0}, ArrayStats.empty()); MemorySegment[] segments = { MemorySegment.ofArray(fillBytes), MemorySegment.ofArray(idxBuf), MemorySegment.ofArray(valBuf) }; - - Registry registry = TestRegistry.of(new SparseEncoding(), new PrimitiveEncoding()); - - return new DecodeContext(sparseNode, dtype, rowCount, segments, registry, java.lang.foreign.Arena.global()); + return new DecodeContext(sparseNode, dtype, rowCount, segments, REGISTRY, java.lang.foreign.Arena.global()); } private static byte[] buildSparseMetaBytes(long numPatches, long offset, PType idxPtype) { - PatchesMetadata patchesMeta = new PatchesMetadata(numPatches, offset, io.github.dfa1.vortex.proto.PType.fromValue(idxPtype.ordinal()), null, null, null); - return new SparseMetadata(patchesMeta) - .encode(); + PatchesMetadata patchesMeta = new PatchesMetadata(numPatches, offset, + io.github.dfa1.vortex.proto.PType.fromValue(idxPtype.ordinal()), null, null, null); + return new SparseMetadata(patchesMeta).encode(); } private static byte[] toLEBytes(long[] values, PType ptype) { @@ -245,114 +214,75 @@ private static byte[] intLEBytes(int[] values) { @Test void decode_noPatches_returnsFillValue() { - // Given — 5 elements, fill=99, no patches long fill = 99L; DecodeContext ctx = buildSparseCtx(DTypes.I64, 5, fill, PType.U32, new long[0], new long[0]); - SparseEncoding sut = new SparseEncoding(); - - // When - Array result = sut.decode(ctx); + Array result = DECODER.decode(ctx); - // Then assertThat(result.length()).isEqualTo(5L); - var layout = PTypeIO.LE_LONG; for (int i = 0; i < 5; i++) { - assertThat(ArraySegments.of(result).get(layout, (long) i * 8)) + assertThat(ArraySegments.of(result).get(PTypeIO.LE_LONG, (long) i * 8)) .as("index %d", i).isEqualTo(fill); } } @Test void decode_withPatches_overwritesAtIndices() { - // Given — 8 elements, fill=0, patches at indices [1, 5] with values [10, 50] long fill = 0L; long[] patchIndices = {1L, 5L}; long[] patchValues = {10L, 50L}; DecodeContext ctx = buildSparseCtx(DTypes.I64, 8, fill, PType.U32, patchIndices, patchValues); - SparseEncoding sut = new SparseEncoding(); + Array result = DECODER.decode(ctx); - // When - Array result = sut.decode(ctx); - - // Then - var layout = PTypeIO.LE_LONG; long[] expected = {0, 10, 0, 0, 0, 50, 0, 0}; for (int i = 0; i < expected.length; i++) { - assertThat(ArraySegments.of(result).get(layout, (long) i * 8)) + assertThat(ArraySegments.of(result).get(PTypeIO.LE_LONG, (long) i * 8)) .as("index %d", i).isEqualTo(expected[i]); } } @Test void decode_f64_fillAndPatches() { - // Given — 4 F64 elements, fill=NaN bits, patch at index 2 with value 3.14 double fillVal = Double.NaN; double patchVal = 3.14; DecodeContext ctx = buildSparseCtxF64(DTypes.F64, 4, fillVal, new long[]{2L}, new double[]{patchVal}); - SparseEncoding sut = new SparseEncoding(); - - // When - Array result = sut.decode(ctx); + Array result = DECODER.decode(ctx); - // Then - var layout = PTypeIO.LE_DOUBLE; - assertThat(ArraySegments.of(result).get(layout, 0L)).isNaN(); - assertThat(ArraySegments.of(result).get(layout, 8L)).isNaN(); - assertThat(ArraySegments.of(result).get(layout, 16L)).isEqualTo(3.14); - assertThat(ArraySegments.of(result).get(layout, 24L)).isNaN(); + assertThat(ArraySegments.of(result).get(PTypeIO.LE_DOUBLE, 0L)).isNaN(); + assertThat(ArraySegments.of(result).get(PTypeIO.LE_DOUBLE, 8L)).isNaN(); + assertThat(ArraySegments.of(result).get(PTypeIO.LE_DOUBLE, 16L)).isEqualTo(3.14); + assertThat(ArraySegments.of(result).get(PTypeIO.LE_DOUBLE, 24L)).isNaN(); } @Test void decode_offsetSubtracted() { - // Given — offset=10, patch index=12 → absolute position = 12 - 10 = 2 long[] patchIndices = {12L}; long[] patchValues = {777L}; DecodeContext ctx = buildSparseCtxWithOffset(DTypes.I64, 5, 0L, PType.U32, patchIndices, patchValues, 10L); - SparseEncoding sut = new SparseEncoding(); + Array result = DECODER.decode(ctx); - // When - Array result = sut.decode(ctx); - - // Then - var layout = PTypeIO.LE_LONG; - assertThat(ArraySegments.of(result).get(layout, 16L)).isEqualTo(777L); + assertThat(ArraySegments.of(result).get(PTypeIO.LE_LONG, 16L)).isEqualTo(777L); } - // regression: NULL_VALUE fill caused "unexpected scalar kind NULL_VALUE" on nullable cols @Test void decode_nullValueFill_treatedAsZero() { - // Given — fill encoded as ScalarValue.NULL_VALUE (as Rust writes for nullable cols) byte[] nullFill = ScalarValue.ofNullValue(NullValue.NULL_VALUE).encode(); byte[] meta = buildSparseMetaBytes(0, 0L, PType.U32); - DecodeContext ctx = buildCtx(DTypes.I64, 4, nullFill, meta, new byte[0], new byte[0], - new DType.Primitive(PType.U32, false)); - SparseEncoding sut = new SparseEncoding(); - - // When - Array result = sut.decode(ctx); + DecodeContext ctx = buildCtx(DTypes.I64, 4, nullFill, meta, new byte[0], new byte[0]); + Array result = DECODER.decode(ctx); - // Then - var layout = PTypeIO.LE_LONG; for (int i = 0; i < 4; i++) { - assertThat(ArraySegments.of(result).get(layout, (long) i * 8)).as("index %d", i).isZero(); + assertThat(ArraySegments.of(result).get(PTypeIO.LE_LONG, (long) i * 8)).as("index %d", i).isZero(); } } - // regression: Utf8 dtype caused "expected primitive dtype, got Utf8[nullable=true]" @Test void decode_utf8_noPatches_allEmpty() { - // Given — Utf8 sparse, no patches → all positions empty (null fill) DType utf8 = new DType.Utf8(true); byte[] nullFill = ScalarValue.ofNullValue(NullValue.NULL_VALUE).encode(); byte[] meta = buildSparseMetaBytes(0, 0L, PType.U32); - DecodeContext ctx = buildCtx(utf8, 3, nullFill, meta, new byte[0], new byte[0], - new DType.Primitive(PType.U32, false)); - SparseEncoding sut = new SparseEncoding(); - - // When - Array result = sut.decode(ctx); + DecodeContext ctx = buildCtx(utf8, 3, nullFill, meta, new byte[0], new byte[0]); + Array result = DECODER.decode(ctx); - // Then assertThat(result.length()).isEqualTo(3L); VarBinArray varBin = (VarBinArray) result; for (int i = 0; i < 3; i++) { @@ -360,10 +290,8 @@ void decode_utf8_noPatches_allEmpty() { } } - // regression: Utf8 dtype caused "expected primitive dtype, got Utf8[nullable=true]" @Test void decode_utf8_withPatches_writesStringsAtIndices() { - // Given — 5 Utf8 elements, patches at [1]="hi" and [3]="bye" DType utf8 = new DType.Utf8(true); byte[] nullFill = ScalarValue.ofNullValue(NullValue.NULL_VALUE).encode(); byte[] meta = buildSparseMetaBytes(2, 0L, PType.U32); @@ -384,7 +312,7 @@ void decode_utf8_withPatches_writesStringsAtIndices() { ByteBuffer.wrap(meta), new ArrayNode[]{idxNode, valNode}, new int[]{0}, ArrayStats.empty()); - Registry registry = TestRegistry.of(new SparseEncoding(), new PrimitiveEncoding(), new VarBinEncoding()); + Registry registry = TestRegistry.ofDecoders(DECODER, new PrimitiveEncodingDecoder(), new VarBinEncodingDecoder()); MemorySegment[] segments = { MemorySegment.ofArray(nullFill), @@ -393,12 +321,9 @@ void decode_utf8_withPatches_writesStringsAtIndices() { MemorySegment.ofArray(offsets), }; DecodeContext ctx = new DecodeContext(sparseNode, utf8, 5, segments, registry, Arena.global()); - SparseEncoding sut = new SparseEncoding(); - // When - Array result = sut.decode(ctx); + Array result = DECODER.decode(ctx); - // Then VarBinArray varBin = (VarBinArray) result; assertThat(varBin.length()).isEqualTo(5L); assertThat(varBin.getByteLength(0)).isZero(); @@ -408,10 +333,8 @@ void decode_utf8_withPatches_writesStringsAtIndices() { assertThat(varBin.getByteLength(4)).isZero(); } - // regression: Bool dtype caused "expected primitive dtype, got Bool[nullable=true]" @Test void decode_bool_withPatches_setsBitsAtIndices() { - // Given — 6 Bool elements, patches at [2]=true and [5]=true DType bool = new DType.Bool(true); byte[] nullFill = ScalarValue.ofNullValue(NullValue.NULL_VALUE).encode(); byte[] meta = buildSparseMetaBytes(2, 0L, PType.U32); @@ -426,7 +349,7 @@ void decode_bool_withPatches_setsBitsAtIndices() { ByteBuffer.wrap(meta), new ArrayNode[]{idxNode, valNode}, new int[]{0}, ArrayStats.empty()); - Registry registry = TestRegistry.of(new SparseEncoding(), new PrimitiveEncoding(), new BoolEncoding()); + Registry registry = TestRegistry.ofDecoders(DECODER, new PrimitiveEncodingDecoder(), new BoolEncodingDecoder()); MemorySegment[] segments = { MemorySegment.ofArray(nullFill), @@ -434,12 +357,9 @@ void decode_bool_withPatches_setsBitsAtIndices() { MemorySegment.ofArray(boolBits), }; DecodeContext ctx = new DecodeContext(sparseNode, bool, 6, segments, registry, Arena.global()); - SparseEncoding sut = new SparseEncoding(); - // When - Array result = sut.decode(ctx); + Array result = DECODER.decode(ctx); - // Then BoolArray boolArr = (BoolArray) result; assertThat(boolArr.length()).isEqualTo(6L); assertThat(boolArr.getBoolean(0)).isFalse(); diff --git a/core/src/test/java/io/github/dfa1/vortex/encoding/StructEncodingTest.java b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/StructEncodingEncoderTest.java similarity index 65% rename from core/src/test/java/io/github/dfa1/vortex/encoding/StructEncodingTest.java rename to writer/src/test/java/io/github/dfa1/vortex/writer/encode/StructEncodingEncoderTest.java index 777445f5..e88662d4 100644 --- a/core/src/test/java/io/github/dfa1/vortex/encoding/StructEncodingTest.java +++ b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/StructEncodingEncoderTest.java @@ -1,4 +1,4 @@ -package io.github.dfa1.vortex.encoding; +package io.github.dfa1.vortex.writer.encode; import io.github.dfa1.vortex.core.ArrayStats; import io.github.dfa1.vortex.core.DType; @@ -7,6 +7,21 @@ import io.github.dfa1.vortex.core.array.LongArray; import io.github.dfa1.vortex.core.array.MaskedArray; import io.github.dfa1.vortex.core.array.StructArray; +import io.github.dfa1.vortex.encoding.ArrayNode; +import io.github.dfa1.vortex.encoding.DTypes; +import io.github.dfa1.vortex.encoding.DecodeContext; +import io.github.dfa1.vortex.encoding.EncodeNode; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodeTestHelper; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.PTypeIO; +import io.github.dfa1.vortex.encoding.Registry; +import io.github.dfa1.vortex.encoding.StructData; +import io.github.dfa1.vortex.encoding.TestRegistry; +import io.github.dfa1.vortex.encoding.TestSegments; +import io.github.dfa1.vortex.reader.decode.BoolEncodingDecoder; +import io.github.dfa1.vortex.reader.decode.PrimitiveEncodingDecoder; +import io.github.dfa1.vortex.reader.decode.StructEncodingDecoder; import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; @@ -16,7 +31,11 @@ import static org.assertj.core.api.Assertions.assertThat; -class StructEncodingTest { +class StructEncodingEncoderTest { + + private static final StructEncodingEncoder ENCODER = new StructEncodingEncoder(); + private static final StructEncodingDecoder DECODER = new StructEncodingDecoder(); + private static final Registry REGISTRY = TestRegistry.ofDecoders(DECODER, new PrimitiveEncodingDecoder()); private static ArrayNode toArrayNode(EncodeNode node) { ArrayNode[] children = new ArrayNode[node.children().length]; @@ -31,35 +50,25 @@ class Encode { @Test void accepts_structDtype_trueForStruct_falseForPrimitive() { - // Given - StructEncoding sut = new StructEncoding(); - DType.Struct structDtype = new DType.Struct( - List.of("x"), List.of(DTypes.I64), false); - - // When / Then - assertThat(sut.accepts(structDtype)).isTrue(); - assertThat(sut.accepts(DTypes.I64)).isFalse(); + DType.Struct structDtype = new DType.Struct(List.of("x"), List.of(DTypes.I64), false); + assertThat(ENCODER.accepts(structDtype)).isTrue(); + assertThat(ENCODER.accepts(DTypes.I64)).isFalse(); } @Test void roundTrip_twoI64Fields_preservesValues() { - // Given long[] ids = {1L, 2L, 3L}; long[] values = {10L, 20L, 30L}; DType.Struct dtype = new DType.Struct( List.of("id", "value"), List.of(DTypes.I64, DTypes.I64), false); StructData data = new StructData(List.of(ids, values)); - StructEncoding sut = new StructEncoding(); - // When - EncodeResult result = sut.encode(dtype, data, EncodeTestHelper.testCtx()); + EncodeResult result = ENCODER.encode(dtype, data, EncodeTestHelper.testCtx()); - // Then — decode round-trip MemorySegment[] bufs = result.buffers().toArray(MemorySegment[]::new); - Registry registry = TestRegistry.of(new StructEncoding(), new PrimitiveEncoding()); DecodeContext ctx = new DecodeContext( - toArrayNode(result.rootNode()), dtype, ids.length, bufs, registry, Arena.global()); - StructArray decoded = (StructArray) sut.decode(ctx); + toArrayNode(result.rootNode()), dtype, ids.length, bufs, REGISTRY, Arena.global()); + StructArray decoded = (StructArray) DECODER.decode(ctx); assertThat(decoded.length()).isEqualTo(ids.length); assertThat(decoded.fieldCount()).isEqualTo(2); @@ -73,32 +82,25 @@ void roundTrip_twoI64Fields_preservesValues() { @Test void singleField_encodeResult_hasOneChildAndNoBuffers() { - // Given long[] data = {7L, 14L, 21L}; DType.Struct dtype = new DType.Struct(List.of("v"), List.of(DTypes.I64), false); - StructEncoding sut = new StructEncoding(); - // When - EncodeResult result = sut.encode(dtype, new StructData(List.of(data)), EncodeTestHelper.testCtx()); + EncodeResult result = ENCODER.encode(dtype, new StructData(List.of(data)), EncodeTestHelper.testCtx()); - // Then — struct node wraps one field child with remapped buffers assertThat(result.rootNode().encodingId()).isEqualTo(EncodingId.VORTEX_STRUCT); assertThat(result.rootNode().children()).hasSize(1); assertThat(result.rootNode().bufferIndices()).isEmpty(); - assertThat(result.buffers()).hasSize(1); // one buffer for the DTypes.I64 field + assertThat(result.buffers()).hasSize(1); } @Test void fieldCountMismatch_throwsVortexException() { - // Given DType.Struct dtype = new DType.Struct(List.of("a", "b"), List.of(DTypes.I64, DTypes.I64), false); - StructData data = new StructData(List.of(new long[]{1L})); // only 1 field, dtype has 2 - StructEncoding sut = new StructEncoding(); + StructData data = new StructData(List.of(new long[]{1L})); - // When / Then org.junit.jupiter.api.Assertions.assertThrows( io.github.dfa1.vortex.core.VortexException.class, - () -> sut.encode(dtype, data, EncodeTestHelper.testCtx())); + () -> ENCODER.encode(dtype, data, EncodeTestHelper.testCtx())); } } @@ -115,27 +117,18 @@ private static ArrayNode boolNode(int bufferIdx) { new int[]{bufferIdx}, ArrayStats.empty()); } - private static DecodeContext buildStructCtx(ArrayNode structNode, MemorySegment[] segs, long rowCount) { - Registry registry = TestRegistry.of(new StructEncoding(), new PrimitiveEncoding()); - return new DecodeContext(structNode, DTypes.I64, rowCount, segs, registry, Arena.global()); - } - @Test void decode_nonNullableWrapper_oneChild_returnsValues() { - // Given — struct{values: DTypes.I64} (non-nullable, 1 child) long[] data = {10L, 20L, 30L}; MemorySegment seg = TestSegments.leLongs(data); ArrayNode valuesNode = primitiveNode(0); ArrayNode structNode = ArrayNode.of(EncodingId.VORTEX_STRUCT, null, new ArrayNode[]{valuesNode}, new int[0], ArrayStats.empty()); - DecodeContext ctx = buildStructCtx(structNode, new MemorySegment[]{seg}, data.length); - StructEncoding sut = new StructEncoding(); + DecodeContext ctx = new DecodeContext(structNode, DTypes.I64, data.length, + new MemorySegment[]{seg}, REGISTRY, Arena.global()); + Array result = DECODER.decode(ctx); - // When - Array result = sut.decode(ctx); - - // Then assertThat(result.length()).isEqualTo(data.length); for (int i = 0; i < data.length; i++) { assertThat(ArraySegments.of(result).get(PTypeIO.LE_LONG, (long) i * 8)).isEqualTo(data[i]); @@ -144,28 +137,23 @@ void decode_nonNullableWrapper_oneChild_returnsValues() { @Test void decode_nullableWrapper_twoChildren_returnsMaskedArray() { - // Given — struct{validity: Bool, values: DTypes.I64} (nullable, 2 children) long[] data = {7L, 14L, 21L}; - MemorySegment validitySeg = MemorySegment.ofArray(new byte[]{(byte) 0xFF}); // all valid + MemorySegment validitySeg = MemorySegment.ofArray(new byte[]{(byte) 0xFF}); MemorySegment valuesSeg = TestSegments.leLongs(data); - ArrayNode validityNode = boolNode(0); // slot 0 = validity bitmap - ArrayNode valuesNode = primitiveNode(1); // slot 1 = actual values + ArrayNode validityNode = boolNode(0); + ArrayNode valuesNode = primitiveNode(1); ArrayNode structNode = ArrayNode.of(EncodingId.VORTEX_STRUCT, null, new ArrayNode[]{validityNode, valuesNode}, new int[0], ArrayStats.empty()); - Registry registry = TestRegistry.of(new StructEncoding(), new PrimitiveEncoding(), new BoolEncoding()); + Registry registry = TestRegistry.ofDecoders(DECODER, new PrimitiveEncodingDecoder(), new BoolEncodingDecoder()); DecodeContext ctx = new DecodeContext( structNode, DTypes.I64, data.length, new MemorySegment[]{validitySeg, valuesSeg}, registry, Arena.global()); - StructEncoding sut = new StructEncoding(); - - // When - Array result = sut.decode(ctx); + Array result = DECODER.decode(ctx); - // Then — validity preserved; values accessible via inner array assertThat(result).isInstanceOf(MaskedArray.class); MaskedArray masked = (MaskedArray) result; LongArray values = (LongArray) masked.inner(); diff --git a/core/src/test/java/io/github/dfa1/vortex/encoding/VarBinEncodingTest.java b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/VarBinEncodingEncoderTest.java similarity index 52% rename from core/src/test/java/io/github/dfa1/vortex/encoding/VarBinEncodingTest.java rename to writer/src/test/java/io/github/dfa1/vortex/writer/encode/VarBinEncodingEncoderTest.java index ab6952b7..94fcf410 100644 --- a/core/src/test/java/io/github/dfa1/vortex/encoding/VarBinEncodingTest.java +++ b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/VarBinEncodingEncoderTest.java @@ -1,10 +1,20 @@ -package io.github.dfa1.vortex.encoding; +package io.github.dfa1.vortex.writer.encode; -import io.github.dfa1.vortex.proto.VarBinMetadata; import io.github.dfa1.vortex.core.DType; import io.github.dfa1.vortex.core.PType; import io.github.dfa1.vortex.core.VortexException; import io.github.dfa1.vortex.core.array.VarBinArray; +import io.github.dfa1.vortex.encoding.ArrayNode; +import io.github.dfa1.vortex.encoding.DTypes; +import io.github.dfa1.vortex.encoding.DecodeContext; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodeTestHelper; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.Registry; +import io.github.dfa1.vortex.encoding.TestRegistry; +import io.github.dfa1.vortex.proto.VarBinMetadata; +import io.github.dfa1.vortex.reader.decode.PrimitiveEncodingDecoder; +import io.github.dfa1.vortex.reader.decode.VarBinEncodingDecoder; import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; @@ -15,70 +25,53 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; -class VarBinEncodingTest { +class VarBinEncodingEncoderTest { + + private static final VarBinEncodingEncoder ENCODER = new VarBinEncodingEncoder(); + private static final VarBinEncodingDecoder DECODER = new VarBinEncodingDecoder(); + private static final Registry REGISTRY = TestRegistry.ofDecoders(DECODER, new PrimitiveEncodingDecoder()); @Nested class Encode { - private static Registry buildRegistry() { - return Registry.builder() - .register(new VarBinEncoding()) - .register(new PrimitiveEncoding()) - .build(); - } - @Test void encodingId_isVortexVarbin() { - // Given / When / Then - assertThat(new VarBinEncoding().encodingId()).isEqualTo(EncodingId.VORTEX_VARBIN); + assertThat(ENCODER.encodingId()).isEqualTo(EncodingId.VORTEX_VARBIN); } @Test void accepts_utf8Dtype_returnsTrue() { - // Given / When / Then - assertThat(new VarBinEncoding().accepts(DTypes.UTF8)).isTrue(); + assertThat(ENCODER.accepts(DTypes.UTF8)).isTrue(); } @Test void accepts_binaryDtype_returnsTrue() { - // Given / When / Then - assertThat(new VarBinEncoding().accepts(DTypes.BINARY)).isTrue(); + assertThat(ENCODER.accepts(DTypes.BINARY)).isTrue(); } @Test void accepts_primitiveDtype_returnsFalse() { - // Given / When / Then - assertThat(new VarBinEncoding().accepts(new DType.Primitive(PType.I64, false))).isFalse(); + assertThat(ENCODER.accepts(new DType.Primitive(PType.I64, false))).isFalse(); } @Test void encode_singleString_roundTrips() { - // Given - var sut = new VarBinEncoding(); String[] data = {"hello"}; + EncodeResult result = ENCODER.encode(DTypes.UTF8, data, EncodeTestHelper.testCtx()); + DecodeContext ctx = EncodeTestHelper.toDecodeContext(result, data.length, DTypes.UTF8, REGISTRY); + VarBinArray decoded = (VarBinArray) DECODER.decode(ctx); - // When - EncodeResult result = sut.encode(DTypes.UTF8, data, EncodeTestHelper.testCtx()); - DecodeContext ctx = EncodeTestHelper.toDecodeContext(result, data.length, DTypes.UTF8, buildRegistry()); - VarBinArray decoded = (VarBinArray) sut.decode(ctx); - - // Then assertThat(decoded.length()).isEqualTo(1); assertThat(decoded.getBytes(0)).isEqualTo("hello".getBytes(StandardCharsets.UTF_8)); } @Test void encode_multipleStrings_roundTrips() { - // Given - var sut = new VarBinEncoding(); String[] data = {"foo", "bar", "baz"}; + EncodeResult result = ENCODER.encode(DTypes.UTF8, data, EncodeTestHelper.testCtx()); + DecodeContext ctx = EncodeTestHelper.toDecodeContext(result, data.length, DTypes.UTF8, REGISTRY); + VarBinArray decoded = (VarBinArray) DECODER.decode(ctx); - // When - EncodeResult result = sut.encode(DTypes.UTF8, data, EncodeTestHelper.testCtx()); - DecodeContext ctx = EncodeTestHelper.toDecodeContext(result, data.length, DTypes.UTF8, buildRegistry()); - VarBinArray decoded = (VarBinArray) sut.decode(ctx); - - // Then assertThat(decoded.length()).isEqualTo(3); for (int i = 0; i < data.length; i++) { assertThat(decoded.getBytes(i)).isEqualTo(data[i].getBytes(StandardCharsets.UTF_8)); @@ -87,16 +80,11 @@ void encode_multipleStrings_roundTrips() { @Test void encode_unicodeString_roundTrips() { - // Given - var sut = new VarBinEncoding(); String[] data = {"héllo", "wörld", "日本語"}; + EncodeResult result = ENCODER.encode(DTypes.UTF8, data, EncodeTestHelper.testCtx()); + DecodeContext ctx = EncodeTestHelper.toDecodeContext(result, data.length, DTypes.UTF8, REGISTRY); + VarBinArray decoded = (VarBinArray) DECODER.decode(ctx); - // When - EncodeResult result = sut.encode(DTypes.UTF8, data, EncodeTestHelper.testCtx()); - DecodeContext ctx = EncodeTestHelper.toDecodeContext(result, data.length, DTypes.UTF8, buildRegistry()); - VarBinArray decoded = (VarBinArray) sut.decode(ctx); - - // Then assertThat(decoded.length()).isEqualTo(3); for (int i = 0; i < data.length; i++) { assertThat(decoded.getBytes(i)).isEqualTo(data[i].getBytes(StandardCharsets.UTF_8)); @@ -105,16 +93,11 @@ void encode_unicodeString_roundTrips() { @Test void encode_emptyStringInArray_roundTrips() { - // Given - var sut = new VarBinEncoding(); String[] data = {"a", "", "b"}; + EncodeResult result = ENCODER.encode(DTypes.UTF8, data, EncodeTestHelper.testCtx()); + DecodeContext ctx = EncodeTestHelper.toDecodeContext(result, data.length, DTypes.UTF8, REGISTRY); + VarBinArray decoded = (VarBinArray) DECODER.decode(ctx); - // When - EncodeResult result = sut.encode(DTypes.UTF8, data, EncodeTestHelper.testCtx()); - DecodeContext ctx = EncodeTestHelper.toDecodeContext(result, data.length, DTypes.UTF8, buildRegistry()); - VarBinArray decoded = (VarBinArray) sut.decode(ctx); - - // Then assertThat(decoded.length()).isEqualTo(3); assertThat(decoded.getBytes(0)).isEqualTo(new byte[]{'a'}); assertThat(decoded.getBytes(1)).isEmpty(); @@ -123,16 +106,11 @@ void encode_emptyStringInArray_roundTrips() { @Test void encode_emptyArray_producesZeroLengthResult() { - // Given - var sut = new VarBinEncoding(); String[] data = {}; + EncodeResult result = ENCODER.encode(DTypes.UTF8, data, EncodeTestHelper.testCtx()); + DecodeContext ctx = EncodeTestHelper.toDecodeContext(result, data.length, DTypes.UTF8, REGISTRY); + VarBinArray decoded = (VarBinArray) DECODER.decode(ctx); - // When - EncodeResult result = sut.encode(DTypes.UTF8, data, EncodeTestHelper.testCtx()); - DecodeContext ctx = EncodeTestHelper.toDecodeContext(result, data.length, DTypes.UTF8, buildRegistry()); - VarBinArray decoded = (VarBinArray) sut.decode(ctx); - - // Then assertThat(decoded.length()).isZero(); } } @@ -142,14 +120,10 @@ class Decode { @Test void decode_missingMetadata_throwsVortexException() { - // Given - var sut = new VarBinEncoding(); ArrayNode node = ArrayNode.of(EncodingId.VORTEX_VARBIN, null, new ArrayNode[0], new int[0], null); DecodeContext ctx = new DecodeContext(node, DTypes.UTF8, 3, new MemorySegment[0], Registry.empty(), Arena.ofAuto()); - - // When / Then - assertThatThrownBy(() -> sut.decode(ctx)) + assertThatThrownBy(() -> DECODER.decode(ctx)) .isInstanceOf(VortexException.class) .hasMessageContaining("missing metadata"); } @@ -160,18 +134,12 @@ class Metadata { @Test void encode_utf8_metadata_offsetsPtype_isI64() throws Exception { - // Given — VarBinEncoding always uses I64 (ordinal=7) for offsets - // if tag drifts, offsets_ptype reads as 0 (U8) and decode fails or produces garbage String[] data = {"hello", "world"}; - var sut = new VarBinEncoding(); - - // When - EncodeResult result = sut.encode(DTypes.UTF8, data, EncodeTestHelper.testCtx()); - VarBinMetadata meta = - VarBinMetadata.decode(java.lang.foreign.MemorySegment.ofBuffer(result.rootNode().metadata().duplicate()), 0, java.lang.foreign.MemorySegment.ofBuffer(result.rootNode().metadata().duplicate()).byteSize()); + EncodeResult result = ENCODER.encode(DTypes.UTF8, data, EncodeTestHelper.testCtx()); + var metaSeg = java.lang.foreign.MemorySegment.ofBuffer(result.rootNode().metadata().duplicate()); + VarBinMetadata meta = VarBinMetadata.decode(metaSeg, 0, metaSeg.byteSize()); - // Then - assertThat(meta.offsets_ptype().value()).isEqualTo(7); // I64 + assertThat(meta.offsets_ptype().value()).isEqualTo(7); } } } diff --git a/core/src/test/java/io/github/dfa1/vortex/encoding/VarBinViewEncodingTest.java b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/VarBinViewEncodingEncoderTest.java similarity index 71% rename from core/src/test/java/io/github/dfa1/vortex/encoding/VarBinViewEncodingTest.java rename to writer/src/test/java/io/github/dfa1/vortex/writer/encode/VarBinViewEncodingEncoderTest.java index 457070a0..15ca985d 100644 --- a/core/src/test/java/io/github/dfa1/vortex/encoding/VarBinViewEncodingTest.java +++ b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/VarBinViewEncodingEncoderTest.java @@ -1,6 +1,16 @@ -package io.github.dfa1.vortex.encoding; +package io.github.dfa1.vortex.writer.encode; import io.github.dfa1.vortex.core.array.VarBinArray; +import io.github.dfa1.vortex.encoding.ArrayNode; +import io.github.dfa1.vortex.encoding.DTypes; +import io.github.dfa1.vortex.encoding.DecodeContext; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodeTestHelper; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.PTypeIO; +import io.github.dfa1.vortex.encoding.Registry; +import io.github.dfa1.vortex.encoding.TestRegistry; +import io.github.dfa1.vortex.reader.decode.VarBinViewEncodingDecoder; import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; @@ -14,58 +24,43 @@ import static org.assertj.core.api.Assertions.assertThat; -/// Property: decode reconstructs every string value exactly, -/// regardless of inlined vs reference layout. -class VarBinViewEncodingTest { +class VarBinViewEncodingEncoderTest { + + private static final VarBinViewEncodingEncoder ENCODER = new VarBinViewEncodingEncoder(); + private static final VarBinViewEncodingDecoder DECODER = new VarBinViewEncodingDecoder(); + private static final Registry REGISTRY = TestRegistry.ofDecoders(DECODER); @Nested class Encode { @Test void accepts_utf8_true() { - // Given - var sut = new VarBinViewEncoding(); - - // When / Then - assertThat(sut.accepts(DTypes.UTF8)).isTrue(); + assertThat(ENCODER.accepts(DTypes.UTF8)).isTrue(); } @Test void accepts_binary_true() { - // Given - var sut = new VarBinViewEncoding(); - - // When / Then - assertThat(sut.accepts(DTypes.BINARY)).isTrue(); + assertThat(ENCODER.accepts(DTypes.BINARY)).isTrue(); } @Test void accepts_primitive_false() { - // Given - var sut = new VarBinViewEncoding(); - - // When / Then - assertThat(sut.accepts(DTypes.I32)).isFalse(); + assertThat(ENCODER.accepts(DTypes.I32)).isFalse(); } @ParameterizedTest(name = "{0}") - @MethodSource("io.github.dfa1.vortex.encoding.VarBinViewEncodingTest$Decode#stringArrays") + @MethodSource("io.github.dfa1.vortex.writer.encode.VarBinViewEncodingEncoderTest$Decode#stringArrays") void encode_thenDecode_roundtripsAllStrings(String name, String[] values) { - // Given - var sut = new VarBinViewEncoding(); Arena arena = Arena.ofAuto(); - // When - EncodeResult result = sut.encode(DTypes.UTF8, values, EncodeTestHelper.testCtx()); + EncodeResult result = ENCODER.encode(DTypes.UTF8, values, EncodeTestHelper.testCtx()); MemorySegment[] bufs = result.buffers().toArray(MemorySegment[]::new); ArrayNode node = ArrayNode.of( EncodingId.VORTEX_VARBINVIEW, null, new ArrayNode[0], result.rootNode().bufferIndices(), null); - Registry registry = TestRegistry.of(sut); - DecodeContext ctx = new DecodeContext(node, DTypes.UTF8, values.length, bufs, registry, arena); - var decoded = (VarBinArray) sut.decode(ctx); + DecodeContext ctx = new DecodeContext(node, DTypes.UTF8, values.length, bufs, REGISTRY, arena); + var decoded = (VarBinArray) DECODER.decode(ctx); - // Then assertThat(decoded.length()).isEqualTo(values.length); for (int i = 0; i < values.length; i++) { assertThat(decoded.getString(i)).as("index %d", i).isEqualTo(values[i]); @@ -81,8 +76,8 @@ static Stream stringArrays() { Arguments.of("empty-array", new String[0]), Arguments.of("single-empty-string", new String[]{""}), Arguments.of("short-strings", new String[]{"hi", "ok", "no"}), - Arguments.of("exactly-12-bytes", new String[]{"123456789012"}), // max inlined - Arguments.of("just-over-12-bytes", new String[]{"1234567890123"}), // min reference + Arguments.of("exactly-12-bytes", new String[]{"123456789012"}), + Arguments.of("just-over-12-bytes", new String[]{"1234567890123"}), Arguments.of("long-strings", new String[]{"the quick brown fox jumps over the lazy dog"}), Arguments.of("mixed-lengths", new String[]{"a", "hello", "this is a longer string than twelve"}), Arguments.of("repeated-short", repeat("ab", 50)), @@ -99,13 +94,11 @@ private static String[] repeat(String s, int n) { } @ParameterizedTest(name = "{0}") - @MethodSource("io.github.dfa1.vortex.encoding.VarBinViewEncodingTest$Decode#stringArrays") + @MethodSource("stringArrays") void decode_roundtrip_returnsAllStrings(String name, String[] values) { - // Given Arena arena = Arena.ofAuto(); long n = values.length; - // Encode all long strings into one data buffer byte[][] bytesArr = new byte[values.length][]; int dataBufLen = 0; for (int i = 0; i < values.length; i++) { @@ -124,18 +117,15 @@ void decode_roundtrip_returnsAllStrings(String name, String[] values) { long viewOff = (long) i * 16; views.set(PTypeIO.LE_INT, viewOff, b.length); if (b.length <= 12) { - // inlined: data at viewOff+4 MemorySegment.copy(MemorySegment.ofArray(b), 0, views, viewOff + 4, b.length); } else { - // reference: buffer_index=0, offset=dataOffset - views.set(PTypeIO.LE_INT, viewOff + 8, 0); // buffer_index - views.set(PTypeIO.LE_INT, viewOff + 12, dataOffset); // offset + views.set(PTypeIO.LE_INT, viewOff + 8, 0); + views.set(PTypeIO.LE_INT, viewOff + 12, dataOffset); MemorySegment.copy(MemorySegment.ofArray(b), 0, dataBuf, dataOffset, b.length); dataOffset += b.length; } } - // bufferIndices: [0=dataBuf, 1=views] when data buffer needed, else just [0=views] int[] bufIndices; MemorySegment[] segBufs; if (dataBufLen > 0) { @@ -149,23 +139,17 @@ void decode_roundtrip_returnsAllStrings(String name, String[] values) { ArrayNode node = ArrayNode.of(EncodingId.VORTEX_VARBINVIEW, null, new ArrayNode[0], bufIndices, null); - Registry registry = TestRegistry.of(new VarBinViewEncoding()); - - DecodeContext ctx = new DecodeContext(node, DTypes.UTF8, n, segBufs, registry, arena); - var sut = new VarBinViewEncoding(); + DecodeContext ctx = new DecodeContext(node, DTypes.UTF8, n, segBufs, REGISTRY, arena); - // When - var result = sut.decode(ctx); + var result = DECODER.decode(ctx); - // Then assertThat(result).isInstanceOf(VarBinArray.class); assertThat(result.length()).isEqualTo(n); + VarBinArray varBinArray = (VarBinArray) result; for (int i = 0; i < values.length; i++) { - VarBinArray varBinArray = (VarBinArray) result; String decoded = varBinArray.getString(i); assertThat(decoded).as("index %d", i).isEqualTo(values[i]); } } } - } diff --git a/core/src/test/java/io/github/dfa1/vortex/encoding/ZigZagEncodingTest.java b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/ZigZagEncodingEncoderTest.java similarity index 69% rename from core/src/test/java/io/github/dfa1/vortex/encoding/ZigZagEncodingTest.java rename to writer/src/test/java/io/github/dfa1/vortex/writer/encode/ZigZagEncodingEncoderTest.java index eec143e7..c01c5e5c 100644 --- a/core/src/test/java/io/github/dfa1/vortex/encoding/ZigZagEncodingTest.java +++ b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/ZigZagEncodingEncoderTest.java @@ -1,8 +1,19 @@ -package io.github.dfa1.vortex.encoding; +package io.github.dfa1.vortex.writer.encode; import io.github.dfa1.vortex.core.array.Array; import io.github.dfa1.vortex.core.array.ArraySegments; import io.github.dfa1.vortex.core.array.IntArray; +import io.github.dfa1.vortex.encoding.ArrayNode; +import io.github.dfa1.vortex.encoding.DTypes; +import io.github.dfa1.vortex.encoding.DecodeContext; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodeTestHelper; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.PTypeIO; +import io.github.dfa1.vortex.encoding.Registry; +import io.github.dfa1.vortex.encoding.TestRegistry; +import io.github.dfa1.vortex.reader.decode.PrimitiveEncodingDecoder; +import io.github.dfa1.vortex.reader.decode.ZigZagEncodingDecoder; import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; @@ -17,15 +28,17 @@ import static org.assertj.core.api.Assertions.assertThat; -class ZigZagEncodingTest { +class ZigZagEncodingEncoderTest { + private static final ZigZagEncodingEncoder ENCODER = new ZigZagEncodingEncoder(); + private static final ZigZagEncodingDecoder DECODER = new ZigZagEncodingDecoder(); + private static final Registry REGISTRY = TestRegistry.ofDecoders(DECODER, new PrimitiveEncodingDecoder()); @Nested class Decode { static Stream i32Cases() { return Stream.of( - // zigzag: 0→0, 1→-1, 2→1, 3→-2, 4→2 Arguments.of("zeros", new int[]{0, 0, 0}, new int[]{0, 0, 0}), Arguments.of("mixed", new int[]{0, 1, 2, 3, 4}, new int[]{0, -1, 1, -2, 2}), Arguments.of("large", new int[]{Integer.MAX_VALUE & ~1, (Integer.MAX_VALUE & ~1) | 1}, @@ -44,22 +57,16 @@ private static DecodeContext buildI32Ctx(int[] encodedUnsigned) { ArrayNode primitiveNode = ArrayNode.of(EncodingId.VORTEX_PRIMITIVE, null, new ArrayNode[0], new int[]{0}, null); ArrayNode zigzagNode = ArrayNode.of(EncodingId.VORTEX_ZIGZAG, null, new ArrayNode[]{primitiveNode}, new int[0], null); - Registry registry = TestRegistry.of(new ZigZagEncoding(), new PrimitiveEncoding()); return new DecodeContext(zigzagNode, DTypes.I32, encodedUnsigned.length, - new MemorySegment[]{seg}, registry, Arena.ofAuto()); + new MemorySegment[]{seg}, REGISTRY, Arena.ofAuto()); } @ParameterizedTest(name = "{0}") @MethodSource("i32Cases") void decode_i32_zigzagDecodesCorrectly(String name, int[] encoded, int[] expected) { - // Given DecodeContext ctx = buildI32Ctx(encoded); - var sut = new ZigZagEncoding(); + var result = DECODER.decode(ctx); - // When - var result = sut.decode(ctx); - - // Then assertThat(result).isInstanceOf(IntArray.class); assertThat(result.length()).isEqualTo(expected.length); MemorySegment seg = ArraySegments.of(result); @@ -71,14 +78,8 @@ void decode_i32_zigzagDecodesCorrectly(String name, int[] encoded, int[] expecte @Test void decode_empty_returnsEmptyArray() { - // Given DecodeContext ctx = buildI32Ctx(new int[]{}); - var sut = new ZigZagEncoding(); - - // When - var result = sut.decode(ctx); - - // Then + var result = DECODER.decode(ctx); assertThat(result.length()).isZero(); } } @@ -108,40 +109,26 @@ static Stream i64RoundtripArrays() { @ParameterizedTest @MethodSource("i32RoundtripArrays") void encodeDecode_i32_isLossless(int[] data) { - // Given - var sut = new ZigZagEncoding(); - Registry registry = TestRegistry.withPrimitive(sut); - var le = PTypeIO.LE_INT; + EncodeResult encoded = ENCODER.encode(DTypes.I32, data, EncodeTestHelper.testCtx()); + DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, data.length, DTypes.I32, REGISTRY); + Array result = DECODER.decode(ctx); - // When - EncodeResult encoded = sut.encode(DTypes.I32, data, EncodeTestHelper.testCtx()); - DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, data.length, DTypes.I32, registry); - Array result = sut.decode(ctx); - - // Then assertThat(result.length()).isEqualTo(data.length); for (int i = 0; i < data.length; i++) { - assertThat(ArraySegments.of(result).get(le, (long) i * 4)).as("index %d", i).isEqualTo(data[i]); + assertThat(ArraySegments.of(result).get(PTypeIO.LE_INT, (long) i * 4)).as("index %d", i).isEqualTo(data[i]); } } @ParameterizedTest @MethodSource("i64RoundtripArrays") void encodeDecode_i64_isLossless(long[] data) { - // Given - var sut = new ZigZagEncoding(); - Registry registry = TestRegistry.withPrimitive(sut); - var le = PTypeIO.LE_LONG; - - // When - EncodeResult encoded = sut.encode(DTypes.I64, data, EncodeTestHelper.testCtx()); - DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, data.length, DTypes.I64, registry); - Array result = sut.decode(ctx); + EncodeResult encoded = ENCODER.encode(DTypes.I64, data, EncodeTestHelper.testCtx()); + DecodeContext ctx = EncodeTestHelper.toDecodeContext(encoded, data.length, DTypes.I64, REGISTRY); + Array result = DECODER.decode(ctx); - // Then assertThat(result.length()).isEqualTo(data.length); for (int i = 0; i < data.length; i++) { - assertThat(ArraySegments.of(result).get(le, (long) i * 8)).as("index %d", i).isEqualTo(data[i]); + assertThat(ArraySegments.of(result).get(PTypeIO.LE_LONG, (long) i * 8)).as("index %d", i).isEqualTo(data[i]); } } } diff --git a/core/src/test/java/io/github/dfa1/vortex/encoding/ZstdEncodingTest.java b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/ZstdEncodingEncoderTest.java similarity index 73% rename from core/src/test/java/io/github/dfa1/vortex/encoding/ZstdEncodingTest.java rename to writer/src/test/java/io/github/dfa1/vortex/writer/encode/ZstdEncodingEncoderTest.java index b1b0eaff..f943632c 100644 --- a/core/src/test/java/io/github/dfa1/vortex/encoding/ZstdEncodingTest.java +++ b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/ZstdEncodingEncoderTest.java @@ -1,7 +1,5 @@ -package io.github.dfa1.vortex.encoding; +package io.github.dfa1.vortex.writer.encode; -import io.github.dfa1.vortex.proto.ZstdFrameMetadata; -import io.github.dfa1.vortex.proto.ZstdMetadata; import com.github.luben.zstd.ZstdCompressCtx; import io.airlift.compress.v3.zstd.ZstdCompressor; import io.airlift.compress.v3.zstd.ZstdJavaCompressor; @@ -12,6 +10,19 @@ import io.github.dfa1.vortex.core.array.LongArray; import io.github.dfa1.vortex.core.array.MaskedArray; import io.github.dfa1.vortex.core.array.VarBinArray; +import io.github.dfa1.vortex.encoding.ArrayNode; +import io.github.dfa1.vortex.encoding.DTypes; +import io.github.dfa1.vortex.encoding.DecodeContext; +import io.github.dfa1.vortex.encoding.EncodeNode; +import io.github.dfa1.vortex.encoding.EncodeResult; +import io.github.dfa1.vortex.encoding.EncodeTestHelper; +import io.github.dfa1.vortex.encoding.EncodingId; +import io.github.dfa1.vortex.encoding.Registry; +import io.github.dfa1.vortex.encoding.TestRegistry; +import io.github.dfa1.vortex.proto.ZstdFrameMetadata; +import io.github.dfa1.vortex.proto.ZstdMetadata; +import io.github.dfa1.vortex.reader.decode.BoolEncodingDecoder; +import io.github.dfa1.vortex.reader.decode.ZstdEncodingDecoder; import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; @@ -27,24 +38,21 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; -class ZstdEncodingTest { +class ZstdEncodingEncoderTest { + + private static final ZstdEncodingEncoder ENCODER = new ZstdEncodingEncoder(); + private static final ZstdEncodingDecoder DECODER = new ZstdEncodingDecoder(); + private static final BoolEncodingEncoder BOOL_ENCODER = new BoolEncodingEncoder(); @Nested class Encode { @Test void encode_i32_roundTrips() { - // Given - var sut = new ZstdEncoding(); int[] data = {10, 20, 30, 40}; - - // When - EncodeResult result = sut.encode(DTypes.I32, data, EncodeTestHelper.testCtx()); + EncodeResult result = ENCODER.encode(DTypes.I32, data, EncodeTestHelper.testCtx()); DecodeContext ctx = EncodeTestHelper.toDecodeContext(result, data.length, DTypes.I32, Registry.empty()); - IntArray decoded = (IntArray) sut.decode(ctx); - - // Then - assertThat(decoded.length()).isEqualTo(data.length); + IntArray decoded = (IntArray) DECODER.decode(ctx); for (int i = 0; i < data.length; i++) { assertThat(decoded.getInt(i)).as("index %d", i).isEqualTo(data[i]); } @@ -52,17 +60,10 @@ void encode_i32_roundTrips() { @Test void encode_i64_roundTrips() { - // Given - var sut = new ZstdEncoding(); long[] data = {100L, 200L, 300L}; - - // When - EncodeResult result = sut.encode(DTypes.I64, data, EncodeTestHelper.testCtx()); + EncodeResult result = ENCODER.encode(DTypes.I64, data, EncodeTestHelper.testCtx()); DecodeContext ctx = EncodeTestHelper.toDecodeContext(result, data.length, DTypes.I64, Registry.empty()); - LongArray decoded = (LongArray) sut.decode(ctx); - - // Then - assertThat(decoded.length()).isEqualTo(data.length); + LongArray decoded = (LongArray) DECODER.decode(ctx); for (int i = 0; i < data.length; i++) { assertThat(decoded.getLong(i)).as("index %d", i).isEqualTo(data[i]); } @@ -70,17 +71,10 @@ void encode_i64_roundTrips() { @Test void encode_utf8_roundTrips() { - // Given - var sut = new ZstdEncoding(); String[] data = {"hello", "world", "zstd"}; - - // When - EncodeResult result = sut.encode(DTypes.UTF8, data, EncodeTestHelper.testCtx()); + EncodeResult result = ENCODER.encode(DTypes.UTF8, data, EncodeTestHelper.testCtx()); DecodeContext ctx = EncodeTestHelper.toDecodeContext(result, data.length, DTypes.UTF8, Registry.empty()); - VarBinArray decoded = (VarBinArray) sut.decode(ctx); - - // Then - assertThat(decoded.length()).isEqualTo(data.length); + VarBinArray decoded = (VarBinArray) DECODER.decode(ctx); for (int i = 0; i < data.length; i++) { assertThat(decoded.getString(i)).as("index %d", i).isEqualTo(data[i]); } @@ -88,26 +82,16 @@ void encode_utf8_roundTrips() { @Test void encode_emptyArray_roundTrips() { - // Given - var sut = new ZstdEncoding(); int[] data = {}; - - // When - EncodeResult result = sut.encode(DTypes.I32, data, EncodeTestHelper.testCtx()); + EncodeResult result = ENCODER.encode(DTypes.I32, data, EncodeTestHelper.testCtx()); DecodeContext ctx = EncodeTestHelper.toDecodeContext(result, data.length, DTypes.I32, Registry.empty()); - IntArray decoded = (IntArray) sut.decode(ctx); - - // Then + IntArray decoded = (IntArray) DECODER.decode(ctx); assertThat(decoded.length()).isZero(); } @Test void encode_unsupportedDtype_throwsVortexException() { - // Given - var sut = new ZstdEncoding(); - - // When / Then - assertThatThrownBy(() -> sut.encode(new DType.Null(false), null, EncodeTestHelper.testCtx())) + assertThatThrownBy(() -> ENCODER.encode(new DType.Null(false), null, EncodeTestHelper.testCtx())) .isInstanceOf(VortexException.class); } } @@ -118,7 +102,6 @@ class Decode { private static DecodeContext makeDictCtx( byte[] meta, DType dtype, long n, byte[] dictBytes, byte[]... compressedFrames ) { - // buffer[0] = dict, buffer[1..] = frames MemorySegment[] segments = new MemorySegment[1 + compressedFrames.length]; segments[0] = MemorySegment.ofArray(dictBytes); int[] bufIndices = new int[1 + compressedFrames.length]; @@ -133,7 +116,6 @@ private static DecodeContext makeDictCtx( } private static byte[] makeDictFor(byte[]... samples) { - // Repeat samples to meet zstd's minimum training data requirement (~1 KB) int total = 0; for (byte[] s : samples) { total += s.length; @@ -160,8 +142,7 @@ private static byte[] compressWithDict(byte[] data, byte[] dictBytes) { private static DecodeContext makeNullableCtx( byte[] meta, DType dtype, long n, boolean[] validityBits, byte[]... compressedFrames ) { - BoolEncoding boolEncoding = new BoolEncoding(); - EncodeResult validityResult = boolEncoding.encode(new DType.Bool(false), validityBits, EncodeTestHelper.testCtx()); + EncodeResult validityResult = BOOL_ENCODER.encode(new DType.Bool(false), validityBits, EncodeTestHelper.testCtx()); EncodeNode remappedValidity = EncodeNode.remapBufferIndices( validityResult.rootNode(), compressedFrames.length); @@ -177,7 +158,7 @@ private static DecodeContext makeNullableCtx( ArrayNode node = ArrayNode.of(EncodingId.VORTEX_ZSTD, ByteBuffer.wrap(meta), new ArrayNode[]{validityNode}, bufIndices, null); - Registry registry = Registry.builder().register(new BoolEncoding()).build(); + Registry registry = TestRegistry.ofDecoders(new BoolEncodingDecoder()); return new DecodeContext(node, dtype, n, allSegments.toArray(new MemorySegment[0]), registry, Arena.ofAuto()); @@ -191,18 +172,6 @@ private static ArrayNode toArrayNode(EncodeNode enc) { return ArrayNode.of(enc.encodingId(), enc.metadata(), children, enc.bufferIndices(), null); } - private static DecodeContext makeCtx(byte[] meta, DType dtype, long n, byte[]... compressedFrames) { - MemorySegment[] segments = new MemorySegment[compressedFrames.length]; - int[] bufIndices = new int[compressedFrames.length]; - for (int i = 0; i < compressedFrames.length; i++) { - segments[i] = MemorySegment.ofArray(compressedFrames[i]); - bufIndices[i] = i; - } - ArrayNode node = ArrayNode.of(EncodingId.VORTEX_ZSTD, ByteBuffer.wrap(meta), - new ArrayNode[0], bufIndices, null); - return new DecodeContext(node, dtype, n, segments, Registry.empty(), Arena.ofAuto()); - } - private static byte[] compress(byte[] input) { ZstdCompressor compressor = new ZstdJavaCompressor(); byte[] out = new byte[compressor.maxCompressedLength(input.length)]; @@ -226,14 +195,6 @@ private static byte[] toLeBytes(int[] values) { return buf.array(); } - private static byte[] toLeBytes(long[] values) { - ByteBuffer buf = ByteBuffer.allocate(values.length * 8).order(ByteOrder.LITTLE_ENDIAN); - for (long v : values) { - buf.putLong(v); - } - return buf.array(); - } - private static byte[] toLengthPrefixed(String[] strings) { int total = 0; for (String s : strings) { @@ -250,8 +211,6 @@ private static byte[] toLengthPrefixed(String[] strings) { @Test void decode_withDictionary_utf8_roundTrips() { - // Given - var sut = new ZstdEncoding(); String[] strings = {"hello", "world", "zstd"}; byte[] raw = toLengthPrefixed(strings); byte[] dictBytes = makeDictFor(raw); @@ -260,10 +219,8 @@ void decode_withDictionary_utf8_roundTrips() { java.util.List.of(new ZstdFrameMetadata(raw.length, strings.length))).encode(); DecodeContext ctx = makeDictCtx(meta, DTypes.UTF8, strings.length, dictBytes, compressed); - // When - VarBinArray result = (VarBinArray) sut.decode(ctx); + VarBinArray result = (VarBinArray) DECODER.decode(ctx); - // Then assertThat(result.length()).isEqualTo(strings.length); for (int i = 0; i < strings.length; i++) { assertThat(result.getString(i)).as("index %d", i).isEqualTo(strings[i]); @@ -272,8 +229,6 @@ void decode_withDictionary_utf8_roundTrips() { @Test void decode_withDictionary_multipleFrames_roundTrips() { - // Given - var sut = new ZstdEncoding(); int[] frame0 = {1, 2, 3}; int[] frame1 = {4, 5}; byte[] raw0 = toLeBytes(frame0); @@ -286,10 +241,8 @@ void decode_withDictionary_multipleFrames_roundTrips() { new ZstdFrameMetadata(raw1.length, frame1.length))).encode(); DecodeContext ctx = makeDictCtx(meta, DTypes.I32, 5, dictBytes, comp0, comp1); - // When - IntArray result = (IntArray) sut.decode(ctx); + IntArray result = (IntArray) DECODER.decode(ctx); - // Then assertThat(result.length()).isEqualTo(5); for (int i = 0; i < 3; i++) { assertThat(result.getInt(i)).isEqualTo(frame0[i]); @@ -301,11 +254,7 @@ void decode_withDictionary_multipleFrames_roundTrips() { @Test void decode_nullable_primitive_scattersValuesCorrectly() { - // Given - var sut = new ZstdEncoding(); - // validity: [true, false, true, false] — positions 0,2 are valid boolean[] validityBits = {true, false, true, false}; - // only valid values compressed: 10, 30 byte[] raw = toLeBytes(new int[]{10, 30}); byte[] compressed = compress(raw); DType i32Nullable = new DType.Primitive(PType.I32, true); @@ -313,10 +262,8 @@ void decode_nullable_primitive_scattersValuesCorrectly() { metaNoDict(new long[]{raw.length}, new long[]{2}), i32Nullable, 4, validityBits, compressed); - // When - MaskedArray result = (MaskedArray) sut.decode(ctx); + MaskedArray result = (MaskedArray) DECODER.decode(ctx); - // Then assertThat(result.length()).isEqualTo(4); assertThat(result.isValid(0)).isTrue(); assertThat(result.isValid(1)).isFalse(); @@ -329,11 +276,7 @@ void decode_nullable_primitive_scattersValuesCorrectly() { @Test void decode_nullable_utf8_scattersValuesCorrectly() { - // Given - var sut = new ZstdEncoding(); - // validity: [true, false, true] — positions 0,2 are valid boolean[] validityBits = {true, false, true}; - // only valid strings compressed byte[] raw = toLengthPrefixed(new String[]{"hello", "world"}); byte[] compressed = compress(raw); DType utf8Nullable = new DType.Utf8(true); @@ -341,10 +284,8 @@ void decode_nullable_utf8_scattersValuesCorrectly() { metaNoDict(new long[]{raw.length}, new long[]{2}), utf8Nullable, 3, validityBits, compressed); - // When - MaskedArray result = (MaskedArray) sut.decode(ctx); + MaskedArray result = (MaskedArray) DECODER.decode(ctx); - // Then assertThat(result.length()).isEqualTo(3); assertThat(result.isValid(0)).isTrue(); assertThat(result.isValid(1)).isFalse(); @@ -356,10 +297,7 @@ void decode_nullable_utf8_scattersValuesCorrectly() { @Test void decode_allNull_returnsEmptyMaskedArray() { - // Given - var sut = new ZstdEncoding(); boolean[] validityBits = {false, false, false}; - // no valid values — zero-length compressed buffer byte[] raw = new byte[0]; byte[] compressed = compress(raw); DType i32Nullable = new DType.Primitive(PType.I32, true); @@ -367,10 +305,8 @@ void decode_allNull_returnsEmptyMaskedArray() { metaNoDict(new long[]{raw.length}, new long[]{0}), i32Nullable, 3, validityBits, compressed); - // When - MaskedArray result = (MaskedArray) sut.decode(ctx); + MaskedArray result = (MaskedArray) DECODER.decode(ctx); - // Then assertThat(result.length()).isEqualTo(3); assertThat(result.isValid(0)).isFalse(); assertThat(result.isValid(1)).isFalse(); @@ -379,14 +315,11 @@ void decode_allNull_returnsEmptyMaskedArray() { @Test void decode_missingMetadata_throwsVortexException() { - // Given - var sut = new ZstdEncoding(); ArrayNode node = ArrayNode.of(EncodingId.VORTEX_ZSTD, null, new ArrayNode[0], new int[0], null); DecodeContext ctx = new DecodeContext(node, DTypes.I32, 0, new MemorySegment[0], Registry.empty(), Arena.ofAuto()); - // When / Then - assertThatThrownBy(() -> sut.decode(ctx)) + assertThatThrownBy(() -> DECODER.decode(ctx)) .isInstanceOf(VortexException.class) .hasMessageContaining("missing metadata"); } @@ -397,20 +330,14 @@ class Metadata { @Test void encode_i32_metadata_framesCount_isNonZero() throws Exception { - // Given — any non-empty encode produces at least one zstd frame - // if tag drifts, frames list is empty and decode silently produces no data int[] data = new int[100]; for (int i = 0; i < data.length; i++) { data[i] = i; } - ZstdEncoding sut = new ZstdEncoding(); - - // When - EncodeResult result = sut.encode(DTypes.I32, data, EncodeTestHelper.testCtx()); - ZstdMetadata meta = - ZstdMetadata.decode(java.lang.foreign.MemorySegment.ofBuffer(result.rootNode().metadata().duplicate()), 0, java.lang.foreign.MemorySegment.ofBuffer(result.rootNode().metadata().duplicate()).byteSize()); + EncodeResult result = ENCODER.encode(DTypes.I32, data, EncodeTestHelper.testCtx()); + var metaSeg = java.lang.foreign.MemorySegment.ofBuffer(result.rootNode().metadata().duplicate()); + ZstdMetadata meta = ZstdMetadata.decode(metaSeg, 0, metaSeg.byteSize()); - // Then assertThat(meta.frames().size()).isGreaterThan(0); } }