Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -571,7 +571,7 @@ private void runDictLoad(InspectorTree.Node dictNode) {
try (java.lang.foreign.Arena arena = java.lang.foreign.Arena.ofConfined()) {
int segIdx = values.segments().getFirst();
SegmentSpec spec = tree.segmentSpecs().get(segIdx);
java.lang.foreign.MemorySegment seg = handle.slice(spec.offset(), spec.length());
java.lang.foreign.MemorySegment seg = handle.slice(spec.offset(), spec.length()).unwrapForSubParser("inspector tui flat segment");
io.github.dfa1.vortex.core.array.Array arr =
new io.github.dfa1.vortex.encoding.FlatSegmentDecoder(handle.registry())
.decode(seg, handle.footer().arraySpecs(),
Expand Down Expand Up @@ -760,7 +760,7 @@ private byte[] fetchHex(InspectorTree.Node node) {
return new byte[0];
}
try {
MemorySegment seg = handle.slice(spec.offset(), wanted);
MemorySegment seg = handle.slice(spec.offset(), wanted).unwrapForSubParser("inspector tui hex peek");
byte[] buf = new byte[wanted];
MemorySegment.copy(seg, 0, MemorySegment.ofArray(buf), 0, wanted);
return buf;
Expand Down
123 changes: 123 additions & 0 deletions core/src/main/java/io/github/dfa1/vortex/core/BoundedSegment.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
package io.github.dfa1.vortex.core;

import java.lang.foreign.MemorySegment;
import java.lang.foreign.ValueLayout;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;

/// A memory-mapped region with built-in bounds-checking for slicing on untrusted input.
///
/// <p>By construction, callers cannot reach {@link MemorySegment#asSlice(long, long)} without
/// going through {@link #slice(long, long, String)}, which routes the offset/length through
/// {@link MemorySegments#slice} and throws {@link VortexException} on malformed input —
/// never {@link IndexOutOfBoundsException}.
///
/// <p>The {@code context} label travels with the type; nested slices receive an explicit
/// child label at the {@link #slice} site. Error messages thus name the on-disk structure
/// ({@code "trailer"}, {@code "postscript blob"}, {@code "encoded buffer 3"}) rather than
/// surfacing raw byte offsets.
///
/// <p>The raw segment is exposed only via {@link #unwrapForSubParser(String)}, which both
/// documents the trust transfer and forces a {@code reason} string so every escape-hatch
/// site is greppable for audit.
///
/// @param seg the backing memory-mapped region; lifetime tied to the {@link
/// java.lang.foreign.Arena Arena} that produced it
/// @param context human-readable label naming the on-disk structure this region represents
public record BoundedSegment(MemorySegment seg, String context) {

private static final ValueLayout.OfByte BYTE = ValueLayout.JAVA_BYTE;
private static final ValueLayout.OfInt LE_INT =
ValueLayout.JAVA_INT_UNALIGNED.withOrder(ByteOrder.LITTLE_ENDIAN);
private static final ValueLayout.OfLong LE_LONG =
ValueLayout.JAVA_LONG_UNALIGNED.withOrder(ByteOrder.LITTLE_ENDIAN);

/// @return total size of the bounded region in bytes
public long byteSize() {
return seg.byteSize();
}

/// Returns a sub-region with a fresh context label.
///
/// @param off start offset in bytes, relative to this region
/// @param len slice length in bytes
/// @param childContext label for the resulting sub-region
/// @return the bounded sub-region
/// @throws VortexException if {@code off} or {@code len} is negative, or if
/// {@code off + len > this.byteSize()}
public BoundedSegment slice(long off, long len, String childContext) {
checkRange(off, len);
return new BoundedSegment(seg.asSlice(off, len), childContext);
}

/// Bounds-checked single-byte read.
///
/// @param off byte offset
/// @return the byte at {@code off}
/// @throws VortexException if {@code off} is negative or {@code >= this.byteSize()}
public byte getByte(long off) {
checkRange(off, 1);
return seg.get(BYTE, off);
}

/// Bounds-checked little-endian 32-bit read.
///
/// @param off byte offset of the 4-byte word
/// @return the int at {@code off}
/// @throws VortexException if {@code off} is negative or {@code > this.byteSize() - 4}
public int getIntLE(long off) {
checkRange(off, 4);
return seg.get(LE_INT, off);
}

/// Bounds-checked little-endian 64-bit read.
///
/// @param off byte offset of the 8-byte word
/// @return the long at {@code off}
/// @throws VortexException if {@code off} is negative or {@code > this.byteSize() - 8}
public long getLongLE(long off) {
checkRange(off, 8);
return seg.get(LE_LONG, off);
}

private void checkRange(long off, long len) {
long segSize = seg.byteSize();
if (off < 0) {
throw new VortexException("malformed " + context + ": negative offset " + off);
}
if (len < 0) {
throw new VortexException("malformed " + context + ": negative length " + len);
}
// Overflow-safe form of `off + len > segSize`. The subtraction can't underflow because
// len has already been bounded against segSize on the line above (segSize >= 0 always).
if (len > segSize || off > segSize - len) {
throw new VortexException("malformed " + context + ": offset+length "
+ off + "+" + len + " exceeds segment size " + segSize);
}
}

/// Little-endian {@link ByteBuffer} view of the whole bounded region, used by the
/// FlatBuffer runtime (which performs its own offset validation against the buffer's
/// capacity).
///
/// @return a {@link ByteBuffer} view in little-endian order
public ByteBuffer asByteBufferLE() {
return seg.asByteBuffer().order(ByteOrder.LITTLE_ENDIAN);
}

/// Escape hatch returning the raw {@link MemorySegment} for a downstream parser that
/// takes its own bounds-checked cursor (currently {@link
/// io.github.dfa1.vortex.proto.ProtoReader}). The {@code reason} string names the
/// sub-parser for diagnostic attribution at the call site.
///
/// <p><strong>Audit point.</strong> Every call to this method is a trust transfer
/// across the bounds-checking boundary. New call sites must justify in review why
/// the receiver re-validates the bounds itself.
///
/// @param reason short label naming the sub-parser ({@code "proto reader"},
/// {@code "flatbuffer root"})
/// @return the raw memory segment
public MemorySegment unwrapForSubParser(String reason) {
return seg;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ static Array decode(DecodeContext ctx) {
int typeBits = ptype.byteSize() * 8;
long rowCount = ctx.rowCount();

MemorySegment packed = ctx.buffer(0);
MemorySegment packed = ctx.buffer(0).unwrapForSubParser("bitpacked encoding");
MemorySegment output = ctx.arena().allocate(rowCount * ptype.byteSize());
fastlanesUnpackToSeg(packed, bitWidth, offset, typeBits, rowCount, output);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,6 @@ public EncodeResult encode(DType dtype, Object data, EncodeContext ctx) {

@Override
public Array decode(DecodeContext ctx) {
return new BoolArray(ctx.dtype(), ctx.rowCount(), ctx.buffer(0));
return new BoolArray(ctx.dtype(), ctx.rowCount(), ctx.buffer(0).unwrapForSubParser("bool encoding"));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ public EncodeResult encode(DType dtype, Object data, EncodeContext ctx) {
@Override
public Array decode(DecodeContext ctx) {
long n = ctx.rowCount();
MemorySegment bytes = ctx.buffer(0);
MemorySegment bytes = ctx.buffer(0).unwrapForSubParser("bytebool encoding");
long packedBytes = (n + 7) >>> 3;
MemorySegment packed = ctx.arena().allocate(packedBytes > 0 ? packedBytes : 1);
for (long i = 0; i < n; i++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ private static ScalarValue buildScalar(PType ptype, long rawBits) {
private static final class Decoder {

private static Array decode(DecodeContext ctx) {
MemorySegment scalarBuf = ctx.buffer(0);
MemorySegment scalarBuf = ctx.buffer(0).unwrapForSubParser("constant encoding");
ScalarValue scalar;
try {
scalar = ScalarValue.decode(scalarBuf, 0, scalarBuf.byteSize());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ private static Array decode(DecodeContext ctx) {
}
int valuesType = decoded.values_type();
int byteWidth = decimalTypeByteWidth(valuesType);
MemorySegment buffer = ctx.buffer(0);
MemorySegment buffer = ctx.buffer(0).unwrapForSubParser("decimal encoding");
long expected = ctx.rowCount() * byteWidth;
if (buffer.byteSize() < expected) {
throw new VortexException(EncodingId.VORTEX_DECIMAL,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package io.github.dfa1.vortex.encoding;

import io.github.dfa1.vortex.core.BoundedSegment;
import io.github.dfa1.vortex.core.DType;
import io.github.dfa1.vortex.core.array.Array;

Expand All @@ -24,10 +25,32 @@ public record DecodeContext(
ArrayNode node,
DType dtype,
long rowCount,
MemorySegment[] segmentBuffers,
BoundedSegment[] segmentBuffers,
Registry registry,
SegmentAllocator arena
) {
/// Convenience factory that wraps raw {@link MemorySegment} buffers as {@link BoundedSegment}s
/// for tests and other callers that produce synthetic, trusted buffer arrays. Production
/// decoders receive their buffers from {@link FlatSegmentDecoder}, which already wraps them
/// against the parent flat segment.
///
/// @param node array node describing this encoding's tree structure
/// @param dtype logical type expected for the decoded array
/// @param rowCount number of logical rows to decode
/// @param rawBufs raw segment buffers; each wrapped as {@code "test buffer i"}
/// @param registry encoding registry used for recursive child decoding
/// @param arena allocator for decode output
/// @return a {@link DecodeContext} backed by bounded views of {@code rawBufs}
public static DecodeContext ofRawBuffers(
ArrayNode node, DType dtype, long rowCount,
MemorySegment[] rawBufs, Registry registry, SegmentAllocator arena) {
BoundedSegment[] wrapped = new BoundedSegment[rawBufs.length];
for (int i = 0; i < rawBufs.length; i++) {
wrapped[i] = new BoundedSegment(rawBufs[i], "test buffer " + i);
}
return new DecodeContext(node, dtype, rowCount, wrapped, registry, arena);
}

/// Recursively decode child {@code i} using this context's dtype and row count.
///
/// @param i zero-based child index within this node's children array
Expand Down Expand Up @@ -78,8 +101,8 @@ public MemorySegment decodeChildSegment(int i, DType dtype, long rowCount) {
/// Return the buffer at position `i` in this node's bufferIndices.
///
/// @param i zero-based index into this node's {@code bufferIndices} array
/// @return the {@link MemorySegment} for the referenced segment buffer
public MemorySegment buffer(int i) {
/// @return the {@link BoundedSegment} for the referenced segment buffer
public BoundedSegment buffer(int i) {
return segmentBuffers[node.bufferIndices()[i]];
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ private static Array decodeLegacyJava(DecodeContext ctx, byte codeTypeByte) {
long rowCount = ctx.rowCount();

// Values: always VORTEX_PRIMITIVE leaf, read direct
MemorySegment valuesBuf = ctx.segmentBuffers()[ctx.node().children()[0].bufferIndices()[0]];
MemorySegment valuesBuf = ctx.segmentBuffers()[ctx.node().children()[0].bufferIndices()[0]].unwrapForSubParser("dict encoding values");

// Codes: decode through registry — supports both raw (VORTEX_PRIMITIVE) and cascade (FASTLANES_BITPACKED) children
DType codesDtype = new DType.Primitive(codePType, false);
Expand Down Expand Up @@ -435,9 +435,9 @@ private static Array decodeUtf8DictLegacy(DecodeContext ctx, ByteBuffer meta) {
PType codePType = PType.fromOrdinal(Byte.toUnsignedInt(meta.get(0)));
long n = ctx.rowCount();

MemorySegment dictBytes = ctx.buffer(0);
MemorySegment dictOffsets = ctx.buffer(1);
MemorySegment codes = ctx.buffer(2);
MemorySegment dictBytes = ctx.buffer(0).unwrapForSubParser("dict encoding");
MemorySegment dictOffsets = ctx.buffer(1).unwrapForSubParser("dict encoding");
MemorySegment codes = ctx.buffer(2).unwrapForSubParser("dict encoding");

return VarBinArray.ofDict(ctx.dtype(), n,
dictBytes, dictOffsets, PType.I64,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
package io.github.dfa1.vortex.encoding;

import io.github.dfa1.vortex.core.DType;
import io.github.dfa1.vortex.core.ArrayStats;
import io.github.dfa1.vortex.core.BoundedSegment;
import io.github.dfa1.vortex.core.DType;
import io.github.dfa1.vortex.core.array.Array;
import io.github.dfa1.vortex.fbs.Buffer;

Expand Down Expand Up @@ -50,12 +51,13 @@ public Array decode(MemorySegment seg, List<String> encodingSpecs,
var fbArray = io.github.dfa1.vortex.fbs.Array.getRootAsArray(fbBuf);

int numBuffers = fbArray.buffersLength();
MemorySegment[] bufs = new MemorySegment[numBuffers];
BoundedSegment[] bufs = new BoundedSegment[numBuffers];
BoundedSegment region = new BoundedSegment(seg, "flat segment");
long dataOffset = 0;
for (int i = 0; i < numBuffers; i++) {
Buffer bufDesc = fbArray.buffers(i);
dataOffset += bufDesc.padding();
bufs[i] = seg.asSlice(dataOffset, bufDesc.length());
bufs[i] = region.slice(dataOffset, bufDesc.length(), "encoded buffer " + i);
dataOffset += bufDesc.length();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -206,9 +206,9 @@ private static Array decode(DecodeContext ctx) {

long n = ctx.rowCount();

MemorySegment symbolsBuf = ctx.buffer(0); // 8 bytes per symbol (LE u64)
MemorySegment symbolLensBuf = ctx.buffer(1); // 1 byte per symbol
MemorySegment compressedBytes = ctx.buffer(2); // FSST-compressed heap
MemorySegment symbolsBuf = ctx.buffer(0).unwrapForSubParser("fsst encoding"); // 8 bytes per symbol (LE u64)
MemorySegment symbolLensBuf = ctx.buffer(1).unwrapForSubParser("fsst encoding"); // 1 byte per symbol
MemorySegment compressedBytes = ctx.buffer(2).unwrapForSubParser("fsst encoding"); // FSST-compressed heap

MemorySegment uncompLensSeg = ctx.decodeChildSegment(0, new DType.Primitive(uncompLenPType, false), n);
MemorySegment codesOffsetsSeg = ctx.decodeChildSegment(1, new DType.Primitive(codesOffPType, false), n + 1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ static Array decode(DecodeContext ctx) {

for (int c = 0; c < nChunks; c++) {
PcoChunkInfo chunkInfo = meta.chunks().get(c);
MemorySegment chunkMetaBuf = ctx.buffer(bufIdx++);
MemorySegment chunkMetaBuf = ctx.buffer(bufIdx++).unwrapForSubParser("pco encoding");
PcoChunkMeta chunkMeta = readChunkMeta(chunkMetaBuf, dtypeSize);

int mode = chunkMeta.mode();
Expand All @@ -160,7 +160,7 @@ static Array decode(DecodeContext ctx) {
chunkMeta.ansSizeLog(), chunkMeta.bins());
for (int p = 0; p < chunkInfo.pages().size(); p++) {
int pageN = chunkInfo.pages().get(p).n_values();
MemorySegment pageBuf = ctx.buffer(bufIdx++);
MemorySegment pageBuf = ctx.buffer(bufIdx++).unwrapForSubParser("pco encoding");
rawByteOffset = decodeConv1Page(
primaryTans, chunkMeta.ansSizeLog(),
chunkMeta.conv1Weights().length,
Expand All @@ -186,7 +186,7 @@ static Array decode(DecodeContext ctx) {
long mask = typeMask(dtypeSize);
for (int p = 0; p < chunkInfo.pages().size(); p++) {
int pageN = chunkInfo.pages().get(p).n_values();
MemorySegment pageBuf = ctx.buffer(bufIdx++);
MemorySegment pageBuf = ctx.buffer(bufIdx++).unwrapForSubParser("pco encoding");
rawByteOffset = decodeLookbackPage(
deltaTans, chunkMeta.deltaAnsSizeLog(),
primaryTans, chunkMeta.ansSizeLog(),
Expand All @@ -202,7 +202,7 @@ static Array decode(DecodeContext ctx) {
PcoTansDecoder tans = PcoTansDecoder.build(chunkMeta.ansSizeLog(), chunkMeta.bins());
for (int p = 0; p < chunkInfo.pages().size(); p++) {
int pageN = chunkInfo.pages().get(p).n_values();
MemorySegment pageBuf = ctx.buffer(bufIdx++);
MemorySegment pageBuf = ctx.buffer(bufIdx++).unwrapForSubParser("pco encoding");
rawByteOffset = decodeClassicPage(tans, chunkMeta.ansSizeLog(),
chunkMeta.deltaOrder(), primaryDtypeSize,
pageBuf, pageN, rawLatents, rawByteOffset,
Expand All @@ -225,7 +225,7 @@ static Array decode(DecodeContext ctx) {
long adjByteOffset = 0L;
for (int p = 0; p < chunkInfo.pages().size(); p++) {
int pageN = chunkInfo.pages().get(p).n_values();
MemorySegment pageBuf = ctx.buffer(bufIdx++);
MemorySegment pageBuf = ctx.buffer(bufIdx++).unwrapForSubParser("pco encoding");
decodeIntMultPage(primaryTans, primaryAnsSizeLog, deltaOrder,
secondaryTans, secondaryAnsSizeLog, secondaryDeltaOrder,
dtypeSize, pageBuf, pageN,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ private static byte[] scalarF64(double v) {
private static final class Decoder {

private static Array decode(DecodeContext ctx) {
MemorySegment buf = ctx.buffer(0);
MemorySegment buf = ctx.buffer(0).unwrapForSubParser("primitive encoding");
long n = ctx.rowCount();
DType dt = ctx.dtype();
PType ptype = ((DType.Primitive) dt).ptype();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ private static UnknownArray decodeUnknown(DecodeContext ctx, ArrayNode node) {
};
MemorySegment[] bufs = new MemorySegment[node.bufferIndices().length];
for (int i = 0; i < bufs.length; i++) {
bufs[i] = ctx.buffer(i);
bufs[i] = ctx.buffer(i).unwrapForSubParser("registry");
}
Array[] children = new Array[node.children().length];
for (int i = 0; i < children.length; i++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ private static Array decode(DecodeContext ctx) {
}
PType valuePtype = ((DType.Primitive) ctx.dtype()).ptype();

MemorySegment fillBuf = ctx.buffer(0);
MemorySegment fillBuf = ctx.buffer(0).unwrapForSubParser("sparse encoding");
ScalarValue fillScalar;
try {
fillScalar = ScalarValue.decode(fillBuf, 0, fillBuf.byteSize());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ private static Array decode(DecodeContext ctx) {
offsets = materialized;
}

MemorySegment bytes = ctx.buffer(0);
MemorySegment bytes = ctx.buffer(0).unwrapForSubParser("varbin encoding");

return new VarBinArray(ctx.dtype(), n, bytes, offsets, offsetsPtype);
}
Expand Down
Loading
Loading