Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tls_codec/benches/quic_vec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ fn byte_slice(c: &mut Criterion) {
c.bench_function("TLS Serialize VL Byte Slice", |b| {
b.iter_batched_ref(
|| (vec![77u8; N], Vec::with_capacity(8 + N)),
|(long_vec, buf)| VLByteSlice(long_vec).tls_serialize(buf).unwrap(),
|(long_vec, buf)| Serialize::tls_serialize(&VLByteSlice(long_vec), buf).unwrap(),
BatchSize::SmallInput,
)
});
Expand Down
1 change: 1 addition & 0 deletions tls_codec/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ use std::io::{Read, Write};
mod arrays;
mod primitives;
mod quic_vec;
mod string;
mod tls_vec;
mod varint;

Expand Down
30 changes: 27 additions & 3 deletions tls_codec/src/quic_vec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -539,6 +539,29 @@ impl Size for VLByteSlice<'_> {
}
}

impl SerializeBytes for ContentLength {
fn tls_serialize(&self) -> Result<Vec<u8>, Error> {
SerializeBytes::tls_serialize(&self.0)
}
}

impl SerializeBytes for VLByteSlice<'_> {
fn tls_serialize(&self) -> Result<Vec<u8>, Error> {
// Get the byte length of the content, make sure it's not too
// large and write it out.
let content_length = self.0.len();

let mut len_bytes =
SerializeBytes::tls_serialize(&ContentLength::from_usize(content_length)?)?;

let mut out = alloc::vec::Vec::with_capacity(content_length + len_bytes.len());

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should make sure content_length + len_bytes.len() <= isize::MAX, or does that happen somewhere already?

out.append(&mut len_bytes);
out.extend(self.0);

Ok(out)
}
}

#[cfg(feature = "std")]
pub mod rw {
use super::*;
Expand All @@ -554,7 +577,7 @@ pub mod rw {
impl Serialize for ContentLength {
#[inline(always)]
fn tls_serialize<W: std::io::Write>(&self, writer: &mut W) -> Result<usize, Error> {
self.0.tls_serialize(writer)
Serialize::tls_serialize(&self.0, writer)
}
}

Expand Down Expand Up @@ -598,7 +621,7 @@ pub mod rw {
writer: &mut W,
content_length: usize,
) -> Result<usize, Error> {
ContentLength::from_usize(content_length)?.tls_serialize(writer)
Serialize::tls_serialize(&ContentLength::from_usize(content_length)?, writer)
}

impl<T: Serialize + std::fmt::Debug> Serialize for Vec<T> {
Expand Down Expand Up @@ -654,7 +677,8 @@ mod rw_bytes {
// large and write it out.
let content_length = bytes.len();

let len_len = ContentLength::from_usize(content_length)?.tls_serialize(writer)?;
let len_len =
Serialize::tls_serialize(&ContentLength::from_usize(content_length)?, writer)?;

// Now serialize the elements
writer.write_all(bytes)?;
Expand Down
231 changes: 231 additions & 0 deletions tls_codec/src/string.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,231 @@
//! This module implements de/serialization for String by storing the UTF-8 representation in a
//! VLByteVec, i.e. a byte vec with a varint Length.

use alloc::string::String;

use crate::{DeserializeBytes, SerializeBytes, Size, VLByteSlice, VLByteVec};

impl Size for String {
fn tls_serialized_len(&self) -> usize {
self.as_bytes().tls_serialized_len()
}
}

impl Size for &str {

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add implementations for str as well? Then we can call it on a string directly "hello".tls_serialize_detached().

fn tls_serialized_len(&self) -> usize {
self.as_bytes().tls_serialized_len()
}
}

impl SerializeBytes for String {
fn tls_serialize(&self) -> Result<alloc::vec::Vec<u8>, crate::Error> {
SerializeBytes::tls_serialize(&VLByteSlice(self.as_bytes()))
}
}

impl SerializeBytes for &str {
fn tls_serialize(&self) -> Result<alloc::vec::Vec<u8>, crate::Error> {
SerializeBytes::tls_serialize(&self.as_bytes())
}
}

impl DeserializeBytes for String {
fn tls_deserialize_bytes(bytes: &[u8]) -> Result<(Self, &[u8]), crate::Error>
where
Self: Sized,
{
let (bytes, rest) = VLByteVec::tls_deserialize_bytes(bytes)?;
let text = String::from_utf8(bytes.into())
.map_err(|err| crate::Error::DecodingError(format!("invalid utf8: {err}")))?;

Ok((text, rest))
}
}

#[cfg(feature = "std")]
mod std_only {
use super::*;
use crate::{Deserialize, Serialize};

impl Serialize for String {
fn tls_serialize<W: std::io::Write>(&self, writer: &mut W) -> Result<usize, crate::Error> {
Serialize::tls_serialize(&VLByteSlice(self.as_bytes()), writer)
}
}

impl Serialize for &str {
fn tls_serialize<W: std::io::Write>(&self, writer: &mut W) -> Result<usize, crate::Error> {
Serialize::tls_serialize(&self.as_bytes(), writer)
}
}

impl Deserialize for String {
fn tls_deserialize<R: std::io::Read>(bytes: &mut R) -> Result<Self, crate::Error>
where
Self: Sized,
{
let bytes = VLByteVec::tls_deserialize(bytes)?;
String::from_utf8(bytes.into())
.map_err(|err| crate::Error::DecodingError(format!("invalid utf8: {err}")))
}
}
}

#[cfg(all(test, feature = "std"))]
mod tests_with_std {
use crate::{Deserialize, Serialize, Size};
use alloc::string::String;

#[test]
fn serialize_multibyte_utf8_string() {
// U+00FC = "ü", encoded as 2 bytes in UTF-8: [0xC3, 0xBC]
let s = String::from("ü");
let buf = s.tls_serialize_detached().unwrap();
assert_eq!(buf, [2, 0xC3, 0xBC]);
assert_eq!(s.tls_serialized_len(), 3);
}

#[test]
fn serialize_empty_string() {
let s = String::new();
let buf = s.tls_serialize_detached().unwrap();
assert_eq!(buf, [0]);
assert_eq!(s.tls_serialized_len(), 1);
}

#[test]
fn serialize_hello_string() {
let s = String::from("hello");
let buf = s.tls_serialize_detached().unwrap();
// length prefix (5) + b"hello"
assert_eq!(buf, [5, b'h', b'e', b'l', b'l', b'o']);
assert_eq!(s.tls_serialized_len(), 6);
}

#[test]
fn roundtrip_deserialize() {
let original = String::from("roundtrip test");
let buf = original.tls_serialize_detached().unwrap();
let deserialized = String::tls_deserialize_exact(&buf).unwrap();
assert_eq!(original, deserialized);
}

#[test]
fn roundtrip_deserialize_longstring() {
let original = String::from_utf8(vec![0x30u8; 300]).unwrap();
let buf = original.tls_serialize_detached().unwrap();
let deserialized = String::tls_deserialize_exact(&buf).unwrap();
assert_eq!(original, deserialized);
}

#[test]
fn roundtrip_deserialize_empty() {
let original = String::new();
let buf = original.tls_serialize_detached().unwrap();
let deserialized = String::tls_deserialize_exact(&buf).unwrap();
assert_eq!(original, deserialized);
}

#[test]
fn deserialize_invalid_utf8() {
// length prefix 2 + two bytes that are not valid UTF-8
let buf: &[u8] = &[2, 0xFF, 0xFE];
let err = String::tls_deserialize_exact(buf).unwrap_err();
assert!(matches!(err, crate::Error::DecodingError(msg) if msg.contains("invalid utf8")));
}
}
#[cfg(test)]
mod tests {
use alloc::string::String;

#[cfg(feature = "std")]
use crate::Serialize;

use crate::{DeserializeBytes, SerializeBytes, Size};

#[test]
fn serialize_empty_str() {
let s = "";

#[cfg(feature = "std")]
{
let mut buf = [0u8; 1];
Serialize::tls_serialize(&s, &mut buf.as_mut_slice()).unwrap();
assert_eq!(buf, [0]);
assert_eq!(s.tls_serialized_len(), 1);
}

let buf = SerializeBytes::tls_serialize(&s).unwrap();
assert_eq!(buf, [0]);
assert_eq!(s.tls_serialized_len(), 1);
}

#[test]
fn serialize_hello_str() {
let s = "hello";
#[cfg(feature = "std")]
{
let mut buf = [0u8; 6];
Serialize::tls_serialize(&s, &mut buf.as_mut_slice()).unwrap();
// length prefix (5) + b"hello"
assert_eq!(buf, [5, b'h', b'e', b'l', b'l', b'o']);
assert_eq!(s.tls_serialized_len(), 6);
}

let buf = SerializeBytes::tls_serialize(&s).unwrap();
// length prefix (5) + b"hello"
assert_eq!(buf, [5, b'h', b'e', b'l', b'l', b'o']);
assert_eq!(s.tls_serialized_len(), 6);
}

#[test]
fn serialize_multibyte_utf8_str() {
// U+00FC = "ü", encoded as 2 bytes in UTF-8: [0xC3, 0xBC]
let s = "ü";
#[cfg(feature = "std")]
{
let mut buf = [0u8; 3];
Serialize::tls_serialize(&s, &mut buf.as_mut_slice()).unwrap();
assert_eq!(buf, [2, 0xC3, 0xBC]);
assert_eq!(s.tls_serialized_len(), 3);
}

let buf = SerializeBytes::tls_serialize(&s).unwrap();
assert_eq!(buf, [2, 0xC3, 0xBC]);
assert_eq!(s.tls_serialized_len(), 3);
}

#[test]
fn deserialize_bytes_hello() {
let input = [5, b'h', b'e', b'l', b'l', b'o'];
let (s, rest) = String::tls_deserialize_bytes(&input).unwrap();
assert_eq!(s, "hello");
assert!(rest.is_empty());
assert_eq!(s.tls_serialized_len(), 6);
}

#[test]
fn deserialize_bytes_with_trailing_data() {
// "hi" (length 2) followed by extra byte 0x99
let input = [2, b'h', b'i', 0x99];
let (s, rest) = String::tls_deserialize_bytes(&input).unwrap();
assert_eq!(s, "hi");
assert_eq!(rest, [0x99]);
}

#[test]
fn deserialize_bytes_invalid_utf8() {
// length prefix 3 + 3 bytes that form an invalid UTF-8 sequence
let input = [3, 0xED, 0xA0, 0x80]; // surrogates are invalid in UTF-8
let err = String::tls_deserialize_exact_bytes(&input).unwrap_err();
assert!(matches!(err, crate::Error::DecodingError(msg) if msg.contains("invalid utf8")));
}

#[test]
fn deserialize_bytes_empty_string() {
let input = [0];
let (s, rest) = String::tls_deserialize_bytes(&input).unwrap();
assert_eq!(s, "");
assert!(rest.is_empty());
}
}
21 changes: 16 additions & 5 deletions tls_codec/src/varint.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::{Deserialize, DeserializeBytes, Error, Serialize, Size};
use crate::{Deserialize, DeserializeBytes, Error, Serialize, SerializeBytes, Size};

/// Variable-length encoded unsigned integer as defined in [RFC 9000].
///
Expand Down Expand Up @@ -168,6 +168,16 @@ impl Serialize for TlsVarInt {
}
}

impl SerializeBytes for TlsVarInt {
#[inline]
fn tls_serialize(&self) -> Result<alloc::vec::Vec<u8>, Error> {
let mut bytes = alloc::vec![0u8; 8];
let len = self.write_bytes(&mut bytes)?;
bytes.truncate(len);

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you'd use bytes_len() for the length you wouldn't need to truncate. I don't think it makes a real difference though.

Ok(bytes)
}
}

impl Size for TlsVarInt {
#[inline]
fn tls_serialized_len(&self) -> usize {
Expand Down Expand Up @@ -237,10 +247,11 @@ mod tests {

for (value, len, bytes) in TESTS {
let mut buf = Vec::new();
let written = TlsVarInt::try_from(value)
.expect("value too large")
.tls_serialize(&mut buf)
.expect("tls serialize failed");
let written = Serialize::tls_serialize(
&TlsVarInt::try_from(value).expect("value too large"),
&mut buf,
)
.expect("tls serialize failed");
assert_eq!(written, len, "{value}");
assert_eq!(buf.len(), len, "{value}");
assert_eq!(&buf[..], bytes, "{value}");
Expand Down