diff --git a/Cargo.lock b/Cargo.lock index f144305f..123d0e59 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -72,7 +72,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dbb4e440d04be07da1f1bf44fb4495ebd58669372fe0cffa6e48595ac5bd88a3" dependencies = [ "android_log-sys", - "env_filter", + "env_filter 0.1.3", "log", ] @@ -85,6 +85,56 @@ dependencies = [ "libc", ] +[[package]] +name = "anstream" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "824a212faf96e9acacdbd09febd34438f8f711fb84e09a8916013cd7815ca28d" +dependencies = [ + "anstyle", + "anstyle-parse", + "anstyle-query", + "anstyle-wincon", + "colorchoice", + "is_terminal_polyfill", + "utf8parse", +] + +[[package]] +name = "anstyle" +version = "1.0.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "940b3a0ca603d1eade50a4846a2afffd5ef57a9feac2c0e2ec2e14f9ead76000" + +[[package]] +name = "anstyle-parse" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "52ce7f38b242319f7cabaa6813055467063ecdc9d355bbb4ce0c68908cd8130e" +dependencies = [ + "utf8parse", +] + +[[package]] +name = "anstyle-query" +version = "1.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "40c48f72fd53cd289104fc64099abca73db4166ad86ea0b4341abe65af83dadc" +dependencies = [ + "windows-sys 0.61.2", +] + +[[package]] +name = "anstyle-wincon" +version = "3.0.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "291e6a250ff86cd4a820112fb8898808a366d8f9f58ce16d1f538353ad55747d" +dependencies = [ + "anstyle", + "once_cell_polyfill", + "windows-sys 0.61.2", +] + [[package]] name = "anyhow" version = "1.0.99" @@ -109,6 +159,18 @@ version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" +[[package]] +name = "async-channel" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "924ed96dd52d1b75e9c1a3e6275715fd320f5f9439fb5a4a11fa51f4221158d2" +dependencies = [ + "concurrent-queue", + "event-listener-strategy", + "futures-core", + "pin-project-lite", +] + [[package]] name = "atomic" version = "0.5.3" @@ -293,6 +355,21 @@ dependencies = [ "thiserror", ] +[[package]] +name = "bit-set" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08807e080ed7f9d5433fa9b275196cfc35414f66a0c79d864dc51a0d825231a3" +dependencies = [ + "bit-vec", +] + +[[package]] +name = "bit-vec" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e764a1d40d510daf35e07be9eb06e75770908c27d411ee6c92109c9840eaaf7" + [[package]] name = "bitcoin-io" version = "0.1.3" @@ -326,9 +403,9 @@ dependencies = [ [[package]] name = "bitflags" -version = "2.9.3" +version = "2.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "34efbcccd345379ca2868b2b2c9d3782e9cc58ba87bc7d79d5b53d9c9ae6f25d" +checksum = "843867be96c8daad0d758b57df9392b6d8d271134fce549de6ce169ff98a92af" [[package]] name = "blake2" @@ -479,7 +556,7 @@ dependencies = [ "num-traits", "serde", "wasm-bindgen", - "windows-link", + "windows-link 0.1.3", ] [[package]] @@ -493,6 +570,22 @@ dependencies = [ "zeroize", ] +[[package]] +name = "colorchoice" +version = "1.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d07550c9036bf2ae0c684c4297d503f838287c83c53686d05370d0e139ae570" + +[[package]] +name = "concurrent-queue" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ca0197aee26d1ae37445ee532fefce43251d24cc7c166799f4d46817f1d3973" +dependencies = [ + "crossbeam-utils", + "loom", +] + [[package]] name = "console" version = "0.15.11" @@ -502,7 +595,7 @@ dependencies = [ "encode_unicode", "libc", "once_cell", - "windows-sys", + "windows-sys 0.59.0", ] [[package]] @@ -532,6 +625,17 @@ version = "0.8.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" +[[package]] +name = "core-models" +version = "0.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "657f625ff361906f779745d08375ae3cc9fef87a35fba5f22874cf773010daf4" +dependencies = [ + "hax-lib", + "pastey", + "rand 0.9.2", +] + [[package]] name = "cpufeatures" version = "0.2.17" @@ -565,6 +669,12 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "crossbeam-utils" +version = "0.8.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" + [[package]] name = "crunchy" version = "0.2.4" @@ -678,6 +788,12 @@ dependencies = [ "zeroize", ] +[[package]] +name = "diatomic-waker" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ab03c107fafeb3ee9f5925686dbb7a73bc76e3932abb0d2b365cb64b169cf04c" + [[package]] name = "digest" version = "0.10.7" @@ -803,12 +919,73 @@ dependencies = [ "regex", ] +[[package]] +name = "env_filter" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32e90c2accc4b07a8456ea0debdc2e7587bdd890680d71173a15d4ae604f6eef" +dependencies = [ + "log", + "regex", +] + +[[package]] +name = "env_logger" +version = "0.11.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0621c04f2196ac3f488dd583365b9c09be011a4ab8b9f37248ffcc8f6198b56a" +dependencies = [ + "anstream", + "anstyle", + "env_filter 1.0.1", + "jiff", + "log", +] + [[package]] name = "equivalent" version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" +[[package]] +name = "errno" +version = "0.3.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb" +dependencies = [ + "libc", + "windows-sys 0.61.2", +] + +[[package]] +name = "event-listener" +version = "5.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e13b66accf52311f30a0db42147dadea9850cb48cd070028831ae5f5d4b856ab" +dependencies = [ + "concurrent-queue", + "loom", + "parking", + "pin-project-lite", +] + +[[package]] +name = "event-listener-strategy" +version = "0.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8be9f3dfaaffdae2972880079a491a1a8bb7cbed0b8dd7a347f668b4150a3b93" +dependencies = [ + "event-listener", + "pin-project-lite", +] + +[[package]] +name = "fastrand" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" + [[package]] name = "ff" version = "0.13.1" @@ -867,6 +1044,12 @@ dependencies = [ "syn 2.0.106", ] +[[package]] +name = "fnv" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" + [[package]] name = "form_urlencoded" version = "1.2.2" @@ -941,6 +1124,19 @@ version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6" +[[package]] +name = "futures-lite" +version = "2.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f78e10609fe0e0b3f4157ffab1876319b5b0db102a2c60dc4626306dc46b44ad" +dependencies = [ + "fastrand", + "futures-core", + "futures-io", + "parking", + "pin-project-lite", +] + [[package]] name = "futures-macro" version = "0.3.31" @@ -982,6 +1178,21 @@ dependencies = [ "slab", ] +[[package]] +name = "generator" +version = "0.8.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "52f04ae4152da20c76fe800fa48659201d5cf627c5149ca0b707b69d7eef6cf9" +dependencies = [ + "cc", + "cfg-if", + "libc", + "log", + "rustversion", + "windows-link 0.2.1", + "windows-result", +] + [[package]] name = "generic-array" version = "0.14.7" @@ -1081,6 +1292,43 @@ version = "0.16.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "841d1cc9bed7f9236f321df977030373f4a4163ae1a7dbfe1a51a2c1a51d9100" +[[package]] +name = "hax-lib" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "543f93241d32b3f00569201bfce9d7a93c92c6421b23c77864ac929dc947b9fc" +dependencies = [ + "hax-lib-macros", + "num-bigint", + "num-traits", +] + +[[package]] +name = "hax-lib-macros" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8755751e760b11021765bb04cb4a6c4e24742688d9f3aa14c2079638f537b0f" +dependencies = [ + "hax-lib-macros-types", + "proc-macro-error2", + "proc-macro2", + "quote", + "syn 2.0.106", +] + +[[package]] +name = "hax-lib-macros-types" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f177c9ae8ea456e2f71ff3c1ea47bf4464f772a05133fcbba56cd5ba169035a2" +dependencies = [ + "proc-macro2", + "quote", + "serde", + "serde_json", + "uuid", +] + [[package]] name = "hermit-abi" version = "0.5.2" @@ -1295,6 +1543,12 @@ dependencies = [ "libc", ] +[[package]] +name = "is_terminal_polyfill" +version = "1.70.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a6cb138bb79a146c1bd460005623e142ef0181e3d0219cb493e02f7d08a35695" + [[package]] name = "itertools" version = "0.11.0" @@ -1310,6 +1564,30 @@ version = "1.0.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4a5f13b858c8d314ee3e8f639011f7ccefe71f97f96e50151fb991f267928e2c" +[[package]] +name = "jiff" +version = "0.2.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a3546dc96b6d42c5f24902af9e2538e82e39ad350b0c766eb3fbf2d8f3d8359" +dependencies = [ + "jiff-static", + "log", + "portable-atomic", + "portable-atomic-util", + "serde_core", +] + +[[package]] +name = "jiff-static" +version = "0.2.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2a8c8b344124222efd714b73bb41f8b5120b27a7cc1c75593a6ff768d9d05aa4" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.106", +] + [[package]] name = "jobserver" version = "0.1.33" @@ -1356,12 +1634,96 @@ version = "0.2.175" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6a82ae493e598baaea5209805c49bbf2ea7de956d50d7da0da1164f9c6d28543" +[[package]] +name = "libcrux-aesgcm" +version = "0.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "99f2a019dab4097585a7d4f5b9deebe46cd1e628b16a5bc4cb0ce35e1da334e6" +dependencies = [ + "libcrux-intrinsics", + "libcrux-platform", + "libcrux-secrets", + "libcrux-traits", +] + +[[package]] +name = "libcrux-intrinsics" +version = "0.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1b5db005ff8001e026b73a6842ee81bbef8ec5ff0e1915a67ae65fd2a9fafa5" +dependencies = [ + "core-models", + "hax-lib", +] + +[[package]] +name = "libcrux-ml-kem" +version = "0.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aca7de713c6dddcf7aaf76e8ef9dc0097c8d7ce23a8eadf04c8761734714e184" +dependencies = [ + "hax-lib", + "libcrux-intrinsics", + "libcrux-platform", + "libcrux-secrets", + "libcrux-sha3", + "libcrux-traits", + "rand 0.9.2", + "tls_codec", +] + +[[package]] +name = "libcrux-platform" +version = "0.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d9e21d7ed31a92ac539bd69a8c970b183ee883872d2d19ce27036e24cb8ecc4" +dependencies = [ + "libc", +] + +[[package]] +name = "libcrux-secrets" +version = "0.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ce650f3041b44ba40d4263852347d007cd2cd9d1cc856a6f6c8b2e10c3fd40b" +dependencies = [ + "hax-lib", +] + +[[package]] +name = "libcrux-sha3" +version = "0.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8c50f6e04a184511b782c5cc1eb6a227c6d36f2c935e93d698655a93a99696b5" +dependencies = [ + "hax-lib", + "libcrux-intrinsics", + "libcrux-platform", + "libcrux-traits", +] + +[[package]] +name = "libcrux-traits" +version = "0.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "812e4fa89f3f5e34b47f928b22b1b78395a0d4ec23b1f583db635f128159d65f" +dependencies = [ + "libcrux-secrets", + "rand 0.9.2", +] + [[package]] name = "libm" version = "0.2.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f9fbbcab51052fe104eb5e5d351cf728d30a5be1fe14d9be8a3b097481fb97de" +[[package]] +name = "linux-raw-sys" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df1d3c3b53da64cf5760482273a98e575c651a67eec7f77df96b5b642de8f039" + [[package]] name = "litemap" version = "0.8.0" @@ -1380,9 +1742,31 @@ dependencies = [ [[package]] name = "log" -version = "0.4.27" +version = "0.4.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897" + +[[package]] +name = "loom" +version = "0.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "419e0dc8046cb947daa77eb95ae174acfbddb7673b4151f56d1eed8e93fbfaca" +dependencies = [ + "cfg-if", + "generator", + "scoped-tls", + "tracing", + "tracing-subscriber", +] + +[[package]] +name = "matchers" +version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "13dc2df351e3202783a1fe0d44375f7295ffb4049267b0f3018346dc122a1d94" +checksum = "d1525a2a28c7f4fa0fc98bb91ae755d1e2d1505079e05539e35bc876b5d65ae9" +dependencies = [ + "regex-automata", +] [[package]] name = "md-5" @@ -1446,7 +1830,7 @@ checksum = "78bed444cc8a2160f01cbcf811ef18cac863ad68ae8ca62092e8db51d51c761c" dependencies = [ "libc", "wasi 0.11.1+wasi-snapshot-preview1", - "windows-sys", + "windows-sys 0.59.0", ] [[package]] @@ -1469,6 +1853,25 @@ dependencies = [ "syn 2.0.106", ] +[[package]] +name = "nu-ansi-term" +version = "0.50.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7957b9740744892f114936ab4a57b3f487491bbeafaf8083688b16841a4240e5" +dependencies = [ + "windows-sys 0.61.2", +] + +[[package]] +name = "num-bigint" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a5e44f723f1133c9deac646763579fdb3ac745e418f2a7af9cd0c431da1f20b9" +dependencies = [ + "num-integer", + "num-traits", +] + [[package]] name = "num-bigint-dig" version = "0.8.6" @@ -1540,6 +1943,18 @@ version = "1.21.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" +[[package]] +name = "once_cell_polyfill" +version = "1.70.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "384b8ab6d37215f3c5301a95a4accb5d64aa607f1fcb26a11b5303878451b4fe" + +[[package]] +name = "oneshot" +version = "0.1.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b4ce411919553d3f9fa53a0880544cda985a112117a0444d5ff1e870a893d6ea" + [[package]] name = "opaque-debug" version = "0.3.1" @@ -1595,6 +2010,15 @@ dependencies = [ "sha2", ] +[[package]] +name = "parking" +version = "2.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f38d5652c16fde515bb1ecef450ab0f6a219d619a7274976324d5e377f7dceba" +dependencies = [ + "loom", +] + [[package]] name = "parking_lot_core" version = "0.9.11" @@ -1625,6 +2049,12 @@ version = "1.0.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" +[[package]] +name = "pastey" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b867cad97c0791bbd3aaa6472142568c6c9e8f71937e98379f584cfb0cf35bec" + [[package]] name = "pbkdf2" version = "0.12.2" @@ -1742,6 +2172,15 @@ version = "1.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f84267b20a16ea918e43c6a88433c2d54fa145c92a811b5b047ccbe153674483" +[[package]] +name = "portable-atomic-util" +version = "0.2.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c2a106d1259c23fac8e543272398ae0e3c0b8d33c88ed73d0cc71b0f1d902618" +dependencies = [ + "portable-atomic", +] + [[package]] name = "potential_utf" version = "0.1.2" @@ -1810,6 +2249,28 @@ dependencies = [ "elliptic-curve", ] +[[package]] +name = "proc-macro-error-attr2" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "96de42df36bb9bba5542fe9f1a054b8cc87e172759a1868aa05c1f3acc89dfc5" +dependencies = [ + "proc-macro2", + "quote", +] + +[[package]] +name = "proc-macro-error2" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "11ec05c52be0a07b08061f7dd003e7d7092e0472bc731b4af7bb1ef876109802" +dependencies = [ + "proc-macro-error-attr2", + "proc-macro2", + "quote", + "syn 2.0.106", +] + [[package]] name = "proc-macro2" version = "1.0.101" @@ -1819,6 +2280,25 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "proptest" +version = "1.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b45fcc2344c680f5025fe57779faef368840d0bd1f42f216291f0dc4ace4744" +dependencies = [ + "bit-set", + "bit-vec", + "bitflags", + "num-traits", + "rand 0.9.2", + "rand_chacha 0.9.0", + "rand_xorshift", + "regex-syntax", + "rusty-fork", + "tempfile", + "unarray", +] + [[package]] name = "provenance-mark" version = "0.16.0" @@ -1863,6 +2343,54 @@ dependencies = [ "syn 2.0.106", ] +[[package]] +name = "ql-fsm" +version = "0.1.0" +dependencies = [ + "bytes", + "indexmap", + "proptest", + "ql-wire", +] + +[[package]] +name = "ql-rpc" +version = "0.1.0" +dependencies = [ + "bytes", + "trait-variant", +] + +[[package]] +name = "ql-runtime" +version = "0.1.0" +dependencies = [ + "async-channel", + "bytes", + "diatomic-waker", + "env_logger", + "event-listener", + "futures-lite", + "log", + "loom", + "oneshot", + "ql-fsm", + "ql-rpc", + "ql-wire", + "tokio", +] + +[[package]] +name = "ql-wire" +version = "0.1.0" +dependencies = [ + "bytes", + "getrandom 0.2.16", + "libcrux-aesgcm", + "libcrux-ml-kem", + "sha2", +] + [[package]] name = "quantum-link-macros" version = "0.1.0" @@ -1872,6 +2400,12 @@ dependencies = [ "syn 2.0.106", ] +[[package]] +name = "quick-error" +version = "1.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1d01941d82fa2ab50be1e79e6714289dd7cde78eba4c074bc5a4374f650dfe0" + [[package]] name = "quote" version = "1.0.40" @@ -1955,6 +2489,15 @@ dependencies = [ "getrandom 0.3.3", ] +[[package]] +name = "rand_xorshift" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "513962919efc330f829edb2535844d1b912b0fbe2ca165d613e4e8788bb05a5a" +dependencies = [ + "rand_core 0.9.3", +] + [[package]] name = "rand_xoshiro" version = "0.6.0" @@ -1975,9 +2518,9 @@ dependencies = [ [[package]] name = "regex" -version = "1.11.1" +version = "1.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b544ef1b4eac5dc2db33ea63606ae9ffcfac26c1416a2806ae0bf5f56b201191" +checksum = "e10754a14b9137dd7b1e3e5b0493cc9171fdd105e0ab477f51b72e7f3ac0e276" dependencies = [ "aho-corasick", "memchr", @@ -1987,9 +2530,9 @@ dependencies = [ [[package]] name = "regex-automata" -version = "0.4.9" +version = "0.4.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "809e8dc61f6de73b46c85f4c96486310fe304c434cfa43669d7b40f711150908" +checksum = "6e1dd4122fc1595e8162618945476892eefca7b88c52820e74af6262213cae8f" dependencies = [ "aho-corasick", "memchr", @@ -2087,12 +2630,37 @@ dependencies = [ "semver", ] +[[package]] +name = "rustix" +version = "1.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cd15f8a2c5551a84d56efdc1cd049089e409ac19a3072d5037a17fd70719ff3e" +dependencies = [ + "bitflags", + "errno", + "libc", + "linux-raw-sys", + "windows-sys 0.61.2", +] + [[package]] name = "rustversion" version = "1.0.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d" +[[package]] +name = "rusty-fork" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cc6bf79ff24e648f6da1f8d1f011e9cac26491b619e6b9280f2b47f1774e6ee2" +dependencies = [ + "fnv", + "quick-error", + "tempfile", + "wait-timeout", +] + [[package]] name = "ryu" version = "1.0.20" @@ -2108,6 +2676,12 @@ dependencies = [ "cipher", ] +[[package]] +name = "scoped-tls" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e1cf6437eb19a8f4a6cc0f7dca544973b0b78843adbfeb3683d1a94a0024a294" + [[package]] name = "scopeguard" version = "1.2.0" @@ -2167,18 +2741,28 @@ checksum = "56e6fa9c48d24d85fb3de5ad847117517440f6beceb7798af16b4a87d616b8d0" [[package]] name = "serde" -version = "1.0.219" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a8e94ea7f378bd32cbbd37198a4a91436180c5bb472411e48b5ec2e2124ae9e" +dependencies = [ + "serde_core", + "serde_derive", +] + +[[package]] +name = "serde_core" +version = "1.0.228" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f0e2c6ed6606019b4e29e69dbaba95b11854410e5347d525002456dbbb786b6" +checksum = "41d385c7d4ca58e59fc732af25c3983b67ac852c1a25000afe1175de458b67ad" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.219" +version = "1.0.228" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b0276cf7f2c73365f7157c8123c21cd9a50fbbd844757af28ca1f5925fc2a00" +checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79" dependencies = [ "proc-macro2", "quote", @@ -2219,6 +2803,15 @@ dependencies = [ "digest", ] +[[package]] +name = "sharded-slab" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f40ca3c46823713e0d4209592e8d6e826aa57e928f09752619fc696c499637f6" +dependencies = [ + "lazy_static", +] + [[package]] name = "shlex" version = "1.3.0" @@ -2382,6 +2975,19 @@ dependencies = [ "syn 2.0.106", ] +[[package]] +name = "tempfile" +version = "3.23.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2d31c77bdf42a745371d260a26ca7163f1e0924b64afa0b688e61b5a9fa02f16" +dependencies = [ + "fastrand", + "getrandom 0.3.3", + "once_cell", + "rustix", + "windows-sys 0.61.2", +] + [[package]] name = "thiserror" version = "2.0.17" @@ -2402,6 +3008,15 @@ dependencies = [ "syn 2.0.106", ] +[[package]] +name = "thread_local" +version = "1.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f60246a4944f24f6e018aa17cdeffb7818b76356965d03b07d6a9886e8962185" +dependencies = [ + "cfg-if", +] + [[package]] name = "threadpool" version = "1.8.1" @@ -2436,6 +3051,27 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" +[[package]] +name = "tls_codec" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0de2e01245e2bb89d6f05801c564fa27624dbd7b1846859876c7dad82e90bf6b" +dependencies = [ + "tls_codec_derive", + "zeroize", +] + +[[package]] +name = "tls_codec_derive" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2d2e76690929402faae40aebdda620a2c0e25dd6d3b9afe48867dfd95991f4bd" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.106", +] + [[package]] name = "tokio" version = "1.47.1" @@ -2448,6 +3084,78 @@ dependencies = [ "mio", "pin-project-lite", "slab", + "tokio-macros", +] + +[[package]] +name = "tokio-macros" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e06d43f1345a3bcd39f6a56dbb7dcab2ba47e68e8ac134855e7e2bdbaf8cab8" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.106", +] + +[[package]] +name = "tracing" +version = "0.1.44" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "63e71662fa4b2a2c3a26f570f037eb95bb1f85397f3cd8076caed2f026a6d100" +dependencies = [ + "pin-project-lite", + "tracing-core", +] + +[[package]] +name = "tracing-core" +version = "0.1.36" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "db97caf9d906fbde555dd62fa95ddba9eecfd14cb388e4f491a66d74cd5fb79a" +dependencies = [ + "once_cell", + "valuable", +] + +[[package]] +name = "tracing-log" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee855f1f400bd0e5c02d150ae5de3840039a3f54b025156404e34c23c03f47c3" +dependencies = [ + "log", + "once_cell", + "tracing-core", +] + +[[package]] +name = "tracing-subscriber" +version = "0.3.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb7f578e5945fb242538965c2d0b04418d38ec25c79d160cd279bf0731c8d319" +dependencies = [ + "matchers", + "nu-ansi-term", + "once_cell", + "regex-automata", + "sharded-slab", + "smallvec", + "thread_local", + "tracing", + "tracing-core", + "tracing-log", +] + +[[package]] +name = "trait-variant" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "70977707304198400eb4835a78f6a9f928bf41bba420deb8fdb175cd965d77a7" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.106", ] [[package]] @@ -2456,6 +3164,12 @@ version = "1.18.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1dccffe3ce07af9386bfd29e80c0ab1a8205a2fc34e4bcd40364df902cfa8f3f" +[[package]] +name = "unarray" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eaea85b334db583fe3274d12b4cd1880032beab409c0d774be044d4480ab9a94" + [[package]] name = "unicode-ident" version = "1.0.18" @@ -2511,22 +3225,44 @@ version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be" +[[package]] +name = "utf8parse" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" + [[package]] name = "uuid" version = "1.18.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2f87b8aa10b915a06587d0dec516c282ff295b475d94abf425d62b57710070a2" dependencies = [ + "getrandom 0.3.3", "js-sys", "wasm-bindgen", ] +[[package]] +name = "valuable" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba73ea9cf16a25df0c8caa16c51acb937d5712a8429db78a3ee29d5dcacd3a65" + [[package]] name = "version_check" version = "0.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" +[[package]] +name = "wait-timeout" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09ac3b126d3914f9849036f826e054cbabdc8519970b8998ddaf3b5bd3c65f11" +dependencies = [ + "libc", +] + [[package]] name = "wasi" version = "0.11.1+wasi-snapshot-preview1" @@ -2631,7 +3367,7 @@ checksum = "c0fdd3ddb90610c7638aa2b3a3ab2904fb9e5cdbecc643ddb3647212781c4ae3" dependencies = [ "windows-implement", "windows-interface", - "windows-link", + "windows-link 0.1.3", "windows-result", "windows-strings", ] @@ -2664,13 +3400,19 @@ version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5e6ad25900d524eaabdbbb96d20b4311e1e7ae1699af4fb28c17ae66c80d798a" +[[package]] +name = "windows-link" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5" + [[package]] name = "windows-result" version = "0.3.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "56f42bd332cc6c8eac5af113fc0c1fd6a8fd2aa08a0119358686e5160d0586c6" dependencies = [ - "windows-link", + "windows-link 0.1.3", ] [[package]] @@ -2679,7 +3421,7 @@ version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "56e6c93f3a0c3b36176cb1327a4958a0353d5d166c2a35cb268ace15e91d3b57" dependencies = [ - "windows-link", + "windows-link 0.1.3", ] [[package]] @@ -2691,6 +3433,15 @@ dependencies = [ "windows-targets", ] +[[package]] +name = "windows-sys" +version = "0.61.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae137229bcbd6cdf0f7b80a31df61766145077ddf49416a728b02cb3921ff3fc" +dependencies = [ + "windows-link 0.2.1", +] + [[package]] name = "windows-targets" version = "0.52.6" diff --git a/Cargo.toml b/Cargo.toml index 0fd0e755..a8a932ea 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,15 @@ [workspace] resolver = "2" -members = ["api", "backup-shard", "btp", "quantum-link-macros"] +members = [ + "api", + "backup-shard", + "btp", + "ql-fsm", + "ql-rpc", + "ql-runtime", + "ql-wire", + "quantum-link-macros", +] [workspace.package] homepage = "https://github.com/Foundation-Devices/foundation-api" @@ -24,6 +33,10 @@ backup-shard = { path = "backup-shard" } btp = { path = "btp" } foundation-api = { path = "api" } quantum-link-macros = { path = "quantum-link-macros" } +ql-protocol = { path = "ql-protocol" } +ql-fsm = { path = "ql-fsm" } +ql-rpc = { path = "ql-rpc" } +ql-wire = { path = "ql-wire" } [patch.crates-io] pqcrypto-traits = { git = "https://github.com/Foundation-Devices/pqcrypto", rev = "ebadf71214f67cb970242fa1053b4acb65767737" } diff --git a/QL_V2.md b/QL_V2.md new file mode 100644 index 00000000..0062c7e6 --- /dev/null +++ b/QL_V2.md @@ -0,0 +1,376 @@ +# QuantumLink V2 + +QuantumLink V2 is a peer-to-peer protocol for authenticated encrypted sessions carrying multiplexed duplex byte streams. + +It operates on whole QL records. Packetization, fragmentation, batching, and reassembly belong to the transport adapter, not to QLv2 itself. + +## Design goals +1. [Ephemeral peer sessions](#handshake): short-lived keys for encryption +2. [Forward secrecy](#security-properties): losing a long-term private key does not reveal old session data +3. [Minimal authenticated header](#record-and-frame-wire-format): keep routing visible, but authenticated +4. [QL-level reliability](#acknowledgment-and-retransmission): `ack` means received, decrypted, and accepted +5. [Duplex byte streams](#streams): avoid cross-stream head-of-line blocking and keep backpressure local +6. [Efficient wire format](#record-and-frame-wire-format): keep steady-state traffic compact +7. [Hardware-backed cryptography](#security-properties): allow platform-specific crypto implementations +8. Shared core state machine: keep implementation consistent across platforms + +## Non-goals + +QLv2 is not: + +- a packet framing format +- a generic reliability layer for arbitrary raw datagrams +- a globally ordered message bus + +## Core terms + +- `peer`: one QLv2 endpoint +- `QID`: a stable 16-byte peer identifier +- `peer bundle`: public peer information: `version`, `qid`, `capabilities`, and ML-KEM public key +- `pairing token`: an out-of-band secret that authorizes an `XX` pairing attempt +- `pairing_id`: the visible identifier derived from a pairing token and carried on `XX` records +- `session`: one live encrypted channel with directional keys and directional connection IDs +- `record`: one complete QLv2 wire unit +- `frame`: one logical item inside a session record +- `stream`: one duplex byte stream inside a session +- `route_id`: the application route carried once on the first initiator `StreamData` frame for a stream +- `stream origin`: the peer that opened the stream +- `origin lane`: bytes sent by the stream origin +- `return lane`: bytes sent back toward the stream origin + +## Record And Frame Wire Format + +QLv2 has two record types: + +- `handshake record`: used only during setup +- `session record`: used after the handshake completes + +Handshake records are large because they carry ML-KEM material. Session records are small and can carry multiple frames, including frames for different streams. + +All whole-record sizes below include the outer 2-byte record header: `version` plus `record type`. + +QLv2 uses QUIC-style variable-length integers for several steady-state fields. A varint is 1, 2, 4, or 8 bytes and can represent values in the range `0..2^62-1`. This keeps small values compact while allowing very large record and stream number spaces. + +Today, varints are used for: + +- session record `seq` +- `Ack.largest_acked` +- `Ack.block_count` +- `Ack.first_range_len` +- `Ack.gap` +- `Ack.range_len` +- `StreamData.stream_id` +- `StreamData.offset` +- `StreamData.route_id` when present +- `StreamData.bytes_len` +- `StreamWindow.stream_id` +- `StreamWindow.maximum_offset` +- `StreamClose.stream_id` + +### Handshake records + +QLv2 has two routed known-peer handshakes and one pairing handshake: + +- `IK` and `KK` carry a visible `sender` and `recipient` QID +- `XX` carries a visible `pairing_id` + +#### IK + +Used when the initiator already knows the responder bundle. + +| Record | Size | Purpose | +| --- | ---: | --- | +| `IK1` | 4785 bytes | start a handshake toward a known responder | +| `IK2` | 3195 bytes | complete `IK` and establish the session | + +#### KK + +Used when both peers already know each other. + +| Record | Size | Purpose | +| --- | ---: | --- | +| `KK1` | 3179 bytes | start a handshake between already-known peers | +| `KK2` | 3195 bytes | complete `KK` and establish the session | + +#### XX + +Used when the initiator has received an out of band pairing token, and neither peer knows each other. + +| Record | Size | Purpose | +| --- | ---: | --- | +| `XX1` | 1595 bytes | start pairing | +| `XX2` | 3201 bytes | send responder static identity and ciphertext | +| `XX3` | 3217 bytes | send initiator static identity and ciphertext | +| `XX4` | 1611 bytes | complete `XX` and establish the session | + +### Session records + +`session record size = 35..42 + sum(frame sizes)` + +There is no explicit AEAD nonce on the wire. The record `seq` is used to derive the nonce. + +| Fixed part | Size | Purpose | +| --- | ---: | --- | +| version | 1 byte | protocol version | +| record type | 1 byte | identifies a session record | +| `connection_id` | 16 bytes | route the record to the current session | +| `seq` | 1..8 bytes | varint record identity for ack and retransmit | +| AEAD auth tag | 16 bytes | authenticate the encrypted body | +| fixed overhead total | 35..42 bytes | overhead before any frames | + +The visible session header is authenticated as AEAD AAD but is not encrypted. + +### Session frames + +| Frame | Size | Purpose | +| --- | ---: | --- | +| `Ping` | 1 byte | keep the session alive when idle | +| `Unpair` | 1 byte | forget the currently bound peer and abort the session | +| `Ack` | `4+` bytes | acknowledge received session records with ACK ranges | +| `StreamWindow` | `3..17` bytes | extend per-stream send credit | +| `StreamClose` | `5..12` bytes | abort one stream lane or both lanes | +| `Close` | 3 bytes | close the whole session | +| `StreamData` | `5..34 + payload_len` bytes | carry stream bytes, optional opener route, and optional `fin` | + +`StreamData` is the main steady-state frame: + +`1 kind + varint(stream_id) + varint(offset) + 1 flags + optional varint(route_id) + varint(bytes_len) + payload_len` + +The flags byte carries: + +- `fin` +- `header present` + +Some useful minimum whole-record sizes for single-frame records: + +| Record | Size | Meaning | +| --- | ---: | --- | +| `Ping` only | 36 bytes | idle keepalive | +| `Unpair` only | 36 bytes | peer unpair | +| `Ack` only | 39 bytes | smallest selective ack | +| `Close` only | 38 bytes | session shutdown | +| empty `StreamData` without route header | 40 bytes | empty data or empty `fin` on an existing stream | +| empty opener `StreamData` with a 1-byte `route_id` | 41 bytes | open a new stream without payload bytes | + +## Handshake + +QLv2 currently supports three Noise-style handshake patterns: + +- `IK`: 2 messages, initiator already knows the responder bundle +- `KK`: 2 messages, both peers already know each other +- `XX`: 4 messages, peers authenticate through an out-of-band pairing token and exchange static identity during the handshake + +The handshake covers peer authentication and session establishment. + +Each successful handshake does five things: + +1. authenticate which peer we are talking to +2. derive a fresh transmit key and receive key +3. derive a directional transmit `connection_id` and receive `connection_id` +4. bind transport parameters into the transcript +5. produce a `handshake_hash` for the completed exchange + +Today the only transport parameter is: + +- initial per-stream receive window + +Future transport parameters could include session-wide byte credit or record-size limits. + +Each handshake attempt carries: + +- `handshake_id`: identifies one attempt and lets stale replies be ignored +- transport parameters + +`valid_until` is not currently part of the wire format. Handshake attempts instead expire by local timer. + +### Pattern summary + +- `IK` lets the responder learn the initiator during handshake completion. The initiator still needs the responder bundle before it can start. +- `KK` requires both peers to already know each other. +- `XX` requires the responder to be armed for pairing and to recognize the visible `pairing_id` derived from the expected pairing token. + +### Handshake rules + +- attempts are identified by `handshake_id` +- handshake messages are not retransmitted in place +- simultaneous starts must converge deterministically +- if `IK` and `KK` race, `IK` wins +- same-pattern races break ties by ordering the initial ephemeral public keys +- `XX` requires out-of-band authorization and uses visible `pairing_id` for lookup + +### Session establishment points + +- `IK` and `KK` complete after message 2 (1 RT) +- `XX` completes after 4 messages (2 RTT) + +## Session Model + +After the handshake, peers exchange encrypted session records. + +Each session record has: + +- one visible `connection_id` +- one visible `seq` +- one encrypted body containing one or more frames + +One session record may carry: + +- only control frames +- only stream data +- a mixture of frames for multiple streams + +This is the core steady-state model: records are the encrypted transport unit, frames are the logical items inside them. + +## Acknowledgment And Retransmission + +`Ack` is record-level, not stream-level. + +An `Ack` means the peer: + +- received that session record +- decrypted it with the current session key +- accepted its `seq` + +The ACK wire format is range-based, not bitmap-based. It carries: + +- `largest_acked` +- `block_count` +- `first_range_len` +- zero or more `(gap, range_len)` blocks + +Ranges are encoded from highest sequence numbers down to lowest sequence numbers. + +Receivers track a recent accepted record window so they can: + +- reject duplicates +- ignore records that are too old +- emit selective ACK ranges + +Pending ACK state is also range-based. If there are too many disjoint ranges, older low ranges may be dropped. An emitted ACK may also be truncated by the remaining record budget. + +Retransmission works at the frame level: + +- every emitted session record gets a fresh `seq` +- retransmit timers start only after the local transport confirms that it accepted the write +- if a record is considered lost, the FSM restores its frames +- those frames are packed into a new record with a new `seq` + +QLv2 does not resend the same logical record identity. + +There is no explicit `Nack` frame. Loss is inferred from timeout or from later ACK state that no longer includes a record. + +Pure ACK-only records are fire-and-forget: they are not themselves retransmitted. + +Example: + +`seq = 10` + +| Frame | Contents | +| --- | --- | +| `StreamData` | `stream_id=4 offset=0 bytes="hello"` | + +The sender receives more bytes for that stream before `seq = 10` is acked: + +| Pending new frame | Contents | +| --- | --- | +| `StreamData` | `stream_id=4 offset=5 bytes=" world"` | + +If `seq = 10` is considered lost, its frame is restored and packed again with a new record sequence: + +`seq = 11` + +| Frame | Contents | +| --- | --- | +| `StreamData` | `stream_id=4 offset=0 bytes="hello"` | +| `StreamData` | `stream_id=4 offset=5 bytes=" world"` | + +## Streams + +Streams are the application primitive. + +A stream has two independent lanes: + +- origin lane +- return lane + +Important properties: + +- either peer can open a stream +- stream IDs are split by parity derived from QID ordering, so both peers can open streams without collision +- stream IDs increase monotonically within each parity namespace and must not repeat within a session +- ordering is preserved within a stream lane +- different streams can make progress independently +- record loss on one stream does not block unrelated streams + +There is no separate open frame. + +Locally, opening a stream allocates: + +- a new `stream_id` +- an application `route_id` + +On the wire, the stream opener carries that `route_id` once, in the first initiator `StreamData` frame at `offset = 0`, using the optional `StreamHeader`. + +`StreamData` carries: + +- `stream_id` +- `offset` +- optional `StreamHeader { route_id }` +- `fin` +- bytes + +`StreamHeader` is only valid on the first initiator `StreamData` frame for a stream, at `offset = 0`. + +`fin` is graceful completion of one lane. It says "no more bytes on this lane" without aborting the other lane. + +## Flow Control + +Flow control is per stream. + +During the handshake, each peer advertises an initial per-stream receive window. That becomes the initial send credit the remote peer can use on each stream. + +`StreamWindow` extends that credit by advertising a larger absolute `maximum_offset`. + +In practice, a stream is writable only when both are true: + +- local send buffering has room +- peer-advertised stream credit allows more bytes + +Receive credit advances when the local application commits read bytes, not merely when bytes become readable. That is when the FSM emits a `StreamWindow` update. + +## Close And Liveness + +`StreamClose` aborts a stream early. Semantically it can target: + +- the origin lane +- the return lane +- both lanes + +`Close` aborts the whole session. + +`Unpair` is stronger than `Close`: + +- it forgets the currently bound peer locally +- it aborts the active session immediately +- it may emit one final outbound `Unpair` frame +- reconnect does not resume until a peer is paired again + +Idle sessions may send `Ping`. The peer does not answer with another ping; normal record acknowledgment is enough. + +Sessions also have local timers for: + +- handshake timeout +- delayed ack emission +- session record retransmit timeout +- keepalive ping interval +- peer silence timeout + +If peer silence exceeds the configured timeout, the session closes with timeout. + +## Security Properties + +The current handshake family is ML-KEM-based and post-quantum focused. + +Session payloads are encrypted and authenticated. The session header stays visible so the receiver can route the record, but it is still authenticated as AEAD AAD. + +QLv2 also provides forward secrecy in the following sense: even if an attacker later obtains a peer's long-term ML-KEM private key, they still cannot decrypt messages from earlier completed sessions. diff --git a/api/src/api/quantum_link.rs b/api/src/api/quantum_link.rs index a9033ed3..0d956897 100644 --- a/api/src/api/quantum_link.rs +++ b/api/src/api/quantum_link.rs @@ -239,6 +239,8 @@ impl QuantumLinkIdentity { #[cfg(test)] mod tests { + use dcbor::CBOREncodable; + use crate::{ api::{ message::{QuantumLinkMessage, PROTOCOL_VERSION}, @@ -247,6 +249,7 @@ mod tests { fx::ExchangeRate, message::EnvoyMessage, quantum_link::{ARIDCache, QlError, QuantumLinkIdentity}, + status::Heartbeat, }; #[test] @@ -309,6 +312,55 @@ mod tests { assert_eq!(fx_rate.rate, fx_rate_decoded.rate); } + #[test] + fn test_sealed_message_size() { + let envoy = QuantumLinkIdentity::generate(); + let passport = QuantumLinkIdentity::generate(); + + let fx_rate = ExchangeRate { + currency_code: String::from("USD"), + rate: 0.85, + timestamp: 0, + }; + let message = EnvoyMessage { + message: QuantumLinkMessage::ExchangeRate(fx_rate), + timestamp: 123456, + protocol_version: None, + }; + + let envelope = QuantumLink::seal( + message, + (envoy.private_keys.as_ref().unwrap(), &envoy.xid_document), + &passport.xid_document, + ); + let bytes = envelope.to_cbor_data(); + + println!("sealed message size: {} bytes", bytes.len()); + assert!(!bytes.is_empty()); + } + + #[test] + fn test_sealed_heartbeat_size() { + let envoy = QuantumLinkIdentity::generate(); + let passport = QuantumLinkIdentity::generate(); + + let message = EnvoyMessage { + message: QuantumLinkMessage::Heartbeat(Heartbeat {}), + timestamp: 123456, + protocol_version: None, + }; + + let envelope = QuantumLink::seal( + message, + (envoy.private_keys.as_ref().unwrap(), &envoy.xid_document), + &passport.xid_document, + ); + let bytes = envelope.to_cbor_data(); + + println!("sealed heartbeat size: {} bytes", bytes.len()); + assert!(!bytes.is_empty()); + } + #[test] fn test_serialize_ql_identity() { let identity = QuantumLinkIdentity::generate(); diff --git a/ql-fsm/Cargo.toml b/ql-fsm/Cargo.toml new file mode 100644 index 00000000..47da2e1b --- /dev/null +++ b/ql-fsm/Cargo.toml @@ -0,0 +1,15 @@ +[package] +name = "ql-fsm" +version = "0.1.0" +edition = "2021" +description = "Quantum Link synchronous finite state machine" +license = "Proprietary" + +[dependencies] +bytes = "1" +indexmap = "2" +ql-wire = { path = "../ql-wire" } + +[dev-dependencies] +proptest = "1.6" +ql-wire = { path = "../ql-wire", features = ["test-utils"] } diff --git a/ql-fsm/src/error.rs b/ql-fsm/src/error.rs new file mode 100644 index 00000000..9bf2a915 --- /dev/null +++ b/ql-fsm/src/error.rs @@ -0,0 +1,124 @@ +use std::{ + error::Error, + fmt::{Display, Formatter}, +}; + +use ql_wire::{PairingId, WireError}; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum ReceiveError { + InvalidRecordHeader(WireError), + InvalidRecordVersion, + InvalidHandshakeRecord(WireError), + InvalidSessionRecord(WireError), + InvalidSessionConnectionId, + InvalidSessionPayload(WireError), + InvalidIkHandshake(WireError), + InvalidKkHandshake(WireError), + InvalidXxHandshake(WireError), + InvalidRemoteBundle, + InvalidQid, + NoPeer, + NoSession, + NotPairingMode, + InvalidPairingId { + expected: PairingId, + actual: PairingId, + }, +} + +impl Display for ReceiveError { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + Self::InvalidRecordHeader(error) => write!(f, "invalid record header: {error}"), + Self::InvalidRecordVersion => f.write_str("invalid record version"), + Self::InvalidHandshakeRecord(error) => { + write!(f, "invalid handshake record: {error}") + } + Self::InvalidSessionRecord(error) => write!(f, "invalid session record: {error}"), + Self::InvalidSessionConnectionId => f.write_str("invalid session connection id"), + Self::InvalidSessionPayload(error) => write!(f, "invalid session payload: {error}"), + Self::InvalidIkHandshake(error) => write!(f, "invalid ik handshake: {error}"), + Self::InvalidKkHandshake(error) => write!(f, "invalid kk handshake: {error}"), + Self::InvalidXxHandshake(error) => write!(f, "invalid xx handshake: {error}"), + Self::InvalidRemoteBundle => f.write_str("invalid remote bundle"), + Self::InvalidQid => f.write_str("invalid qid"), + Self::NoPeer => f.write_str("no bound peer"), + Self::NoSession => f.write_str("no active session"), + Self::NotPairingMode => f.write_str("not in pairing mode"), + Self::InvalidPairingId { expected, actual } => { + write!( + f, + "invalid pairing id: expected {expected}, actual {actual}" + ) + } + } + } +} + +impl std::error::Error for ReceiveError {} + +impl From for ReceiveError { + fn from(_: NoSessionError) -> Self { + Self::NoSession + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct NoPeerError; + +impl Display for NoPeerError { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.write_str("no peer bound") + } +} + +impl Error for NoPeerError {} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct NoSessionError; + +impl Display for NoSessionError { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "no session") + } +} + +impl Error for NoSessionError {} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum StreamError { + MissingStream, + NotWritable, + NoSession, +} + +impl Display for StreamError { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + let message = match self { + Self::MissingStream => "missing stream", + Self::NotWritable => "stream is not writable", + Self::NoSession => "no session", + }; + f.write_str(message) + } +} + +impl Error for StreamError {} + +impl From for StreamError { + fn from(_: NoSessionError) -> Self { + Self::NoSession + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct CommitReadError; + +impl Display for CommitReadError { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "invalid read commit") + } +} + +impl Error for CommitReadError {} diff --git a/ql-fsm/src/fsm.rs b/ql-fsm/src/fsm.rs new file mode 100644 index 00000000..036a336e --- /dev/null +++ b/ql-fsm/src/fsm.rs @@ -0,0 +1,267 @@ +use std::{collections::VecDeque, time::Instant}; + +use bytes::Bytes; +use ql_wire::{self as wire, QlCrypto, RouteId, SessionCloseCode, StreamId, WireDecode}; + +use crate::{ + handshake, + session::{self, SessionEvent, TerminalFrame}, + state::LinkState, + Event, NoPeerError, NoSessionError, OutboundWrite, QlFsm, ReceiveError, StreamError, WriteId, +}; + +pub struct EventSink<'a> { + events: &'a mut VecDeque, + termination: Option, +} + +impl<'a> EventSink<'a> { + fn new(events: &'a mut VecDeque) -> Self { + Self { + events, + termination: None, + } + } +} + +impl session::EventSink for EventSink<'_> { + fn emit(&mut self, event: SessionEvent) { + match event { + SessionEvent::Unpaired => { + self.termination = Some(TerminalFrame::Unpair); + } + SessionEvent::Opened { + stream_id, + route_id, + } => { + self.events.push_back(Event::Opened { + stream_id, + route_id, + }); + } + SessionEvent::Readable(stream_id) => { + self.events.push_back(Event::Readable(stream_id)); + } + SessionEvent::Writable(stream_id) => { + self.events.push_back(Event::Writable(stream_id)); + } + SessionEvent::Finished(stream_id) => { + self.events.push_back(Event::Finished(stream_id)); + } + SessionEvent::OutboundFinished(stream_id) => { + self.events.push_back(Event::OutboundFinished(stream_id)); + } + SessionEvent::Closed(frame) => { + self.events.push_back(Event::Closed(frame)); + } + SessionEvent::WritableClosed(frame) => { + self.events.push_back(Event::WritableClosed(frame)); + } + SessionEvent::SessionClosed(close) => { + self.termination = Some(TerminalFrame::Close(close.clone())); + self.events.push_back(Event::SessionClosed(close)); + } + } + } +} + +pub fn handle_bind_peer(fsm: &mut QlFsm, peer: ql_wire::PeerBundle) { + fsm.state.handshake = None; + fsm.state.link = LinkState::Idle; + fsm.state.peer = Some(peer); +} + +pub fn unpair(fsm: &mut QlFsm) { + let had_peer = fsm.state.peer.is_some(); + fsm.state.handshake = None; + fsm.state.armed_pairing_token = None; + + if let Some(conn) = fsm.state.link.connected_mut() { + let mut emit = EventSink::new(&mut fsm.events); + conn.session.unpair(&mut emit); + } else { + fsm.state.link = LinkState::Idle; + } + + if had_peer { + emit_peer_status(fsm, crate::PeerStatus::Unpaired); + } + fsm.state.peer = None; +} + +pub fn handle_disarm_pairing(fsm: &mut QlFsm) { + fsm.state.armed_pairing_token = None; + handshake::handle_disarm_pairing(fsm); +} + +pub fn handle_connect_xx(fsm: &mut QlFsm, invite: crate::PairingInvite, crypto: &impl QlCrypto) { + handshake::handle_connect_xx(fsm, invite, crypto); +} + +pub fn handle_connect_ik(fsm: &mut QlFsm, crypto: &impl QlCrypto) -> Result<(), NoPeerError> { + handshake::handle_connect_ik(fsm, crypto) +} + +pub fn handle_connect_kk(fsm: &mut QlFsm, crypto: &impl QlCrypto) -> Result<(), NoPeerError> { + handshake::handle_connect_kk(fsm, crypto) +} + +pub fn receive( + fsm: &mut QlFsm, + mut bytes: Vec, + crypto: &impl QlCrypto, +) -> Result<(), ReceiveError> { + let mut reader = wire::Reader::new(bytes.as_mut_slice()); + let header = + wire::RecordHeader::decode(&mut reader).map_err(ReceiveError::InvalidRecordHeader)?; + + if header.version != wire::QL_WIRE_VERSION { + return Err(ReceiveError::InvalidRecordVersion); + } + + match header.record_type { + wire::RecordType::Handshake => { + let record = wire::QlHandshakeRecord::decode(&mut reader) + .map_err(ReceiveError::InvalidHandshakeRecord)?; + handshake::handle_handshake_record(fsm, crypto, &record) + } + wire::RecordType::Session => { + let termination = { + let QlFsm { state, events, .. } = fsm; + let conn = state.link.connected_mut_or_err()?; + let (decrypt_len, seq) = { + let record = wire::QlSessionRecord::decode(&mut reader) + .map_err(ReceiveError::InvalidSessionRecord)?; + if record.header.connection_id != conn.transport.rx_connection_id { + return Err(ReceiveError::InvalidSessionConnectionId); + } + let payload = wire::decrypt_record( + crypto, + &record.header, + record.payload, + &conn.transport.rx_key, + ) + .map_err(ReceiveError::InvalidSessionPayload)?; + (payload.len(), record.header.seq) + }; + + let len = bytes.len(); + let plaintext = Bytes::from(bytes).slice(len - decrypt_len..); + let frames = wire::parse_session_frames(plaintext); + + let mut emit = EventSink::new(events); + conn.session.receive(state.now, seq, frames, &mut emit); + emit.termination + }; + + if matches!(termination, Some(TerminalFrame::Unpair)) { + if fsm.state.peer.is_some() { + emit_peer_status(fsm, crate::PeerStatus::Unpaired); + } + fsm.state.handshake = None; + fsm.state.armed_pairing_token = None; + fsm.state.peer = None; + } + Ok(()) + } + } +} + +pub fn on_timer(fsm: &mut QlFsm) { + handshake::handle_timer(fsm); + + let QlFsm { state, events, .. } = fsm; + let Some(conn) = state.link.connected_mut() else { + return; + }; + + let mut emit = EventSink::new(events); + conn.session.on_timer(state.now, &mut emit); +} + +pub fn next_deadline(fsm: &QlFsm) -> Option { + [ + handshake::next_handshake_deadline(fsm), + fsm.state + .link + .connected() + .and_then(|state| state.session.next_deadline()), + ] + .into_iter() + .flatten() + .min() +} + +pub fn take_next_write(fsm: &mut QlFsm, crypto: &impl QlCrypto) -> Option { + if let Some(record) = fsm.state.handshake.take() { + let record = wire::encode_record_vec(ql_wire::RecordType::Handshake, &record); + return Some(OutboundWrite { + record, + write_id: None, + }); + } + + let QlFsm { state, .. } = fsm; + let conn = state.link.connected_mut()?; + + let (write_id, builder) = conn.session.take_next_write(state.now)?; + let record = builder.encrypt( + crypto, + conn.transport.tx_connection_id, + &conn.transport.tx_key, + ); + if conn.session.is_closed() && matches!(fsm.state.link, LinkState::Connected(_)) { + fsm.state.link = LinkState::Idle; + emit_peer_status(fsm, fsm.state.link.status()); + } + Some(OutboundWrite { + record, + write_id: write_id.map(WriteId), + }) +} + +pub fn complete_write(fsm: &mut QlFsm, write_id: WriteId, success: bool) { + let QlFsm { state, .. } = fsm; + if let Some(conn) = state.link.connected_mut() { + conn.session.complete_write(state.now, write_id.0, success); + } +} + +pub fn close_session(fsm: &mut QlFsm, code: SessionCloseCode) { + let QlFsm { state, events, .. } = fsm; + let Some(conn) = state.link.connected_mut() else { + return; + }; + let mut emit = EventSink::new(events); + conn.session.close(code, &mut emit); +} + +pub fn open_stream( + fsm: &mut QlFsm, + route_id: RouteId, +) -> Result, NoSessionError> { + let QlFsm { state, events, .. } = fsm; + let conn = state.link.connected_mut_or_err()?; + let inner = conn.session.open_stream(route_id, EventSink::new(events))?; + Ok(crate::StreamOps { inner }) +} + +pub fn stream(fsm: &mut QlFsm, stream_id: StreamId) -> Result, StreamError> { + let QlFsm { state, events, .. } = fsm; + let conn = state.link.connected_mut_or_err()?; + let inner = conn.session.stream(stream_id, EventSink::new(events))?; + Ok(crate::StreamOps { inner }) +} + +pub fn queue_ping(fsm: &mut QlFsm) -> Result<(), NoSessionError> { + let conn = fsm.state.link.connected_mut_or_err()?; + conn.session.queue_ping() +} + +pub fn poll_event(fsm: &mut QlFsm) -> Option { + fsm.events.pop_front() +} + +pub fn emit_peer_status(fsm: &mut QlFsm, status: crate::PeerStatus) { + fsm.events.push_back(Event::PeerStatusChanged(status)); +} diff --git a/ql-fsm/src/handshake/ik.rs b/ql-fsm/src/handshake/ik.rs new file mode 100644 index 00000000..7e6ebd1e --- /dev/null +++ b/ql-fsm/src/handshake/ik.rs @@ -0,0 +1,119 @@ +use ql_wire::{self as wire, Ik1, Ik2, PeerBundle, QlCrypto, QlHandshakeRecord}; + +use super::{ + emit_peer_status, enqueue_handshake, finish_handshake, reset_connected_session_if_needed, +}; +use crate::{ + state::{IkInitiatorState, LinkState, SessionTransport}, + QlFsm, ReceiveError, +}; + +pub fn start_initiator(fsm: &mut QlFsm, crypto: &impl QlCrypto, peer: PeerBundle) { + let meta = super::next_handshake_meta(fsm); + let mut handshake = wire::IkHandshake::new_initiator( + crypto, + fsm.identity.clone(), + peer, + super::local_transport_params(fsm), + ); + let message = handshake.write_1(crypto, meta).unwrap(); + + fsm.state.link = LinkState::IkInitiator(IkInitiatorState { + handshake_id: meta.handshake_id, + initial_ephemeral: message.ephemeral.clone(), + handshake, + deadline: fsm.state.now + fsm.config.handshake_timeout, + }); + enqueue_handshake(fsm, QlHandshakeRecord::Ik1(message)); + emit_peer_status(fsm, fsm.state.link.status()); +} + +pub fn handle_ik1( + fsm: &mut QlFsm, + crypto: &impl QlCrypto, + message: &Ik1, +) -> Result<(), ReceiveError> { + if should_ignore_inbound(fsm, message) { + return Ok(()); + } + if message.header.recipient != fsm.identity.qid { + return Err(ReceiveError::InvalidQid); + } + if let Some(peer) = fsm.state.peer.as_ref() { + if message.header.sender != peer.qid { + return Err(ReceiveError::InvalidQid); + } + } + + reset_connected_session_if_needed(fsm); + + let mut handshake = wire::IkHandshake::new_responder( + crypto, + fsm.identity.clone(), + fsm.state.peer.clone(), + super::local_transport_params(fsm), + ); + handshake + .read_1(crypto, message) + .map_err(ReceiveError::InvalidIkHandshake)?; + let outbound = handshake + .write_2(crypto, message.meta) + .map_err(ReceiveError::InvalidIkHandshake)?; + let (transport, remote_bundle) = SessionTransport::from_finalized( + handshake + .finalize(crypto) + .map_err(ReceiveError::InvalidIkHandshake)?, + ); + finish_handshake(fsm, transport, remote_bundle)?; + fsm.state.handshake = None; + enqueue_handshake(fsm, QlHandshakeRecord::Ik2(outbound)); + Ok(()) +} + +pub fn handle_ik2( + fsm: &mut QlFsm, + crypto: &impl QlCrypto, + message: &Ik2, +) -> Result<(), ReceiveError> { + { + let LinkState::IkInitiator(state) = &mut fsm.state.link else { + return Ok(()); + }; + + if message.meta.handshake_id != state.handshake_id { + return Ok(()); + } + + state + .handshake + .read_2(crypto, message) + .map_err(ReceiveError::InvalidIkHandshake)?; + } + + let LinkState::IkInitiator(state) = fsm.state.link.take() else { + unreachable!("active IK initiator was checked above"); + }; + let (transport, remote_bundle) = SessionTransport::from_finalized( + state + .handshake + .finalize(crypto) + .map_err(ReceiveError::InvalidIkHandshake)?, + ); + finish_handshake(fsm, transport, remote_bundle) +} + +pub fn should_ignore_inbound(fsm: &QlFsm, message: &Ik1) -> bool { + match &fsm.state.link { + LinkState::Idle + | LinkState::Connected(_) + | LinkState::KkInitiator(_) + | LinkState::XxInitiator(_) + | LinkState::XxResponder(_) => false, + LinkState::IkInitiator(state) => { + if fsm.state.peer.as_ref().map(|peer| peer.qid) != Some(message.header.sender) { + return false; + } + super::local_start_wins(&state.initial_ephemeral, &message.ephemeral) + } + } +} diff --git a/ql-fsm/src/handshake/kk.rs b/ql-fsm/src/handshake/kk.rs new file mode 100644 index 00000000..e78c8a6d --- /dev/null +++ b/ql-fsm/src/handshake/kk.rs @@ -0,0 +1,118 @@ +use ql_wire::{self as wire, Kk1, Kk2, PeerBundle, QlCrypto, QlHandshakeRecord}; + +use super::{ + emit_peer_status, enqueue_handshake, finish_handshake, reset_connected_session_if_needed, +}; +use crate::{ + state::{KkInitiatorState, LinkState, SessionTransport}, + QlFsm, ReceiveError, +}; + +pub fn start_initiator(fsm: &mut QlFsm, crypto: &impl QlCrypto, peer: PeerBundle) { + let meta = super::next_handshake_meta(fsm); + let mut handshake = wire::KkHandshake::new_initiator( + crypto, + fsm.identity.clone(), + peer, + super::local_transport_params(fsm), + ); + let message = handshake.write_1(crypto, meta).unwrap(); + + fsm.state.link = LinkState::KkInitiator(KkInitiatorState { + handshake_id: meta.handshake_id, + initial_ephemeral: message.ephemeral.clone(), + handshake, + deadline: fsm.state.now + fsm.config.handshake_timeout, + }); + enqueue_handshake(fsm, QlHandshakeRecord::Kk1(message)); + emit_peer_status(fsm, fsm.state.link.status()); +} + +pub fn handle_kk1( + fsm: &mut QlFsm, + crypto: &impl QlCrypto, + message: &Kk1, +) -> Result<(), ReceiveError> { + if should_ignore_inbound(fsm, message) { + return Ok(()); + } + + let Some(peer) = fsm.state.peer.clone() else { + return Err(ReceiveError::NoPeer); + }; + if message.header.recipient != fsm.identity.qid || message.header.sender != peer.qid { + return Err(ReceiveError::InvalidQid); + } + + reset_connected_session_if_needed(fsm); + + let mut handshake = wire::KkHandshake::new_responder( + crypto, + fsm.identity.clone(), + peer, + super::local_transport_params(fsm), + ); + handshake + .read_1(crypto, message) + .map_err(ReceiveError::InvalidKkHandshake)?; + let outbound = handshake + .write_2(crypto, message.meta) + .map_err(ReceiveError::InvalidKkHandshake)?; + let (transport, remote_bundle) = SessionTransport::from_finalized( + handshake + .finalize(crypto) + .map_err(ReceiveError::InvalidKkHandshake)?, + ); + finish_handshake(fsm, transport, remote_bundle)?; + fsm.state.handshake = None; + enqueue_handshake(fsm, QlHandshakeRecord::Kk2(outbound)); + Ok(()) +} + +pub fn handle_kk2( + fsm: &mut QlFsm, + crypto: &impl QlCrypto, + message: &Kk2, +) -> Result<(), ReceiveError> { + { + let LinkState::KkInitiator(state) = &mut fsm.state.link else { + return Ok(()); + }; + + if message.meta.handshake_id != state.handshake_id { + return Ok(()); + } + + state + .handshake + .read_2(crypto, message) + .map_err(ReceiveError::InvalidKkHandshake)?; + } + + let LinkState::KkInitiator(state) = fsm.state.link.take() else { + unreachable!("active KK initiator was checked above"); + }; + let (transport, remote_bundle) = SessionTransport::from_finalized( + state + .handshake + .finalize(crypto) + .map_err(ReceiveError::InvalidKkHandshake)?, + ); + finish_handshake(fsm, transport, remote_bundle) +} + +pub fn should_ignore_inbound(fsm: &QlFsm, message: &Kk1) -> bool { + match &fsm.state.link { + LinkState::Idle + | LinkState::Connected(_) + | LinkState::XxInitiator(_) + | LinkState::XxResponder(_) => false, + LinkState::IkInitiator(_) => true, + LinkState::KkInitiator(state) => { + if fsm.state.peer.as_ref().map(|peer| peer.qid) != Some(message.header.sender) { + return false; + } + super::local_start_wins(&state.initial_ephemeral, &message.ephemeral) + } + } +} diff --git a/ql-fsm/src/handshake/mod.rs b/ql-fsm/src/handshake/mod.rs new file mode 100644 index 00000000..1881f66e --- /dev/null +++ b/ql-fsm/src/handshake/mod.rs @@ -0,0 +1,140 @@ +mod ik; +mod kk; +mod xx; + +use ql_wire::{self as wire, EphemeralPublicKey, HandshakeMeta, QlCrypto, QlHandshakeRecord}; + +use crate::{ + fsm::emit_peer_status, + session::{SessionConfig, SessionFsm, StreamParity}, + state::{ConnectedState, LinkState, SessionTransport}, + Event, NoPeerError, QlFsm, ReceiveError, +}; + +pub fn handle_connect_ik(fsm: &mut QlFsm, crypto: &impl QlCrypto) -> Result<(), NoPeerError> { + let peer = fsm.state.peer.clone().ok_or(NoPeerError)?; + prepare_for_outbound_connect(fsm); + ik::start_initiator(fsm, crypto, peer); + Ok(()) +} + +pub fn handle_connect_kk(fsm: &mut QlFsm, crypto: &impl QlCrypto) -> Result<(), NoPeerError> { + let peer = fsm.state.peer.clone().ok_or(NoPeerError)?; + prepare_for_outbound_connect(fsm); + kk::start_initiator(fsm, crypto, peer); + Ok(()) +} + +pub fn handle_connect_xx(fsm: &mut QlFsm, invite: crate::PairingInvite, crypto: &impl QlCrypto) { + prepare_for_outbound_connect(fsm); + xx::start_initiator(fsm, crypto, invite.token, invite.qid); +} + +pub fn next_handshake_meta(fsm: &mut QlFsm) -> HandshakeMeta { + let handshake_id = wire::HandshakeId(fsm.state.next_control_id); + fsm.state.next_control_id = fsm.state.next_control_id.wrapping_add(1); + HandshakeMeta { handshake_id } +} + +pub fn enqueue_handshake(fsm: &mut QlFsm, record: QlHandshakeRecord) { + debug_assert!(fsm.state.handshake.is_none()); + fsm.state.handshake = Some(record); +} + +pub fn handle_disarm_pairing(fsm: &mut QlFsm) { + xx::disarm_pairing(fsm); +} + +fn local_transport_params(fsm: &QlFsm) -> wire::TransportParams { + wire::TransportParams { + initial_stream_receive_window: fsm.config.session_stream_receive_buffer_size, + } +} + +pub fn prepare_for_outbound_connect(fsm: &mut QlFsm) { + fsm.state.handshake = None; + reset_connected_session_if_needed(fsm); +} + +pub fn handle_handshake_record( + fsm: &mut QlFsm, + crypto: &impl QlCrypto, + record: &QlHandshakeRecord, +) -> Result<(), ReceiveError> { + match record { + QlHandshakeRecord::Ik1(message) => ik::handle_ik1(fsm, crypto, message), + QlHandshakeRecord::Ik2(message) => ik::handle_ik2(fsm, crypto, message), + QlHandshakeRecord::Kk1(message) => kk::handle_kk1(fsm, crypto, message), + QlHandshakeRecord::Kk2(message) => kk::handle_kk2(fsm, crypto, message), + QlHandshakeRecord::Xx1(message) => xx::handle_xx1(fsm, crypto, message), + QlHandshakeRecord::Xx2(message) => xx::handle_xx2(fsm, crypto, message), + QlHandshakeRecord::Xx3(message) => xx::handle_xx3(fsm, crypto, message), + QlHandshakeRecord::Xx4(message) => xx::handle_xx4(fsm, crypto, message), + } +} + +pub fn handle_timer(fsm: &mut QlFsm) { + let Some(deadline) = fsm.state.link.handshake_deadline() else { + return; + }; + if deadline > fsm.state.now { + return; + } + + fsm.state.link = LinkState::Idle; + fsm.state.handshake = None; + emit_peer_status(fsm, fsm.state.link.status()); +} + +pub fn next_handshake_deadline(fsm: &QlFsm) -> Option { + fsm.state.link.handshake_deadline() +} + +pub fn finish_handshake( + fsm: &mut QlFsm, + transport: SessionTransport, + remote_bundle: wire::PeerBundle, +) -> Result<(), ReceiveError> { + let qid = remote_bundle.qid; + if let Some(peer) = fsm.state.peer.as_ref() { + if peer != &remote_bundle { + return Err(ReceiveError::InvalidRemoteBundle); + } + } else { + fsm.state.peer = Some(remote_bundle); + fsm.events.push_back(Event::NewPeer); + } + + let config = &fsm.config; + let session = SessionFsm::new( + SessionConfig { + local_parity: StreamParity::for_local(fsm.identity.qid, qid), + record_max_size: config.session_record_max_size, + ack_delay: config.session_record_ack_delay, + retransmit_timeout: config.session_record_retransmit_timeout, + keepalive_interval: config.session_keepalive_interval, + peer_timeout: config.session_peer_timeout, + stream_send_buffer_size: config.session_stream_send_buffer_size, + stream_receive_buffer_size: config.session_stream_receive_buffer_size, + accepted_record_window: config.session_accepted_record_window, + pending_ack_range_limit: config.session_pending_ack_range_limit, + initial_peer_stream_receive_window: transport + .remote_transport_params + .initial_stream_receive_window, + }, + fsm.state.now, + ); + fsm.state.link = LinkState::Connected(ConnectedState { transport, session }); + emit_peer_status(fsm, fsm.state.link.status()); + Ok(()) +} + +pub fn reset_connected_session_if_needed(fsm: &mut QlFsm) { + if matches!(fsm.state.link, LinkState::Connected(_)) { + fsm.state.link = LinkState::Idle; + } +} + +fn local_start_wins(local: &EphemeralPublicKey, inbound: &EphemeralPublicKey) -> bool { + local.mlkem_public_key.as_bytes() <= inbound.mlkem_public_key.as_bytes() +} diff --git a/ql-fsm/src/handshake/xx.rs b/ql-fsm/src/handshake/xx.rs new file mode 100644 index 00000000..c9a289e0 --- /dev/null +++ b/ql-fsm/src/handshake/xx.rs @@ -0,0 +1,207 @@ +use ql_wire::{self as wire, PairingToken, QlCrypto, QlHandshakeRecord, Xx1, Xx2, Xx3, Xx4, QID}; + +use super::{ + emit_peer_status, enqueue_handshake, finish_handshake, reset_connected_session_if_needed, +}; +use crate::{ + state::{LinkState, SessionTransport, XxInitiatorState, XxResponderState}, + QlFsm, ReceiveError, +}; + +pub fn start_initiator( + fsm: &mut QlFsm, + crypto: &impl QlCrypto, + token: PairingToken, + remote_qid: QID, +) { + let meta = super::next_handshake_meta(fsm); + let mut handshake = wire::XxHandshake::new_initiator( + crypto, + fsm.identity.clone(), + remote_qid, + token, + super::local_transport_params(fsm), + ); + let message = handshake.write_1(crypto, meta).unwrap(); + + fsm.state.link = LinkState::XxInitiator(XxInitiatorState { + handshake_id: meta.handshake_id, + initial_ephemeral: message.ephemeral.clone(), + handshake, + deadline: fsm.state.now + fsm.config.handshake_timeout, + }); + enqueue_handshake(fsm, QlHandshakeRecord::Xx1(message)); + emit_peer_status(fsm, fsm.state.link.status()); +} + +pub fn handle_xx1( + fsm: &mut QlFsm, + crypto: &impl QlCrypto, + message: &Xx1, +) -> Result<(), ReceiveError> { + if should_ignore_inbound(fsm, crypto, message) { + return Ok(()); + } + match fsm.state.armed_pairing_token { + Some(expected) if expected.id(crypto) != message.pairing_id => { + Err(ReceiveError::InvalidPairingId { + expected: expected.id(crypto), + actual: message.pairing_id, + }) + } + Some(_) + if message.header.recipient != fsm.identity.qid + || message.header.sender == fsm.identity.qid => + { + Err(ReceiveError::InvalidQid) + } + Some(token) => { + reset_connected_session_if_needed(fsm); + + let mut handshake = wire::XxHandshake::new_responder( + crypto, + fsm.identity.clone(), + message.header.sender, + token, + super::local_transport_params(fsm), + ); + handshake + .read_1(crypto, message) + .map_err(ReceiveError::InvalidXxHandshake)?; + let outbound = handshake + .write_2(crypto, message.meta) + .map_err(ReceiveError::InvalidXxHandshake)?; + fsm.state.link = LinkState::XxResponder(XxResponderState { + handshake, + handshake_meta: message.meta, + deadline: fsm.state.now + fsm.config.handshake_timeout, + }); + fsm.state.handshake = None; + enqueue_handshake(fsm, QlHandshakeRecord::Xx2(outbound)); + Ok(()) + } + None => Err(ReceiveError::NotPairingMode), + } +} + +pub fn handle_xx2( + fsm: &mut QlFsm, + crypto: &impl QlCrypto, + message: &Xx2, +) -> Result<(), ReceiveError> { + { + let LinkState::XxInitiator(state) = &mut fsm.state.link else { + return Ok(()); + }; + + if message.meta.handshake_id != state.handshake_id { + return Ok(()); + } + + state + .handshake + .read_2(crypto, message) + .map_err(ReceiveError::InvalidXxHandshake)?; + let outbound = state + .handshake + .write_3(crypto, message.meta) + .map_err(ReceiveError::InvalidXxHandshake)?; + fsm.state.handshake = None; + enqueue_handshake(fsm, QlHandshakeRecord::Xx3(outbound)); + } + + Ok(()) +} + +pub fn handle_xx3( + fsm: &mut QlFsm, + crypto: &impl QlCrypto, + message: &Xx3, +) -> Result<(), ReceiveError> { + let LinkState::XxResponder(state) = &mut fsm.state.link else { + return Ok(()); + }; + + if message.meta.handshake_id != state.handshake_meta.handshake_id { + return Ok(()); + } + + state + .handshake + .read_3(crypto, message) + .map_err(ReceiveError::InvalidXxHandshake)?; + let handshake_meta = state.handshake_meta; + let LinkState::XxResponder(mut state) = fsm.state.link.take() else { + unreachable!("active XX responder was checked above"); + }; + let outbound = state + .handshake + .write_4(crypto, handshake_meta) + .map_err(ReceiveError::InvalidXxHandshake)?; + fsm.state.handshake = None; + enqueue_handshake(fsm, QlHandshakeRecord::Xx4(outbound)); + let (transport, remote_bundle) = SessionTransport::from_finalized( + state + .handshake + .finalize(crypto) + .map_err(ReceiveError::InvalidXxHandshake)?, + ); + finish_handshake(fsm, transport, remote_bundle) +} + +pub fn handle_xx4( + fsm: &mut QlFsm, + crypto: &impl QlCrypto, + message: &Xx4, +) -> Result<(), ReceiveError> { + { + let LinkState::XxInitiator(state) = &mut fsm.state.link else { + return Ok(()); + }; + + if message.meta.handshake_id != state.handshake_id { + return Ok(()); + } + + state + .handshake + .read_4(crypto, message) + .map_err(ReceiveError::InvalidXxHandshake)?; + } + + let LinkState::XxInitiator(state) = fsm.state.link.take() else { + unreachable!("active XX initiator was checked above"); + }; + let (transport, remote_bundle) = SessionTransport::from_finalized( + state + .handshake + .finalize(crypto) + .map_err(ReceiveError::InvalidXxHandshake)?, + ); + finish_handshake(fsm, transport, remote_bundle) +} + +pub fn disarm_pairing(fsm: &mut QlFsm) { + if matches!(fsm.state.link, LinkState::XxResponder(_)) { + fsm.state.link = LinkState::Idle; + fsm.state.handshake = None; + } +} + +pub fn should_ignore_inbound(fsm: &QlFsm, crypto: &impl QlCrypto, message: &Xx1) -> bool { + match &fsm.state.link { + LinkState::Idle | LinkState::Connected(_) => false, + LinkState::IkInitiator(_) | LinkState::KkInitiator(_) | LinkState::XxResponder(_) => true, + LinkState::XxInitiator(state) => { + if state.handshake.pairing_id(crypto) != message.pairing_id { + return false; + } + if message.header.recipient != fsm.identity.qid + || message.header.sender != state.handshake.remote_qid() + { + return false; + } + super::local_start_wins(&state.initial_ephemeral, &message.ephemeral) + } + } +} diff --git a/ql-fsm/src/lib.rs b/ql-fsm/src/lib.rs new file mode 100644 index 00000000..2010e308 --- /dev/null +++ b/ql-fsm/src/lib.rs @@ -0,0 +1,334 @@ +//! sync finite state machine for quantum link protocol +//! +//! a caller drives `QlFsm` inside its own event loop +//! +//! inputs to that loop usually include +//! - app actions like `bind_peer`, `connect_ik`, `connect_kk`, `connect_xx`, `open_stream`, or +//! `stream` +//! - inbound transport bytes passed to `receive` +//! - a deadline expiring, handled by calling `on_timer` +//! - transport write results passed to `complete_write` +//! +//! outputs from `QlFsm` are +//! - outbound session and handshake records from `take_next_write` +//! - queued `QlFsmEvent`s returned by `poll_event` after `connect_ik`, `connect_kk`, +//! `connect_xx`, `receive`, and `on_timer` +//! +//! call `next_deadline` after handling current inputs and any queued outputs +//! use it to decide how long the outer loop can wait before `on_timer` must run +//! another input may arrive before that deadline, which is fine + +mod error; +mod fsm; +mod handshake; +mod pairing; +mod session; +pub(crate) mod state; +#[cfg(test)] +mod tests; + +use std::{ + collections::VecDeque, + time::{Duration, Instant}, +}; + +pub use bytes::Bytes; +pub use error::*; +pub use pairing::PairingInvite; +use ql_wire::{ + PairingToken, PeerBundle, QlCrypto, QlIdentity, RouteId, SessionClose, SessionCloseCode, + StreamClose, StreamId, +}; +pub use session::{SessionEvent, StreamReadIter, StreamWriter}; + +use crate::state::{LinkState, QlFsmState}; + +/// connection state for the bound peer +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum PeerStatus { + /// no active encrypted session + Disconnected, + /// we are driving the handshake + Initiator, + /// the encrypted session is up + Connected, + /// the bound peer was forgotten immediately + /// + /// unpair is abortive and best-effort. the binding is removed immediately + /// and one final write may remain: a record containing only `SessionFrame::Unpair` + Unpaired, +} + +/// events emitted by `QlFsm` +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum Event { + /// a peer was learned during handshake completion + NewPeer, + /// the peer changed lifecycle state + PeerStatusChanged(PeerStatus), + /// a stream was opened + Opened { + stream_id: StreamId, + route_id: RouteId, + }, + /// a stream has bytes ready to read + Readable(StreamId), + /// a stream has room for more local writes + Writable(StreamId), + /// the peer finished writing this stream and no more bytes remain to read + Finished(StreamId), + /// our local FIN was acknowledged by the peer at the session layer + OutboundFinished(StreamId), + /// a stream was closed + Closed(StreamClose), + /// local writes on this stream are closed + WritableClosed(StreamClose), + /// the encrypted session was closed + /// + /// session close is abortive and best-effort. the session ends immediately + /// one final write remains: a record containing only `SessionFrame::Close` + /// the FSM does not wait for an ack for that record + SessionClosed(SessionClose), +} + +/// handle for a session write returned by `QlFsm::take_next_write` +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct WriteId(pub(crate) u64); + +/// outbound record produced by `QlFsm` +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct OutboundWrite { + /// wire bytes to hand to the transport + pub record: Vec, + /// write handle that must be completed exactly once + pub write_id: Option, +} + +pub struct StreamOps<'a> { + inner: session::StreamOps<'a, fsm::EventSink<'a>>, +} + +impl StreamOps<'_> { + /// returns this stream's identifier + pub fn stream_id(&self) -> StreamId { + self.inner.stream_id() + } + + /// returns the readable stream bytes as owned `Bytes` views without consuming them + pub fn read(&self) -> StreamReadIter<'_> { + self.inner.read() + } + + /// returns how many bytes can be read from the stream + pub fn readable_bytes(&self) -> usize { + self.inner.readable_bytes() + } + + /// marks previously read bytes as consumed + pub fn commit_read(&mut self, len: usize) -> Result<(), CommitReadError> { + self.inner.commit_read(len) + } + + /// returns a writer if the local write side is still open + pub fn writer(&mut self) -> Option> { + self.inner.writer() + } + + /// closes the origin lane, return lane, or both lanes of the stream + pub fn close(&mut self, target: ql_wire::CloseTarget, code: ql_wire::StreamCloseCode) { + self.inner.close(target, code); + } +} + +/// timing and buffering knobs for `QlFsm` +#[derive(Debug, Clone, Copy)] +pub struct QlFsmConfig { + /// overall time limit for one handshake attempt + pub handshake_timeout: Duration, + /// delay before sending a pure record ack + pub session_record_ack_delay: Duration, + /// how long to wait before resending unacked session records + pub session_record_retransmit_timeout: Duration, + /// idle delay before sending a keepalive ping + pub session_keepalive_interval: Duration, + /// how long to wait before declaring the peer dead + pub session_peer_timeout: Duration, + /// maximum total wire size for one session record, including header and auth tag + pub session_record_max_size: usize, + /// maximum bytes buffered locally for one stream send side + pub session_stream_send_buffer_size: usize, + /// maximum bytes buffered locally for one stream receive side + pub session_stream_receive_buffer_size: u32, + /// how many accepted record sequence numbers to retain for duplicate detection + pub session_accepted_record_window: u64, + /// maximum disjoint pending ACK ranges to retain before dropping the oldest low ranges + pub session_pending_ack_range_limit: usize, +} + +impl Default for QlFsmConfig { + fn default() -> Self { + let s = session::SessionConfig::default(); + Self { + handshake_timeout: Duration::from_secs(5), + session_record_ack_delay: s.ack_delay, + session_record_retransmit_timeout: s.retransmit_timeout, + session_keepalive_interval: s.keepalive_interval, + session_peer_timeout: s.peer_timeout, + session_record_max_size: s.record_max_size, + session_stream_send_buffer_size: s.stream_send_buffer_size, + session_stream_receive_buffer_size: s.stream_receive_buffer_size, + session_accepted_record_window: s.accepted_record_window, + session_pending_ack_range_limit: s.pending_ack_range_limit, + } + } +} + +/// synchronous driver for peer binding, handshake, and encrypted streams +pub struct QlFsm { + config: QlFsmConfig, + identity: QlIdentity, + state: QlFsmState, + events: VecDeque, +} + +impl QlFsm { + /// creates a new `QlFsm` + pub fn new(config: QlFsmConfig, identity: QlIdentity, now: Instant) -> Self { + Self { + config, + identity, + state: QlFsmState { + next_control_id: 1, + peer: None, + armed_pairing_token: None, + handshake: None, + link: LinkState::Idle, + now, + }, + events: VecDeque::new(), + } + } + + /// binds the remote peer + pub fn bind_peer(&mut self, peer: PeerBundle) { + fsm::handle_bind_peer(self, peer); + } + + /// returns the currently bound peer, if any + pub fn peer(&self) -> Option<&PeerBundle> { + self.state.peer.as_ref() + } + + /// arms acceptance of inbound xx pairings for a single token + pub fn arm_pairing(&mut self, token: PairingToken) { + self.state.armed_pairing_token = Some(token); + } + + pub fn pairing_token(&self) -> Option<&PairingToken> { + self.state.armed_pairing_token.as_ref() + } + + /// disarms inbound xx pairing and rejects any in-flight inbound xx responder state + pub fn disarm_pairing(&mut self) { + fsm::handle_disarm_pairing(self); + } + + /// starts an outbound xx handshake using a pairing invite + pub fn connect_xx(&mut self, now: Instant, invite: PairingInvite, crypto: &impl QlCrypto) { + self.state.now = now; + fsm::handle_connect_xx(self, invite, crypto); + } + + /// starts an IK handshake with the currently bound peer + pub fn connect_ik(&mut self, now: Instant, crypto: &impl QlCrypto) -> Result<(), NoPeerError> { + self.state.now = now; + fsm::handle_connect_ik(self, crypto) + } + + /// starts a KK handshake with the currently bound peer + pub fn connect_kk(&mut self, now: Instant, crypto: &impl QlCrypto) -> Result<(), NoPeerError> { + self.state.now = now; + fsm::handle_connect_kk(self, crypto) + } + + /// handles one inbound wire message + pub fn receive( + &mut self, + now: Instant, + bytes: Vec, + crypto: &impl QlCrypto, + ) -> Result<(), ReceiveError> { + self.state.now = now; + fsm::receive(self, bytes, crypto) + } + + /// returns the next queued event, if any + pub fn poll_event(&mut self) -> Option { + fsm::poll_event(self) + } + + /// advances time-based state + pub fn on_timer(&mut self, now: Instant) { + self.state.now = now; + fsm::on_timer(self); + } + + /// returns the next timer deadline, if any + pub fn next_deadline(&self) -> Option { + fsm::next_deadline(self) + } + + pub fn has_shutdown_work(&self) -> bool { + self.state + .link + .connected() + .is_some_and(|state| state.session.has_shutdown_work()) + } + + /// returns the next outbound record + /// + /// if `write_id` is `Some`, call `complete_write` exactly once + /// + /// if it is `None`, the record is fire-and-forget + pub fn take_next_write( + &mut self, + now: Instant, + crypto: &impl QlCrypto, + ) -> Option { + self.state.now = now; + fsm::take_next_write(self, crypto) + } + + /// completes a `SessionWriteId` from `take_next_write` with the transport outcome + /// + /// call this at most once for each returned `SessionWriteId` + pub fn complete_write(&mut self, now: Instant, write_id: WriteId, success: bool) { + self.state.now = now; + fsm::complete_write(self, write_id, success); + } + + /// closes the current encrypted session locally + pub fn close_session(&mut self, code: SessionCloseCode) { + fsm::close_session(self, code); + } + + /// forgets the bound peer locally and may emit one final outbound `SessionFrame::Unpair` + pub fn unpair(&mut self) { + fsm::unpair(self); + } + + /// opens a new outgoing stream + pub fn open_stream(&mut self, route_id: RouteId) -> Result, NoSessionError> { + fsm::open_stream(self, route_id) + } + + /// returns a facade for an open stream + pub fn stream(&mut self, stream_id: StreamId) -> Result, StreamError> { + fsm::stream(self, stream_id) + } + + /// queues a ping on the active session + pub fn queue_ping(&mut self) -> Result<(), NoSessionError> { + fsm::queue_ping(self) + } +} diff --git a/ql-fsm/src/pairing.rs b/ql-fsm/src/pairing.rs new file mode 100644 index 00000000..4b8361b8 --- /dev/null +++ b/ql-fsm/src/pairing.rs @@ -0,0 +1,38 @@ +use ql_wire::{ByteSlice, PairingToken, Reader, WireDecode, WireEncode, WireError, QID}; + +/// Out-of-band invite consumed by the initiator of an XX pairing +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct PairingInvite { + pub qid: QID, + pub token: PairingToken, +} + +impl PairingInvite { + pub const VERSION: u8 = 1; + pub const WIRE_SIZE: usize = size_of::() + QID::SIZE + PairingToken::SIZE; +} + +impl WireEncode for PairingInvite { + fn encoded_len(&self) -> usize { + Self::WIRE_SIZE + } + + fn encode(&self, out: &mut W) { + Self::VERSION.encode(out); + self.qid.encode(out); + self.token.encode(out); + } +} + +impl WireDecode for PairingInvite { + fn decode(reader: &mut Reader) -> Result { + if reader.decode::()? != Self::VERSION { + return Err(WireError::InvalidPayload); + } + + Ok(Self { + qid: reader.decode()?, + token: reader.decode()?, + }) + } +} diff --git a/ql-fsm/src/session/ack_tracker.rs b/ql-fsm/src/session/ack_tracker.rs new file mode 100644 index 00000000..a75b5c63 --- /dev/null +++ b/ql-fsm/src/session/ack_tracker.rs @@ -0,0 +1,266 @@ +use std::{ops::RangeInclusive, time::Instant}; + +use ql_wire::{RecordAck, RecordAckBuilder, RecordSeq}; + +use super::range_set::RangeSet; + +#[derive(Debug, Clone)] +pub struct AckTracker { + accepted_records: RangeSet, + pending_ack: RangeSet, + ack_state: AckState, + accepted_record_window: u64, + pending_ack_range_limit: usize, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct PendingAck { + pub ack: RecordAck, + pub due_at: Instant, + pub includes_all_pending: bool, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ReceiveOutcome { + New, + Duplicate, + TooOld, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum AckState { + Idle, + Dirty { due_at: Instant }, +} + +impl AckTracker { + pub fn new(accepted_record_window: u64, pending_ack_range_limit: usize) -> Self { + Self { + accepted_records: RangeSet::new(), + pending_ack: RangeSet::new(), + ack_state: AckState::Idle, + accepted_record_window: accepted_record_window.max(1), + pending_ack_range_limit: pending_ack_range_limit.max(1), + } + } + + pub fn insert(&mut self, seq: RecordSeq) -> ReceiveOutcome { + let seq = seq.into_inner(); + let largest_accepted = self.accepted_records.max(); + if largest_accepted.is_some_and(|largest| seq < self.accepted_cutoff(largest)) { + return ReceiveOutcome::TooOld; + } + if self.accepted_records.contains(seq) { + self.pending_ack.insert(single_range(seq)); + self.trim_pending_ack_ranges(); + return ReceiveOutcome::Duplicate; + } + + self.accepted_records.insert(single_range(seq)); + self.trim_accepted_records(); + + self.pending_ack.insert(single_range(seq)); + self.trim_pending_ack_ranges(); + + ReceiveOutcome::New + } + + pub fn ack_deadline(&self) -> Option { + match self.ack_state { + AckState::Idle => None, + AckState::Dirty { due_at } => Some(due_at), + } + } + + pub fn schedule_ack(&mut self, due_at: Instant) { + self.ack_state = match self.ack_state { + AckState::Dirty { due_at: old } => AckState::Dirty { + due_at: due_at.min(old), + }, + AckState::Idle => AckState::Dirty { due_at }, + }; + } + + pub fn pending_ack(&self, max_wire_size: usize) -> Option { + let due_at = self.ack_deadline()?; + if max_wire_size == 0 || self.pending_ack.range_count() == 0 { + return None; + } + + let total_range_count = self.pending_ack.range_count(); + let mut ack = RecordAckBuilder::new(); + let mut selected_range_count = 0usize; + + for range in self.pending_ack.iter_rev() { + let pushed = ack + .try_push_range(to_ack_range(range), max_wire_size) + .unwrap(); + if !pushed { + break; + } + selected_range_count += 1; + } + + (selected_range_count != 0).then(|| PendingAck { + ack: ack.build().unwrap(), + due_at, + includes_all_pending: total_range_count == selected_range_count, + }) + } + + pub fn on_ack_emitted(&mut self, pending_ack: &PendingAck) { + self.retire_acked_ranges(&pending_ack.ack); + if pending_ack.includes_all_pending || self.pending_ack.range_count() == 0 { + self.ack_state = AckState::Idle; + } + } + + pub fn retire_acked_ranges(&mut self, ack: &RecordAck) { + for range in ack.ranges() { + self.pending_ack.remove(from_ack_range(range)); + } + if self.pending_ack.range_count() == 0 { + self.ack_state = AckState::Idle; + } + } + + pub fn clear_ack_state(&mut self) { + self.ack_state = AckState::Idle; + } + + pub fn restore_acked_ranges(&mut self, ack: &RecordAck, due_at: Instant) { + for range in ack.ranges() { + self.pending_ack.insert(from_ack_range(range)); + } + self.trim_pending_ack_ranges(); + self.schedule_ack(due_at); + } + + fn accepted_cutoff(&self, largest_accepted: u64) -> u64 { + largest_accepted + .saturating_add(1) + .saturating_sub(self.accepted_record_window) + } + + fn trim_accepted_records(&mut self) { + let Some(largest_accepted) = self.accepted_records.max() else { + return; + }; + let cutoff = self.accepted_cutoff(largest_accepted); + self.accepted_records.remove(0..cutoff); + } + + fn trim_pending_ack_ranges(&mut self) { + while self.pending_ack.range_count() > self.pending_ack_range_limit { + self.pending_ack.pop_min(); + } + } +} + +fn single_range(seq: u64) -> std::ops::Range { + seq..seq.checked_add(1).unwrap() +} + +fn to_ack_range(range: std::ops::Range) -> RangeInclusive { + let end = range.end.checked_sub(1).unwrap(); + RecordSeq::from_u64(range.start).unwrap()..=RecordSeq::from_u64(end).unwrap() +} + +fn from_ack_range(range: RangeInclusive) -> std::ops::Range { + let start = range.start().into_inner(); + let end = range.end().into_inner().checked_add(1).unwrap(); + start..end +} + +#[cfg(test)] +mod tests { + use std::time::{Duration, Instant}; + + use ql_wire::RecordSeq; + + use super::{AckTracker, PendingAck, ReceiveOutcome}; + + fn seq(value: u64) -> RecordSeq { + RecordSeq::from_u64(value).unwrap() + } + + fn ack_ranges(pending_ack: &PendingAck) -> Vec<(u64, u64)> { + pending_ack + .ack + .ranges() + .map(|range| (range.start().into_inner(), range.end().into_inner())) + .collect() + } + + #[test] + fn contiguous_records_emit_one_ack_range() { + let now = Instant::now(); + let mut ack_tracker = AckTracker::new(128, 8); + + assert_eq!(ack_tracker.insert(seq(10)), ReceiveOutcome::New); + assert_eq!(ack_tracker.insert(seq(11)), ReceiveOutcome::New); + assert_eq!(ack_tracker.insert(seq(12)), ReceiveOutcome::New); + + ack_tracker.schedule_ack(now); + let pending_ack = ack_tracker.pending_ack(usize::MAX).unwrap(); + assert_eq!(ack_ranges(&pending_ack), vec![(10, 12)]); + } + + #[test] + fn sparse_records_emit_descending_ack_ranges() { + let now = Instant::now(); + let mut ack_tracker = AckTracker::new(128, 8); + + assert_eq!(ack_tracker.insert(seq(10)), ReceiveOutcome::New); + assert_eq!(ack_tracker.insert(seq(15)), ReceiveOutcome::New); + assert_eq!(ack_tracker.insert(seq(16)), ReceiveOutcome::New); + assert_eq!(ack_tracker.insert(seq(12)), ReceiveOutcome::New); + + ack_tracker.schedule_ack(now + Duration::from_millis(5)); + let pending_ack = ack_tracker.pending_ack(usize::MAX).unwrap(); + assert_eq!(ack_ranges(&pending_ack), vec![(15, 16), (12, 12), (10, 10)]); + } + + #[test] + fn accepted_record_window_evicts_old_sequences() { + let mut ack_tracker = AckTracker::new(4, 8); + + assert_eq!(ack_tracker.insert(seq(10)), ReceiveOutcome::New); + assert_eq!(ack_tracker.insert(seq(15)), ReceiveOutcome::New); + + assert_eq!(ack_tracker.insert(seq(10)), ReceiveOutcome::TooOld); + } + + #[test] + fn pending_ack_range_limit_drops_oldest_low_ranges() { + let now = Instant::now(); + let mut ack_tracker = AckTracker::new(128, 2); + + assert_eq!(ack_tracker.insert(seq(1)), ReceiveOutcome::New); + assert_eq!(ack_tracker.insert(seq(3)), ReceiveOutcome::New); + assert_eq!(ack_tracker.insert(seq(5)), ReceiveOutcome::New); + + ack_tracker.schedule_ack(now); + let pending_ack = ack_tracker.pending_ack(usize::MAX).unwrap(); + assert_eq!(ack_ranges(&pending_ack), vec![(5, 5), (3, 3)]); + } + + #[test] + fn retire_acked_ranges_removes_only_exact_snapshot() { + let now = Instant::now(); + let mut ack_tracker = AckTracker::new(128, 8); + + assert_eq!(ack_tracker.insert(seq(1)), ReceiveOutcome::New); + assert_eq!(ack_tracker.insert(seq(3)), ReceiveOutcome::New); + assert_eq!(ack_tracker.insert(seq(5)), ReceiveOutcome::New); + ack_tracker.schedule_ack(now); + + let first_ack = ack_tracker.pending_ack(4).unwrap(); + assert_eq!(ack_ranges(&first_ack), vec![(5, 5)]); + ack_tracker.on_ack_emitted(&first_ack); + ack_tracker.retire_acked_ranges(&first_ack.ack); + + let second_ack = ack_tracker.pending_ack(usize::MAX).unwrap(); + assert_eq!(ack_ranges(&second_ack), vec![(3, 3), (1, 1)]); + } +} diff --git a/ql-fsm/src/session/mod.rs b/ql-fsm/src/session/mod.rs new file mode 100644 index 00000000..55187757 --- /dev/null +++ b/ql-fsm/src/session/mod.rs @@ -0,0 +1,1041 @@ +pub use self::{state::TerminalFrame, stream_ops::*, stream_parity::*, stream_rx::*}; + +mod ack_tracker; +mod range_set; +mod remote_stream_history; +mod state; +mod stream_ops; +mod stream_parity; +mod stream_rx; +mod stream_tx; +mod tracked; + +#[cfg(test)] +mod tests; + +use std::time::{Duration, Instant}; + +use bytes::Bytes; +use indexmap::IndexMap; +use ql_wire::{ + CloseTarget, RecordAck, RecordSeq, RouteId, SessionClose, SessionCloseCode, SessionFrame, + SessionRecordBuilder, StreamClose, StreamData, StreamHeader, StreamId, StreamWindow, VarInt, + WireError, +}; + +use self::{ + ack_tracker::{AckTracker, PendingAck, ReceiveOutcome}, + remote_stream_history::RemoteStreamHistory, + state::{InboundState, OutboundState, SessionPhase, SessionState, StreamRole, StreamState}, + stream_tx::StreamTxRange, + tracked::{TrackedFrame, TrackedRecord, TrackedStreamData}, +}; +use crate::{NoSessionError, StreamError}; + +#[derive(Debug, Clone, Copy)] +pub struct SessionConfig { + pub local_parity: StreamParity, + pub record_max_size: usize, + pub ack_delay: Duration, + pub retransmit_timeout: Duration, + pub keepalive_interval: Duration, + pub peer_timeout: Duration, + pub stream_send_buffer_size: usize, + pub stream_receive_buffer_size: u32, + pub initial_peer_stream_receive_window: u32, + pub accepted_record_window: u64, + pub pending_ack_range_limit: usize, +} + +impl Default for SessionConfig { + fn default() -> Self { + Self { + local_parity: StreamParity::Even, + record_max_size: 8 * 1024, + ack_delay: Duration::from_millis(5), + retransmit_timeout: Duration::from_millis(150), + keepalive_interval: Duration::from_secs(10), + peer_timeout: Duration::from_secs(30), + stream_send_buffer_size: 16 * 1024, + stream_receive_buffer_size: 16 * 1024, + initial_peer_stream_receive_window: 16 * 1024, + accepted_record_window: 4096, + pending_ack_range_limit: 64, + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum SessionEvent { + Opened { + stream_id: StreamId, + route_id: RouteId, + }, + Readable(StreamId), + Writable(StreamId), + Finished(StreamId), + OutboundFinished(StreamId), + Closed(StreamClose), + WritableClosed(StreamClose), + SessionClosed(SessionClose), + Unpaired, +} + +pub trait EventSink { + fn emit(&mut self, event: SessionEvent); +} + +impl EventSink for F +where + F: FnMut(SessionEvent), +{ + fn emit(&mut self, event: SessionEvent) { + self(event); + } +} + +pub struct SessionFsm { + config: SessionConfig, + state: SessionState, +} + +impl SessionFsm { + pub fn new(mut config: SessionConfig, now: Instant) -> Self { + config.record_max_size = config + .record_max_size + .max(SessionRecordBuilder::MIN_CAPACITY); + config.stream_send_buffer_size = config.stream_send_buffer_size.max(1); + config.stream_receive_buffer_size = config.stream_receive_buffer_size.max(1); + config.accepted_record_window = config.accepted_record_window.max(1); + config.pending_ack_range_limit = config.pending_ack_range_limit.max(1); + Self { + config, + state: SessionState { + last_activity_at: now, + last_inbound_at: now, + phase: SessionPhase::Open, + next_stream_ordinal: 0, + next_record_seq: RecordSeq::from_u32(0), + next_write_id: 0, + tracked_records: IndexMap::default(), + ack_tracker: AckTracker::new( + config.accepted_record_window, + config.pending_ack_range_limit, + ), + pending_ping: false, + streams: IndexMap::default(), + next_stream_index: 0, + remote_stream_history: RemoteStreamHistory::new(config.local_parity.remote()), + }, + } + } + + pub fn open_stream( + &mut self, + route_id: RouteId, + sink: E, + ) -> Result, NoSessionError> + where + E: EventSink, + { + self.ensure_session_open()?; + let stream_id = self + .config + .local_parity + .make_stream_id(self.state.next_stream_ordinal); + self.state.next_stream_ordinal = self.state.next_stream_ordinal.saturating_add(1); + self.state.streams.insert( + stream_id, + StreamState::new( + StreamRole::Initiator, + Some(route_id), + self.config.stream_receive_buffer_size, + self.config.initial_peer_stream_receive_window, + ), + ); + let stream_index = self.state.streams.len() - 1; + Ok(StreamOps::new(self, stream_id, stream_index, sink)) + } + + pub fn stream( + &mut self, + stream_id: StreamId, + sink: E, + ) -> Result, StreamError> + where + E: EventSink, + { + self.ensure_session_open()?; + let Some(stream_index) = self.state.streams.get_index_of(&stream_id) else { + return Err(StreamError::MissingStream); + }; + + Ok(StreamOps::new(self, stream_id, stream_index, sink)) + } + + pub fn queue_ping(&mut self) -> Result<(), NoSessionError> { + self.ensure_session_open()?; + self.state.pending_ping = true; + Ok(()) + } + + pub fn close(&mut self, code: SessionCloseCode, sink: &mut impl EventSink) { + if self.state.phase != SessionPhase::Open { + return; + } + + self.begin_termination(TerminalFrame::Close(SessionClose { code }), sink); + } + + pub fn unpair(&mut self, sink: &mut impl EventSink) { + if self.state.phase != SessionPhase::Open { + return; + } + + self.begin_termination(TerminalFrame::Unpair, sink); + } + + pub fn is_closed(&self) -> bool { + self.state.phase == SessionPhase::Closed + } + + pub fn receive(&mut self, now: Instant, seq: RecordSeq, frames: I, sink: &mut impl EventSink) + where + I: IntoIterator, WireError>>, + { + if self.state.phase != SessionPhase::Open { + return; + } + + self.state.last_activity_at = now; + self.state.last_inbound_at = now; + self.collect_timeouts(now); + + match self.state.ack_tracker.insert(seq) { + ReceiveOutcome::TooOld => return, + ReceiveOutcome::Duplicate => { + self.schedule_ack(now, true); + return; + } + ReceiveOutcome::New => {} + } + + let mut ack_eliciting = false; + + for frame in frames { + let Ok(frame) = frame else { + self.close(SessionCloseCode::PROTOCOL, sink); + return; + }; + ack_eliciting |= !matches!(frame, SessionFrame::Ack(_)); + match frame { + SessionFrame::Ping => {} + SessionFrame::Unpair => { + self.unpair(sink); + return; + } + SessionFrame::Ack(ack) => self.process_record_ack(&ack, sink), + SessionFrame::StreamData(frame) => { + if self.handle_stream_data(frame, sink).is_err() { + self.close(SessionCloseCode::PROTOCOL, sink); + return; + } + } + SessionFrame::StreamWindow(frame) => self.handle_stream_window(&frame, sink), + SessionFrame::StreamClose(frame) => { + if self.handle_stream_close(&frame, sink).is_err() { + self.close(SessionCloseCode::PROTOCOL, sink); + return; + } + } + SessionFrame::Close(close) => { + self.close(close.code, sink); + return; + } + } + } + + if ack_eliciting { + self.schedule_ack(now, false); + } + } + + pub fn complete_write(&mut self, now: Instant, write_id: u64, success: bool) { + if !self.state.phase.is_open() { + return; + } + if success { + let Some(record) = self.state.tracked_records.get_mut(&write_id) else { + return; + }; + if record.sent_at.is_some() { + return; + } + self.state.last_activity_at = now; + record.sent_at = Some(now); + } else { + if self + .state + .tracked_records + .get(&write_id) + .is_some_and(|record| record.sent_at.is_some()) + { + return; + } + let Some(record) = self.state.tracked_records.shift_remove(&write_id) else { + return; + }; + restore_tracked_record( + now, + &mut self.state.ack_tracker, + &mut self.state.pending_ping, + &mut self.state.streams, + record, + ); + } + } + + pub fn on_timer(&mut self, now: Instant, sink: &mut impl EventSink) { + if !self.state.phase.is_open() { + return; + } + self.collect_timeouts(now); + if !self.config.peer_timeout.is_zero() + && self.state.last_inbound_at + self.config.peer_timeout <= now + { + self.close(SessionCloseCode::TIMEOUT, sink); + return; + } + if self.state.phase == SessionPhase::Open + && !self.config.keepalive_interval.is_zero() + && self.state.last_activity_at + self.config.keepalive_interval <= now + { + self.state.pending_ping = true; + } + } + + pub fn next_deadline(&self) -> Option { + if !self.state.phase.is_open() { + return None; + } + let ack_deadline = self.state.ack_tracker.ack_deadline(); + let retransmit_deadline = self + .state + .tracked_records + .values() + .filter_map(|record| { + record + .sent_at + .map(|sent_at| sent_at + self.config.retransmit_timeout) + }) + .min(); + let is_open = self.state.phase.is_open(); + let keepalive_deadline = + (is_open && !self.config.keepalive_interval.is_zero() && !self.state.pending_ping) + .then_some(self.state.last_activity_at + self.config.keepalive_interval); + let peer_timeout_deadline = (is_open && !self.config.peer_timeout.is_zero()) + .then_some(self.state.last_inbound_at + self.config.peer_timeout); + [ + ack_deadline, + retransmit_deadline, + keepalive_deadline, + peer_timeout_deadline, + ] + .into_iter() + .flatten() + .min() + } + + pub fn has_shutdown_work(&self) -> bool { + matches!(self.state.phase, SessionPhase::Terminating(_)) + || self.state.ack_tracker.ack_deadline().is_some() + || !self.state.tracked_records.is_empty() + } + + pub fn take_next_write(&mut self, now: Instant) -> Option<(Option, SessionRecordBuilder)> { + match &self.state.phase { + SessionPhase::Terminating(frame) => { + let seq = self.state.next_record_seq; + next_seq(&mut self.state.next_record_seq); + let mut builder = SessionRecordBuilder::new(seq, self.config.record_max_size); + match frame { + TerminalFrame::Close(close) => { + assert!(builder.push_close(close), "builder has capacity"); + } + TerminalFrame::Unpair => { + assert!(builder.push_unpair(), "builder has capacity"); + } + } + self.state.phase = SessionPhase::Closed; + return Some((None, builder)); + } + SessionPhase::Closed => { + return None; + } + SessionPhase::Open => {} + } + self.collect_timeouts(now); + + let (builder, outbound) = self.build_next_record(now)?; + + let should_track = outbound.ping_included + || !outbound.window_updates.is_empty() + || !outbound.frames.is_empty(); + let write_id = should_track.then(|| { + let write_id = self.state.next_write_id; + self.state.next_write_id = self.state.next_write_id.wrapping_add(1); + self.state.tracked_records.insert(write_id, outbound); + write_id + }); + + Some((write_id, builder)) + } + + fn build_next_record(&mut self, now: Instant) -> Option<(SessionRecordBuilder, TrackedRecord)> { + let seq = self.state.next_record_seq; + let mut builder = SessionRecordBuilder::new(seq, self.config.record_max_size); + let mut outbound = TrackedRecord { + seq, + frames: Vec::new(), + ack: None, + ping_included: false, + window_updates: Vec::new(), + sent_at: None, + }; + + self.push_next_pending_stream_close(&mut builder, &mut outbound); + + if self.state.pending_ping && builder.push_ping() { + self.state.pending_ping = false; + outbound.ping_included = true; + } + + self.push_next_pending_stream_window(&mut builder, &mut outbound); + + self.push_next_stream_data(&mut builder, &mut outbound); + + if let Some(pending_ack) = self.pending_ack(builder.remaining_capacity()) { + if (!builder.is_empty() || pending_ack.due_at <= now) + && builder.push_ack(&pending_ack.ack) + { + self.state.ack_tracker.on_ack_emitted(&pending_ack); + outbound.ack = Some(pending_ack.ack); + } + } + + if builder.is_empty() { + return None; + } + + next_seq(&mut self.state.next_record_seq); + Some((builder, outbound)) + } + + fn begin_termination(&mut self, frame: TerminalFrame, sink: &mut impl EventSink) { + match &frame { + TerminalFrame::Close(close) => sink.emit(SessionEvent::SessionClosed(close.clone())), + TerminalFrame::Unpair => sink.emit(SessionEvent::Unpaired), + } + + self.state.phase = SessionPhase::Terminating(frame); + self.state.tracked_records.clear(); + self.state.ack_tracker.clear_ack_state(); + self.clear_streams(); + } + + fn push_next_pending_stream_close( + &mut self, + builder: &mut SessionRecordBuilder, + outbound: &mut TrackedRecord, + ) { + let len = self.state.streams.len(); + if len == 0 { + return; + } + + let start = self.state.next_stream_index % len; + for offset in 0..len { + let index = (start + offset) % len; + let stream = self.state.streams.get_index_mut(index).unwrap().1; + let Some(close) = stream.pending_close.as_ref() else { + continue; + }; + if !builder.push_stream_close(close) { + break; + } + + outbound.frames.push(TrackedFrame::StreamClose( + stream.pending_close.take().unwrap(), + )); + } + } + + fn push_next_pending_stream_window( + &mut self, + builder: &mut SessionRecordBuilder, + outbound: &mut TrackedRecord, + ) { + let len = self.state.streams.len(); + if len == 0 { + return; + } + + let start = self.state.next_stream_index % len; + for offset in 0..len { + let index = (start + offset) % len; + let (&stream_id, stream) = self.state.streams.get_index_mut(index).unwrap(); + if !stream.pending_window { + continue; + } + let frame = StreamWindow { + stream_id, + maximum_offset: VarInt::from_u64(stream.recv_limit()).unwrap(), + }; + if !builder.push_stream_window(&frame) { + break; + } + + stream.pending_window = false; + stream.advertised_max_offset = frame.maximum_offset.into_inner(); + outbound + .window_updates + .push((stream_id, frame.maximum_offset.into_inner())); + } + } + + fn push_next_stream_data( + &mut self, + builder: &mut SessionRecordBuilder, + outbound: &mut TrackedRecord, + ) { + const OVERHEAD: usize = 1 + StreamData::>::MIN_WIRE_SIZE; + + let len = self.state.streams.len(); + if len == 0 { + return; + } + + let start = self.state.next_stream_index % len; + let mut next_index = start; + + for offset in 0..len { + let Some(max_payload) = builder.remaining_capacity().checked_sub(OVERHEAD) else { + break; + }; + + let index = (start + offset) % len; + let (&stream_id, stream) = self.state.streams.get_index_mut(index).unwrap(); + if matches!(stream.outbound_state, OutboundState::Closed) { + continue; + } + let Some(candidate) = stream.tx.poll_transmit(max_payload, stream.peer_max_offset) + else { + continue; + }; + let offset = + VarInt::from_u64(candidate.offset).expect("stream offsets must fit ql-wire varint"); + let frame = StreamData { + stream_id, + offset, + header: if matches!(stream.role, StreamRole::Initiator) && candidate.offset == 0 { + stream.route_id.map(|route_id| StreamHeader { route_id }) + } else { + None + }, + fin: candidate.fin, + bytes: stream.tx.ranged_bytes(candidate), + }; + let res = builder.push_stream_data(&frame); + assert!(res, "builder has capacity"); + + if candidate.fin { + stream.outbound_state = OutboundState::Finished; + } + outbound + .frames + .push(TrackedFrame::StreamData(TrackedStreamData { + stream_id, + offset: candidate.offset, + len: candidate.len, + fin: candidate.fin, + })); + next_index = (index + 1) % len; + } + + self.state.next_stream_index = next_index; + } + + fn ensure_session_open(&self) -> Result<(), NoSessionError> { + if self.state.phase == SessionPhase::Open { + Ok(()) + } else { + Err(NoSessionError) + } + } + + fn process_record_ack(&mut self, ack: &RecordAck, sink: &mut impl EventSink) { + let stream_send_buffer_size = self.config.stream_send_buffer_size; + let acked_records = self + .state + .tracked_records + .extract_if(.., |_, record| { + record.sent_at.is_some() && ack.contains(record.seq.into_inner()) + }) + .map(|(_, record)| record) + .collect::>(); + + for record in acked_records { + for frame in &record.frames { + acknowledge_tracked_frame( + &mut self.state.streams, + stream_send_buffer_size, + frame, + sink, + ); + } + } + self.reap_reapable_streams(); + } + + fn schedule_ack(&mut self, now: Instant, immediate: bool) { + self.state.ack_tracker.schedule_ack(if immediate { + now + } else { + now + self.config.ack_delay + }); + } + + fn pending_ack(&self, remaining_capacity: usize) -> Option { + let max_ack_wire_size = remaining_capacity.checked_sub(1)?; + self.state.ack_tracker.pending_ack(max_ack_wire_size) + } + + fn collect_timeouts(&mut self, now: Instant) { + let retransmit_timeout = self.config.retransmit_timeout; + for (_, record) in self.state.tracked_records.extract_if(.., |_, record| { + record + .sent_at + .is_some_and(|sent_at| sent_at + retransmit_timeout <= now) + }) { + restore_tracked_record( + now, + &mut self.state.ack_tracker, + &mut self.state.pending_ping, + &mut self.state.streams, + record, + ); + } + } + + fn handle_stream_data( + &mut self, + frame: StreamData, + sink: &mut impl EventSink, + ) -> Result<(), ()> { + let StreamData { + stream_id, + offset, + header, + fin, + bytes, + } = frame; + let stream = match self.state.streams.get_mut(&stream_id) { + Some(stream) => stream, + None => match self.create_remote_stream(stream_id)? { + Some(stream) => stream, + None => return Ok(()), + }, + }; + + let frame_offset = offset.into_inner(); + let Some(frame_end) = frame_offset.checked_add(bytes.len() as u64) else { + return Err(()); + }; + let readable_before = stream.readable_bytes(); + let was_finished = matches!(stream.inbound_state, InboundState::Finished); + + let opened_route = match (stream.role, stream.route_id, header, frame_offset) { + (StreamRole::Responder, None, Some(header), 0) => { + stream.route_id = Some(header.route_id); + Some(header.route_id) + } + (StreamRole::Initiator, _, Some(_), _) + | (StreamRole::Responder, None, Some(_), _) + | (StreamRole::Responder, None, None, 0) => return Err(()), + _ => None, + }; + + match stream.inbound_state { + InboundState::Open => {} + InboundState::Discarding | InboundState::Closed(_) => return Ok(()), + InboundState::Finished => { + // finished stream should always have a final offset + let Some(final_offset) = stream.rx.final_offset() else { + debug_assert!(false, "finished stream must retain final offset"); + return Ok(()); + }; + + // retransmitted data for an already-finished stream is fine as long as it stays + // within the finalized byte range and any repeated FIN lands on that same offset. + if (!frame.fin || frame_end == final_offset) && frame_end <= final_offset { + if let Some(route_id) = opened_route { + sink.emit(SessionEvent::Opened { + stream_id, + route_id, + }); + if readable_before > 0 { + sink.emit(SessionEvent::Readable(stream_id)); + } else { + sink.emit(SessionEvent::Finished(stream_id)); + } + } + return Ok(()); + } + + return Err(()); + } + } + + let outcome = stream.rx.insert(frame_offset, fin, bytes).map_err(|_| ())?; + + if outcome.became_complete { + stream.inbound_state = InboundState::Finished; + } + + if let Some(route_id) = opened_route { + sink.emit(SessionEvent::Opened { + stream_id, + route_id, + }); + } + + if stream.route_id.is_some() && readable_before == 0 && stream.readable_bytes() > 0 { + sink.emit(SessionEvent::Readable(stream_id)); + } + + if stream.route_id.is_some() + && !was_finished + && matches!(stream.inbound_state, InboundState::Finished) + && stream.readable_bytes() == 0 + { + sink.emit(SessionEvent::Finished(stream_id)); + } + + self.try_reap_stream(stream_id); + Ok(()) + } + + fn handle_stream_window(&mut self, frame: &StreamWindow, sink: &mut impl EventSink) { + let Some(stream) = self.state.streams.get_mut(&frame.stream_id) else { + return; + }; + + let was_full = stream.send_capacity(self.config.stream_send_buffer_size) == 0; + let maximum_offset = frame.maximum_offset.into_inner(); + if maximum_offset > stream.peer_max_offset { + stream.peer_max_offset = maximum_offset; + } + if was_full && stream.send_capacity(self.config.stream_send_buffer_size) > 0 { + sink.emit(SessionEvent::Writable(frame.stream_id)); + } + } + + fn handle_stream_close( + &mut self, + frame: &StreamClose, + sink: &mut impl EventSink, + ) -> Result<(), ()> { + let stream_id = frame.stream_id; + let stream = match self.state.streams.get_mut(&stream_id) { + Some(stream) => stream, + None => match self.create_remote_stream(stream_id)? { + Some(stream) => stream, + None => return Ok(()), + }, + }; + + if Self::target_affects_inbound(stream.role, frame.target) + && !matches!( + stream.inbound_state, + InboundState::Closed(_) | InboundState::Discarding + ) + { + stream.inbound_state = InboundState::Closed(frame.clone()); + stream.reset_recv(); + sink.emit(SessionEvent::Closed(frame.clone())); + } + if Self::target_affects_outbound(stream.role, frame.target) + && !matches!(stream.outbound_state, OutboundState::Closed) + { + stream.outbound_state = OutboundState::Closed; + stream.tx.clear(); + stream.pending_close = None; + sink.emit(SessionEvent::WritableClosed(frame.clone())); + } + self.try_reap_stream(frame.stream_id); + Ok(()) + } + + fn apply_local_close_to_stream(stream: &mut StreamState, target: CloseTarget) { + if Self::target_affects_inbound(stream.role, target) { + stream.inbound_state = InboundState::Discarding; + stream.reset_recv(); + } + if Self::target_affects_outbound(stream.role, target) { + stream.outbound_state = OutboundState::Closed; + stream.tx.clear(); + } + } + + fn target_affects_inbound(role: StreamRole, target: CloseTarget) -> bool { + matches!(target, CloseTarget::Both) || role.inbound_target() == target + } + + fn target_affects_outbound(role: StreamRole, target: CloseTarget) -> bool { + matches!(target, CloseTarget::Both) || role.outbound_target() == target + } + + fn stream_is_reapable(&self, stream_id: StreamId, stream: &StreamState) -> bool { + let tracked_refs_stream = self.state.tracked_records.values().any(|record| { + record.window_updates.iter().any(|(id, _)| *id == stream_id) + || record.frames.iter().any(|frame| match frame { + TrackedFrame::StreamData(frame) => frame.stream_id == stream_id, + TrackedFrame::StreamClose(frame) => frame.stream_id == stream_id, + }) + }); + if tracked_refs_stream { + return false; + } + + if !stream.tx.is_empty() + || stream.pending_close.is_some() + || stream.pending_window + || stream.readable_bytes() > 0 + || stream.rx.buffered_end_offset() > stream.rx.start_offset() + { + return false; + } + + matches!( + stream.inbound_state, + InboundState::Finished | InboundState::Closed(_) | InboundState::Discarding + ) && matches!( + stream.outbound_state, + OutboundState::Finished | OutboundState::Closed + ) + } + + fn reap_reapable_streams(&mut self) { + let mut index = 0usize; + while index < self.state.streams.len() { + let stream_id = *self.state.streams.get_index(index).unwrap().0; + let len_before = self.state.streams.len(); + self.try_reap_stream(stream_id); + if self.state.streams.len() == len_before { + index += 1; + } + } + } + + fn try_reap_stream(&mut self, stream_id: StreamId) { + let Some(index) = self.state.streams.get_index_of(&stream_id) else { + return; + }; + self.try_reap_stream_at(stream_id, index); + } + + fn try_reap_stream_at(&mut self, stream_id: StreamId, index: usize) { + let Some((indexed_stream_id, stream)) = self.state.streams.get_index(index) else { + return; + }; + debug_assert_eq!(*indexed_stream_id, stream_id); + if !self.stream_is_reapable(stream_id, stream) { + return; + } + self.reap_stream_at(index); + } + + fn reap_stream_at(&mut self, index: usize) { + self.state.streams.shift_remove_index(index); + + if self.state.streams.is_empty() { + self.state.next_stream_index = 0; + return; + } + if index < self.state.next_stream_index { + self.state.next_stream_index -= 1; + } + if self.state.next_stream_index >= self.state.streams.len() { + self.state.next_stream_index %= self.state.streams.len(); + } + } + + fn clear_streams(&mut self) { + self.state.next_stream_index = 0; + self.state.streams.clear(); + } + + fn create_remote_stream( + &mut self, + stream_id: StreamId, + ) -> Result, ()> { + match classify_missing_stream( + self.config.local_parity, + self.state.next_stream_ordinal, + stream_id, + &mut self.state.remote_stream_history, + ) { + MissingStreamAction::Create => {} + MissingStreamAction::Ignore => return Ok(None), + MissingStreamAction::FailProtocol => { + return Err(()); + } + } + + let stream = self + .state + .streams + .entry(stream_id) + .insert_entry(StreamState::new( + StreamRole::Responder, + None, + self.config.stream_receive_buffer_size, + self.config.initial_peer_stream_receive_window, + )); + + Ok(Some(stream.into_mut())) + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum MissingStreamAction { + Create, + Ignore, + FailProtocol, +} + +fn classify_missing_stream( + local_parity: StreamParity, + next_stream_ordinal: u32, + stream_id: StreamId, + remote_stream_history: &mut RemoteStreamHistory, +) -> MissingStreamAction { + if !local_parity.remote().matches(stream_id) { + return if local_stream_was_opened(local_parity, next_stream_ordinal, stream_id) { + MissingStreamAction::Ignore + } else { + MissingStreamAction::FailProtocol + }; + } + + if remote_stream_history.observe(stream_id) { + MissingStreamAction::Ignore + } else { + MissingStreamAction::Create + } +} + +fn local_stream_was_opened( + local_parity: StreamParity, + next_stream_ordinal: u32, + stream_id: StreamId, +) -> bool { + local_parity.matches(stream_id) + && stream_id.into_inner() + < local_parity + .make_stream_id(next_stream_ordinal) + .into_inner() +} + +fn restore_tracked_record( + now: Instant, + ack_tracker: &mut AckTracker, + pending_ping: &mut bool, + streams: &mut IndexMap, + record: TrackedRecord, +) { + if let Some(ack) = &record.ack { + ack_tracker.restore_acked_ranges(ack, now); + } + if record.ping_included { + *pending_ping = true; + } + for (stream_id, maximum_offset) in record.window_updates { + if let Some(stream) = streams.get_mut(&stream_id) { + if stream.recv_limit() >= maximum_offset { + stream.pending_window = true; + } + } + } + for frame in record.frames { + requeue_tracked_frame(streams, frame); + } +} + +fn requeue_tracked_frame(streams: &mut IndexMap, frame: TrackedFrame) { + match frame { + TrackedFrame::StreamClose(close) => restore_stream_close(streams, close), + TrackedFrame::StreamData(frame) => restore_stream_data(streams, frame), + } +} + +fn restore_stream_close(streams: &mut IndexMap, close: StreamClose) { + if let Some(stream) = streams.get_mut(&close.stream_id) { + stream.pending_close = Some(close); + } +} + +fn restore_stream_data(streams: &mut IndexMap, frame: TrackedStreamData) { + if let Some(stream) = streams.get_mut(&frame.stream_id) { + if matches!(stream.outbound_state, OutboundState::Closed) { + return; + } + stream.tx.retransmit(stream_tx::StreamTxRange { + offset: frame.offset, + len: frame.len, + fin: frame.fin, + }); + if frame.fin && matches!(stream.outbound_state, OutboundState::Finished) { + stream.outbound_state = OutboundState::FinQueued; + } + } +} + +fn acknowledge_tracked_frame( + streams: &mut IndexMap, + stream_send_buffer_size: usize, + frame: &TrackedFrame, + sink: &mut impl EventSink, +) { + match frame { + TrackedFrame::StreamClose(_) => {} + TrackedFrame::StreamData(frame) => { + let stream_id = frame.stream_id; + if let Some(stream) = streams.get_mut(&stream_id) { + let was_full = stream.send_capacity(stream_send_buffer_size) == 0; + let had_unacked_fin = frame.fin && stream.tx.has_unacked_fin(); + stream.tx.ack(StreamTxRange { + offset: frame.offset, + len: frame.len, + fin: frame.fin, + }); + if was_full && stream.send_capacity(stream_send_buffer_size) > 0 { + sink.emit(SessionEvent::Writable(stream_id)); + } + if had_unacked_fin && !stream.tx.has_unacked_fin() { + sink.emit(SessionEvent::OutboundFinished(stream_id)); + } + } + } + } +} + +#[inline] +#[track_caller] +fn next_seq(seq: &mut RecordSeq) { + *seq = seq + .into_inner() + .checked_add(1) + .and_then(|next| RecordSeq::from_u64(next).ok()) + .expect("record sequence overflow"); +} diff --git a/ql-fsm/src/session/range_set.rs b/ql-fsm/src/session/range_set.rs new file mode 100644 index 00000000..53d66269 --- /dev/null +++ b/ql-fsm/src/session/range_set.rs @@ -0,0 +1,221 @@ +use std::{ + cmp, + collections::BTreeMap, + ops::{ + Bound::{Excluded, Included}, + Range, + }, +}; + +/// A set of `u64` values optimized for long runs and random insert/delete. +#[derive(Debug, Default, Clone, PartialEq, Eq)] +pub struct RangeSet(BTreeMap); + +impl RangeSet { + pub fn new() -> Self { + Self::default() + } + + pub fn insert(&mut self, mut x: Range) -> bool { + if x.is_empty() { + return false; + } + + if let Some((start, end)) = self.before(x.start) { + if end >= x.end { + return false; + } else if end >= x.start { + self.0.remove(&start); + x.start = start; + } + } + + while let Some((next_start, next_end)) = self.after(x.start) { + if next_start > x.end { + break; + } + self.0.remove(&next_start); + x.end = cmp::max(next_end, x.end); + } + + self.0.insert(x.start, x.end); + true + } + + pub fn remove(&mut self, x: Range) -> bool { + if x.is_empty() { + return false; + } + + let before = match self.before(x.start) { + Some((start, end)) if end > x.start => { + self.0.remove(&start); + if start < x.start { + self.0.insert(start, x.start); + } + if end > x.end { + self.0.insert(x.end, end); + } + if end >= x.end { + return true; + } + true + } + Some(_) | None => false, + }; + + let mut after = false; + while let Some((start, end)) = self.after(x.start) { + if start >= x.end { + break; + } + after = true; + self.0.remove(&start); + if end > x.end { + self.0.insert(x.end, end); + break; + } + } + + before || after + } + + pub fn min(&self) -> Option { + self.0.first_key_value().map(|(&start, _)| start) + } + + pub fn max(&self) -> Option { + self.0 + .last_key_value() + .map(|(_, &end)| end.checked_sub(1).unwrap()) + } + + pub fn contains(&self, x: u64) -> bool { + self.before(x).is_some_and(|(_, end)| end > x) + } + + pub fn range_count(&self) -> usize { + self.0.len() + } + + pub fn iter(&self) -> Iter<'_> { + Iter(self.0.iter()) + } + + pub fn iter_rev(&self) -> RevIter<'_> { + RevIter(self.0.iter().rev()) + } + + pub fn peek_min(&self) -> Option> { + let (&start, &end) = self.0.iter().next()?; + Some(start..end) + } + + pub fn pop_min(&mut self) -> Option> { + let result = self.peek_min()?; + self.0.remove(&result.start); + Some(result) + } + + #[cfg(test)] + pub fn peek_max(&self) -> Option> { + let (&start, &end) = self.0.iter().next_back()?; + Some(start..end) + } + + #[cfg(test)] + pub fn pop_max(&mut self) -> Option> { + let result = self.peek_max()?; + self.0.remove(&result.start); + Some(result) + } + + /// find closest range to `x` that begins at or before it + fn before(&self, x: u64) -> Option<(u64, u64)> { + self.0 + .range((Included(0), Included(x))) + .next_back() + .map(|(&start, &end)| (start, end)) + } + + /// find the closest range to `x` that begins after it + fn after(&self, x: u64) -> Option<(u64, u64)> { + self.0 + .range((Excluded(x), Included(u64::MAX))) + .next() + .map(|(&start, &end)| (start, end)) + } +} + +pub struct Iter<'a>(std::collections::btree_map::Iter<'a, u64, u64>); + +impl Iterator for Iter<'_> { + type Item = Range; + + fn next(&mut self) -> Option { + self.0.next().map(|(&start, &end)| start..end) + } +} + +pub struct RevIter<'a>(std::iter::Rev>); + +impl Iterator for RevIter<'_> { + type Item = Range; + + fn next(&mut self) -> Option { + self.0.next().map(|(&start, &end)| start..end) + } +} + +#[cfg(test)] +mod tests { + use super::RangeSet; + + #[test] + fn insert_merges_overlaps() { + let mut set = RangeSet::new(); + assert!(set.insert(10..20)); + assert!(set.insert(30..40)); + assert!(set.insert(15..35)); + assert_eq!(set.iter().collect::>(), vec![10..40]); + } + + #[test] + fn remove_splits_ranges() { + let mut set = RangeSet::new(); + set.insert(10..40); + assert!(set.remove(20..30)); + assert_eq!(set.iter().collect::>(), vec![10..20, 30..40]); + } + + #[test] + fn reverse_iteration_visits_highest_range_first() { + let mut set = RangeSet::new(); + set.insert(10..20); + set.insert(30..40); + set.insert(50..60); + + assert_eq!( + set.iter_rev().collect::>(), + vec![50..60, 30..40, 10..20] + ); + assert_eq!(set.peek_max(), Some(50..60)); + assert_eq!(set.pop_max(), Some(50..60)); + assert_eq!(set.iter().collect::>(), vec![10..20, 30..40]); + } + + #[test] + fn contains_and_max_reflect_current_membership() { + let mut set = RangeSet::new(); + set.insert(10..20); + set.insert(30..31); + + assert!(!set.contains(9)); + assert!(set.contains(10)); + assert!(set.contains(19)); + assert!(!set.contains(20)); + assert_eq!(set.min(), Some(10)); + assert_eq!(set.max(), Some(30)); + assert_eq!(set.range_count(), 2); + } +} diff --git a/ql-fsm/src/session/remote_stream_history.rs b/ql-fsm/src/session/remote_stream_history.rs new file mode 100644 index 00000000..76c1e8bb --- /dev/null +++ b/ql-fsm/src/session/remote_stream_history.rs @@ -0,0 +1,60 @@ +use ql_wire::StreamId; + +use super::{range_set::RangeSet, stream_parity::StreamParity}; + +#[derive(Debug)] +pub struct RemoteStreamHistory { + parity: StreamParity, + seen: RangeSet, +} + +impl RemoteStreamHistory { + pub fn new(parity: StreamParity) -> Self { + Self { + parity, + seen: RangeSet::new(), + } + } + + /// returns true when this remote stream id was already observed before + /// panics if `stream_id` is wrong stream parity + #[allow(clippy::range_plus_one)] + pub fn observe(&mut self, stream_id: StreamId) -> bool { + let ordinal = self + .stream_ordinal(stream_id) + .expect("remote stream history used with wrong stream parity"); + !self.seen.insert(ordinal..ordinal + 1) + } + + fn stream_ordinal(&self, stream_id: StreamId) -> Option { + let delta = stream_id + .into_inner() + .checked_sub(u64::from(self.parity.first_stream_id()))?; + if delta % 2 != 0 { + return None; + } + Some(delta / 2) + } +} + +#[cfg(test)] +mod tests { + use super::RemoteStreamHistory; + use crate::session::stream_parity::StreamParity; + + #[test] + fn observe() { + let parity = StreamParity::Even; + let mut history = RemoteStreamHistory::new(parity); + + assert!(!history.observe(parity.make_stream_id(2))); + assert!(!history.observe(parity.make_stream_id(5))); + assert!(!history.observe(parity.make_stream_id(0))); + assert!(!history.observe(parity.make_stream_id(4))); + assert!(history.observe(parity.make_stream_id(2))); + assert!(!history.observe(parity.make_stream_id(1))); + assert!(history.observe(parity.make_stream_id(5))); + assert!(!history.observe(parity.make_stream_id(3))); + assert!(history.observe(parity.make_stream_id(0))); + } +} diff --git a/ql-fsm/src/session/state.rs b/ql-fsm/src/session/state.rs new file mode 100644 index 00000000..b63140a1 --- /dev/null +++ b/ql-fsm/src/session/state.rs @@ -0,0 +1,140 @@ +use std::time::Instant; + +use indexmap::IndexMap; +use ql_wire::{CloseTarget, RecordSeq, RouteId, SessionClose, StreamClose, StreamId}; + +use super::{ + ack_tracker::AckTracker, remote_stream_history::RemoteStreamHistory, stream_rx::StreamRx, + stream_tx::StreamTx, tracked::TrackedRecord, +}; + +pub struct SessionState { + pub last_activity_at: Instant, + pub last_inbound_at: Instant, + pub phase: SessionPhase, + pub next_stream_ordinal: u32, + pub next_record_seq: RecordSeq, + pub next_write_id: u64, + pub tracked_records: IndexMap, + pub ack_tracker: AckTracker, + pub pending_ping: bool, + pub streams: IndexMap, + pub next_stream_index: usize, + pub remote_stream_history: RemoteStreamHistory, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum SessionPhase { + Open, + Terminating(TerminalFrame), + Closed, +} + +impl SessionPhase { + pub fn is_open(&self) -> bool { + self == &Self::Open + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum TerminalFrame { + Close(SessionClose), + Unpair, +} + +#[derive(Debug)] +pub struct StreamState { + pub role: StreamRole, + pub route_id: Option, + pub rx: StreamRx, + pub tx: StreamTx, + pub pending_close: Option, + pub peer_max_offset: u64, + pub outbound_state: OutboundState, + pub inbound_state: InboundState, + pub advertised_max_offset: u64, + pub pending_window: bool, +} + +impl StreamState { + pub fn new( + role: StreamRole, + route_id: Option, + receive_buffer_size: u32, + initial_peer_stream_receive_window: u32, + ) -> Self { + let receive_buffer_size = receive_buffer_size as usize; + Self { + role, + route_id, + tx: StreamTx::new(), + pending_close: None, + peer_max_offset: u64::from(initial_peer_stream_receive_window), + outbound_state: OutboundState::Open, + inbound_state: InboundState::Open, + rx: StreamRx::new(receive_buffer_size), + advertised_max_offset: receive_buffer_size as u64, + pending_window: false, + } + } + + pub fn is_writable(&self) -> bool { + matches!(self.outbound_state, OutboundState::Open) + } + + pub fn send_capacity(&self, send_buffer_size: usize) -> usize { + send_buffer_size.saturating_sub(self.tx.buffered_len()) + } + + pub fn readable_bytes(&self) -> usize { + self.rx.readable_len() + } + + pub fn recv_limit(&self) -> u64 { + self.rx + .start_offset() + .saturating_add(self.rx.max_buffered() as u64) + } + + pub fn reset_recv(&mut self) { + self.rx = StreamRx::with_start_offset(self.rx.start_offset(), self.rx.max_buffered()); + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum StreamRole { + Initiator, + Responder, +} + +impl StreamRole { + pub fn outbound_target(self) -> CloseTarget { + match self { + Self::Initiator => CloseTarget::Origin, + Self::Responder => CloseTarget::Return, + } + } + + pub fn inbound_target(self) -> CloseTarget { + match self { + Self::Initiator => CloseTarget::Return, + Self::Responder => CloseTarget::Origin, + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum OutboundState { + Open, + FinQueued, + Finished, + Closed, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum InboundState { + Open, + Finished, + Closed(StreamClose), + Discarding, +} diff --git a/ql-fsm/src/session/stream_ops.rs b/ql-fsm/src/session/stream_ops.rs new file mode 100644 index 00000000..548189b7 --- /dev/null +++ b/ql-fsm/src/session/stream_ops.rs @@ -0,0 +1,147 @@ +use ql_wire::{CloseTarget, StreamClose, StreamCloseCode, StreamId}; + +use super::{ + state::{InboundState, StreamState}, + stream_rx::StreamReadIter, + EventSink, SessionEvent, SessionFsm, +}; +use crate::CommitReadError; + +pub struct StreamOps<'a, E> { + session: &'a mut SessionFsm, + emit: E, + stream_id: StreamId, + stream_index: usize, + reap_on_drop: bool, +} + +impl<'a, E: EventSink> StreamOps<'a, E> { + pub(super) fn new( + session: &'a mut SessionFsm, + stream_id: StreamId, + stream_index: usize, + emit: E, + ) -> Self { + Self { + session, + emit, + stream_id, + stream_index, + reap_on_drop: false, + } + } + + /// returns this stream's identifier + pub fn stream_id(&self) -> StreamId { + self.stream_id + } + + /// returns the readable stream bytes as owned `Bytes` views without consuming them + pub fn read(&self) -> StreamReadIter<'_> { + self.stream().rx.bytes() + } + + /// returns how many bytes can be read from the stream + pub fn readable_bytes(&self) -> usize { + self.stream().readable_bytes() + } + + /// marks previously read bytes as consumed + pub fn commit_read(&mut self, len: usize) -> Result<(), CommitReadError> { + let stream_id = self.stream_id; + let emit_finished = { + let stream = self.stream_mut(); + if len > stream.readable_bytes() { + return Err(CommitReadError); + } + stream.rx.consume(len); + if stream.recv_limit() > stream.advertised_max_offset { + stream.pending_window = true; + } + stream.route_id.is_some() + && matches!(stream.inbound_state, InboundState::Finished) + && stream.readable_bytes() == 0 + }; + if emit_finished { + self.emit.emit(SessionEvent::Finished(stream_id)); + } + self.reap_on_drop = true; + Ok(()) + } + + /// returns a writer if the local write side is still open + pub fn writer(&mut self) -> Option> { + let send_buffer_size = self.session.config.stream_send_buffer_size; + let stream = self.stream_mut(); + if !stream.is_writable() { + return None; + } + Some(StreamWriter::new(stream, send_buffer_size)) + } + + /// closes the origin lane, return lane, or both lanes of the stream + pub fn close(&mut self, target: CloseTarget, code: StreamCloseCode) { + let stream_id = self.stream_id; + let stream = self.stream_mut(); + SessionFsm::apply_local_close_to_stream(stream, target); + stream.pending_close = Some(StreamClose { + stream_id, + target, + code, + }); + self.reap_on_drop = true; + } + + fn stream(&self) -> &StreamState { + &self.session.state.streams[self.stream_index] + } + + fn stream_mut(&mut self) -> &mut StreamState { + &mut self.session.state.streams[self.stream_index] + } +} + +impl Drop for StreamOps<'_, E> { + fn drop(&mut self) { + if !self.reap_on_drop { + return; + } + + self.session + .try_reap_stream_at(self.stream_id, self.stream_index); + } +} + +pub struct StreamWriter<'a> { + stream: &'a mut StreamState, + send_buffer_size: usize, +} + +impl<'a> StreamWriter<'a> { + pub(super) fn new(stream: &'a mut StreamState, send_buffer_size: usize) -> Self { + Self { + stream, + send_buffer_size, + } + } + + /// returns how many bytes can still be buffered for local writes + pub fn capacity(&self) -> usize { + self.stream.send_capacity(self.send_buffer_size) + } + + /// appends as many bytes as possible and returns the accepted count + pub fn write(&mut self, bytes: &mut bytes::Bytes) -> usize { + let accepted = bytes.len().min(self.capacity()); + if accepted > 0 { + self.stream.tx.append(bytes.split_to(accepted)); + } + accepted + } + + /// marks the local write side as finished + pub fn finish(self) { + self.stream.tx.queue_fin(); + self.stream.outbound_state = super::state::OutboundState::FinQueued; + } +} diff --git a/ql-fsm/src/session/stream_parity.rs b/ql-fsm/src/session/stream_parity.rs new file mode 100644 index 00000000..70f60776 --- /dev/null +++ b/ql-fsm/src/session/stream_parity.rs @@ -0,0 +1,44 @@ +use ql_wire::{StreamId, QID}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum StreamParity { + Even, + Odd, +} + +impl StreamParity { + pub fn for_local(local: QID, peer: QID) -> Self { + match local.0.cmp(&peer.0) { + std::cmp::Ordering::Less | std::cmp::Ordering::Equal => Self::Even, + std::cmp::Ordering::Greater => Self::Odd, + } + } + + pub const fn first_stream_id(self) -> u32 { + match self { + Self::Even => 0, + Self::Odd => 1, + } + } + + pub const fn matches(self, stream_id: StreamId) -> bool { + match self { + Self::Even => stream_id.into_inner() % 2 == 0, + Self::Odd => stream_id.into_inner() % 2 == 1, + } + } + + pub const fn remote(self) -> Self { + match self { + Self::Even => Self::Odd, + Self::Odd => Self::Even, + } + } + + pub fn make_stream_id(self, ordinal: u32) -> StreamId { + StreamId(ql_wire::VarInt::from_u32( + self.first_stream_id() + .saturating_add(ordinal.saturating_mul(2)), + )) + } +} diff --git a/ql-fsm/src/session/stream_rx.rs b/ql-fsm/src/session/stream_rx.rs new file mode 100644 index 00000000..0f5a8eab --- /dev/null +++ b/ql-fsm/src/session/stream_rx.rs @@ -0,0 +1,428 @@ +use std::collections::{btree_map, BTreeMap}; + +use bytes::{Buf, Bytes}; + +/// reassembles one stream direction from out-of-order byte ranges. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct StreamRx { + start_offset: u64, + chunks: BTreeMap, + final_offset: Option, + max_buffered: usize, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct InsertOutcome { + pub newly_readable_bytes: usize, + pub became_complete: bool, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum StreamRxError { + OffsetOverflow, + OutOfWindow, + InconsistentFinalOffset, + FinalOffsetBeforeBufferedData, + BeyondFinalOffset, +} + +impl StreamRx { + pub fn new(max_buffered: usize) -> Self { + Self::with_start_offset(0, max_buffered) + } + + pub fn with_start_offset(start_offset: u64, max_buffered: usize) -> Self { + Self { + start_offset, + chunks: BTreeMap::new(), + final_offset: None, + max_buffered, + } + } + + pub fn start_offset(&self) -> u64 { + self.start_offset + } + + pub fn buffered_end_offset(&self) -> u64 { + self.chunks + .last_key_value() + .map_or(self.start_offset, |(&offset, bytes)| { + offset + bytes.len() as u64 + }) + } + + pub fn final_offset(&self) -> Option { + self.final_offset + } + + pub fn max_buffered(&self) -> usize { + self.max_buffered + } + + pub fn readable_len(&self) -> usize { + let mut cursor = self.start_offset; + for (&offset, bytes) in self.chunks.range(self.start_offset..) { + if offset > cursor { + break; + } + + let end = offset + bytes.len() as u64; + if end > cursor { + cursor = end; + } + } + + usize::try_from(cursor - self.start_offset).expect("readable prefix exceeds usize") + } + + pub fn bytes(&self) -> StreamReadIter<'_> { + StreamReadIter { + inner: self.chunks.range(self.start_offset..), + cursor: self.start_offset, + remaining: self.readable_len(), + } + } + + pub fn is_complete(&self) -> bool { + matches!(self.final_offset, Some(final_offset) + if final_offset == self.buffered_end_offset() + && final_offset == self.start_offset + self.readable_len() as u64) + } + + pub fn insert( + &mut self, + offset: u64, + fin: bool, + mut bytes: Bytes, + ) -> Result { + let end = offset + .checked_add(bytes.len() as u64) + .ok_or(StreamRxError::OffsetOverflow)?; + + let was_complete = self.is_complete(); + let old_readable = self.readable_len(); + + if fin { + self.set_or_validate_final_offset(end)?; + } + if let Some(final_offset) = self.final_offset { + if end > final_offset { + return Err(StreamRxError::BeyondFinalOffset); + } + } + + if bytes.is_empty() || end <= self.start_offset { + return Ok(self.insert_outcome(was_complete, old_readable)); + } + + let effective_offset = offset.max(self.start_offset); + let trim_front = + usize::try_from(effective_offset - offset).expect("front trim exceeds usize"); + bytes.advance(trim_front); + if bytes.is_empty() { + return Ok(self.insert_outcome(was_complete, old_readable)); + } + + let effective_end = effective_offset + bytes.len() as u64; + self.ensure_within_window(effective_end)?; + self.insert_chunk(effective_offset, bytes); + + Ok(self.insert_outcome(was_complete, old_readable)) + } + + pub fn consume(&mut self, len: usize) { + let readable = self.readable_len(); + debug_assert!(len <= readable, "consume beyond readable bytes"); + if len > readable { + return; + } + + let new_start = self.start_offset.saturating_add(len as u64); + while let Some((&offset, bytes)) = self.chunks.first_key_value() { + let end = offset + bytes.len() as u64; + if end <= new_start { + self.chunks.pop_first(); + continue; + } + if offset < new_start { + let (offset, mut bytes) = self.chunks.pop_first().unwrap(); + bytes.advance(usize::try_from(new_start - offset).expect("trim exceeds usize")); + self.chunks.insert(new_start, bytes); + } + break; + } + + self.start_offset = new_start; + } + + fn insert_outcome(&self, was_complete: bool, old_readable: usize) -> InsertOutcome { + InsertOutcome { + newly_readable_bytes: self.readable_len().saturating_sub(old_readable), + became_complete: !was_complete && self.is_complete(), + } + } + + fn set_or_validate_final_offset(&mut self, final_offset: u64) -> Result<(), StreamRxError> { + if let Some(existing) = self.final_offset { + return if existing == final_offset { + Ok(()) + } else { + Err(StreamRxError::InconsistentFinalOffset) + }; + } + + let buffered_end = self.buffered_end_offset(); + if final_offset < buffered_end { + return Err(StreamRxError::FinalOffsetBeforeBufferedData); + } + + self.final_offset = Some(final_offset); + Ok(()) + } + + fn ensure_within_window(&self, end: u64) -> Result<(), StreamRxError> { + let attempted = end.saturating_sub(self.start_offset); + if attempted > self.max_buffered as u64 { + return Err(StreamRxError::OutOfWindow); + } + Ok(()) + } + + fn insert_chunk(&mut self, mut offset: u64, mut bytes: Bytes) { + if bytes.is_empty() { + return; + } + + if let Some((&existing_offset, existing)) = self.chunks.range(..offset).next_back() { + let existing_end = existing_offset + existing.len() as u64; + if existing_end > offset { + let overlap = + usize::try_from((existing_end - offset).min(bytes.len() as u64)).unwrap(); + bytes.advance(overlap); + offset += overlap as u64; + } + } + + if bytes.is_empty() { + return; + } + + let end = offset + bytes.len() as u64; + let overlapping = self + .chunks + .range(offset..end) + .map(|(&chunk_offset, _)| chunk_offset) + .collect::>(); + + for chunk_offset in overlapping { + let chunk_end = chunk_offset + self.chunks[&chunk_offset].len() as u64; + + if chunk_offset > offset { + let len = usize::try_from(chunk_offset - offset).expect("gap exceeds usize"); + self.chunks.insert(offset, bytes.slice(..len)); + bytes.advance(len); + offset = chunk_offset; + } + + let overlap = usize::try_from((chunk_end - offset).min(bytes.len() as u64)).unwrap(); + bytes.advance(overlap); + offset += overlap as u64; + + if bytes.is_empty() { + return; + } + } + + self.chunks.insert(offset, bytes); + } +} + +#[derive(Debug, Clone)] +pub struct StreamReadIter<'a> { + inner: btree_map::Range<'a, u64, Bytes>, + cursor: u64, + remaining: usize, +} + +impl Iterator for StreamReadIter<'_> { + type Item = Bytes; + + fn next(&mut self) -> Option { + while self.remaining > 0 { + let (&offset, bytes) = self.inner.next()?; + if offset > self.cursor { + self.remaining = 0; + return None; + } + + let skip = usize::try_from(self.cursor.saturating_sub(offset)) + .expect("read cursor exceeds usize"); + if skip >= bytes.len() { + continue; + } + + let len = (bytes.len() - skip).min(self.remaining); + self.remaining -= len; + self.cursor += len as u64; + return Some(bytes.slice(skip..skip + len)); + } + + None + } +} + +#[cfg(test)] +mod tests { + use bytes::Bytes; + + use super::{InsertOutcome, StreamRx, StreamRxError}; + + pub fn copy_readable(rx: &StreamRx) -> Vec { + let readable = rx.readable_len(); + let mut out = Vec::with_capacity(readable); + for chunk in rx.bytes() { + out.extend_from_slice(&chunk); + } + out + } + + fn bytes(bytes: &'static [u8]) -> Bytes { + Bytes::from_static(bytes) + } + + #[test] + fn contiguous_insert_becomes_readable_and_complete() { + let mut rx = StreamRx::new(64); + + let outcome = rx.insert(0, true, bytes(b"hello")).unwrap(); + + assert_eq!( + outcome, + InsertOutcome { + newly_readable_bytes: 5, + became_complete: true, + } + ); + assert_eq!(rx.readable_len(), 5); + assert_eq!(copy_readable(&rx), b"hello"); + assert_eq!(rx.final_offset, Some(5)); + assert!(rx.is_complete()); + } + + #[test] + fn out_of_order_insert_tracks_gap_until_prefix_is_filled() { + let mut rx = StreamRx::new(64); + + let first = rx.insert(5, true, bytes(b" world")).unwrap(); + assert_eq!( + first, + InsertOutcome { + newly_readable_bytes: 0, + became_complete: false, + } + ); + assert_eq!(rx.readable_len(), 0); + + let second = rx.insert(0, false, bytes(b"hello")).unwrap(); + assert_eq!( + second, + InsertOutcome { + newly_readable_bytes: 11, + became_complete: true, + } + ); + assert_eq!(copy_readable(&rx), b"hello world"); + assert!(rx.is_complete()); + } + + #[test] + fn duplicate_insert_is_ignored_if_bytes_match() { + let mut rx = StreamRx::new(64); + + rx.insert(0, false, bytes(b"hello")).unwrap(); + let duplicate = rx.insert(0, false, bytes(b"hello")).unwrap(); + + assert_eq!( + duplicate, + InsertOutcome { + newly_readable_bytes: 0, + became_complete: false, + } + ); + assert_eq!(copy_readable(&rx), b"hello"); + } + + #[test] + fn consume_advances_start_offset_and_trims_old_prefix() { + let mut rx = StreamRx::new(64); + + rx.insert(0, false, bytes(b"abcd")).unwrap(); + rx.consume(2); + assert_eq!(rx.start_offset(), 2); + assert_eq!(copy_readable(&rx), b"cd"); + + let outcome = rx.insert(1, true, bytes(b"bcde")).unwrap(); + assert_eq!( + outcome, + InsertOutcome { + newly_readable_bytes: 1, + became_complete: true, + } + ); + assert_eq!(copy_readable(&rx), b"cde"); + assert_eq!(rx.final_offset, Some(5)); + assert!(rx.is_complete()); + } + + #[test] + fn insert_can_fill_multiple_gaps_without_rebuilding_state() { + let mut rx = StreamRx::new(64); + + rx.insert(0, false, bytes(b"ab")).unwrap(); + rx.insert(4, false, bytes(b"ef")).unwrap(); + rx.insert(8, true, bytes(b"ij")).unwrap(); + + let outcome = rx.insert(2, false, bytes(b"cdefgh")).unwrap(); + + assert_eq!( + outcome, + InsertOutcome { + newly_readable_bytes: 8, + became_complete: true, + } + ); + + assert_eq!(copy_readable(&rx), b"abcdefghij"); + assert!(rx.is_complete()); + } + + #[test] + fn heavily_fragmented_inserts_stay_valid() { + let mut rx = StreamRx::new(64); + + rx.insert(1, false, bytes(b"b")).unwrap(); + rx.insert(3, false, bytes(b"d")).unwrap(); + rx.insert(5, false, bytes(b"f")).unwrap(); + rx.insert(7, false, bytes(b"h")).unwrap(); + rx.insert(9, true, bytes(b"j")).unwrap(); + + let outcome = rx.insert(0, false, bytes(b"abcdefghi")).unwrap(); + assert_eq!( + outcome, + InsertOutcome { + newly_readable_bytes: 10, + became_complete: true, + } + ); + assert_eq!(copy_readable(&rx), b"abcdefghij"); + assert!(rx.is_complete()); + } + + #[test] + fn out_of_window_insert_is_rejected() { + let mut rx = StreamRx::new(4); + let error = rx.insert(5, false, bytes(b"a")).unwrap_err(); + assert_eq!(error, StreamRxError::OutOfWindow); + } +} diff --git a/ql-fsm/src/session/stream_tx.rs b/ql-fsm/src/session/stream_tx.rs new file mode 100644 index 00000000..15533922 --- /dev/null +++ b/ql-fsm/src/session/stream_tx.rs @@ -0,0 +1,579 @@ +use std::{collections::VecDeque, ops::Range}; + +use bytes::{Buf, Bytes}; +use ql_wire::BufView; + +use super::range_set::RangeSet; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct StreamTx { + chunks: VecDeque, + buffered_len: usize, + base_offset: u64, + unsent: u64, + acked: RangeSet, + retransmits: RangeSet, + final_offset: Option, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +struct TrackedFinalOffset { + offset: u64, + state: SendState, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum SendState { + Unsent, + Sent, + Lost, + Acked, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct StreamTxRange { + pub offset: u64, + pub len: usize, + pub fin: bool, +} + +#[derive(Debug, Clone, Copy)] +pub struct StreamTxBytes<'a> { + inner: &'a VecDeque, + offset: usize, + len: usize, +} + +pub struct StreamTxBuf<'a> { + inner: std::collections::vec_deque::Iter<'a, Bytes>, + skip: usize, + remaining: usize, + current: &'a [u8], +} + +impl BufView for StreamTxBytes<'_> { + type Buf<'a> + = StreamTxBuf<'a> + where + Self: 'a; + + fn buf(&self) -> Self::Buf<'_> { + let mut buf = StreamTxBuf { + inner: self.inner.iter(), + skip: self.offset, + remaining: self.len, + current: &[], + }; + buf.refill(); + buf + } +} + +impl StreamTxBuf<'_> { + fn refill(&mut self) { + if self.remaining == 0 { + self.current = &[]; + return; + } + + for chunk in self.inner.by_ref() { + if self.skip >= chunk.len() { + self.skip -= chunk.len(); + continue; + } + + let chunk = &chunk[self.skip..]; + self.skip = 0; + if chunk.is_empty() { + continue; + } + + let len = chunk.len().min(self.remaining); + self.current = &chunk[..len]; + return; + } + + self.current = &[]; + } +} + +impl Buf for StreamTxBuf<'_> { + fn remaining(&self) -> usize { + self.remaining + } + + fn chunk(&self) -> &[u8] { + self.current + } + + fn advance(&mut self, cnt: usize) { + let remaining = self.remaining; + assert!( + cnt <= remaining, + "cannot advance past remaining bytes: {cnt} > {remaining}", + ); + + self.remaining -= cnt; + let mut cnt = cnt; + while cnt > 0 { + if cnt < self.current.len() { + self.current = &self.current[cnt..]; + return; + } + + cnt -= self.current.len(); + self.refill(); + } + + if self.remaining == 0 { + self.current = &[]; + } + } +} + +impl StreamTx { + pub fn new() -> Self { + Self { + chunks: VecDeque::new(), + buffered_len: 0, + base_offset: 0, + unsent: 0, + acked: RangeSet::new(), + retransmits: RangeSet::new(), + final_offset: None, + } + } + + pub fn buffered_len(&self) -> usize { + self.buffered_len + } + + pub fn end_offset(&self) -> u64 { + self.base_offset + self.buffered_len as u64 + } + + pub fn is_empty(&self) -> bool { + self.buffered_len == 0 && self.final_offset.is_none() + } + + pub fn append(&mut self, bytes: Bytes) { + if bytes.is_empty() { + return; + } + + self.buffered_len += bytes.len(); + self.chunks.push_back(bytes); + } + + pub fn queue_fin(&mut self) { + self.final_offset = Some(TrackedFinalOffset { + offset: self.end_offset(), + state: SendState::Unsent, + }); + } + + pub fn has_unacked_fin(&self) -> bool { + self.final_offset + .is_some_and(|final_offset| final_offset.state != SendState::Acked) + } + + pub fn poll_transmit( + &mut self, + max_payload: usize, + peer_max_offset: u64, + ) -> Option { + let budget_end = |start: u64| { + start + .saturating_add(max_payload as u64) + .min(peer_max_offset) + }; + + // prefer the lowest lost bytes before sending new bytes + if let Some(range) = self.retransmits.peek_min() { + let mut end = range.end.min(budget_end(range.start)); + + // extend only when lost bytes end where unsent bytes begin + if end == range.end && range.end == self.unsent { + end = self.end_offset().min(budget_end(range.start)); + } + + if end > range.start { + let range = self.retransmits.pop_min().unwrap(); + if end < range.end { + self.retransmits.insert(end..range.end); + } + + // mark any new bytes in this frame as sent + self.unsent = self.unsent.max(end); + return Some(StreamTxRange { + offset: range.start, + len: usize::try_from(end - range.start).unwrap(), + fin: self.poll_fin(end), + }); + } + } + + // send bytes that have not been sent yet + if self.unsent < self.end_offset() { + let end = self.end_offset().min(budget_end(self.unsent)); + if end > self.unsent { + let start = self.unsent; + self.unsent = end; + return Some(StreamTxRange { + offset: start, + len: usize::try_from(end - start).unwrap(), + fin: self.poll_fin(end), + }); + } + } + + // send a fin after all data has been sent + let final_offset = + self.final_offset + .as_mut() + .filter(|TrackedFinalOffset { offset, state }| { + (*state == SendState::Lost || *state == SendState::Unsent) + && *offset <= peer_max_offset + })?; + final_offset.state = SendState::Sent; + Some(StreamTxRange { + offset: final_offset.offset, + len: 0, + fin: true, + }) + } + + pub fn ranged_bytes(&self, range: StreamTxRange) -> StreamTxBytes<'_> { + let offset = usize::try_from(range.offset - self.base_offset).unwrap(); + let len = range.len.min(self.buffered_len.saturating_sub(offset)); + StreamTxBytes { + inner: &self.chunks, + offset, + len, + } + } + + pub fn retransmit(&mut self, range: StreamTxRange) { + if let Some(range) = self.clamp_sent_range(range.offset, range.len) { + Self::insert_not_acked(&self.acked, &mut self.retransmits, range); + } + if range.fin { + self.mark_fin_lost(); + } + } + + pub fn ack(&mut self, range: StreamTxRange) { + if let Some(range) = self.clamp_buffered_range(range.offset, range.len) { + self.acked.insert(range.clone()); + self.retransmits.remove(range); + self.trim_acked_prefix(); + } + if range.fin { + if let Some(final_offset) = self.final_offset.as_mut() { + final_offset.state = SendState::Acked; + } + } + self.trim_acked_fin(); + } + + pub fn clear(&mut self) { + self.chunks.clear(); + self.buffered_len = 0; + self.unsent = self.base_offset; + self.acked = RangeSet::new(); + self.retransmits = RangeSet::new(); + self.final_offset = None; + } + + fn clamp_buffered_range(&self, offset: u64, len: usize) -> Option> { + if len == 0 { + return None; + } + let start = offset.max(self.base_offset); + let end = offset.saturating_add(len as u64).min(self.end_offset()); + (start < end).then_some(start..end) + } + + fn clamp_sent_range(&self, offset: u64, len: usize) -> Option> { + if len == 0 { + return None; + } + let start = offset.max(self.base_offset); + let end = offset.saturating_add(len as u64).min(self.unsent); + (start < end).then_some(start..end) + } + + fn insert_not_acked(acked_set: &RangeSet, target: &mut RangeSet, range: Range) { + let mut cursor = range.start; + for acked in acked_set.iter() { + if acked.end <= cursor { + continue; + } + if acked.start >= range.end { + break; + } + if cursor < acked.start { + target.insert(cursor..acked.start.min(range.end)); + } + cursor = cursor.max(acked.end); + if cursor >= range.end { + break; + } + } + if cursor < range.end { + target.insert(cursor..range.end); + } + } + + fn poll_fin(&mut self, offset: u64) -> bool { + let Some(final_offset) = self.final_offset.as_mut() else { + return false; + }; + if matches!(final_offset.state, SendState::Lost | SendState::Unsent) + && final_offset.offset == offset + { + final_offset.state = SendState::Sent; + true + } else { + false + } + } + + fn mark_fin_lost(&mut self) { + if let Some(final_offset) = self.final_offset.as_mut() { + if final_offset.state != SendState::Acked { + final_offset.state = SendState::Lost; + } + } + } + + fn trim_acked_prefix(&mut self) { + while self.acked.min() == Some(self.base_offset) { + let prefix = self.acked.pop_min().unwrap(); + let mut to_advance = usize::try_from(prefix.end - prefix.start).unwrap(); + self.buffered_len -= to_advance; + while to_advance > 0 { + let front = self + .chunks + .front_mut() + .expect("expected buffered chunks for acked prefix"); + if front.len() <= to_advance { + to_advance -= front.len(); + self.chunks.pop_front(); + } else { + front.advance(to_advance); + to_advance = 0; + } + } + self.base_offset = prefix.end; + } + } + + fn trim_acked_fin(&mut self) { + if self.final_offset.is_some_and(|final_offset| { + final_offset.state == SendState::Acked + && final_offset.offset == self.base_offset + && self.buffered_len == 0 + }) { + self.final_offset = None; + } + } +} + +#[cfg(test)] +mod tests { + use bytes::Bytes; + + use super::{StreamTx, StreamTxRange}; + + #[test] + fn append_tracks_unsent_bytes() { + let mut tx = StreamTx::new(); + tx.append(Bytes::from_static(b"abc")); + tx.append(Bytes::from_static(b"de")); + + assert_eq!( + tx.poll_transmit(8, u64::MAX), + Some(StreamTxRange { + offset: 0, + len: 5, + fin: false, + }) + ); + } + + #[test] + fn lost_range_is_selected_before_unsent_bytes() { + let mut tx = StreamTx::new(); + tx.append(Bytes::from_static(b"abcdef")); + + let first = tx.poll_transmit(3, u64::MAX).unwrap(); + tx.retransmit(first); + + assert_eq!( + tx.poll_transmit(3, u64::MAX), + Some(StreamTxRange { + offset: 0, + len: 3, + fin: false, + }) + ); + } + + #[test] + fn lost_range_coalesces_contiguous_unsent_bytes() { + let mut tx = StreamTx::new(); + tx.append(Bytes::from_static(b"abc")); + + let first = tx.poll_transmit(3, u64::MAX).unwrap(); + tx.retransmit(first); + tx.append(Bytes::from_static(b"def")); + + assert_eq!( + tx.poll_transmit(6, u64::MAX), + Some(StreamTxRange { + offset: 0, + len: 6, + fin: false, + }) + ); + assert_eq!(tx.poll_transmit(6, u64::MAX), None); + } + + #[test] + fn lost_range_coalesces_only_new_bytes_that_fit() { + let mut tx = StreamTx::new(); + tx.append(Bytes::from_static(b"abc")); + + let first = tx.poll_transmit(3, u64::MAX).unwrap(); + tx.retransmit(first); + tx.append(Bytes::from_static(b"def")); + + assert_eq!( + tx.poll_transmit(5, u64::MAX), + Some(StreamTxRange { + offset: 0, + len: 5, + fin: false, + }) + ); + assert_eq!( + tx.poll_transmit(6, u64::MAX), + Some(StreamTxRange { + offset: 5, + len: 1, + fin: false, + }) + ); + } + + #[test] + fn non_contiguous_lost_range_does_not_coalesce_unsent_bytes() { + let mut tx = StreamTx::new(); + tx.append(Bytes::from_static(b"abcdef")); + + let first = tx.poll_transmit(3, u64::MAX).unwrap(); + let _second = tx.poll_transmit(3, u64::MAX).unwrap(); + tx.retransmit(first); + tx.append(Bytes::from_static(b"ghi")); + + assert_eq!( + tx.poll_transmit(6, u64::MAX), + Some(StreamTxRange { + offset: 0, + len: 3, + fin: false, + }) + ); + assert_eq!( + tx.poll_transmit(6, u64::MAX), + Some(StreamTxRange { + offset: 6, + len: 3, + fin: false, + }) + ); + } + + #[test] + fn acked_prefix_is_trimmed() { + let mut tx = StreamTx::new(); + tx.append(Bytes::from_static(b"abcdef")); + + let first = tx.poll_transmit(3, u64::MAX).unwrap(); + tx.ack(first); + + assert_eq!( + tx.poll_transmit(3, u64::MAX), + Some(StreamTxRange { + offset: 3, + len: 3, + fin: false, + }) + ); + } + + #[test] + fn empty_fin_is_tracked_separately() { + let mut tx = StreamTx::new(); + tx.queue_fin(); + + let range = tx.poll_transmit(16, u64::MAX).unwrap(); + assert_eq!( + range, + StreamTxRange { + offset: 0, + len: 0, + fin: true, + } + ); + + tx.ack(range); + assert!(tx.is_empty()); + } + + #[test] + fn subrange_updates_split_merged_in_flight_segments() { + let mut tx = StreamTx::new(); + tx.append(Bytes::from_static(b"abcdefghijkl")); + + let _first = tx.poll_transmit(4, u64::MAX).unwrap(); + let second = tx.poll_transmit(4, u64::MAX).unwrap(); + let _third = tx.poll_transmit(4, u64::MAX).unwrap(); + + tx.retransmit(second); + + assert_eq!( + tx.poll_transmit(4, u64::MAX), + Some(StreamTxRange { + offset: 4, + len: 4, + fin: false, + }) + ); + } + + #[test] + fn acked_subrange_is_not_reopened_by_stale_timeout() { + let mut tx = StreamTx::new(); + tx.append(Bytes::from_static(b"abcdefghijklmnop")); + + let _first = tx.poll_transmit(4, u64::MAX).unwrap(); + let second = tx.poll_transmit(4, u64::MAX).unwrap(); + let third = tx.poll_transmit(4, u64::MAX).unwrap(); + let _fourth = tx.poll_transmit(4, u64::MAX).unwrap(); + + tx.ack(second); + tx.retransmit(second); + tx.retransmit(third); + + assert_eq!( + tx.poll_transmit(4, u64::MAX), + Some(StreamTxRange { + offset: 8, + len: 4, + fin: false, + }) + ); + } +} diff --git a/ql-fsm/src/session/tests.rs b/ql-fsm/src/session/tests.rs new file mode 100644 index 00000000..f1f29879 --- /dev/null +++ b/ql-fsm/src/session/tests.rs @@ -0,0 +1,869 @@ +use std::time::{Duration, Instant}; + +use bytes::Bytes; +use ql_wire::{ + decode_session_frames, parse_session_frames, CloseTarget, RecordAck, RecordSeq, RouteId, + SessionFrame, SessionRecordBuilder, StreamClose, StreamCloseCode, StreamData, StreamHeader, + StreamId, VarInt, QID, +}; + +use super::{SessionConfig, SessionEvent, SessionFsm}; +use crate::session::stream_parity::StreamParity; + +fn seq(value: u64) -> RecordSeq { + RecordSeq::from_u64(value).unwrap() +} + +fn stream_id(value: u64) -> StreamId { + StreamId(VarInt::from_u64(value).unwrap()) +} + +fn offset(value: u64) -> VarInt { + VarInt::from_u64(value).unwrap() +} + +fn route_id(value: u64) -> RouteId { + RouteId::from_u64(value).unwrap() +} + +fn record_ack(seq: RecordSeq) -> RecordAck { + RecordAck::from_ranges([seq..=seq]).unwrap() +} + +const REFUSED: StreamCloseCode = StreamCloseCode(1); +const TIMEOUT: StreamCloseCode = StreamCloseCode(2); + +fn header(value: u64) -> StreamHeader { + StreamHeader { + route_id: route_id(value), + } +} + +fn opened(stream_id: StreamId) -> SessionEvent { + SessionEvent::Opened { + stream_id, + route_id: route_id(1), + } +} + +fn open_stream_id(fsm: &mut SessionFsm) -> StreamId { + fsm.open_stream(route_id(1), |_| {}).unwrap().stream_id() +} + +fn write_stream_bytes(fsm: &mut SessionFsm, stream_id: StreamId, bytes: &[u8]) -> usize { + let mut bytes = Bytes::copy_from_slice(bytes); + let mut stream = fsm.stream(stream_id, |_| {}).unwrap(); + let mut writer = stream.writer().unwrap(); + writer.write(&mut bytes) +} + +fn read_stream_all(fsm: &mut SessionFsm, stream_id: StreamId) -> Vec { + let mut stream = fsm.stream(stream_id, |_| {}).unwrap(); + let out = stream.read().flatten().collect::>(); + stream.commit_read(out.len()).unwrap(); + out +} + +fn read_stream_all_with_events( + fsm: &mut SessionFsm, + stream_id: StreamId, + events: &mut Vec, +) -> Vec { + let mut stream = fsm.stream(stream_id, |event| events.push(event)).unwrap(); + let out = stream.read().flatten().collect::>(); + stream.commit_read(out.len()).unwrap(); + out +} + +fn next_outbound( + fsm: &mut SessionFsm, + now: Instant, +) -> Option<(RecordSeq, Vec>>)> { + let (write_id, builder) = fsm.take_next_write(now)?; + if let Some(write_id) = write_id { + fsm.complete_write(now, write_id, true); + } + Some(( + builder.seq(), + decode_session_frames(builder.bytes()).unwrap(), + )) +} + +fn drain_outbound( + fsm: &mut SessionFsm, + now: Instant, + limit: usize, +) -> Vec<(RecordSeq, Vec>>)> { + let mut records = Vec::new(); + for _ in 0..limit { + let Some(record) = next_outbound(fsm, now) else { + return records; + }; + records.push(record); + } + + panic!("session did not quiesce within outbound limit"); +} + +fn receive_events( + fsm: &mut SessionFsm, + now: Instant, + seq: RecordSeq, + record: &[SessionFrame>], +) -> Vec { + let mut builder = SessionRecordBuilder::new(seq, usize::MAX); + for frame in record { + assert!(builder.push_frame(frame)); + } + let bytes = Bytes::from(builder.bytes().to_vec()); + let frames = parse_session_frames(bytes); + let mut events = Vec::new(); + let mut emit = |event| events.push(event); + fsm.receive(now, seq, frames, &mut emit); + events +} + +#[test] +fn outbound_record_seq_increments_monotonically() { + let now = Instant::now(); + let mut fsm = SessionFsm::new(SessionConfig::default(), now); + let stream_id = open_stream_id(&mut fsm); + + assert_eq!(write_stream_bytes(&mut fsm, stream_id, b"one"), 3); + let (first_seq, _) = next_outbound(&mut fsm, now).unwrap(); + + assert_eq!(write_stream_bytes(&mut fsm, stream_id, b"two"), 3); + let (second_seq, _) = next_outbound(&mut fsm, now + Duration::from_millis(1)).unwrap(); + + assert_eq!(first_seq, seq(0)); + assert_eq!(second_seq, seq(1)); +} + +#[test] +fn retransmit_uses_new_record_seq() { + let now = Instant::now(); + let mut fsm = SessionFsm::new(SessionConfig::default(), now); + let stream_id = open_stream_id(&mut fsm); + + assert_eq!(write_stream_bytes(&mut fsm, stream_id, b"retry"), 5); + let (first_seq, first) = next_outbound(&mut fsm, now).unwrap(); + + let mut emit = |_| {}; + fsm.on_timer(now + Duration::from_millis(200), &mut emit); + let (retried_seq, retried) = next_outbound(&mut fsm, now + Duration::from_millis(200)).unwrap(); + + assert_ne!(first_seq, retried_seq); + assert_eq!(first, retried); +} + +#[test] +fn lost_record_on_one_stream_does_not_block_another_stream() { + let now = Instant::now(); + let mut fsm = SessionFsm::new( + SessionConfig { + record_max_size: 80 + SessionRecordBuilder::MIN_CAPACITY, + ..SessionConfig::default() + }, + now, + ); + let stream_id_a = open_stream_id(&mut fsm); + let stream_id_b = open_stream_id(&mut fsm); + let payload_a = vec![b'a'; 40]; + let payload_b = vec![b'b'; 40]; + + assert_eq!(write_stream_bytes(&mut fsm, stream_id_a, &payload_a), 40); + assert_eq!(write_stream_bytes(&mut fsm, stream_id_b, &payload_b), 40); + + let (first_seq, first) = next_outbound(&mut fsm, now).unwrap(); + let (second_seq, _second) = next_outbound(&mut fsm, now + Duration::from_millis(1)).unwrap(); + assert_ne!(first_seq, second_seq); + assert!(first.iter().any( + |frame| matches!(frame, SessionFrame::StreamData(frame) if frame.stream_id == stream_id_a) + )); + + assert_eq!(write_stream_bytes(&mut fsm, stream_id_b, b"b-2"), 3); + let (_third_seq, third) = next_outbound(&mut fsm, now + Duration::from_millis(2)).unwrap(); + + let stream_ids: Vec<_> = third + .iter() + .filter_map(|frame| match frame { + SessionFrame::StreamData(frame) => Some(frame.stream_id), + _ => None, + }) + .collect(); + assert_eq!(stream_ids, vec![stream_id_b]); +} + +#[test] +fn ack_reopens_write_capacity() { + let now = Instant::now(); + let mut fsm = SessionFsm::new( + SessionConfig { + stream_send_buffer_size: 4, + ..SessionConfig::default() + }, + now, + ); + let stream_id = open_stream_id(&mut fsm); + + assert_eq!(write_stream_bytes(&mut fsm, stream_id, b"abcd"), 4); + let (record_seq, _record) = next_outbound(&mut fsm, now).unwrap(); + + let mut events = Vec::new(); + let mut emit = |event| events.push(event); + fsm.receive( + now + Duration::from_millis(1), + seq(9), + std::iter::once(Ok(SessionFrame::Ack(record_ack(record_seq)))), + &mut emit, + ); + + assert!(events.contains(&SessionEvent::Writable(stream_id))); + assert_eq!(write_stream_bytes(&mut fsm, stream_id, b"z"), 1); +} + +#[test] +fn ack_of_fin_emits_outbound_finished_once() { + let now = Instant::now(); + let mut fsm = SessionFsm::new(SessionConfig::default(), now); + let stream_id = open_stream_id(&mut fsm); + + assert_eq!(write_stream_bytes(&mut fsm, stream_id, b"done"), 4); + fsm.stream(stream_id, |_| {}) + .unwrap() + .writer() + .unwrap() + .finish(); + + let (record_seq, record) = next_outbound(&mut fsm, now).unwrap(); + assert!(matches!( + record.as_slice(), + [SessionFrame::StreamData(StreamData { + stream_id: id, + fin: true, + .. + })] if *id == stream_id + )); + + let mut events = Vec::new(); + { + let mut emit = |event| events.push(event); + fsm.receive( + now + Duration::from_millis(1), + seq(9), + std::iter::once(Ok(SessionFrame::Ack(record_ack(record_seq)))), + &mut emit, + ); + } + assert_eq!(events, vec![SessionEvent::OutboundFinished(stream_id)]); + + { + let mut emit = |event| events.push(event); + fsm.receive( + now + Duration::from_millis(2), + seq(10), + std::iter::once(Ok(SessionFrame::Ack(record_ack(record_seq)))), + &mut emit, + ); + } + assert_eq!(events, vec![SessionEvent::OutboundFinished(stream_id)]); +} + +#[test] +fn commit_stream_read_is_what_advances_stream_window() { + let now = Instant::now(); + let mut fsm = SessionFsm::new( + SessionConfig { + local_parity: StreamParity::Even, + ack_delay: Duration::ZERO, + ..SessionConfig::default() + }, + now, + ); + let stream_id = stream_id(1); + let data = vec![SessionFrame::StreamData(StreamData { + stream_id, + offset: offset(0), + header: Some(header(1)), + fin: false, + bytes: b"hi".to_vec(), + })]; + let events = receive_events(&mut fsm, now, seq(7), &data); + assert_eq!( + events, + vec![opened(stream_id), SessionEvent::Readable(stream_id)] + ); + + let (write_id, builder) = fsm.take_next_write(now + Duration::from_millis(1)).unwrap(); + let first = decode_session_frames(builder.bytes()).unwrap(); + assert!(write_id.is_none()); + assert!(matches!(first.as_slice(), [SessionFrame::Ack(_)])); + + let read = fsm + .stream(stream_id, |_| {}) + .unwrap() + .read() + .map(|chunk| chunk.len()) + .sum::(); + assert_eq!(read, 2); + + assert!(next_outbound(&mut fsm, now + Duration::from_millis(2)).is_none()); + + fsm.stream(stream_id, |_| {}) + .unwrap() + .commit_read(2) + .unwrap(); + let (_second_seq, second) = next_outbound(&mut fsm, now + Duration::from_millis(3)).unwrap(); + assert!(matches!( + second.as_slice(), + [SessionFrame::StreamWindow(window)] if window.stream_id == stream_id + )); +} + +#[test] +fn pure_ack_only_records_are_fire_and_forget() { + let now = Instant::now(); + let config = SessionConfig { + ack_delay: Duration::ZERO, + ..SessionConfig::default() + }; + let retransmit_timeout = config.retransmit_timeout; + let mut fsm = SessionFsm::new(config, now); + let stream_id = stream_id(1); + let record = vec![SessionFrame::StreamData(StreamData { + stream_id, + offset: offset(0), + header: Some(header(1)), + fin: false, + bytes: b"hi".to_vec(), + })]; + + let _ = receive_events(&mut fsm, now, seq(7), &record); + + let (write_id, builder) = fsm.take_next_write(now + Duration::from_millis(1)).unwrap(); + let ack = decode_session_frames(builder.bytes()).unwrap(); + assert!(write_id.is_none()); + assert!(matches!(ack.as_slice(), [SessionFrame::Ack(_)])); + + let mut emit = |_| {}; + fsm.on_timer( + now + retransmit_timeout + Duration::from_millis(1), + &mut emit, + ); + assert!(fsm + .take_next_write(now + retransmit_timeout + Duration::from_millis(1)) + .is_none()); +} + +#[test] +fn inbound_stream_data_emits_opened_and_readable() { + let now = Instant::now(); + let mut fsm = SessionFsm::new(SessionConfig::default(), now); + let stream_id = stream_id(1); + let record = vec![SessionFrame::StreamData(ql_wire::StreamData { + stream_id, + offset: offset(0), + header: Some(header(1)), + fin: true, + bytes: b"hello".to_vec(), + })]; + + let events = receive_events(&mut fsm, now, seq(0), &record); + assert_eq!( + events, + vec![opened(stream_id), SessionEvent::Readable(stream_id)] + ); + let mut events = Vec::new(); + assert_eq!( + read_stream_all_with_events(&mut fsm, stream_id, &mut events), + b"hello".to_vec() + ); + assert_eq!(events, vec![SessionEvent::Finished(stream_id)]); +} + +#[test] +fn inbound_empty_fin_emits_finished_immediately() { + let now = Instant::now(); + let mut fsm = SessionFsm::new(SessionConfig::default(), now); + let stream_id = stream_id(1); + let record = vec![SessionFrame::StreamData(StreamData { + stream_id, + offset: offset(0), + header: Some(header(1)), + fin: true, + bytes: Vec::new(), + })]; + + let events = receive_events(&mut fsm, now, seq(0), &record); + assert_eq!( + events, + vec![opened(stream_id), SessionEvent::Finished(stream_id)] + ); +} + +#[test] +fn remote_stream_close_is_reliable_and_retried() { + let now = Instant::now(); + let mut fsm = SessionFsm::new(SessionConfig::default(), now); + let stream_id = open_stream_id(&mut fsm); + + fsm.stream(stream_id, |_| {}) + .unwrap() + .close(CloseTarget::Both, StreamCloseCode::CANCELLED); + + let (write_id, builder) = fsm.take_next_write(now).unwrap(); + fsm.complete_write(now, write_id.expect("stream close should be tracked"), true); + let first = decode_session_frames(builder.bytes()).unwrap(); + assert!(matches!( + first.as_slice(), + [SessionFrame::StreamClose(StreamClose { stream_id: id, .. })] if *id == stream_id + )); + + let mut emit = |_| {}; + fsm.on_timer(now + Duration::from_millis(200), &mut emit); + let (_retried_seq, retried) = + next_outbound(&mut fsm, now + Duration::from_millis(200)).unwrap(); + assert_eq!(first, retried); +} + +#[test] +fn stream_ids_follow_even_odd_xid_ordering() { + let now = Instant::now(); + let even = StreamParity::for_local(QID([1; QID::SIZE]), QID([2; QID::SIZE])); + let odd = StreamParity::for_local(QID([2; QID::SIZE]), QID([1; QID::SIZE])); + + let even_id = SessionFsm::new( + SessionConfig { + local_parity: even, + ..SessionConfig::default() + }, + now, + ) + .open_stream(route_id(1), |_| {}) + .unwrap() + .stream_id(); + let odd_id = SessionFsm::new( + SessionConfig { + local_parity: odd, + ..SessionConfig::default() + }, + now, + ) + .open_stream(route_id(1), |_| {}) + .unwrap() + .stream_id(); + + assert_eq!(even_id.into_inner() % 2, 0); + assert_eq!(odd_id.into_inner() % 2, 1); +} + +#[test] +fn duplicate_stream_data_is_not_redelivered() { + let now = Instant::now(); + let mut fsm = SessionFsm::new(SessionConfig::default(), now); + let stream_id = stream_id(1); + let record = vec![SessionFrame::StreamData(StreamData { + stream_id, + offset: offset(0), + header: Some(header(1)), + fin: false, + bytes: b"hi".to_vec(), + })]; + let _ = receive_events(&mut fsm, now, seq(1), &record); + let _ = receive_events(&mut fsm, now + Duration::from_millis(1), seq(2), &record); + + assert_eq!(read_stream_all(&mut fsm, stream_id), b"hi".to_vec()); +} + +#[test] +fn duplicate_remote_close_after_reap_is_ignored() { + let now = Instant::now(); + let mut fsm = SessionFsm::new(SessionConfig::default(), now); + let close = StreamClose { + stream_id: stream_id(1), + target: CloseTarget::Both, + code: StreamCloseCode(9), + }; + let record = vec![SessionFrame::StreamClose(close.clone())]; + + let first = receive_events(&mut fsm, now, seq(1), &record); + assert_eq!( + first, + vec![ + SessionEvent::Closed(close.clone()), + SessionEvent::WritableClosed(close), + ] + ); + + let second = receive_events(&mut fsm, now + Duration::from_millis(1), seq(2), &record); + assert!(second.is_empty()); +} + +#[test] +fn late_remote_stream_data_after_close_is_ignored() { + let now = Instant::now(); + let mut fsm = SessionFsm::new(SessionConfig::default(), now); + let stream_id = stream_id(1); + let close = vec![SessionFrame::StreamClose(StreamClose { + stream_id, + target: CloseTarget::Both, + code: StreamCloseCode(9), + })]; + let data = vec![SessionFrame::StreamData(StreamData { + stream_id, + offset: offset(0), + header: Some(header(1)), + fin: false, + bytes: b"hello".to_vec(), + })]; + + let first = receive_events(&mut fsm, now, seq(1), &close); + assert_eq!( + first, + vec![ + SessionEvent::Closed(StreamClose { + stream_id, + target: CloseTarget::Both, + code: StreamCloseCode(9), + }), + SessionEvent::WritableClosed(StreamClose { + stream_id, + target: CloseTarget::Both, + code: StreamCloseCode(9), + }), + ] + ); + + let second = receive_events(&mut fsm, now + Duration::from_millis(1), seq(2), &data); + assert!(second.is_empty()); +} + +#[test] +fn duplicate_finished_remote_data_after_reap_is_ignored() { + let now = Instant::now(); + let mut fsm = SessionFsm::new(SessionConfig::default(), now); + let stream_id = stream_id(1); + let record = vec![SessionFrame::StreamData(StreamData { + stream_id, + offset: offset(0), + header: Some(header(1)), + fin: true, + bytes: b"hello".to_vec(), + })]; + + let first = receive_events(&mut fsm, now, seq(1), &record); + assert_eq!( + first, + vec![opened(stream_id), SessionEvent::Readable(stream_id)] + ); + let mut events = Vec::new(); + assert_eq!( + read_stream_all_with_events(&mut fsm, stream_id, &mut events), + b"hello".to_vec() + ); + assert_eq!(events, vec![SessionEvent::Finished(stream_id)]); + + let second = receive_events(&mut fsm, now + Duration::from_millis(1), seq(2), &record); + assert!(second.is_empty()); +} + +#[test] +fn duplicate_finished_remote_data_before_read_is_ignored() { + let now = Instant::now(); + let mut fsm = SessionFsm::new(SessionConfig::default(), now); + let stream_id = stream_id(1); + let record = vec![SessionFrame::StreamData(StreamData { + stream_id, + offset: offset(0), + header: Some(header(1)), + fin: true, + bytes: b"hello".to_vec(), + })]; + + let first = receive_events(&mut fsm, now, seq(1), &record); + assert_eq!( + first, + vec![opened(stream_id), SessionEvent::Readable(stream_id)] + ); + + let second = receive_events(&mut fsm, now + Duration::from_millis(1), seq(2), &record); + assert!(second.is_empty()); + let mut events = Vec::new(); + assert_eq!( + read_stream_all_with_events(&mut fsm, stream_id, &mut events), + b"hello".to_vec() + ); + assert_eq!(events, vec![SessionEvent::Finished(stream_id)]); +} + +#[test] +fn out_of_order_remote_stream_first_observations_still_open_once_each() { + let now = Instant::now(); + let mut fsm = SessionFsm::new(SessionConfig::default(), now); + let close3 = vec![SessionFrame::StreamClose(StreamClose { + stream_id: stream_id(3), + target: CloseTarget::Both, + code: REFUSED, + })]; + let close1 = vec![SessionFrame::StreamClose(StreamClose { + stream_id: stream_id(1), + target: CloseTarget::Both, + code: TIMEOUT, + })]; + + let first = receive_events(&mut fsm, now, seq(1), &close3); + assert_eq!( + first, + vec![ + SessionEvent::Closed(StreamClose { + stream_id: stream_id(3), + target: CloseTarget::Both, + code: REFUSED, + }), + SessionEvent::WritableClosed(StreamClose { + stream_id: stream_id(3), + target: CloseTarget::Both, + code: REFUSED, + }), + ] + ); + + let second = receive_events(&mut fsm, now + Duration::from_millis(1), seq(2), &close1); + assert_eq!( + second, + vec![ + SessionEvent::Closed(StreamClose { + stream_id: stream_id(1), + target: CloseTarget::Both, + code: TIMEOUT, + }), + SessionEvent::WritableClosed(StreamClose { + stream_id: stream_id(1), + target: CloseTarget::Both, + code: TIMEOUT, + }), + ] + ); + + let third = receive_events(&mut fsm, now + Duration::from_millis(2), seq(3), &close3); + assert!(third.is_empty()); +} + +#[test] +fn invalid_remote_stream_close_closes_session() { + let now = Instant::now(); + let mut fsm = SessionFsm::new(SessionConfig::default(), now); + + let invalid = vec![SessionFrame::StreamClose(StreamClose { + stream_id: stream_id(0), + target: CloseTarget::Both, + code: StreamCloseCode(9), + })]; + let events = receive_events(&mut fsm, now, seq(1), &invalid); + + assert_eq!( + events, + vec![SessionEvent::SessionClosed(ql_wire::SessionClose { + code: ql_wire::SessionCloseCode::PROTOCOL, + })] + ); +} + +#[test] +fn close_does_not_ack_rejected_record_seq() { + let now = Instant::now(); + let mut fsm = SessionFsm::new( + SessionConfig { + ack_delay: Duration::ZERO, + ..SessionConfig::default() + }, + now, + ); + + let invalid = vec![SessionFrame::StreamData(StreamData { + stream_id: stream_id(0), + offset: offset(0), + header: Some(header(1)), + fin: false, + bytes: b"bad".to_vec(), + })]; + let events = receive_events(&mut fsm, now, seq(7), &invalid); + assert_eq!( + events, + vec![SessionEvent::SessionClosed(ql_wire::SessionClose { + code: ql_wire::SessionCloseCode::PROTOCOL, + })] + ); + + let valid_after_close = vec![SessionFrame::Ping]; + let events = receive_events( + &mut fsm, + now + Duration::from_millis(1), + seq(8), + &valid_after_close, + ); + assert!(events.is_empty()); + + let (_seq, outbound) = next_outbound(&mut fsm, now + Duration::from_millis(2)).unwrap(); + assert!(matches!(outbound.as_slice(), [SessionFrame::Close(_)])); +} + +#[test] +fn inbound_unpair_emits_final_unpair_frame() { + let now = Instant::now(); + let mut fsm = SessionFsm::new(SessionConfig::default(), now); + + let events = receive_events(&mut fsm, now, seq(1), &[SessionFrame::Unpair]); + assert_eq!(events, vec![SessionEvent::Unpaired]); + assert!(!fsm.is_closed()); + + let (_seq, outbound) = next_outbound(&mut fsm, now + Duration::from_millis(1)).unwrap(); + assert!(matches!(outbound.as_slice(), [SessionFrame::Unpair])); + assert!(fsm.is_closed()); +} + +#[test] +fn terminating_session_ignores_inbound_frames() { + let now = Instant::now(); + let mut fsm = SessionFsm::new(SessionConfig::default(), now); + + let mut events = Vec::new(); + fsm.unpair(&mut |event| events.push(event)); + assert_eq!(events, vec![SessionEvent::Unpaired]); + + let ignored = receive_events( + &mut fsm, + now + Duration::from_millis(1), + seq(1), + &[SessionFrame::Ping], + ); + assert!(ignored.is_empty()); + + let (_seq, outbound) = next_outbound(&mut fsm, now + Duration::from_millis(2)).unwrap(); + assert!(matches!(outbound.as_slice(), [SessionFrame::Unpair])); + assert!(fsm.is_closed()); +} + +#[test] +fn initial_peer_stream_receive_window_limits_first_send() { + let now = Instant::now(); + let mut fsm = SessionFsm::new( + SessionConfig { + initial_peer_stream_receive_window: 3, + ..SessionConfig::default() + }, + now, + ); + let stream_id = open_stream_id(&mut fsm); + + assert_eq!(write_stream_bytes(&mut fsm, stream_id, b"hello"), 5); + let (_first_seq, first) = next_outbound(&mut fsm, now).unwrap(); + assert!(matches!( + first.as_slice(), + [SessionFrame::StreamData(frame)] if frame.stream_id == stream_id && frame.bytes.as_slice() == b"hel" + )); + + let events = receive_events( + &mut fsm, + now + Duration::from_millis(1), + seq(9), + &[SessionFrame::StreamWindow(ql_wire::StreamWindow { + stream_id, + maximum_offset: offset(5), + })], + ); + assert!(events.is_empty()); + + let (_second_seq, second) = next_outbound(&mut fsm, now + Duration::from_millis(2)).unwrap(); + assert!(second.iter().any(|frame| { + matches!( + frame, + SessionFrame::StreamData(frame) + if frame.stream_id == stream_id + && frame.offset == offset(3) + && frame.bytes.as_slice() == b"lo" + ) + })); +} + +#[test] +fn sparse_out_of_order_ack_ranges_page_and_quiesce() { + let now = Instant::now(); + let sender_config = SessionConfig { + local_parity: StreamParity::Even, + record_max_size: SessionRecordBuilder::MIN_CAPACITY + 40, + ack_delay: Duration::from_millis(5), + retransmit_timeout: Duration::from_millis(25), + stream_send_buffer_size: 8 * 1024, + initial_peer_stream_receive_window: 8 * 1024, + ..SessionConfig::default() + }; + let receiver_config = SessionConfig { + local_parity: StreamParity::Odd, + record_max_size: SessionRecordBuilder::MIN_CAPACITY + 10, + ack_delay: Duration::from_millis(1), + retransmit_timeout: Duration::from_millis(25), + pending_ack_range_limit: 512, + initial_peer_stream_receive_window: 8 * 1024, + ..SessionConfig::default() + }; + let mut sender = SessionFsm::new(sender_config, now); + let mut receiver = SessionFsm::new(receiver_config, now); + + let stream_id = open_stream_id(&mut sender); + let payload = vec![b'x'; 2048]; + assert_eq!( + write_stream_bytes(&mut sender, stream_id, &payload), + payload.len() + ); + + let originals = drain_outbound(&mut sender, now, 4096); + assert!(originals.len() >= 64); + + for (seq, record) in originals + .iter() + .filter(|(seq, _)| seq.into_inner() % 2 == 1) + { + let _ = receive_events(&mut receiver, now, *seq, record); + } + + let first_ack_time = now + receiver_config.ack_delay; + let first_acks = drain_outbound(&mut receiver, first_ack_time, originals.len()); + assert!(first_acks.len() > 1); + assert!(first_acks + .iter() + .all(|(_, frames)| matches!(frames.as_slice(), [SessionFrame::Ack(_)]))); + + for (seq, record) in &first_acks { + let _ = receive_events(&mut sender, first_ack_time, *seq, record); + } + + let retransmit_time = now + sender_config.retransmit_timeout + Duration::from_millis(1); + let mut emit = |_| {}; + sender.on_timer(retransmit_time, &mut emit); + let retransmits = drain_outbound(&mut sender, retransmit_time, originals.len()); + assert!(!retransmits.is_empty()); + + for (seq, record) in &retransmits { + let _ = receive_events(&mut receiver, retransmit_time, *seq, record); + } + + let second_ack_time = retransmit_time + receiver_config.ack_delay; + let second_acks = drain_outbound(&mut receiver, second_ack_time, retransmits.len() + 16); + assert!(!second_acks.is_empty()); + assert!(second_acks + .iter() + .all(|(_, frames)| matches!(frames.as_slice(), [SessionFrame::Ack(_)]))); + + for (seq, record) in &second_acks { + let _ = receive_events(&mut sender, second_ack_time, *seq, record); + } + + let final_now = second_ack_time + sender_config.retransmit_timeout + Duration::from_millis(1); + let mut sender_emit = |_| {}; + sender.on_timer(final_now, &mut sender_emit); + let mut receiver_emit = |_| {}; + receiver.on_timer(final_now, &mut receiver_emit); + assert!(next_outbound(&mut sender, final_now).is_none()); + assert!(next_outbound(&mut receiver, final_now).is_none()); +} diff --git a/ql-fsm/src/session/tracked.rs b/ql-fsm/src/session/tracked.rs new file mode 100644 index 00000000..84317951 --- /dev/null +++ b/ql-fsm/src/session/tracked.rs @@ -0,0 +1,29 @@ +//! outbound record tracking state for ack and retransmit handling + +use std::time::Instant; + +use ql_wire::{RecordAck, RecordSeq, StreamClose, StreamId}; + +#[derive(Debug, Clone)] +pub struct TrackedRecord { + pub seq: RecordSeq, + pub frames: Vec, + pub ack: Option, + pub ping_included: bool, + pub window_updates: Vec<(StreamId, u64)>, + pub sent_at: Option, +} + +#[derive(Debug, Clone)] +pub enum TrackedFrame { + StreamData(TrackedStreamData), + StreamClose(StreamClose), +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct TrackedStreamData { + pub stream_id: StreamId, + pub offset: u64, + pub len: usize, + pub fin: bool, +} diff --git a/ql-fsm/src/state.rs b/ql-fsm/src/state.rs new file mode 100644 index 00000000..8268bc16 --- /dev/null +++ b/ql-fsm/src/state.rs @@ -0,0 +1,139 @@ +use std::time::Instant; + +use ql_wire::{ + ConnectionId, EphemeralPublicKey, HandshakeId, HandshakeMeta, IkHandshake, KkHandshake, + PairingToken, PeerBundle, QlHandshakeRecord, SessionKey, TransportParams, XxHandshake, +}; + +use crate::{session::SessionFsm, NoSessionError, PeerStatus}; + +pub struct QlFsmState { + pub next_control_id: u32, + pub peer: Option, + pub armed_pairing_token: Option, + pub handshake: Option, + pub link: LinkState, + pub now: Instant, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct SessionTransport { + pub tx_key: SessionKey, + pub rx_key: SessionKey, + pub tx_connection_id: ConnectionId, + pub rx_connection_id: ConnectionId, + pub remote_transport_params: TransportParams, +} + +impl SessionTransport { + pub fn from_finalized(finalized: ql_wire::FinalizedHandshake) -> (Self, PeerBundle) { + ( + Self { + tx_key: finalized.tx_key, + rx_key: finalized.rx_key, + tx_connection_id: finalized.tx_connection_id, + rx_connection_id: finalized.rx_connection_id, + remote_transport_params: finalized.remote_transport_params, + }, + finalized.remote_bundle, + ) + } +} + +#[allow(clippy::large_enum_variant)] +pub enum LinkState { + Idle, + IkInitiator(IkInitiatorState), + KkInitiator(KkInitiatorState), + XxInitiator(XxInitiatorState), + XxResponder(XxResponderState), + Connected(ConnectedState), +} + +pub struct ConnectedState { + pub transport: SessionTransport, + pub session: SessionFsm, +} + +#[derive(Debug, Clone)] +pub struct IkInitiatorState { + pub handshake: IkHandshake, + pub handshake_id: HandshakeId, + pub deadline: Instant, + pub initial_ephemeral: EphemeralPublicKey, +} + +#[derive(Debug, Clone)] +pub struct KkInitiatorState { + pub handshake: KkHandshake, + pub handshake_id: HandshakeId, + pub deadline: Instant, + pub initial_ephemeral: EphemeralPublicKey, +} + +#[derive(Debug, Clone)] +pub struct XxInitiatorState { + pub handshake: XxHandshake, + pub handshake_id: HandshakeId, + pub deadline: Instant, + pub initial_ephemeral: EphemeralPublicKey, +} + +#[derive(Debug, Clone)] +pub struct XxResponderState { + pub handshake: XxHandshake, + pub handshake_meta: HandshakeMeta, + pub deadline: Instant, +} + +impl LinkState { + pub fn take(&mut self) -> Self { + std::mem::replace(self, Self::Idle) + } + + pub fn status(&self) -> PeerStatus { + match self { + Self::Idle | Self::XxResponder(_) => PeerStatus::Disconnected, + Self::IkInitiator(_) | Self::KkInitiator(_) | Self::XxInitiator(_) => { + PeerStatus::Initiator + } + Self::Connected(_) => PeerStatus::Connected, + } + } + + #[inline] + pub fn connected(&self) -> Option<&ConnectedState> { + match self { + Self::Connected(state) => Some(state), + _ => None, + } + } + + #[inline] + pub fn connected_mut(&mut self) -> Option<&mut ConnectedState> { + match self { + Self::Connected(state) => Some(state), + _ => None, + } + } + + #[inline] + pub fn connected_mut_or_err(&mut self) -> Result<&mut ConnectedState, NoSessionError> { + self.connected_mut().ok_or(NoSessionError) + } + + pub fn handshake_deadline(&self) -> Option { + match self { + Self::Idle | Self::Connected(_) => None, + Self::IkInitiator(state) => Some(state.deadline), + Self::KkInitiator(state) => Some(state.deadline), + Self::XxInitiator(state) => Some(state.deadline), + Self::XxResponder(state) => Some(state.deadline), + } + } + + #[cfg(test)] + pub fn transport(&self) -> Option<&SessionTransport> { + self.connected().map(|state| &state.transport) + } +} diff --git a/ql-fsm/src/tests/handshake.rs b/ql-fsm/src/tests/handshake.rs new file mode 100644 index 00000000..4de4f06e --- /dev/null +++ b/ql-fsm/src/tests/handshake.rs @@ -0,0 +1,388 @@ +use std::time::Duration; + +use ql_wire::QlHandshakeRecord; + +use super::*; +use crate::{state::LinkState, Event, NoPeerError, PeerStatus, ReceiveError}; + +#[test] +fn ik_connect_round_trip_establishes_transport() { + let mut harness = Harness::paired_known(QlFsmConfig::default()); + + harness.connect_ik(Side::A).unwrap(); + harness.pump(); + + assert!(matches!(harness.a.fsm.state.link, LinkState::Connected(_))); + assert!(matches!(harness.b.fsm.state.link, LinkState::Connected(_))); +} + +#[test] +fn kk_connect_round_trip_establishes_transport() { + let mut harness = Harness::paired_known(QlFsmConfig::default()); + + harness.connect_kk(Side::A).unwrap(); + harness.pump(); + + assert!(matches!(harness.a.fsm.state.link, LinkState::Connected(_))); + assert!(matches!(harness.b.fsm.state.link, LinkState::Connected(_))); +} + +#[test] +fn xx_connect_round_trip_establishes_transport_when_armed() { + let mut harness = Harness::paired(QlFsmConfig::default(), false, false); + let token = pairing_token(1); + + harness.b.fsm.arm_pairing(token); + harness.connect_xx(Side::A, token); + + let xx1 = harness.next_outbound(Side::A).unwrap(); + harness.deliver(Side::B, xx1); + let xx2 = harness.next_outbound(Side::B).unwrap(); + harness.deliver(Side::A, xx2); + let xx3 = harness.next_outbound(Side::A).unwrap(); + harness.deliver(Side::B, xx3); + + let xx4 = harness.next_outbound(Side::B).unwrap(); + harness.deliver(Side::A, xx4); + + assert_eq!(harness.a.fsm.peer(), Some(&harness.b.fsm.identity.bundle())); + assert_eq!(harness.b.fsm.peer(), Some(&harness.a.fsm.identity.bundle())); + assert!(matches!(harness.a.fsm.state.link, LinkState::Connected(_))); + assert!(matches!(harness.b.fsm.state.link, LinkState::Connected(_))); +} + +#[test] +fn ik_connect_learns_remote_initial_stream_receive_window() { + let mut harness = Harness::paired_known_with_configs( + QlFsmConfig { + session_stream_receive_buffer_size: 9, + ..QlFsmConfig::default() + }, + QlFsmConfig { + session_stream_receive_buffer_size: 3, + ..QlFsmConfig::default() + }, + ); + + harness.connect_ik(Side::A).unwrap(); + harness.pump(); + + assert_eq!( + harness + .a + .fsm + .state + .link + .transport() + .unwrap() + .remote_transport_params + .initial_stream_receive_window, + 3 + ); + assert_eq!( + harness + .b + .fsm + .state + .link + .transport() + .unwrap() + .remote_transport_params + .initial_stream_receive_window, + 9 + ); +} + +#[test] +fn connect_methods_require_bound_peer() { + let time = Harness::paired_known(QlFsmConfig::default()).time(); + let identity = generate_identity(&SoftwareCrypto, "identity").unwrap(); + let mut fsm = QlFsm::new(QlFsmConfig::default(), identity, time); + let crypto = SoftwareCrypto; + + assert_eq!(fsm.connect_ik(time, &crypto), Err(NoPeerError)); + assert_eq!(fsm.connect_kk(time, &crypto), Err(NoPeerError)); + + fsm.connect_xx( + time, + PairingInvite { + qid: ql_wire::QID([2; ql_wire::QID::SIZE]), + token: pairing_token(2), + }, + &crypto, + ); +} + +#[test] +fn connect_ik_emits_initiator_status() { + let mut harness = Harness::paired_known(QlFsmConfig::default()); + + harness.connect_ik(Side::A).unwrap(); + + assert_eq!( + harness.drain_events(Side::A), + vec![Event::PeerStatusChanged(PeerStatus::Initiator)] + ); +} + +#[test] +fn inbound_xx1_rejects_when_not_in_pairing_mode() { + let mut harness = Harness::paired(QlFsmConfig::default(), false, false); + let token = pairing_token(3); + + harness.connect_xx(Side::A, token); + let xx1 = harness.next_outbound(Side::A).unwrap(); + let time = harness.time(); + let Node { fsm, crypto } = &mut harness.b; + let err = fsm.receive(time, xx1, crypto); + + assert_eq!(err, Err(ReceiveError::NotPairingMode)); + assert!(matches!(harness.b.fsm.state.link, LinkState::Idle)); + assert!(harness.drain_events(Side::B).is_empty()); + assert!(harness.next_outbound(Side::B).is_none()); +} + +#[test] +fn inbound_xx1_rejects_mismatched_pairing_id_with_expected_and_actual() { + let mut harness = Harness::paired(QlFsmConfig::default(), false, false); + let expected = pairing_token(4); + let actual = pairing_token(7); + + harness.b.fsm.arm_pairing(expected); + harness.connect_xx(Side::A, actual); + let xx1 = harness.next_outbound(Side::A).unwrap(); + + let time = harness.time(); + let Node { fsm, crypto } = &mut harness.b; + let err = fsm.receive(time, xx1, crypto); + + assert_eq!( + err, + Err(ReceiveError::InvalidPairingId { + expected: expected.id(&SoftwareCrypto), + actual: actual.id(&SoftwareCrypto), + }) + ); +} + +#[test] +fn disarm_pairing_rejects_inflight_inbound_xx_responder() { + let mut harness = Harness::paired(QlFsmConfig::default(), false, false); + let token = pairing_token(5); + + harness.b.fsm.arm_pairing(token); + harness.connect_xx(Side::A, token); + let xx1 = harness.next_outbound(Side::A).unwrap(); + harness.deliver(Side::B, xx1); + let xx2 = harness.next_outbound(Side::B).unwrap(); + harness.deliver(Side::A, xx2); + let xx3 = harness.next_outbound(Side::A).unwrap(); + harness.b.fsm.disarm_pairing(); + harness.deliver(Side::B, xx3); + + assert!(matches!(harness.b.fsm.state.link, LinkState::Idle)); + assert!(harness.next_outbound(Side::B).is_none()); +} + +#[test] +fn simultaneous_xx_connect_converges() { + let mut harness = Harness::paired(QlFsmConfig::default(), false, false); + let token = pairing_token(6); + + harness.a.fsm.arm_pairing(token); + harness.b.fsm.arm_pairing(token); + harness.connect_xx(Side::A, token); + harness.connect_xx(Side::B, token); + + for _ in 0..2 { + if let Some(record) = harness.next_outbound(Side::A) { + harness.deliver(Side::B, record); + } + if let Some(record) = harness.next_outbound(Side::B) { + harness.deliver(Side::A, record); + } + } + harness.pump(); + + assert!(matches!(harness.a.fsm.state.link, LinkState::Connected(_))); + assert!(matches!(harness.b.fsm.state.link, LinkState::Connected(_))); +} + +#[test] +fn connect_ik_replaces_in_flight_attempt_and_ignores_stale_reply() { + let mut harness = Harness::paired_known(QlFsmConfig::default()); + + harness.connect_ik(Side::A).unwrap(); + harness.drain_events(Side::A); + let first = harness.next_outbound(Side::A).unwrap(); + let first_id = handshake_id(&first); + + harness.connect_ik(Side::A).unwrap(); + let second = harness.next_outbound(Side::A).unwrap(); + let second_id = handshake_id(&second); + + assert_ne!(first_id, second_id); + + harness.deliver(Side::B, first); + let stale_reply = harness.next_outbound(Side::B).unwrap(); + assert_eq!(handshake_id(&stale_reply), first_id); + + harness.deliver(Side::A, stale_reply); + assert!(matches!( + harness.a.fsm.state.link, + LinkState::IkInitiator(_) + )); + + harness.deliver(Side::B, second); + harness.pump(); + + assert!(matches!(harness.a.fsm.state.link, LinkState::Connected(_))); + assert!(matches!(harness.b.fsm.state.link, LinkState::Connected(_))); +} + +#[test] +fn connect_kk_replaces_in_flight_attempt_and_ignores_stale_reply() { + let mut harness = Harness::paired_known(QlFsmConfig::default()); + + harness.connect_kk(Side::A).unwrap(); + let first = harness.next_outbound(Side::A).unwrap(); + let first_id = handshake_id(&first); + + harness.connect_kk(Side::A).unwrap(); + let second = harness.next_outbound(Side::A).unwrap(); + let second_id = handshake_id(&second); + + assert_ne!(first_id, second_id); + + harness.deliver(Side::B, first); + let stale_reply = harness.next_outbound(Side::B).unwrap(); + assert_eq!(handshake_id(&stale_reply), first_id); + + harness.deliver(Side::A, stale_reply); + assert!(matches!( + harness.a.fsm.state.link, + LinkState::KkInitiator(_) + )); + + harness.deliver(Side::B, second); + harness.pump(); + + assert!(matches!(harness.a.fsm.state.link, LinkState::Connected(_))); + assert!(matches!(harness.b.fsm.state.link, LinkState::Connected(_))); +} + +#[test] +fn inbound_ik1_auto_binds_unbound_responder() { + let mut harness = Harness::paired(QlFsmConfig::default(), true, false); + + harness.connect_ik(Side::A).unwrap(); + harness.pump(); + + let expected_peer = harness.a.fsm.identity.bundle(); + assert_eq!(harness.b.fsm.peer(), Some(&expected_peer)); + assert_eq!( + harness.drain_events(Side::B), + vec![ + Event::NewPeer, + Event::PeerStatusChanged(PeerStatus::Connected), + ] + ); + assert!(matches!(harness.a.fsm.state.link, LinkState::Connected(_))); + assert!(matches!(harness.b.fsm.state.link, LinkState::Connected(_))); +} + +#[test] +fn handshake_timeout_drops_single_ik_attempt_without_resend() { + let config = QlFsmConfig { + handshake_timeout: Duration::from_millis(60), + ..QlFsmConfig::default() + }; + let mut harness = Harness::paired_known(config); + + harness.connect_ik(Side::A).unwrap(); + harness.drain_events(Side::A); + let first = harness.next_outbound(Side::A).unwrap(); + let (_, first) = ql_wire::decode_record::(first.as_slice()).unwrap(); + assert!(matches!(first, ql_wire::QlHandshakeRecord::Ik1(_))); + assert!(harness.next_outbound(Side::A).is_none()); + + harness.advance(config.handshake_timeout); + harness.on_timer(Side::A); + + assert!(matches!(harness.a.fsm.state.link, LinkState::Idle)); + assert_eq!( + harness.take_event(Side::A), + Some(Event::PeerStatusChanged(PeerStatus::Disconnected)) + ); + assert!(harness.next_outbound(Side::A).is_none()); +} + +#[test] +fn handshake_timeout_clears_queued_kk_output() { + let config = QlFsmConfig { + handshake_timeout: Duration::from_millis(60), + ..QlFsmConfig::default() + }; + let mut harness = Harness::paired_known(config); + + harness.connect_kk(Side::A).unwrap(); + + harness.advance(config.handshake_timeout); + harness.on_timer(Side::A); + + assert!(matches!(harness.a.fsm.state.link, LinkState::Idle)); + assert!(harness.next_outbound(Side::A).is_none()); +} + +#[test] +fn bind_peer_clears_queued_handshake_output() { + let mut harness = Harness::paired_known(QlFsmConfig::default()); + + harness.connect_ik(Side::A).unwrap(); + harness.drain_events(Side::A); + harness + .a + .fsm + .bind_peer(generate_identity(&SoftwareCrypto, "peer").unwrap().bundle()); + + assert!(harness.drain_events(Side::A).is_empty()); + assert!(harness.next_outbound(Side::A).is_none()); +} + +#[test] +fn simultaneous_ik_connect_converges() { + let mut harness = Harness::paired_known(QlFsmConfig::default()); + + harness.connect_ik(Side::A).unwrap(); + harness.connect_ik(Side::B).unwrap(); + harness.pump(); + + assert!(matches!(harness.a.fsm.state.link, LinkState::Connected(_))); + assert!(matches!(harness.b.fsm.state.link, LinkState::Connected(_))); +} + +#[test] +fn simultaneous_ik_and_kk_connect_prefers_ik() { + let mut harness = Harness::paired_known(QlFsmConfig::default()); + + harness.connect_ik(Side::A).unwrap(); + harness.connect_kk(Side::B).unwrap(); + harness.pump(); + + assert!(matches!(harness.a.fsm.state.link, LinkState::Connected(_))); + assert!(matches!(harness.b.fsm.state.link, LinkState::Connected(_))); +} + +fn handshake_id(record: &[u8]) -> ql_wire::HandshakeId { + let (_, record) = ql_wire::decode_record(record).unwrap(); + match record { + ql_wire::QlHandshakeRecord::Ik1(message) => message.meta.handshake_id, + ql_wire::QlHandshakeRecord::Ik2(message) => message.meta.handshake_id, + ql_wire::QlHandshakeRecord::Kk1(message) => message.meta.handshake_id, + ql_wire::QlHandshakeRecord::Kk2(message) => message.meta.handshake_id, + ql_wire::QlHandshakeRecord::Xx1(message) => message.meta.handshake_id, + ql_wire::QlHandshakeRecord::Xx2(message) => message.meta.handshake_id, + ql_wire::QlHandshakeRecord::Xx3(message) => message.meta.handshake_id, + ql_wire::QlHandshakeRecord::Xx4(message) => message.meta.handshake_id, + } +} diff --git a/ql-fsm/src/tests/mod.rs b/ql-fsm/src/tests/mod.rs new file mode 100644 index 00000000..c4d005e4 --- /dev/null +++ b/ql-fsm/src/tests/mod.rs @@ -0,0 +1,351 @@ +mod handshake; +mod proptest; +mod session; + +use std::time::{Duration, Instant}; + +use ql_wire::{ + self, generate_identity, test_identities, ConnectionId, PairingToken, QlCrypto, SessionKey, + SoftwareCrypto, TransportParams, QID, +}; + +use crate::{ + session::{SessionConfig, SessionFsm, StreamParity}, + state::{ConnectedState, LinkState, SessionTransport}, + Event, NoPeerError, OutboundWrite, PairingInvite, QlFsm, QlFsmConfig, WriteId, +}; + +type TestCrypto = SoftwareCrypto; + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +enum Side { + A, + B, +} + +impl Side { + fn idx(self) -> usize { + match self { + Self::A => 0, + Self::B => 1, + } + } +} + +struct Node { + fsm: QlFsm, + crypto: TestCrypto, +} + +struct Harness { + now: Instant, + a: Node, + b: Node, +} + +struct DecodedSessionWrite { + record: Vec, + write_id: Option, + header: ql_wire::SessionHeader, + frames: Vec>>, +} + +impl Harness { + fn paired_known(config: QlFsmConfig) -> Self { + Self::paired_with_configs(config, config, true, true) + } + + fn paired(config: QlFsmConfig, know_a: bool, know_b: bool) -> Self { + Self::paired_with_configs(config, config, know_a, know_b) + } + + fn paired_known_with_configs(config_a: QlFsmConfig, config_b: QlFsmConfig) -> Self { + Self::paired_with_configs(config_a, config_b, true, true) + } + + fn paired_with_configs( + config_a: QlFsmConfig, + config_b: QlFsmConfig, + know_a: bool, + know_b: bool, + ) -> Self { + let (identity_a, identity_b) = test_identities(&SoftwareCrypto); + let now = Instant::now(); + + let mut harness = Self { + now, + a: Node { + fsm: QlFsm::new(config_a, identity_a.clone(), now), + crypto: SoftwareCrypto, + }, + b: Node { + fsm: QlFsm::new(config_b, identity_b.clone(), now), + crypto: SoftwareCrypto, + }, + }; + + if know_a { + harness.a.fsm.bind_peer(identity_b.bundle()); + } + if know_b { + harness.b.fsm.bind_peer(identity_a.bundle()); + } + + harness + } + + fn connected(config: QlFsmConfig) -> Self { + let mut harness = Self::paired_known(config); + let a_to_b_key = SessionKey::from_data([7; SessionKey::SIZE]); + let b_to_a_key = SessionKey::from_data([9; SessionKey::SIZE]); + let a_to_b_conn = ConnectionId::from_data([0xA1; ConnectionId::SIZE]); + let b_to_a_conn = ConnectionId::from_data([0xB2; ConnectionId::SIZE]); + + harness.a.fsm.state.link = LinkState::Connected(ConnectedState { + transport: SessionTransport { + tx_key: a_to_b_key.clone(), + rx_key: b_to_a_key.clone(), + tx_connection_id: a_to_b_conn, + rx_connection_id: b_to_a_conn, + remote_transport_params: TransportParams { + initial_stream_receive_window: harness + .b + .fsm + .config + .session_stream_receive_buffer_size, + }, + }, + session: SessionFsm::new(session_config(&harness, true), harness.now), + }); + harness.b.fsm.state.link = LinkState::Connected(ConnectedState { + transport: SessionTransport { + tx_key: b_to_a_key, + rx_key: a_to_b_key, + tx_connection_id: b_to_a_conn, + rx_connection_id: a_to_b_conn, + remote_transport_params: TransportParams { + initial_stream_receive_window: harness + .a + .fsm + .config + .session_stream_receive_buffer_size, + }, + }, + session: SessionFsm::new(session_config(&harness, false), harness.now), + }); + harness + } + + fn time(&self) -> Instant { + self.now + } + + fn advance(&mut self, duration: Duration) { + self.now += duration; + } + + fn node(&self, side: Side) -> &Node { + match side { + Side::A => &self.a, + Side::B => &self.b, + } + } + + fn node_mut(&mut self, side: Side) -> &mut Node { + match side { + Side::A => &mut self.a, + Side::B => &mut self.b, + } + } + + fn next_outbound(&mut self, side: Side) -> Option> { + let write = self.next_write(side)?; + if let Some(id) = write.write_id { + self.confirm_write(side, id); + } + Some(write.record) + } + + fn next_write(&mut self, side: Side) -> Option { + let time = self.time(); + let Node { fsm, crypto } = self.node_mut(side); + fsm.take_next_write(time, crypto) + } + + fn next_decoded_outbound(&mut self, side: Side) -> Option { + let write = self.next_write(side)?; + if let Some(id) = write.write_id { + self.confirm_write(side, id); + } + Some(self.decode_session_write(write, side)) + } + + fn next_decoded_write(&mut self, side: Side) -> Option { + let write = self.next_write(side)?; + Some(self.decode_session_write(write, side)) + } + + fn connect_ik(&mut self, side: Side) -> Result<(), NoPeerError> { + let time = self.time(); + let Node { fsm, crypto } = self.node_mut(side); + fsm.connect_ik(time, crypto) + } + + fn connect_kk(&mut self, side: Side) -> Result<(), NoPeerError> { + let time = self.time(); + let Node { fsm, crypto } = self.node_mut(side); + fsm.connect_kk(time, crypto) + } + + fn connect_xx(&mut self, side: Side, token: PairingToken) { + let time = self.time(); + let remote_qid = self.remote_qid(side); + let Node { fsm, crypto } = self.node_mut(side); + fsm.connect_xx( + time, + PairingInvite { + qid: remote_qid, + token, + }, + crypto, + ); + } + + fn remote_qid(&self, side: Side) -> QID { + match side { + Side::A => self.b.fsm.identity.qid, + Side::B => self.a.fsm.identity.qid, + } + } + + fn deliver(&mut self, side: Side, record: Vec) { + let time = self.time(); + let Node { fsm, crypto } = self.node_mut(side); + fsm.receive(time, record, crypto).unwrap(); + } + + fn confirm_write(&mut self, side: Side, write_id: WriteId) { + let time = self.time(); + self.node_mut(side).fsm.complete_write(time, write_id, true); + } + + fn reject_write(&mut self, side: Side, write_id: WriteId) { + let time = self.time(); + self.node_mut(side) + .fsm + .complete_write(time, write_id, false); + } + + fn decode_session_write(&self, write: OutboundWrite, side: Side) -> DecodedSessionWrite { + let peer = self.node(match side { + Side::A => Side::B, + Side::B => Side::A, + }); + let crypto = &peer.crypto; + let session_key = &peer.fsm.state.link.transport().unwrap().rx_key; + let (header, frames) = decrypt_record(crypto, &write.record, session_key); + DecodedSessionWrite { + record: write.record, + write_id: write.write_id, + header, + frames, + } + } + + fn on_timer(&mut self, side: Side) { + let time = self.time(); + self.node_mut(side).fsm.on_timer(time); + } + + fn take_event(&mut self, side: Side) -> Option { + self.node_mut(side).fsm.poll_event() + } + + fn drain_events(&mut self, side: Side) -> Vec { + let mut events = Vec::new(); + while let Some(event) = self.take_event(side) { + events.push(event); + } + events + } + + fn pump(&mut self) { + for _ in 0..128 { + let mut progressed = false; + + while let Some(record) = self.next_outbound(Side::A) { + progressed = true; + self.deliver(Side::B, record); + } + + while let Some(record) = self.next_outbound(Side::B) { + progressed = true; + self.deliver(Side::A, record); + } + + if !progressed { + return; + } + } + + panic!("pump did not quiesce"); + } +} + +fn pairing_token(byte: u8) -> PairingToken { + PairingToken([byte; PairingToken::SIZE]) +} + +fn session_config(harness: &Harness, a: bool) -> SessionConfig { + let (local, peer, config) = if a { + ( + harness.a.fsm.identity.qid, + harness.a.fsm.state.peer.as_ref().unwrap().qid, + harness.a.fsm.config, + ) + } else { + ( + harness.b.fsm.identity.qid, + harness.b.fsm.state.peer.as_ref().unwrap().qid, + harness.b.fsm.config, + ) + }; + + SessionConfig { + local_parity: StreamParity::for_local(local, peer), + record_max_size: config.session_record_max_size, + ack_delay: config.session_record_ack_delay, + retransmit_timeout: config.session_record_retransmit_timeout, + keepalive_interval: config.session_keepalive_interval, + peer_timeout: config.session_peer_timeout, + stream_send_buffer_size: config.session_stream_send_buffer_size, + stream_receive_buffer_size: config.session_stream_receive_buffer_size, + accepted_record_window: config.session_accepted_record_window, + pending_ack_range_limit: config.session_pending_ack_range_limit, + initial_peer_stream_receive_window: if a { + harness.b.fsm.config.session_stream_receive_buffer_size + } else { + harness.a.fsm.config.session_stream_receive_buffer_size + }, + } +} + +fn decrypt_record( + crypto: &impl QlCrypto, + record: &[u8], + session_key: &SessionKey, +) -> (ql_wire::SessionHeader, Vec>>) { + let (_header, record) = + ql_wire::decode_record::, _>(record).unwrap(); + let plaintext = ql_wire::decrypt_record( + crypto, + &record.header, + record.payload.into_owned(), + session_key, + ) + .unwrap(); + ( + record.header, + ql_wire::decode_session_frames(&plaintext).unwrap(), + ) +} diff --git a/ql-fsm/src/tests/proptest.rs b/ql-fsm/src/tests/proptest.rs new file mode 100644 index 00000000..bc97ca77 --- /dev/null +++ b/ql-fsm/src/tests/proptest.rs @@ -0,0 +1,1001 @@ +use std::{ + collections::{BTreeMap, BTreeSet}, + time::Duration, +}; + +extern crate proptest as proptest_crate; + +use bytes::Bytes; +use proptest_crate::{collection::vec, prelude::*, test_runner::TestCaseResult}; +use ql_wire::{CloseTarget, StreamCloseCode, StreamId, WireError}; + +use super::*; + +fn test_route_id() -> ql_wire::RouteId { + ql_wire::RouteId::from_u32(1) +} +use crate::{state::LinkState, Event, PeerStatus, ReceiveError, WriteId}; + +const SLOT_COUNT: usize = 4; + +#[derive(Clone, Debug)] +enum Action { + ConnectIk(Side), + ConnectKk(Side), + AdvanceMs(u8), + OnTimer(Side), + OnTimerBoth, + Pump, + TakeNext(Side), + ConfirmTaken { + side: Side, + index: usize, + }, + RejectTaken { + side: Side, + index: usize, + }, + CaptureNext(Side), + DeliverNext(Side), + DropNext(Side), + DeliverQueued { + side: Side, + index: usize, + }, + DuplicateQueued { + side: Side, + index: usize, + }, + DropQueued { + side: Side, + index: usize, + }, + OpenStream { + side: Side, + slot: usize, + }, + Write { + side: Side, + slot: usize, + bytes: Vec, + }, + Finish { + side: Side, + slot: usize, + }, + Close { + side: Side, + slot: usize, + }, +} + +impl Action { + fn confirm_taken(side: Side, index: usize) -> Self { + Self::ConfirmTaken { side, index } + } + + fn reject_taken(side: Side, index: usize) -> Self { + Self::RejectTaken { side, index } + } + + fn deliver_queued(side: Side, index: usize) -> Self { + Self::DeliverQueued { side, index } + } + + fn duplicate_queued(side: Side, index: usize) -> Self { + Self::DuplicateQueued { side, index } + } + + fn drop_queued(side: Side, index: usize) -> Self { + Self::DropQueued { side, index } + } + + fn open_stream(side: Side, slot: usize) -> Self { + Self::OpenStream { side, slot } + } + + fn write(side: Side, slot: usize, bytes: Vec) -> Self { + Self::Write { side, slot, bytes } + } + + fn finish(side: Side, slot: usize) -> Self { + Self::Finish { side, slot } + } + + fn close(side: Side, slot: usize) -> Self { + Self::Close { side, slot } + } +} + +#[derive(Clone, Debug)] +struct TakenWrite { + record: Vec, + write_id: Option, +} + +#[derive(Default)] +struct SideEventState { + opened: BTreeSet, + finished: BTreeSet, + outbound_finished: BTreeSet, + writable_closed: BTreeSet, + closed: BTreeSet, + peer_statuses: Vec, + last_peer_status: Option, + session_epoch: usize, + session_closed_epoch: Option, +} + +impl SideEventState { + fn note_peer_status(&mut self, status: PeerStatus) { + if status == PeerStatus::Connected && self.last_peer_status != Some(PeerStatus::Connected) { + self.session_epoch = self.session_epoch.saturating_add(1); + } + self.peer_statuses.push(status); + self.last_peer_status = Some(status); + } +} + +struct Runner { + harness: Harness, + slots: [[Option; SLOT_COUNT]; 2], + taken: [Vec; 2], + pending: [Vec>; 2], + receive_errors: Vec<(Side, ReceiveError)>, + events: [SideEventState; 2], + known_streams: BTreeSet, + expected: [BTreeMap>; 2], + received: [BTreeMap>; 2], + finished_by: [BTreeSet; 2], + closed_by: [BTreeSet; 2], +} + +impl Runner { + fn handshake() -> Self { + let config = QlFsmConfig { + handshake_timeout: Duration::from_millis(60), + session_record_ack_delay: Duration::from_millis(5), + session_record_retransmit_timeout: Duration::from_millis(15), + session_peer_timeout: Duration::from_millis(80), + ..QlFsmConfig::default() + }; + + Self { + harness: Harness::paired_known(config), + slots: [[None; SLOT_COUNT]; 2], + taken: [Vec::new(), Vec::new()], + pending: [Vec::new(), Vec::new()], + receive_errors: Vec::new(), + events: [SideEventState::default(), SideEventState::default()], + known_streams: BTreeSet::new(), + expected: [BTreeMap::new(), BTreeMap::new()], + received: [BTreeMap::new(), BTreeMap::new()], + finished_by: [BTreeSet::new(), BTreeSet::new()], + closed_by: [BTreeSet::new(), BTreeSet::new()], + } + } + + fn connected() -> Self { + let config = QlFsmConfig { + session_record_ack_delay: Duration::from_millis(5), + session_record_retransmit_timeout: Duration::from_millis(15), + session_peer_timeout: Duration::from_secs(5), + ..QlFsmConfig::default() + }; + Self::connected_with_config(config) + } + + fn connected_with_config(config: QlFsmConfig) -> Self { + let connected_events = || SideEventState { + last_peer_status: Some(PeerStatus::Connected), + session_epoch: 1, + ..SideEventState::default() + }; + + Self { + harness: Harness::connected(config), + slots: [[None; SLOT_COUNT]; 2], + taken: [Vec::new(), Vec::new()], + pending: [Vec::new(), Vec::new()], + receive_errors: Vec::new(), + events: [connected_events(), connected_events()], + known_streams: BTreeSet::new(), + expected: [BTreeMap::new(), BTreeMap::new()], + received: [BTreeMap::new(), BTreeMap::new()], + finished_by: [BTreeSet::new(), BTreeSet::new()], + closed_by: [BTreeSet::new(), BTreeSet::new()], + } + } + + fn run(&mut self, actions: &[Action]) -> TestCaseResult { + for action in actions { + self.apply(action); + self.observe_and_assert()?; + } + + self.cleanup()?; + self.observe_and_assert()?; + self.assert_terminal_semantics()?; + self.assert_quiesced() + } + + #[allow(clippy::cognitive_complexity, clippy::too_many_lines)] + fn apply(&mut self, action: &Action) { + match action { + Action::ConnectIk(side) => { + let _ = self.harness.connect_ik(*side); + } + Action::ConnectKk(side) => { + let _ = self.harness.connect_kk(*side); + } + Action::AdvanceMs(ms) => { + self.harness + .advance(Duration::from_millis(u64::from(*ms) + 1)); + } + Action::OnTimer(side) => self.harness.on_timer(*side), + Action::OnTimerBoth => { + self.harness.on_timer(Side::A); + self.harness.on_timer(Side::B); + } + Action::Pump => self.capture_all_outbound(), + Action::TakeNext(side) => { + if let Some(write) = take_unconfirmed_outbound(&mut self.harness, *side) { + self.taken[side.idx()].push(write); + } + } + Action::ConfirmTaken { side, index } => { + if let Some(write) = take_taken(&mut self.taken[side.idx()], *index) { + confirm_taken(&mut self.harness, *side, &write); + self.pending[side.idx()].push(write.record); + } + } + Action::RejectTaken { side, index } => { + if let Some(write) = take_taken(&mut self.taken[side.idx()], *index) { + reject_taken(&mut self.harness, *side, &write); + } + } + Action::CaptureNext(side) => { + if let Some(record) = take_confirmed_outbound(&mut self.harness, *side) { + self.pending[side.idx()].push(record); + } + } + Action::DeliverNext(side) => { + if let Some(record) = take_confirmed_outbound(&mut self.harness, *side) { + self.deliver_to(opposite(*side), record); + } + } + Action::DropNext(side) => { + let _ = take_confirmed_outbound(&mut self.harness, *side); + } + Action::DeliverQueued { side, index } => { + if let Some(record) = take_pending(&mut self.pending[side.idx()], *index) { + self.deliver_to(opposite(*side), record); + } + } + Action::DuplicateQueued { side, index } => { + if let Some(record) = peek_pending(&self.pending[side.idx()], *index) { + self.deliver_to(opposite(*side), record); + } + } + Action::DropQueued { side, index } => { + let _ = take_pending(&mut self.pending[side.idx()], *index); + } + Action::OpenStream { side, slot } => { + let stream_id = self + .harness + .node_mut(*side) + .fsm + .open_stream(test_route_id()) + .ok() + .map(|stream| stream.stream_id()); + if let Some(stream_id) = stream_id { + self.slots[side.idx()][*slot] = Some(stream_id); + self.known_streams.insert(stream_id); + } + } + Action::Write { side, slot, bytes } => { + if let Some(stream_id) = self.slots[side.idx()][*slot] { + let mut chunk = Bytes::copy_from_slice(bytes); + let accepted = self.harness.node_mut(*side).fsm.stream(stream_id).map_or( + 0, + |mut stream| { + stream + .writer() + .map_or(0, |mut writer| writer.write(&mut chunk)) + }, + ); + if accepted != 0 { + self.expected[opposite(*side).idx()] + .entry(stream_id) + .or_default() + .extend_from_slice(&bytes[..accepted]); + } + } + } + Action::Finish { side, slot } => { + if let Some(stream_id) = self.slots[side.idx()][*slot] { + let finished = self + .harness + .node_mut(*side) + .fsm + .stream(stream_id) + .is_ok_and(|mut stream| { + stream.writer().is_some_and(|writer| { + writer.finish(); + true + }) + }); + if finished { + self.finished_by[side.idx()].insert(stream_id); + } + } + } + Action::Close { side, slot } => { + if let Some(stream_id) = self.slots[side.idx()][*slot] { + let closed = self + .harness + .node_mut(*side) + .fsm + .stream(stream_id) + .is_ok_and(|mut stream| { + stream.close(CloseTarget::Both, StreamCloseCode::CANCELLED); + true + }); + if closed { + self.closed_by[side.idx()].insert(stream_id); + self.slots[side.idx()][*slot] = None; + } + } + } + } + } + + fn observe_and_assert(&mut self) -> TestCaseResult { + self.drain_reads(Side::A); + self.drain_reads(Side::B); + let events_a = self.harness.drain_events(Side::A); + let events_b = self.harness.drain_events(Side::B); + self.process_events(Side::A, events_a)?; + self.process_events(Side::B, events_b)?; + self.assert_prefix_invariants()?; + self.assert_legal_link_state()?; + self.assert_receive_errors() + } + + fn cleanup(&mut self) -> TestCaseResult { + let tick = self + .harness + .a + .fsm + .config + .session_record_retransmit_timeout + .max(self.harness.a.fsm.config.session_record_ack_delay) + + Duration::from_millis(1); + + self.reject_all_taken(); + + for _ in 0..12 { + self.capture_all_outbound(); + self.flush_pending_in_order(); + self.capture_all_outbound(); + self.flush_pending_in_order(); + self.observe_and_assert()?; + self.harness.advance(tick); + self.harness.on_timer(Side::A); + self.harness.on_timer(Side::B); + self.capture_all_outbound(); + self.flush_pending_in_order(); + self.observe_and_assert()?; + self.reject_all_taken(); + } + + Ok(()) + } + + fn drain_reads(&mut self, side: Side) { + for stream_id in self.known_streams.clone() { + let appended = drain_stream(&mut self.harness.node_mut(side).fsm, stream_id); + if !appended.is_empty() { + self.received[side.idx()] + .entry(stream_id) + .or_default() + .extend_from_slice(&appended); + } + } + } + + fn process_events(&mut self, side: Side, events: Vec) -> TestCaseResult { + for event in events { + match event { + Event::NewPeer => {} + Event::PeerStatusChanged(status) => { + if status == PeerStatus::Unpaired { + let state = &mut self.events[side.idx()]; + prop_assert!( + state.session_epoch > 0, + "side {side:?} emitted Unpaired without a connected session" + ); + prop_assert!( + state.session_closed_epoch != Some(state.session_epoch), + "side {side:?} emitted duplicate terminal event in session epoch {}", + state.session_epoch + ); + state.session_closed_epoch = Some(state.session_epoch); + } + self.events[side.idx()].note_peer_status(status); + } + Event::Opened { stream_id, .. } => { + prop_assert!( + self.known_streams.contains(&stream_id), + "side {side:?} emitted Opened for unknown stream {stream_id:?}" + ); + prop_assert!( + self.events[side.idx()].opened.insert(stream_id), + "side {side:?} emitted duplicate Opened for {stream_id:?}" + ); + } + Event::Readable(stream_id) | Event::Writable(stream_id) => { + prop_assert!( + self.known_streams.contains(&stream_id), + "side {side:?} emitted readiness for unknown stream {stream_id:?}" + ); + } + Event::Finished(stream_id) => { + prop_assert!( + self.known_streams.contains(&stream_id), + "side {side:?} emitted Finished for unknown stream {stream_id:?}" + ); + prop_assert!( + self.events[side.idx()].finished.insert(stream_id), + "side {side:?} emitted duplicate Finished for {stream_id:?}" + ); + prop_assert!( + !self.events[side.idx()].closed.contains(&stream_id), + "side {side:?} emitted Finished after Closed for {stream_id:?}" + ); + } + Event::OutboundFinished(stream_id) => { + prop_assert!( + self.known_streams.contains(&stream_id), + "side {side:?} emitted OutboundFinished for unknown stream {stream_id:?}" + ); + prop_assert!( + self.events[side.idx()].outbound_finished.insert(stream_id), + "side {side:?} emitted duplicate OutboundFinished for {stream_id:?}" + ); + } + Event::Closed(frame) => { + prop_assert!( + self.known_streams.contains(&frame.stream_id), + "side {side:?} emitted Closed for unknown stream {:?}", + frame.stream_id + ); + prop_assert!( + self.events[side.idx()].closed.insert(frame.stream_id), + "side {side:?} emitted duplicate Closed for {:?}", + frame.stream_id + ); + } + Event::WritableClosed(frame) => { + let stream_id = frame.stream_id; + prop_assert!( + self.known_streams.contains(&stream_id), + "side {side:?} emitted WritableClosed for unknown stream {stream_id:?}" + ); + prop_assert!( + self.events[side.idx()].writable_closed.insert(stream_id), + "side {side:?} emitted duplicate WritableClosed for {stream_id:?}" + ); + } + Event::SessionClosed(_) => { + let state = &mut self.events[side.idx()]; + prop_assert!( + state.session_epoch > 0, + "side {side:?} emitted SessionClosed without a connected session" + ); + prop_assert!( + state.session_closed_epoch != Some(state.session_epoch), + "side {side:?} emitted duplicate SessionClosed in session epoch {}", + state.session_epoch + ); + state.session_closed_epoch = Some(state.session_epoch); + } + } + } + + Ok(()) + } + + fn assert_prefix_invariants(&self) -> TestCaseResult { + for side in [Side::A, Side::B] { + for (stream_id, received) in &self.received[side.idx()] { + let expected = self.expected[side.idx()] + .get(stream_id) + .map_or(&[][..], Vec::as_slice); + prop_assert!( + expected.starts_with(received), + "side {side:?} observed non-prefix bytes on {stream_id:?}: received={received:?} expected={expected:?}" + ); + } + } + + Ok(()) + } + + fn assert_legal_link_state(&self) -> TestCaseResult { + let a_connected = matches!(self.harness.a.fsm.state.link, LinkState::Connected(_)); + let b_connected = matches!(self.harness.b.fsm.state.link, LinkState::Connected(_)); + + prop_assert!( + !a_connected || self.harness.a.fsm.peer().is_some(), + "side A reached Connected without a bound peer" + ); + prop_assert!( + !b_connected || self.harness.b.fsm.peer().is_some(), + "side B reached Connected without a bound peer" + ); + + Ok(()) + } + + fn assert_receive_errors(&self) -> TestCaseResult { + for (side, error) in &self.receive_errors { + prop_assert!( + matches!( + error, + ReceiveError::NoSession + | ReceiveError::NoPeer + | ReceiveError::InvalidRemoteBundle + | ReceiveError::InvalidSessionPayload(WireError::InvalidPayload) + | ReceiveError::InvalidSessionPayload(WireError::DecryptFailed) + | ReceiveError::InvalidIkHandshake(WireError::InvalidPayload) + | ReceiveError::InvalidIkHandshake(WireError::InvalidState) + | ReceiveError::InvalidKkHandshake(WireError::InvalidPayload) + | ReceiveError::InvalidKkHandshake(WireError::InvalidState) + | ReceiveError::InvalidXxHandshake(WireError::InvalidPayload) + | ReceiveError::InvalidXxHandshake(WireError::InvalidState) + | ReceiveError::InvalidXxHandshake(WireError::DecryptFailed) + ), + "unexpected receive error on side {side:?}: {error:?}" + ); + } + + Ok(()) + } + + fn assert_terminal_semantics(&self) -> TestCaseResult { + let a_connected = matches!(self.harness.a.fsm.state.link, LinkState::Connected(_)); + let b_connected = matches!(self.harness.b.fsm.state.link, LinkState::Connected(_)); + let connected = [a_connected, b_connected]; + + for side in [Side::A, Side::B] { + for stream_id in &self.events[side.idx()].finished { + if self.inbound_aborted(side, *stream_id) { + continue; + } + let expected = self.expected[side.idx()] + .get(stream_id) + .map_or(&[][..], Vec::as_slice); + let received = self.received[side.idx()] + .get(stream_id) + .map_or(&[][..], Vec::as_slice); + prop_assert_eq!( + received, + expected, + "side {:?} finished {:?} without receiving all expected bytes", + side, + stream_id + ); + } + + for stream_id in &self.finished_by[side.idx()] { + prop_assert!( + self.events[opposite(side).idx()].finished.contains(stream_id) + || self.events[opposite(side).idx()].closed.contains(stream_id) + || !connected[opposite(side).idx()], + "side {side:?} finished {stream_id:?} but side {:?} saw neither Finished nor Closed", + opposite(side) + ); + } + + for stream_id in &self.closed_by[side.idx()] { + prop_assert!( + self.events[opposite(side).idx()].closed.contains(stream_id) + || !connected[opposite(side).idx()], + "side {side:?} closed {stream_id:?} but side {:?} saw no Closed event", + opposite(side) + ); + } + } + + Ok(()) + } + + fn assert_expected_delivered(&self, side: Side) -> TestCaseResult { + for (stream_id, expected) in &self.expected[side.idx()] { + let received = self.received[side.idx()] + .get(stream_id) + .map_or(&[][..], Vec::as_slice); + prop_assert_eq!( + received, + expected, + "side {:?} did not receive full payload for {:?}", + side, + stream_id + ); + } + + Ok(()) + } + + fn assert_no_stream_events(&self) -> TestCaseResult { + prop_assert!( + self.known_streams.is_empty() + && self.events.iter().all(|events| { + events.opened.is_empty() + && events.finished.is_empty() + && events.outbound_finished.is_empty() + && events.closed.is_empty() + && events.writable_closed.is_empty() + }), + "handshake-only property observed stream activity" + ); + Ok(()) + } + + fn assert_no_taken_writes(&self) -> TestCaseResult { + prop_assert!( + self.taken.iter().all(Vec::is_empty), + "cleanup left taken writes queued" + ); + Ok(()) + } + + fn assert_quiesced(&mut self) -> TestCaseResult { + self.reject_all_taken(); + + for _ in 0..8 { + self.capture_all_outbound(); + if self.pending.iter().all(Vec::is_empty) { + break; + } + self.flush_pending_in_order(); + self.observe_and_assert()?; + } + + self.capture_all_outbound(); + prop_assert!( + self.pending.iter().all(Vec::is_empty) && self.taken.iter().all(Vec::is_empty), + "cleanup did not quiesce: taken_a={} taken_b={} pending_a={} pending_b={}", + self.taken[Side::A.idx()].len(), + self.taken[Side::B.idx()].len(), + self.pending[Side::A.idx()].len(), + self.pending[Side::B.idx()].len() + ); + + Ok(()) + } + + fn capture_all_outbound(&mut self) { + for side in [Side::A, Side::B] { + while let Some(record) = take_confirmed_outbound(&mut self.harness, side) { + self.pending[side.idx()].push(record); + } + } + } + + fn flush_pending_in_order(&mut self) { + for side in [Side::A, Side::B] { + while let Some(record) = pop_front_pending(&mut self.pending[side.idx()]) { + self.deliver_to(opposite(side), record); + } + } + } + + fn reject_all_taken(&mut self) { + for side in [Side::A, Side::B] { + while let Some(write) = self.taken[side.idx()].pop() { + reject_taken(&mut self.harness, side, &write); + } + } + } + + fn deliver_to(&mut self, side: Side, record: Vec) { + if let Err(error) = deliver_to(&mut self.harness, side, record) { + self.receive_errors.push((side, error)); + } + } + + fn inbound_aborted(&self, side: Side, stream_id: StreamId) -> bool { + self.events[side.idx()].closed.contains(&stream_id) + || self.closed_by[side.idx()].contains(&stream_id) + } +} + +fn take_unconfirmed_outbound(harness: &mut Harness, side: Side) -> Option { + let write = harness.next_write(side)?; + Some(TakenWrite { + record: write.record, + write_id: write.write_id, + }) +} + +fn take_confirmed_outbound(harness: &mut Harness, side: Side) -> Option> { + let write = take_unconfirmed_outbound(harness, side)?; + confirm_taken(harness, side, &write); + Some(write.record) +} + +fn confirm_taken(harness: &mut Harness, side: Side, write: &TakenWrite) { + if let Some(write_id) = write.write_id { + harness.confirm_write(side, write_id); + } +} + +fn reject_taken(harness: &mut Harness, side: Side, write: &TakenWrite) { + if let Some(write_id) = write.write_id { + harness.reject_write(side, write_id); + } +} + +fn deliver_to(harness: &mut Harness, side: Side, record: Vec) -> Result<(), ReceiveError> { + let time = harness.time(); + let Node { fsm, crypto } = harness.node_mut(side); + fsm.receive(time, record, crypto) +} + +fn take_pending(pending: &mut Vec>, index: usize) -> Option> { + if pending.is_empty() { + return None; + } + + Some(pending.remove(index % pending.len())) +} + +fn peek_pending(pending: &[Vec], index: usize) -> Option> { + if pending.is_empty() { + return None; + } + + Some(pending[index % pending.len()].clone()) +} + +fn pop_front_pending(pending: &mut Vec>) -> Option> { + if pending.is_empty() { + None + } else { + Some(pending.remove(0)) + } +} + +fn take_taken(taken: &mut Vec, index: usize) -> Option { + if taken.is_empty() { + return None; + } + + Some(taken.remove(index % taken.len())) +} + +fn drain_stream(fsm: &mut QlFsm, stream_id: StreamId) -> Vec { + let mut out = Vec::new(); + let Ok(mut stream) = fsm.stream(stream_id) else { + return out; + }; + + loop { + let mut read = 0usize; + for chunk in stream.read() { + out.extend_from_slice(&chunk); + read += chunk.len(); + } + + if read == 0 { + break; + } + + stream.commit_read(read).unwrap(); + } + + out +} + +fn opposite(side: Side) -> Side { + match side { + Side::A => Side::B, + Side::B => Side::A, + } +} + +fn side_strategy() -> impl Strategy { + prop_oneof![Just(Side::A), Just(Side::B)] +} + +fn side_action(f: fn(Side) -> Action) -> impl Strategy { + side_strategy().prop_map(f) +} + +fn side_usize_action( + values: impl Strategy, + f: fn(Side, usize) -> Action, +) -> impl Strategy { + (side_strategy(), values).prop_map(move |(side, value)| f(side, value)) +} + +fn side_usize_vec_action( + values: impl Strategy, + bytes: impl Strategy>, + f: fn(Side, usize, Vec) -> Action, +) -> impl Strategy { + (side_strategy(), values, bytes).prop_map(move |(side, value, bytes)| f(side, value, bytes)) +} + +fn handshake_action_strategy() -> impl Strategy { + let queue_index = 0usize..6; + prop_oneof![ + side_action(Action::ConnectIk), + side_action(Action::ConnectKk), + (0u8..40).prop_map(Action::AdvanceMs), + side_action(Action::OnTimer), + Just(Action::OnTimerBoth), + Just(Action::Pump), + side_action(Action::TakeNext), + side_usize_action(queue_index.clone(), Action::confirm_taken), + side_usize_action(queue_index.clone(), Action::reject_taken), + side_action(Action::CaptureNext), + side_action(Action::DeliverNext), + side_action(Action::DropNext), + side_usize_action(queue_index.clone(), Action::deliver_queued), + side_usize_action(queue_index.clone(), Action::duplicate_queued), + side_usize_action(queue_index, Action::drop_queued), + ] +} + +fn connected_action_strategy() -> impl Strategy { + let bytes = vec(any::(), 0..24); + let slot = 0usize..SLOT_COUNT; + let queue_index = 0usize..6; + prop_oneof![ + (0u8..30).prop_map(Action::AdvanceMs), + side_action(Action::OnTimer), + Just(Action::OnTimerBoth), + Just(Action::Pump), + side_action(Action::TakeNext), + side_usize_action(queue_index.clone(), Action::confirm_taken), + side_usize_action(queue_index.clone(), Action::reject_taken), + side_action(Action::CaptureNext), + side_action(Action::DeliverNext), + side_action(Action::DropNext), + side_usize_action(queue_index.clone(), Action::deliver_queued), + side_usize_action(queue_index.clone(), Action::duplicate_queued), + side_usize_action(queue_index, Action::drop_queued), + side_usize_action(slot.clone(), Action::open_stream), + side_usize_vec_action(slot.clone(), bytes, Action::write), + side_usize_action(slot.clone(), Action::finish), + side_usize_action(slot, Action::close), + ] +} + +fn write_tracking_action_strategy() -> impl Strategy { + let bytes = vec(any::(), 0..16); + let slot = 0usize..SLOT_COUNT; + let queue_index = 0usize..6; + prop_oneof![ + side_usize_action(slot.clone(), Action::open_stream), + side_usize_vec_action(slot, bytes, Action::write), + side_action(Action::TakeNext), + side_usize_action(queue_index.clone(), Action::confirm_taken), + side_usize_action(queue_index.clone(), Action::reject_taken), + side_usize_action(queue_index.clone(), Action::deliver_queued), + side_usize_action(queue_index.clone(), Action::duplicate_queued), + side_usize_action(queue_index, Action::drop_queued), + Just(Action::Pump), + side_action(Action::OnTimer), + Just(Action::OnTimerBoth), + (0u8..20).prop_map(Action::AdvanceMs), + ] +} + +fn packet_loss_recovery_action_strategy() -> impl Strategy { + let queue_index = 0usize..16; + prop_oneof![ + (0u8..20).prop_map(Action::AdvanceMs), + side_action(Action::OnTimer), + Just(Action::OnTimerBoth), + Just(Action::Pump), + side_usize_action(queue_index.clone(), Action::deliver_queued), + side_usize_action(queue_index.clone(), Action::duplicate_queued), + side_usize_action(queue_index, Action::drop_queued), + ] +} + +fn terminal_action_strategy() -> impl Strategy { + let bytes = vec(any::(), 0..16); + let slot = 0usize..SLOT_COUNT; + let queue_index = 0usize..6; + prop_oneof![ + side_usize_action(slot.clone(), Action::open_stream), + side_usize_vec_action(slot.clone(), bytes, Action::write), + side_usize_action(slot.clone(), Action::finish), + side_usize_action(slot, Action::close), + side_action(Action::TakeNext), + side_usize_action(queue_index.clone(), Action::confirm_taken), + side_usize_action(queue_index.clone(), Action::reject_taken), + side_usize_action(queue_index.clone(), Action::deliver_queued), + side_usize_action(queue_index.clone(), Action::duplicate_queued), + side_usize_action(queue_index, Action::drop_queued), + Just(Action::Pump), + side_action(Action::OnTimer), + Just(Action::OnTimerBoth), + (0u8..20).prop_map(Action::AdvanceMs), + ] +} + +proptest_crate::proptest! { + #![proptest_config(ProptestConfig { + cases: 24, + max_shrink_iters: 10_000, + .. ProptestConfig::default() + })] + + #[test] + fn randomized_handshake_actions_quiesce(actions in vec(handshake_action_strategy(), 1..64)) { + let mut runner = Runner::handshake(); + runner.run(&actions)?; + runner.assert_no_stream_events()?; + } + + #[test] + fn randomized_stream_actions_preserve_integrity(actions in vec(connected_action_strategy(), 1..80)) { + let mut runner = Runner::connected(); + runner.run(&actions)?; + } + + #[test] + fn randomized_write_tracking_actions_quiesce(actions in vec(write_tracking_action_strategy(), 1..80)) { + let mut runner = Runner::connected(); + runner.run(&actions)?; + runner.assert_no_taken_writes()?; + } + + #[test] + fn randomized_session_packet_loss_recovers( + payload in vec(any::(), 512..2048), + actions in vec(packet_loss_recovery_action_strategy(), 1..96), + ) { + let config = QlFsmConfig { + session_record_ack_delay: Duration::from_millis(1), + session_record_retransmit_timeout: Duration::from_millis(10), + session_record_max_size: 96, + session_pending_ack_range_limit: 512, + ..QlFsmConfig::default() + }; + let mut runner = Runner::connected_with_config(config); + + runner.apply(&Action::open_stream(Side::A, 0)); + runner.observe_and_assert()?; + + runner.apply(&Action::write(Side::A, 0, payload)); + runner.observe_and_assert()?; + + runner.apply(&Action::finish(Side::A, 0)); + runner.observe_and_assert()?; + + for action in &actions { + runner.apply(action); + runner.observe_and_assert()?; + } + + runner.cleanup()?; + runner.observe_and_assert()?; + runner.assert_expected_delivered(Side::B)?; + runner.assert_terminal_semantics()?; + runner.assert_quiesced()?; + } + + #[test] + fn randomized_terminal_actions_preserve_terminal_semantics(actions in vec(terminal_action_strategy(), 1..80)) { + let mut runner = Runner::connected(); + runner.run(&actions)?; + runner.assert_terminal_semantics()?; + } +} diff --git a/ql-fsm/src/tests/session.rs b/ql-fsm/src/tests/session.rs new file mode 100644 index 00000000..c55e51c1 --- /dev/null +++ b/ql-fsm/src/tests/session.rs @@ -0,0 +1,532 @@ +use std::time::Duration; + +use bytes::Bytes; +use ql_wire::{RouteId, SessionClose, StreamId, VarInt}; + +use super::*; +use crate::{state::LinkState, CommitReadError, Event, NoSessionError, PeerStatus, StreamError}; + +fn stream_id(value: u32) -> StreamId { + StreamId(VarInt::from_u32(value)) +} + +fn route_id(value: u32) -> RouteId { + RouteId::from_u32(value) +} + +fn opened(stream_id: StreamId) -> Event { + Event::Opened { + stream_id, + route_id: route_id(1), + } +} + +fn open_stream_id(fsm: &mut QlFsm) -> StreamId { + fsm.open_stream(route_id(1)).unwrap().stream_id() +} + +fn write_stream_bytes( + fsm: &mut QlFsm, + stream_id: StreamId, + bytes: &[u8], +) -> Result { + let mut bytes = Bytes::copy_from_slice(bytes); + let mut stream = fsm.stream(stream_id)?; + let Some(mut writer) = stream.writer() else { + return Err(StreamError::NotWritable); + }; + Ok(writer.write(&mut bytes)) +} + +fn read_stream_all(fsm: &mut QlFsm, stream_id: StreamId) -> Vec { + let mut out = Vec::new(); + let Ok(mut stream) = fsm.stream(stream_id) else { + return out; + }; + loop { + let mut read = 0; + for chunk in stream.read() { + out.extend_from_slice(&chunk); + read += chunk.len(); + } + if read == 0 { + break; + } + stream.commit_read(read).unwrap(); + } + out +} + +#[test] +fn connected_fsms_deliver_stream_data() { + let mut harness = Harness::connected(QlFsmConfig::default()); + + let stream_id = open_stream_id(&mut harness.a.fsm); + assert_eq!( + write_stream_bytes(&mut harness.a.fsm, stream_id, b"hello").unwrap(), + 5 + ); + harness + .a + .fsm + .stream(stream_id) + .unwrap() + .writer() + .unwrap() + .finish(); + + harness.pump(); + + assert_eq!(harness.take_event(Side::B), Some(opened(stream_id))); + assert_eq!( + harness.take_event(Side::B), + Some(Event::Readable(stream_id)) + ); + assert_eq!( + read_stream_all(&mut harness.b.fsm, stream_id), + b"hello".to_vec() + ); + assert_eq!( + harness.take_event(Side::B), + Some(Event::Finished(stream_id)) + ); + harness.advance(QlFsmConfig::default().session_record_ack_delay); + harness.on_timer(Side::B); + harness.pump(); + assert_eq!( + harness.take_event(Side::A), + Some(Event::OutboundFinished(stream_id)) + ); +} + +#[test] +fn session_retransmit_uses_new_record_seq() { + let config = QlFsmConfig::default(); + let mut harness = Harness::connected(config); + + let stream_id = open_stream_id(&mut harness.a.fsm); + assert_eq!( + write_stream_bytes(&mut harness.a.fsm, stream_id, b"retry").unwrap(), + 5 + ); + + let first = harness.next_decoded_outbound(Side::A).unwrap(); + + harness.advance(config.session_record_retransmit_timeout + Duration::from_millis(1)); + harness.on_timer(Side::A); + + let retried = harness.next_decoded_outbound(Side::A).unwrap(); + + assert_ne!(retried.header.seq, first.header.seq); + assert_eq!(retried.frames, first.frames); + + harness.deliver(Side::B, retried.record); + harness.advance(config.session_record_ack_delay); + harness.on_timer(Side::A); + harness.on_timer(Side::B); + harness.pump(); + + assert_eq!(harness.take_event(Side::B), Some(opened(stream_id))); + assert_eq!( + harness.take_event(Side::B), + Some(Event::Readable(stream_id)) + ); + assert_eq!( + read_stream_all(&mut harness.b.fsm, stream_id), + b"retry".to_vec() + ); + + harness.advance(config.session_record_retransmit_timeout + Duration::from_millis(1)); + harness.on_timer(Side::A); + assert!(harness.next_outbound(Side::A).is_none()); +} + +#[test] +fn simultaneous_opens_use_even_and_odd_stream_ids() { + let mut harness = Harness::connected(QlFsmConfig::default()); + + let stream_id_a = open_stream_id(&mut harness.a.fsm); + let stream_id_b = open_stream_id(&mut harness.b.fsm); + + assert_ne!(stream_id_a, stream_id_b); + assert!( + StreamParity::for_local(harness.a.fsm.identity.qid, harness.b.fsm.identity.qid) + .matches(stream_id_a) + ); + assert!( + StreamParity::for_local(harness.b.fsm.identity.qid, harness.a.fsm.identity.qid) + .matches(stream_id_b) + ); + + assert_eq!( + write_stream_bytes(&mut harness.a.fsm, stream_id_a, b"from-a").unwrap(), + 6 + ); + assert_eq!( + write_stream_bytes(&mut harness.b.fsm, stream_id_b, b"from-b").unwrap(), + 6 + ); + + harness.pump(); + + assert_eq!(harness.take_event(Side::A), Some(opened(stream_id_b))); + assert_eq!( + harness.take_event(Side::A), + Some(Event::Readable(stream_id_b)) + ); + assert_eq!( + read_stream_all(&mut harness.a.fsm, stream_id_b), + b"from-b".to_vec() + ); + assert_eq!(harness.take_event(Side::B), Some(opened(stream_id_a))); + assert_eq!( + harness.take_event(Side::B), + Some(Event::Readable(stream_id_a)) + ); + assert_eq!( + read_stream_all(&mut harness.b.fsm, stream_id_a), + b"from-a".to_vec() + ); +} + +#[test] +fn disconnected_stream_operations_fail_with_no_session() { + let mut harness = Harness::paired_known(QlFsmConfig::default()); + let missing = stream_id(0); + + assert!(matches!( + harness.a.fsm.open_stream(route_id(1)), + Err(NoSessionError) + )); + assert_eq!( + write_stream_bytes(&mut harness.a.fsm, missing, b"queued"), + Err(StreamError::NoSession) + ); + assert_eq!( + harness + .a + .fsm + .stream(missing) + .map(|mut stream| stream.writer().unwrap().finish()), + Err(StreamError::NoSession) + ); + assert_eq!( + harness.a.fsm.stream(missing).map(|mut stream| { + stream.close( + ql_wire::CloseTarget::Both, + ql_wire::StreamCloseCode::CANCELLED, + ); + }), + Err(StreamError::NoSession) + ); + assert_eq!(harness.a.fsm.queue_ping(), Err(NoSessionError)); + assert!(matches!( + harness.a.fsm.stream(missing), + Err(StreamError::NoSession) + )); +} + +#[test] +fn disconnected_stream_read_accessors_return_none() { + let mut harness = Harness::paired_known(QlFsmConfig::default()); + let missing = stream_id(0); + + assert!(matches!( + harness.a.fsm.stream(missing), + Err(StreamError::NoSession) + )); +} + +#[test] +fn commit_read_rejects_lengths_past_readable_prefix() { + let mut harness = Harness::connected(QlFsmConfig::default()); + + let stream_id = open_stream_id(&mut harness.a.fsm); + assert_eq!( + write_stream_bytes(&mut harness.a.fsm, stream_id, b"hi").unwrap(), + 2 + ); + harness.pump(); + + let mut stream = harness.b.fsm.stream(stream_id).unwrap(); + assert_eq!(stream.commit_read(3), Err(CommitReadError)); +} + +#[test] +fn returned_session_write_is_reissued_with_new_record_seq() { + let mut harness = Harness::connected(QlFsmConfig::default()); + + let stream_id = open_stream_id(&mut harness.a.fsm); + assert_eq!( + write_stream_bytes(&mut harness.a.fsm, stream_id, b"retry").unwrap(), + 5 + ); + + let first = harness.next_decoded_write(Side::A).unwrap(); + let id = first.write_id.expect("expected session write"); + + harness.reject_write(Side::A, id); + + let reissued = harness.next_decoded_write(Side::A).unwrap(); + let reissued_id = reissued.write_id.expect("expected reissued write"); + + assert_ne!(reissued_id, id); + assert_ne!(reissued.header.seq, first.header.seq); + assert_eq!(reissued.frames, first.frames); + + harness.confirm_write(Side::A, reissued_id); + harness.deliver(Side::B, reissued.record); + harness.pump(); + + assert_eq!(harness.take_event(Side::B), Some(opened(stream_id))); + assert_eq!( + harness.take_event(Side::B), + Some(Event::Readable(stream_id)) + ); + assert_eq!( + read_stream_all(&mut harness.b.fsm, stream_id), + b"retry".to_vec() + ); +} + +#[test] +fn unconfirmed_session_write_does_not_start_retransmit_timer() { + let config = QlFsmConfig::default(); + let mut harness = Harness::connected(config); + + let stream_id = open_stream_id(&mut harness.a.fsm); + assert_eq!( + write_stream_bytes(&mut harness.a.fsm, stream_id, b"retry").unwrap(), + 5 + ); + + let first = harness.next_decoded_write(Side::A).unwrap(); + let id = first.write_id.expect("expected session write"); + + harness.advance(config.session_record_retransmit_timeout + Duration::from_millis(1)); + harness.on_timer(Side::A); + assert!(harness.next_write(Side::A).is_none()); + + harness.confirm_write(Side::A, id); + harness.advance(config.session_record_retransmit_timeout + Duration::from_millis(1)); + harness.on_timer(Side::A); + + let retried = harness.next_decoded_write(Side::A).unwrap(); + + assert_ne!(retried.header.seq, first.header.seq); + assert_eq!(retried.frames, first.frames); +} + +#[test] +fn ack_frame_releases_stream_capacity_and_emits_writable() { + let config = QlFsmConfig { + session_stream_send_buffer_size: 4, + ..QlFsmConfig::default() + }; + let mut harness = Harness::connected(config); + + let stream_id = open_stream_id(&mut harness.a.fsm); + assert_eq!( + write_stream_bytes(&mut harness.a.fsm, stream_id, b"abcd").unwrap(), + 4 + ); + assert_eq!( + write_stream_bytes(&mut harness.a.fsm, stream_id, b"z").unwrap(), + 0 + ); + + let record = harness.next_outbound(Side::A).unwrap(); + harness.deliver(Side::B, record); + harness.advance(config.session_record_ack_delay); + harness.on_timer(Side::A); + harness.on_timer(Side::B); + harness.pump(); + + assert_eq!( + harness.take_event(Side::A), + Some(Event::Writable(stream_id)) + ); +} + +#[test] +fn close_session_disconnects_locally() { + let mut harness = Harness::connected(QlFsmConfig::default()); + + harness + .a + .fsm + .close_session(ql_wire::SessionCloseCode::CANCELLED); + + assert!(matches!( + harness.take_event(Side::A), + Some(Event::SessionClosed(SessionClose { + code: ql_wire::SessionCloseCode::CANCELLED, + })) + )); + assert!(matches!(harness.a.fsm.state.link, LinkState::Connected(_))); + assert!(matches!( + harness.a.fsm.open_stream(route_id(1)), + Err(NoSessionError) + )); + assert_eq!(harness.a.fsm.queue_ping(), Err(NoSessionError)); + + let close = harness.next_decoded_outbound(Side::A).unwrap(); + assert!(matches!( + close.frames.as_slice(), + [ql_wire::SessionFrame::Close(_)] + )); + + assert!(matches!(harness.a.fsm.state.link, LinkState::Idle)); + assert_eq!( + harness.take_event(Side::A), + Some(Event::PeerStatusChanged(PeerStatus::Disconnected)) + ); +} + +#[test] +fn unpair_clears_bound_peer_and_emits_unpair_frame() { + let mut harness = Harness::connected(QlFsmConfig::default()); + + harness.a.fsm.unpair(); + + assert_eq!( + harness.take_event(Side::A), + Some(Event::PeerStatusChanged(PeerStatus::Unpaired)) + ); + assert!(harness.a.fsm.peer().is_none()); + assert!(matches!( + harness.a.fsm.open_stream(route_id(1)), + Err(NoSessionError) + )); + assert_eq!(harness.a.fsm.queue_ping(), Err(NoSessionError)); + + let unpair = harness.next_decoded_outbound(Side::A).unwrap(); + assert!(matches!( + unpair.frames.as_slice(), + [ql_wire::SessionFrame::Unpair] + )); + assert!(matches!(harness.a.fsm.state.link, LinkState::Idle)); +} + +#[test] +fn inbound_unpair_clears_remote_peer_binding() { + let mut harness = Harness::connected(QlFsmConfig::default()); + + harness.a.fsm.unpair(); + let unpair = harness.next_outbound(Side::A).unwrap(); + harness.deliver(Side::B, unpair); + + assert_eq!( + harness.take_event(Side::B), + Some(Event::PeerStatusChanged(PeerStatus::Unpaired)) + ); + assert!(harness.b.fsm.peer().is_none()); + assert!(matches!( + harness.b.fsm.open_stream(route_id(1)), + Err(NoSessionError) + )); + assert!(matches!(harness.connect_ik(Side::B), Err(NoPeerError))); + + let reply_key = harness.b.fsm.state.link.transport().unwrap().tx_key.clone(); + let reply = harness.next_outbound(Side::B).unwrap(); + let (_header, frames) = decrypt_record(&harness.b.crypto, &reply, &reply_key); + assert!(matches!(frames.as_slice(), [ql_wire::SessionFrame::Unpair])); + assert!(matches!(harness.b.fsm.state.link, LinkState::Idle)); +} + +#[test] +fn local_unpair_without_session_emits_unpaired_immediately() { + let mut harness = Harness::paired_known(QlFsmConfig::default()); + + harness.a.fsm.unpair(); + + assert_eq!( + harness.take_event(Side::A), + Some(Event::PeerStatusChanged(PeerStatus::Unpaired)) + ); + assert!(harness.a.fsm.peer().is_none()); + assert_eq!(harness.take_event(Side::A), None); +} + +#[test] +fn session_records_contain_ack_frames_after_delivery() { + let config = QlFsmConfig::default(); + let mut harness = Harness::connected(config); + + let stream_id = open_stream_id(&mut harness.a.fsm); + assert_eq!( + write_stream_bytes(&mut harness.a.fsm, stream_id, b"x").unwrap(), + 1 + ); + + let data = harness.next_outbound(Side::A).unwrap(); + harness.deliver(Side::B, data); + harness.advance(config.session_record_ack_delay); + harness.on_timer(Side::B); + + let ack = harness.next_decoded_outbound(Side::B).unwrap(); + assert!(matches!( + ack.frames.as_slice(), + [ql_wire::SessionFrame::Ack(_)] + )); +} + +#[test] +fn first_stream_data_uses_negotiated_initial_peer_credit() { + let mut harness = Harness::paired_known_with_configs( + QlFsmConfig { + session_stream_receive_buffer_size: 8, + ..QlFsmConfig::default() + }, + QlFsmConfig { + session_stream_receive_buffer_size: 3, + ..QlFsmConfig::default() + }, + ); + + harness.connect_ik(Side::A).unwrap(); + let ik1 = harness.next_outbound(Side::A).unwrap(); + harness.deliver(Side::B, ik1); + let ik2 = harness.next_outbound(Side::B).unwrap(); + harness.deliver(Side::A, ik2); + + let stream_id = open_stream_id(&mut harness.a.fsm); + assert_eq!( + write_stream_bytes(&mut harness.a.fsm, stream_id, b"hello").unwrap(), + 5 + ); + + assert!(matches!( + harness.next_decoded_outbound(Side::A).unwrap().frames.as_slice(), + [ql_wire::SessionFrame::StreamData(frame)] if frame.stream_id == stream_id && frame.bytes.as_slice() == b"hel" + )); +} + +#[test] +fn session_timeout_emits_close_before_disconnect() { + let config = QlFsmConfig { + session_peer_timeout: Duration::from_millis(30), + ..QlFsmConfig::default() + }; + let mut harness = Harness::connected(config); + + harness.advance(config.session_peer_timeout); + harness.on_timer(Side::A); + + assert_eq!( + harness.drain_events(Side::A), + vec![Event::SessionClosed(SessionClose { + code: ql_wire::SessionCloseCode::TIMEOUT, + })] + ); + + let close = harness.next_decoded_outbound(Side::A).unwrap(); + assert!(matches!( + close.frames.as_slice(), + [ql_wire::SessionFrame::Close(_)] + )); + assert_eq!( + harness.take_event(Side::A), + Some(Event::PeerStatusChanged(PeerStatus::Disconnected)) + ); +} diff --git a/ql-rpc/Cargo.toml b/ql-rpc/Cargo.toml new file mode 100644 index 00000000..51a764dc --- /dev/null +++ b/ql-rpc/Cargo.toml @@ -0,0 +1,10 @@ +[package] +name = "ql-rpc" +version = "0.1.0" +edition = "2021" +description = "Quantum Link RPC protocol traits and framing" +license = "Proprietary" + +[dependencies] +bytes = { version = "1" } +trait-variant = { version = "0.1" } diff --git a/ql-rpc/src/chunk_queue.rs b/ql-rpc/src/chunk_queue.rs new file mode 100644 index 00000000..33f62998 --- /dev/null +++ b/ql-rpc/src/chunk_queue.rs @@ -0,0 +1,252 @@ +use std::collections::VecDeque; + +use bytes::{Buf, Bytes}; + +use crate::{CodecError, Error}; + +const LENGTH_SIZE: usize = 8; + +#[derive(Debug, Default)] +pub struct ChunkQueue { + chunks: VecDeque, + remaining: usize, +} + +impl ChunkQueue { + pub fn push(&mut self, chunk: Bytes) { + if chunk.is_empty() { + return; + } + self.remaining += chunk.len(); + self.chunks.push_back(chunk); + } + + pub fn remaining(&self) -> usize { + self.remaining + } + + pub fn expect_empty(&self) -> Result<(), CodecError> { + if self.remaining > 0 { + Err(CodecError::Rpc(Error::TrailingBytes)) + } else { + Ok(()) + } + } + + pub fn pop_front(&mut self, max_len: usize) -> Option { + let front = self.chunks.front_mut()?; + let chunk = if max_len >= front.len() { + self.chunks.pop_front().expect("buffered chunk is present") + } else { + front.split_to(max_len) + }; + self.remaining -= chunk.len(); + Some(chunk) + } + + pub fn pop_front_chunk(&mut self) -> Option { + self.pop_front(usize::MAX) + } + + pub fn try_take_part(&mut self) -> Result>, Error> { + let Some(len) = self.peek_next_part_len()? else { + return Ok(None); + }; + self.advance(LENGTH_SIZE); + Ok(Some(DrainBuf::new(self, len))) + } + + pub fn try_take_tagged_part(&mut self) -> Result)>, Error> { + let mut bytes = self.peek(); + let Ok(kind) = bytes.try_get_u8() else { + return Ok(None); + }; + let Some(len) = read_next_part_len(&mut bytes)? else { + return Ok(None); + }; + + self.advance(1 + LENGTH_SIZE); + Ok(Some((kind, DrainBuf::new(self, len)))) + } + + pub fn try_take_tagged_part_header(&mut self) -> Result, Error> { + let mut bytes = self.peek(); + let Ok(kind) = bytes.try_get_u8() else { + return Ok(None); + }; + let Some(len) = read_part_len_header(&mut bytes)? else { + return Ok(None); + }; + + self.advance(1 + LENGTH_SIZE); + Ok(Some((kind, len))) + } + + pub fn try_take_body(&mut self, len: usize) -> Option> { + if self.remaining < len { + return None; + } + + Some(DrainBuf::new(self, len)) + } + + fn peek_next_part_len(&self) -> Result, Error> { + let mut bytes = self.peek(); + read_next_part_len(&mut bytes) + } + + fn peek(&self) -> ChunkQueuePeek<'_> { + ChunkQueuePeek { + chunks: &self.chunks, + chunk_index: 0, + chunk_offset: 0, + remaining: self.remaining, + } + } + + fn front_chunk(&self, limit: usize) -> &[u8] { + let Some(chunk) = self.chunks.front() else { + return &[]; + }; + &chunk[..chunk.len().min(limit)] + } + + fn advance_inner(&mut self, mut cnt: usize) { + assert!(cnt <= self.remaining, "advanced past buffered data"); + self.remaining -= cnt; + while cnt > 0 { + let front = self.chunks.front_mut().expect("buffered data present"); + let consumed = cnt.min(front.len()); + front.advance(consumed); + cnt -= consumed; + if front.is_empty() { + self.chunks.pop_front(); + } + } + } +} + +struct ChunkQueuePeek<'a> { + chunks: &'a VecDeque, + chunk_index: usize, + chunk_offset: usize, + remaining: usize, +} + +impl Buf for ChunkQueuePeek<'_> { + fn remaining(&self) -> usize { + self.remaining + } + + fn chunk(&self) -> &[u8] { + if self.remaining == 0 { + return &[]; + } + + let Some(chunk) = self.chunks.get(self.chunk_index) else { + return &[]; + }; + &chunk[self.chunk_offset..] + } + + fn advance(&mut self, mut cnt: usize) { + assert!(cnt <= self.remaining, "advanced past buffered data"); + self.remaining -= cnt; + + while cnt > 0 { + let chunk = self + .chunks + .get(self.chunk_index) + .expect("buffered data present"); + let available = chunk.len() - self.chunk_offset; + let step = cnt.min(available); + self.chunk_offset += step; + cnt -= step; + if self.chunk_offset == chunk.len() { + self.chunk_index += 1; + self.chunk_offset = 0; + } + } + } +} + +impl Buf for ChunkQueue { + fn remaining(&self) -> usize { + self.remaining + } + + fn chunk(&self) -> &[u8] { + self.front_chunk(self.remaining) + } + + fn advance(&mut self, cnt: usize) { + assert!(cnt <= self.remaining, "advanced past buffered data"); + self.advance_inner(cnt); + } +} + +pub struct DrainBuf<'a> { + bytes: &'a mut ChunkQueue, + remaining: usize, +} + +impl<'a> DrainBuf<'a> { + pub fn new(bytes: &'a mut ChunkQueue, len: usize) -> Self { + debug_assert!(bytes.remaining() >= len); + Self { + bytes, + remaining: len, + } + } + + pub fn expect_empty(&self) -> Result<(), CodecError> { + if self.remaining > 0 { + Err(CodecError::Rpc(Error::TrailingBytes)) + } else { + Ok(()) + } + } +} + +impl Buf for DrainBuf<'_> { + fn remaining(&self) -> usize { + self.remaining + } + + fn chunk(&self) -> &[u8] { + self.bytes.front_chunk(self.remaining) + } + + fn advance(&mut self, cnt: usize) { + assert!(cnt <= self.remaining(), "advanced past payload boundary"); + self.bytes.advance_inner(cnt); + self.remaining -= cnt; + } +} + +impl Drop for DrainBuf<'_> { + fn drop(&mut self) { + if self.remaining > 0 { + self.bytes.advance_inner(self.remaining); + self.remaining = 0; + } + } +} + +fn read_next_part_len(bytes: &mut B) -> Result, Error> { + let Some(len) = read_part_len_header(bytes)? else { + return Ok(None); + }; + if bytes.remaining() < len { + return Ok(None); + } + Ok(Some(len)) +} + +fn read_part_len_header(bytes: &mut B) -> Result, Error> { + let Ok(len) = bytes.try_get_u64_le() else { + return Ok(None); + }; + let len: usize = len.try_into().map_err(|_| Error::LengthOverflow)?; + Ok(Some(len)) +} diff --git a/ql-rpc/src/codec.rs b/ql-rpc/src/codec.rs new file mode 100644 index 00000000..51da527b --- /dev/null +++ b/ql-rpc/src/codec.rs @@ -0,0 +1,83 @@ +use std::{convert::Infallible, str::Utf8Error}; + +use bytes::{Buf, BufMut, Bytes}; + +pub use crate::chunk_queue::ChunkQueue; + +pub trait RpcCodec: Sized { + type Error; + + fn encode_value(&self, out: &mut B); + fn decode_value(bytes: &mut B) -> Result; +} + +impl RpcCodec for String { + type Error = Utf8Error; + + fn encode_value(&self, out: &mut B) { + out.put_slice(self.as_bytes()); + } + + fn decode_value(bytes: &mut B) -> Result { + let len = bytes.remaining(); + if bytes.chunk().len() == len { + let s = std::str::from_utf8(bytes.chunk())?.to_owned(); + bytes.advance(len); + Ok(s) + } else { + let mut buf = vec![0; len]; + bytes.copy_to_slice(&mut buf); + String::from_utf8(buf).map_err(|err| err.utf8_error()) + } + } +} + +impl RpcCodec for Vec { + type Error = Infallible; + + fn encode_value(&self, out: &mut B) { + out.put_slice(self.as_slice()); + } + + fn decode_value(bytes: &mut B) -> Result { + let len = bytes.remaining(); + let mut buf = vec![0; len]; + bytes.copy_to_slice(&mut buf); + Ok(buf) + } +} + +impl RpcCodec for Bytes { + type Error = Infallible; + + fn encode_value(&self, out: &mut B) { + out.put_slice(self.as_ref()); + } + + fn decode_value(bytes: &mut B) -> Result { + Ok(bytes.copy_to_bytes(bytes.remaining())) + } +} + +const LENGTH_SIZE: usize = 8; + +pub fn encode_value_part>(value: &T, out: &mut B) { + let payload_start = reserve_length(out); + value.encode_value(out); + backpatch_length(out, payload_start); +} + +/// reads one length-delimited rpc value from buffered byte chunks +pub fn reserve_length>(out: &mut B) -> usize { + let start = out.as_mut().len(); + out.put_bytes(0, LENGTH_SIZE); + start +} + +pub fn backpatch_length + ?Sized>(out: &mut B, start: usize) { + let out = out.as_mut(); + let payload_start = start + LENGTH_SIZE; + let payload_len = out.len() - payload_start; + let payload_len = u64::try_from(payload_len).expect("rpc payload exceeds u64 length framing"); + out[start..payload_start].copy_from_slice(&payload_len.to_le_bytes()); +} diff --git a/ql-rpc/src/error.rs b/ql-rpc/src/error.rs new file mode 100644 index 00000000..7404a22e --- /dev/null +++ b/ql-rpc/src/error.rs @@ -0,0 +1,112 @@ +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum Error { + Truncated, + LengthOverflow, + UnexpectedFrameKind(u8), + MissingResponse, + TrailingBytes, +} + +impl std::fmt::Display for Error { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Truncated => f.write_str("truncated rpc payload"), + Self::LengthOverflow => f.write_str("rpc payload length overflow"), + Self::UnexpectedFrameKind(kind) => write!(f, "unexpected rpc frame kind {kind}"), + Self::MissingResponse => f.write_str("missing terminal rpc response"), + Self::TrailingBytes => f.write_str("trailing rpc bytes"), + } + } +} + +impl std::error::Error for Error {} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum CodecError { + Rpc(Error), + Codec(E), +} + +impl std::error::Error for CodecError +where + E: std::error::Error + 'static, +{ + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + CodecError::Rpc(e) => Some(e), + CodecError::Codec(e) => Some(e), + } + } + + fn cause(&self) -> Option<&dyn std::error::Error> { + self.source() + } +} + +impl std::fmt::Display for CodecError +where + E: std::fmt::Display, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + CodecError::Rpc(e) => write!(f, "{e}"), + CodecError::Codec(e) => write!(f, "{e}"), + } + } +} + +impl From for CodecError { + fn from(error: Error) -> Self { + Self::Rpc(error) + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum CallError { + Protocol(Error), + Codec(C), + Transport(T), +} + +impl std::fmt::Display for CallError +where + C: std::fmt::Display, + T: std::fmt::Display, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Protocol(error) => write!(f, "{error}"), + Self::Codec(error) => write!(f, "{error}"), + Self::Transport(error) => write!(f, "{error}"), + } + } +} + +impl std::error::Error for CallError +where + C: std::error::Error + 'static, + T: std::error::Error + 'static, +{ + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + CallError::Protocol(error) => Some(error), + CallError::Codec(error) => Some(error), + CallError::Transport(error) => Some(error), + } + } +} + +impl From for CallError { + fn from(error: Error) -> Self { + Self::Protocol(error) + } +} + +impl From> for CallError { + fn from(error: CodecError) -> Self { + match error { + CodecError::Rpc(error) => Self::Protocol(error), + CodecError::Codec(error) => Self::Codec(error), + } + } +} diff --git a/ql-rpc/src/framed_value.rs b/ql-rpc/src/framed_value.rs new file mode 100644 index 00000000..600357da --- /dev/null +++ b/ql-rpc/src/framed_value.rs @@ -0,0 +1,127 @@ +use std::marker::PhantomData; + +use bytes::Bytes; + +use crate::{chunk_queue::ChunkQueue, CodecError, RpcCodec}; + +/// reads one length-delimited rpc value from buffered byte chunks +pub struct FramedReader { + bytes: ChunkQueue, + marker: PhantomData T>, +} + +pub enum FramedReadStep { + NeedMore(FramedReader), + Value(T), +} + +pub enum FramedPrefixStep { + NeedMore(FramedReader), + Value { value: T, bytes: ChunkQueue }, +} + +impl Default for FramedReader { + fn default() -> Self { + Self { + bytes: ChunkQueue::default(), + marker: PhantomData, + } + } +} + +impl FramedReader { + pub fn push(mut self, chunk: Bytes) -> Self { + self.bytes.push(chunk); + self + } + + pub fn advance(self) -> Result, CodecError> { + match self.advance_prefix()? { + FramedPrefixStep::NeedMore(next) => Ok(FramedReadStep::NeedMore(next)), + FramedPrefixStep::Value { value, bytes } => { + bytes.expect_empty()?; + Ok(FramedReadStep::Value(value)) + } + } + } + + pub fn advance_prefix(self) -> Result, CodecError> { + let mut this = self; + let Some(mut body) = this.bytes.try_take_part()? else { + return Ok(FramedPrefixStep::NeedMore(this)); + }; + + let value = T::decode_value(&mut body).map_err(CodecError::Codec)?; + drop(body); + Ok(FramedPrefixStep::Value { + value, + bytes: this.bytes, + }) + } +} + +#[cfg(test)] +mod tests { + use bytes::Bytes; + + use super::{FramedPrefixStep, FramedReadStep, FramedReader}; + use crate::codec::encode_value_part; + + #[test] + fn value_reader_round_trips_framed_values() { + let mut encoded = Vec::new(); + encode_value_part(&b"hello".to_vec(), &mut encoded); + + match FramedReader::>::default() + .push(Bytes::from(encoded)) + .advance() + .unwrap() + { + FramedReadStep::Value(value) => assert_eq!(value, b"hello".to_vec()), + _ => unreachable!(), + } + } + + #[test] + fn value_reader_waits_for_complete_frame() { + let mut encoded = Vec::new(); + encode_value_part(&b"hello".to_vec(), &mut encoded); + let encoded = Bytes::from(encoded); + + let reader = match FramedReader::>::default() + .push(encoded.slice(..4)) + .advance() + .unwrap() + { + FramedReadStep::NeedMore(next) => next, + _ => unreachable!(), + }; + + match reader.push(encoded.slice(4..)).advance().unwrap() { + FramedReadStep::Value(value) => assert_eq!(value, b"hello".to_vec()), + _ => unreachable!(), + } + } + + #[test] + fn value_reader_returns_prefix_remainder() { + let mut encoded = Vec::new(); + encode_value_part(&b"hello".to_vec(), &mut encoded); + encoded.extend_from_slice(b"tail"); + + match FramedReader::>::default() + .push(Bytes::from(encoded)) + .advance_prefix() + .unwrap() + { + FramedPrefixStep::Value { value, mut bytes } => { + assert_eq!(value, b"hello".to_vec()); + assert_eq!( + bytes.pop_front(usize::MAX), + Some(Bytes::from_static(b"tail")) + ); + } + _ => unreachable!(), + } + } +} diff --git a/ql-rpc/src/lib.rs b/ql-rpc/src/lib.rs new file mode 100644 index 00000000..abd20eae --- /dev/null +++ b/ql-rpc/src/lib.rs @@ -0,0 +1,42 @@ +//! quantum link rpc protocol traits and framing helpers. + +mod chunk_queue; +pub(crate) mod codec; +mod error; +mod framed_value; +mod route_id; +mod router; +mod rpc; +mod stream; + +pub use chunk_queue::ChunkQueue; +pub use codec::RpcCodec; +pub use error::*; +use framed_value::*; +pub use route_id::RouteId; +pub use router::*; +pub use rpc::*; +pub use stream::*; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[repr(transparent)] +pub struct StreamCloseCode(pub u16); + +impl StreamCloseCode { + /// operation was cancelled + pub const CANCELLED: Self = Self(0); + /// local internal error + pub const INTERNAL: Self = Self(1); + /// request was refused + pub const REFUSED: Self = Self(2); + /// operation timed out + pub const TIMEOUT: Self = Self(3); + /// configured limit was exceeded + pub const LIMIT: Self = Self(4); + /// route identifier was unknown + pub const UNKNOWN_ROUTE: Self = Self(5); + + pub const fn into_inner(self) -> u16 { + self.0 + } +} diff --git a/ql-rpc/src/route_id.rs b/ql-rpc/src/route_id.rs new file mode 100644 index 00000000..1b054e74 --- /dev/null +++ b/ql-rpc/src/route_id.rs @@ -0,0 +1,19 @@ +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[repr(transparent)] +pub struct RouteId(pub u32); + +impl RouteId { + pub const fn from_u32(value: u32) -> Self { + Self(value) + } + + pub const fn into_inner(self) -> u32 { + self.0 + } +} + +impl From for RouteId { + fn from(value: u32) -> Self { + Self::from_u32(value) + } +} diff --git a/ql-rpc/src/router/builder.rs b/ql-rpc/src/router/builder.rs new file mode 100644 index 00000000..b59a84e6 --- /dev/null +++ b/ql-rpc/src/router/builder.rs @@ -0,0 +1,354 @@ +use std::marker::PhantomData; + +use super::{ + LocalSpawner, RouteEntry, RouteFn, Router, RouterConfig, RpcStream, SendSpawner, Spawner, +}; +use crate::{ + download::{server::*, Download as DownloadRpc}, + duplex::{server::*, Duplex as DuplexRpc}, + notification::{server::*, Notification as NotificationRpc}, + progress::{server::*, Progress as ProgressRpc}, + request::{server::*, Request as RequestRpc}, + subscription::{server::*, Subscription as SubscriptionRpc}, + upload::{server::*, Upload as UploadRpc}, +}; + +pub struct LocalRoutes; +pub struct SendRoutes; + +pub struct RouterBuilder +where + Sp: Spawner, +{ + config: RouterConfig, + spawner: Sp, + routes: Vec>, + marker: PhantomData Mode>, +} + +impl RouterBuilder +where + Sp: Spawner, +{ + pub(crate) fn new(spawner: Sp) -> Self { + Self { + config: RouterConfig::default(), + spawner, + routes: Vec::new(), + marker: PhantomData, + } + } + + pub fn config(mut self, config: RouterConfig) -> Self { + self.config = config; + self + } + + pub fn max_request_bytes(mut self, max_request_bytes: usize) -> Self { + self.config.max_request_bytes = max_request_bytes; + self + } + + pub fn build(mut self, state: S) -> Router { + self.routes.sort_by_key(|entry| entry.route_id); + self.routes.shrink_to_fit(); + Router { + config: self.config, + state, + spawner: self.spawner, + routes: self.routes, + } + } + + fn add_route(mut self, route_id: crate::RouteId, route: RouteFn) -> Self { + if self.routes.iter().any(|entry| entry.route_id == route_id) { + panic!("duplicate rpc route {}", route_id.into_inner()); + } + self.routes.push(RouteEntry::new(route_id, route)); + self + } +} + +impl RouterBuilder +where + Sp: LocalSpawner, + St: RpcStream + 'static, +{ + pub fn request(self) -> Self + where + M: RequestRpc + 'static, + S: RequestHandlerLocal + 'static, + { + self.add_route(M::ROUTE, |spawner, state, config, stream| { + let (reader, writer) = stream.split(); + spawner.spawn(handle_request_inner::( + state, + config, + reader, + writer, + S::handle, + S::handle_transport_error, + )) + }) + } + + pub fn notification(self) -> Self + where + M: NotificationRpc + 'static, + S: NotificationHandlerLocal + 'static, + { + self.add_route(M::ROUTE, |spawner, state, config, stream| { + let (reader, writer) = stream.split(); + spawner.spawn(handle_notification_inner::( + state, + config, + reader, + writer, + S::handle, + S::handle_transport_error, + )) + }) + } + + pub fn duplex(self) -> Self + where + M: DuplexRpc + 'static, + S: DuplexHandlerLocal + 'static, + { + self.add_route(M::ROUTE, |spawner, state, config, stream| { + let (reader, writer) = stream.split(); + spawner.spawn(handle_duplex_inner::( + state, + config, + reader, + writer, + S::handle, + )) + }) + } + + pub fn download(self) -> Self + where + M: DownloadRpc + 'static, + S: DownloadHandlerLocal + 'static, + { + self.add_route(M::ROUTE, |spawner, state, config, stream| { + let (reader, writer) = stream.split(); + spawner.spawn(handle_download_inner::( + state, + config, + reader, + writer, + S::handle, + S::handle_transport_error, + )) + }) + } + + pub fn subscription(self) -> Self + where + M: SubscriptionRpc + 'static, + S: SubscriptionHandlerLocal + 'static, + { + self.add_route(M::ROUTE, |spawner, state, config, stream| { + let (reader, writer) = stream.split(); + spawner.spawn(handle_subscription_inner::( + state, + config, + reader, + writer, + S::handle, + S::handle_transport_error, + )) + }) + } + + pub fn progress(self) -> Self + where + M: ProgressRpc + 'static, + S: ProgressHandlerLocal + 'static, + { + self.add_route(M::ROUTE, |spawner, state, config, stream| { + let (reader, writer) = stream.split(); + spawner.spawn(handle_progress_inner::( + state, + config, + reader, + writer, + S::handle, + S::handle_transport_error, + )) + }) + } + + pub fn upload(self) -> Self + where + M: UploadRpc + 'static, + S: UploadHandlerLocal + 'static, + { + self.add_route(M::ROUTE, |spawner, state, config, stream| { + let (reader, writer) = stream.split(); + spawner.spawn(handle_upload_inner::( + state, + config, + reader, + writer, + S::handle, + S::handle_transport_error, + )) + }) + } +} + +impl RouterBuilder +where + Sp: SendSpawner + Send, + St: RpcStream + 'static, +{ + pub fn request(self) -> Self + where + M: RequestRpc + 'static, + M::Request: Send + 'static, + S: RequestHandler + Send + 'static, + St::Reader: Send + 'static, + St::Writer: Send + 'static, + { + self.add_route(M::ROUTE, |spawner, state, config, stream| { + let (reader, writer) = stream.split(); + spawner.spawn(handle_request_inner::( + state, + config, + reader, + writer, + S::handle, + S::handle_transport_error, + )) + }) + } + + pub fn notification(self) -> Self + where + M: NotificationRpc + 'static, + M::Payload: Send + 'static, + S: NotificationHandler + Send + 'static, + St::Reader: Send + 'static, + St::Writer: Send + 'static, + { + self.add_route(M::ROUTE, |spawner, state, config, stream| { + let (reader, writer) = stream.split(); + spawner.spawn(handle_notification_inner::( + state, + config, + reader, + writer, + S::handle, + S::handle_transport_error, + )) + }) + } + + pub fn duplex(self) -> Self + where + M: DuplexRpc + 'static, + M::InitiatorEvent: Send + 'static, + M::ResponderEvent: Send + 'static, + S: DuplexHandler + Send + 'static, + St::Reader: Send + 'static, + St::Writer: Send + 'static, + { + self.add_route(M::ROUTE, |spawner, state, config, stream| { + let (reader, writer) = stream.split(); + spawner.spawn(handle_duplex_inner::( + state, + config, + reader, + writer, + S::handle, + )) + }) + } + + pub fn download(self) -> Self + where + M: DownloadRpc + 'static, + M::Request: Send + 'static, + S: DownloadHandler + Send + 'static, + St::Reader: Send + 'static, + St::Writer: Send + 'static, + { + self.add_route(M::ROUTE, |spawner, state, config, stream| { + let (reader, writer) = stream.split(); + spawner.spawn(handle_download_inner::( + state, + config, + reader, + writer, + S::handle, + S::handle_transport_error, + )) + }) + } + + pub fn subscription(self) -> Self + where + M: SubscriptionRpc + 'static, + M::Request: Send + 'static, + S: SubscriptionHandler + Send + 'static, + St::Reader: Send + 'static, + St::Writer: Send + 'static, + { + self.add_route(M::ROUTE, |spawner, state, config, stream| { + let (reader, writer) = stream.split(); + spawner.spawn(handle_subscription_inner::( + state, + config, + reader, + writer, + S::handle, + S::handle_transport_error, + )) + }) + } + + pub fn progress(self) -> Self + where + M: ProgressRpc + 'static, + M::Request: Send + 'static, + S: ProgressHandler + Send + 'static, + St::Reader: Send + 'static, + St::Writer: Send + 'static, + { + self.add_route(M::ROUTE, |spawner, state, config, stream| { + let (reader, writer) = stream.split(); + spawner.spawn(handle_progress_inner::( + state, + config, + reader, + writer, + S::handle, + S::handle_transport_error, + )) + }) + } + + pub fn upload(self) -> Self + where + M: UploadRpc + 'static, + M::Request: Send + 'static, + S: UploadHandler + Send + 'static, + St::Reader: Send + 'static, + St::Writer: Send + 'static, + { + self.add_route(M::ROUTE, |spawner, state, config, stream| { + let (reader, writer) = stream.split(); + spawner.spawn(handle_upload_inner::( + state, + config, + reader, + writer, + S::handle, + S::handle_transport_error, + )) + }) + } +} diff --git a/ql-rpc/src/router/config.rs b/ql-rpc/src/router/config.rs new file mode 100644 index 00000000..d6fb048f --- /dev/null +++ b/ql-rpc/src/router/config.rs @@ -0,0 +1,12 @@ +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct RouterConfig { + pub max_request_bytes: usize, +} + +impl Default for RouterConfig { + fn default() -> Self { + Self { + max_request_bytes: usize::MAX, + } + } +} diff --git a/ql-rpc/src/router/mod.rs b/ql-rpc/src/router/mod.rs new file mode 100644 index 00000000..31e973ac --- /dev/null +++ b/ql-rpc/src/router/mod.rs @@ -0,0 +1,89 @@ +use crate::{RouteId, StreamCloseCode}; + +mod builder; +mod config; +mod mode; + +pub use self::{ + builder::{LocalRoutes, RouterBuilder, SendRoutes}, + config::RouterConfig, + mode::*, +}; +use crate::{close_stream, RpcStream}; +pub use crate::{ + download::{DownloadHandler, DownloadHandlerLocal, DownloadStart, DownloadWriter}, + duplex::{DuplexHandler, DuplexHandlerLocal, DuplexPeer}, + notification::{NotificationHandler, NotificationHandlerLocal}, + progress::{ProgressHandler, ProgressHandlerLocal, ProgressResponder}, + request::{RequestHandler, RequestHandlerLocal, Response}, + subscription::{SubscriptionHandler, SubscriptionHandlerLocal, SubscriptionResponder}, + upload::{UploadHandler, UploadHandlerLocal, UploadReader, UploadResponder}, +}; + +pub struct Router +where + Sp: Spawner, +{ + config: RouterConfig, + state: S, + spawner: Sp, + routes: Vec>, +} + +struct RouteEntry +where + Sp: Spawner, +{ + route_id: RouteId, + route: RouteFn, +} + +impl RouteEntry +where + Sp: Spawner, +{ + fn new(route_id: RouteId, route: RouteFn) -> Self { + Self { route_id, route } + } +} + +impl Router +where + S: Clone + 'static, + St: RpcStream, + Sp: Spawner, +{ + pub fn builder_local(spawner: Sp) -> RouterBuilder + where + Sp: LocalSpawner, + { + RouterBuilder::::new(spawner) + } + + pub fn builder_send(spawner: Sp) -> RouterBuilder + where + Sp: SendSpawner, + { + RouterBuilder::::new(spawner) + } + + pub fn handle(&self, stream: St) -> Option<(RouteId, Sp::Handle)> { + let route_id = stream.route_id()?; + let Ok(index) = self + .routes + .binary_search_by_key(&route_id, |entry| entry.route_id) + else { + close_stream(stream, StreamCloseCode::UNKNOWN_ROUTE); + return None; + }; + let route = self.routes[index].route; + Some(( + route_id, + route(&self.spawner, self.state.clone(), self.config, stream), + )) + } + + pub fn route_ids(&self) -> impl ExactSizeIterator + '_ { + self.routes.iter().map(|entry| entry.route_id) + } +} diff --git a/ql-rpc/src/router/mode.rs b/ql-rpc/src/router/mode.rs new file mode 100644 index 00000000..33b6c06a --- /dev/null +++ b/ql-rpc/src/router/mode.rs @@ -0,0 +1,21 @@ +use std::future::Future; + +use crate::RouterConfig; + +pub type RouteFn = fn(&Sp, S, RouterConfig, St) -> ::Handle; + +pub trait Spawner: Clone + 'static { + type Handle; +} + +pub trait LocalSpawner: Spawner { + fn spawn(&self, fut: F) -> Self::Handle + where + F: Future + 'static; +} + +pub trait SendSpawner: Spawner { + fn spawn(&self, fut: F) -> Self::Handle + where + F: Future + Send + 'static; +} diff --git a/ql-rpc/src/rpc/download/client.rs b/ql-rpc/src/rpc/download/client.rs new file mode 100644 index 00000000..9a648181 --- /dev/null +++ b/ql-rpc/src/rpc/download/client.rs @@ -0,0 +1,246 @@ +use std::future::poll_fn; + +use bytes::{BufMut, Bytes}; + +use crate::{ + download::{Download, PartReadStep}, + rpc::parts::FrameKind, + CallError, FramedPrefixStep, FramedReader, RpcCodec, RpcRead, StreamCloseCode, +}; + +pub struct DownloadCall +where + M: Download, + R: RpcRead, +{ + stream: Option, + reader: Option>, +} + +pub struct DownloadPart<'a, M, R> +where + M: Download, + R: RpcRead, +{ + parent: &'a mut DownloadReader, + finished: bool, +} + +pub struct DownloadReader +where + M: Download, + R: RpcRead, +{ + stream: Option, + reader: crate::download::PartFrameReader, +} + +impl DownloadCall +where + M: Download, + R: RpcRead, +{ + pub fn new(stream: R) -> Self { + Self { + stream: Some(stream), + reader: Some(FramedReader::default()), + } + } + + pub async fn start( + mut self, + ) -> Result<(M::ResponseHeader, DownloadReader), CallError> { + loop { + let reader = self.reader.take().unwrap(); + let reader = match reader.advance_prefix() { + Ok(FramedPrefixStep::Value { value, bytes }) => { + let stream = self.stream.take().unwrap(); + return Ok(( + value, + DownloadReader { + stream: Some(stream), + reader: crate::download::PartFrameReader::::new(bytes), + }, + )); + } + Ok(FramedPrefixStep::NeedMore(next)) => next, + Err(error) => return Err(error.into()), + }; + + let stream = self.stream.as_mut().unwrap(); + match poll_fn(|cx| stream.poll_read(usize::MAX, cx)).await { + Ok(Some(chunk)) => { + self.reader = Some(reader.push(chunk)); + } + Ok(None) => return Err(crate::Error::Truncated.into()), + Err(error) => return Err(CallError::Transport(error)), + } + } + } + + pub fn close(mut self, code: StreamCloseCode) { + self.close_inner(code); + } + + fn close_inner(&mut self, code: StreamCloseCode) { + if let Some(stream) = self.stream.take() { + stream.close(code); + } + } +} + +impl Drop for DownloadCall +where + M: Download, + R: RpcRead, +{ + fn drop(&mut self) { + self.close_inner(StreamCloseCode::CANCELLED); + } +} + +impl DownloadReader +where + M: Download, + R: RpcRead, +{ + pub async fn next_part( + &mut self, + ) -> Result)>, CallError> + { + if self.stream.is_none() { + return Ok(None); + } + + match self.read_frame().await? { + PartReadStep::PartHeader(value) => Ok(Some(( + value, + DownloadPart { + parent: self, + finished: false, + }, + ))), + PartReadStep::Finish => { + self.stream.take(); + Ok(None) + } + PartReadStep::BodyBytes(_) => { + Err(crate::Error::UnexpectedFrameKind(FrameKind::BodyChunk.tag()).into()) + } + PartReadStep::EndPart => { + Err(crate::Error::UnexpectedFrameKind(FrameKind::EndPart.tag()).into()) + } + PartReadStep::NeedMore => unreachable!("read_frame waits for a complete frame"), + } + } + + pub async fn complete(mut self) -> Result<(), CallError> { + match self.read_frame().await? { + PartReadStep::Finish => { + self.stream.take(); + Ok(()) + } + PartReadStep::PartHeader(_) => { + Err(crate::Error::UnexpectedFrameKind(FrameKind::PartHeader.tag()).into()) + } + PartReadStep::BodyBytes(_) => { + Err(crate::Error::UnexpectedFrameKind(FrameKind::BodyChunk.tag()).into()) + } + PartReadStep::EndPart => { + Err(crate::Error::UnexpectedFrameKind(FrameKind::EndPart.tag()).into()) + } + PartReadStep::NeedMore => unreachable!("read_frame waits for a complete frame"), + } + } + + pub fn close(mut self, code: StreamCloseCode) { + self.close_inner(code); + } + + async fn read_frame( + &mut self, + ) -> Result, CallError> { + loop { + match self.reader.advance() { + Ok(PartReadStep::NeedMore) => {} + Ok(step) => return Ok(step), + Err(error) => return Err(error.into()), + } + + let stream = self.stream.as_mut().unwrap(); + match poll_fn(|cx| stream.poll_read(usize::MAX, cx)).await { + Ok(Some(chunk)) => { + self.reader.push(chunk); + } + Ok(None) => return Err(crate::Error::Truncated.into()), + Err(error) => return Err(CallError::Transport(error)), + } + } + } + + fn close_inner(&mut self, code: StreamCloseCode) { + if let Some(stream) = self.stream.take() { + stream.close(code); + } + } +} + +impl Drop for DownloadReader +where + M: Download, + R: RpcRead, +{ + fn drop(&mut self) { + if self.stream.is_some() { + self.close_inner(StreamCloseCode::CANCELLED); + } + } +} + +impl DownloadPart<'_, M, R> +where + M: Download, + R: RpcRead, +{ + pub async fn read_chunk(&mut self) -> Result, CallError> { + if self.finished { + return Ok(None); + } + + match self.parent.read_frame().await? { + PartReadStep::BodyBytes(bytes) => Ok(Some(bytes)), + PartReadStep::EndPart => { + self.finished = true; + Ok(None) + } + PartReadStep::PartHeader(_) => { + Err(crate::Error::UnexpectedFrameKind(FrameKind::PartHeader.tag()).into()) + } + PartReadStep::Finish => { + Err(crate::Error::UnexpectedFrameKind(FrameKind::Finish.tag()).into()) + } + PartReadStep::NeedMore => unreachable!("read_frame waits for a complete frame"), + } + } + + pub fn close(mut self, code: StreamCloseCode) { + self.parent.close_inner(code); + self.finished = true; + } +} + +impl Drop for DownloadPart<'_, M, R> +where + M: Download, + R: RpcRead, +{ + fn drop(&mut self) { + if !self.finished { + self.parent.close_inner(StreamCloseCode::CANCELLED); + } + } +} + +pub fn encode_request(request: &M::Request, out: &mut (impl BufMut + AsMut<[u8]>)) { + request.encode_value(out) +} diff --git a/ql-rpc/src/rpc/download/mod.rs b/ql-rpc/src/rpc/download/mod.rs new file mode 100644 index 00000000..5ed34aed --- /dev/null +++ b/ql-rpc/src/rpc/download/mod.rs @@ -0,0 +1,31 @@ +use super::Route; +use crate::RpcCodec; + +pub(crate) mod client; +pub(crate) mod server; + +pub use client::{encode_request, DownloadCall, DownloadPart, DownloadReader}; +pub use server::{ + DownloadHandler, DownloadHandlerLocal, DownloadPartWriter, DownloadStart, DownloadWriter, +}; + +pub use crate::rpc::parts::{ + encode_body_chunk, encode_end_part, encode_finish, encode_part_header, PartFrameReader, + PartReadStep, +}; + +/// rpc where the responder returns metadata first and then zero or more byte parts +/// +/// the typed portion of the response ends at [`Self::ResponseHeader`] +/// after the header is decoded, the rest of the stream is exposed as typed +/// part headers followed by raw byte chunks through [`DownloadReader`] +pub trait Download: Route { + /// codec error shared by request and response header values + type Error; + /// typed input needed to start the download + type Request: RpcCodec; + /// typed metadata available before parts arrive + type ResponseHeader: RpcCodec; + /// typed metadata available before each byte part arrives + type PartHeader: RpcCodec; +} diff --git a/ql-rpc/src/rpc/download/server.rs b/ql-rpc/src/rpc/download/server.rs new file mode 100644 index 00000000..fcdcb047 --- /dev/null +++ b/ql-rpc/src/rpc/download/server.rs @@ -0,0 +1,221 @@ +use std::{future::Future, marker::PhantomData}; + +use bytes::Bytes; + +use crate::{ + codec, + download::Download as DownloadRpc, + finish_bytes, + rpc::{ + parts::{encode_body_chunk, encode_end_part, encode_finish, encode_part_header}, + read_eof_request, + }, + write_bytes, RouterConfig, RpcRead, RpcStream, RpcWrite, StreamCloseCode, StreamError, +}; + +#[trait_variant::make(DownloadHandler: Send)] +pub trait DownloadHandlerLocal +where + M: DownloadRpc, + St: RpcStream, +{ + async fn handle(self, message: M::Request, download: DownloadStart); + + fn handle_transport_error(&self, _error: &St::Error) {} +} + +pub struct DownloadStart +where + M: DownloadRpc, + W: RpcWrite, +{ + writer: Option, + marker: PhantomData M>, +} + +pub struct DownloadWriter +where + M: DownloadRpc, + W: RpcWrite, +{ + writer: Option, + marker: PhantomData M>, +} + +pub struct DownloadPartWriter<'a, M, W> +where + M: DownloadRpc, + W: RpcWrite, +{ + parent: &'a mut DownloadWriter, + finished: bool, +} + +impl DownloadStart +where + M: DownloadRpc, + W: RpcWrite, +{ + pub(crate) fn new(writer: W) -> Self { + Self { + writer: Some(writer), + marker: PhantomData, + } + } + + /// send the response header and begin streaming parts + pub async fn start( + mut self, + response_header: M::ResponseHeader, + ) -> Result, W::Error> { + let mut writer = self.writer.take().unwrap(); + let mut encoded = Vec::new(); + codec::encode_value_part(&response_header, &mut encoded); + write_bytes(&mut writer, Bytes::from(encoded)).await?; + Ok(DownloadWriter { + writer: Some(writer), + marker: PhantomData, + }) + } + + /// send a header-only response and finish the stream + pub async fn complete(mut self, response_header: M::ResponseHeader) -> Result<(), W::Error> { + let mut writer = self.writer.take().unwrap(); + let mut encoded = Vec::new(); + codec::encode_value_part(&response_header, &mut encoded); + encode_finish(&mut encoded); + write_bytes(&mut writer, Bytes::from(encoded)).await?; + finish_bytes(&mut writer).await + } + + /// close the stream with a transport code + pub fn close(mut self, code: StreamCloseCode) { + if let Some(writer) = self.writer.take() { + writer.close(code); + } + } +} + +impl Drop for DownloadStart +where + M: DownloadRpc, + W: RpcWrite, +{ + fn drop(&mut self) { + if let Some(writer) = self.writer.take() { + writer.close(StreamCloseCode::CANCELLED); + } + } +} + +impl DownloadWriter +where + M: DownloadRpc, + W: RpcWrite, +{ + pub async fn start_part( + &mut self, + part_header: M::PartHeader, + ) -> Result, W::Error> { + let writer = self.writer.as_mut().unwrap(); + let mut encoded = Vec::new(); + encode_part_header(&part_header, &mut encoded); + write_bytes(writer, Bytes::from(encoded)).await?; + Ok(DownloadPartWriter { + parent: self, + finished: false, + }) + } + + pub async fn finish(mut self) -> Result<(), W::Error> { + let mut writer = self.writer.take().unwrap(); + let mut encoded = Vec::new(); + encode_finish(&mut encoded); + write_bytes(&mut writer, Bytes::from(encoded)).await?; + finish_bytes(&mut writer).await + } + + pub fn close(mut self, code: StreamCloseCode) { + if let Some(writer) = self.writer.take() { + writer.close(code); + } + } +} + +impl Drop for DownloadWriter +where + M: DownloadRpc, + W: RpcWrite, +{ + fn drop(&mut self) { + if let Some(writer) = self.writer.take() { + writer.close(StreamCloseCode::CANCELLED); + } + } +} + +impl DownloadPartWriter<'_, M, W> +where + M: DownloadRpc, + W: RpcWrite, +{ + pub async fn send(&mut self, bytes: Bytes) -> Result<(), W::Error> { + let writer = self.parent.writer.as_mut().unwrap(); + let mut encoded = Vec::new(); + encode_body_chunk(&bytes, &mut encoded); + write_bytes(writer, Bytes::from(encoded)).await + } + + pub async fn finish(mut self) -> Result<(), W::Error> { + let writer = self.parent.writer.as_mut().unwrap(); + let mut encoded = Vec::new(); + encode_end_part(&mut encoded); + write_bytes(writer, Bytes::from(encoded)).await?; + self.finished = true; + Ok(()) + } +} + +impl Drop for DownloadPartWriter<'_, M, W> +where + M: DownloadRpc, + W: RpcWrite, +{ + fn drop(&mut self) { + if !self.finished { + if let Some(writer) = self.parent.writer.take() { + writer.close(StreamCloseCode::CANCELLED); + } + } + } +} + +pub(crate) async fn handle_download_inner( + state: S, + config: RouterConfig, + mut reader: St::Reader, + writer: St::Writer, + handle: H, + handle_transport_error: E, +) where + M: DownloadRpc + 'static, + St: RpcStream + 'static, + H: FnOnce(S, M::Request, DownloadStart) -> HF, + HF: Future, + E: FnOnce(&S, &St::Error), +{ + let request = match read_eof_request::(&mut reader, config).await { + Ok(request) => request, + Err(error) => { + let code = error.close_code(); + handle_transport_error(&state, &error); + if let Some(code) = code { + reader.close(code); + writer.close(code); + } + return; + } + }; + + handle(state, request, DownloadStart::new(writer)).await; +} diff --git a/ql-rpc/src/rpc/duplex/client.rs b/ql-rpc/src/rpc/duplex/client.rs new file mode 100644 index 00000000..e76050a6 --- /dev/null +++ b/ql-rpc/src/rpc/duplex/client.rs @@ -0,0 +1,164 @@ +use std::{ + future::poll_fn, + marker::PhantomData, + task::{Context, Poll}, +}; + +use bytes::Bytes; + +use crate::{ + duplex::{codec, Duplex, EventReader, ReadStep}, + finish_bytes, write_bytes, CallError, RpcCodec, RpcRead, RpcWrite, StreamCloseCode, +}; + +pub struct DuplexCall +where + M: Duplex, + W: RpcWrite, + R: RpcRead, +{ + pub sender: DuplexSender, + pub receiver: DuplexReceiver, +} + +pub struct DuplexSender +where + T: RpcCodec, + W: RpcWrite, +{ + writer: Option, + marker: PhantomData T>, +} + +pub struct DuplexReceiver +where + T: RpcCodec, + R: RpcRead, +{ + stream: Option, + reader: EventReader, +} + +impl DuplexSender +where + T: RpcCodec, + W: RpcWrite, +{ + pub fn new(writer: W) -> Self { + Self { + writer: Some(writer), + marker: PhantomData, + } + } + + pub async fn send(&mut self, event: &T) -> Result<(), W::Error> { + let writer = self.writer.as_mut().unwrap(); + let mut encoded = Vec::new(); + codec::encode_event(event, &mut encoded); + write_bytes(writer, Bytes::from(encoded)).await + } + + pub async fn finish(mut self) -> Result<(), W::Error> { + let mut writer = self.writer.take().unwrap(); + finish_bytes(&mut writer).await + } + + pub fn close(mut self, code: StreamCloseCode) { + if let Some(writer) = self.writer.take() { + writer.close(code); + } + } +} + +impl Drop for DuplexSender +where + T: RpcCodec, + W: RpcWrite, +{ + fn drop(&mut self) { + if let Some(writer) = self.writer.take() { + writer.close(StreamCloseCode::CANCELLED); + } + } +} + +impl DuplexReceiver +where + T: RpcCodec, + R: RpcRead, +{ + pub fn new(stream: R) -> Self { + Self { + stream: Some(stream), + reader: EventReader::default(), + } + } + + pub async fn next_event(&mut self) -> Option>> { + poll_fn(|cx| self.poll_next_event(cx)).await + } + + pub fn poll_next_event( + &mut self, + cx: &mut Context<'_>, + ) -> Poll>>> { + if self.stream.is_none() { + return Poll::Ready(None); + } + + loop { + match self.reader.advance() { + Ok(ReadStep::Event(value)) => return Poll::Ready(Some(Ok(value))), + Ok(ReadStep::NeedMore) => {} + Err(error) => { + self.stream.take(); + return Poll::Ready(Some(Err(error.into()))); + } + } + + let stream = self.stream.as_mut().unwrap(); + match stream.poll_read(usize::MAX, cx) { + Poll::Ready(Ok(Some(chunk))) => { + self.reader.push(chunk); + } + Poll::Ready(Ok(None)) => { + if self.reader.is_empty() { + self.stream.take(); + return Poll::Ready(None); + } + self.stream.take(); + return Poll::Ready(Some(Err(crate::Error::Truncated.into()))); + } + Poll::Ready(Err(error)) => { + self.stream.take(); + return Poll::Ready(Some(Err(CallError::Transport(error)))); + } + Poll::Pending => { + return Poll::Pending; + } + } + } + } + + pub fn close(mut self, code: StreamCloseCode) { + self.close_inner(code); + } + + fn close_inner(&mut self, code: StreamCloseCode) { + if let Some(stream) = self.stream.take() { + stream.close(code); + } + } +} + +impl Drop for DuplexReceiver +where + T: RpcCodec, + R: RpcRead, +{ + fn drop(&mut self) { + if self.stream.is_some() { + self.close_inner(StreamCloseCode::CANCELLED); + } + } +} diff --git a/ql-rpc/src/rpc/duplex/codec.rs b/ql-rpc/src/rpc/duplex/codec.rs new file mode 100644 index 00000000..68bc87c7 --- /dev/null +++ b/ql-rpc/src/rpc/duplex/codec.rs @@ -0,0 +1,86 @@ +use std::marker::PhantomData; + +use bytes::{BufMut, Bytes}; + +use crate::{codec, CodecError, RpcCodec}; + +pub fn encode_event(event: &T, out: &mut (impl BufMut + AsMut<[u8]>)) +where + T: RpcCodec, +{ + codec::encode_value_part(event, out) +} + +pub enum ReadStep { + NeedMore, + Event(T), +} + +pub struct EventReader { + bytes: codec::ChunkQueue, + marker: PhantomData T>, +} + +impl Default for EventReader { + fn default() -> Self { + Self { + bytes: codec::ChunkQueue::default(), + marker: PhantomData, + } + } +} + +impl EventReader { + pub fn push(&mut self, chunk: Bytes) { + self.bytes.push(chunk); + } + + pub fn is_empty(&self) -> bool { + self.bytes.remaining() == 0 + } + + pub fn advance(&mut self) -> Result, CodecError> { + let Some(mut body) = self.bytes.try_take_part()? else { + return Ok(ReadStep::NeedMore); + }; + + let value = { + let value = T::decode_value(&mut body).map_err(CodecError::Codec)?; + drop(body); + value + }; + Ok(ReadStep::Event(value)) + } +} + +#[cfg(test)] +mod tests { + use bytes::Bytes; + + use super::{encode_event, EventReader, ReadStep}; + + #[test] + fn event_reader_emits_multiple_events() { + let mut encoded = Vec::new(); + encode_event(&b"one".to_vec(), &mut encoded); + encode_event(&b"two".to_vec(), &mut encoded); + + let mut reader = EventReader::>::default(); + reader.push(Bytes::from(encoded)); + + match reader.advance().unwrap() { + ReadStep::Event(value) => { + assert_eq!(value, b"one".to_vec()); + } + _ => unreachable!(), + }; + + match reader.advance().unwrap() { + ReadStep::Event(value) => { + assert_eq!(value, b"two".to_vec()); + assert!(reader.is_empty()); + } + _ => unreachable!(), + } + } +} diff --git a/ql-rpc/src/rpc/duplex/mod.rs b/ql-rpc/src/rpc/duplex/mod.rs new file mode 100644 index 00000000..a9622029 --- /dev/null +++ b/ql-rpc/src/rpc/duplex/mod.rs @@ -0,0 +1,24 @@ +use super::Route; +use crate::RpcCodec; + +pub(crate) mod client; +pub(crate) mod codec; +pub(crate) mod server; + +pub use client::{DuplexCall, DuplexReceiver, DuplexSender}; +pub use codec::{encode_event, EventReader, ReadStep}; +pub use server::{DuplexHandler, DuplexHandlerLocal, DuplexPeer}; + +/// rpc where both sides exchange typed events on the same stream +/// +/// The initiator opens the routed stream. After that, either side may send any +/// number of events of its directional event type until it finishes or closes +/// its write side. +pub trait Duplex: Route { + /// codec error shared by both directional event values + type Error; + /// typed event sent by the side that opened the stream + type InitiatorEvent: RpcCodec; + /// typed event sent by the side handling the route + type ResponderEvent: RpcCodec; +} diff --git a/ql-rpc/src/rpc/duplex/server.rs b/ql-rpc/src/rpc/duplex/server.rs new file mode 100644 index 00000000..bf024335 --- /dev/null +++ b/ql-rpc/src/rpc/duplex/server.rs @@ -0,0 +1,47 @@ +use std::future::Future; + +use crate::{ + duplex::{Duplex, DuplexReceiver, DuplexSender}, + RpcRead, RpcStream, RpcWrite, +}; + +#[trait_variant::make(DuplexHandler: Send)] +pub trait DuplexHandlerLocal +where + M: Duplex, + St: RpcStream, +{ + async fn handle(self, peer: DuplexPeer); +} + +pub struct DuplexPeer +where + M: Duplex, + W: RpcWrite, + R: RpcRead, +{ + pub sender: DuplexSender, + pub receiver: DuplexReceiver, +} + +pub(crate) async fn handle_duplex_inner( + state: S, + _config: crate::RouterConfig, + reader: St::Reader, + writer: St::Writer, + handle: H, +) where + M: Duplex + 'static, + St: RpcStream + 'static, + H: FnOnce(S, DuplexPeer) -> HF, + HF: Future, +{ + handle( + state, + DuplexPeer { + sender: DuplexSender::new(writer), + receiver: DuplexReceiver::new(reader), + }, + ) + .await; +} diff --git a/ql-rpc/src/rpc/mod.rs b/ql-rpc/src/rpc/mod.rs new file mode 100644 index 00000000..2d84f050 --- /dev/null +++ b/ql-rpc/src/rpc/mod.rs @@ -0,0 +1,32 @@ +//! rpc protocol families built on top of one stream per call +//! +//! each trait in this module names one rpc shape and the typed values that +//! travel on that stream +//! route dispatch uses [`crate::RouteId`] and the submodules provide the matching +//! client and server helpers for encoding, decoding, and handler glue + +use crate::RouteId; + +pub mod download; +pub mod duplex; +pub mod notification; +pub(crate) mod parts; +pub mod progress; +pub mod request; +pub mod subscription; +pub mod upload; +mod utils; + +pub trait Route { + /// route used to dispatch this rpc family + const ROUTE: RouteId; +} + +pub use download::Download; +pub use duplex::Duplex; +pub use notification::Notification; +pub use progress::Progress; +pub use request::Request; +pub use subscription::Subscription; +pub use upload::Upload; +use utils::*; diff --git a/ql-rpc/src/rpc/notification/client.rs b/ql-rpc/src/rpc/notification/client.rs new file mode 100644 index 00000000..72b6900a --- /dev/null +++ b/ql-rpc/src/rpc/notification/client.rs @@ -0,0 +1,10 @@ +use bytes::BufMut; + +use crate::{notification::Notification, RpcCodec}; + +pub fn encode_notification( + payload: &M::Payload, + out: &mut (impl BufMut + AsMut<[u8]>), +) { + payload.encode_value(out) +} diff --git a/ql-rpc/src/rpc/notification/mod.rs b/ql-rpc/src/rpc/notification/mod.rs new file mode 100644 index 00000000..4740a64f --- /dev/null +++ b/ql-rpc/src/rpc/notification/mod.rs @@ -0,0 +1,19 @@ +use super::Route; +use crate::RpcCodec; + +pub(crate) mod client; +pub(crate) mod server; + +pub use client::encode_notification; +pub use server::{NotificationHandler, NotificationHandlerLocal}; + +/// one-way rpc that carries a single typed payload and no typed response +/// +/// the server reads [`Self::Payload`] to eof and then closes the response side +/// of the stream +pub trait Notification: Route { + /// codec error for the notification payload + type Error; + /// typed payload emitted by the caller + type Payload: RpcCodec; +} diff --git a/ql-rpc/src/rpc/notification/server.rs b/ql-rpc/src/rpc/notification/server.rs new file mode 100644 index 00000000..c9a4fdba --- /dev/null +++ b/ql-rpc/src/rpc/notification/server.rs @@ -0,0 +1,48 @@ +use std::future::Future; + +use crate::{ + notification::Notification as NotificationRpc, rpc::read_eof_request, RouterConfig, RpcRead, + RpcStream, RpcWrite, StreamCloseCode, StreamError, +}; + +#[trait_variant::make(NotificationHandler: Send)] +pub trait NotificationHandlerLocal +where + M: NotificationRpc, + St: RpcStream, +{ + async fn handle(self, message: M::Payload); + + fn handle_transport_error(&self, _error: &St::Error) {} +} + +pub(crate) async fn handle_notification_inner( + state: S, + config: RouterConfig, + mut reader: St::Reader, + writer: St::Writer, + handle: H, + handle_transport_error: E, +) where + M: NotificationRpc + 'static, + St: RpcStream + 'static, + H: FnOnce(S, M::Payload) -> HF, + HF: Future, + E: FnOnce(&S, &St::Error), +{ + let notification = match read_eof_request::(&mut reader, config).await { + Ok(notification) => notification, + Err(error) => { + let code = error.close_code(); + handle_transport_error(&state, &error); + if let Some(code) = code { + reader.close(code); + writer.close(code); + } + return; + } + }; + + writer.close(StreamCloseCode::CANCELLED); + handle(state, notification).await; +} diff --git a/ql-rpc/src/rpc/parts.rs b/ql-rpc/src/rpc/parts.rs new file mode 100644 index 00000000..47ff1e87 --- /dev/null +++ b/ql-rpc/src/rpc/parts.rs @@ -0,0 +1,283 @@ +use std::marker::PhantomData; + +use bytes::{BufMut, Bytes}; + +use crate::{codec, ChunkQueue, CodecError, RpcCodec}; + +pub enum PartReadStep { + NeedMore, + PartHeader(H), + BodyBytes(Bytes), + EndPart, + Finish, +} + +pub struct PartFrameReader { + bytes: codec::ChunkQueue, + pending_frame: PendingFrame, + marker: PhantomData H>, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum PendingFrame { + None, + Control { kind: FrameKind, len: usize }, + Body { remaining: usize }, +} + +impl PendingFrame { + fn take(&mut self) -> Self { + std::mem::replace(self, Self::None) + } +} + +impl PartFrameReader { + pub fn new(bytes: ChunkQueue) -> Self { + Self { + bytes, + pending_frame: PendingFrame::None, + marker: PhantomData, + } + } + + pub fn push(&mut self, chunk: Bytes) { + self.bytes.push(chunk); + } + + pub fn advance(&mut self) -> Result, CodecError> { + loop { + match self.pending_frame.take() { + PendingFrame::Body { remaining } => { + if remaining == 0 { + continue; + } + + let Some(bytes) = self.bytes.pop_front(remaining) else { + self.pending_frame = PendingFrame::Body { remaining }; + return Ok(PartReadStep::NeedMore); + }; + + let remaining = remaining - bytes.len(); + self.pending_frame = if remaining == 0 { + PendingFrame::None + } else { + PendingFrame::Body { remaining } + }; + return Ok(PartReadStep::BodyBytes(bytes)); + } + PendingFrame::Control { kind, len } => { + let Some(mut body) = self.bytes.try_take_body(len) else { + self.pending_frame = PendingFrame::Control { kind, len }; + return Ok(PartReadStep::NeedMore); + }; + + match kind { + FrameKind::PartHeader => { + let value = H::decode_value(&mut body).map_err(CodecError::Codec)?; + return Ok(PartReadStep::PartHeader(value)); + } + FrameKind::BodyChunk => unreachable!("body chunk is not a control frame"), + FrameKind::EndPart => { + body.expect_empty()?; + return Ok(PartReadStep::EndPart); + } + FrameKind::Finish => { + body.expect_empty()?; + drop(body); + self.bytes.expect_empty()?; + return Ok(PartReadStep::Finish); + } + } + } + PendingFrame::None => { + let Some((kind, len)) = self + .bytes + .try_take_tagged_part_header() + .map_err(CodecError::Rpc)? + else { + return Ok(PartReadStep::NeedMore); + }; + + let kind = FrameKind::try_from(kind).map_err(CodecError::Rpc)?; + self.pending_frame = if kind == FrameKind::BodyChunk { + PendingFrame::Body { remaining: len } + } else { + PendingFrame::Control { kind, len } + }; + } + } + } + } +} + +pub fn encode_part_header(part_header: &H, out: &mut (impl BufMut + AsMut<[u8]>)) { + encode_tagged_value_part(FrameKind::PartHeader, part_header, out) +} + +pub fn encode_body_chunk(bytes: &Bytes, out: &mut (impl BufMut + AsMut<[u8]>)) { + encode_tagged_value_part(FrameKind::BodyChunk, bytes, out) +} + +pub fn encode_end_part(out: &mut (impl BufMut + AsMut<[u8]>)) { + encode_tagged_empty_part(FrameKind::EndPart, out) +} + +pub fn encode_finish(out: &mut (impl BufMut + AsMut<[u8]>)) { + encode_tagged_empty_part(FrameKind::Finish, out) +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[repr(u8)] +pub(super) enum FrameKind { + PartHeader = 1, + BodyChunk = 2, + EndPart = 3, + Finish = 4, +} + +impl FrameKind { + pub fn tag(self) -> u8 { + self as u8 + } +} + +impl TryFrom for FrameKind { + type Error = crate::Error; + + fn try_from(value: u8) -> Result { + match value { + x if x == Self::PartHeader.tag() => Ok(Self::PartHeader), + x if x == Self::BodyChunk.tag() => Ok(Self::BodyChunk), + x if x == Self::EndPart.tag() => Ok(Self::EndPart), + x if x == Self::Finish.tag() => Ok(Self::Finish), + other => Err(crate::Error::UnexpectedFrameKind(other)), + } + } +} + +fn encode_tagged_value_part>( + kind: FrameKind, + value: &T, + out: &mut B, +) { + out.put_u8(kind.tag()); + let payload_start = codec::reserve_length(out); + value.encode_value(out); + codec::backpatch_length(out, payload_start); +} + +fn encode_tagged_empty_part>(kind: FrameKind, out: &mut B) { + out.put_u8(kind.tag()); + let payload_start = codec::reserve_length(out); + codec::backpatch_length(out, payload_start); +} + +#[cfg(test)] +mod tests { + use bytes::Bytes; + + use super::{ + encode_body_chunk, encode_end_part, encode_finish, encode_part_header, PartFrameReader, + PartReadStep, + }; + + #[test] + fn part_reader_emits_multipart_sequence() { + let mut encoded = Vec::new(); + encode_part_header(&b"a.txt".to_vec(), &mut encoded); + encode_body_chunk(&Bytes::from_static(b"hel"), &mut encoded); + encode_body_chunk(&Bytes::from_static(b"lo"), &mut encoded); + encode_end_part(&mut encoded); + encode_part_header(&b"b.txt".to_vec(), &mut encoded); + encode_end_part(&mut encoded); + encode_finish(&mut encoded); + + let mut reader = PartFrameReader::>::new(Default::default()); + reader.push(Bytes::from(encoded)); + + match reader.advance().unwrap() { + PartReadStep::PartHeader(value) => { + assert_eq!(value, b"a.txt".to_vec()); + } + _ => unreachable!(), + }; + + match reader.advance().unwrap() { + PartReadStep::BodyBytes(bytes) => assert_eq!(bytes, Bytes::from_static(b"hel")), + _ => unreachable!(), + }; + + match reader.advance().unwrap() { + PartReadStep::BodyBytes(bytes) => assert_eq!(bytes, Bytes::from_static(b"lo")), + _ => unreachable!(), + }; + + match reader.advance().unwrap() { + PartReadStep::EndPart => {} + _ => unreachable!(), + }; + + match reader.advance().unwrap() { + PartReadStep::PartHeader(value) => { + assert_eq!(value, b"b.txt".to_vec()); + } + _ => unreachable!(), + }; + + match reader.advance().unwrap() { + PartReadStep::EndPart => {} + _ => unreachable!(), + }; + + match reader.advance().unwrap() { + PartReadStep::Finish => {} + _ => unreachable!(), + } + } + + #[test] + fn part_reader_waits_for_complete_header_frame() { + let mut encoded = Vec::new(); + encode_part_header(&b"a.txt".to_vec(), &mut encoded); + let encoded = Bytes::from(encoded); + + let mut reader = PartFrameReader::>::new(Default::default()); + reader.push(encoded.slice(..4)); + match reader.advance().unwrap() { + PartReadStep::NeedMore => {} + _ => unreachable!(), + }; + + reader.push(encoded.slice(4..)); + match reader.advance().unwrap() { + PartReadStep::PartHeader(value) => assert_eq!(value, b"a.txt".to_vec()), + _ => unreachable!(), + } + } + + #[test] + fn body_chunk_frame_streams_after_header() { + let mut encoded = Vec::new(); + encode_body_chunk(&Bytes::from_static(b"hello"), &mut encoded); + let encoded = Bytes::from(encoded); + + let mut reader = PartFrameReader::>::new(Default::default()); + reader.push(encoded.slice(..9)); + match reader.advance().unwrap() { + PartReadStep::NeedMore => {} + _ => unreachable!(), + }; + + reader.push(encoded.slice(9..11)); + match reader.advance().unwrap() { + PartReadStep::BodyBytes(bytes) => assert_eq!(bytes, Bytes::from_static(b"he")), + _ => unreachable!(), + }; + + reader.push(encoded.slice(11..)); + match reader.advance().unwrap() { + PartReadStep::BodyBytes(bytes) => assert_eq!(bytes, Bytes::from_static(b"llo")), + _ => unreachable!(), + }; + } +} diff --git a/ql-rpc/src/rpc/progress/client.rs b/ql-rpc/src/rpc/progress/client.rs new file mode 100644 index 00000000..c2218c97 --- /dev/null +++ b/ql-rpc/src/rpc/progress/client.rs @@ -0,0 +1,152 @@ +use std::{ + future::{poll_fn, Future}, + pin::Pin, + task::{Context, Poll}, +}; + +use crate::{ + progress::{Progress, ReadStep, ResponseReader}, + CallError, Error, RpcRead, StreamCloseCode, +}; + +pub struct ProgressCall +where + M: Progress, + R: RpcRead, +{ + stream: Option, + state: State, +} + +enum State +where + M: Progress, +{ + Invalid, + Reading(ResponseReader), + Terminal(Result>), + Done, +} + +impl Unpin for ProgressCall +where + M: Progress, + R: RpcRead, +{ +} + +impl ProgressCall +where + M: Progress, + R: RpcRead, +{ + pub fn new(stream: R) -> Self { + Self { + stream: Some(stream), + state: State::Reading(ResponseReader::default()), + } + } + + pub async fn next_progress(&mut self) -> Option { + poll_fn(|cx| self.poll_next_progress(cx)).await + } + + fn poll_step(&mut self, cx: &mut Context<'_>) -> Poll> { + loop { + let reader = match &mut self.state { + State::Reading(reader) => reader, + State::Terminal(_) | State::Done => return Poll::Ready(None), + State::Invalid => panic!("invalid state"), + }; + + match reader.advance() { + Ok(ReadStep::Progress(value)) => return Poll::Ready(Some(value)), + Ok(ReadStep::Response(response)) => { + self.state = State::Terminal(Ok(response)); + return Poll::Ready(None); + } + Ok(ReadStep::NeedMore) => {} + Err(error) => { + self.state = State::Terminal(Err(error.into())); + return Poll::Ready(None); + } + } + + let stream = self.stream.as_mut().unwrap(); + match stream.poll_read(usize::MAX, cx) { + Poll::Ready(Ok(Some(chunk))) => { + let State::Reading(reader) = &mut self.state else { + panic!("invalid state"); + }; + reader.push(chunk); + } + Poll::Ready(Ok(None)) => { + self.state = State::Terminal(Err(Error::MissingResponse.into())); + return Poll::Ready(None); + } + Poll::Ready(Err(error)) => { + self.state = State::Terminal(Err(CallError::Transport(error))); + return Poll::Ready(None); + } + Poll::Pending => return Poll::Pending, + } + } + } + + pub fn poll_next_progress(&mut self, cx: &mut Context<'_>) -> Poll> { + self.poll_step(cx) + } + + pub fn close(mut self, code: StreamCloseCode) { + self.close_inner(code); + } + + fn close_inner(&mut self, code: StreamCloseCode) { + self.state = State::Done; + if let Some(stream) = self.stream.take() { + stream.close(code); + } + } +} + +impl Drop for ProgressCall +where + M: Progress, + R: RpcRead, +{ + fn drop(&mut self) { + if matches!(self.state, State::Reading(_)) { + self.close_inner(StreamCloseCode::CANCELLED); + } + } +} + +impl Future for ProgressCall +where + M: Progress, + R: RpcRead, +{ + type Output = Result>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.get_mut(); + + loop { + match this.poll_step(cx) { + Poll::Ready(Some(_)) => {} + Poll::Ready(None) => match std::mem::replace(&mut this.state, State::Invalid) { + State::Terminal(result) => { + this.state = State::Done; + return Poll::Ready(result); + } + State::Done => panic!("polled after completion"), + State::Invalid => panic!("polled during state transition"), + State::Reading(_) => { + panic!("progress call reached terminal step without result") + } + }, + Poll::Pending => return Poll::Pending, + } + } + } +} diff --git a/ql-rpc/src/rpc/progress/codec.rs b/ql-rpc/src/rpc/progress/codec.rs new file mode 100644 index 00000000..a0dc1b8c --- /dev/null +++ b/ql-rpc/src/rpc/progress/codec.rs @@ -0,0 +1,145 @@ +use std::marker::PhantomData; + +use bytes::{BufMut, Bytes}; + +use crate::{codec, progress::Progress, CodecError, Error, RpcCodec}; + +pub enum ReadStep { + NeedMore, + Progress(M::Progress), + Response(M::Response), +} + +pub struct ResponseReader { + bytes: codec::ChunkQueue, + marker: PhantomData M>, +} + +impl Default for ResponseReader { + fn default() -> Self { + Self { + bytes: codec::ChunkQueue::default(), + marker: PhantomData, + } + } +} + +impl ResponseReader { + pub fn push(&mut self, chunk: Bytes) { + self.bytes.push(chunk); + } + + pub fn advance(&mut self) -> Result, CodecError> { + let Some((kind, mut body)) = self.bytes.try_take_tagged_part().map_err(CodecError::Rpc)? + else { + return Ok(ReadStep::NeedMore); + }; + + match kind { + x if x == FrameKind::Progress as u8 => { + let value = { + let value = M::Progress::decode_value(&mut body).map_err(CodecError::Codec)?; + drop(body); + value + }; + Ok(ReadStep::Progress(value)) + } + x if x == FrameKind::Response as u8 => { + let response = M::Response::decode_value(&mut body).map_err(CodecError::Codec)?; + drop(body); + if self.bytes.remaining() > 0 { + Err(CodecError::Rpc(Error::TrailingBytes)) + } else { + Ok(ReadStep::Response(response)) + } + } + other => Err(CodecError::Rpc(Error::UnexpectedFrameKind(other))), + } + } +} + +pub fn encode_request(request: &M::Request, out: &mut (impl BufMut + AsMut<[u8]>)) { + codec::encode_value_part(request, out) +} + +pub fn encode_progress(progress: &M::Progress, out: &mut (impl BufMut + AsMut<[u8]>)) { + encode_tagged_value_part(FrameKind::Progress, progress, out) +} + +pub fn encode_response(response: &M::Response, out: &mut (impl BufMut + AsMut<[u8]>)) { + encode_tagged_value_part(FrameKind::Response, response, out) +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[repr(u8)] +enum FrameKind { + Progress = 1, + Response = 2, +} + +fn encode_tagged_value_part>( + kind: FrameKind, + value: &T, + out: &mut B, +) { + out.put_u8(kind as u8); + let payload_start = codec::reserve_length(out); + value.encode_value(out); + codec::backpatch_length(out, payload_start); +} + +#[cfg(test)] +mod tests { + use bytes::Bytes; + + use super::{encode_progress, encode_response, ReadStep, ResponseReader}; + use crate::{progress::Progress, Route, RouteId}; + + struct Watch; + + impl Route for Watch { + const ROUTE: RouteId = RouteId::from_u32(11); + } + + impl Progress for Watch { + type Error = core::convert::Infallible; + type Request = Vec; + type Progress = Vec; + type Response = Vec; + } + + #[test] + fn response_reader_emits_progress_then_response() { + let mut encoded = Vec::new(); + encode_progress::(&b"10%".to_vec(), &mut encoded); + encode_response::(&b"done".to_vec(), &mut encoded); + + let mut reader = ResponseReader::::default(); + reader.push(Bytes::from(encoded)); + + match reader.advance().unwrap() { + ReadStep::Progress(value) => { + assert_eq!(value, b"10%".to_vec()); + } + _ => unreachable!(), + }; + match reader.advance().unwrap() { + ReadStep::Response(value) => assert_eq!(value, b"done".to_vec()), + _ => unreachable!(), + } + } + + #[test] + fn response_reader_handles_response_only() { + let mut encoded = Vec::new(); + encode_response::(&b"done".to_vec(), &mut encoded); + + let mut reader = ResponseReader::::default(); + reader.push(Bytes::from(encoded)); + + match reader.advance().unwrap() { + ReadStep::Response(value) => assert_eq!(value, b"done".to_vec()), + _ => unreachable!(), + } + } +} diff --git a/ql-rpc/src/rpc/progress/mod.rs b/ql-rpc/src/rpc/progress/mod.rs new file mode 100644 index 00000000..b21c826d --- /dev/null +++ b/ql-rpc/src/rpc/progress/mod.rs @@ -0,0 +1,27 @@ +use super::Route; +use crate::RpcCodec; + +pub(crate) mod client; +pub(crate) mod codec; +pub(crate) mod server; + +pub use client::ProgressCall; +pub use codec::{encode_progress, encode_request, encode_response, ReadStep, ResponseReader}; +pub use server::{ProgressHandler, ProgressHandlerLocal, ProgressResponder}; + +/// rpc where the responder streams progress values before a final response +/// +/// the request is length-delimited +/// response frames are tagged so the client can distinguish +/// [`Self::Progress`] items from the final [`Self::Response`] +/// reaching eof before the final response is an error +pub trait Progress: Route { + /// codec error shared by request, progress, and response values + type Error; + /// typed input sent by the caller + type Request: RpcCodec; + /// typed progress item emitted before completion + type Progress: RpcCodec; + /// typed terminal response that completes the call + type Response: RpcCodec; +} diff --git a/ql-rpc/src/rpc/progress/server.rs b/ql-rpc/src/rpc/progress/server.rs new file mode 100644 index 00000000..b94421cf --- /dev/null +++ b/ql-rpc/src/rpc/progress/server.rs @@ -0,0 +1,106 @@ +use std::{future::Future, marker::PhantomData}; + +use bytes::Bytes; + +use crate::{ + finish_bytes, + progress::{encode_progress, encode_response, Progress}, + rpc::read_framed_request, + write_bytes, RouterConfig, RpcRead, RpcStream, RpcWrite, StreamCloseCode, StreamError, +}; + +#[trait_variant::make(ProgressHandler: Send)] +pub trait ProgressHandlerLocal +where + M: Progress, + St: RpcStream, +{ + async fn handle(self, request: M::Request, responder: ProgressResponder); + + fn handle_transport_error(&self, _error: &St::Error) {} +} + +pub struct ProgressResponder +where + M: Progress, + W: RpcWrite, +{ + writer: Option, + marker: PhantomData M>, +} + +impl ProgressResponder +where + M: Progress, + W: RpcWrite, +{ + pub(crate) fn new(writer: W) -> Self { + Self { + writer: Some(writer), + marker: PhantomData, + } + } + + pub async fn send(&mut self, progress: M::Progress) -> Result<(), W::Error> { + let writer = self.writer.as_mut().unwrap(); + let mut encoded = Vec::new(); + encode_progress::(&progress, &mut encoded); + write_bytes(writer, Bytes::from(encoded)).await + } + + pub async fn finish(mut self, response: M::Response) -> Result<(), W::Error> { + let mut writer = self.writer.take().unwrap(); + let mut encoded = Vec::new(); + encode_response::(&response, &mut encoded); + write_bytes(&mut writer, Bytes::from(encoded)).await?; + finish_bytes(&mut writer).await + } + + pub fn close(mut self, code: StreamCloseCode) { + if let Some(writer) = self.writer.take() { + writer.close(code); + } + } +} + +impl Drop for ProgressResponder +where + M: Progress, + W: RpcWrite, +{ + fn drop(&mut self) { + if let Some(writer) = self.writer.take() { + writer.close(StreamCloseCode::CANCELLED); + } + } +} + +pub(crate) async fn handle_progress_inner( + state: S, + config: RouterConfig, + mut reader: St::Reader, + writer: St::Writer, + handle: H, + handle_transport_error: E, +) where + M: Progress + 'static, + St: RpcStream + 'static, + H: FnOnce(S, M::Request, ProgressResponder) -> HF, + HF: Future, + E: FnOnce(&S, &St::Error), +{ + let request = match read_framed_request::(&mut reader, config).await { + Ok(request) => request, + Err(error) => { + let code = error.close_code(); + handle_transport_error(&state, &error); + if let Some(code) = code { + reader.close(code); + writer.close(code); + } + return; + } + }; + + handle(state, request, ProgressResponder::new(writer)).await; +} diff --git a/ql-rpc/src/rpc/request/client.rs b/ql-rpc/src/rpc/request/client.rs new file mode 100644 index 00000000..e7ffb845 --- /dev/null +++ b/ql-rpc/src/rpc/request/client.rs @@ -0,0 +1,34 @@ +use bytes::BufMut; + +use crate::{read_bytes, request::Request, CallError, ChunkQueue, RpcCodec, RpcRead}; + +pub fn encode_request(request: &M::Request, out: &mut (impl BufMut + AsMut<[u8]>)) { + request.encode_value(out) +} + +pub fn encode_response(response: &M::Response, out: &mut (impl BufMut + AsMut<[u8]>)) { + response.encode_value(out) +} + +pub async fn read_response( + mut reader: R, +) -> Result> +where + M: Request, + R: RpcRead, +{ + let mut bytes = ChunkQueue::default(); + + while let Some(chunk) = read_bytes(&mut reader, usize::MAX) + .await + .map_err(CallError::Transport)? + { + bytes.push(chunk); + } + + let value = M::Response::decode_value(&mut bytes).map_err(CallError::Codec)?; + if bytes.remaining() > 0 { + return Err(crate::Error::TrailingBytes.into()); + } + Ok(value) +} diff --git a/ql-rpc/src/rpc/request/mod.rs b/ql-rpc/src/rpc/request/mod.rs new file mode 100644 index 00000000..adf32597 --- /dev/null +++ b/ql-rpc/src/rpc/request/mod.rs @@ -0,0 +1,23 @@ +use super::Route; +use crate::RpcCodec; + +pub(crate) mod client; +pub(crate) mod server; + +pub use client::{encode_request, encode_response, read_response}; +pub use server::{RequestHandler, RequestHandlerLocal, Response}; + +/// request-response rpc with exactly one typed value in each direction +/// +/// the request is read to eof on the server side, so callers must finish the +/// request stream after encoding [`Self::Request`] +/// the response is also read to eof and rejects trailing bytes after +/// [`Self::Response`] +pub trait Request: Route { + /// codec error shared by request and response values + type Error; + /// typed input sent by the caller + type Request: RpcCodec; + /// typed output returned by the responder + type Response: RpcCodec; +} diff --git a/ql-rpc/src/rpc/request/server.rs b/ql-rpc/src/rpc/request/server.rs new file mode 100644 index 00000000..5211cce2 --- /dev/null +++ b/ql-rpc/src/rpc/request/server.rs @@ -0,0 +1,96 @@ +use std::{future::Future, marker::PhantomData}; + +use bytes::Bytes; + +use crate::{ + finish_bytes, request::Request as RequestRpc, rpc::read_eof_request, write_bytes, RouterConfig, + RpcCodec, RpcRead, RpcStream, RpcWrite, StreamCloseCode, StreamError, +}; + +#[trait_variant::make(RequestHandler: Send)] +pub trait RequestHandlerLocal +where + M: RequestRpc, + St: RpcStream, +{ + async fn handle(self, message: M::Request, responder: Response); + + fn handle_transport_error(&self, _error: &St::Error) {} +} + +pub struct Response +where + W: RpcWrite, +{ + writer: Option, + marker: PhantomData T>, +} + +impl Response +where + T: RpcCodec, + W: RpcWrite, +{ + pub(crate) fn new(writer: W) -> Self { + Self { + writer: Some(writer), + marker: PhantomData, + } + } + + pub async fn respond(mut self, response: T) -> Result<(), W::Error> { + let mut writer = self.writer.take().unwrap(); + let mut encoded = Vec::new(); + response.encode_value(&mut encoded); + write_bytes(&mut writer, Bytes::from(encoded)).await?; + finish_bytes(&mut writer).await?; + Ok(()) + } + + pub fn close(mut self, code: StreamCloseCode) { + if let Some(writer) = self.writer.take() { + writer.close(code); + } + } +} + +impl Drop for Response +where + W: RpcWrite, +{ + fn drop(&mut self) { + if let Some(writer) = self.writer.take() { + writer.close(StreamCloseCode::CANCELLED); + } + } +} + +pub(crate) async fn handle_request_inner( + state: S, + config: RouterConfig, + mut reader: St::Reader, + writer: St::Writer, + handle: H, + handle_transport_error: E, +) where + M: RequestRpc + 'static, + St: RpcStream + 'static, + H: FnOnce(S, M::Request, Response) -> HF, + HF: Future, + E: FnOnce(&S, &St::Error), +{ + let request = match read_eof_request::(&mut reader, config).await { + Ok(request) => request, + Err(error) => { + let code = error.close_code(); + handle_transport_error(&state, &error); + if let Some(code) = code { + reader.close(code); + writer.close(code); + } + return; + } + }; + + handle(state, request, Response::new(writer)).await; +} diff --git a/ql-rpc/src/rpc/subscription/client.rs b/ql-rpc/src/rpc/subscription/client.rs new file mode 100644 index 00000000..fe6aa5b1 --- /dev/null +++ b/ql-rpc/src/rpc/subscription/client.rs @@ -0,0 +1,99 @@ +use std::{ + future::poll_fn, + task::{Context, Poll}, +}; + +use crate::{ + subscription::{ReadStep, ResponseReader, Subscription}, + CallError, RpcRead, StreamCloseCode, +}; + +pub struct SubscriptionCall +where + M: Subscription, + R: RpcRead, +{ + stream: Option, + reader: ResponseReader, +} + +impl SubscriptionCall +where + M: Subscription, + R: RpcRead, +{ + pub fn new(stream: R) -> Self { + Self { + stream: Some(stream), + reader: ResponseReader::default(), + } + } + + pub async fn next_event(&mut self) -> Option>> { + poll_fn(|cx| self.poll_next_event(cx)).await + } + + pub fn poll_next_event( + &mut self, + cx: &mut Context<'_>, + ) -> Poll>>> { + if self.stream.is_none() { + return Poll::Ready(None); + } + + loop { + match self.reader.advance() { + Ok(ReadStep::Item(value)) => return Poll::Ready(Some(Ok(value))), + Ok(ReadStep::NeedMore) => {} + Err(error) => { + self.stream.take(); + return Poll::Ready(Some(Err(error.into()))); + } + } + + let stream = self.stream.as_mut().unwrap(); + match stream.poll_read(usize::MAX, cx) { + Poll::Ready(Ok(Some(chunk))) => { + self.reader.push(chunk); + } + Poll::Ready(Ok(None)) => { + if self.reader.is_empty() { + self.stream.take(); + return Poll::Ready(None); + } + self.stream.take(); + return Poll::Ready(Some(Err(crate::Error::Truncated.into()))); + } + Poll::Ready(Err(error)) => { + self.stream.take(); + return Poll::Ready(Some(Err(CallError::Transport(error)))); + } + Poll::Pending => { + return Poll::Pending; + } + } + } + } + + pub fn close(mut self, code: StreamCloseCode) { + self.close_inner(code); + } + + fn close_inner(&mut self, code: StreamCloseCode) { + if let Some(stream) = self.stream.take() { + stream.close(code); + } + } +} + +impl Drop for SubscriptionCall +where + M: Subscription, + R: RpcRead, +{ + fn drop(&mut self) { + if self.stream.is_some() { + self.close_inner(StreamCloseCode::CANCELLED); + } + } +} diff --git a/ql-rpc/src/rpc/subscription/codec.rs b/ql-rpc/src/rpc/subscription/codec.rs new file mode 100644 index 00000000..bdd16209 --- /dev/null +++ b/ql-rpc/src/rpc/subscription/codec.rs @@ -0,0 +1,58 @@ +use std::marker::PhantomData; + +use bytes::{BufMut, Bytes}; + +use crate::{codec, subscription::Subscription, CodecError, RpcCodec}; + +pub fn encode_request( + request: &M::Request, + out: &mut (impl BufMut + AsMut<[u8]>), +) { + request.encode_value(out) +} + +pub fn encode_item(item: &M::Event, out: &mut (impl BufMut + AsMut<[u8]>)) { + codec::encode_value_part(item, out) +} + +pub enum ReadStep { + NeedMore, + Item(M::Event), +} + +pub struct ResponseReader { + bytes: codec::ChunkQueue, + marker: PhantomData M>, +} + +impl Default for ResponseReader { + fn default() -> Self { + Self { + bytes: codec::ChunkQueue::default(), + marker: PhantomData, + } + } +} + +impl ResponseReader { + pub fn push(&mut self, chunk: Bytes) { + self.bytes.push(chunk); + } + + pub fn is_empty(&self) -> bool { + self.bytes.remaining() == 0 + } + + pub fn advance(&mut self) -> Result, CodecError> { + let Some(mut body) = self.bytes.try_take_part()? else { + return Ok(ReadStep::NeedMore); + }; + + let item = { + let item = M::Event::decode_value(&mut body).map_err(CodecError::Codec)?; + drop(body); + item + }; + Ok(ReadStep::Item(item)) + } +} diff --git a/ql-rpc/src/rpc/subscription/mod.rs b/ql-rpc/src/rpc/subscription/mod.rs new file mode 100644 index 00000000..672eb9bc --- /dev/null +++ b/ql-rpc/src/rpc/subscription/mod.rs @@ -0,0 +1,23 @@ +use super::Route; +use crate::RpcCodec; + +pub(crate) mod client; +pub(crate) mod codec; +pub(crate) mod server; + +pub use client::SubscriptionCall; +pub use codec::{encode_item, encode_request, ReadStep, ResponseReader}; +pub use server::{SubscriptionHandler, SubscriptionHandlerLocal, SubscriptionResponder}; + +/// rpc where one request opens a stream of typed events +/// +/// event frames are length-delimited and the stream ends cleanly at eof +/// any partial trailing frame is reported as truncation on the client side +pub trait Subscription: Route { + /// codec error shared by request and event values + type Error; + /// typed input that starts the subscription + type Request: RpcCodec; + /// typed event yielded by the responder + type Event: RpcCodec; +} diff --git a/ql-rpc/src/rpc/subscription/server.rs b/ql-rpc/src/rpc/subscription/server.rs new file mode 100644 index 00000000..6dfdd4b0 --- /dev/null +++ b/ql-rpc/src/rpc/subscription/server.rs @@ -0,0 +1,105 @@ +use std::{future::Future, marker::PhantomData}; + +use bytes::Bytes; + +use crate::{ + codec, finish_bytes, rpc::read_eof_request, subscription::Subscription as SubscriptionRpc, + write_bytes, RouterConfig, RpcCodec, RpcRead, RpcStream, RpcWrite, StreamCloseCode, + StreamError, +}; + +#[trait_variant::make(SubscriptionHandler: Send)] +pub trait SubscriptionHandlerLocal +where + M: SubscriptionRpc, + St: RpcStream, +{ + async fn handle( + self, + message: M::Request, + responder: SubscriptionResponder, + ); + + fn handle_transport_error(&self, _error: &St::Error) {} +} + +pub struct SubscriptionResponder +where + W: RpcWrite, +{ + writer: Option, + marker: PhantomData T>, +} + +impl SubscriptionResponder +where + T: RpcCodec, + W: RpcWrite, +{ + pub(crate) fn new(writer: W) -> Self { + Self { + writer: Some(writer), + marker: PhantomData, + } + } + + pub async fn send(&mut self, event: T) -> Result<(), W::Error> { + let writer = self.writer.as_mut().unwrap(); + let mut encoded = Vec::new(); + codec::encode_value_part(&event, &mut encoded); + write_bytes(writer, Bytes::from(encoded)).await?; + Ok(()) + } + + pub async fn finish(mut self) -> Result<(), W::Error> { + let mut writer = self.writer.take().unwrap(); + finish_bytes(&mut writer).await + } + + pub fn close(mut self, code: StreamCloseCode) { + if let Some(writer) = self.writer.take() { + writer.close(code); + } + } +} + +impl Drop for SubscriptionResponder +where + W: RpcWrite, +{ + fn drop(&mut self) { + if let Some(writer) = self.writer.take() { + writer.close(StreamCloseCode::CANCELLED); + } + } +} + +pub(crate) async fn handle_subscription_inner( + state: S, + config: RouterConfig, + mut reader: St::Reader, + writer: St::Writer, + handle: H, + handle_transport_error: E, +) where + M: SubscriptionRpc + 'static, + St: RpcStream + 'static, + H: FnOnce(S, M::Request, SubscriptionResponder) -> HF, + HF: Future, + E: FnOnce(&S, &St::Error), +{ + let request = match read_eof_request::(&mut reader, config).await { + Ok(request) => request, + Err(error) => { + let code = error.close_code(); + handle_transport_error(&state, &error); + if let Some(code) = code { + reader.close(code); + writer.close(code); + } + return; + } + }; + + handle(state, request, SubscriptionResponder::new(writer)).await; +} diff --git a/ql-rpc/src/rpc/upload/client.rs b/ql-rpc/src/rpc/upload/client.rs new file mode 100644 index 00000000..b41dedcd --- /dev/null +++ b/ql-rpc/src/rpc/upload/client.rs @@ -0,0 +1,146 @@ +use bytes::{BufMut, Bytes}; + +use crate::{ + finish_bytes, read_bytes, + rpc::parts::{encode_body_chunk, encode_end_part, encode_finish, encode_part_header}, + upload::Upload, + write_bytes, CallError, ChunkQueue, RpcCodec, RpcRead, RpcWrite, StreamCloseCode, +}; + +pub struct UploadCall +where + M: Upload, + W: RpcWrite, + R: RpcRead, +{ + writer: Option, + reader: Option, + marker: std::marker::PhantomData M>, +} + +pub struct UploadPartWriter<'a, M, W, R> +where + M: Upload, + W: RpcWrite, + R: RpcRead, +{ + parent: &'a mut UploadCall, + finished: bool, +} + +impl UploadCall +where + M: Upload, + W: RpcWrite, + R: RpcRead, +{ + pub fn new(writer: W, reader: R) -> Self { + Self { + writer: Some(writer), + reader: Some(reader), + marker: std::marker::PhantomData, + } + } + + pub async fn start_part( + &mut self, + part_header: M::PartHeader, + ) -> Result, W::Error> { + let writer = self.writer.as_mut().unwrap(); + let mut encoded = Vec::new(); + encode_part_header(&part_header, &mut encoded); + write_bytes(writer, Bytes::from(encoded)).await?; + Ok(UploadPartWriter { + parent: self, + finished: false, + }) + } + + pub async fn finish(mut self) -> Result> { + let mut writer = self.writer.take().unwrap(); + let mut encoded = Vec::new(); + encode_finish(&mut encoded); + write_bytes(&mut writer, Bytes::from(encoded)) + .await + .map_err(CallError::Transport)?; + finish_bytes(&mut writer) + .await + .map_err(CallError::Transport)?; + + let mut reader = self.reader.take().unwrap(); + let mut bytes = ChunkQueue::default(); + + while let Some(chunk) = read_bytes(&mut reader, usize::MAX) + .await + .map_err(CallError::Transport)? + { + bytes.push(chunk); + } + + let value = M::Response::decode_value(&mut bytes).map_err(CallError::Codec)?; + if bytes.remaining() > 0 { + return Err(crate::Error::TrailingBytes.into()); + } + Ok(value) + } + + fn close(&mut self, code: StreamCloseCode) { + if let Some(reader) = self.reader.take() { + reader.close(code); + } + if let Some(writer) = self.writer.take() { + writer.close(code); + } + } +} + +impl Drop for UploadCall +where + M: Upload, + W: RpcWrite, + R: RpcRead, +{ + fn drop(&mut self) { + self.close(StreamCloseCode::CANCELLED); + } +} + +impl UploadPartWriter<'_, M, W, R> +where + M: Upload, + W: RpcWrite, + R: RpcRead, +{ + pub async fn send(&mut self, bytes: Bytes) -> Result<(), W::Error> { + let writer = self.parent.writer.as_mut().unwrap(); + let mut encoded = Vec::new(); + encode_body_chunk(&bytes, &mut encoded); + write_bytes(writer, Bytes::from(encoded)).await + } + + pub async fn finish(mut self) -> Result<(), W::Error> { + let writer = self.parent.writer.as_mut().unwrap(); + let mut encoded = Vec::new(); + encode_end_part(&mut encoded); + write_bytes(writer, Bytes::from(encoded)).await?; + self.finished = true; + Ok(()) + } +} + +impl Drop for UploadPartWriter<'_, M, W, R> +where + M: Upload, + W: RpcWrite, + R: RpcRead, +{ + fn drop(&mut self) { + if !self.finished { + self.parent.close(StreamCloseCode::CANCELLED); + } + } +} + +pub fn encode_request(request: &M::Request, out: &mut (impl BufMut + AsMut<[u8]>)) { + crate::codec::encode_value_part(request, out) +} diff --git a/ql-rpc/src/rpc/upload/mod.rs b/ql-rpc/src/rpc/upload/mod.rs new file mode 100644 index 00000000..9f96a824 --- /dev/null +++ b/ql-rpc/src/rpc/upload/mod.rs @@ -0,0 +1,26 @@ +use super::Route; +use crate::RpcCodec; + +pub(crate) mod client; +pub(crate) mod server; + +pub use client::{encode_request, UploadCall, UploadPartWriter}; +pub use server::{UploadHandler, UploadHandlerLocal, UploadPart, UploadReader, UploadResponder}; + +/// rpc where the caller uploads zero or more byte parts after a typed request +/// +/// the typed request usually describes how the responder should interpret the +/// following parts +/// the request is length-delimited so raw upload bytes can follow immediately +/// once the upload reaches eof, the responder returns one typed +/// [`Self::Response`] +pub trait Upload: Route { + /// codec error shared by request and response values + type Error; + /// typed input needed before request body bytes arrive + type Request: RpcCodec; + /// typed metadata available before each byte part arrives + type PartHeader: RpcCodec; + /// typed terminal result after the upload body is fully read + type Response: RpcCodec; +} diff --git a/ql-rpc/src/rpc/upload/server.rs b/ql-rpc/src/rpc/upload/server.rs new file mode 100644 index 00000000..d2e6765b --- /dev/null +++ b/ql-rpc/src/rpc/upload/server.rs @@ -0,0 +1,243 @@ +use std::future::{poll_fn, Future}; + +use bytes::Bytes; + +use crate::{ + request::Response, + rpc::{ + parts::{FrameKind, PartFrameReader, PartReadStep}, + read_framed_request_prefix, + }, + RouterConfig, RpcRead, RpcStream, RpcWrite, StreamCloseCode, StreamError, Upload, +}; + +#[trait_variant::make(UploadHandler: Send)] +pub trait UploadHandlerLocal +where + M: Upload, + St: RpcStream, +{ + async fn handle( + self, + request: M::Request, + upload: UploadReader, + responder: UploadResponder, + ); + + fn handle_transport_error(&self, _error: &St::Error) {} +} + +pub struct UploadReader +where + M: Upload, + R: RpcRead, +{ + stream: Option, + reader: PartFrameReader, +} + +pub struct UploadPart<'a, M, R> +where + M: Upload, + R: RpcRead, +{ + parent: &'a mut UploadReader, + finished: bool, +} + +pub struct UploadResponder +where + W: RpcWrite, +{ + inner: Response, +} + +impl UploadReader +where + M: Upload, + R: RpcRead, +{ + pub async fn next_part( + &mut self, + ) -> Result)>, crate::CallError> + { + if self.stream.is_none() { + return Ok(None); + } + + match self.read_frame().await? { + PartReadStep::PartHeader(value) => Ok(Some(( + value, + UploadPart { + parent: self, + finished: false, + }, + ))), + PartReadStep::Finish => { + self.stream.take(); + Ok(None) + } + PartReadStep::BodyBytes(_) => { + Err(crate::Error::UnexpectedFrameKind(FrameKind::BodyChunk.tag()).into()) + } + PartReadStep::EndPart => { + Err(crate::Error::UnexpectedFrameKind(FrameKind::EndPart.tag()).into()) + } + PartReadStep::NeedMore => unreachable!("read_frame waits for a complete frame"), + } + } + + async fn read_frame( + &mut self, + ) -> Result, crate::CallError> { + loop { + match self.reader.advance() { + Ok(PartReadStep::NeedMore) => {} + Ok(step) => return Ok(step), + Err(error) => return Err(error.into()), + } + + let stream = self.stream.as_mut().unwrap(); + match poll_fn(|cx| stream.poll_read(usize::MAX, cx)).await { + Ok(Some(chunk)) => { + self.reader.push(chunk); + } + Ok(None) => return Err(crate::Error::Truncated.into()), + Err(error) => return Err(crate::CallError::Transport(error)), + } + } + } + + pub fn close(mut self, code: StreamCloseCode) { + self.close_inner(code); + } + + fn close_inner(&mut self, code: StreamCloseCode) { + if let Some(stream) = self.stream.take() { + stream.close(code); + } + } +} + +impl Drop for UploadReader +where + M: Upload, + R: RpcRead, +{ + fn drop(&mut self) { + if self.stream.is_some() { + self.close_inner(StreamCloseCode::CANCELLED); + } + } +} + +impl UploadPart<'_, M, R> +where + M: Upload, + R: RpcRead, +{ + pub async fn read_chunk( + &mut self, + ) -> Result, crate::CallError> { + if self.finished { + return Ok(None); + } + + match self.parent.read_frame().await? { + PartReadStep::BodyBytes(bytes) => Ok(Some(bytes)), + PartReadStep::EndPart => { + self.finished = true; + Ok(None) + } + PartReadStep::PartHeader(_) => { + Err(crate::Error::UnexpectedFrameKind(FrameKind::PartHeader.tag()).into()) + } + PartReadStep::Finish => { + Err(crate::Error::UnexpectedFrameKind(FrameKind::Finish.tag()).into()) + } + PartReadStep::NeedMore => unreachable!("read_frame waits for a complete frame"), + } + } + + pub fn close(mut self, code: StreamCloseCode) { + self.parent.close_inner(code); + self.finished = true; + } +} + +impl Drop for UploadPart<'_, M, R> +where + M: Upload, + R: RpcRead, +{ + fn drop(&mut self) { + if !self.finished { + self.parent.close_inner(StreamCloseCode::CANCELLED); + } + } +} + +impl UploadResponder +where + T: crate::RpcCodec, + W: RpcWrite, +{ + pub(crate) fn new(writer: W) -> Self { + Self { + inner: Response::new(writer), + } + } + + pub async fn respond(self, response: T) -> Result<(), W::Error> { + self.inner.respond(response).await + } + + pub fn close(self, code: StreamCloseCode) { + self.inner.close(code); + } +} + +pub(crate) async fn handle_upload_inner( + state: S, + config: RouterConfig, + mut reader: St::Reader, + writer: St::Writer, + handle: H, + handle_transport_error: E, +) where + M: Upload + 'static, + St: RpcStream + 'static, + H: FnOnce( + S, + M::Request, + UploadReader, + UploadResponder, + ) -> HF, + HF: Future, + E: FnOnce(&S, &St::Error), +{ + let (request, buffered) = + match read_framed_request_prefix::(&mut reader, config).await { + Ok(value) => value, + Err(error) => { + let code = error.close_code(); + handle_transport_error(&state, &error); + if let Some(code) = code { + reader.close(code); + writer.close(code); + } + return; + } + }; + + handle( + state, + request, + UploadReader { + stream: Some(reader), + reader: PartFrameReader::new(buffered), + }, + UploadResponder::new(writer), + ) + .await; +} diff --git a/ql-rpc/src/rpc/utils.rs b/ql-rpc/src/rpc/utils.rs new file mode 100644 index 00000000..bf5f49ea --- /dev/null +++ b/ql-rpc/src/rpc/utils.rs @@ -0,0 +1,120 @@ +use crate::{ + read_bytes, ChunkQueue, CodecError, FramedPrefixStep, FramedReadStep, FramedReader, + RouterConfig, RpcCodec, RpcRead, StreamCloseCode, +}; + +/// reads one length-delimited value and rejects trailing bytes +pub(crate) async fn read_framed_request( + reader: &mut R, + config: RouterConfig, +) -> Result +where + T: RpcCodec, + R: RpcRead, +{ + let mut value_reader = FramedReader::::default(); + let mut total_read = 0usize; + + let value = loop { + match value_reader.advance() { + Ok(FramedReadStep::Value(value)) => break value, + Ok(FramedReadStep::NeedMore(next)) => value_reader = next, + Err(CodecError::Rpc(_error)) => return Err(StreamCloseCode::REFUSED.into()), + Err(CodecError::Codec(_error)) => return Err(StreamCloseCode::REFUSED.into()), + } + + let remaining = config.max_request_bytes.saturating_sub(total_read); + if remaining == 0 { + return Err(StreamCloseCode::LIMIT.into()); + } + + match read_bytes(reader, remaining).await { + Ok(Some(chunk)) => { + total_read += chunk.len(); + value_reader = value_reader.push(chunk); + } + Ok(None) => return Err(StreamCloseCode::REFUSED.into()), + Err(error) => return Err(error), + } + }; + + let remaining = config.max_request_bytes.saturating_sub(total_read); + let probe = remaining.max(1); + match read_bytes(reader, probe).await { + Ok(None) => Ok(value), + Ok(Some(_)) if remaining == 0 => Err(StreamCloseCode::LIMIT.into()), + Ok(Some(_)) => Err(StreamCloseCode::REFUSED.into()), + Err(error) => Err(error), + } +} + +/// reads one length-delimited value and returns any bytes already buffered +pub(crate) async fn read_framed_request_prefix( + reader: &mut R, + config: RouterConfig, +) -> Result<(T, ChunkQueue), R::Error> +where + T: RpcCodec, + R: RpcRead, +{ + let mut value_reader = FramedReader::::default(); + let mut total_read = 0usize; + + loop { + match value_reader.advance_prefix() { + Ok(FramedPrefixStep::Value { value, bytes }) => return Ok((value, bytes)), + Ok(FramedPrefixStep::NeedMore(next)) => value_reader = next, + Err(CodecError::Rpc(_error)) => return Err(StreamCloseCode::REFUSED.into()), + Err(CodecError::Codec(_error)) => return Err(StreamCloseCode::REFUSED.into()), + } + + let remaining = config.max_request_bytes.saturating_sub(total_read); + if remaining == 0 { + return Err(StreamCloseCode::LIMIT.into()); + } + + match read_bytes(reader, remaining).await { + Ok(Some(chunk)) => { + total_read += chunk.len(); + value_reader = value_reader.push(chunk); + } + Ok(None) => return Err(StreamCloseCode::REFUSED.into()), + Err(error) => return Err(error), + } + } +} + +/// reads one eof-delimited value up to the configured request limit +pub(crate) async fn read_eof_request( + reader: &mut R, + config: RouterConfig, +) -> Result +where + T: RpcCodec, + R: RpcRead, +{ + let mut bytes = ChunkQueue::default(); + let mut total_read = 0usize; + + loop { + let remaining = config.max_request_bytes.saturating_sub(total_read); + let probe = remaining.max(1); + match read_bytes(reader, probe).await { + Ok(Some(chunk)) => { + if chunk.len() > remaining { + return Err(StreamCloseCode::LIMIT.into()); + } + total_read += chunk.len(); + bytes.push(chunk); + } + Ok(None) => break, + Err(error) => return Err(error), + } + } + + let value = T::decode_value(&mut bytes).map_err(|_error| StreamCloseCode::REFUSED)?; + if bytes.remaining() > 0 { + return Err(StreamCloseCode::REFUSED.into()); + } + Ok(value) +} diff --git a/ql-rpc/src/stream.rs b/ql-rpc/src/stream.rs new file mode 100644 index 00000000..f6174efd --- /dev/null +++ b/ql-rpc/src/stream.rs @@ -0,0 +1,89 @@ +use std::{ + future::poll_fn, + task::{Context, Poll}, +}; + +use bytes::Bytes; + +use crate::{RouteId, StreamCloseCode}; + +pub trait RpcStream { + type Error: StreamError; + type Reader: RpcRead; + type Writer: RpcWrite; + + fn route_id(&self) -> Option; + fn split(self) -> (Self::Reader, Self::Writer); +} + +pub trait RpcRead { + type Error: StreamError; + + /// reads inbound bytes until eof or error + fn poll_read( + &mut self, + max_len: usize, + cx: &mut Context<'_>, + ) -> Poll, Self::Error>>; + + /// aborts the read side + fn close(self, code: StreamCloseCode); +} + +pub trait RpcWrite { + type Error: StreamError; + + /// writes outbound bytes before finish or close + fn poll_write( + &mut self, + bytes: &mut Bytes, + cx: &mut Context<'_>, + ) -> Poll>; + + /// completes the write side and must be polled until ready without further write or close calls + fn poll_finish(&mut self, cx: &mut Context<'_>) -> Poll>; + + /// aborts the write side before finish + fn close(self, code: StreamCloseCode); +} + +pub trait StreamError: From { + fn close_code(&self) -> Option; +} + +impl StreamError for StreamCloseCode { + fn close_code(&self) -> Option { + Some(*self) + } +} + +pub async fn read_bytes(reader: &mut R, max_len: usize) -> Result, R::Error> +where + R: RpcRead, +{ + poll_fn(|cx| reader.poll_read(max_len, cx)).await +} + +pub async fn write_bytes(writer: &mut W, bytes: Bytes) -> Result<(), W::Error> +where + W: RpcWrite, +{ + let mut bytes = bytes; + poll_fn(|cx| writer.poll_write(&mut bytes, cx)).await +} + +pub async fn finish_bytes(writer: &mut W) -> Result<(), W::Error> +where + W: RpcWrite, +{ + poll_fn(|cx| writer.poll_finish(cx)).await +} + +pub fn close_stream(stream: St, code: StreamCloseCode) +where + St: RpcStream, +{ + let (reader, writer) = stream.split(); + reader.close(code); + writer.close(code); +} diff --git a/ql-runtime/Cargo.toml b/ql-runtime/Cargo.toml new file mode 100644 index 00000000..56fd327f --- /dev/null +++ b/ql-runtime/Cargo.toml @@ -0,0 +1,35 @@ +[package] +name = "ql-runtime" +version = "0.1.0" +edition = "2021" +description = "Quantum Link runtime" +license = "Proprietary" + +[features] +default = [] +log = ["dep:log"] +rpc = ["dep:ql-rpc"] + +[dependencies] +async-channel = { version = "2.5" } +bytes = "1" +diatomic-waker = { version = "0.2.3", default-features = false } +futures-lite = { version = "2.5" } +log = { version = "0.4", optional = true } +oneshot = { version = "0.1.11" } +ql-fsm = { path = "../ql-fsm" } +ql-rpc = { path = "../ql-rpc", optional = true } +ql-wire = { path = "../ql-wire" } + +[dev-dependencies] +env_logger = "0.11" +log = "0.4" +ql-wire = { path = "../ql-wire", features = ["test-utils"] } +tokio = { version = "1.44", features = ["macros", "rt", "time", "sync"] } + +[target.'cfg(loom)'.dev-dependencies] +event-listener = { version = "5.4", features = ["loom"] } +loom = "0.7" + +[lints.rust] +unexpected_cfgs = { level = "warn", check-cfg = ['cfg(loom)'] } diff --git a/ql-runtime/src/command.rs b/ql-runtime/src/command.rs new file mode 100644 index 00000000..4a47a45e --- /dev/null +++ b/ql-runtime/src/command.rs @@ -0,0 +1,57 @@ +use ql_fsm::{NoSessionError, PairingInvite}; +use ql_wire::{ + CloseTarget, PairingToken, PeerBundle, RouteId, SessionCloseCode, StreamCloseCode, StreamId, +}; + +use crate::{StreamReader, StreamWriter}; + +pub enum Command { + BindPeer { + peer: PeerBundle, + }, + Connect, + ArmPairing { + token: PairingToken, + }, + DisarmPairing, + StartPairing { + invite: PairingInvite, + }, + OpenStream { + route_id: RouteId, + start: oneshot::Sender>, + }, + PollInbound { + stream_id: StreamId, + }, + PollStream { + stream_id: StreamId, + }, + CloseSession { + code: SessionCloseCode, + }, + Unpair, + CloseStream { + stream_id: StreamId, + target: CloseTarget, + code: StreamCloseCode, + }, +} + +impl Command { + pub fn kind(&self) -> &'static str { + match self { + Self::BindPeer { .. } => "BindPeer", + Self::Connect => "Connect", + Self::ArmPairing { .. } => "ArmPairing", + Self::DisarmPairing => "DisarmPairing", + Self::StartPairing { .. } => "StartPairing", + Self::OpenStream { .. } => "OpenStream", + Self::PollInbound { .. } => "PollInbound", + Self::PollStream { .. } => "PollStream", + Self::CloseSession { .. } => "CloseSession", + Self::Unpair => "Unpair", + Self::CloseStream { .. } => "CloseStream", + } + } +} diff --git a/ql-runtime/src/driver/mod.rs b/ql-runtime/src/driver/mod.rs new file mode 100644 index 00000000..35de1bf0 --- /dev/null +++ b/ql-runtime/src/driver/mod.rs @@ -0,0 +1,609 @@ +mod state; +#[cfg(test)] +mod test; + +use std::{ + collections::{ + hash_map::{Entry, OccupiedEntry}, + HashMap, + }, + future::Future, + pin::{pin, Pin}, + task::{Context, Poll}, + time::Instant, +}; + +use async_channel::Recv; +use futures_lite::future::{poll_fn, yield_now}; +use ql_fsm::{Event, QlFsm, WriteId}; +use ql_wire::{CloseTarget, StreamCloseCode, StreamId}; + +use self::state::{DriverState, DriverStreamIo, InboundIo, InboundWriteResult, OutboundIo}; +use crate::{ + command::Command, + handle::QlStream, + io, log, + platform::{QlInbound, QlPlatform, QlTimer}, + QlStreamError, Runtime, RuntimeHandle, +}; + +impl Runtime

{ + #[allow(clippy::future_not_send)] + pub async fn run(self) { + let Self { + identity, + mut platform, + config, + rx, + tx, + } = self; + + let mut fsm = QlFsm::new(config.fsm, identity, Instant::now()); + + let mut state = DriverState { + streams: HashMap::new(), + runtime_tx: tx, + max_concurrent_message_writes: config.max_concurrent_message_writes, + }; + + let mut in_flight = Vec::new(); + let timer = platform.timer(); + let mut timer = pin!(timer); + let inbound = platform.inbound(); + let mut inbound = pin!(inbound); + let recv_future = rx.recv(); + let mut recv_future = Some(pin!(recv_future)); + let mut poll_cursor = 0usize; + + loop { + state.drain_fsm_events(&mut fsm, &platform); + if state.fill_write_slots(&mut fsm, &platform, &mut in_flight) { + state.drain_fsm_events(&mut fsm, &platform); + } + timer.as_mut().set_deadline(fsm.next_deadline()); + + let step = poll_fn(|cx| { + next_step( + cx, + recv_future.as_mut().map(|future| future.as_mut()), + inbound.as_mut(), + timer.as_mut(), + &mut in_flight, + poll_cursor, + ) + }) + .await; + poll_cursor = (poll_cursor + 1) % STEP_COUNT; + + match step { + DriverStep::Command(command) => { + log::trace!("processing command: kind={}", command.kind()); + state.drive_command(&mut fsm, command, &platform); + } + DriverStep::Inbound(bytes) => { + log::trace!("received transport frame: len={}", bytes.len()); + if let Err(e) = fsm.receive(Instant::now(), bytes, &platform) { + log::info!("receive rejected frame: error={e:?}"); + platform.handle_recv_error(e); + } + } + DriverStep::WriteCompleted { index, success } => { + let write = in_flight.swap_remove(index); + let write_id = write.write_id; + log::trace!( + "write completed: success={success} index={index} write_id={write_id:?}", + ); + DriverState::drive_write_completed(&mut fsm, write_id, success); + yield_now().await; + } + DriverStep::TimerExpired => { + log::trace!("timer expired"); + fsm.on_timer(Instant::now()); + } + DriverStep::Closed => { + log::debug!( + "command channel closed: in_flight_writes={}", + in_flight.len() + ); + recv_future = None; + if in_flight.is_empty() && !fsm.has_shutdown_work() { + break; + } + } + } + } + log::info!("runtime stopped"); + } +} + +struct InFlightWrite { + write_id: Option, + future: F, +} + +enum DriverStep { + Command(Command), + Inbound(Vec), + WriteCompleted { index: usize, success: bool }, + TimerExpired, + Closed, +} + +const STEP_COUNT: usize = 4; + +fn next_step( + cx: &mut Context<'_>, + mut recv_future: Option>>, + mut inbound: Pin<&mut I>, + mut timer: Pin<&mut T>, + in_flight: &mut [InFlightWrite], + start: usize, +) -> Poll +where + T: QlTimer, + F: Future + Unpin, + I: QlInbound, +{ + for offset in 0..STEP_COUNT { + let step = (start + offset) % STEP_COUNT; + let poll = match step { + 0 => recv_future.as_mut().map_or(Poll::Pending, |recv_future| { + recv_future + .as_mut() + .poll(cx) + .map(|res| res.map_or(DriverStep::Closed, DriverStep::Command)) + }), + 1 => inbound.as_mut().poll_recv(cx).map(DriverStep::Inbound), + 2 => { + for (index, write) in in_flight.iter_mut().enumerate() { + if let Poll::Ready(success) = Pin::new(&mut write.future).poll(cx) { + return Poll::Ready(DriverStep::WriteCompleted { index, success }); + } + } + Poll::Pending + } + 3 => timer + .as_mut() + .poll_wait(cx) + .map(|()| DriverStep::TimerExpired), + _ => unreachable!(), + }; + if poll.is_ready() { + return poll; + } + } + + Poll::Pending +} + +impl DriverState { + #[allow(clippy::too_many_lines)] + fn drive_command(&mut self, fsm: &mut QlFsm, command: Command, platform: &P) { + match command { + Command::BindPeer { peer } => { + log::info!("binding peer"); + fsm.bind_peer(peer); + } + Command::Connect => { + log::info!("starting IK connect"); + if fsm.connect_ik(Instant::now(), platform).is_err() { + log::warn!("IK connect ignored: no bound peer"); + } + } + Command::ArmPairing { token } => { + log::info!("arming inbound pairing"); + fsm.arm_pairing(token); + } + Command::DisarmPairing => { + log::info!("disarming inbound pairing"); + fsm.disarm_pairing(); + } + Command::StartPairing { invite } => { + log::info!(" starting XX pairing"); + fsm.connect_xx(Instant::now(), invite, platform); + } + Command::CloseSession { code } => { + log::info!("closing session: code={code:?}"); + fsm.close_session(code); + } + Command::Unpair => { + log::info!("unpairing peer"); + fsm.unpair(); + } + Command::OpenStream { route_id, start } => { + log::info!("open stream requested: route_id={route_id}"); + let Some(runtime_tx) = self.runtime_tx.upgrade() else { + log::warn!("open stream aborted: runtime channel unavailable"); + let _ = start.send(Err(ql_fsm::NoSessionError)); + return; + }; + + let mut stream_ops = match fsm.open_stream(route_id) { + Ok(stream_ops) => stream_ops, + Err(error) => { + log::warn!("open stream failed: route_id={route_id}"); + let _ = start.send(Err(error)); + return; + } + }; + let stream_id = stream_ops.stream_id(); + log::info!("open stream allocated: route_id={route_id} stream_id={stream_id}"); + let (reader, writer, reader_io, writer_io) = io::new_stream( + stream_id, + CloseTarget::Return, + CloseTarget::Origin, + RuntimeHandle::new(runtime_tx), + ); + self.streams.insert( + stream_id, + DriverStreamIo::new( + true, + Some(OutboundIo::new(writer_io)), + Some(InboundIo::new(reader_io)), + ), + ); + if start.send(Ok((stream_id, reader, writer))).is_err() { + log::warn!("open stream cancelled before delivery: stream_id={stream_id}"); + if let Some(stream) = self.streams.get_mut(&stream_id) { + stream.inbound_close(); + stream.outbound_close(); + } + stream_ops.close(CloseTarget::Both, StreamCloseCode::CANCELLED); + drop(stream_ops); + return; + } + drop(stream_ops); + self.poll_stream(fsm, stream_id); + } + Command::PollInbound { stream_id } => { + log::trace!("poll inbound requested: stream_id={stream_id}"); + self.handle_inbound_readable(fsm, stream_id); + } + Command::PollStream { stream_id } => { + log::trace!("poll stream requested: stream_id={stream_id}"); + self.poll_stream(fsm, stream_id); + } + Command::CloseStream { + stream_id, + target, + code, + } => { + log::debug!( + "close stream command: stream_id={stream_id} target={target:?} code={code:?}" + ); + if let Entry::Occupied(mut entry) = self.streams.entry(stream_id) { + let stream = entry.get_mut(); + if target == CloseTarget::Both || target == stream.inbound_target() { + stream.inbound_close(); + } + if target == CloseTarget::Both || target == stream.outbound_target() { + stream.outbound_close(); + } + Self::try_reap_stream(entry); + } + if let Ok(mut stream) = fsm.stream(stream_id) { + stream.close(target, code); + } + } + } + } + + fn drive_write_completed(fsm: &mut QlFsm, session_write_id: Option, success: bool) { + if let Some(write_id) = session_write_id { + fsm.complete_write(Instant::now(), write_id, success); + } + } + + fn drain_fsm_events(&mut self, fsm: &mut QlFsm, platform: &P) { + while let Some(event) = fsm.poll_event() { + log::trace!("polled FSM event: event={event:?}"); + match event { + Event::NewPeer => { + log::info!("new ql peer"); + if let Some(peer) = fsm.peer().cloned() { + platform.persist_peer(peer); + } + } + Event::PeerStatusChanged(status) => { + let peer = fsm.peer().map(|peer| peer.qid); + log::info!("peer status changed: peer={peer:?} status={status:?}"); + if status == ql_fsm::PeerStatus::Unpaired { + for (_, mut stream) in self.streams.drain() { + stream.fail_all(); + } + } + platform.handle_peer_status(peer, status); + } + Event::Opened { + stream_id, + route_id, + } => { + log::info!("inbound stream opened: stream_id={stream_id} route_id={route_id}"); + self.handle_opened_stream(fsm, platform, stream_id, route_id); + } + Event::Readable(stream_id) => { + log::trace!("stream readable: stream_id={stream_id}"); + self.handle_inbound_readable(fsm, stream_id); + } + Event::Writable(stream_id) => { + log::trace!("stream writable: stream_id={stream_id}"); + self.poll_stream(fsm, stream_id); + } + Event::Finished(stream_id) => { + log::info!("peer finished stream writes: stream_id={stream_id}"); + self.handle_inbound_finished(stream_id); + } + Event::OutboundFinished(stream_id) => { + log::info!("outbound finish acknowledged: stream_id={stream_id}"); + self.handle_outbound_finished(stream_id); + } + Event::Closed(frame) => { + self.handle_closed_stream(&frame); + } + Event::WritableClosed(frame) => { + self.handle_writable_closed(&frame); + } + Event::SessionClosed(close) => { + log::info!("session closed: frame={close:?}"); + for (_, mut stream) in self.streams.drain() { + stream.fail_all(); + } + } + } + } + } + + fn handle_opened_stream( + &mut self, + fsm: &mut QlFsm, + platform: &P, + stream_id: StreamId, + route_id: ql_wire::RouteId, + ) { + let Some(runtime_tx) = self.runtime_tx.upgrade() else { + log::warn!( + "dropping inbound stream because handle channel is unavailable: stream_id={stream_id}" + ); + if let Ok(mut stream) = fsm.stream(stream_id) { + stream.close(CloseTarget::Both, StreamCloseCode::CANCELLED); + } + return; + }; + + let (reader, writer, reader_io, writer_io) = io::new_stream( + stream_id, + CloseTarget::Origin, + CloseTarget::Return, + RuntimeHandle::new(runtime_tx), + ); + + self.streams.insert( + stream_id, + DriverStreamIo::new( + false, + Some(OutboundIo::new(writer_io)), + Some(InboundIo::new(reader_io)), + ), + ); + + log::info!( + "delivering inbound stream to platform: stream_id={stream_id} route_id={route_id}" + ); + platform.handle_inbound(QlStream { + stream_id, + route_id, + writer, + reader, + }); + } + + fn handle_inbound_readable(&mut self, fsm: &mut QlFsm, stream_id: StreamId) { + let Ok(mut stream_ops) = fsm.stream(stream_id) else { + log::info!("inbound readable for unknown stream: stream_id={stream_id}"); + return; + }; + let readable = stream_ops.readable_bytes(); + if readable == 0 { + return; + } + log::trace!("draining inbound bytes: stream_id={stream_id} readable={readable}"); + let mut accepted = 0usize; + let mut peer_closed = false; + let target; + { + let Some(stream) = self.streams.get_mut(&stream_id) else { + return; + }; + target = stream.inbound_target(); + for chunk in stream_ops.read() { + if chunk.is_empty() { + continue; + } + match stream.inbound_try_write(chunk) { + InboundWriteResult::Accepted(n) => { + accepted += n; + } + InboundWriteResult::Full => { + log::debug!( + "inbound backpressure: stream_id={stream_id} accepted={accepted}" + ); + break; + } + InboundWriteResult::Closed => { + log::warn!( + "inbound consumer closed; sending CANCELLED: stream_id={stream_id} target={target:?}" + ); + peer_closed = true; + break; + } + } + } + } + + if accepted > 0 { + log::trace!("committed inbound bytes: stream_id={stream_id:?} accepted={accepted}"); + stream_ops.commit_read(accepted).unwrap(); + } + if peer_closed { + stream_ops.close(target, StreamCloseCode::CANCELLED); + if let Entry::Occupied(entry) = self.streams.entry(stream_id) { + Self::try_reap_stream(entry); + } + } + + drop(stream_ops); + } + + fn handle_inbound_finished(&mut self, stream_id: StreamId) { + log::info!("inbound finished event: stream_id={stream_id}"); + let Entry::Occupied(mut entry) = self.streams.entry(stream_id) else { + return; + }; + log::info!("delivering clean inbound finish: stream_id={stream_id}"); + entry.get_mut().inbound_finish(); + Self::try_reap_stream(entry); + } + + fn handle_closed_stream(&mut self, frame: &ql_wire::StreamClose) { + log::info!( + "inbound close frame: stream_id={} target={:?} code={}", + frame.stream_id, + frame.target, + frame.code + ); + let Entry::Occupied(mut entry) = self.streams.entry(frame.stream_id) else { + return; + }; + let stream = entry.get_mut(); + + if frame.target == CloseTarget::Both || frame.target == stream.inbound_target() { + stream.inbound_fail(QlStreamError::StreamClosed { code: frame.code }); + } + if frame.target == CloseTarget::Both || frame.target == stream.outbound_target() { + stream.outbound_fail(QlStreamError::StreamClosed { code: frame.code }); + } + Self::try_reap_stream(entry); + } + + fn handle_writable_closed(&mut self, frame: &ql_wire::StreamClose) { + log::info!( + "writable close frame: stream_id={} target={:?} code={}", + frame.stream_id, + frame.target, + frame.code + ); + let Entry::Occupied(mut entry) = self.streams.entry(frame.stream_id) else { + return; + }; + let stream = entry.get_mut(); + stream.outbound_fail(QlStreamError::StreamClosed { code: frame.code }); + Self::try_reap_stream(entry); + } + + fn handle_outbound_finished(&mut self, stream_id: StreamId) { + log::info!("outbound finish acknowledged: stream_id={stream_id}"); + let Entry::Occupied(mut entry) = self.streams.entry(stream_id) else { + return; + }; + let stream = entry.get_mut(); + if !stream.outbound_finish_pending() { + return; + } + stream.outbound_finish(); + Self::try_reap_stream(entry); + } + + fn fill_write_slots<'a, P: QlPlatform + 'a>( + &self, + fsm: &mut QlFsm, + platform: &'a P, + in_flight: &mut Vec>>, + ) -> bool { + let mut filled = false; + while in_flight.len() < self.max_concurrent_message_writes { + let Some(write) = fsm.take_next_write(Instant::now(), platform) else { + break; + }; + filled = true; + log::trace!( + "queueing transport write: bytes={} write_id={:?}", + write.record.len(), + write.write_id + ); + in_flight.push(InFlightWrite { + write_id: write.write_id, + future: platform.write_message(write.record), + }); + } + filled + } + + fn poll_stream(&mut self, fsm: &mut QlFsm, stream_id: StreamId) { + let Entry::Occupied(mut entry) = self.streams.entry(stream_id) else { + return; + }; + let stream = entry.get_mut(); + let Some(writer_io) = stream.outbound_writer_mut() else { + log::trace!("poll stream skipped without outbound writer: stream_id={stream_id}"); + return; + }; + + if writer_io.is_finished() { + log::info!("observed outbound writer finished before write: stream_id={stream_id}"); + if let Ok(mut stream_ops) = fsm.stream(stream_id) { + if let Some(writer) = stream_ops.writer() { + writer.finish(); + } + } + stream.outbound_queue_finish(); + if stream.is_closed() { + entry.remove(); + } + return; + } + + let Ok(mut stream_ops) = fsm.stream(stream_id) else { + return; + }; + let Some(mut writer) = stream_ops.writer() else { + log::trace!("poll stream skipped without session writer: stream_id={stream_id}"); + return; + }; + + loop { + let capacity = writer.capacity(); + log::trace!("stream write capacity: stream_id={stream_id} capacity={capacity}"); + if capacity == 0 { + break; + } + + let Ok(mut bytes) = writer_io.try_read(capacity) else { + break; + }; + if bytes.is_empty() { + break; + } + + log::trace!( + "writing stream bytes: stream_id={stream_id} len={}", + bytes.len() + ); + let _ = writer.write(&mut bytes); + } + + if writer_io.is_finished() { + log::info!("observed outbound writer finished after write: stream_id={stream_id}"); + writer.finish(); + stream.outbound_queue_finish(); + if stream.is_closed() { + entry.remove(); + } + } + } + + fn try_reap_stream(entry: OccupiedEntry<'_, StreamId, DriverStreamIo>) { + if entry.get().is_closed() { + entry.remove(); + } + } +} diff --git a/ql-runtime/src/driver/state.rs b/ql-runtime/src/driver/state.rs new file mode 100644 index 00000000..0ff8eca8 --- /dev/null +++ b/ql-runtime/src/driver/state.rs @@ -0,0 +1,165 @@ +use std::collections::HashMap; + +use bytes::Bytes; +use ql_wire::{CloseTarget, StreamId}; + +use crate::{ + command::Command, + io::{PushError, Rx, Tx}, + QlStreamError, +}; + +pub struct DriverState { + pub streams: HashMap, + pub runtime_tx: async_channel::WeakSender, + pub max_concurrent_message_writes: usize, +} + +pub struct DriverStreamIo { + is_initiator: bool, + outbound: Option, + inbound: Option, +} + +impl DriverStreamIo { + pub fn new( + is_initiator: bool, + outbound: Option, + inbound: Option, + ) -> Self { + Self { + is_initiator, + outbound, + inbound, + } + } + + pub fn inbound_target(&self) -> CloseTarget { + if self.is_initiator { + CloseTarget::Return + } else { + CloseTarget::Origin + } + } + + pub fn outbound_target(&self) -> CloseTarget { + if self.is_initiator { + CloseTarget::Origin + } else { + CloseTarget::Return + } + } + + pub fn fail_all(&mut self) { + self.inbound_fail(QlStreamError::NoSession); + self.outbound_fail(QlStreamError::NoSession); + } + + pub fn is_closed(&self) -> bool { + self.outbound.is_none() && self.inbound.is_none() + } + + pub fn outbound_close(&mut self) { + self.outbound = None; + } + + pub fn outbound_finish(&mut self) { + if let Some(outbound) = self.outbound.take() { + outbound.tx.finish(); + } + } + + pub fn outbound_fail(&mut self, error: QlStreamError) { + if let Some(outbound) = self.outbound.take() { + let _ = outbound.tx.fail(error); + } + } + + pub fn outbound_writer_mut(&mut self) -> Option<&mut OutboundIo> { + self.outbound.as_mut() + } + + pub fn outbound_queue_finish(&mut self) { + if let Some(outbound) = self.outbound.as_mut() { + outbound.finish_pending = true; + } + } + + pub fn outbound_finish_pending(&self) -> bool { + self.outbound + .as_ref() + .is_some_and(|outbound| outbound.finish_pending) + } + + pub fn inbound_close(&mut self) { + self.inbound = None; + } + + pub fn inbound_try_write(&mut self, bytes: Bytes) -> InboundWriteResult { + let Some(inbound) = self.inbound.as_mut() else { + return InboundWriteResult::Closed; + }; + + let len = bytes.len(); + match inbound.rx.try_write(bytes) { + Ok(()) => InboundWriteResult::Accepted(len), + Err(PushError::Full(_)) => InboundWriteResult::Full, + Err(PushError::Closed(_)) => { + self.inbound = None; + InboundWriteResult::Closed + } + } + } + + pub fn inbound_finish(&mut self) { + if let Some(inbound) = self.inbound.take() { + inbound.rx.finish(); + } + } + + pub fn inbound_fail(&mut self, error: QlStreamError) { + if let Some(inbound) = self.inbound.take() { + inbound.rx.fail(error); + } + } +} + +pub struct OutboundIo { + tx: Tx, + pending: Bytes, + finish_pending: bool, +} + +impl OutboundIo { + pub fn new(tx: Tx) -> Self { + Self { + tx, + pending: Bytes::new(), + finish_pending: false, + } + } + + pub fn is_finished(&self) -> bool { + self.pending.is_empty() && self.tx.is_finished() + } + + pub fn try_read(&mut self, max_len: usize) -> Result { + self.tx.try_read(&mut self.pending, max_len) + } +} + +pub struct InboundIo { + rx: Rx, +} + +pub enum InboundWriteResult { + Accepted(usize), + Full, + Closed, +} + +impl InboundIo { + pub fn new(rx: Rx) -> Self { + Self { rx } + } +} diff --git a/ql-runtime/src/driver/test.rs b/ql-runtime/src/driver/test.rs new file mode 100644 index 00000000..af4ab63a --- /dev/null +++ b/ql-runtime/src/driver/test.rs @@ -0,0 +1,207 @@ +use ql_wire::{generate_identity, NoopCrypto, PeerBundle, SoftwareCrypto, StreamClose, QID}; + +use super::*; +use crate::{ + driver::state::{InboundIo, OutboundIo}, + io, + platform::QlInbound, +}; + +pub struct NoopTimer; +pub struct NoopInbound; + +impl crate::platform::QlTimer for NoopTimer { + fn set_deadline(self: Pin<&mut Self>, _deadline: Option) {} + + fn poll_wait(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<()> { + Poll::Pending + } +} + +impl QlPlatform for NoopCrypto { + type Timer = NoopTimer; + type WriteMessageFut<'a> = std::future::Ready; + type Inbound = NoopInbound; + + fn write_message(&self, _message: Vec) -> Self::WriteMessageFut<'_> { + std::future::ready(true) + } + + fn inbound(&mut self) -> Self::Inbound { + NoopInbound + } + + fn timer(&self) -> Self::Timer { + NoopTimer + } + + fn persist_peer(&self, _peer: PeerBundle) {} + + fn handle_peer_status(&self, _peer: Option, _status: ql_fsm::PeerStatus) {} + + fn handle_inbound(&self, _event: QlStream) {} +} + +impl QlInbound for NoopInbound { + fn poll_recv(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Pending + } +} + +fn new_driver_state() -> (DriverState, QlFsm) { + let (runtime_tx, _runtime_rx) = async_channel::unbounded(); + ( + DriverState { + streams: HashMap::new(), + runtime_tx: runtime_tx.downgrade(), + max_concurrent_message_writes: 1, + }, + QlFsm::new( + ql_fsm::QlFsmConfig::default(), + generate_identity(&SoftwareCrypto, "driver").unwrap(), + Instant::now(), + ), + ) +} + +fn new_inbound_io(capacity: usize) -> InboundIo { + let _ = capacity; + let (runtime_tx, _runtime_rx) = async_channel::unbounded(); + let stream = io::new_stream( + StreamId(99u32.into()), + CloseTarget::Origin, + CloseTarget::Return, + RuntimeHandle::new(runtime_tx), + ); + let (_, _, reader_io, _) = stream; + InboundIo::new(reader_io) +} + +fn new_outbound_io() -> OutboundIo { + let (runtime_tx, _runtime_rx) = async_channel::unbounded(); + let stream = io::new_stream( + StreamId(100u32.into()), + CloseTarget::Return, + CloseTarget::Origin, + RuntimeHandle::new(runtime_tx), + ); + let (_, _, _, writer_io) = stream; + OutboundIo::new(writer_io) +} + +#[test] +fn handle_inbound_finished_reaps_closed_initiator_stream() { + let (mut state, _fsm) = new_driver_state(); + let stream_id = StreamId(1u32.into()); + + state.streams.insert( + stream_id, + DriverStreamIo::new(true, None, Some(new_inbound_io(1))), + ); + + state.handle_inbound_finished(stream_id); + + assert!(!state.streams.contains_key(&stream_id)); +} + +#[test] +fn handle_closed_stream_reaps_when_both_halves_close() { + let (mut state, _fsm) = new_driver_state(); + let stream_id = StreamId(1u32.into()); + + state.streams.insert( + stream_id, + DriverStreamIo::new(false, Some(new_outbound_io()), Some(new_inbound_io(1))), + ); + + state.handle_closed_stream(&StreamClose { + stream_id, + target: CloseTarget::Both, + code: StreamCloseCode::CANCELLED, + }); + + assert!(!state.streams.contains_key(&stream_id)); +} + +#[test] +fn poll_stream_keeps_outbound_pending_after_local_finish_when_inbound_is_closed() { + let (mut state, mut fsm) = new_driver_state(); + let stream_id = StreamId(1u32.into()); + let (runtime_tx, _runtime_rx) = async_channel::unbounded(); + let (_, mut writer, _, writer_io) = io::new_stream( + stream_id, + CloseTarget::Return, + CloseTarget::Origin, + RuntimeHandle::new(runtime_tx), + ); + writer.queue_finish(); + state.streams.insert( + stream_id, + DriverStreamIo::new(true, Some(OutboundIo::new(writer_io)), None), + ); + + state.poll_stream(&mut fsm, stream_id); + + let stream = state.streams.get(&stream_id).unwrap(); + assert!(stream.outbound_finish_pending()); + assert!(!stream.is_closed()); +} + +#[test] +fn local_close_command_reaps_when_other_half_is_already_closed() { + let (mut state, mut fsm) = new_driver_state(); + let stream_id = StreamId(1u32.into()); + let (runtime_tx, _runtime_rx) = async_channel::unbounded(); + let (_, _, _, writer_io) = io::new_stream( + stream_id, + CloseTarget::Return, + CloseTarget::Origin, + RuntimeHandle::new(runtime_tx), + ); + + state.streams.insert( + stream_id, + DriverStreamIo::new(true, Some(OutboundIo::new(writer_io)), None), + ); + + state.drive_command( + &mut fsm, + Command::CloseStream { + stream_id, + target: CloseTarget::Origin, + code: StreamCloseCode::CANCELLED, + }, + &NoopCrypto, + ); + + assert!(!state.streams.contains_key(&stream_id)); +} + +#[test] +fn unpaired_status_fails_and_reaps_all_streams() { + let (mut state, mut fsm) = new_driver_state(); + let peer = generate_identity(&SoftwareCrypto, "peer").unwrap().bundle(); + let stream_id = StreamId(1u32.into()); + let (runtime_tx, _runtime_rx) = async_channel::unbounded(); + let (_, _, reader_io, writer_io) = io::new_stream( + stream_id, + CloseTarget::Origin, + CloseTarget::Return, + RuntimeHandle::new(runtime_tx), + ); + + state.streams.insert( + stream_id, + DriverStreamIo::new( + false, + Some(OutboundIo::new(writer_io)), + Some(InboundIo::new(reader_io)), + ), + ); + fsm.bind_peer(peer); + fsm.unpair(); + + state.drain_fsm_events(&mut fsm, &NoopCrypto); + + assert!(state.streams.is_empty()); +} diff --git a/ql-runtime/src/error.rs b/ql-runtime/src/error.rs new file mode 100644 index 00000000..5b74bcf8 --- /dev/null +++ b/ql-runtime/src/error.rs @@ -0,0 +1,18 @@ +use ql_wire::StreamCloseCode; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum QlStreamError { + StreamClosed { code: StreamCloseCode }, + NoSession, +} + +impl std::fmt::Display for QlStreamError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::StreamClosed { code } => write!(f, "stream closed {code:?}"), + Self::NoSession => f.write_str("no session"), + } + } +} + +impl std::error::Error for QlStreamError {} diff --git a/ql-runtime/src/handle/mod.rs b/ql-runtime/src/handle/mod.rs new file mode 100644 index 00000000..1782c17a --- /dev/null +++ b/ql-runtime/src/handle/mod.rs @@ -0,0 +1,96 @@ +use ql_fsm::{NoSessionError, PairingInvite}; +use ql_wire::{PairingToken, PeerBundle, RouteId, SessionCloseCode, StreamId}; + +use crate::command::Command; +pub use crate::io::{StreamReader, StreamWriter}; + +#[derive(Debug)] +pub struct QlStream { + pub stream_id: StreamId, + pub route_id: RouteId, + pub writer: StreamWriter, + pub reader: StreamReader, +} + +#[derive(Clone)] +pub struct RuntimeHandle { + tx: async_channel::Sender, +} + +impl RuntimeHandle { + /// binds the remote peer + pub fn bind_peer(&self, peer: PeerBundle) { + self.send(Command::BindPeer { peer }); + } + + /// starts an IK handshake with the bound peer + pub fn connect(&self) { + self.send(Command::Connect); + } + + /// arms acceptance of inbound xx pairings for a single token + pub fn arm_pairing(&self, token: PairingToken) { + self.send(Command::ArmPairing { token }); + } + + /// disarms inbound xx pairing + pub fn disarm_pairing(&self) { + self.send(Command::DisarmPairing); + } + + /// starts an outbound xx handshake using an out-of-band pairing invite + pub fn start_pairing(&self, invite: PairingInvite) { + self.send(Command::StartPairing { invite }); + } + + /// closes the current encrypted session + pub fn close_session(&self, code: SessionCloseCode) { + self.send(Command::CloseSession { code }); + } + + /// forgets the currently bound peer and initiates session unpairing if connected + pub fn unpair(&self) { + self.send(Command::Unpair); + } + + /// opens a new stream on the active encrypted session + pub async fn open_stream(&self, route_id: RouteId) -> Result { + let (start_tx, start_rx) = oneshot::channel(); + + self.send(Command::OpenStream { + route_id, + start: start_tx, + }); + + // runtime cannot be shutdown while we have a handle + let (stream_id, reader, writer) = start_rx.await.unwrap()?; + + Ok(QlStream { + stream_id, + route_id, + writer, + reader, + }) + } + + #[cfg(feature = "rpc")] + pub fn rpc(&self) -> crate::rpc::RpcHandle { + crate::rpc::RpcHandle::new(self.clone()) + } +} + +impl RuntimeHandle { + pub(crate) fn new(tx: async_channel::Sender) -> Self { + Self { tx } + } + + #[inline] + #[track_caller] + pub(crate) fn send(&self, cmd: Command) { + self.tx.try_send(cmd).expect("runtime is alive"); + } + + pub(crate) fn try_send(&self, cmd: Command) -> bool { + self.tx.try_send(cmd).is_ok() + } +} diff --git a/ql-runtime/src/io/inner.rs b/ql-runtime/src/io/inner.rs new file mode 100644 index 00000000..64df6ced --- /dev/null +++ b/ql-runtime/src/io/inner.rs @@ -0,0 +1,643 @@ +//! per-stream shared io state +//! each lane has one slot and one waker +//! the low slot bits belong to `slot.rs` and the higher bits here carry lane-specific flags + +use std::task::Waker; + +use bytes::Bytes; +use diatomic_waker::DiatomicWaker; +use ql_wire::StreamId; + +use super::{ + slot::{PopError, PushError, Slot}, + sync::Arc, +}; +use crate::QlStreamError; + +pub(super) fn new(stream_id: StreamId) -> Arc { + Arc::new(Inner { + stream_id, + rx: RxInner::new(), + tx: TxInner::new(), + }) +} + +pub(super) struct Inner { + pub(super) stream_id: StreamId, + pub(super) rx: RxInner, + pub(super) tx: TxInner, +} + +pub enum Item { + Chunk(Bytes), + Error(QlStreamError), +} + +#[derive(Debug, PartialEq, Eq)] +pub struct ForcePushError(pub T); + +/// reader-lane shared state +pub struct RxInner { + slot: Slot, + changed: DiatomicWaker, +} + +impl RxInner { + const FINISHED: usize = 1 << 2; + + fn new() -> Self { + Self { + slot: Slot::new(), + changed: DiatomicWaker::new(), + } + } + + pub fn try_write(&self, bytes: Bytes) -> Result<(), PushError> { + try_write_chunk(&self.slot, &self.changed, bytes, Self::FINISHED) + } + + /// marks clean reader eof + pub fn finish(&self) { + if self.slot.fetch_or(Self::FINISHED) & Self::FINISHED == 0 { + self.changed.notify(); + } + } + + /// stores a terminal reader error + pub fn fail(&self, error: QlStreamError) -> Option { + let displaced = self.slot.force_push(Item::Error(error)); + self.changed.notify(); + displaced_bytes(displaced) + } + + pub fn load_state(&self) -> usize { + self.slot.load_state() + } + + pub fn is_finished(state: usize) -> bool { + state & Self::FINISHED != 0 + } + + pub fn pop(&self) -> Result { + pop_item(&self.slot, &self.changed) + } + + /// registers the sole reader-lane waiter + pub fn register_waiter(&self, waker: &Waker) { + // Safety: StreamReader is the only reader-lane registrar for this + // shared state, so register/unregister never run concurrently. + unsafe { self.changed.register(waker) }; + } + + /// unregisters the sole reader-lane waiter + pub fn unregister_waiter(&self) { + // Safety: StreamReader is the only reader-lane registrar for this + // shared state, so register/unregister never run concurrently. + unsafe { self.changed.unregister() }; + } +} + +/// writer-lane shared state +/// +/// finish and fail race to establish the terminal result +/// terminal errors are stored in the slot +pub struct TxInner { + slot: Slot, + changed: DiatomicWaker, +} + +impl TxInner { + const FINISH_REQUESTED: usize = 1 << 2; + const TERMINAL_READY: usize = 1 << 3; + const TERMINAL_OK: usize = 1 << 4; + + fn new() -> Self { + Self { + slot: Slot::new(), + changed: DiatomicWaker::new(), + } + } + + pub fn load_state(&self) -> usize { + self.slot.load_state() + } + + pub fn finish_requested(state: usize) -> bool { + state & Self::FINISH_REQUESTED != 0 + } + + pub fn terminal_ready(state: usize) -> bool { + state & Self::TERMINAL_READY != 0 + } + + pub fn terminal_ok(state: usize) -> bool { + state & Self::TERMINAL_OK != 0 + } + + pub fn try_write(&self, bytes: Bytes) -> Result<(), PushError> { + try_write_chunk( + &self.slot, + &self.changed, + bytes, + Self::FINISH_REQUESTED | Self::TERMINAL_READY, + ) + } + + /// prevents future chunk writes once observed + pub fn request_finish(&self) { + if self.slot.fetch_or(Self::FINISH_REQUESTED) & Self::FINISH_REQUESTED == 0 { + self.changed.notify(); + } + } + + /// commits a clean writer eof + pub fn finish(&self) { + let mut state = self.slot.load_state(); + loop { + if Self::terminal_ready(state) { + return; + } + + let new_state = state | Self::TERMINAL_READY | Self::TERMINAL_OK; + match self.slot.compare_exchange(state, new_state) { + Ok(()) => { + self.changed.notify(); + return; + } + Err(actual) => state = actual, + } + } + } + + /// stores a terminal writer error + /// futures calls will have no effect + pub fn fail( + &self, + error: QlStreamError, + ) -> Result, ForcePushError> { + let mut state = self.slot.load_state(); + loop { + if Self::terminal_ready(state) { + return Err(ForcePushError(error)); + } + + let new_state = state | Self::TERMINAL_READY; + match self.slot.compare_exchange(state, new_state) { + Ok(()) => break, + Err(actual) => state = actual, + } + } + + let displaced = self.slot.force_push(Item::Error(error)); + self.changed.notify(); + Ok(displaced_bytes(displaced)) + } + + pub fn pop(&self) -> Result { + pop_item(&self.slot, &self.changed) + } + + /// registers the sole writer-lane waiter + pub fn register_waiter(&self, waker: &Waker) { + // Safety: StreamWriter is the only writer-lane registrar for this + // shared state, so register/unregister never run concurrently. + unsafe { self.changed.register(waker) }; + } + + /// unregisters the sole writer-lane waiter + pub fn unregister_waiter(&self) { + // Safety: StreamWriter is the only writer-lane registrar for this + // shared state, so register/unregister never run concurrently. + unsafe { self.changed.unregister() }; + } + + /// returns true once finish was requested and buffered data is drained + pub fn is_finished(&self) -> bool { + let state = self.load_state(); + Self::finish_requested(state) && Slot::::is_empty_state(state) + } + + pub fn try_read(&self, pending: &mut Bytes, max_len: usize) -> Result { + if !pending.is_empty() { + return Ok(if pending.len() <= max_len { + std::mem::take(pending) + } else { + pending.split_to(max_len) + }); + } + + let state = self.load_state(); + if Self::terminal_ready(state) { + return Err(()); + } + + match self.pop() { + Ok(Item::Chunk(mut bytes)) => { + if bytes.len() <= max_len { + Ok(bytes) + } else { + let head = bytes.split_to(max_len); + *pending = bytes; + Ok(head) + } + } + Ok(Item::Error(_)) => Err(()), + Err(PopError) => Ok(Bytes::new()), + } + } +} + +#[inline] +fn try_write_chunk( + slot: &Slot, + changed: &DiatomicWaker, + bytes: Bytes, + closed_mask: usize, +) -> Result<(), PushError> { + match slot.try_push(Item::Chunk(bytes), closed_mask) { + Ok(()) => { + changed.notify(); + Ok(()) + } + Err(PushError::Closed(Item::Chunk(bytes))) => Err(PushError::Closed(bytes)), + Err(PushError::Full(Item::Chunk(bytes))) => Err(PushError::Full(bytes)), + Err(PushError::Closed(Item::Error(_)) | PushError::Full(Item::Error(_))) => { + unreachable!("chunk write cannot recover an error payload") + } + } +} + +#[inline] +fn displaced_bytes(displaced: Option) -> Option { + match displaced { + Some(Item::Chunk(bytes)) => Some(bytes), + Some(Item::Error(_)) | None => None, + } +} + +#[inline] +fn pop_item(slot: &Slot, changed: &DiatomicWaker) -> Result { + match slot.pop() { + item @ Ok(Item::Chunk(_)) => { + changed.notify(); + item + } + item @ (Ok(Item::Error(_)) | Err(_)) => item, + } +} + +#[cfg(all(test, loom))] +mod loom_tests { + use std::task::Waker; + + use bytes::Bytes; + use loom::thread; + use ql_wire::StreamCloseCode; + + use super::*; + use crate::{ + io::{sync::loom::*, Tx}, + QlStreamError, + }; + + #[test] + fn reader_waiter_registration_survives_finish() { + check_model(|| { + let shared = shared(); + shared.rx.register_waiter(Waker::noop()); + + let finisher = { + let shared = shared.clone(); + thread::spawn(move || { + shared.rx.finish(); + }) + }; + + finisher.join().unwrap(); + assert!(RxInner::is_finished(shared.rx.load_state())); + + shared.rx.unregister_waiter(); + }); + } + + #[test] + fn reader_chunk_remains_available_after_finish() { + check_model(|| { + let shared = shared(); + + let producer = { + let shared = shared.clone(); + thread::spawn(move || { + shared.rx.try_write(Bytes::from_static(b"abc")).unwrap(); + shared.rx.finish(); + }) + }; + + producer.join().unwrap(); + + match shared.rx.pop() { + Ok(Item::Chunk(bytes)) => assert_eq!(bytes, Bytes::from_static(b"abc")), + _ => panic!("expected buffered reader chunk"), + } + assert!(RxInner::is_finished(shared.rx.load_state())); + assert!(matches!(shared.rx.pop(), Err(PopError))); + }); + } + + #[test] + fn reader_rejects_write_after_finish() { + check_model(|| { + let shared = shared(); + + shared.rx.finish(); + + assert_eq!( + shared.rx.try_write(Bytes::from_static(b"abc")), + Err(PushError::Closed(Bytes::from_static(b"abc"))) + ); + assert!(RxInner::is_finished(shared.rx.load_state())); + assert!(matches!(shared.rx.pop(), Err(PopError))); + }); + } + + #[test] + fn reader_write_races_with_finish_has_coherent_outcome() { + check_model(|| { + let shared = shared(); + + let writer = { + let shared = shared.clone(); + thread::spawn(move || shared.rx.try_write(Bytes::from_static(b"abc"))) + }; + let finisher = { + let shared = shared.clone(); + thread::spawn(move || shared.rx.finish()) + }; + + let write_result = writer.join().unwrap(); + finisher.join().unwrap(); + + assert!(RxInner::is_finished(shared.rx.load_state())); + match write_result { + Ok(()) => match shared.rx.pop() { + Ok(Item::Chunk(bytes)) => assert_eq!(bytes, Bytes::from_static(b"abc")), + _ => panic!("expected buffered reader chunk"), + }, + Err(PushError::Closed(bytes)) => { + assert_eq!(bytes, Bytes::from_static(b"abc")); + assert!(matches!(shared.rx.pop(), Err(PopError))); + return; + } + Err(PushError::Full(_)) => panic!("empty reader slot must not report full"), + } + assert!(matches!(shared.rx.pop(), Err(PopError))); + }); + } + + #[test] + fn reader_fail_racing_with_pop_preserves_terminal_outcome() { + check_model(|| { + let shared = shared(); + shared.rx.try_write(Bytes::from_static(b"abc")).unwrap(); + + let popper = { + let shared = shared.clone(); + thread::spawn(move || shared.rx.pop()) + }; + let failer = { + let shared = shared.clone(); + thread::spawn(move || { + shared.rx.fail(QlStreamError::StreamClosed { + code: StreamCloseCode::CANCELLED, + }) + }) + }; + + let pop_result = popper.join().unwrap(); + let fail_result = failer.join().unwrap(); + + match (pop_result, fail_result) { + (Ok(Item::Chunk(bytes)), None) => { + assert_eq!(bytes, Bytes::from_static(b"abc")); + match shared.rx.pop() { + Ok(Item::Error(QlStreamError::StreamClosed { code })) => { + assert_eq!(code, StreamCloseCode::CANCELLED); + } + _ => panic!("expected terminal reader error"), + } + } + (Ok(Item::Error(QlStreamError::StreamClosed { code })), Some(bytes)) => { + assert_eq!(code, StreamCloseCode::CANCELLED); + assert_eq!(bytes, Bytes::from_static(b"abc")); + assert!(matches!(shared.rx.pop(), Err(PopError))); + } + _ => panic!("unexpected reader fail/pop race outcome"), + } + }); + } + + #[test] + fn writer_is_finished_only_after_drain() { + check_model(|| { + let shared = shared(); + let tx = Tx(shared.clone()); + let mut pending = Bytes::new(); + + shared.tx.try_write(Bytes::from_static(b"abc")).unwrap(); + shared.tx.request_finish(); + + assert!(!(pending.is_empty() && tx.is_finished())); + assert_eq!(tx.try_read(&mut pending, 2), Ok(Bytes::from_static(b"ab"))); + assert!(!(pending.is_empty() && tx.is_finished())); + assert_eq!(tx.try_read(&mut pending, 8), Ok(Bytes::from_static(b"c"))); + assert!(pending.is_empty() && tx.is_finished()); + }); + } + + #[test] + fn writer_write_races_with_request_finish() { + check_model(|| { + let shared = shared(); + let tx = Tx(shared.clone()); + let mut pending = Bytes::new(); + + let writer = { + let shared = shared.clone(); + thread::spawn(move || shared.tx.try_write(Bytes::from_static(b"abc"))) + }; + let finisher = { + let shared = shared.clone(); + thread::spawn(move || shared.tx.request_finish()) + }; + + let write_result = writer.join().unwrap(); + finisher.join().unwrap(); + + assert!(TxInner::finish_requested(shared.tx.load_state())); + match write_result { + Ok(()) => { + assert_eq!(tx.try_read(&mut pending, 8), Ok(Bytes::from_static(b"abc"))); + assert!(pending.is_empty() && tx.is_finished()); + } + Err(PushError::Closed(bytes)) => { + assert_eq!(bytes, Bytes::from_static(b"abc")); + assert!(pending.is_empty() && tx.is_finished()); + } + Err(PushError::Full(_)) => panic!("empty writer slot must not report full"), + } + }); + } + + #[test] + fn writer_fail_overwrites_buffered_chunk_and_keeps_terminal_state_observable() { + check_model(|| { + let shared = shared(); + shared.tx.try_write(Bytes::from_static(b"abc")).unwrap(); + shared.tx.register_waiter(Waker::noop()); + + let failer = { + let shared = shared.clone(); + thread::spawn(move || { + let displaced = shared.tx.fail(QlStreamError::StreamClosed { + code: StreamCloseCode::CANCELLED, + }); + assert_eq!(displaced.unwrap(), Some(Bytes::from_static(b"abc"))); + }) + }; + + failer.join().unwrap(); + + assert!(TxInner::terminal_ready(shared.tx.load_state())); + shared.tx.unregister_waiter(); + match shared.tx.pop() { + Ok(Item::Error(QlStreamError::StreamClosed { code })) => { + assert_eq!(code, StreamCloseCode::CANCELLED); + } + _ => panic!("expected terminal writer error"), + } + }); + } + + #[test] + fn reader_waiter_registration_can_be_reused_after_notification() { + check_model(|| { + let shared = shared(); + + shared.rx.register_waiter(Waker::noop()); + shared.rx.try_write(Bytes::from_static(b"abc")).unwrap(); + match shared.rx.pop() { + Ok(Item::Chunk(bytes)) => assert_eq!(bytes, Bytes::from_static(b"abc")), + _ => panic!("expected buffered reader chunk"), + } + + shared.rx.register_waiter(Waker::noop()); + shared.rx.finish(); + assert!(RxInner::is_finished(shared.rx.load_state())); + shared.rx.unregister_waiter(); + }); + } + + #[test] + fn writer_waiter_registration_can_be_reused_after_notification() { + check_model(|| { + let shared = shared(); + + shared.tx.register_waiter(Waker::noop()); + shared.tx.try_write(Bytes::from_static(b"abc")).unwrap(); + match shared.tx.pop() { + Ok(Item::Chunk(bytes)) => assert_eq!(bytes, Bytes::from_static(b"abc")), + _ => panic!("expected buffered writer chunk"), + } + + shared.tx.register_waiter(Waker::noop()); + shared.tx.finish(); + assert!(TxInner::terminal_ready(shared.tx.load_state())); + shared.tx.unregister_waiter(); + }); + } + + #[test] + fn writer_write_races_with_fail() { + check_model(|| { + let shared = shared(); + + let writer = { + let shared = shared.clone(); + thread::spawn(move || shared.tx.try_write(Bytes::from_static(b"abc"))) + }; + let failer = { + let shared = shared.clone(); + thread::spawn(move || { + shared.tx.fail(QlStreamError::StreamClosed { + code: StreamCloseCode::CANCELLED, + }) + }) + }; + + let write_result = writer.join().unwrap(); + let fail_result = failer.join().unwrap(); + + assert!(TxInner::terminal_ready(shared.tx.load_state())); + match (&write_result, &fail_result) { + (Ok(()), Ok(Some(bytes))) => { + assert_eq!(Bytes::from_static(b"abc"), bytes.clone()); + } + (Err(PushError::Closed(bytes)), Ok(None)) => { + assert_eq!(Bytes::from_static(b"abc"), bytes.clone()); + } + (Err(PushError::Full(bytes)), Ok(None)) => { + assert_eq!(Bytes::from_static(b"abc"), bytes.clone()); + } + _ => panic!( + "unexpected writer fail/write race outcome: write={write_result:?} fail={fail_result:?}" + ), + } + + match shared.tx.pop() { + Ok(Item::Error(QlStreamError::StreamClosed { code })) => { + assert_eq!(code, StreamCloseCode::CANCELLED); + } + _ => panic!("expected terminal writer error"), + } + }); + } + + #[test] + fn writer_finish_races_with_fail_without_masking_error() { + check_model(|| { + let shared = shared(); + + let finisher = { + let shared = shared.clone(); + thread::spawn(move || shared.tx.finish()) + }; + let failer = { + let shared = shared.clone(); + thread::spawn(move || { + shared.tx.fail(QlStreamError::StreamClosed { + code: StreamCloseCode::CANCELLED, + }) + }) + }; + + finisher.join().unwrap(); + let fail_result = failer.join().unwrap(); + + assert!(TxInner::terminal_ready(shared.tx.load_state())); + match fail_result { + Err(_) => { + assert!(TxInner::terminal_ok(shared.tx.load_state())); + } + Ok(_) => { + assert!(!TxInner::terminal_ok(shared.tx.load_state())); + match shared.tx.pop() { + Ok(Item::Error(QlStreamError::StreamClosed { code })) => { + assert_eq!(code, StreamCloseCode::CANCELLED); + } + _ => panic!("expected terminal writer error"), + } + } + } + }); + } +} diff --git a/ql-runtime/src/io/mod.rs b/ql-runtime/src/io/mod.rs new file mode 100644 index 00000000..2eb7f0f0 --- /dev/null +++ b/ql-runtime/src/io/mod.rs @@ -0,0 +1,59 @@ +mod inner; +mod reader; +mod slot; +mod sync; +mod writer; + +use std::ops::Deref; + +use ql_wire::{CloseTarget, StreamId}; + +pub use self::{reader::StreamReader, slot::PushError, writer::StreamWriter}; +use crate::RuntimeHandle; + +pub struct Rx(sync::Arc); + +impl Deref for Rx { + type Target = inner::RxInner; + + fn deref(&self) -> &Self::Target { + &self.0.rx + } +} + +impl Rx { + pub fn stream_id(&self) -> StreamId { + self.0.stream_id + } +} + +pub struct Tx(sync::Arc); + +impl Deref for Tx { + type Target = inner::TxInner; + + fn deref(&self) -> &Self::Target { + &self.0.tx + } +} + +impl Tx { + pub fn stream_id(&self) -> StreamId { + self.0.stream_id + } +} + +pub fn new_stream( + stream_id: StreamId, + reader_target: CloseTarget, + writer_target: CloseTarget, + handle: RuntimeHandle, +) -> (StreamReader, StreamWriter, Rx, Tx) { + let shared = inner::new(stream_id); + ( + StreamReader::new(Rx(shared.clone()), reader_target, handle.clone()), + StreamWriter::new(Tx(shared.clone()), writer_target, handle), + Rx(shared.clone()), + Tx(shared), + ) +} diff --git a/ql-runtime/src/io/reader.rs b/ql-runtime/src/io/reader.rs new file mode 100644 index 00000000..8c40ccd3 --- /dev/null +++ b/ql-runtime/src/io/reader.rs @@ -0,0 +1,236 @@ +use std::{ + future::poll_fn, + task::{Context, Poll}, +}; + +use bytes::Bytes; +use ql_wire::{CloseTarget, StreamCloseCode}; + +use super::{ + inner::{Item, RxInner}, + slot::PopError, + Rx, +}; +use crate::{command::Command, log, QlStreamError, RuntimeHandle}; + +pub struct StreamReader { + rx: Rx, + target: CloseTarget, + pending: Bytes, + terminal: ReaderTerminalState, + handle: RuntimeHandle, +} + +enum ReaderTerminalState { + Open, + Delivered, +} + +unsafe impl Sync for StreamReader {} + +impl std::fmt::Debug for StreamReader { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("StreamReader") + .field("stream_id", &self.rx.stream_id()) + .field("target", &self.target) + .field( + "terminal", + &matches!(self.terminal, ReaderTerminalState::Delivered), + ) + .finish_non_exhaustive() + } +} + +impl StreamReader { + pub(crate) fn new(shared: Rx, target: CloseTarget, handle: RuntimeHandle) -> Self { + Self { + rx: shared, + target, + pending: Bytes::new(), + terminal: ReaderTerminalState::Open, + handle, + } + } + + pub fn poll_read( + &mut self, + max_len: usize, + cx: &mut Context<'_>, + ) -> Poll, QlStreamError>> { + if matches!(self.terminal, ReaderTerminalState::Delivered) { + return Poll::Ready(Ok(None)); + } + + match self.try_read_ready(max_len) { + Poll::Ready(result) => return Poll::Ready(result), + Poll::Pending => {} + } + + self.rx.register_waiter(cx.waker()); + + match self.try_read_ready(max_len) { + Poll::Ready(result) => { + self.rx.unregister_waiter(); + Poll::Ready(result) + } + Poll::Pending => Poll::Pending, + } + } + + fn try_read_ready(&mut self, max_len: usize) -> Poll, QlStreamError>> { + if !self.pending.is_empty() { + let pending = &mut self.pending; + let bytes = if pending.len() <= max_len { + std::mem::take(pending) + } else { + pending.split_to(max_len) + }; + self.handle.try_send(Command::PollInbound { + stream_id: self.rx.stream_id(), + }); + return Poll::Ready(Ok(Some(bytes))); + } + + match self.rx.pop() { + Ok(Item::Chunk(mut bytes)) => { + log::trace!( + "byte reader received chunk: stream_id={} target={:?} len={}", + self.rx.stream_id(), + self.target, + bytes.len() + ); + self.handle.try_send(Command::PollInbound { + stream_id: self.rx.stream_id(), + }); + if bytes.len() <= max_len { + return Poll::Ready(Ok(Some(bytes))); + } + let head = bytes.split_to(max_len); + self.pending = bytes; + Poll::Ready(Ok(Some(head))) + } + Ok(Item::Error(error)) => { + log::debug!( + "byte reader delivered terminal error: stream_id={} target={:?} error={:?}", + self.rx.stream_id(), + self.target, + error + ); + self.terminal = ReaderTerminalState::Delivered; + Poll::Ready(Err(error)) + } + Err(PopError) => { + if RxInner::is_finished(self.rx.load_state()) { + log::debug!( + "byte reader delivered clean eof: stream_id={} target={:?}", + self.rx.stream_id(), + self.target + ); + self.terminal = ReaderTerminalState::Delivered; + return Poll::Ready(Ok(None)); + } + Poll::Pending + } + } + } + + pub fn poll_read_chunk( + &mut self, + cx: &mut Context<'_>, + ) -> Poll, QlStreamError>> { + self.poll_read(usize::MAX, cx) + } + + pub async fn read(&mut self, max_len: usize) -> Result, QlStreamError> { + poll_fn(|cx| self.poll_read(max_len, cx)).await + } + + pub async fn read_chunk(&mut self) -> Result, QlStreamError> { + self.read(usize::MAX).await + } + + pub fn close(mut self, code: StreamCloseCode) { + self.close_inner(code); + } + + fn close_inner(&mut self, code: StreamCloseCode) { + if matches!(self.terminal, ReaderTerminalState::Delivered) { + return; + } + log::debug!( + "byte reader explicit close: stream_id={:?} target={:?} code={:?}", + self.rx.stream_id(), + self.target, + code + ); + self.terminal = ReaderTerminalState::Delivered; + self.handle.try_send(Command::CloseStream { + stream_id: self.rx.stream_id(), + target: self.target, + code, + }); + } +} + +impl Drop for StreamReader { + fn drop(&mut self) { + if matches!(self.terminal, ReaderTerminalState::Delivered) { + return; + } + log::debug!( + "byte reader drop close: stream_id={:?} target={:?} code={:?}", + self.rx.stream_id(), + self.target, + StreamCloseCode::CANCELLED + ); + self.handle.try_send(Command::CloseStream { + stream_id: self.rx.stream_id(), + target: self.target, + code: StreamCloseCode::CANCELLED, + }); + } +} + +#[cfg(all(test, loom))] +mod loom_tests { + use std::task::{Context, Poll, Waker}; + + use bytes::Bytes; + use loom::thread; + use ql_wire::CloseTarget; + + use super::*; + use crate::io::sync::loom::*; + + #[test] + fn poll_read_observes_chunk_racing_with_registration() { + check_model(|| { + let inner = shared(); + let mut reader = StreamReader::new(Rx(inner.clone()), CloseTarget::Origin, handle()); + let mut cx = Context::from_waker(Waker::noop()); + + let producer = { + let inner = inner.clone(); + thread::spawn(move || { + inner.rx.try_write(Bytes::from_static(b"abc")).unwrap(); + }) + }; + + let first = reader.poll_read(usize::MAX, &mut cx); + producer.join().unwrap(); + + match first { + Poll::Ready(Ok(Some(bytes))) => { + assert_eq!(bytes, Bytes::from_static(b"abc")); + } + Poll::Pending => { + assert_eq!( + reader.poll_read(usize::MAX, &mut cx), + Poll::Ready(Ok(Some(Bytes::from_static(b"abc")))) + ); + } + other => panic!("unexpected first poll result: {other:?}"), + } + }); + } +} diff --git a/ql-runtime/src/io/slot.rs b/ql-runtime/src/io/slot.rs new file mode 100644 index 00000000..f71f1b0c --- /dev/null +++ b/ql-runtime/src/io/slot.rs @@ -0,0 +1,175 @@ +//! local single-slot queue for stream io +//! copied from `concurrent_queue::single::Single` in `concurrent-queue` + +use core::mem::MaybeUninit; + +#[allow(clippy::wildcard_imports)] +use super::sync::*; + +const LOCKED: usize = 1 << 0; +const PUSHED: usize = 1 << 1; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct PopError; + +#[derive(Debug, PartialEq, Eq)] +pub enum PushError { + Full(T), + Closed(T), +} + +/// A single-element queue. +pub struct Slot { + state: AtomicUsize, + value: UnsafeCell>, +} + +unsafe impl Send for Slot {} +unsafe impl Sync for Slot {} + +impl Slot { + /// Creates a new single-element queue. + pub fn new() -> Self { + Self { + state: AtomicUsize::new(0), + value: UnsafeCell::new(MaybeUninit::uninit()), + } + } + + #[inline] + pub fn load_state(&self) -> usize { + self.state.load(Ordering::Acquire) + } + + #[inline] + pub fn fetch_or(&self, bits: usize) -> usize { + self.state.fetch_or(bits, Ordering::Release) + } + + #[inline] + pub fn compare_exchange(&self, current: usize, new: usize) -> Result<(), usize> { + self.state + .compare_exchange(current, new, Ordering::AcqRel, Ordering::Acquire) + .map(|_| ()) + } + + /// Attempts to push an item into the queue. + pub fn try_push(&self, value: T, closed_mask: usize) -> Result<(), PushError> { + let mut state = self.load_state(); + loop { + if state & closed_mask != 0 { + return Err(PushError::Closed(value)); + } + if state & LOCKED != 0 { + busy_wait(); + state = self.load_state(); + continue; + } + if state & PUSHED != 0 { + return Err(PushError::Full(value)); + } + + // Lock and fill the slot. + let new_state = state | LOCKED | PUSHED; + match self.compare_exchange(state, new_state) { + Ok(()) => { + // Write the value and unlock. + self.value.with_mut(|slot| unsafe { + slot.write(MaybeUninit::new(value)); + }); + self.state.fetch_and(!LOCKED, Ordering::Release); + return Ok(()); + } + Err(actual) => state = actual, + } + } + } + + /// Attempts to push an item into the queue, displacing another if necessary. + pub fn force_push(&self, value: T) -> Option { + // Attempt to lock the slot. + let mut state = self.load_state(); + + loop { + if state & LOCKED != 0 { + busy_wait(); + state = self.load_state(); + continue; + } + + // Lock the slot. + let new_state = state | LOCKED | PUSHED; + match self.compare_exchange(state, new_state) { + Ok(()) => { + // If the value was pushed, swap out the value. + let displaced = if state & PUSHED == 0 { + // SAFETY: write is safe because we have locked the state. + self.value.with_mut(|slot| unsafe { + slot.write(MaybeUninit::new(value)); + }); + None + } else { + // SAFETY: replace is safe because we have locked the state, and + // assume_init is safe because we have checked that the value was pushed. + self.value.with_mut(move |slot| unsafe { + Some(std::ptr::replace(slot, MaybeUninit::new(value)).assume_init()) + }) + }; + + // We can unlock the slot now. + self.state.fetch_and(!LOCKED, Ordering::Release); + return displaced; + } + Err(actual) => state = actual, + } + } + } + + /// Attempts to pop an item from the queue. + pub fn pop(&self) -> Result { + let mut state = PUSHED; + loop { + if state & LOCKED != 0 { + busy_wait(); + state = self.load_state(); + continue; + } + if state & PUSHED == 0 { + return Err(PopError); + } + + // Lock and empty the slot. + let new_state = (state | LOCKED) & !PUSHED; + match self.compare_exchange(state, new_state) { + Ok(()) => { + // Read the value and unlock. + let value = self + .value + .with_mut(|slot| unsafe { slot.read().assume_init() }); + self.state.fetch_and(!LOCKED, Ordering::Release); + return Ok(value); + } + Err(actual) => state = actual, + } + } + } + + #[inline] + pub fn is_empty_state(state: usize) -> bool { + state & PUSHED == 0 + } +} + +impl Drop for Slot { + fn drop(&mut self) { + // Drop the value in the slot. + self.state.with_mut(|state| { + if *state & PUSHED != 0 { + self.value.with_mut(|slot| unsafe { + let value = &mut *slot; + value.as_mut_ptr().drop_in_place(); + }); + } + }); + } +} diff --git a/ql-runtime/src/io/sync.rs b/ql-runtime/src/io/sync.rs new file mode 100644 index 00000000..c5034076 --- /dev/null +++ b/ql-runtime/src/io/sync.rs @@ -0,0 +1,89 @@ +#[cfg(not(all(test, loom)))] +mod inner { + pub use std::{ + cell::UnsafeCell, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, + }, + }; + + pub fn busy_wait() { + std::thread::yield_now(); + } + + pub trait UnsafeCellExt { + type Value; + + fn with_mut(&self, f: F) -> R + where + F: FnOnce(*mut Self::Value) -> R; + } + + impl UnsafeCellExt for UnsafeCell { + type Value = T; + + fn with_mut(&self, f: F) -> R + where + F: FnOnce(*mut Self::Value) -> R, + { + f(self.get()) + } + } + + pub trait AtomicExt { + type Value; + + fn with_mut(&mut self, f: F) -> R + where + F: FnOnce(&mut Self::Value) -> R; + } + + impl AtomicExt for AtomicUsize { + type Value = usize; + + fn with_mut(&mut self, f: F) -> R + where + F: FnOnce(&mut Self::Value) -> R, + { + f(self.get_mut()) + } + } +} + +#[cfg(all(test, loom))] +mod inner { + pub use loom::{ + cell::UnsafeCell, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, + }, + thread::yield_now as busy_wait, + }; +} + +pub use inner::*; + +#[cfg(all(test, loom))] +pub(crate) mod loom { + use loom::model; + use ql_wire::StreamId; + + use super::Arc; + use crate::{io::inner::Inner, RuntimeHandle}; + + pub(crate) fn check_model(f: impl Fn() + Sync + Send + 'static) { + let builder = model::Builder::new(); + builder.check(f); + } + + pub(crate) fn shared() -> Arc { + crate::io::inner::new(StreamId(1u32.into())) + } + + pub(crate) fn handle() -> RuntimeHandle { + let (tx, _rx) = async_channel::unbounded(); + RuntimeHandle::new(tx) + } +} diff --git a/ql-runtime/src/io/writer.rs b/ql-runtime/src/io/writer.rs new file mode 100644 index 00000000..cfad3196 --- /dev/null +++ b/ql-runtime/src/io/writer.rs @@ -0,0 +1,294 @@ +use std::{ + future::poll_fn, + task::{Context, Poll}, +}; + +use bytes::Bytes; +use ql_wire::{CloseTarget, StreamCloseCode}; + +use super::{ + inner::{Item, TxInner}, + slot::PopError, + PushError, Tx, +}; +use crate::{command::Command, log, QlStreamError, RuntimeHandle}; + +pub struct StreamWriter { + tx: Tx, + target: CloseTarget, + open: bool, + terminal: WriterTerminalState, + handle: RuntimeHandle, +} + +enum WriterTerminalState { + Pending, + Terminal(Result<(), QlStreamError>), +} + +unsafe impl Sync for StreamWriter {} + +impl std::fmt::Debug for StreamWriter { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("StreamWriter") + .field("stream_id", &self.tx.stream_id()) + .field("target", &self.target) + .field("closed", &!self.open) + .finish_non_exhaustive() + } +} + +impl StreamWriter { + pub(crate) fn new(shared: Tx, target: CloseTarget, handle: RuntimeHandle) -> Self { + Self { + tx: shared, + target, + open: true, + terminal: WriterTerminalState::Pending, + handle, + } + } + + pub fn poll_write( + &mut self, + bytes: &mut Bytes, + cx: &mut Context<'_>, + ) -> Poll> { + if bytes.is_empty() { + return Poll::Ready(Ok(())); + } + + if !self.open { + return self.poll_terminal(cx); + } + + match self.tx.try_write(std::mem::take(bytes)) { + Ok(()) => { + log::trace!( + "byte writer accepted chunk: stream_id={} target={:?}", + self.tx.stream_id(), + self.target + ); + self.poll_runtime(); + return Poll::Ready(Ok(())); + } + Err(PushError::Closed(chunk)) => { + *bytes = chunk; + self.open = false; + return self.poll_terminal(cx); + } + Err(PushError::Full(chunk)) => { + *bytes = chunk; + } + } + + self.tx.register_waiter(cx.waker()); + + match self.tx.try_write(std::mem::take(bytes)) { + Ok(()) => { + self.tx.unregister_waiter(); + log::trace!( + "byte writer accepted chunk: stream_id={} target={:?}", + self.tx.stream_id(), + self.target + ); + self.poll_runtime(); + Poll::Ready(Ok(())) + } + Err(PushError::Closed(chunk)) => { + self.tx.unregister_waiter(); + *bytes = chunk; + self.open = false; + self.poll_terminal(cx) + } + Err(PushError::Full(chunk)) => { + *bytes = chunk; + Poll::Pending + } + } + } + + pub async fn write(&mut self, bytes: Bytes) -> Result<(), QlStreamError> { + let mut bytes = bytes; + poll_fn(|cx| self.poll_write(&mut bytes, cx)).await + } + + pub fn queue_finish(&mut self) { + if !self.open { + return; + } + log::debug!( + "byte writer finish: stream_id={} target={:?}", + self.tx.stream_id(), + self.target + ); + self.open = false; + self.tx.request_finish(); + self.poll_runtime(); + } + + pub async fn finish(mut self) -> Result<(), QlStreamError> { + self.queue_finish(); + poll_fn(|cx| self.poll_terminal(cx)).await + } + + pub fn poll_finish(&mut self, cx: &mut Context<'_>) -> Poll> { + if self.open { + self.queue_finish(); + } + self.poll_terminal(cx) + } + + pub fn close(mut self, code: StreamCloseCode) { + self.close_inner(code); + } + + fn poll_runtime(&self) { + self.handle.try_send(Command::PollStream { + stream_id: self.tx.stream_id(), + }); + } + + fn poll_terminal(&mut self, cx: &Context<'_>) -> Poll> { + match &self.terminal { + WriterTerminalState::Terminal(result) => return Poll::Ready(result.clone()), + WriterTerminalState::Pending => {} + } + + match self.try_poll_terminal_ready() { + Poll::Ready(result) => return Poll::Ready(result), + Poll::Pending => {} + } + + self.tx.register_waiter(cx.waker()); + + match self.try_poll_terminal_ready() { + Poll::Ready(result) => { + self.tx.unregister_waiter(); + Poll::Ready(result) + } + Poll::Pending => Poll::Pending, + } + } + + fn try_poll_terminal_ready(&mut self) -> Poll> { + let state = self.tx.load_state(); + if TxInner::terminal_ready(state) { + if TxInner::terminal_ok(state) { + self.terminal = WriterTerminalState::Terminal(Ok(())); + return Poll::Ready(Ok(())); + } + + match self.tx.pop() { + Ok(Item::Error(error)) => { + self.terminal = WriterTerminalState::Terminal(Err(error.clone())); + return Poll::Ready(Err(error)); + } + Ok(Item::Chunk(_)) => { + panic!("writer terminal phase contained chunk data") + } + Err(PopError) => {} + } + } + + Poll::Pending + } + + fn close_inner(&mut self, code: StreamCloseCode) { + if !self.open { + return; + } + self.open = false; + log::debug!( + "byte writer close: stream_id={:?} target={:?} code={:?}", + self.tx.stream_id(), + self.target, + code + ); + self.handle.try_send(Command::CloseStream { + stream_id: self.tx.stream_id(), + target: self.target, + code, + }); + } +} + +impl Drop for StreamWriter { + fn drop(&mut self) { + self.close_inner(StreamCloseCode::CANCELLED); + } +} + +#[cfg(all(test, loom))] +mod loom_tests { + use std::task::{Context, Poll, Waker}; + + use bytes::Bytes; + use loom::thread; + use ql_wire::CloseTarget; + + use super::*; + use crate::io::sync::loom::*; + + #[test] + fn poll_write_observes_capacity_racing_with_registration() { + check_model(|| { + let inner = shared(); + inner.tx.try_write(Bytes::from_static(b"abc")).unwrap(); + + let mut writer = StreamWriter::new(Tx(inner.clone()), CloseTarget::Origin, handle()); + let mut bytes = Bytes::from_static(b"xyz"); + let mut cx = Context::from_waker(Waker::noop()); + + let drainer = { + let inner = inner.clone(); + thread::spawn(move || { + assert!(matches!(inner.tx.pop(), Ok(Item::Chunk(_)))); + }) + }; + + let first = writer.poll_write(&mut bytes, &mut cx); + drainer.join().unwrap(); + + match first { + Poll::Ready(Ok(())) => { + assert!(bytes.is_empty()); + } + Poll::Pending => { + assert_eq!(writer.poll_write(&mut bytes, &mut cx), Poll::Ready(Ok(()))); + assert!(bytes.is_empty()); + } + other => panic!("unexpected first poll result: {other:?}"), + } + }); + } + + #[test] + fn poll_finish_observes_terminal_racing_with_registration() { + check_model(|| { + let inner = shared(); + let mut writer = StreamWriter::new(Tx(inner.clone()), CloseTarget::Origin, handle()); + let mut cx = Context::from_waker(Waker::noop()); + + writer.queue_finish(); + + let finisher = { + let inner = inner.clone(); + thread::spawn(move || { + inner.tx.finish(); + }) + }; + + let first = writer.poll_finish(&mut cx); + finisher.join().unwrap(); + + match first { + Poll::Ready(Ok(())) => {} + Poll::Pending => { + assert_eq!(writer.poll_finish(&mut cx), Poll::Ready(Ok(()))); + } + other => panic!("unexpected first poll result: {other:?}"), + } + }); + } +} diff --git a/ql-runtime/src/lib.rs b/ql-runtime/src/lib.rs new file mode 100644 index 00000000..33783456 --- /dev/null +++ b/ql-runtime/src/lib.rs @@ -0,0 +1,63 @@ +pub use ql_fsm::{NoSessionError, PairingInvite}; + +pub use self::{error::QlStreamError, handle::*, platform::*}; + +pub(crate) mod command; +pub(crate) mod driver; +mod error; +pub mod handle; +pub(crate) mod io; +pub mod log; +pub mod platform; +#[cfg(feature = "rpc")] +pub mod rpc; + +#[cfg(test)] +mod tests; + +use ql_fsm::QlFsmConfig; +use ql_wire::QlIdentity; + +#[derive(Debug, Clone, Copy)] +pub struct RuntimeConfig { + pub fsm: QlFsmConfig, + pub max_concurrent_message_writes: usize, +} + +impl Default for RuntimeConfig { + fn default() -> Self { + Self { + fsm: QlFsmConfig::default(), + max_concurrent_message_writes: 4, + } + } +} + +pub struct Runtime

{ + identity: QlIdentity, + platform: P, + config: RuntimeConfig, + rx: async_channel::Receiver, + tx: async_channel::WeakSender, +} + +pub fn new_runtime

( + identity: QlIdentity, + platform: P, + config: RuntimeConfig, +) -> (Runtime

, RuntimeHandle) +where + P: QlPlatform, +{ + let (tx, rx) = async_channel::unbounded(); + ( + Runtime { + identity, + platform, + config, + rx, + tx: tx.downgrade(), + }, + RuntimeHandle::new(tx), + ) +} diff --git a/ql-runtime/src/log.rs b/ql-runtime/src/log.rs new file mode 100644 index 00000000..a0908f79 --- /dev/null +++ b/ql-runtime/src/log.rs @@ -0,0 +1,54 @@ +#![allow(unused_imports, unused_macros)] + +#[cfg(any(feature = "log", test))] +macro_rules! log { + ($level:ident, $($arg:tt)*) => { + ::log::log!(::log::Level::$level, $($arg)*) + }; +} + +#[cfg(not(any(feature = "log", test)))] +macro_rules! log { + ($level:ident, $($arg:tt)*) => { + if false { + let _ = format_args!($($arg)*); + } + }; +} + +macro_rules! trace { + ($($arg:tt)*) => { + $crate::log::log!(Trace, $($arg)*) + }; +} + +macro_rules! debug { + ($($arg:tt)*) => { + $crate::log::log!(Debug, $($arg)*) + }; +} + +macro_rules! info { + ($($arg:tt)*) => { + $crate::log::log!(Info, $($arg)*) + }; +} + +macro_rules! warn_ { + ($($arg:tt)*) => { + $crate::log::log!(Warn, $($arg)*) + }; +} + +macro_rules! error { + ($($arg:tt)*) => { + $crate::log::log!(Error, $($arg)*) + }; +} + +pub(crate) use debug; +pub(crate) use error; +pub(crate) use info; +pub(crate) use log; +pub(crate) use trace; +pub(crate) use warn_ as warn; diff --git a/ql-runtime/src/platform.rs b/ql-runtime/src/platform.rs new file mode 100644 index 00000000..331bfe7a --- /dev/null +++ b/ql-runtime/src/platform.rs @@ -0,0 +1,43 @@ +use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, + time::Instant, +}; + +use ql_fsm::{PeerStatus, ReceiveError}; +use ql_wire::{PeerBundle, QlCrypto, QID}; + +use crate::QlStream; + +pub trait QlTimer { + fn set_deadline(self: Pin<&mut Self>, deadline: Option); + fn poll_wait(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()>; +} + +pub trait QlInbound { + fn poll_recv(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll>; +} + +pub trait QlPlatform: QlCrypto { + type Timer: QlTimer; + type WriteMessageFut<'a>: Future + Unpin + 'a + where + Self: 'a; + type Inbound: QlInbound; + + fn write_message(&self, message: Vec) -> Self::WriteMessageFut<'_>; + /// Returns the platform's inbound transport poller. + /// + /// The runtime calls this once while starting the driver loop and retains the returned + /// poller for the lifetime of the runtime. Platform implementations may panic if this is + /// called more than once. + fn inbound(&mut self) -> Self::Inbound; + fn timer(&self) -> Self::Timer; + + fn persist_peer(&self, peer: PeerBundle); + + fn handle_peer_status(&self, peer: Option, status: PeerStatus); + fn handle_inbound(&self, event: QlStream); + fn handle_recv_error(&self, _error: ReceiveError) {} +} diff --git a/ql-runtime/src/rpc/adapter.rs b/ql-runtime/src/rpc/adapter.rs new file mode 100644 index 00000000..a7347602 --- /dev/null +++ b/ql-runtime/src/rpc/adapter.rs @@ -0,0 +1,83 @@ +use std::task::{Context, Poll}; + +use bytes::Bytes; +use ql_rpc::{RouteId, RpcRead, RpcStream, RpcWrite, StreamCloseCode, StreamError}; +use ql_wire::{RouteId as WireRouteId, StreamCloseCode as WireStreamCloseCode}; + +use crate::{QlStream, QlStreamError, StreamReader, StreamWriter}; + +impl RpcStream for QlStream { + type Error = QlStreamError; + type Reader = StreamReader; + type Writer = StreamWriter; + + fn route_id(&self) -> Option { + let route_id = u32::try_from(self.route_id.into_inner()).ok()?; + Some(RouteId::from_u32(route_id)) + } + + fn split(self) -> (Self::Reader, Self::Writer) { + (self.reader, self.writer) + } +} + +impl RpcRead for StreamReader { + type Error = QlStreamError; + + fn poll_read( + &mut self, + max_len: usize, + cx: &mut Context<'_>, + ) -> Poll, QlStreamError>> { + StreamReader::poll_read(self, max_len, cx) + } + + fn close(self, code: StreamCloseCode) { + StreamReader::close(self, to_wire_close_code(code)); + } +} + +impl RpcWrite for StreamWriter { + type Error = QlStreamError; + + fn poll_write( + &mut self, + bytes: &mut Bytes, + cx: &mut Context<'_>, + ) -> Poll> { + StreamWriter::poll_write(self, bytes, cx) + } + + fn poll_finish(&mut self, cx: &mut Context<'_>) -> Poll> { + StreamWriter::poll_finish(self, cx) + } + + fn close(self, code: StreamCloseCode) { + StreamWriter::close(self, to_wire_close_code(code)); + } +} + +pub(super) fn to_wire_route_id(route_id: RouteId) -> WireRouteId { + WireRouteId::from_u32(route_id.into_inner()) +} + +pub(super) fn to_wire_close_code(code: StreamCloseCode) -> WireStreamCloseCode { + WireStreamCloseCode(code.into_inner()) +} + +impl From for QlStreamError { + fn from(code: StreamCloseCode) -> Self { + Self::StreamClosed { + code: WireStreamCloseCode(code.into_inner()), + } + } +} + +impl StreamError for QlStreamError { + fn close_code(&self) -> Option { + match self { + QlStreamError::StreamClosed { code } => Some(StreamCloseCode(code.0)), + QlStreamError::NoSession => None, + } + } +} diff --git a/ql-runtime/src/rpc/download.rs b/ql-runtime/src/rpc/download.rs new file mode 100644 index 00000000..d3b63585 --- /dev/null +++ b/ql-runtime/src/rpc/download.rs @@ -0,0 +1,67 @@ +use bytes::Bytes; +use ql_rpc::download::Download as DownloadRpc; + +use super::RpcError; +use crate::StreamReader; + +pub struct DownloadCall { + pub(super) inner: ql_rpc::download::DownloadCall, +} + +pub struct DownloadReader { + pub(super) inner: ql_rpc::download::DownloadReader, +} + +pub struct DownloadPart<'a, M: DownloadRpc> { + inner: ql_rpc::download::DownloadPart<'a, M, StreamReader>, +} + +impl DownloadCall +where + M: DownloadRpc, +{ + pub async fn start(self) -> Result<(M::ResponseHeader, DownloadReader), RpcError> { + let (header, inner) = self.inner.start().await?; + Ok((header, DownloadReader { inner })) + } + + pub fn close(self, code: ql_wire::StreamCloseCode) { + self.inner.close(ql_rpc::StreamCloseCode(code.0)); + } +} + +impl DownloadReader +where + M: DownloadRpc, +{ + pub async fn next_part( + &mut self, + ) -> Result)>, RpcError> { + Ok(self + .inner + .next_part() + .await? + .map(|(header, inner)| (header, DownloadPart { inner }))) + } + + pub async fn complete(self) -> Result<(), RpcError> { + self.inner.complete().await.map_err(RpcError::from) + } + + pub fn close(self, code: ql_wire::StreamCloseCode) { + self.inner.close(ql_rpc::StreamCloseCode(code.0)); + } +} + +impl DownloadPart<'_, M> +where + M: DownloadRpc, +{ + pub async fn read_chunk(&mut self) -> Result, RpcError> { + Ok(self.inner.read_chunk().await?) + } + + pub fn close(self, code: ql_wire::StreamCloseCode) { + self.inner.close(ql_rpc::StreamCloseCode(code.0)); + } +} diff --git a/ql-runtime/src/rpc/duplex.rs b/ql-runtime/src/rpc/duplex.rs new file mode 100644 index 00000000..cdad6670 --- /dev/null +++ b/ql-runtime/src/rpc/duplex.rs @@ -0,0 +1,59 @@ +use futures_lite::future::poll_fn; +use ql_rpc::duplex::Duplex as DuplexRpc; + +use super::RpcError; +use crate::{QlStreamError, StreamReader, StreamWriter}; + +pub struct DuplexCall { + pub sender: DuplexSender, + pub receiver: DuplexReceiver, +} + +pub struct DuplexSender +where + T: ql_rpc::RpcCodec, +{ + pub(super) inner: ql_rpc::duplex::DuplexSender, +} + +pub struct DuplexReceiver +where + T: ql_rpc::RpcCodec, +{ + pub(super) inner: ql_rpc::duplex::DuplexReceiver, +} + +impl DuplexSender +where + T: ql_rpc::RpcCodec, +{ + pub async fn send(&mut self, event: &T) -> Result<(), QlStreamError> { + self.inner.send(event).await + } + + pub async fn finish(self) -> Result<(), QlStreamError> { + self.inner.finish().await + } + + pub fn close(self, code: ql_wire::StreamCloseCode) { + self.inner.close(ql_rpc::StreamCloseCode(code.0)); + } +} + +impl DuplexReceiver +where + T: ql_rpc::RpcCodec, +{ + pub async fn next_event(&mut self) -> Option>> { + poll_fn(|cx| { + self.inner + .poll_next_event(cx) + .map(|item| item.map(|result| Ok(result?))) + }) + .await + } + + pub fn close(self, code: ql_wire::StreamCloseCode) { + self.inner.close(ql_rpc::StreamCloseCode(code.0)); + } +} diff --git a/ql-runtime/src/rpc/error.rs b/ql-runtime/src/rpc/error.rs new file mode 100644 index 00000000..4cc9e176 --- /dev/null +++ b/ql-runtime/src/rpc/error.rs @@ -0,0 +1,79 @@ +use ql_fsm::NoSessionError; + +use crate::QlStreamError; + +#[derive(Debug)] +pub enum RpcError { + NoSession, + Closed(ql_rpc::StreamCloseCode), + Protocol(ql_rpc::Error), + Codec(E), +} + +impl From for RpcError { + fn from(_: NoSessionError) -> Self { + Self::NoSession + } +} + +impl From for RpcError { + fn from(error: QlStreamError) -> Self { + match error { + QlStreamError::StreamClosed { code } => Self::Closed(ql_rpc::StreamCloseCode(code.0)), + QlStreamError::NoSession => Self::NoSession, + } + } +} + +impl From for RpcError { + fn from(error: ql_rpc::Error) -> Self { + Self::Protocol(error) + } +} + +impl From> for RpcError { + fn from(error: ql_rpc::CodecError) -> Self { + match error { + ql_rpc::CodecError::Rpc(error) => Self::Protocol(error), + ql_rpc::CodecError::Codec(error) => Self::Codec(error), + } + } +} + +impl From> for RpcError { + fn from(error: ql_rpc::CallError) -> Self { + match error { + ql_rpc::CallError::Protocol(error) => Self::Protocol(error), + ql_rpc::CallError::Codec(error) => Self::Codec(error), + ql_rpc::CallError::Transport(error) => error.into(), + } + } +} + +impl std::fmt::Display for RpcError +where + E: std::fmt::Display, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::NoSession => write!(f, "no session"), + Self::Closed(code) => write!(f, "stream closed {code:?}"), + Self::Protocol(error) => write!(f, "{error}"), + Self::Codec(error) => write!(f, "{error}"), + } + } +} + +impl std::error::Error for RpcError +where + E: std::error::Error + 'static, +{ + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + Self::Protocol(error) => Some(error), + Self::Codec(error) => Some(error), + RpcError::NoSession => None, + RpcError::Closed(_) => None, + } + } +} diff --git a/ql-runtime/src/rpc/mod.rs b/ql-runtime/src/rpc/mod.rs new file mode 100644 index 00000000..d8be02c5 --- /dev/null +++ b/ql-runtime/src/rpc/mod.rs @@ -0,0 +1,154 @@ +pub use self::{download::*, duplex::*, error::*, progress::*, subscription::*, upload::*}; + +mod adapter; +mod download; +mod duplex; +mod error; +mod progress; +mod subscription; +mod upload; + +use bytes::Bytes; +use ql_rpc::{ + download::{self as rpc_download, Download as DownloadRpc}, + duplex::{self as rpc_duplex, Duplex as DuplexRpc}, + notification::{self, Notification}, + progress::{self as rpc_progress, Progress}, + request::{self, Request as RequestRpc}, + subscription::{self as rpc_subscription, Subscription as SubscriptionRpc}, + upload::{self as rpc_upload, Upload as UploadRpc}, +}; + +use crate::{RuntimeHandle, StreamReader}; + +#[derive(Clone)] +pub struct RpcHandle { + inner: RuntimeHandle, +} + +impl RpcHandle { + pub async fn notification(&self, event: &M::Payload) -> Result<(), RpcError> + where + M: Notification, + { + let mut payload = Vec::new(); + notification::encode_notification::(event, &mut payload); + let mut stream = self + .inner + .open_stream(adapter::to_wire_route_id(M::ROUTE)) + .await?; + stream.reader.close(ql_wire::StreamCloseCode::CANCELLED); + stream.writer.write(Bytes::from(payload)).await?; + stream.writer.finish().await?; + Ok(()) + } + + pub async fn request(&self, request: &M::Request) -> Result> + where + M: RequestRpc, + { + let mut payload = Vec::new(); + request::encode_request::(request, &mut payload); + let response = self.start_request(M::ROUTE, payload).await?; + Ok(request::read_response::(response).await?) + } + + pub async fn subscribe( + &self, + request: &M::Request, + ) -> Result, RpcError> + where + M: SubscriptionRpc, + { + let mut payload = Vec::new(); + rpc_subscription::encode_request::(request, &mut payload); + let response = self.start_request(M::ROUTE, payload).await?; + Ok(Subscription { + inner: rpc_subscription::SubscriptionCall::new(response), + }) + } + + pub async fn download( + &self, + request: &M::Request, + ) -> Result, RpcError> + where + M: DownloadRpc, + { + let mut payload = Vec::new(); + rpc_download::encode_request::(request, &mut payload); + let response = self.start_request(M::ROUTE, payload).await?; + Ok(DownloadCall { + inner: rpc_download::DownloadCall::new(response), + }) + } + + pub async fn progress( + &self, + request: &M::Request, + ) -> Result, RpcError> + where + M: Progress, + { + let mut payload = Vec::new(); + rpc_progress::encode_request::(request, &mut payload); + let response = self.start_request(M::ROUTE, payload).await?; + Ok(ProgressCall { + inner: rpc_progress::ProgressCall::new(response), + }) + } + + pub async fn upload(&self, request: &M::Request) -> Result, RpcError> + where + M: UploadRpc, + { + let mut payload = Vec::new(); + rpc_upload::encode_request::(request, &mut payload); + let mut stream = self + .inner + .open_stream(adapter::to_wire_route_id(M::ROUTE)) + .await?; + stream.writer.write(Bytes::from(payload)).await?; + Ok(UploadCall { + inner: rpc_upload::UploadCall::new(stream.writer, stream.reader), + }) + } + + pub async fn duplex(&self) -> Result, RpcError> + where + M: DuplexRpc, + { + let stream = self + .inner + .open_stream(adapter::to_wire_route_id(M::ROUTE)) + .await?; + Ok(DuplexCall { + sender: DuplexSender { + inner: rpc_duplex::DuplexSender::new(stream.writer), + }, + receiver: DuplexReceiver { + inner: rpc_duplex::DuplexReceiver::new(stream.reader), + }, + }) + } +} + +impl RpcHandle { + pub(super) fn new(inner: RuntimeHandle) -> Self { + Self { inner } + } + + async fn start_request( + &self, + route_id: ql_rpc::RouteId, + payload: Vec, + ) -> Result> { + let mut stream = self + .inner + .open_stream(adapter::to_wire_route_id(route_id)) + .await?; + stream.writer.write(Bytes::from(payload)).await?; + stream.writer.finish().await?; + Ok(stream.reader) + } +} diff --git a/ql-runtime/src/rpc/progress.rs b/ql-runtime/src/rpc/progress.rs new file mode 100644 index 00000000..a22da20f --- /dev/null +++ b/ql-runtime/src/rpc/progress.rs @@ -0,0 +1,50 @@ +use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, +}; + +use futures_lite::Stream; +use ql_rpc::progress::Progress; + +use super::RpcError; +use crate::StreamReader; + +pub struct ProgressCall { + pub(super) inner: ql_rpc::progress::ProgressCall, +} + +impl Unpin for ProgressCall where M: Progress {} + +impl ProgressCall +where + M: Progress, +{ + pub fn close(self, code: ql_wire::StreamCloseCode) { + self.inner.close(ql_rpc::StreamCloseCode(code.0)); + } +} + +impl Stream for ProgressCall +where + M: Progress, +{ + type Item = M::Progress; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.get_mut().inner.poll_next_progress(cx) + } +} + +impl Future for ProgressCall +where + M: Progress, +{ + type Output = Result>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + Pin::new(&mut self.get_mut().inner) + .poll(cx) + .map(|result| result.map_err(RpcError::from)) + } +} diff --git a/ql-runtime/src/rpc/subscription.rs b/ql-runtime/src/rpc/subscription.rs new file mode 100644 index 00000000..45a08a6b --- /dev/null +++ b/ql-runtime/src/rpc/subscription.rs @@ -0,0 +1,43 @@ +use std::{ + pin::Pin, + task::{Context, Poll}, +}; + +use futures_lite::{future::poll_fn, Stream}; +use ql_rpc::subscription::Subscription as SubscriptionRpc; + +use super::RpcError; +use crate::StreamReader; + +pub struct Subscription { + pub(super) inner: ql_rpc::subscription::SubscriptionCall, +} + +impl Unpin for Subscription where M: SubscriptionRpc {} + +impl Subscription +where + M: SubscriptionRpc, +{ + pub async fn next_event(&mut self) -> Option>> { + poll_fn(|cx| Pin::new(&mut *self).poll_next(cx)).await + } + + pub fn close(self, code: ql_wire::StreamCloseCode) { + self.inner.close(ql_rpc::StreamCloseCode(code.0)); + } +} + +impl Stream for Subscription +where + M: SubscriptionRpc, +{ + type Item = Result>; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.get_mut() + .inner + .poll_next_event(cx) + .map(|item| item.map(|result| Ok(result?))) + } +} diff --git a/ql-runtime/src/rpc/upload.rs b/ql-runtime/src/rpc/upload.rs new file mode 100644 index 00000000..33ee3665 --- /dev/null +++ b/ql-runtime/src/rpc/upload.rs @@ -0,0 +1,44 @@ +use bytes::Bytes; +use ql_rpc::upload::Upload as UploadRpc; + +use super::RpcError; +use crate::QlStreamError; + +pub struct UploadCall { + pub(super) inner: ql_rpc::upload::UploadCall, +} + +pub struct UploadPartWriter<'a, M: UploadRpc> { + inner: ql_rpc::upload::UploadPartWriter<'a, M, crate::StreamWriter, crate::StreamReader>, +} + +impl UploadCall +where + M: UploadRpc, +{ + pub async fn start_part( + &mut self, + part_header: M::PartHeader, + ) -> Result, QlStreamError> { + Ok(UploadPartWriter { + inner: self.inner.start_part(part_header).await?, + }) + } + + pub async fn finish(self) -> Result> { + self.inner.finish().await.map_err(RpcError::from) + } +} + +impl UploadPartWriter<'_, M> +where + M: UploadRpc, +{ + pub async fn send(&mut self, bytes: Bytes) -> Result<(), QlStreamError> { + self.inner.send(bytes).await + } + + pub async fn finish(self) -> Result<(), QlStreamError> { + self.inner.finish().await + } +} diff --git a/ql-runtime/src/tests/handshake.rs b/ql-runtime/src/tests/handshake.rs new file mode 100644 index 00000000..65731bbc --- /dev/null +++ b/ql-runtime/src/tests/handshake.rs @@ -0,0 +1,178 @@ +use std::time::Duration; + +use bytes::Bytes; + +use super::*; + +#[tokio::test(flavor = "current_thread")] +async fn connect_round_trip_changes_peer_status() { + run_local_test(async { + let pair = TestPair::new(default_runtime_config()); + pair.connect_and_wait(Side::A).await; + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn opening_stream_requires_connection() { + run_local_test(async { + let pair = TestPair::new(default_runtime_config()); + assert!(matches!( + pair.side(Side::A).handle.open_stream(test_route_id()).await, + Err(NoSessionError) + )); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn handshake_timeout_disconnects() { + run_local_test(async { + let config = RuntimeConfig { + fsm: QlFsmConfig { + handshake_timeout: Duration::from_millis(60), + ..default_runtime_config().fsm + }, + ..default_runtime_config() + }; + let (platform_a, _outbound_a, _inbound_a, status_a) = TestPlatform::new(); + let (platform_b, _outbound_b, _inbound_b, _status_b) = TestPlatform::new(); + let (identity_a, identity_b) = test_identities(&SoftwareCrypto); + + let (runtime_a, handle_a) = new_runtime(identity_a.clone(), platform_a, config); + let (runtime_b, handle_b) = new_runtime(identity_b.clone(), platform_b, config); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + register_peers(&handle_a, &handle_b, &identity_a, &identity_b); + handle_a.connect(); + + await_status(&status_a, Some(identity_b.qid), PeerStatus::Disconnected).await; + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn rejected_session_write_is_reissued() { + run_local_test(async { + let config = default_runtime_config(); + let (platform_a, outbound_a, inbound_a_tx, status_a) = + TestPlatform::new_with_session_write_failure(1); + let (platform_b, outbound_b, inbound_b_tx, status_b, inbound_b) = + TestPlatform::new_with_inbound(); + let (identity_a, identity_b) = test_identities(&SoftwareCrypto); + + let (runtime_a, handle_a) = new_runtime(identity_a.clone(), platform_a, config); + let (runtime_b, handle_b) = new_runtime(identity_b.clone(), platform_b, config); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + spawn_forwarder(outbound_a, inbound_b_tx); + spawn_forwarder(outbound_b, inbound_a_tx); + + register_peers(&handle_a, &handle_b, &identity_a, &identity_b); + handle_a.connect(); + + await_status(&status_a, Some(identity_b.qid), PeerStatus::Connected).await; + await_status(&status_b, Some(identity_a.qid), PeerStatus::Connected).await; + + let responder = tokio::task::spawn_local(async move { + let stream = inbound_b.recv().await.unwrap(); + let request = read_all(stream.reader).await.unwrap(); + stream.writer.finish().await.unwrap(); + request + }); + + let mut stream = handle_a.open_stream(test_route_id()).await.unwrap(); + stream + .writer + .write(Bytes::from_static(b"retry")) + .await + .unwrap(); + stream.writer.finish().await.unwrap(); + assert_eq!(next_chunk(&mut stream.reader).await.unwrap(), None); + + assert_eq!( + tokio::time::timeout(Duration::from_secs(2), responder) + .await + .unwrap() + .unwrap(), + b"retry".to_vec() + ); + + assert_no_status_for( + &status_a, + Some(identity_b.qid), + PeerStatus::Disconnected, + Duration::from_millis(150), + ) + .await; + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn start_pairing_round_trip_connects_when_armed() { + run_local_test(async { + let config = default_runtime_config(); + let (platform_a, outbound_a, inbound_a_tx, status_a) = TestPlatform::new(); + let (platform_b, outbound_b, inbound_b_tx, status_b) = TestPlatform::new(); + let (identity_a, identity_b) = test_identities(&SoftwareCrypto); + let token = pairing_token(7); + + let (runtime_a, handle_a) = new_runtime(identity_a.clone(), platform_a, config); + let (runtime_b, handle_b) = new_runtime(identity_b.clone(), platform_b, config); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + spawn_forwarder(outbound_a, inbound_b_tx); + spawn_forwarder(outbound_b, inbound_a_tx); + + handle_b.arm_pairing(token); + handle_a.start_pairing(PairingInvite { + qid: identity_b.qid, + token, + }); + + await_status(&status_a, Some(identity_b.qid), PeerStatus::Connected).await; + await_status(&status_b, Some(identity_a.qid), PeerStatus::Connected).await; + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn start_pairing_does_not_connect_when_unarmed() { + run_local_test(async { + let config = default_runtime_config(); + let (platform_a, outbound_a, inbound_a_tx, status_a) = TestPlatform::new(); + let (platform_b, outbound_b, inbound_b_tx, _status_b) = TestPlatform::new(); + let (identity_a, identity_b) = test_identities(&SoftwareCrypto); + let token = pairing_token(8); + + let (runtime_a, handle_a) = new_runtime(identity_a.clone(), platform_a, config); + let (runtime_b, _handle_b) = new_runtime(identity_b.clone(), platform_b, config); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + spawn_forwarder(outbound_a, inbound_b_tx); + spawn_forwarder(outbound_b, inbound_a_tx); + + handle_a.start_pairing(PairingInvite { + qid: identity_b.qid, + token, + }); + + assert_no_status_for( + &status_a, + Some(identity_b.qid), + PeerStatus::Connected, + Duration::from_millis(150), + ) + .await; + }) + .await; +} diff --git a/ql-runtime/src/tests/mod.rs b/ql-runtime/src/tests/mod.rs new file mode 100644 index 00000000..af368738 --- /dev/null +++ b/ql-runtime/src/tests/mod.rs @@ -0,0 +1,710 @@ +use std::{ + future::Future, + pin::Pin, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, Mutex, Once, + }, + task::{Context, Poll}, + time::Duration, +}; + +use async_channel::{Receiver, Sender}; +use futures_lite::Stream; +use ql_fsm::PeerStatus; +use ql_wire::{ + generate_identity, test_identities, MlKemCiphertext, MlKemKeyPair, MlKemPrivateKey, + MlKemPublicKey, Nonce, PairingToken, PeerBundle, QlAead, QlHash, QlIdentity, QlKem, QlRandom, + RecordHeader, RecordType, RouteId, SessionKey, SoftwareCrypto, WireDecode, QID, +}; +use tokio::{task::LocalSet, time::Sleep}; + +use crate::{ + new_runtime, platform::QlTimer, NoSessionError, PairingInvite, QlFsmConfig, QlStream, + QlStreamError, RuntimeConfig, RuntimeHandle, +}; + +mod handshake; +#[cfg(feature = "rpc")] +mod rpc; +mod session; +mod stream; + +fn init_test_logger() { + static INIT: Once = Once::new(); + + INIT.call_once(|| { + let env = env_logger::Env::default().default_filter_or("ql_runtime=info"); + let mut builder = env_logger::Builder::from_env(env); + builder.is_test(true); + let _ = builder.try_init(); + }); +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +struct StatusEvent { + peer: Option, + status: PeerStatus, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum Side { + A, + B, +} + +impl Side { + fn opposite(self) -> Self { + match self { + Self::A => Self::B, + Self::B => Self::A, + } + } +} + +fn test_route_id() -> RouteId { + RouteId::from_u32(1) +} + +#[derive(Debug, Clone)] +struct WriteStats { + active: Arc, + max_active: Arc, +} + +impl WriteStats { + fn new() -> Self { + Self { + active: Arc::new(AtomicUsize::new(0)), + max_active: Arc::new(AtomicUsize::new(0)), + } + } + + fn max_active(&self) -> usize { + self.max_active.load(Ordering::Relaxed) + } +} + +struct TestPlatform { + outbound: Sender>, + _inbound_messages_tx: Sender>, + inbound_messages: Option>>, + status: Sender, + inbound: Option>, + crypto: SoftwareCrypto, + encrypted_write_counter: AtomicUsize, + fail_encrypted_write_at: Option, + write_delay: Duration, + write_stats: Option, +} + +struct TestInbound { + receiver: Receiver>, +} + +type TestPlatformParts = ( + TestPlatform, + Receiver>, + Sender>, + Receiver, +); + +type TestPlatformPartsWithInbound = ( + TestPlatform, + Receiver>, + Sender>, + Receiver, + Receiver, +); + +impl TestPlatform { + fn new() -> TestPlatformParts { + Self::new_inner(None, None, Duration::ZERO, None) + } + + fn new_with_inbound() -> TestPlatformPartsWithInbound { + let (inbound_tx, inbound_rx) = async_channel::unbounded(); + let (platform, outbound_rx, inbound_messages_tx, status_rx) = + Self::new_inner(Some(inbound_tx), None, Duration::ZERO, None); + ( + platform, + outbound_rx, + inbound_messages_tx, + status_rx, + inbound_rx, + ) + } + + fn new_with_session_write_failure(fail_encrypted_write_at: usize) -> TestPlatformParts { + Self::new_inner(None, Some(fail_encrypted_write_at), Duration::ZERO, None) + } + + fn new_with_delayed_writes(delay: Duration, write_stats: WriteStats) -> TestPlatformParts { + Self::new_inner(None, None, delay, Some(write_stats)) + } + + fn new_inner( + inbound: Option>, + fail_encrypted_write_at: Option, + write_delay: Duration, + write_stats: Option, + ) -> TestPlatformParts { + let (outbound, outbound_rx) = async_channel::unbounded(); + let (inbound_messages_tx, inbound_messages_rx) = async_channel::unbounded(); + let (status, status_rx) = async_channel::unbounded(); + ( + Self { + outbound, + _inbound_messages_tx: inbound_messages_tx.clone(), + inbound_messages: Some(inbound_messages_rx), + status, + inbound, + crypto: SoftwareCrypto, + encrypted_write_counter: AtomicUsize::new(0), + fail_encrypted_write_at, + write_delay, + write_stats, + }, + outbound_rx, + inbound_messages_tx, + status_rx, + ) + } +} + +struct TestSide { + handle: RuntimeHandle, + status: Receiver, + peer: QID, + inbound: Receiver, +} + +struct TestPair { + a: TestSide, + b: TestSide, +} + +#[derive(Debug, Clone, Copy, Default)] +struct LinkBehavior { + base_delay: Duration, + drop_encrypted_every: Option, + duplicate_encrypted_every: Option, + delay_encrypted_every: Option<(usize, Duration)>, +} + +#[derive(Clone, Default)] +struct LinkController { + behavior: Arc>, +} + +impl LinkController { + fn new(behavior: LinkBehavior) -> Self { + Self { + behavior: Arc::new(Mutex::new(behavior)), + } + } + + fn load(&self) -> LinkBehavior { + *self.behavior.lock().unwrap() + } + + fn store(&self, behavior: LinkBehavior) { + *self.behavior.lock().unwrap() = behavior; + } +} + +#[derive(Clone)] +struct ControlledLinks { + a_to_b: LinkController, + b_to_a: LinkController, +} + +impl TestPair { + fn new(config: RuntimeConfig) -> Self { + Self::new_with_links(config, LinkBehavior::default(), LinkBehavior::default()) + } + + fn new_with_links(config: RuntimeConfig, a_to_b: LinkBehavior, b_to_a: LinkBehavior) -> Self { + let (pair, _links) = Self::new_with_controlled_links(config, a_to_b, b_to_a); + pair + } + + fn new_with_controlled_links( + config: RuntimeConfig, + a_to_b: LinkBehavior, + b_to_a: LinkBehavior, + ) -> (Self, ControlledLinks) { + let (platform_a, outbound_a, inbound_a_tx, status_a, inbound_a) = + TestPlatform::new_with_inbound(); + let (platform_b, outbound_b, inbound_b_tx, status_b, inbound_b) = + TestPlatform::new_with_inbound(); + let (identity_a, identity_b) = test_identities(&SoftwareCrypto); + let links = ControlledLinks { + a_to_b: LinkController::new(a_to_b), + b_to_a: LinkController::new(b_to_a), + }; + + let (runtime_a, handle_a) = new_runtime(identity_a.clone(), platform_a, config); + let (runtime_b, handle_b) = new_runtime(identity_b.clone(), platform_b, config); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + spawn_simulated_forwarder(outbound_a, inbound_b_tx, links.a_to_b.clone()); + spawn_simulated_forwarder(outbound_b, inbound_a_tx, links.b_to_a.clone()); + register_peers(&handle_a, &handle_b, &identity_a, &identity_b); + + ( + Self { + a: TestSide { + handle: handle_a, + status: status_a, + peer: identity_a.qid, + inbound: inbound_a, + }, + b: TestSide { + handle: handle_b, + status: status_b, + peer: identity_b.qid, + inbound: inbound_b, + }, + }, + links, + ) + } + + fn side(&self, side: Side) -> &TestSide { + match side { + Side::A => &self.a, + Side::B => &self.b, + } + } + + fn side_mut(&mut self, side: Side) -> &mut TestSide { + match side { + Side::A => &mut self.a, + Side::B => &mut self.b, + } + } + + async fn connect_and_wait(&self, initiator: Side) { + self.side(initiator).handle.connect(); + await_status( + &self.side(initiator).status, + Some(self.side(initiator.opposite()).peer), + PeerStatus::Connected, + ) + .await; + await_status( + &self.side(initiator.opposite()).status, + Some(self.side(initiator).peer), + PeerStatus::Connected, + ) + .await; + } + + fn take_inbound(&mut self, side: Side) -> Receiver { + let replacement = async_channel::unbounded().1; + std::mem::replace(&mut self.side_mut(side).inbound, replacement) + } +} + +struct TokioTimer { + sleep: Pin>, +} + +impl TokioTimer { + fn new() -> Self { + Self { + sleep: Box::pin(tokio::time::sleep_until(parked_deadline())), + } + } +} + +impl QlTimer for TokioTimer { + fn set_deadline(mut self: Pin<&mut Self>, deadline: Option) { + let deadline = deadline.map_or_else(parked_deadline, tokio::time::Instant::from_std); + self.as_mut().get_mut().sleep.as_mut().reset(deadline); + } + + fn poll_wait(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { + self.as_mut().get_mut().sleep.as_mut().poll(cx) + } +} + +impl QlRandom for TestPlatform { + fn fill_random_bytes(&self, data: &mut [u8]) { + self.crypto.fill_random_bytes(data); + } +} + +impl QlHash for TestPlatform { + fn sha256(&self, parts: &[&[u8]]) -> [u8; 32] { + self.crypto.sha256(parts) + } +} + +impl QlAead for TestPlatform { + fn aes256_gcm_encrypt( + &self, + key: &SessionKey, + nonce: &Nonce, + aad: &[u8], + buffer: &mut [u8], + ) -> [u8; ql_wire::ENCRYPTED_MESSAGE_AUTH_SIZE] { + self.crypto.aes256_gcm_encrypt(key, nonce, aad, buffer) + } + + fn aes256_gcm_decrypt( + &self, + key: &SessionKey, + nonce: &Nonce, + aad: &[u8], + buffer: &mut [u8], + auth_tag: &[u8; ql_wire::ENCRYPTED_MESSAGE_AUTH_SIZE], + ) -> bool { + self.crypto + .aes256_gcm_decrypt(key, nonce, aad, buffer, auth_tag) + } +} + +impl QlKem for TestPlatform { + fn mlkem_generate_keypair(&self) -> MlKemKeyPair { + self.crypto.mlkem_generate_keypair() + } + + fn mlkem_encapsulate(&self, public_key: &MlKemPublicKey) -> (MlKemCiphertext, SessionKey) { + self.crypto.mlkem_encapsulate(public_key) + } + + fn mlkem_decapsulate(&self, pk: &MlKemPrivateKey, cipher: &MlKemCiphertext) -> SessionKey { + self.crypto.mlkem_decapsulate(pk, cipher) + } +} + +impl crate::platform::QlPlatform for TestPlatform { + type Timer = TokioTimer; + type WriteMessageFut<'a> = Pin + Send + 'a>>; + type Inbound = TestInbound; + + fn write_message(&self, message: Vec) -> Self::WriteMessageFut<'_> { + let outbound = self.outbound.clone(); + let write_delay = self.write_delay; + let fail_encrypted_write_at = self.fail_encrypted_write_at; + let write_stats = self.write_stats.clone(); + + Box::pin(async move { + if let Some(stats) = write_stats.as_ref() { + let active = stats.active.fetch_add(1, Ordering::Relaxed) + 1; + stats.max_active.fetch_max(active, Ordering::Relaxed); + } + + if !write_delay.is_zero() { + tokio::time::sleep(write_delay).await; + } + + let should_fail = if is_encrypted_payload(&message) { + let count = self.encrypted_write_counter.fetch_add(1, Ordering::Relaxed) + 1; + fail_encrypted_write_at == Some(count) + } else { + false + }; + + let success = if should_fail { + false + } else { + outbound.send(message).await.is_ok() + }; + + if let Some(stats) = write_stats.as_ref() { + stats.active.fetch_sub(1, Ordering::Relaxed); + } + + success + }) + } + + fn inbound(&mut self) -> Self::Inbound { + TestInbound { + receiver: self + .inbound_messages + .take() + .expect("TestPlatform::inbound may only be called once"), + } + } + + fn timer(&self) -> Self::Timer { + TokioTimer::new() + } + + fn persist_peer(&self, _peer: PeerBundle) {} + + fn handle_peer_status(&self, peer: Option, status: PeerStatus) { + let _ = self.status.try_send(StatusEvent { peer, status }); + } + + fn handle_inbound(&self, event: QlStream) { + if let Some(tx) = &self.inbound { + let _ = tx.try_send(event); + } + } +} + +impl crate::platform::QlInbound for TestInbound { + fn poll_recv(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match unsafe { self.as_mut().map_unchecked_mut(|this| &mut this.receiver) }.poll_next(cx) { + Poll::Ready(Some(bytes)) => Poll::Ready(bytes), + Poll::Ready(None) => panic!("TestInbound channel closed"), + Poll::Pending => Poll::Pending, + } + } +} + +fn parked_deadline() -> tokio::time::Instant { + tokio::time::Instant::now() + Duration::from_secs(60 * 60 * 24 * 365 * 100) +} + +fn is_encrypted_payload(bytes: &[u8]) -> bool { + RecordHeader::decode_bytes(bytes) + .ok() + .is_some_and(|header| header.record_type == RecordType::Session) +} + +fn pairing_token(byte: u8) -> PairingToken { + PairingToken([byte; PairingToken::SIZE]) +} + +fn register_peers( + handle_a: &RuntimeHandle, + handle_b: &RuntimeHandle, + id_a: &QlIdentity, + id_b: &QlIdentity, +) { + handle_a.bind_peer(id_b.bundle()); + handle_b.bind_peer(id_a.bundle()); +} + +fn spawn_forwarder(outbound: Receiver>, inbound: Sender>) { + spawn_simulated_forwarder( + outbound, + inbound, + LinkController::new(LinkBehavior::default()), + ); +} + +fn spawn_simulated_forwarder( + outbound: Receiver>, + inbound: Sender>, + controller: LinkController, +) { + tokio::task::spawn_local(async move { + let mut encrypted_count = 0usize; + while let Ok(bytes) = outbound.recv().await { + let behavior = controller.load(); + let encrypted = is_encrypted_payload(&bytes); + let ordinal = if encrypted { + encrypted_count = encrypted_count.saturating_add(1); + Some(encrypted_count) + } else { + None + }; + + if ordinal.is_some_and(|count| { + behavior + .drop_encrypted_every + .is_some_and(|nth| nth != 0 && count % nth == 0) + }) { + continue; + } + + let mut delay = behavior.base_delay; + if let Some(count) = ordinal { + if let Some((nth, extra_delay)) = behavior.delay_encrypted_every { + if nth != 0 && count % nth == 0 { + delay += extra_delay; + } + } + } + + let primary = bytes.clone(); + let primary_inbound = inbound.clone(); + tokio::task::spawn_local(async move { + if !delay.is_zero() { + tokio::time::sleep(delay).await; + } + let _ = primary_inbound.try_send(primary); + }); + + if ordinal.is_some_and(|count| { + behavior + .duplicate_encrypted_every + .is_some_and(|nth| nth != 0 && count % nth == 0) + }) { + let duplicate_inbound = inbound.clone(); + tokio::task::spawn_local(async move { + let duplicate_delay = delay + Duration::from_millis(1); + if !duplicate_delay.is_zero() { + tokio::time::sleep(duplicate_delay).await; + } + let _ = duplicate_inbound.try_send(bytes); + }); + } + } + }); +} + +fn spawn_drop_every_nth_encrypted_forwarder( + outbound: Receiver>, + inbound: Sender>, + nth: usize, +) { + tokio::task::spawn_local(async move { + let mut encrypted_count = 0usize; + while let Ok(bytes) = outbound.recv().await { + if nth > 0 && is_encrypted_payload(&bytes) { + encrypted_count = encrypted_count.saturating_add(1); + if encrypted_count % nth == 0 { + continue; + } + } + let _ = inbound.try_send(bytes); + } + }); +} + +fn spawn_gated_forwarder( + outbound: Receiver>, + inbound: Sender>, + drop_flag: Arc, +) { + tokio::task::spawn_local(async move { + while let Ok(bytes) = outbound.recv().await { + if drop_flag.load(Ordering::Relaxed) { + continue; + } + let _ = inbound.try_send(bytes); + } + }); +} + +#[allow(clippy::future_not_send)] +async fn run_local_test(future: F) +where + F: Future, +{ + run_local_test_timeout(Duration::from_secs(5), future).await; +} + +#[allow(clippy::future_not_send)] +async fn run_local_test_timeout(duration: Duration, future: F) +where + F: Future, +{ + init_test_logger(); + let local = LocalSet::new(); + let future = local.run_until(future); + tokio::time::timeout(duration, future) + .await + .unwrap_or_else(|_| panic!("local runtime test exceeded {duration:?}")); +} + +async fn await_status(receiver: &Receiver, peer: Option, stage: PeerStatus) { + tokio::time::timeout(Duration::from_secs(2), async { + loop { + if let Ok(event) = receiver.recv().await { + if event.peer == peer && event.status == stage { + return; + } + } + } + }) + .await + .unwrap(); +} + +async fn assert_no_status_for( + receiver: &Receiver, + peer: Option, + status: PeerStatus, + window: Duration, +) { + let res = tokio::time::timeout(window, async { + loop { + let event = receiver.recv().await.unwrap(); + if event.peer == peer && event.status == status { + return; + } + } + }) + .await; + assert!(res.is_err(), "unexpected status event: {status:?}"); +} + +async fn read_all(mut stream: crate::StreamReader) -> Result, QlStreamError> { + let mut data = Vec::new(); + while let Some(chunk) = next_chunk(&mut stream).await? { + data.extend_from_slice(&chunk); + } + Ok(data) +} + +async fn next_chunk_max( + stream: &mut crate::StreamReader, + max_len: usize, +) -> Result>, crate::QlStreamError> { + stream + .read(max_len) + .await + .map(|chunk| chunk.map(|bytes| bytes.to_vec())) +} + +async fn next_chunk(stream: &mut crate::StreamReader) -> Result>, QlStreamError> { + next_chunk_max(stream, usize::MAX).await +} + +fn default_runtime_config() -> RuntimeConfig { + RuntimeConfig { + fsm: QlFsmConfig { + handshake_timeout: Duration::from_millis(300), + session_record_retransmit_timeout: Duration::from_millis(30), + session_keepalive_interval: Duration::ZERO, + session_peer_timeout: Duration::ZERO, + ..Default::default() + }, + ..Default::default() + } +} + +// runtime is send, if platform is send +#[test] +fn runtime_is_send() { + let config = default_runtime_config(); + let identity = generate_identity(&SoftwareCrypto, "runtime").unwrap(); + let (platform, _, _, _) = TestPlatform::new(); + let (runtime, _handle) = new_runtime(identity, platform, config); + let _run: Box + Send> = Box::new(runtime.run()); +} + +#[test] +fn runtime_exits_when_last_handle_drops() { + let config = default_runtime_config(); + let identity = generate_identity(&SoftwareCrypto, "runtime").unwrap(); + let (platform, _, _, _) = TestPlatform::new(); + let (runtime, handle) = new_runtime(identity, platform, config); + let (done_tx, done_rx) = oneshot::channel(); + + std::thread::spawn(move || { + tokio::runtime::Builder::new_current_thread() + .enable_time() + .build() + .unwrap() + .block_on(runtime.run()); + done_tx.send(()).unwrap(); + }); + + drop(handle); + + done_rx + .recv_timeout(Duration::from_secs(1)) + .expect("runtime should stop once the last sender is dropped"); +} diff --git a/ql-runtime/src/tests/rpc.rs b/ql-runtime/src/tests/rpc.rs new file mode 100644 index 00000000..d6147c30 --- /dev/null +++ b/ql-runtime/src/tests/rpc.rs @@ -0,0 +1,657 @@ +use std::{ + cell::RefCell, + future::Future, + rc::Rc, + str::Utf8Error, + sync::{Arc, Mutex}, + time::Duration, +}; + +use bytes::Bytes; +use futures_lite::StreamExt; +use ql_rpc::{ + DownloadHandlerLocal, DownloadStart, DuplexHandlerLocal, DuplexPeer, LocalSpawner, + NotificationHandlerLocal, ProgressHandlerLocal, ProgressResponder, RequestHandler, + RequestHandlerLocal, Response, RouteId, SendSpawner, Spawner, StreamCloseCode, + SubscriptionHandlerLocal, SubscriptionResponder, UploadHandlerLocal, UploadReader, + UploadResponder, +}; + +use super::*; +use crate::{rpc::RpcError, QlStream, StreamWriter}; + +#[derive(Debug, Clone, Copy)] +struct TokioLocalSpawner; + +impl Spawner for TokioLocalSpawner { + type Handle = tokio::task::JoinHandle<()>; +} + +impl LocalSpawner for TokioLocalSpawner { + fn spawn(&self, fut: F) -> Self::Handle + where + F: Future + 'static, + { + tokio::task::spawn_local(fut) + } +} + +#[derive(Debug, Clone, Copy)] +struct TokioSendSpawner; + +impl Spawner for TokioSendSpawner { + type Handle = tokio::task::JoinHandle<()>; +} + +impl SendSpawner for TokioSendSpawner { + fn spawn(&self, fut: F) -> Self::Handle + where + F: Future + Send + 'static, + { + tokio::task::spawn(fut) + } +} + +struct Echo; + +impl ql_rpc::request::Request for Echo { + const ROUTE: RouteId = RouteId::from_u32(51); + + type Error = Utf8Error; + + type Request = String; + type Response = String; +} + +struct Feed; + +impl ql_rpc::subscription::Subscription for Feed { + const ROUTE: RouteId = RouteId::from_u32(52); + type Error = core::convert::Infallible; + type Request = Vec; + type Event = Vec; +} + +struct Notice; + +impl ql_rpc::notification::Notification for Notice { + const ROUTE: RouteId = RouteId::from_u32(521); + type Error = core::convert::Infallible; + type Payload = Vec; +} + +struct Download; + +impl ql_rpc::progress::Progress for Download { + const ROUTE: RouteId = RouteId::from_u32(53); + type Error = core::convert::Infallible; + type Request = Vec; + type Progress = Vec; + type Response = Vec; +} + +struct BlobDownload; + +impl ql_rpc::download::Download for BlobDownload { + const ROUTE: RouteId = RouteId::from_u32(54); + type Error = core::convert::Infallible; + type Request = Vec; + type ResponseHeader = Vec; + type PartHeader = Vec; +} + +struct BlobUpload; + +impl ql_rpc::upload::Upload for BlobUpload { + const ROUTE: RouteId = RouteId::from_u32(55); + type Error = core::convert::Infallible; + type Request = Vec; + type PartHeader = Vec; + type Response = Vec; +} + +struct Chat; + +impl ql_rpc::duplex::Duplex for Chat { + const ROUTE: RouteId = RouteId::from_u32(56); + type Error = core::convert::Infallible; + type InitiatorEvent = Vec; + type ResponderEvent = Vec; +} + +#[tokio::test(flavor = "current_thread")] +async fn rpc_request() { + #[derive(Clone)] + struct RouterState { + seen: Arc>>, + } + + impl RequestHandler for RouterState { + async fn handle(self, request: String, response: Response) { + let seen = self.seen.clone(); + seen.lock().unwrap().push(request); + let _ = response.respond("world".into()).await; + } + } + + run_local_test(async { + let mut pair = TestPair::new(default_runtime_config()); + pair.connect_and_wait(Side::A).await; + let inbound_b = pair.take_inbound(Side::B); + let seen = Arc::new(Mutex::new(Vec::new())); + + let router = + ql_rpc::Router::<_, QlStream, TokioSendSpawner>::builder_send(TokioSendSpawner) + .request::() + .build(RouterState { seen: seen.clone() }); + + let responder = tokio::task::spawn_local(async move { + let inbound = inbound_b.recv().await.unwrap(); + if let Some((_, fut)) = router.handle(inbound) { + let fut = assert_send(fut); + fut.await.unwrap(); + } + }); + + let rpc = pair.side_mut(Side::A).handle.rpc(); + let response = rpc.request::(&"hello".into()).await.unwrap(); + assert_eq!(response, "world"); + assert_eq!(&*seen.lock().unwrap(), &["hello".to_string()]); + + tokio::time::timeout(Duration::from_secs(2), responder) + .await + .unwrap() + .unwrap(); + }) + .await; +} + +fn assert_send(value: T) -> T { + value +} + +#[tokio::test(flavor = "current_thread")] +async fn rpc_notification() { + #[derive(Clone)] + struct RouterState { + seen: Rc>>>, + } + + impl NotificationHandlerLocal for RouterState { + async fn handle(self, payload: Vec) { + self.seen.borrow_mut().push(payload); + } + } + + run_local_test(async { + let mut pair = TestPair::new(default_runtime_config()); + pair.connect_and_wait(Side::A).await; + let inbound_b = pair.take_inbound(Side::B); + let seen = Rc::new(RefCell::new(Vec::new())); + + let router = + ql_rpc::Router::<_, QlStream, TokioLocalSpawner>::builder_local(TokioLocalSpawner) + .notification::() + .build(RouterState { seen: seen.clone() }); + + let responder = tokio::task::spawn_local(async move { + let inbound = inbound_b.recv().await.unwrap(); + if let Some((_, fut)) = router.handle(inbound) { + fut.await.unwrap(); + } + }); + + let rpc = pair.side_mut(Side::A).handle.rpc(); + rpc.notification::(&b"hello".to_vec()) + .await + .unwrap(); + assert_eq!(seen.borrow().as_slice(), &[b"hello".to_vec()]); + + tokio::time::timeout(Duration::from_secs(2), responder) + .await + .unwrap() + .unwrap(); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn rpc_subscrption() { + #[derive(Clone)] + struct RouterState { + seen: Rc>>>, + } + + impl SubscriptionHandlerLocal for RouterState { + async fn handle( + self, + request: Vec, + mut response: SubscriptionResponder, StreamWriter>, + ) { + let seen = self.seen.clone(); + seen.borrow_mut().push(request); + let _ = response.send(b"one".to_vec()).await; + let _ = response.send(b"two".to_vec()).await; + let _ = response.finish().await; + } + } + + run_local_test(async { + let mut pair = TestPair::new(default_runtime_config()); + pair.connect_and_wait(Side::A).await; + let inbound_b = pair.take_inbound(Side::B); + + let seen = Rc::new(RefCell::new(Vec::new())); + let router = + ql_rpc::Router::<_, QlStream, TokioLocalSpawner>::builder_local(TokioLocalSpawner) + .subscription::() + .build(RouterState { seen: seen.clone() }); + + let responder = tokio::task::spawn_local(async move { + let inbound = inbound_b.recv().await.unwrap(); + if let Some((_, fut)) = router.handle(inbound) { + fut.await.unwrap(); + } + }); + + let rpc = pair.side_mut(Side::A).handle.rpc(); + let mut subscription = rpc.subscribe::(&b"watch".to_vec()).await.unwrap(); + assert_eq!(subscription.next().await.unwrap().unwrap(), b"one".to_vec()); + assert_eq!(subscription.next().await.unwrap().unwrap(), b"two".to_vec()); + assert!(subscription.next().await.is_none()); + assert_eq!(seen.borrow().as_slice(), &[b"watch".to_vec()]); + + tokio::time::timeout(Duration::from_secs(2), responder) + .await + .unwrap() + .unwrap(); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn rpc_router_enforces_max_request_bytes() { + #[derive(Clone)] + struct LimitedState; + + impl RequestHandlerLocal for LimitedState { + async fn handle(self, request: String, response: Response) { + let _ = response.respond(request).await; + } + } + + run_local_test(async { + let mut pair = TestPair::new(default_runtime_config()); + pair.connect_and_wait(Side::A).await; + let inbound_b = pair.take_inbound(Side::B); + let router = + ql_rpc::Router::<_, QlStream, TokioLocalSpawner>::builder_local(TokioLocalSpawner) + .max_request_bytes(4) + .request::() + .build(LimitedState); + + let responder = tokio::task::spawn_local(async move { + let inbound = inbound_b.recv().await.unwrap(); + if let Some((_, fut)) = router.handle(inbound) { + fut.await.unwrap(); + } + }); + + let rpc = pair.side_mut(Side::A).handle.rpc(); + let response = rpc.request::(&"hello".to_string()).await; + assert!(matches!( + response, + Err(RpcError::Closed(code)) if code == StreamCloseCode::LIMIT + )); + + tokio::time::timeout(Duration::from_secs(2), responder) + .await + .unwrap() + .unwrap(); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn rpc_progress() { + #[derive(Clone)] + struct RouterState { + seen: Rc>>>, + } + + impl ProgressHandlerLocal for RouterState { + async fn handle( + self, + request: Vec, + mut responder: ProgressResponder, + ) { + let seen = self.seen.clone(); + seen.borrow_mut().push(request); + responder.send(b"10".to_vec()).await.unwrap(); + responder.send(b"90".to_vec()).await.unwrap(); + responder.finish(b"done".to_vec()).await.unwrap(); + } + } + + run_local_test(async { + let mut pair = TestPair::new(default_runtime_config()); + pair.connect_and_wait(Side::A).await; + let inbound_b = pair.take_inbound(Side::B); + let seen = Rc::new(RefCell::new(Vec::new())); + + let router = + ql_rpc::Router::<_, QlStream, TokioLocalSpawner>::builder_local(TokioLocalSpawner) + .progress::() + .build(RouterState { seen: seen.clone() }); + + let responder = tokio::task::spawn_local(async move { + let inbound = inbound_b.recv().await.unwrap(); + if let Some((_, fut)) = router.handle(inbound) { + fut.await.unwrap(); + } + }); + + let rpc = pair.side_mut(Side::A).handle.rpc(); + let mut download = rpc.progress::(&b"logo".to_vec()).await.unwrap(); + + assert_eq!(download.next().await, Some(b"10".to_vec())); + assert_eq!(download.next().await, Some(b"90".to_vec())); + assert_eq!(download.next().await, None); + assert_eq!(download.await.unwrap(), b"done".to_vec()); + assert_eq!(seen.borrow().as_slice(), &[b"logo".to_vec()]); + + tokio::time::timeout(Duration::from_secs(2), responder) + .await + .unwrap() + .unwrap(); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn rpc_download() { + #[derive(Clone)] + struct RouterState { + seen: Rc>>>, + } + + impl DownloadHandlerLocal for RouterState { + async fn handle( + self, + request: Vec, + download: DownloadStart, + ) { + let seen = self.seen.clone(); + seen.borrow_mut().push(request); + let mut writer = download.start(b"image/png".to_vec()).await.unwrap(); + let mut part = writer.start_part(b"icon".to_vec()).await.unwrap(); + part.send(Bytes::from_static(b"abc")).await.unwrap(); + part.send(Bytes::from_static(b"def")).await.unwrap(); + part.finish().await.unwrap(); + let mut part = writer.start_part(b"manifest".to_vec()).await.unwrap(); + part.send(Bytes::from_static(b"{}")).await.unwrap(); + part.finish().await.unwrap(); + writer.finish().await.unwrap(); + } + } + + run_local_test(async { + let mut pair = TestPair::new(default_runtime_config()); + pair.connect_and_wait(Side::A).await; + let inbound_b = pair.take_inbound(Side::B); + let seen = Rc::new(RefCell::new(Vec::new())); + + let router = + ql_rpc::Router::<_, QlStream, TokioLocalSpawner>::builder_local(TokioLocalSpawner) + .download::() + .build(RouterState { seen: seen.clone() }); + + let responder = tokio::task::spawn_local(async move { + let inbound = inbound_b.recv().await.unwrap(); + if let Some((_, fut)) = router.handle(inbound) { + fut.await.unwrap(); + } + }); + + let rpc = pair.side_mut(Side::A).handle.rpc(); + let download = rpc + .download::(&b"logo".to_vec()) + .await + .unwrap(); + let (header, mut reader) = download.start().await.unwrap(); + assert_eq!(header, b"image/png".to_vec()); + { + let (part_header, mut part) = reader.next_part().await.unwrap().unwrap(); + assert_eq!(part_header, b"icon".to_vec()); + assert_eq!( + part.read_chunk().await.unwrap(), + Some(Bytes::from_static(b"abc")) + ); + assert_eq!( + part.read_chunk().await.unwrap(), + Some(Bytes::from_static(b"def")) + ); + assert_eq!(part.read_chunk().await.unwrap(), None); + } + { + let (part_header, mut part) = reader.next_part().await.unwrap().unwrap(); + assert_eq!(part_header, b"manifest".to_vec()); + assert_eq!( + part.read_chunk().await.unwrap(), + Some(Bytes::from_static(b"{}")) + ); + assert_eq!(part.read_chunk().await.unwrap(), None); + } + assert!(reader.next_part().await.unwrap().is_none()); + assert_eq!(seen.borrow().as_slice(), &[b"logo".to_vec()]); + + tokio::time::timeout(Duration::from_secs(2), responder) + .await + .unwrap() + .unwrap(); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn rpc_download_complete() { + #[derive(Clone)] + struct RouterState { + seen: Rc>>>, + } + + impl DownloadHandlerLocal for RouterState { + async fn handle( + self, + request: Vec, + download: DownloadStart, + ) { + self.seen.borrow_mut().push(request); + download.complete(b"not found".to_vec()).await.unwrap(); + } + } + + run_local_test(async { + let mut pair = TestPair::new(default_runtime_config()); + pair.connect_and_wait(Side::A).await; + let inbound_b = pair.take_inbound(Side::B); + let seen = Rc::new(RefCell::new(Vec::new())); + + let router = + ql_rpc::Router::<_, QlStream, TokioLocalSpawner>::builder_local(TokioLocalSpawner) + .download::() + .build(RouterState { seen: seen.clone() }); + + let responder = tokio::task::spawn_local(async move { + let inbound = inbound_b.recv().await.unwrap(); + if let Some((_, fut)) = router.handle(inbound) { + fut.await.unwrap(); + } + }); + + let rpc = pair.side_mut(Side::A).handle.rpc(); + let download = rpc + .download::(&b"logo".to_vec()) + .await + .unwrap(); + let (header, reader) = download.start().await.unwrap(); + assert_eq!(header, b"not found".to_vec()); + reader.complete().await.unwrap(); + assert_eq!(seen.borrow().as_slice(), &[b"logo".to_vec()]); + + tokio::time::timeout(Duration::from_secs(2), responder) + .await + .unwrap() + .unwrap(); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn rpc_upload() { + #[derive(Clone)] + struct RouterState { + requests: Rc>>>, + uploads: Rc>>>, + } + + impl UploadHandlerLocal for RouterState { + async fn handle( + self, + request: Vec, + mut upload: UploadReader, + responder: UploadResponder, StreamWriter>, + ) { + let requests = self.requests.clone(); + let uploads = self.uploads.clone(); + requests.borrow_mut().push(request); + + let mut body = Vec::new(); + while let Some((part_header, mut part)) = upload.next_part().await.unwrap() { + body.extend_from_slice(&part_header); + body.push(b':'); + while let Some(chunk) = part.read_chunk().await.unwrap() { + body.extend_from_slice(&chunk); + } + body.push(b';'); + } + uploads.borrow_mut().push(body.clone()); + + responder.respond(body).await.unwrap(); + } + } + + run_local_test(async { + let mut pair = TestPair::new(default_runtime_config()); + pair.connect_and_wait(Side::A).await; + let inbound_b = pair.take_inbound(Side::B); + let requests = Rc::new(RefCell::new(Vec::new())); + let uploads = Rc::new(RefCell::new(Vec::new())); + + let router = + ql_rpc::Router::<_, QlStream, TokioLocalSpawner>::builder_local(TokioLocalSpawner) + .upload::() + .build(RouterState { + requests: requests.clone(), + uploads: uploads.clone(), + }); + + let responder = tokio::task::spawn_local(async move { + let inbound = inbound_b.recv().await.unwrap(); + if let Some((_, fut)) = router.handle(inbound) { + fut.await.unwrap(); + } + }); + + let rpc = pair.side_mut(Side::A).handle.rpc(); + let mut upload = rpc.upload::(&b"logo".to_vec()).await.unwrap(); + let mut part = upload.start_part(b"icon".to_vec()).await.unwrap(); + part.send(Bytes::from_static(b"abc")).await.unwrap(); + part.send(Bytes::from_static(b"def")).await.unwrap(); + part.finish().await.unwrap(); + let mut part = upload.start_part(b"manifest".to_vec()).await.unwrap(); + part.send(Bytes::from_static(b"{}")).await.unwrap(); + part.finish().await.unwrap(); + let response = upload.finish().await.unwrap(); + + assert_eq!(response, b"icon:abcdef;manifest:{};".to_vec()); + assert_eq!(requests.borrow().as_slice(), &[b"logo".to_vec()]); + assert_eq!( + uploads.borrow().as_slice(), + &[b"icon:abcdef;manifest:{};".to_vec()] + ); + + tokio::time::timeout(Duration::from_secs(2), responder) + .await + .unwrap() + .unwrap(); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn rpc_duplex() { + #[derive(Clone)] + struct RouterState { + seen: Rc>>>, + } + + impl DuplexHandlerLocal for RouterState { + async fn handle(self, mut peer: DuplexPeer) { + let seen = self.seen.clone(); + let first = peer.receiver.next_event().await.unwrap().unwrap(); + seen.borrow_mut().push(first); + + peer.sender + .send(&b"challenge-response".to_vec()) + .await + .unwrap(); + + let second = peer.receiver.next_event().await.unwrap().unwrap(); + seen.borrow_mut().push(second); + + peer.sender.finish().await.unwrap(); + } + } + + run_local_test(async { + let mut pair = TestPair::new(default_runtime_config()); + pair.connect_and_wait(Side::A).await; + let inbound_b = pair.take_inbound(Side::B); + let seen = Rc::new(RefCell::new(Vec::new())); + + let router = + ql_rpc::Router::<_, QlStream, TokioLocalSpawner>::builder_local(TokioLocalSpawner) + .duplex::() + .build(RouterState { seen: seen.clone() }); + + let responder = tokio::task::spawn_local(async move { + let inbound = inbound_b.recv().await.unwrap(); + if let Some((_, fut)) = router.handle(inbound) { + fut.await.unwrap(); + } + }); + + let rpc = pair.side_mut(Side::A).handle.rpc(); + let mut chat = rpc.duplex::().await.unwrap(); + chat.sender.send(&b"challenge".to_vec()).await.unwrap(); + assert_eq!( + chat.receiver.next_event().await.unwrap().unwrap(), + b"challenge-response".to_vec() + ); + chat.sender.send(&b"verification".to_vec()).await.unwrap(); + chat.sender.finish().await.unwrap(); + assert!(chat.receiver.next_event().await.is_none()); + + assert_eq!( + seen.borrow().as_slice(), + &[b"challenge".to_vec(), b"verification".to_vec()] + ); + + tokio::time::timeout(Duration::from_secs(2), responder) + .await + .unwrap() + .unwrap(); + }) + .await; +} diff --git a/ql-runtime/src/tests/session.rs b/ql-runtime/src/tests/session.rs new file mode 100644 index 00000000..ec351e35 --- /dev/null +++ b/ql-runtime/src/tests/session.rs @@ -0,0 +1,213 @@ +use std::{ + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, + }, + time::Duration, +}; + +use bytes::Bytes; +use ql_wire::SessionCloseCode; + +use super::*; +use crate::QlStreamError; + +#[tokio::test(flavor = "current_thread")] +async fn close_session_aborts_active_streams_and_allows_reconnect() { + run_local_test(async { + let mut pair = TestPair::new(default_runtime_config()); + let inbound_b = pair.take_inbound(Side::B); + let (received_tx, received_rx) = async_channel::bounded(1); + pair.connect_and_wait(Side::A).await; + + let responder = tokio::task::spawn_local(async move { + let stream = inbound_b.recv().await.unwrap(); + let mut reader = stream.reader; + + assert_eq!( + next_chunk(&mut reader).await.unwrap(), + Some(vec![1, 2, 3, 4]) + ); + received_tx.send(()).await.unwrap(); + + let err = next_chunk(&mut reader).await.unwrap_err(); + assert_eq!(err, QlStreamError::NoSession); + }); + + let mut stream = pair + .side(Side::A) + .handle + .open_stream(test_route_id()) + .await + .unwrap(); + stream + .writer + .write(Bytes::from_static(&[1, 2, 3, 4])) + .await + .unwrap(); + received_rx.recv().await.unwrap(); + + pair.side(Side::A) + .handle + .close_session(SessionCloseCode::CANCELLED); + + let err = stream.writer.finish().await.unwrap_err(); + assert_eq!(err, QlStreamError::NoSession); + + await_status( + &pair.side(Side::A).status, + Some(pair.side(Side::B).peer), + PeerStatus::Disconnected, + ) + .await; + await_status( + &pair.side(Side::B).status, + Some(pair.side(Side::A).peer), + PeerStatus::Disconnected, + ) + .await; + + tokio::time::timeout(Duration::from_secs(2), responder) + .await + .unwrap() + .unwrap(); + + pair.connect_and_wait(Side::A).await; + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn unpair_aborts_active_streams_and_prevents_reconnect() { + run_local_test(async { + let mut pair = TestPair::new(default_runtime_config()); + let inbound_b = pair.take_inbound(Side::B); + let (received_tx, received_rx) = async_channel::bounded(1); + pair.connect_and_wait(Side::A).await; + + let responder = tokio::task::spawn_local(async move { + let stream = inbound_b.recv().await.unwrap(); + let mut reader = stream.reader; + + assert_eq!( + next_chunk(&mut reader).await.unwrap(), + Some(vec![5, 6, 7, 8]) + ); + received_tx.send(()).await.unwrap(); + + let err = next_chunk(&mut reader).await.unwrap_err(); + assert_eq!(err, QlStreamError::NoSession); + }); + + let mut stream = pair + .side(Side::A) + .handle + .open_stream(test_route_id()) + .await + .unwrap(); + stream + .writer + .write(Bytes::from_static(&[5, 6, 7, 8])) + .await + .unwrap(); + received_rx.recv().await.unwrap(); + + pair.side(Side::A).handle.unpair(); + + let err = stream.writer.finish().await.unwrap_err(); + assert_eq!(err, QlStreamError::NoSession); + + await_status(&pair.side(Side::A).status, None, PeerStatus::Unpaired).await; + await_status(&pair.side(Side::B).status, None, PeerStatus::Unpaired).await; + + tokio::time::timeout(Duration::from_secs(2), responder) + .await + .unwrap() + .unwrap(); + + assert!(matches!( + pair.side(Side::A).handle.open_stream(test_route_id()).await, + Err(NoSessionError) + )); + assert!(matches!( + pair.side(Side::B).handle.open_stream(test_route_id()).await, + Err(NoSessionError) + )); + + pair.side(Side::B).handle.connect(); + assert_no_status_for( + &pair.side(Side::B).status, + None, + PeerStatus::Initiator, + Duration::from_millis(150), + ) + .await; + assert_no_status_for( + &pair.side(Side::B).status, + None, + PeerStatus::Connected, + Duration::from_millis(150), + ) + .await; + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn session_timeout_disconnects_and_fails_pending_open() { + run_local_test(async { + let config_a = RuntimeConfig { + fsm: QlFsmConfig { + session_keepalive_interval: Duration::from_millis(40), + session_peer_timeout: Duration::from_millis(60), + ..default_runtime_config().fsm + }, + ..default_runtime_config() + }; + let config_b = default_runtime_config(); + let (platform_a, outbound_a, inbound_a_tx, status_a) = TestPlatform::new(); + let (platform_b, outbound_b, inbound_b_tx, status_b, inbound_b) = + TestPlatform::new_with_inbound(); + let (identity_a, identity_b) = test_identities(&SoftwareCrypto); + + let (runtime_a, handle_a) = new_runtime(identity_a.clone(), platform_a, config_a); + let (runtime_b, handle_b) = new_runtime(identity_b.clone(), platform_b, config_b); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + let drop_flag = Arc::new(AtomicBool::new(false)); + spawn_forwarder(outbound_a, inbound_b_tx); + spawn_gated_forwarder(outbound_b, inbound_a_tx, drop_flag.clone()); + + register_peers(&handle_a, &handle_b, &identity_a, &identity_b); + handle_a.connect(); + + await_status(&status_a, Some(identity_b.qid), PeerStatus::Connected).await; + await_status(&status_b, Some(identity_a.qid), PeerStatus::Connected).await; + + let responder_task = tokio::task::spawn_local(async move { + let stream = inbound_b.recv().await.unwrap(); + let _ = read_all(stream.reader).await; + let err = stream.writer.finish().await.unwrap_err(); + assert!(matches!(err, QlStreamError::NoSession)); + }); + + drop_flag.store(true, Ordering::Relaxed); + + let mut pending = handle_a.open_stream(test_route_id()).await.unwrap(); + let err = pending.writer.finish().await.unwrap_err(); + assert!(matches!(err, QlStreamError::NoSession)); + + await_status(&status_a, Some(identity_b.qid), PeerStatus::Disconnected).await; + + let result = + tokio::time::timeout(Duration::from_millis(300), next_chunk(&mut pending.reader)) + .await + .unwrap(); + assert!(matches!(result, Err(QlStreamError::NoSession))); + + responder_task.abort(); + }) + .await; +} diff --git a/ql-runtime/src/tests/stream.rs b/ql-runtime/src/tests/stream.rs new file mode 100644 index 00000000..176711c8 --- /dev/null +++ b/ql-runtime/src/tests/stream.rs @@ -0,0 +1,673 @@ +use std::time::Duration; + +use bytes::Bytes; +use ql_wire::StreamCloseCode; + +use super::*; +use crate::QlStreamError; + +#[tokio::test(flavor = "current_thread")] +async fn open_stream_duplex_happy_path() { + run_local_test(async { + let mut pair = TestPair::new(default_runtime_config()); + pair.connect_and_wait(Side::A).await; + let inbound_b = pair.take_inbound(Side::B); + + let responder = tokio::task::spawn_local(async move { + let inbound = inbound_b.recv().await.unwrap(); + + let mut writer = inbound.writer; + let mut reader = inbound.reader; + + assert_eq!(next_chunk(&mut reader).await.unwrap(), Some(vec![1, 2])); + writer.write(Bytes::from_static(&[9])).await.unwrap(); + assert_eq!(next_chunk(&mut reader).await.unwrap(), Some(vec![3, 4])); + writer.write(Bytes::from_static(&[8, 7])).await.unwrap(); + assert_eq!(next_chunk(&mut reader).await.unwrap(), None); + writer.finish().await.unwrap(); + }); + + let mut stream = pair + .side(Side::A) + .handle + .open_stream(test_route_id()) + .await + .unwrap(); + stream + .writer + .write(Bytes::from_static(&[1, 2])) + .await + .unwrap(); + assert_eq!(next_chunk(&mut stream.reader).await.unwrap(), Some(vec![9])); + stream + .writer + .write(Bytes::from_static(&[3, 4])) + .await + .unwrap(); + stream.writer.finish().await.unwrap(); + assert_eq!( + next_chunk(&mut stream.reader).await.unwrap(), + Some(vec![8, 7]) + ); + assert_eq!(next_chunk(&mut stream.reader).await.unwrap(), None); + + tokio::time::timeout(Duration::from_secs(2), responder) + .await + .unwrap() + .unwrap(); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn reader_respects_max_len() { + run_local_test(async { + let mut pair = TestPair::new(default_runtime_config()); + pair.connect_and_wait(Side::A).await; + let inbound_b = pair.take_inbound(Side::B); + + let responder = tokio::task::spawn_local(async move { + let inbound = inbound_b.recv().await.unwrap(); + let mut reader = inbound.reader; + + assert_eq!( + next_chunk_max(&mut reader, 2).await.unwrap(), + Some(vec![1, 2]) + ); + assert_eq!( + next_chunk_max(&mut reader, 2).await.unwrap(), + Some(vec![3, 4]) + ); + assert_eq!( + next_chunk_max(&mut reader, 2).await.unwrap(), + Some(vec![5, 6]) + ); + assert_eq!(next_chunk(&mut reader).await.unwrap(), None); + + inbound.writer.finish().await.unwrap(); + }); + + let mut stream = pair + .side(Side::A) + .handle + .open_stream(test_route_id()) + .await + .unwrap(); + stream + .writer + .write(Bytes::from_static(&[1, 2, 3, 4, 5, 6])) + .await + .unwrap(); + stream.writer.finish().await.unwrap(); + assert_eq!(next_chunk(&mut stream.reader).await.unwrap(), None); + + tokio::time::timeout(Duration::from_secs(2), responder) + .await + .unwrap() + .unwrap(); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn large_stream_payload_round_trips() { + run_local_test(async { + let payload: Vec = (0..40).collect(); + let mut pair = TestPair::new(default_runtime_config()); + let (done_tx, done_rx) = async_channel::bounded(1); + pair.connect_and_wait(Side::A).await; + let inbound_b = pair.take_inbound(Side::B); + + let responder = tokio::task::spawn_local(async move { + let stream = inbound_b.recv().await.unwrap(); + let request_data = read_all(stream.reader).await.unwrap(); + stream.writer.finish().await.unwrap(); + done_tx.send(request_data).await.unwrap(); + }); + + let mut stream = pair + .side(Side::A) + .handle + .open_stream(test_route_id()) + .await + .unwrap(); + stream + .writer + .write(Bytes::from(payload.clone())) + .await + .unwrap(); + stream.writer.finish().await.unwrap(); + assert_eq!(next_chunk(&mut stream.reader).await.unwrap(), None); + + let received = tokio::time::timeout(Duration::from_secs(2), done_rx.recv()) + .await + .unwrap() + .unwrap(); + assert_eq!(received, payload); + + tokio::time::timeout(Duration::from_secs(2), responder) + .await + .unwrap() + .unwrap(); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn dropping_responder_closes_initiator_response() { + run_local_test(async { + let mut pair = TestPair::new(default_runtime_config()); + pair.connect_and_wait(Side::A).await; + let inbound_b = pair.take_inbound(Side::B); + + let responder = tokio::task::spawn_local(async move { + let stream = inbound_b.recv().await.unwrap(); + drop(stream.reader); + }); + + let mut stream = pair + .side(Side::A) + .handle + .open_stream(test_route_id()) + .await + .unwrap(); + let err = stream.writer.finish().await.unwrap_err(); + assert!(matches!( + err, + QlStreamError::StreamClosed { code } if code == StreamCloseCode::CANCELLED + )); + + let err = next_chunk(&mut stream.reader).await.unwrap_err(); + assert!(matches!( + err, + QlStreamError::StreamClosed { code } if code == StreamCloseCode::CANCELLED + )); + + tokio::time::timeout(Duration::from_secs(2), responder) + .await + .unwrap() + .unwrap(); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn dropping_inbound_reader_cancels_remote_writer() { + run_local_test(async { + let mut pair = TestPair::new(default_runtime_config()); + let inbound_b = pair.take_inbound(Side::B); + let (go_tx, go_rx) = async_channel::bounded(1); + pair.connect_and_wait(Side::A).await; + + let responder = tokio::task::spawn_local(async move { + let stream = inbound_b.recv().await.unwrap(); + let mut writer = stream.writer; + let mut reader = stream.reader; + assert_eq!(next_chunk(&mut reader).await.unwrap(), None); + writer + .write(Bytes::from_static(&[1, 2, 3, 4])) + .await + .unwrap(); + go_rx.recv().await.unwrap(); + let _ = writer.write(Bytes::from(vec![5; 64])).await; + let err = writer.finish().await.unwrap_err(); + assert!(matches!( + err, + QlStreamError::StreamClosed { code } if code == StreamCloseCode::CANCELLED + )); + }); + + let mut stream = pair + .side(Side::A) + .handle + .open_stream(test_route_id()) + .await + .unwrap(); + stream.writer.finish().await.unwrap(); + assert_eq!( + next_chunk(&mut stream.reader).await.unwrap(), + Some(vec![1, 2, 3, 4]) + ); + drop(stream.reader); + go_tx.send(()).await.unwrap(); + + tokio::time::timeout(Duration::from_secs(2), responder) + .await + .unwrap() + .unwrap(); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn closing_initiator_reader_preserves_initiator_writer() { + run_local_test(async { + let mut pair = TestPair::new(default_runtime_config()); + pair.connect_and_wait(Side::A).await; + let inbound_b = pair.take_inbound(Side::B); + let (done_tx, done_rx) = async_channel::bounded(1); + + let responder = tokio::task::spawn_local(async move { + let stream = inbound_b.recv().await.unwrap(); + let request = read_all(stream.reader).await.unwrap(); + done_tx.send(request).await.unwrap(); + }); + + let stream = pair + .side(Side::A) + .handle + .open_stream(test_route_id()) + .await + .unwrap(); + let mut writer = stream.writer; + stream.reader.close(StreamCloseCode::CANCELLED); + + writer.write(Bytes::from_static(&[1, 2])).await.unwrap(); + writer.write(Bytes::from_static(&[3, 4])).await.unwrap(); + writer.finish().await.unwrap(); + + let request = tokio::time::timeout(Duration::from_secs(2), done_rx.recv()) + .await + .unwrap() + .unwrap(); + assert_eq!(request, vec![1, 2, 3, 4]); + + tokio::time::timeout(Duration::from_secs(2), responder) + .await + .unwrap() + .unwrap(); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn max_concurrent_message_writes_is_respected() { + run_local_test(async { + let stats = WriteStats::new(); + let config = RuntimeConfig { + max_concurrent_message_writes: 2, + ..default_runtime_config() + }; + let (platform_a, outbound_a, inbound_a_tx, status_a) = + TestPlatform::new_with_delayed_writes(Duration::from_millis(40), stats.clone()); + let (platform_b, outbound_b, inbound_b_tx, status_b, inbound_b) = + TestPlatform::new_with_inbound(); + let (identity_a, identity_b) = test_identities(&SoftwareCrypto); + + let (runtime_a, handle_a) = new_runtime(identity_a.clone(), platform_a, config); + let (runtime_b, handle_b) = new_runtime(identity_b.clone(), platform_b, config); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + spawn_forwarder(outbound_a, inbound_b_tx); + spawn_forwarder(outbound_b, inbound_a_tx); + + register_peers(&handle_a, &handle_b, &identity_a, &identity_b); + handle_a.connect(); + + await_status(&status_a, Some(identity_b.qid), PeerStatus::Connected).await; + await_status(&status_b, Some(identity_a.qid), PeerStatus::Connected).await; + + let responder = tokio::task::spawn_local(async move { + for _ in 0..4 { + let stream = inbound_b.recv().await.unwrap(); + let _ = read_all(stream.reader).await; + let mut writer = stream.writer; + writer.queue_finish(); + } + }); + + let mut tasks = Vec::new(); + for i in 0..4u8 { + let handle = handle_a.clone(); + tasks.push(tokio::task::spawn_local(async move { + let mut stream = handle.open_stream(test_route_id()).await.unwrap(); + stream.writer.write(Bytes::from(vec![i; 8])).await.unwrap(); + stream.writer.finish().await.unwrap(); + assert_eq!(next_chunk(&mut stream.reader).await.unwrap(), None); + })); + } + + for task in tasks { + tokio::time::timeout(Duration::from_secs(4), task) + .await + .unwrap() + .unwrap(); + } + + tokio::time::timeout(Duration::from_secs(4), responder) + .await + .unwrap() + .unwrap(); + + assert!( + stats.max_active() <= 2, + "max active writes exceeded: {}", + stats.max_active() + ); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn stream_round_trip_survives_encrypted_packet_drops() { + run_local_test(async { + let config = RuntimeConfig { + fsm: QlFsmConfig { + session_record_retransmit_timeout: Duration::from_millis(20), + ..default_runtime_config().fsm + }, + ..default_runtime_config() + }; + let (platform_a, outbound_a, inbound_a_tx, status_a) = TestPlatform::new(); + let (platform_b, outbound_b, inbound_b_tx, status_b, inbound_b) = + TestPlatform::new_with_inbound(); + let (identity_a, identity_b) = test_identities(&SoftwareCrypto); + + let request_payload: Vec = (0..32).collect(); + let response_payload: Vec = (100..132).collect(); + let expected_response = response_payload.clone(); + + let (runtime_a, handle_a) = new_runtime(identity_a.clone(), platform_a, config); + let (runtime_b, handle_b) = new_runtime(identity_b.clone(), platform_b, config); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + spawn_drop_every_nth_encrypted_forwarder(outbound_a, inbound_b_tx, 3); + spawn_drop_every_nth_encrypted_forwarder(outbound_b, inbound_a_tx, 3); + + register_peers(&handle_a, &handle_b, &identity_a, &identity_b); + handle_a.connect(); + + await_status(&status_a, Some(identity_b.qid), PeerStatus::Connected).await; + await_status(&status_b, Some(identity_a.qid), PeerStatus::Connected).await; + + let responder = tokio::task::spawn_local(async move { + let stream = inbound_b.recv().await.unwrap(); + let received_request = read_all(stream.reader).await.unwrap(); + let mut writer = stream.writer; + writer + .write(Bytes::from(response_payload.clone())) + .await + .unwrap(); + writer.finish().await.unwrap(); + received_request + }); + + let mut stream = handle_a.open_stream(test_route_id()).await.unwrap(); + stream + .writer + .write(Bytes::from(request_payload.clone())) + .await + .unwrap(); + stream.writer.finish().await.unwrap(); + + let mut received_response = Vec::new(); + while let Some(chunk) = next_chunk(&mut stream.reader).await.unwrap() { + received_response.extend_from_slice(&chunk); + } + assert_eq!(received_response, expected_response); + + let received_request = tokio::time::timeout(Duration::from_secs(4), responder) + .await + .unwrap() + .unwrap(); + assert_eq!(received_request, request_payload); + }) + .await; +} + +#[allow(clippy::too_many_lines)] +#[tokio::test(flavor = "current_thread")] +async fn multi_megabyte_stream_survives_asymmetric_loss_and_delay() { + run_local_test_timeout(Duration::from_secs(10), async { + let payload_len = 2 * 1024 * 1024; + let chunk_len = 16 * 1024; + let payload: Vec = (0..payload_len) + .map(|i| u8::try_from(i % 251).unwrap()) + .collect(); + let expected = payload.clone(); + let config = RuntimeConfig { + fsm: QlFsmConfig { + session_record_max_size: 16 * 1024, + session_record_ack_delay: Duration::from_millis(2), + session_record_retransmit_timeout: Duration::from_millis(25), + session_stream_send_buffer_size: 4 * 1024 * 1024, + session_stream_receive_buffer_size: 4 * 1024 * 1024, + session_accepted_record_window: 16 * 1024, + session_pending_ack_range_limit: 4 * 1024, + ..default_runtime_config().fsm + }, + ..default_runtime_config() + }; + let (mut pair, links) = TestPair::new_with_controlled_links( + config, + LinkBehavior { + base_delay: Duration::from_millis(1), + drop_encrypted_every: Some(41), + delay_encrypted_every: Some((13, Duration::from_millis(12))), + ..LinkBehavior::default() + }, + LinkBehavior { + base_delay: Duration::from_millis(1), + ..LinkBehavior::default() + }, + ); + pair.connect_and_wait(Side::A).await; + links.b_to_a.store(LinkBehavior { + base_delay: Duration::from_millis(3), + drop_encrypted_every: Some(7), + duplicate_encrypted_every: Some(19), + delay_encrypted_every: Some((3, Duration::from_millis(25))), + }); + let inbound_b = pair.take_inbound(Side::B); + + let responder = tokio::task::spawn_local(async move { + let stream = inbound_b.recv().await.unwrap(); + eprintln!("responder accepted inbound stream"); + let mut reader = stream.reader; + let mut received = Vec::new(); + while let Some(chunk) = next_chunk(&mut reader).await.unwrap() { + if received.len() >= 36 * chunk_len { + eprintln!("responder received chunk of {} bytes", chunk.len()); + } + received.extend_from_slice(&chunk); + if received.len() % (256 * 1024) == 0 { + eprintln!("responder received {} bytes", received.len()); + } + } + stream.writer.finish().await.unwrap(); + received + }); + + let recovery_links = links.clone(); + let recovery = tokio::task::spawn_local(async move { + tokio::time::sleep(Duration::from_millis(300)).await; + eprintln!("restoring reverse path"); + recovery_links.b_to_a.store(LinkBehavior { + base_delay: Duration::from_millis(1), + delay_encrypted_every: Some((17, Duration::from_millis(8))), + ..LinkBehavior::default() + }); + }); + + let writer = tokio::task::spawn_local(async move { + let mut stream = pair + .side(Side::A) + .handle + .open_stream(test_route_id()) + .await + .unwrap(); + for (index, chunk) in payload.chunks(chunk_len).enumerate() { + if index + 1 >= 40 { + eprintln!("writer attempting chunk {}", index + 1); + } + stream + .writer + .write(Bytes::copy_from_slice(chunk)) + .await + .unwrap(); + if index + 1 >= 40 { + eprintln!("writer queued chunk {}", index + 1); + } + if index % 16 == 15 { + eprintln!("writer queued {} chunks", index + 1); + } + } + eprintln!("writer finished queueing"); + stream.writer.finish().await.unwrap(); + eprintln!("writer waiting for eof"); + assert_eq!(next_chunk(&mut stream.reader).await.unwrap(), None); + eprintln!("writer observed eof"); + }); + + tokio::time::timeout(Duration::from_secs(30), writer) + .await + .unwrap() + .unwrap(); + tokio::time::timeout(Duration::from_secs(2), recovery) + .await + .unwrap() + .unwrap(); + let received = tokio::time::timeout(Duration::from_secs(30), responder) + .await + .unwrap() + .unwrap(); + assert_eq!(received, expected); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn reproducer_writer_stalls_after_reverse_path_impairment() { + run_local_test_timeout(Duration::from_secs(10), async { + let payload_len = 2 * 1024 * 1024; + let chunk_len = 16 * 1024; + let payload: Vec = (0..payload_len) + .map(|i| u8::try_from(i % 251).unwrap()) + .collect(); + let config = RuntimeConfig { + fsm: QlFsmConfig { + session_record_max_size: 16 * 1024, + session_record_ack_delay: Duration::from_millis(2), + session_record_retransmit_timeout: Duration::from_millis(25), + session_stream_send_buffer_size: 4 * 1024 * 1024, + session_stream_receive_buffer_size: 4 * 1024 * 1024, + session_accepted_record_window: 16 * 1024, + session_pending_ack_range_limit: 4 * 1024, + ..default_runtime_config().fsm + }, + ..default_runtime_config() + }; + let (mut pair, links) = TestPair::new_with_controlled_links( + config, + LinkBehavior { + base_delay: Duration::from_millis(1), + drop_encrypted_every: Some(41), + delay_encrypted_every: Some((13, Duration::from_millis(12))), + ..LinkBehavior::default() + }, + LinkBehavior { + base_delay: Duration::from_millis(1), + ..LinkBehavior::default() + }, + ); + pair.connect_and_wait(Side::A).await; + links.b_to_a.store(LinkBehavior { + base_delay: Duration::from_millis(3), + drop_encrypted_every: Some(7), + duplicate_encrypted_every: Some(19), + delay_encrypted_every: Some((3, Duration::from_millis(25))), + }); + let inbound_b = pair.take_inbound(Side::B); + + let responder = tokio::task::spawn_local(async move { + let stream = inbound_b.recv().await.unwrap(); + let mut reader = stream.reader; + while next_chunk(&mut reader).await.unwrap().is_some() {} + }); + + let recovery_links = links.clone(); + let recovery = tokio::task::spawn_local(async move { + tokio::time::sleep(Duration::from_millis(300)).await; + recovery_links.b_to_a.store(LinkBehavior { + base_delay: Duration::from_millis(1), + delay_encrypted_every: Some((17, Duration::from_millis(8))), + ..LinkBehavior::default() + }); + }); + + let writer = tokio::task::spawn_local(async move { + let mut stream = pair + .side(Side::A) + .handle + .open_stream(test_route_id()) + .await + .unwrap(); + for chunk in payload.chunks(chunk_len) { + stream + .writer + .write(Bytes::copy_from_slice(chunk)) + .await + .unwrap(); + } + stream.writer.queue_finish(); + let _ = next_chunk(&mut stream.reader).await; + }); + + let _ = tokio::time::timeout(Duration::from_secs(15), writer).await; + recovery.abort(); + responder.abort(); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn responder_drains_multiple_local_chunks_per_writable_wake() { + run_local_test(async { + let chunk_len = 4104usize; + let chunk_count = 5usize; + let expected = vec![0x5a; chunk_len * chunk_count]; + let mut pair = TestPair::new(default_runtime_config()); + pair.connect_and_wait(Side::A).await; + let inbound_b = pair.take_inbound(Side::B); + + let responder = tokio::task::spawn_local(async move { + let inbound = inbound_b.recv().await.unwrap(); + let _ = read_all(inbound.reader).await.unwrap(); + + let mut writer = inbound.writer; + for _ in 0..chunk_count { + writer + .write(Bytes::from(vec![0x5a; chunk_len])) + .await + .unwrap(); + } + writer.finish().await.unwrap(); + }); + + let mut stream = pair + .side(Side::A) + .handle + .open_stream(test_route_id()) + .await + .unwrap(); + stream + .writer + .write(Bytes::from_static(b"request")) + .await + .unwrap(); + stream.writer.finish().await.unwrap(); + + let received = read_all(stream.reader).await.unwrap(); + assert_eq!(received, expected); + + tokio::time::timeout(Duration::from_secs(2), responder) + .await + .unwrap() + .unwrap(); + }) + .await; +} diff --git a/ql-wire/Cargo.toml b/ql-wire/Cargo.toml new file mode 100644 index 00000000..5fccc95e --- /dev/null +++ b/ql-wire/Cargo.toml @@ -0,0 +1,27 @@ +[package] +name = "ql-wire" +version = "0.1.0" +edition = "2021" +description = "Quantum Link wire format types and crypto helpers" +license = "Proprietary" + +[features] +test-utils = [ + "dep:getrandom", + "dep:libcrux-aesgcm", + "dep:libcrux-ml-kem", + "dep:sha2", +] + +[dependencies] +bytes = "1" +getrandom = { workspace = true, optional = true } +libcrux-aesgcm = { version = "0.0.7", optional = true } +libcrux-ml-kem = { version = "0.0.7", optional = true } +sha2 = { version = "0.10", optional = true } + +[dev-dependencies] +getrandom = { workspace = true } +libcrux-aesgcm = "0.0.7" +libcrux-ml-kem = "0.0.7" +sha2 = "0.10" diff --git a/ql-wire/src/bytes.rs b/ql-wire/src/bytes.rs new file mode 100644 index 00000000..9fecf5ea --- /dev/null +++ b/ql-wire/src/bytes.rs @@ -0,0 +1,175 @@ +use core::ops::{Deref, DerefMut}; + +use bytes::{Buf, Bytes}; + +/// A mutable or immutable byte slice owner used by the wire parser. +pub trait ByteSlice: Deref + Sized { + /// Splits the current byte view at `mid`. + /// + /// Returns `Err(self)` when `mid` is out of bounds. + fn split_at(self, mid: usize) -> Result<(Self, Self), Self>; +} + +/// A mutable reference to bytes. +pub trait ByteSliceMut: ByteSlice + DerefMut {} + +impl ByteSliceMut for B where B: ByteSlice + DerefMut {} + +impl ByteSlice for &[u8] { + #[inline] + fn split_at(self, mid: usize) -> Result<(Self, Self), Self> { + if mid <= self.len() { + Ok(<[u8]>::split_at(self, mid)) + } else { + Err(self) + } + } +} + +impl ByteSlice for &mut [u8] { + #[inline] + fn split_at(self, mid: usize) -> Result<(Self, Self), Self> { + if mid <= self.len() { + Ok(<[u8]>::split_at_mut(self, mid)) + } else { + Err(self) + } + } +} + +impl ByteSlice for Bytes { + #[inline] + fn split_at(self, mid: usize) -> Result<(Self, Self), Self> { + if mid <= self.len() { + Ok((self.slice(..mid), self.slice(mid..))) + } else { + Err(self) + } + } +} + +/// A byte container that can expose a replayable [`Buf`] view for encoding. +pub trait BufView { + type Buf<'a>: Buf + where + Self: 'a; + + fn buf(&self) -> Self::Buf<'_>; + + fn is_empty(&self) -> bool { + self.buf().remaining() == 0 + } +} + +impl BufView for &T { + type Buf<'a> + = T::Buf<'a> + where + Self: 'a; + + fn buf(&self) -> Self::Buf<'_> { + (*self).buf() + } +} + +impl BufView for &mut T { + type Buf<'a> + = T::Buf<'a> + where + Self: 'a; + + fn buf(&self) -> Self::Buf<'_> { + (**self).buf() + } +} + +impl BufView for [u8] { + type Buf<'a> + = &'a [u8] + where + Self: 'a; + + fn buf(&self) -> Self::Buf<'_> { + self + } +} + +impl BufView for [u8; N] { + type Buf<'a> + = &'a [u8] + where + Self: 'a; + + fn buf(&self) -> Self::Buf<'_> { + self.as_slice() + } +} + +impl BufView for Vec { + type Buf<'a> + = &'a [u8] + where + Self: 'a; + + fn buf(&self) -> Self::Buf<'_> { + self.as_slice() + } +} + +impl BufView for Bytes { + type Buf<'a> + = &'a [u8] + where + Self: 'a; + + fn buf(&self) -> Self::Buf<'_> { + self.as_ref() + } +} + +#[cfg(test)] +mod tests { + use bytes::Buf; + + use super::{BufView, ByteSlice, ByteSliceMut}; + + #[test] + fn shared_slice_split_at() { + let bytes: &[u8] = b"abcdef"; + let (left, right) = ByteSlice::split_at(bytes, 2).unwrap(); + assert_eq!(left, b"ab"); + assert_eq!(right, b"cdef"); + } + + #[test] + fn mutable_slice_split_at() { + let mut bytes = *b"abcdef"; + let (left, right) = ByteSlice::split_at(&mut bytes[..], 2).unwrap(); + assert_eq!(left, b"ab"); + assert_eq!(right, b"cdef"); + } + + #[test] + fn mutable_split_trait_is_implemented() { + fn assert_split_mut(_value: T) {} + + let mut bytes = [0u8; 4]; + assert_split_mut(&mut bytes[..]); + } + + #[test] + fn split_at_rejects_out_of_bounds_index() { + let bytes: &[u8] = b"abcdef"; + assert!(ByteSlice::split_at(bytes, 7).is_err()); + } + + #[test] + fn slice_buf_view_is_contiguous() { + let bytes: &[u8] = b"abcdef"; + let mut buf = bytes.buf(); + assert_eq!(buf.remaining(), 6); + assert_eq!(buf.chunk(), b"abcdef"); + buf.advance(6); + assert!(!buf.has_remaining()); + } +} diff --git a/ql-wire/src/codec.rs b/ql-wire/src/codec.rs new file mode 100644 index 00000000..0245ef6d --- /dev/null +++ b/ql-wire/src/codec.rs @@ -0,0 +1,245 @@ +use bytes::BufMut; + +use crate::{ByteSlice, WireError}; + +pub trait WireEncode { + fn encoded_len(&self) -> usize; + + fn encode(&self, out: &mut W); + + fn encode_vec(&self) -> Vec { + let mut out = Vec::with_capacity(self.encoded_len()); + self.encode(&mut out); + debug_assert_eq!(out.len(), self.encoded_len()); + out + } +} + +pub trait WireDecode: Sized { + fn decode(reader: &mut Reader) -> Result; + + fn decode_bytes(bytes: B) -> Result { + let mut reader = Reader::new(bytes); + Self::decode(&mut reader) + } + + fn decode_exact(bytes: B) -> Result { + let mut reader = Reader::new(bytes); + let value = Self::decode(&mut reader)?; + if reader.is_empty() { + Ok(value) + } else { + Err(WireError::InvalidPayload) + } + } +} + +impl WireDecode for [u8; N] { + fn decode(reader: &mut Reader) -> Result { + let bytes = reader.take_bytes(N)?; + let mut out = [0u8; N]; + out.copy_from_slice(&bytes); + Ok(out) + } +} + +impl WireEncode for [u8; N] { + fn encoded_len(&self) -> usize { + N + } + + fn encode(&self, out: &mut W) { + out.put_slice(self); + } +} + +impl WireDecode for Box<[u8; N]> { + fn decode(reader: &mut Reader) -> Result { + let bytes = reader.take_bytes(N)?; + let mut out = Self::new_uninit(); + let src = bytes.as_ptr(); + let dst = out.as_mut_ptr().cast::(); + // SAFETY: `take_bytes(N)` guarantees the source has exactly `N` bytes. + unsafe { + std::ptr::copy_nonoverlapping(src, dst, N); + Ok(out.assume_init()) + } + } +} + +impl WireEncode for Box<[u8; N]> { + fn encoded_len(&self) -> usize { + N + } + + fn encode(&self, out: &mut W) { + out.put_slice(self.as_ref()); + } +} + +impl WireEncode for [u8] { + fn encoded_len(&self) -> usize { + self.len() + } + + fn encode(&self, out: &mut W) { + out.put_slice(self); + } +} + +impl WireDecode for u8 { + fn decode(reader: &mut Reader) -> Result { + Ok(reader.take_bytes(1)?[0]) + } +} + +impl WireEncode for u8 { + fn encoded_len(&self) -> usize { + size_of::() + } + + fn encode(&self, out: &mut W) { + out.put_u8(*self); + } +} + +impl WireDecode for u16 { + fn decode(reader: &mut Reader) -> Result { + Ok(Self::from_be_bytes(reader.decode()?)) + } +} + +impl WireEncode for u16 { + fn encoded_len(&self) -> usize { + size_of::() + } + + fn encode(&self, out: &mut W) { + out.put_u16(*self); + } +} + +impl WireDecode for u32 { + fn decode(reader: &mut Reader) -> Result { + Ok(Self::from_be_bytes(reader.decode()?)) + } +} + +impl WireEncode for u32 { + fn encoded_len(&self) -> usize { + size_of::() + } + + fn encode(&self, out: &mut W) { + out.put_u32(*self); + } +} + +impl WireDecode for u64 { + fn decode(reader: &mut Reader) -> Result { + Ok(Self::from_be_bytes(reader.decode()?)) + } +} + +impl WireEncode for u64 { + fn encoded_len(&self) -> usize { + size_of::() + } + + fn encode(&self, out: &mut W) { + out.put_u64(*self); + } +} + +impl WireDecode for bool { + fn decode(reader: &mut Reader) -> Result { + match reader.decode::()? { + 0 => Ok(false), + 1 => Ok(true), + _ => Err(WireError::InvalidPayload), + } + } +} + +impl WireEncode for bool { + fn encoded_len(&self) -> usize { + size_of::() + } + + fn encode(&self, out: &mut W) { + out.put_u8(u8::from(*self)); + } +} + +impl WireEncode for Option { + fn encoded_len(&self) -> usize { + 1 + self.as_ref().map_or(0, WireEncode::encoded_len) + } + + fn encode(&self, out: &mut W) { + match self { + None => out.put_u8(0), + Some(inner) => { + out.put_u8(1); + inner.encode(out); + } + } + } +} + +impl> WireDecode for Option { + fn decode(reader: &mut Reader) -> Result { + match reader.decode::()? { + 0 => Ok(None), + 1 => Ok(Some(reader.decode::()?)), + _ => Err(WireError::InvalidPayload), + } + } +} + +#[derive(Clone)] +pub struct Reader { + remaining: Option, +} + +impl Reader { + pub fn new(bytes: B) -> Self { + Self { + remaining: Some(bytes), + } + } + + pub fn is_empty(&self) -> bool { + self.remaining.as_ref().unwrap().is_empty() + } + + pub fn remaining_len(&self) -> usize { + self.remaining.as_ref().unwrap().len() + } + + pub fn take_bytes(&mut self, len: usize) -> Result { + let remaining = self.remaining.take().unwrap(); + match remaining.split_at(len) { + Ok((head, tail)) => { + self.remaining = Some(tail); + Ok(head) + } + Err(remaining) => { + self.remaining = Some(remaining); + Err(WireError::InvalidPayload) + } + } + } + + pub fn take_rest(&mut self) -> B { + self.take_bytes(self.remaining_len()).unwrap() + } + + #[inline] + pub fn decode(&mut self) -> Result + where + T: WireDecode, + { + T::decode(self) + } +} diff --git a/ql-wire/src/crypto.rs b/ql-wire/src/crypto.rs new file mode 100644 index 00000000..96ace383 --- /dev/null +++ b/ql-wire/src/crypto.rs @@ -0,0 +1,47 @@ +use crate::{ + MlKemCiphertext, MlKemKeyPair, MlKemPrivateKey, MlKemPublicKey, Nonce, SessionKey, + ENCRYPTED_MESSAGE_AUTH_SIZE, +}; + +pub trait QlRandom { + fn fill_random_bytes(&self, out: &mut [u8]); +} + +pub trait QlHash { + fn sha256(&self, parts: &[&[u8]]) -> [u8; 32]; +} + +pub trait QlAead { + fn aes256_gcm_encrypt( + &self, + key: &SessionKey, + nonce: &Nonce, + aad: &[u8], + buffer: &mut [u8], + ) -> [u8; ENCRYPTED_MESSAGE_AUTH_SIZE]; + + fn aes256_gcm_decrypt( + &self, + key: &SessionKey, + nonce: &Nonce, + aad: &[u8], + buffer: &mut [u8], + auth_tag: &[u8; ENCRYPTED_MESSAGE_AUTH_SIZE], + ) -> bool; +} + +pub trait QlKem { + fn mlkem_generate_keypair(&self) -> MlKemKeyPair; + + fn mlkem_encapsulate(&self, public_key: &MlKemPublicKey) -> (MlKemCiphertext, SessionKey); + + fn mlkem_decapsulate( + &self, + private_key: &MlKemPrivateKey, + ciphertext: &MlKemCiphertext, + ) -> SessionKey; +} + +pub trait QlCrypto: QlRandom + QlHash + QlAead + QlKem {} + +impl QlCrypto for T where T: QlRandom + QlHash + QlAead + QlKem {} diff --git a/ql-wire/src/encrypted/ack.rs b/ql-wire/src/encrypted/ack.rs new file mode 100644 index 00000000..2eb34b33 --- /dev/null +++ b/ql-wire/src/encrypted/ack.rs @@ -0,0 +1,453 @@ +use std::{fmt, ops::RangeInclusive}; + +use crate::{codec, ByteSlice, RecordSeq, VarInt, WireEncode, WireError}; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct RecordAck { + largest_acked: RecordSeq, + first_range_len: VarInt, + blocks: Box<[RecordAckBlock]>, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct RecordAckBlock { + pub gap: VarInt, + pub range_len: VarInt, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum RecordAckRangeError { + Empty, + InvertedRange, + NotCanonical, +} + +impl RecordAck { + /// Build a record ACK from canonical ranges ordered from highest to lowest sequence number. + /// + /// Ranges must be: + /// - non-empty + /// - individually valid (`start <= end`) + /// - strictly descending + /// - separated by at least one missing sequence number + pub fn from_ranges(ranges: I) -> Result + where + I: IntoIterator>, + { + let mut builder = RecordAckBuilder::new(); + for range in ranges { + let pushed = builder.try_push_range(range, usize::MAX)?; + if !pushed { + unreachable!("record ack should fit inside usize::MAX"); + } + } + builder.build() + } + + pub fn ranges(&self) -> RecordAckRangeIter<'_> { + RecordAckRangeIter { + largest_acked: self.largest_acked.into_inner(), + first_range_len: Some(self.first_range_len), + previous_start: None, + blocks: self.blocks.iter(), + } + } + + pub fn contains(&self, seq: u64) -> bool { + let Ok(seq) = RecordSeq::from_u64(seq) else { + return false; + }; + self.ranges().any(|range| range.contains(&seq)) + } + + fn block_count_len(block_count: usize) -> usize { + VarInt::try_from(block_count).unwrap().encoded_len() + } +} + +impl RecordAckBlock { + fn encoded_len(&self) -> usize { + self.gap.encoded_len() + self.range_len.encoded_len() + } +} + +impl fmt::Display for RecordAckRangeError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Empty => f.write_str("record ack requires at least one acknowledged range"), + Self::InvertedRange => { + f.write_str("record ack range start must be less than or equal to end") + } + Self::NotCanonical => f.write_str( + "record ack ranges must be passed in descending, disjoint order with a gap between adjacent ranges", + ), + } + } +} + +impl std::error::Error for RecordAckRangeError {} + +pub struct RecordAckRangeIter<'a> { + largest_acked: u64, + first_range_len: Option, + previous_start: Option, + blocks: std::slice::Iter<'a, RecordAckBlock>, +} + +impl Iterator for RecordAckRangeIter<'_> { + type Item = RangeInclusive; + + fn next(&mut self) -> Option { + if let Some(first_range_len) = self.first_range_len.take() { + let end = self.largest_acked; + let start = end - first_range_len.into_inner(); + self.previous_start = Some(start); + return Some(RecordSeq::from_u64(start).unwrap()..=RecordSeq::from_u64(end).unwrap()); + } + + let block = self.blocks.next()?; + let previous_start = self + .previous_start + .expect("first ack range is always yielded"); + // gap is encoded as missing_count - 1, so decoding steps back by gap + 2. + let end = previous_start - block.gap.into_inner() - 2; + let start = end - block.range_len.into_inner(); + self.previous_start = Some(start); + Some(RecordSeq::from_u64(start).unwrap()..=RecordSeq::from_u64(end).unwrap()) + } +} + +impl WireEncode for RecordAck { + fn encoded_len(&self) -> usize { + self.largest_acked.encoded_len() + + Self::block_count_len(self.blocks.len()) + + self.first_range_len.encoded_len() + + self + .blocks + .iter() + .map(RecordAckBlock::encoded_len) + .sum::() + } + + fn encode(&self, out: &mut W) { + self.largest_acked.encode(out); + VarInt::try_from(self.blocks.len()).unwrap().encode(out); + self.first_range_len.encode(out); + for block in &self.blocks { + block.gap.encode(out); + block.range_len.encode(out); + } + } +} + +impl codec::WireDecode for RecordAck { + fn decode(reader: &mut codec::Reader) -> Result { + let largest_acked = reader.decode()?; + let block_count = usize::try_from(reader.decode::()?.into_inner()) + .map_err(|_| WireError::InvalidPayload)?; + let first_range_len = reader.decode::()?; + let mut blocks = Vec::with_capacity(block_count); + for _ in 0..block_count { + blocks.push(RecordAckBlock { + gap: reader.decode::()?, + range_len: reader.decode::()?, + }); + } + + let ack = Self { + largest_acked, + first_range_len, + blocks: blocks.into_boxed_slice(), + }; + + // validate + { + let mut previous_start = ack + .largest_acked + .into_inner() + .checked_sub(ack.first_range_len.into_inner()) + .ok_or(WireError::InvalidPayload)?; + + for block in &ack.blocks { + let end = previous_start + .checked_sub( + block + .gap + .into_inner() + .checked_add(2) + .ok_or(WireError::InvalidPayload)?, + ) + .ok_or(WireError::InvalidPayload)?; + previous_start = end + .checked_sub(block.range_len.into_inner()) + .ok_or(WireError::InvalidPayload)?; + } + } + Ok(ack) + } +} + +#[derive(Debug, Clone, Default, PartialEq, Eq)] +pub struct RecordAckBuilder { + largest_acked: Option, + first_range_len: Option, + blocks: Vec, + previous_start: Option, + wire_len: usize, +} + +impl RecordAckBuilder { + pub fn new() -> Self { + Self::default() + } + + pub fn try_push_range( + &mut self, + range: RangeInclusive, + max_wire_size: usize, + ) -> Result { + let start = range.start().into_inner(); + let end = range.end().into_inner(); + if start > end { + return Err(RecordAckRangeError::InvertedRange); + } + + let range_len = VarInt::from_u64(end - start).unwrap(); + if let Some(previous_start) = self.previous_start { + if end.saturating_add(1) >= previous_start { + return Err(RecordAckRangeError::NotCanonical); + } + + let gap = previous_start + .checked_sub(end) + .and_then(|delta| delta.checked_sub(2)) + .expect("canonical ack ranges stay separated by at least one sequence"); + let block = RecordAckBlock { + gap: VarInt::from_u64(gap).unwrap(), + range_len, + }; + let current_block_count_len = RecordAck::block_count_len(self.blocks.len()); + let next_block_count_len = RecordAck::block_count_len(self.blocks.len() + 1); + let next_wire_len = self.wire_len + + (next_block_count_len - current_block_count_len) + + block.encoded_len(); + if next_wire_len > max_wire_size { + return Ok(false); + } + + self.previous_start = Some(start); + self.wire_len = next_wire_len; + self.blocks.push(block); + return Ok(true); + } + + let largest_acked = RecordSeq::from_u64(end).unwrap(); + let wire_len = + largest_acked.encoded_len() + RecordAck::block_count_len(0) + range_len.encoded_len(); + if wire_len > max_wire_size { + return Ok(false); + } + + self.largest_acked = Some(largest_acked); + self.first_range_len = Some(range_len); + self.previous_start = Some(start); + self.wire_len = wire_len; + Ok(true) + } + + pub fn build(self) -> Result { + let Some(largest_acked) = self.largest_acked else { + return Err(RecordAckRangeError::Empty); + }; + + Ok(RecordAck { + largest_acked, + first_range_len: self.first_range_len.unwrap(), + blocks: self.blocks.into_boxed_slice(), + }) + } +} +#[cfg(test)] +mod tests { + use std::ops::RangeInclusive; + + use super::{RecordAck, RecordAckBlock, RecordAckBuilder, RecordAckRangeError}; + use crate::{RecordSeq, VarInt, WireDecode, WireEncode, WireError}; + + fn seq(value: u64) -> RecordSeq { + RecordSeq::from_u64(value).unwrap() + } + + fn ack_range(start: u64, end: u64) -> RangeInclusive { + seq(start)..=seq(end) + } + + fn varint(value: u64) -> VarInt { + VarInt::from_u64(value).unwrap() + } + + #[test] + fn encode_decode_round_trip() { + let ack = + RecordAck::from_ranges([ack_range(95, 100), ack_range(90, 92), ack_range(80, 80)]) + .unwrap(); + let encoded = ack.encode_vec(); + + assert_eq!(RecordAck::decode_exact(encoded.as_slice()).unwrap(), ack); + } + + #[test] + fn wire_fields_match_gap_encoding() { + let ack = + RecordAck::from_ranges([ack_range(95, 100), ack_range(90, 92), ack_range(80, 80)]) + .unwrap(); + + assert_eq!(ack.largest_acked, seq(100)); + assert_eq!(ack.first_range_len, varint(5)); + assert_eq!( + ack.blocks.as_ref(), + &[ + RecordAckBlock { + gap: varint(1), + range_len: varint(2), + }, + RecordAckBlock { + gap: varint(8), + range_len: varint(0), + } + ] + ); + } + + #[test] + fn builder_matches_from_ranges() { + let mut builder = RecordAckBuilder::new(); + assert!(builder + .try_push_range(ack_range(95, 100), usize::MAX) + .unwrap()); + assert!(builder + .try_push_range(ack_range(90, 92), usize::MAX) + .unwrap()); + assert!(builder + .try_push_range(ack_range(80, 80), usize::MAX) + .unwrap()); + + assert_eq!( + builder.build().unwrap(), + RecordAck::from_ranges([ack_range(95, 100), ack_range(90, 92), ack_range(80, 80)]) + .unwrap() + ); + } + + #[test] + fn builder_stops_when_budget_is_exhausted() { + let first_only = RecordAck::from_ranges([ack_range(95, 100)]).unwrap(); + let mut builder = RecordAckBuilder::new(); + + assert!(builder + .try_push_range(ack_range(95, 100), first_only.encoded_len()) + .unwrap()); + assert!(!builder + .try_push_range(ack_range(90, 92), first_only.encoded_len()) + .unwrap()); + assert_eq!(builder.build().unwrap(), first_only); + } + + #[test] + fn builder_rejects_non_canonical_ranges() { + let mut builder = RecordAckBuilder::new(); + assert!(builder + .try_push_range(ack_range(95, 100), usize::MAX) + .unwrap()); + assert_eq!( + builder.try_push_range(ack_range(90, 95), usize::MAX), + Err(RecordAckRangeError::NotCanonical) + ); + } + + #[test] + fn rejects_unsorted_ranges() { + assert_eq!( + RecordAck::from_ranges([ack_range(90, 92), ack_range(95, 100)]), + Err(RecordAckRangeError::NotCanonical) + ); + } + + #[test] + fn rejects_touching_ranges() { + assert_eq!( + RecordAck::from_ranges([ack_range(10, 12), ack_range(7, 9)]), + Err(RecordAckRangeError::NotCanonical) + ); + } + + #[test] + fn rejects_overlapping_ranges() { + assert_eq!( + RecordAck::from_ranges([ack_range(10, 12), ack_range(8, 11)]), + Err(RecordAckRangeError::NotCanonical) + ); + } + + #[test] + fn contains_matches_range_membership() { + let ack = RecordAck::from_ranges([ + ack_range(150, 163), + ack_range(105, 110), + ack_range(100, 100), + ]) + .unwrap(); + + assert!(ack.contains(100)); + assert!(ack.contains(107)); + assert!(ack.contains(163)); + assert!(!ack.contains(99)); + assert!(!ack.contains(104)); + assert!(!ack.contains(164)); + } + + #[test] + fn empty_ack_is_rejected() { + assert_eq!(RecordAck::from_ranges([]), Err(RecordAckRangeError::Empty)); + } + + #[test] + fn inverted_range_is_rejected() { + assert_eq!( + RecordAck::from_ranges([ack_range(5, 4)]), + Err(RecordAckRangeError::InvertedRange) + ); + } + + #[test] + fn decode_rejects_underflowing_ack_blocks() { + let encoded = vec![ + 42, // largest_acked + 1, // block_count + 0, // first_range_len + 41, // gap: implies a missing run larger than largest_acked + 0, // range_len + ]; + + assert_eq!( + RecordAck::decode_exact(encoded.as_slice()), + Err(WireError::InvalidPayload) + ); + } + + #[test] + fn decode_rejects_truncated_payload() { + assert_eq!( + RecordAck::decode_exact(&[][..]), + Err(WireError::InvalidPayload) + ); + + let encoded = RecordAck::from_ranges([ack_range(42, 42)]) + .unwrap() + .encode_vec(); + assert_eq!( + RecordAck::decode_exact(&encoded[..encoded.len() - 1]), + Err(WireError::InvalidPayload) + ); + } +} diff --git a/ql-wire/src/encrypted/builder.rs b/ql-wire/src/encrypted/builder.rs new file mode 100644 index 00000000..42933235 --- /dev/null +++ b/ql-wire/src/encrypted/builder.rs @@ -0,0 +1,172 @@ +use bytes::BufMut; + +use super::{RecordAck, SessionClose, SessionFrame, StreamClose, StreamData, StreamWindow}; +use crate::{ + BufView, ConnectionId, Nonce, QlCrypto, RecordSeq, RecordType, SessionHeader, SessionKey, + WireEncode, QL_WIRE_VERSION, +}; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct SessionRecordBuilder { + seq: RecordSeq, + prefix_len: usize, + max_capacity: usize, + bytes: Vec, +} + +impl SessionRecordBuilder { + pub const MIN_CAPACITY: usize = 1 + + 1 + + ConnectionId::SIZE + + RecordSeq::MAX_ENCODED_LEN + + crate::ENCRYPTED_MESSAGE_AUTH_SIZE; + + pub fn new(seq: RecordSeq, max_capacity: usize) -> Self { + let prefix_len = + 1 + 1 + ConnectionId::SIZE + seq.encoded_len() + crate::ENCRYPTED_MESSAGE_AUTH_SIZE; + assert!(max_capacity >= prefix_len); + Self { + seq, + prefix_len, + max_capacity, + bytes: Vec::new(), + } + } + + pub fn seq(&self) -> RecordSeq { + self.seq + } + + pub fn prefix_len(&self) -> usize { + self.prefix_len + } + + pub fn max_capacity(&self) -> usize { + self.max_capacity + } + + pub fn len(&self) -> usize { + self.bytes.len().saturating_sub(self.prefix_len) + } + + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + pub fn remaining_capacity(&self) -> usize { + self.max_capacity + .saturating_sub(self.bytes.len().max(self.prefix_len)) + } + + pub fn bytes(&self) -> &[u8] { + self.bytes.get(self.prefix_len..).unwrap_or_default() + } + + pub fn push_ping(&mut self) -> bool { + self.push_empty_frame(super::SessionFrameKind::Ping) + } + + pub fn push_unpair(&mut self) -> bool { + self.push_empty_frame(super::SessionFrameKind::Unpair) + } + + pub fn push_ack(&mut self, ack: &RecordAck) -> bool { + self.push_frame_payload(super::SessionFrameKind::Ack, ack) + } + + pub fn push_stream_data(&mut self, frame: &StreamData) -> bool { + self.push_frame_payload(super::SessionFrameKind::StreamData, frame) + } + + pub fn push_stream_window(&mut self, frame: &StreamWindow) -> bool { + self.push_frame_payload(super::SessionFrameKind::StreamWindow, frame) + } + + pub fn push_stream_close(&mut self, frame: &StreamClose) -> bool { + self.push_frame_payload(super::SessionFrameKind::StreamClose, frame) + } + + pub fn push_close(&mut self, close: &SessionClose) -> bool { + self.push_frame_payload(super::SessionFrameKind::Close, close) + } + + pub fn push_frame(&mut self, frame: &SessionFrame) -> bool { + match frame { + SessionFrame::Ping => self.push_ping(), + SessionFrame::Unpair => self.push_unpair(), + SessionFrame::Ack(frame) => self.push_ack(frame), + SessionFrame::StreamData(frame) => self.push_stream_data(frame), + SessionFrame::StreamWindow(frame) => self.push_stream_window(frame), + SessionFrame::StreamClose(frame) => self.push_stream_close(frame), + SessionFrame::Close(close) => self.push_close(close), + } + } + + pub fn encrypt( + mut self, + crypto: &impl QlCrypto, + connection_id: ConnectionId, + session_key: &SessionKey, + ) -> Vec { + self.ensure_prefix_capacity(0); + let header = SessionHeader { + connection_id, + seq: self.seq, + }; + let aad = header.aad(); + let nonce = Nonce::from_counter(self.seq.into_inner()); + let auth = crypto.aes256_gcm_encrypt( + session_key, + &nonce, + &aad, + &mut self.bytes[self.prefix_len..], + ); + + let mut prefix = &mut self.bytes[..self.prefix_len]; + prefix[0] = QL_WIRE_VERSION; + prefix[1] = RecordType::Session as u8; + prefix = &mut prefix[2..]; + header.encode(&mut prefix); + auth.encode(&mut prefix); + debug_assert!(prefix.is_empty()); + self.bytes + } + + fn push_wire_size(&mut self, wire_size: usize, encode: impl FnOnce(&mut Vec)) -> bool { + if !self.can_push_len(wire_size) { + return false; + } + self.ensure_prefix_capacity(wire_size); + let start = self.bytes.len(); + encode(&mut self.bytes); + debug_assert_eq!(self.bytes.len(), start + wire_size); + true + } + + fn push_empty_frame(&mut self, kind: super::SessionFrameKind) -> bool { + self.push_wire_size(1, |out| out.put_u8(kind as u8)) + } + + fn push_frame_payload( + &mut self, + kind: super::SessionFrameKind, + payload: &T, + ) -> bool { + let payload_wire_size = payload.encoded_len(); + self.push_wire_size(1 + payload_wire_size, |out| { + out.put_u8(kind as u8); + payload.encode(out); + }) + } + + fn can_push_len(&self, len: usize) -> bool { + len <= self.remaining_capacity() + } + + fn ensure_prefix_capacity(&mut self, additional_body_len: usize) { + if self.bytes.is_empty() { + self.bytes.reserve(self.prefix_len + additional_body_len); + self.bytes.resize(self.prefix_len, 0); + } + } +} diff --git a/ql-wire/src/encrypted/close.rs b/ql-wire/src/encrypted/close.rs new file mode 100644 index 00000000..e0860d7a --- /dev/null +++ b/ql-wire/src/encrypted/close.rs @@ -0,0 +1,55 @@ +use crate::{codec, codec::Reader, ByteSlice, WireEncode, WireError}; + +/// closes the whole session immediately with a close code. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct SessionClose { + pub code: SessionCloseCode, +} + +impl SessionClose { + pub const WIRE_SIZE: usize = size_of::(); +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[repr(transparent)] +pub struct SessionCloseCode(pub u16); + +impl SessionCloseCode { + pub const CANCELLED: Self = Self(0); + pub const PROTOCOL: Self = Self(1); + pub const TIMEOUT: Self = Self(2); +} + +impl WireEncode for SessionCloseCode { + fn encoded_len(&self) -> usize { + size_of::() + } + + fn encode(&self, out: &mut W) { + self.0.encode(out); + } +} + +impl codec::WireDecode for SessionCloseCode { + fn decode(reader: &mut Reader) -> Result { + Ok(Self(reader.decode()?)) + } +} + +impl codec::WireDecode for SessionClose { + fn decode(reader: &mut Reader) -> Result { + Ok(Self { + code: reader.decode()?, + }) + } +} + +impl WireEncode for SessionClose { + fn encoded_len(&self) -> usize { + Self::WIRE_SIZE + } + + fn encode(&self, out: &mut W) { + self.code.encode(out); + } +} diff --git a/ql-wire/src/encrypted/mod.rs b/ql-wire/src/encrypted/mod.rs new file mode 100644 index 00000000..563f9ded --- /dev/null +++ b/ql-wire/src/encrypted/mod.rs @@ -0,0 +1,178 @@ +use crate::{ + codec, encrypted_message::EncryptedMessage, BufView, ByteSlice, Nonce, QlCrypto, Reader, + SessionHeader, SessionKey, WireDecode, WireEncode, WireError, +}; + +mod ack; +mod builder; +mod close; +mod route_id; +mod stream_close; +mod stream_data; +mod stream_id; +mod stream_window; + +pub use ack::*; +pub use builder::*; +pub use close::*; +pub use route_id::*; +pub use stream_close::*; +pub use stream_data::*; +pub use stream_id::*; +pub use stream_window::*; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum SessionFrame { + // todo: do we need ping as explicit frame? + Ping, + Unpair, + Ack(RecordAck), + StreamData(StreamData), + StreamWindow(StreamWindow), + StreamClose(StreamClose), + Close(SessionClose), +} + +impl WireDecode for SessionFrame { + fn decode(reader: &mut Reader) -> Result { + let kind = reader.decode::()?; + let frame = match kind { + SessionFrameKind::Ping => Self::Ping, + SessionFrameKind::Unpair => Self::Unpair, + SessionFrameKind::Ack => Self::Ack(reader.decode::()?), + SessionFrameKind::StreamData => Self::StreamData(reader.decode::>()?), + SessionFrameKind::StreamWindow => Self::StreamWindow(reader.decode::()?), + SessionFrameKind::StreamClose => Self::StreamClose(reader.decode::()?), + SessionFrameKind::Close => Self::Close(reader.decode::()?), + }; + Ok(frame) + } +} + +impl SessionFrame { + fn kind(&self) -> SessionFrameKind { + match self { + Self::Ping => SessionFrameKind::Ping, + Self::Unpair => SessionFrameKind::Unpair, + Self::Ack(_) => SessionFrameKind::Ack, + Self::StreamData(_) => SessionFrameKind::StreamData, + Self::StreamWindow(_) => SessionFrameKind::StreamWindow, + Self::StreamClose(_) => SessionFrameKind::StreamClose, + Self::Close(_) => SessionFrameKind::Close, + } + } +} + +impl SessionFrame { + pub fn into_owned(self) -> SessionFrame> { + match self { + Self::Ping => SessionFrame::Ping, + Self::Unpair => SessionFrame::Unpair, + Self::Ack(frame) => SessionFrame::Ack(frame), + Self::StreamData(frame) => SessionFrame::StreamData(frame.into_owned()), + Self::StreamWindow(frame) => SessionFrame::StreamWindow(frame), + Self::StreamClose(frame) => SessionFrame::StreamClose(frame), + Self::Close(frame) => SessionFrame::Close(frame), + } + } +} + +impl WireEncode for SessionFrame { + fn encoded_len(&self) -> usize { + 1 + match self { + Self::Ping | Self::Unpair => 0, + Self::Ack(frame) => frame.encoded_len(), + Self::StreamData(frame) => frame.encoded_len(), + Self::StreamWindow(frame) => frame.encoded_len(), + Self::StreamClose(frame) => frame.encoded_len(), + Self::Close(frame) => frame.encoded_len(), + } + } + + fn encode(&self, out: &mut W) { + out.put_u8(self.kind() as u8); + match self { + Self::Ping | Self::Unpair => {} + Self::Ack(frame) => frame.encode(out), + Self::StreamData(frame) => frame.encode(out), + Self::StreamWindow(frame) => frame.encode(out), + Self::StreamClose(frame) => frame.encode(out), + Self::Close(frame) => frame.encode(out), + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[repr(u8)] +pub enum SessionFrameKind { + Ping = 1, + Ack = 2, + StreamData = 3, + StreamWindow = 4, + StreamClose = 5, + Close = 6, + Unpair = 7, +} + +impl TryFrom for SessionFrameKind { + type Error = WireError; + + fn try_from(value: u8) -> Result { + match value { + 1 => Ok(Self::Ping), + 2 => Ok(Self::Ack), + 3 => Ok(Self::StreamData), + 4 => Ok(Self::StreamWindow), + 5 => Ok(Self::StreamClose), + 6 => Ok(Self::Close), + 7 => Ok(Self::Unpair), + _ => Err(WireError::InvalidPayload), + } + } +} + +impl codec::WireDecode for SessionFrameKind { + fn decode(reader: &mut codec::Reader) -> Result { + reader.decode::()?.try_into() + } +} + +pub fn parse_session_frames(bytes: B) -> SessionFrameIter { + SessionFrameIter { + reader: Reader::new(bytes), + } +} + +pub fn decode_session_frames(bytes: &[u8]) -> Result>>, WireError> { + parse_session_frames(bytes) + .map(|frame| frame.map(SessionFrame::into_owned)) + .collect() +} + +#[derive(Clone)] +pub struct SessionFrameIter { + reader: Reader, +} + +impl Iterator for SessionFrameIter { + type Item = Result, WireError>; + + fn next(&mut self) -> Option { + if self.reader.is_empty() { + None + } else { + Some(self.reader.decode::>()) + } + } +} + +pub fn decrypt_record>( + crypto: &impl QlCrypto, + header: &SessionHeader, + encrypted: EncryptedMessage, + session_key: &SessionKey, +) -> Result { + let aad = header.aad(); + let nonce = Nonce::from_counter(header.seq.into_inner()); + encrypted.decrypt_in_place(crypto, session_key, &nonce, &aad) +} diff --git a/ql-wire/src/encrypted/route_id.rs b/ql-wire/src/encrypted/route_id.rs new file mode 100644 index 00000000..6b91a521 --- /dev/null +++ b/ql-wire/src/encrypted/route_id.rs @@ -0,0 +1,55 @@ +use crate::{ByteSlice, Reader, VarInt, VarIntBoundsExceeded, WireDecode, WireEncode, WireError}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[repr(transparent)] +pub struct RouteId(pub VarInt); + +impl RouteId { + pub const MAX_ENCODED_LEN: usize = VarInt::MAX_SIZE; + + pub const fn from_u32(value: u32) -> Self { + Self(VarInt::from_u32(value)) + } + + pub fn from_u64(value: u64) -> Result { + Ok(Self(VarInt::from_u64(value)?)) + } + + pub const fn into_inner(self) -> u64 { + self.0.into_inner() + } +} + +impl WireEncode for RouteId { + fn encoded_len(&self) -> usize { + self.0.size() + } + + fn encode(&self, out: &mut W) { + self.0.encode(out); + } +} + +impl WireDecode for RouteId { + fn decode(reader: &mut Reader) -> Result { + Ok(Self(reader.decode()?)) + } +} + +impl From for RouteId { + fn from(value: VarInt) -> Self { + Self(value) + } +} + +impl From for RouteId { + fn from(value: u32) -> Self { + Self::from_u32(value) + } +} + +impl std::fmt::Display for RouteId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} diff --git a/ql-wire/src/encrypted/stream_close.rs b/ql-wire/src/encrypted/stream_close.rs new file mode 100644 index 00000000..20ddb879 --- /dev/null +++ b/ql-wire/src/encrypted/stream_close.rs @@ -0,0 +1,116 @@ +use super::StreamId; +use crate::{codec, ByteSlice, WireEncode, WireError}; + +/// aborts one or both lanes of a stream with a close code +/// +/// stream origin is the peer that opened the stream +/// origin lane carries bytes sent by the stream origin +/// return lane carries bytes sent back toward the stream origin +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct StreamClose { + pub stream_id: StreamId, + pub target: CloseTarget, + pub code: StreamCloseCode, +} + +impl StreamClose {} + +impl WireEncode for StreamClose { + fn encoded_len(&self) -> usize { + self.stream_id.encoded_len() + self.target.encoded_len() + self.code.encoded_len() + } + + fn encode(&self, out: &mut W) { + self.stream_id.encode(out); + self.target.encode(out); + self.code.encode(out); + } +} + +impl codec::WireDecode for StreamClose { + fn decode(reader: &mut codec::Reader) -> Result { + Ok(Self { + stream_id: reader.decode()?, + target: reader.decode()?, + code: reader.decode()?, + }) + } +} + +/// selects which stream lane a [`StreamClose`] applies to +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[repr(u8)] +pub enum CloseTarget { + /// close the lane sent by the stream origin + Origin = 1, + /// close the lane sent back toward the stream origin + Return = 2, + /// close both stream lanes + Both = 3, +} + +impl CloseTarget { + pub const fn to_wire(self) -> u8 { + self as u8 + } +} + +impl WireEncode for CloseTarget { + fn encoded_len(&self) -> usize { + size_of::() + } + + fn encode(&self, out: &mut W) { + self.to_wire().encode(out); + } +} + +impl TryFrom for CloseTarget { + type Error = WireError; + + fn try_from(value: u8) -> Result { + match value { + 1 => Ok(Self::Origin), + 2 => Ok(Self::Return), + 3 => Ok(Self::Both), + _ => Err(WireError::InvalidPayload), + } + } +} + +impl codec::WireDecode for CloseTarget { + fn decode(reader: &mut codec::Reader) -> Result { + reader.decode::()?.try_into() + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[repr(transparent)] +pub struct StreamCloseCode(pub u16); + +impl StreamCloseCode { + /// the stream was aborted intentionally before graceful completion + pub const CANCELLED: Self = Self(0); +} + +impl codec::WireDecode for StreamCloseCode { + fn decode(reader: &mut codec::Reader) -> Result { + Ok(Self(reader.decode()?)) + } +} + +impl WireEncode for StreamCloseCode { + fn encoded_len(&self) -> usize { + size_of::() + } + + fn encode(&self, out: &mut W) { + self.0.encode(out); + } +} + +impl std::fmt::Display for StreamCloseCode { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} diff --git a/ql-wire/src/encrypted/stream_data.rs b/ql-wire/src/encrypted/stream_data.rs new file mode 100644 index 00000000..9174fe5a --- /dev/null +++ b/ql-wire/src/encrypted/stream_data.rs @@ -0,0 +1,135 @@ +use bytes::Buf; + +use super::{RouteId, StreamId}; +use crate::{codec, BufView, ByteSlice, VarInt, WireDecode, WireEncode, WireError}; + +/// carries bytes for a stream and may finish that sending direction. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct StreamData { + pub stream_id: StreamId, + pub offset: VarInt, + pub header: Option, + pub fin: bool, + pub bytes: B, +} + +impl StreamData { + pub const MIN_WIRE_SIZE: usize = StreamId::MAX_ENCODED_LEN + + VarInt::MAX_SIZE + + size_of::() + + StreamHeader::MAX_WIRE_SIZE + + VarInt::MAX_SIZE; +} + +impl WireDecode for StreamData { + fn decode(reader: &mut codec::Reader) -> Result { + let stream_id = reader.decode()?; + let offset: VarInt = reader.decode()?; + let flags = reader.decode::()?; + let fin = (flags & flag::FIN) != 0; + let has_header = (flags & flag::HEADER) != 0; + let header = if has_header { + Some(reader.decode()?) + } else { + None + }; + let bytes_len = usize::try_from(reader.decode::()?.into_inner()) + .map_err(|_| WireError::InvalidPayload)?; + + Ok(Self { + stream_id, + offset, + header, + fin, + bytes: reader.take_bytes(bytes_len)?, + }) + } +} + +impl StreamData { + pub fn into_owned(self) -> StreamData> + where + B: ByteSlice, + { + StreamData { + stream_id: self.stream_id, + offset: self.offset, + header: self.header, + fin: self.fin, + bytes: self.bytes.to_vec(), + } + } +} + +impl WireEncode for StreamData { + fn encoded_len(&self) -> usize { + let bytes = self.bytes.buf(); + let bytes_len = bytes.remaining(); + self.stream_id.encoded_len() + + self.offset.encoded_len() + + size_of::() + + self.header.as_ref().map_or(0, WireEncode::encoded_len) + + VarInt::try_from(bytes_len).unwrap().encoded_len() + + bytes_len + } + + fn encode(&self, out: &mut W) { + debug_assert!( + self.offset.into_inner() == 0 || self.header.is_none(), + "stream header is only valid at offset 0" + ); + + self.stream_id.encode(out); + self.offset.encode(out); + let mut flags = 0; + if self.fin { + flags |= flag::FIN; + } + if self.header.is_some() { + flags |= flag::HEADER; + } + flags.encode(out); + if let Some(header) = &self.header { + header.encode(out); + } + let mut bytes = self.bytes.buf(); + VarInt::try_from(bytes.remaining()).unwrap().encode(out); + while bytes.has_remaining() { + let chunk = bytes.chunk(); + out.put_slice(chunk); + bytes.advance(chunk.len()); + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct StreamHeader { + pub route_id: RouteId, +} + +impl StreamHeader { + pub const MAX_WIRE_SIZE: usize = RouteId::MAX_ENCODED_LEN; +} + +impl WireDecode for StreamHeader { + fn decode(reader: &mut codec::Reader) -> Result { + Ok(Self { + route_id: reader.decode()?, + }) + } +} + +impl WireEncode for StreamHeader { + fn encoded_len(&self) -> usize { + self.route_id.encoded_len() + } + + fn encode(&self, out: &mut W) { + self.route_id.encode(out); + } +} + +mod flag { + pub const FIN: u8 = 0x01; + pub const HEADER: u8 = 0x02; +} diff --git a/ql-wire/src/encrypted/stream_id.rs b/ql-wire/src/encrypted/stream_id.rs new file mode 100644 index 00000000..07002259 --- /dev/null +++ b/ql-wire/src/encrypted/stream_id.rs @@ -0,0 +1,35 @@ +use crate::{ByteSlice, Reader, VarInt, WireDecode, WireEncode, WireError}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[repr(transparent)] +pub struct StreamId(pub VarInt); + +impl StreamId { + pub const MAX_ENCODED_LEN: usize = VarInt::MAX_SIZE; + + pub const fn into_inner(self) -> u64 { + self.0.into_inner() + } +} + +impl WireEncode for StreamId { + fn encoded_len(&self) -> usize { + self.0.size() + } + + fn encode(&self, out: &mut W) { + self.0.encode(out); + } +} + +impl WireDecode for StreamId { + fn decode(reader: &mut Reader) -> Result { + Ok(Self(reader.decode()?)) + } +} + +impl std::fmt::Display for StreamId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} diff --git a/ql-wire/src/encrypted/stream_window.rs b/ql-wire/src/encrypted/stream_window.rs new file mode 100644 index 00000000..6a2274f9 --- /dev/null +++ b/ql-wire/src/encrypted/stream_window.rs @@ -0,0 +1,29 @@ +use super::StreamId; +use crate::{codec, ByteSlice, VarInt, WireEncode, WireError}; + +/// advertises the highest byte offset the peer may send on a stream. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct StreamWindow { + pub stream_id: StreamId, + pub maximum_offset: VarInt, +} + +impl WireEncode for StreamWindow { + fn encoded_len(&self) -> usize { + self.stream_id.encoded_len() + self.maximum_offset.encoded_len() + } + + fn encode(&self, out: &mut W) { + self.stream_id.encode(out); + self.maximum_offset.encode(out); + } +} + +impl codec::WireDecode for StreamWindow { + fn decode(reader: &mut codec::Reader) -> Result { + Ok(Self { + stream_id: reader.decode()?, + maximum_offset: reader.decode()?, + }) + } +} diff --git a/ql-wire/src/encrypted_message.rs b/ql-wire/src/encrypted_message.rs new file mode 100644 index 00000000..9e11d3d0 --- /dev/null +++ b/ql-wire/src/encrypted_message.rs @@ -0,0 +1,97 @@ +use crate::{ + codec, ByteSlice, Nonce, QlCrypto, SessionKey, WireDecode, WireEncode, WireError, + ENCRYPTED_MESSAGE_AUTH_SIZE, +}; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct EncryptedMessage { + pub auth: [u8; ENCRYPTED_MESSAGE_AUTH_SIZE], + pub ciphertext: B, +} + +impl EncryptedMessage { + pub const AUTH_SIZE: usize = ENCRYPTED_MESSAGE_AUTH_SIZE; + pub const HEADER_LEN: usize = Self::AUTH_SIZE; + + pub fn into_owned(self) -> EncryptedMessage> + where + B: ByteSlice, + { + EncryptedMessage { + auth: self.auth, + ciphertext: self.ciphertext.to_vec(), + } + } +} + +impl WireDecode for EncryptedMessage { + fn decode(reader: &mut codec::Reader) -> Result { + Ok(Self { + auth: reader.decode()?, + ciphertext: reader.take_rest(), + }) + } +} + +impl> EncryptedMessage { + pub fn decrypt( + &self, + crypto: &impl QlCrypto, + key: &SessionKey, + nonce: &Nonce, + aad: &[u8], + ) -> Result, WireError> { + let mut plaintext = self.ciphertext.as_ref().to_vec(); + if !crypto.aes256_gcm_decrypt(key, nonce, aad, &mut plaintext, &self.auth) { + return Err(WireError::DecryptFailed); + } + Ok(plaintext) + } +} + +impl> WireEncode for EncryptedMessage { + fn encoded_len(&self) -> usize { + Self::HEADER_LEN + self.ciphertext.as_ref().len() + } + + fn encode(&self, out: &mut W) { + self.auth.encode(out); + self.ciphertext.as_ref().encode(out); + } +} + +impl> EncryptedMessage { + pub fn decrypt_in_place( + mut self, + crypto: &impl QlCrypto, + key: &SessionKey, + nonce: &Nonce, + aad: &[u8], + ) -> Result { + let ciphertext = self.ciphertext.as_mut(); + if !crypto.aes256_gcm_decrypt(key, nonce, aad, ciphertext, &self.auth) { + return Err(WireError::DecryptFailed); + } + Ok(self.ciphertext) + } +} + +impl EncryptedMessage> { + pub fn encrypt( + crypto: &impl QlCrypto, + key: &SessionKey, + mut plaintext: Vec, + nonce: &Nonce, + aad: &[u8], + ) -> Self { + let auth = crypto.aes256_gcm_encrypt(key, nonce, aad, &mut plaintext); + Self { + auth, + ciphertext: plaintext, + } + } + + pub fn decode(bytes: &[u8]) -> Result { + Ok(EncryptedMessage::decode_exact(bytes)?.into_owned()) + } +} diff --git a/ql-wire/src/error.rs b/ql-wire/src/error.rs new file mode 100644 index 00000000..8da1eec0 --- /dev/null +++ b/ql-wire/src/error.rs @@ -0,0 +1,33 @@ +use core::fmt; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum WireError { + InvalidPayload, + InvalidHandshakeHeader, + InvalidHandshakeMeta, + InvalidPairingId, + InvalidRemoteBundle, + InvalidTransportParams, + Expired, + DecryptFailed, + InvalidState, +} + +impl fmt::Display for WireError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let message = match self { + Self::InvalidPayload => "invalid payload", + Self::InvalidHandshakeHeader => "invalid handshake header", + Self::InvalidHandshakeMeta => "invalid handshake meta", + Self::InvalidPairingId => "invalid pairing id", + Self::InvalidRemoteBundle => "invalid remote bundle", + Self::InvalidTransportParams => "invalid transport params", + Self::Expired => "expired", + Self::DecryptFailed => "decryption failed", + Self::InvalidState => "invalid state", + }; + f.write_str(message) + } +} + +impl std::error::Error for WireError {} diff --git a/ql-wire/src/handshake/ik.rs b/ql-wire/src/handshake/ik.rs new file mode 100644 index 00000000..628e30e7 --- /dev/null +++ b/ql-wire/src/handshake/ik.rs @@ -0,0 +1,376 @@ +use super::{ + decrypt_mlkem_ciphertext, decrypt_peer_bundle, encrypt_mlkem_ciphertext, encrypt_peer_bundle, + finalize_handshake, generate_ephemeral_keypair, init_ik_symmetric, initialize_handshake_meta, + mix_hash_ephemeral, mix_hash_routed_handshake, require_handshake_meta, + EncryptedMlKemCiphertext, EncryptedPeerBundle, EphemeralKeyPair, EphemeralPublicKey, + FinalizedHandshake, HandshakeHeader, Role, SymmetricState, TransportParams, +}; +use crate::{ + codec, ByteSlice, HandshakeKind, HandshakeMeta, MlKemCiphertext, PeerBundle, QlCrypto, + QlIdentity, WireEncode, WireError, +}; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Ik1 { + pub header: HandshakeHeader, + pub meta: HandshakeMeta, + pub transport_params: TransportParams, + pub skem_ciphertext: MlKemCiphertext, + pub ephemeral: EphemeralPublicKey, + pub static_bundle: EncryptedPeerBundle, +} + +impl codec::WireDecode for Ik1 { + fn decode(reader: &mut codec::Reader) -> Result { + Ok(Self { + header: reader.decode()?, + meta: reader.decode()?, + transport_params: reader.decode()?, + skem_ciphertext: reader.decode()?, + ephemeral: reader.decode()?, + static_bundle: reader.decode()?, + }) + } +} + +impl WireEncode for Ik1 { + fn encoded_len(&self) -> usize { + HandshakeHeader::WIRE_SIZE + + HandshakeMeta::WIRE_SIZE + + TransportParams::WIRE_SIZE + + MlKemCiphertext::SIZE + + EphemeralPublicKey::WIRE_SIZE + + self.static_bundle.encoded_len() + } + + fn encode(&self, out: &mut W) { + self.header.encode(out); + self.meta.encode(out); + self.transport_params.encode(out); + self.skem_ciphertext.encode(out); + self.ephemeral.encode(out); + self.static_bundle.encode(out); + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Ik2 { + pub header: HandshakeHeader, + pub meta: HandshakeMeta, + pub transport_params: TransportParams, + pub ekem_ciphertext: MlKemCiphertext, + pub skem_ciphertext: EncryptedMlKemCiphertext, +} + +impl Ik2 { + pub const WIRE_SIZE: usize = HandshakeHeader::WIRE_SIZE + + HandshakeMeta::WIRE_SIZE + + TransportParams::WIRE_SIZE + + MlKemCiphertext::SIZE + + EncryptedMlKemCiphertext::WIRE_SIZE; +} + +impl codec::WireDecode for Ik2 { + fn decode(reader: &mut codec::Reader) -> Result { + Ok(Self { + header: reader.decode()?, + meta: reader.decode()?, + transport_params: reader.decode()?, + ekem_ciphertext: reader.decode()?, + skem_ciphertext: reader.decode()?, + }) + } +} + +impl WireEncode for Ik2 { + fn encoded_len(&self) -> usize { + Self::WIRE_SIZE + } + + fn encode(&self, out: &mut W) { + self.header.encode(out); + self.meta.encode(out); + self.transport_params.encode(out); + self.ekem_ciphertext.encode(out); + self.skem_ciphertext.encode(out); + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum IkStep { + Send1, + Recv1, + Send2, + Recv2, + Done, +} + +#[derive(Debug, Clone)] +pub struct IkHandshake { + role: Role, + step: IkStep, + symmetric: SymmetricState, + local: QlIdentity, + remote_bundle: Option, + local_ephemeral: Option, + remote_ephemeral: Option, + handshake_meta: Option, + local_transport_params: TransportParams, + remote_transport_params: Option, +} + +impl IkHandshake { + pub fn new_initiator( + crypto: &impl QlCrypto, + local: QlIdentity, + remote_bundle: PeerBundle, + local_transport_params: TransportParams, + ) -> Self { + let symmetric = init_ik_symmetric(crypto, &remote_bundle); + Self { + role: Role::Initiator, + step: IkStep::Send1, + symmetric, + local, + remote_bundle: Some(remote_bundle), + local_ephemeral: None, + remote_ephemeral: None, + handshake_meta: None, + local_transport_params, + remote_transport_params: None, + } + } + + pub fn new_responder( + crypto: &impl QlCrypto, + local: QlIdentity, + expected_remote: Option, + local_transport_params: TransportParams, + ) -> Self { + let symmetric = init_ik_symmetric(crypto, &local.bundle()); + Self { + role: Role::Responder, + step: IkStep::Recv1, + symmetric, + local, + remote_bundle: expected_remote, + local_ephemeral: None, + remote_ephemeral: None, + handshake_meta: None, + local_transport_params, + remote_transport_params: None, + } + } + + pub fn is_finished(&self) -> bool { + self.step == IkStep::Done + } + + fn outbound_header(&self) -> Result { + let remote_bundle = self.remote_bundle.as_ref().ok_or(WireError::InvalidState)?; + Ok(HandshakeHeader { + sender: self.local.qid, + recipient: remote_bundle.qid, + }) + } + + fn ensure_inbound_recipient(&self, header: HandshakeHeader) -> Result<(), WireError> { + if header.recipient == self.local.qid { + Ok(()) + } else { + Err(WireError::InvalidPayload) + } + } + + fn ensure_known_remote_sender(&self, header: HandshakeHeader) -> Result<(), WireError> { + if let Some(remote_bundle) = self.remote_bundle.as_ref() { + if header.sender != remote_bundle.qid { + return Err(WireError::InvalidPayload); + } + } + Ok(()) + } + + pub fn write_1( + &mut self, + crypto: &impl QlCrypto, + meta: HandshakeMeta, + ) -> Result { + if self.step != IkStep::Send1 { + return Err(WireError::InvalidState); + } + initialize_handshake_meta(&mut self.handshake_meta, meta)?; + let remote_bundle = self.remote_bundle.as_ref().ok_or(WireError::InvalidState)?; + let header = self.outbound_header()?; + mix_hash_routed_handshake( + &mut self.symmetric, + crypto, + header, + HandshakeKind::Ik1, + meta, + self.local_transport_params, + ); + let (skem_ciphertext, skem_secret) = + crypto.mlkem_encapsulate(&remote_bundle.mlkem_public_key); + self.symmetric.mix_hash(crypto, skem_ciphertext.as_bytes()); + self.symmetric + .mix_key_and_hash(crypto, skem_secret.as_bytes()); + + let local_ephemeral = generate_ephemeral_keypair(crypto); + let public = local_ephemeral.public(); + mix_hash_ephemeral(&mut self.symmetric, crypto, &public); + + let static_bundle = encrypt_peer_bundle(crypto, &mut self.symmetric, &self.local.bundle())?; + + self.local_ephemeral = Some(local_ephemeral); + self.step = IkStep::Recv2; + Ok(Ik1 { + header, + meta, + transport_params: self.local_transport_params, + skem_ciphertext, + ephemeral: public, + static_bundle, + }) + } + + pub fn write_2( + &mut self, + crypto: &impl QlCrypto, + meta: HandshakeMeta, + ) -> Result { + if self.step != IkStep::Send2 { + return Err(WireError::InvalidState); + } + require_handshake_meta(self.handshake_meta.as_ref(), meta)?; + let header = self.outbound_header()?; + mix_hash_routed_handshake( + &mut self.symmetric, + crypto, + header, + HandshakeKind::Ik2, + meta, + self.local_transport_params, + ); + let remote_ephemeral = self + .remote_ephemeral + .clone() + .ok_or(WireError::InvalidState)?; + let (ekem_ciphertext, ekem_secret) = + crypto.mlkem_encapsulate(&remote_ephemeral.mlkem_public_key); + self.symmetric.mix_hash(crypto, ekem_ciphertext.as_bytes()); + self.symmetric.mix_key(crypto, ekem_secret.as_bytes()); + + let remote_bundle = self.remote_bundle.as_ref().ok_or(WireError::InvalidState)?; + let (skem_ciphertext, skem_secret) = + crypto.mlkem_encapsulate(&remote_bundle.mlkem_public_key); + let skem_ciphertext = + encrypt_mlkem_ciphertext(crypto, &mut self.symmetric, &skem_ciphertext)?; + self.symmetric + .mix_key_and_hash(crypto, skem_secret.as_bytes()); + + self.step = IkStep::Done; + Ok(Ik2 { + header, + meta, + transport_params: self.local_transport_params, + ekem_ciphertext, + skem_ciphertext, + }) + } + + pub fn read_1(&mut self, crypto: &impl QlCrypto, message: &Ik1) -> Result<(), WireError> { + if self.step != IkStep::Recv1 { + return Err(WireError::InvalidState); + } + initialize_handshake_meta(&mut self.handshake_meta, message.meta)?; + self.ensure_inbound_recipient(message.header)?; + self.ensure_known_remote_sender(message.header)?; + mix_hash_routed_handshake( + &mut self.symmetric, + crypto, + message.header, + HandshakeKind::Ik1, + message.meta, + message.transport_params, + ); + self.symmetric + .mix_hash(crypto, message.skem_ciphertext.as_bytes()); + let skem_secret = + crypto.mlkem_decapsulate(&self.local.mlkem_private_key, &message.skem_ciphertext); + self.symmetric + .mix_key_and_hash(crypto, skem_secret.as_bytes()); + + mix_hash_ephemeral(&mut self.symmetric, crypto, &message.ephemeral); + self.remote_ephemeral = Some(message.ephemeral.clone()); + + let remote_bundle = + decrypt_peer_bundle(crypto, &mut self.symmetric, &message.static_bundle)?; + if remote_bundle.qid != message.header.sender { + return Err(WireError::InvalidPayload); + } + match self.remote_bundle.as_ref() { + Some(expected) if expected != &remote_bundle => { + return Err(WireError::InvalidPayload); + } + Some(_) => {} + None => self.remote_bundle = Some(remote_bundle), + } + self.remote_transport_params = Some(message.transport_params); + self.step = IkStep::Send2; + Ok(()) + } + + pub fn read_2(&mut self, crypto: &impl QlCrypto, message: &Ik2) -> Result<(), WireError> { + if self.step != IkStep::Recv2 { + return Err(WireError::InvalidState); + } + require_handshake_meta(self.handshake_meta.as_ref(), message.meta)?; + self.ensure_inbound_recipient(message.header)?; + self.ensure_known_remote_sender(message.header)?; + mix_hash_routed_handshake( + &mut self.symmetric, + crypto, + message.header, + HandshakeKind::Ik2, + message.meta, + message.transport_params, + ); + let local_ephemeral = self + .local_ephemeral + .as_ref() + .ok_or(WireError::InvalidState)?; + self.symmetric + .mix_hash(crypto, message.ekem_ciphertext.as_bytes()); + let ekem_secret = + crypto.mlkem_decapsulate(&local_ephemeral.mlkem.private, &message.ekem_ciphertext); + self.symmetric.mix_key(crypto, ekem_secret.as_bytes()); + + let skem_ciphertext = + decrypt_mlkem_ciphertext(crypto, &mut self.symmetric, &message.skem_ciphertext)?; + let skem_secret = crypto.mlkem_decapsulate(&self.local.mlkem_private_key, &skem_ciphertext); + self.symmetric + .mix_key_and_hash(crypto, skem_secret.as_bytes()); + + self.remote_transport_params = Some(message.transport_params); + self.step = IkStep::Done; + Ok(()) + } + + pub fn finalize(self, crypto: &impl QlCrypto) -> Result { + if !self.is_finished() { + return Err(WireError::InvalidState); + } + let remote_bundle = self.remote_bundle.ok_or(WireError::InvalidState)?; + let remote_transport_params = self + .remote_transport_params + .ok_or(WireError::InvalidState)?; + Ok(finalize_handshake( + crypto, + &self.symmetric, + self.role, + remote_bundle, + remote_transport_params, + )) + } +} diff --git a/ql-wire/src/handshake/kk.rs b/ql-wire/src/handshake/kk.rs new file mode 100644 index 00000000..2ad5ee2a --- /dev/null +++ b/ql-wire/src/handshake/kk.rs @@ -0,0 +1,352 @@ +use super::{ + decrypt_mlkem_ciphertext, encrypt_mlkem_ciphertext, finalize_handshake, + generate_ephemeral_keypair, init_kk_symmetric, initialize_handshake_meta, mix_hash_ephemeral, + mix_hash_routed_handshake, require_handshake_meta, EncryptedMlKemCiphertext, EphemeralKeyPair, + EphemeralPublicKey, FinalizedHandshake, HandshakeHeader, Role, SymmetricState, TransportParams, +}; +use crate::{ + codec, ByteSlice, HandshakeKind, HandshakeMeta, MlKemCiphertext, PeerBundle, QlCrypto, + QlIdentity, WireEncode, WireError, +}; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Kk1 { + pub header: HandshakeHeader, + pub meta: HandshakeMeta, + pub transport_params: TransportParams, + pub skem_ciphertext: MlKemCiphertext, + pub ephemeral: EphemeralPublicKey, +} + +impl Kk1 { + pub const WIRE_SIZE: usize = HandshakeHeader::WIRE_SIZE + + HandshakeMeta::WIRE_SIZE + + TransportParams::WIRE_SIZE + + MlKemCiphertext::SIZE + + EphemeralPublicKey::WIRE_SIZE; +} + +impl codec::WireDecode for Kk1 { + fn decode(reader: &mut codec::Reader) -> Result { + Ok(Self { + header: reader.decode()?, + meta: reader.decode()?, + transport_params: reader.decode()?, + skem_ciphertext: reader.decode()?, + ephemeral: reader.decode()?, + }) + } +} + +impl WireEncode for Kk1 { + fn encoded_len(&self) -> usize { + Self::WIRE_SIZE + } + + fn encode(&self, out: &mut W) { + self.header.encode(out); + self.meta.encode(out); + self.transport_params.encode(out); + self.skem_ciphertext.encode(out); + self.ephemeral.encode(out); + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Kk2 { + pub header: HandshakeHeader, + pub meta: HandshakeMeta, + pub transport_params: TransportParams, + pub ekem_ciphertext: MlKemCiphertext, + pub skem_ciphertext: EncryptedMlKemCiphertext, +} + +impl Kk2 { + pub const WIRE_SIZE: usize = HandshakeHeader::WIRE_SIZE + + HandshakeMeta::WIRE_SIZE + + TransportParams::WIRE_SIZE + + MlKemCiphertext::SIZE + + EncryptedMlKemCiphertext::WIRE_SIZE; +} + +impl codec::WireDecode for Kk2 { + fn decode(reader: &mut codec::Reader) -> Result { + Ok(Self { + header: reader.decode()?, + meta: reader.decode()?, + transport_params: reader.decode()?, + ekem_ciphertext: reader.decode()?, + skem_ciphertext: reader.decode()?, + }) + } +} + +impl WireEncode for Kk2 { + fn encoded_len(&self) -> usize { + Self::WIRE_SIZE + } + + fn encode(&self, out: &mut W) { + self.header.encode(out); + self.meta.encode(out); + self.transport_params.encode(out); + self.ekem_ciphertext.encode(out); + self.skem_ciphertext.encode(out); + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum KkStep { + Send1, + Recv1, + Send2, + Recv2, + Done, +} + +#[derive(Debug, Clone)] +pub struct KkHandshake { + role: Role, + step: KkStep, + symmetric: SymmetricState, + local: QlIdentity, + remote_bundle: PeerBundle, + local_ephemeral: Option, + remote_ephemeral: Option, + handshake_meta: Option, + local_transport_params: TransportParams, + remote_transport_params: Option, +} + +impl KkHandshake { + pub fn new_initiator( + crypto: &impl QlCrypto, + local: QlIdentity, + remote_bundle: PeerBundle, + local_transport_params: TransportParams, + ) -> Self { + let symmetric = init_kk_symmetric(crypto, &local.bundle(), &remote_bundle); + Self { + role: Role::Initiator, + step: KkStep::Send1, + symmetric, + local, + remote_bundle, + local_ephemeral: None, + remote_ephemeral: None, + handshake_meta: None, + local_transport_params, + remote_transport_params: None, + } + } + + pub fn new_responder( + crypto: &impl QlCrypto, + local: QlIdentity, + remote_bundle: PeerBundle, + local_transport_params: TransportParams, + ) -> Self { + let symmetric = init_kk_symmetric(crypto, &remote_bundle, &local.bundle()); + Self { + role: Role::Responder, + step: KkStep::Recv1, + symmetric, + local, + remote_bundle, + local_ephemeral: None, + remote_ephemeral: None, + handshake_meta: None, + local_transport_params, + remote_transport_params: None, + } + } + + pub fn is_finished(&self) -> bool { + self.step == KkStep::Done + } + + fn outbound_header(&self) -> HandshakeHeader { + HandshakeHeader { + sender: self.local.qid, + recipient: self.remote_bundle.qid, + } + } + + fn inbound_header(&self) -> HandshakeHeader { + HandshakeHeader { + sender: self.remote_bundle.qid, + recipient: self.local.qid, + } + } + + fn ensure_inbound_header(&self, header: HandshakeHeader) -> Result<(), WireError> { + if header == self.inbound_header() { + Ok(()) + } else { + Err(WireError::InvalidPayload) + } + } + + pub fn write_1( + &mut self, + crypto: &impl QlCrypto, + meta: HandshakeMeta, + ) -> Result { + if self.step != KkStep::Send1 { + return Err(WireError::InvalidState); + } + initialize_handshake_meta(&mut self.handshake_meta, meta)?; + let header = self.outbound_header(); + mix_hash_routed_handshake( + &mut self.symmetric, + crypto, + header, + HandshakeKind::Kk1, + meta, + self.local_transport_params, + ); + let (skem_ciphertext, skem_secret) = + crypto.mlkem_encapsulate(&self.remote_bundle.mlkem_public_key); + self.symmetric + .encrypt_and_hash(crypto, skem_ciphertext.as_bytes())?; + self.symmetric + .mix_key_and_hash(crypto, skem_secret.as_bytes()); + + let local_ephemeral = generate_ephemeral_keypair(crypto); + let public = local_ephemeral.public(); + mix_hash_ephemeral(&mut self.symmetric, crypto, &public); + + self.local_ephemeral = Some(local_ephemeral); + self.step = KkStep::Recv2; + Ok(Kk1 { + header, + meta, + transport_params: self.local_transport_params, + skem_ciphertext, + ephemeral: public, + }) + } + + pub fn write_2( + &mut self, + crypto: &impl QlCrypto, + meta: HandshakeMeta, + ) -> Result { + if self.step != KkStep::Send2 { + return Err(WireError::InvalidState); + } + require_handshake_meta(self.handshake_meta.as_ref(), meta)?; + let header = self.outbound_header(); + mix_hash_routed_handshake( + &mut self.symmetric, + crypto, + header, + HandshakeKind::Kk2, + meta, + self.local_transport_params, + ); + let remote_ephemeral = self + .remote_ephemeral + .clone() + .ok_or(WireError::InvalidState)?; + let (ekem_ciphertext, ekem_secret) = + crypto.mlkem_encapsulate(&remote_ephemeral.mlkem_public_key); + self.symmetric.mix_hash(crypto, ekem_ciphertext.as_bytes()); + self.symmetric.mix_key(crypto, ekem_secret.as_bytes()); + + let (skem_ciphertext, skem_secret) = + crypto.mlkem_encapsulate(&self.remote_bundle.mlkem_public_key); + let skem_ciphertext = + encrypt_mlkem_ciphertext(crypto, &mut self.symmetric, &skem_ciphertext)?; + self.symmetric + .mix_key_and_hash(crypto, skem_secret.as_bytes()); + + self.step = KkStep::Done; + Ok(Kk2 { + header, + meta, + transport_params: self.local_transport_params, + ekem_ciphertext, + skem_ciphertext, + }) + } + + pub fn read_1(&mut self, crypto: &impl QlCrypto, message: &Kk1) -> Result<(), WireError> { + if self.step != KkStep::Recv1 { + return Err(WireError::InvalidState); + } + initialize_handshake_meta(&mut self.handshake_meta, message.meta)?; + self.ensure_inbound_header(message.header)?; + mix_hash_routed_handshake( + &mut self.symmetric, + crypto, + message.header, + HandshakeKind::Kk1, + message.meta, + message.transport_params, + ); + self.symmetric + .decrypt_and_hash(crypto, message.skem_ciphertext.as_bytes())?; + let skem_secret = + crypto.mlkem_decapsulate(&self.local.mlkem_private_key, &message.skem_ciphertext); + self.symmetric + .mix_key_and_hash(crypto, skem_secret.as_bytes()); + + mix_hash_ephemeral(&mut self.symmetric, crypto, &message.ephemeral); + self.remote_ephemeral = Some(message.ephemeral.clone()); + self.remote_transport_params = Some(message.transport_params); + self.step = KkStep::Send2; + Ok(()) + } + + pub fn read_2(&mut self, crypto: &impl QlCrypto, message: &Kk2) -> Result<(), WireError> { + if self.step != KkStep::Recv2 { + return Err(WireError::InvalidState); + } + require_handshake_meta(self.handshake_meta.as_ref(), message.meta)?; + self.ensure_inbound_header(message.header)?; + mix_hash_routed_handshake( + &mut self.symmetric, + crypto, + message.header, + HandshakeKind::Kk2, + message.meta, + message.transport_params, + ); + let local_ephemeral = self + .local_ephemeral + .as_ref() + .ok_or(WireError::InvalidState)?; + self.symmetric + .mix_hash(crypto, message.ekem_ciphertext.as_bytes()); + let ekem_secret = + crypto.mlkem_decapsulate(&local_ephemeral.mlkem.private, &message.ekem_ciphertext); + self.symmetric.mix_key(crypto, ekem_secret.as_bytes()); + + let skem_ciphertext = + decrypt_mlkem_ciphertext(crypto, &mut self.symmetric, &message.skem_ciphertext)?; + let skem_secret = crypto.mlkem_decapsulate(&self.local.mlkem_private_key, &skem_ciphertext); + self.symmetric + .mix_key_and_hash(crypto, skem_secret.as_bytes()); + + self.remote_transport_params = Some(message.transport_params); + self.step = KkStep::Done; + Ok(()) + } + + pub fn finalize(self, crypto: &impl QlCrypto) -> Result { + if !self.is_finished() { + return Err(WireError::InvalidState); + } + let remote_transport_params = self + .remote_transport_params + .ok_or(WireError::InvalidState)?; + Ok(finalize_handshake( + crypto, + &self.symmetric, + self.role, + self.remote_bundle, + remote_transport_params, + )) + } +} diff --git a/ql-wire/src/handshake/meta.rs b/ql-wire/src/handshake/meta.rs new file mode 100644 index 00000000..8cb0cf97 --- /dev/null +++ b/ql-wire/src/handshake/meta.rs @@ -0,0 +1,48 @@ +use crate::{codec, ByteSlice, WireEncode, WireError}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[repr(transparent)] +pub struct HandshakeId(pub u32); + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct HandshakeMeta { + pub handshake_id: HandshakeId, +} + +impl codec::WireDecode for HandshakeId { + fn decode(reader: &mut codec::Reader) -> Result { + Ok(Self(reader.decode()?)) + } +} + +impl WireEncode for HandshakeId { + fn encoded_len(&self) -> usize { + size_of::() + } + + fn encode(&self, out: &mut W) { + self.0.encode(out); + } +} + +impl HandshakeMeta { + pub const WIRE_SIZE: usize = size_of::(); +} + +impl WireEncode for HandshakeMeta { + fn encoded_len(&self) -> usize { + Self::WIRE_SIZE + } + + fn encode(&self, out: &mut W) { + self.handshake_id.encode(out); + } +} + +impl codec::WireDecode for HandshakeMeta { + fn decode(reader: &mut codec::Reader) -> Result { + Ok(Self { + handshake_id: reader.decode()?, + }) + } +} diff --git a/ql-wire/src/handshake/mod.rs b/ql-wire/src/handshake/mod.rs new file mode 100644 index 00000000..a9b7cf87 --- /dev/null +++ b/ql-wire/src/handshake/mod.rs @@ -0,0 +1,590 @@ +use crate::{ + codec, ByteSlice, ConnectionId, HandshakeKind, MlKemCiphertext, MlKemKeyPair, MlKemPublicKey, + Nonce, PeerBundle, QlCrypto, SessionKey, WireDecode, WireEncode, WireError, + ENCRYPTED_MESSAGE_AUTH_SIZE, QID, +}; + +mod ik; +mod kk; +mod meta; +mod pairing; +mod transport_params; +mod xx; + +pub use ik::{Ik1, Ik2, IkHandshake}; +pub use kk::{Kk1, Kk2, KkHandshake}; +pub use meta::{HandshakeId, HandshakeMeta}; +pub use pairing::{PairingId, PairingToken}; +pub use transport_params::TransportParams; +pub use xx::{Xx1, Xx2, Xx3, Xx4, XxHandshake}; + +const SHA256_BLOCK_LEN: usize = 64; +const PROTOCOL_IK: &[u8] = b"ql-wire:pq-ik:v1"; +const PROTOCOL_KK: &[u8] = b"ql-wire:pq-kk:v1"; +const PROTOCOL_XX: &[u8] = b"ql-wire:pq-xx:v1"; +const CONNECTION_ID_DOMAIN: &[u8] = b"ql-wire:conn-id:v1"; +const HANDSHAKE_PREAMBLE_DOMAIN: &[u8] = b"ql-wire:handshake-preamble:v1"; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct HandshakeHeader { + pub sender: QID, + pub recipient: QID, +} + +impl HandshakeHeader { + pub const WIRE_SIZE: usize = QID::SIZE * 2; +} + +impl WireEncode for HandshakeHeader { + fn encoded_len(&self) -> usize { + Self::WIRE_SIZE + } + + fn encode(&self, out: &mut W) { + self.sender.encode(out); + self.recipient.encode(out); + } +} + +impl codec::WireDecode for HandshakeHeader { + fn decode(reader: &mut codec::Reader) -> Result { + Ok(Self { + sender: reader.decode()?, + recipient: reader.decode()?, + }) + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct EphemeralPublicKey { + pub mlkem_public_key: MlKemPublicKey, +} + +impl EphemeralPublicKey { + pub const WIRE_SIZE: usize = MlKemPublicKey::SIZE; +} + +impl WireEncode for EphemeralPublicKey { + fn encoded_len(&self) -> usize { + Self::WIRE_SIZE + } + + fn encode(&self, out: &mut W) { + self.mlkem_public_key.encode(out); + } +} + +impl codec::WireDecode for EphemeralPublicKey { + fn decode(reader: &mut codec::Reader) -> Result { + Ok(Self { + mlkem_public_key: reader.decode()?, + }) + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct EncryptedMlKemCiphertext(pub Box<[u8; Self::WIRE_SIZE]>); + +impl EncryptedMlKemCiphertext { + pub const WIRE_SIZE: usize = MlKemCiphertext::SIZE + ENCRYPTED_MESSAGE_AUTH_SIZE; + + pub fn new(data: Box<[u8; Self::WIRE_SIZE]>) -> Self { + Self(data) + } + + pub fn as_bytes(&self) -> &[u8; Self::WIRE_SIZE] { + self.0.as_ref() + } +} + +impl WireEncode for EncryptedMlKemCiphertext { + fn encoded_len(&self) -> usize { + Self::WIRE_SIZE + } + + fn encode(&self, out: &mut W) { + self.0.as_ref().encode(out); + } +} + +impl codec::WireDecode for EncryptedMlKemCiphertext { + fn decode(reader: &mut codec::Reader) -> Result { + Ok(Self::new(reader.decode()?)) + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct EncryptedPeerBundle(pub Box<[u8]>); + +impl EncryptedPeerBundle { + pub const MAX_WIRE_SIZE: usize = PeerBundle::MAX_WIRE_SIZE + ENCRYPTED_MESSAGE_AUTH_SIZE; + + pub fn as_bytes(&self) -> &[u8] { + self.0.as_ref() + } +} + +impl WireEncode for EncryptedPeerBundle { + fn encoded_len(&self) -> usize { + self.0.len() + } + + fn encode(&self, out: &mut W) { + self.as_bytes().encode(out); + } +} + +impl codec::WireDecode for EncryptedPeerBundle { + fn decode(reader: &mut codec::Reader) -> Result { + let data = reader.take_rest(); + if data.len() > Self::MAX_WIRE_SIZE { + return Err(WireError::InvalidPayload); + } + Ok(Self(data.to_vec().into_boxed_slice())) + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct FinalizedHandshake { + pub tx_key: SessionKey, + pub rx_key: SessionKey, + pub tx_connection_id: ConnectionId, + pub rx_connection_id: ConnectionId, + pub handshake_hash: [u8; 32], + pub remote_bundle: PeerBundle, + /// Transport parameters advertised by the remote peer + pub remote_transport_params: TransportParams, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum Role { + Initiator, + Responder, +} + +#[derive(Debug, Clone)] +struct EphemeralKeyPair { + mlkem: MlKemKeyPair, +} + +impl EphemeralKeyPair { + fn public(&self) -> EphemeralPublicKey { + EphemeralPublicKey { + mlkem_public_key: self.mlkem.public.clone(), + } + } +} + +#[derive(Debug, Clone)] +struct CipherState { + key: Option, + nonce: u64, +} + +impl CipherState { + fn new() -> Self { + Self { + key: None, + nonce: 0, + } + } + + fn initialize_key(&mut self, key: SessionKey) { + self.key = Some(key); + self.nonce = 0; + } + + fn has_key(&self) -> bool { + self.key.is_some() + } + + fn encrypt( + &mut self, + crypto: &impl QlCrypto, + aad: &[u8], + plaintext: &[u8], + ) -> Result, WireError> { + let key = self.key.as_ref().ok_or(WireError::InvalidState)?; + let nonce = Nonce::from_counter(self.nonce); + let mut ciphertext = Vec::with_capacity(plaintext.len() + ENCRYPTED_MESSAGE_AUTH_SIZE); + ciphertext.extend_from_slice(plaintext); + let auth = crypto.aes256_gcm_encrypt(key, &nonce, aad, &mut ciphertext); + self.nonce = self.nonce.wrapping_add(1); + ciphertext.extend_from_slice(&auth); + Ok(ciphertext) + } + + fn decrypt( + &mut self, + crypto: &impl QlCrypto, + aad: &[u8], + ciphertext: &[u8], + ) -> Result, WireError> { + if ciphertext.len() < ENCRYPTED_MESSAGE_AUTH_SIZE { + return Err(WireError::InvalidPayload); + } + let split = ciphertext.len() - ENCRYPTED_MESSAGE_AUTH_SIZE; + let (ciphertext, auth) = ciphertext.split_at(split); + let mut plaintext = ciphertext.to_vec(); + let key = self.key.as_ref().ok_or(WireError::InvalidState)?; + let nonce = Nonce::from_counter(self.nonce); + let mut auth_tag = [0u8; ENCRYPTED_MESSAGE_AUTH_SIZE]; + auth_tag.copy_from_slice(auth); + if !crypto.aes256_gcm_decrypt(key, &nonce, aad, &mut plaintext, &auth_tag) { + return Err(WireError::DecryptFailed); + } + self.nonce = self.nonce.wrapping_add(1); + Ok(plaintext) + } +} + +#[derive(Debug, Clone)] +struct SymmetricState { + chaining_key: [u8; 32], + handshake_hash: [u8; 32], + cipher: CipherState, +} + +impl SymmetricState { + fn new(crypto: &impl QlCrypto, protocol_name: &[u8]) -> Self { + let h = crypto.sha256(&[protocol_name]); + Self { + chaining_key: h, + handshake_hash: h, + cipher: CipherState::new(), + } + } + + fn mix_hash(&mut self, crypto: &impl QlCrypto, data: &[u8]) { + self.handshake_hash = crypto.sha256(&[&self.handshake_hash, data]); + } + + fn mix_key(&mut self, crypto: &impl QlCrypto, input_key_material: &[u8]) { + let (chaining_key, cipher_key) = hkdf2(crypto, &self.chaining_key, input_key_material); + self.chaining_key = chaining_key; + self.cipher.initialize_key(cipher_key); + } + + fn mix_key_and_hash(&mut self, crypto: &impl QlCrypto, input_key_material: &[u8]) { + let (chaining_key, hash_input, cipher_key) = + hkdf3(crypto, &self.chaining_key, input_key_material); + self.chaining_key = chaining_key; + self.mix_hash(crypto, &hash_input); + self.cipher.initialize_key(cipher_key); + } + + fn encrypt_and_hash( + &mut self, + crypto: &impl QlCrypto, + plaintext: &[u8], + ) -> Result, WireError> { + if self.cipher.has_key() { + let ciphertext = self + .cipher + .encrypt(crypto, &self.handshake_hash, plaintext)?; + self.mix_hash(crypto, &ciphertext); + Ok(ciphertext) + } else { + self.mix_hash(crypto, plaintext); + Ok(plaintext.to_vec()) + } + } + + fn decrypt_and_hash( + &mut self, + crypto: &impl QlCrypto, + ciphertext: &[u8], + ) -> Result, WireError> { + if self.cipher.has_key() { + let plaintext = self + .cipher + .decrypt(crypto, &self.handshake_hash, ciphertext)?; + self.mix_hash(crypto, ciphertext); + Ok(plaintext) + } else { + self.mix_hash(crypto, ciphertext); + Ok(ciphertext.to_vec()) + } + } + + fn split_for_role(&self, crypto: &impl QlCrypto, role: Role) -> (SessionKey, SessionKey) { + let temp_key = hmac_sha256(crypto, &self.chaining_key, &[&[]]); + let k1 = SessionKey::from_data(hmac_sha256(crypto, &temp_key, &[&[1]])); + let k2 = SessionKey::from_data(hmac_sha256(crypto, &temp_key, &[k1.as_bytes(), &[2]])); + match role { + Role::Initiator => (k1, k2), + Role::Responder => (k2, k1), + } + } +} + +fn init_kk_symmetric( + crypto: &impl QlCrypto, + initiator_bundle: &PeerBundle, + responder_bundle: &PeerBundle, +) -> SymmetricState { + let mut symmetric = SymmetricState::new(crypto, PROTOCOL_KK); + symmetric.mix_hash(crypto, &initiator_bundle.encode_vec()); + symmetric.mix_hash(crypto, &responder_bundle.encode_vec()); + symmetric +} + +fn init_ik_symmetric(crypto: &impl QlCrypto, responder_bundle: &PeerBundle) -> SymmetricState { + let mut symmetric = SymmetricState::new(crypto, PROTOCOL_IK); + symmetric.mix_hash(crypto, &responder_bundle.encode_vec()); + symmetric +} + +fn init_xx_symmetric(crypto: &impl QlCrypto) -> SymmetricState { + SymmetricState::new(crypto, PROTOCOL_XX) +} + +fn mix_psk_pairing_token( + symmetric: &mut SymmetricState, + crypto: &impl QlCrypto, + pairing_token: PairingToken, +) { + symmetric.mix_key_and_hash(crypto, &pairing_token.psk(crypto)); +} + +fn generate_ephemeral_keypair(crypto: &impl QlCrypto) -> EphemeralKeyPair { + EphemeralKeyPair { + mlkem: crypto.mlkem_generate_keypair(), + } +} + +fn mix_hash_ephemeral( + symmetric: &mut SymmetricState, + crypto: &impl QlCrypto, + public: &EphemeralPublicKey, +) { + symmetric.mix_hash(crypto, public.mlkem_public_key.as_bytes()); +} + +fn mix_hash_routed_handshake( + symmetric: &mut SymmetricState, + crypto: &impl QlCrypto, + header: HandshakeHeader, + kind: HandshakeKind, + meta: HandshakeMeta, + transport_params: TransportParams, +) { + mix_hash_handshake_preamble( + symmetric, + crypto, + &header.encode_vec(), + kind, + meta, + transport_params, + ); +} + +fn mix_hash_pairing_handshake( + symmetric: &mut SymmetricState, + crypto: &impl QlCrypto, + header: HandshakeHeader, + kind: HandshakeKind, + meta: HandshakeMeta, + pairing_id: PairingId, + transport_params: TransportParams, +) { + let mut preamble = header.encode_vec(); + pairing_id.encode(&mut preamble); + mix_hash_handshake_preamble(symmetric, crypto, &preamble, kind, meta, transport_params); +} + +fn mix_hash_handshake_preamble( + symmetric: &mut SymmetricState, + crypto: &impl QlCrypto, + header: &[u8], + kind: HandshakeKind, + meta: HandshakeMeta, + transport_params: TransportParams, +) { + symmetric.mix_hash(crypto, HANDSHAKE_PREAMBLE_DOMAIN); + symmetric.mix_hash(crypto, header); + symmetric.mix_hash(crypto, &[kind as u8]); + symmetric.mix_hash(crypto, &meta.encode_vec()); + symmetric.mix_hash(crypto, &transport_params.encode_vec()); +} + +fn initialize_handshake_meta( + expected: &mut Option, + meta: HandshakeMeta, +) -> Result<(), WireError> { + match expected { + Some(stored) if *stored != meta => Err(WireError::InvalidHandshakeMeta), + Some(_) => Ok(()), + None => { + *expected = Some(meta); + Ok(()) + } + } +} + +fn require_handshake_meta( + expected: Option<&HandshakeMeta>, + meta: HandshakeMeta, +) -> Result<(), WireError> { + match expected { + Some(stored) if *stored == meta => Ok(()), + _ => Err(WireError::InvalidHandshakeMeta), + } +} + +fn initialize_transport_params( + expected: &mut Option, + transport_params: TransportParams, +) -> Result<(), WireError> { + match expected { + Some(stored) if *stored != transport_params => Err(WireError::InvalidTransportParams), + Some(_) => Ok(()), + None => { + *expected = Some(transport_params); + Ok(()) + } + } +} + +fn require_transport_params( + expected: Option<&TransportParams>, + transport_params: TransportParams, +) -> Result<(), WireError> { + match expected { + Some(stored) if *stored == transport_params => Ok(()), + _ => Err(WireError::InvalidTransportParams), + } +} + +fn encrypt_peer_bundle( + crypto: &impl QlCrypto, + symmetric: &mut SymmetricState, + bundle: &PeerBundle, +) -> Result { + let ciphertext = symmetric.encrypt_and_hash(crypto, &bundle.encode_vec())?; + Ok(EncryptedPeerBundle(ciphertext.into_boxed_slice())) +} + +fn decrypt_peer_bundle( + crypto: &impl QlCrypto, + symmetric: &mut SymmetricState, + bundle: &EncryptedPeerBundle, +) -> Result { + let plaintext = symmetric.decrypt_and_hash(crypto, bundle.as_bytes())?; + let bundle = PeerBundle::decode_exact(plaintext.as_slice())?; + if !bundle.qid_matches_public_key(crypto) { + return Err(WireError::InvalidRemoteBundle); + } + Ok(bundle) +} + +fn encrypt_mlkem_ciphertext( + crypto: &impl QlCrypto, + symmetric: &mut SymmetricState, + ciphertext: &MlKemCiphertext, +) -> Result { + let encrypted = symmetric.encrypt_and_hash(crypto, ciphertext.as_bytes())?; + let out: Box<[u8; EncryptedMlKemCiphertext::WIRE_SIZE]> = + encrypted.try_into().map_err(|_| WireError::InvalidState)?; + Ok(EncryptedMlKemCiphertext::new(out)) +} + +fn decrypt_mlkem_ciphertext( + crypto: &impl QlCrypto, + symmetric: &mut SymmetricState, + ciphertext: &EncryptedMlKemCiphertext, +) -> Result { + let plaintext = symmetric.decrypt_and_hash(crypto, ciphertext.as_bytes())?; + let out: Box<[u8; MlKemCiphertext::SIZE]> = plaintext + .try_into() + .map_err(|_| WireError::InvalidPayload)?; + Ok(MlKemCiphertext::new(out)) +} + +fn finalize_handshake( + crypto: &impl QlCrypto, + symmetric: &SymmetricState, + role: Role, + remote_bundle: PeerBundle, + remote_transport_params: TransportParams, +) -> FinalizedHandshake { + let handshake_hash = symmetric.handshake_hash; + let (tx_key, rx_key) = symmetric.split_for_role(crypto, role); + let (initiator_rx, responder_rx) = derive_connection_ids(crypto, &handshake_hash); + let (tx_connection_id, rx_connection_id) = match role { + Role::Initiator => (responder_rx, initiator_rx), + Role::Responder => (initiator_rx, responder_rx), + }; + FinalizedHandshake { + tx_key, + rx_key, + tx_connection_id, + rx_connection_id, + handshake_hash, + remote_bundle, + remote_transport_params, + } +} + +fn derive_connection_ids( + crypto: &impl QlCrypto, + handshake_hash: &[u8; 32], +) -> (ConnectionId, ConnectionId) { + let initiator = crypto.sha256(&[CONNECTION_ID_DOMAIN, handshake_hash, b"initiator-rx"]); + let responder = crypto.sha256(&[CONNECTION_ID_DOMAIN, handshake_hash, b"responder-rx"]); + let mut initiator_rx = [0u8; ConnectionId::SIZE]; + let mut responder_rx = [0u8; ConnectionId::SIZE]; + initiator_rx.copy_from_slice(&initiator[..ConnectionId::SIZE]); + responder_rx.copy_from_slice(&responder[..ConnectionId::SIZE]); + ( + ConnectionId::from_data(initiator_rx), + ConnectionId::from_data(responder_rx), + ) +} + +fn hkdf2( + crypto: &impl QlCrypto, + chaining_key: &[u8; 32], + input_key_material: &[u8], +) -> ([u8; 32], SessionKey) { + let temp_key = hmac_sha256(crypto, chaining_key, &[input_key_material]); + let out1 = hmac_sha256(crypto, &temp_key, &[&[1]]); + let out2 = hmac_sha256(crypto, &temp_key, &[&out1, &[2]]); + (out1, SessionKey::from_data(out2)) +} + +fn hkdf3( + crypto: &impl QlCrypto, + chaining_key: &[u8; 32], + input_key_material: &[u8], +) -> ([u8; 32], [u8; 32], SessionKey) { + let temp_key = hmac_sha256(crypto, chaining_key, &[input_key_material]); + let out1 = hmac_sha256(crypto, &temp_key, &[&[1]]); + let out2 = hmac_sha256(crypto, &temp_key, &[&out1, &[2]]); + let out3 = hmac_sha256(crypto, &temp_key, &[&out2, &[3]]); + (out1, out2, SessionKey::from_data(out3)) +} + +fn hmac_sha256(crypto: &impl QlCrypto, key: &[u8], parts: &[&[u8]]) -> [u8; 32] { + let mut key_block = [0u8; SHA256_BLOCK_LEN]; + if key.len() > SHA256_BLOCK_LEN { + key_block[..32].copy_from_slice(&crypto.sha256(&[key])); + } else { + key_block[..key.len()].copy_from_slice(key); + } + + let mut ipad = [0x36u8; SHA256_BLOCK_LEN]; + let mut opad = [0x5cu8; SHA256_BLOCK_LEN]; + for (dst, src) in ipad.iter_mut().zip(key_block.iter()) { + *dst ^= *src; + } + for (dst, src) in opad.iter_mut().zip(key_block.iter()) { + *dst ^= *src; + } + + let mut inner_parts: Vec<&[u8]> = Vec::with_capacity(parts.len() + 1); + inner_parts.push(&ipad); + inner_parts.extend_from_slice(parts); + let inner = crypto.sha256(&inner_parts); + crypto.sha256(&[&opad, &inner]) +} diff --git a/ql-wire/src/handshake/pairing.rs b/ql-wire/src/handshake/pairing.rs new file mode 100644 index 00000000..237f066b --- /dev/null +++ b/ql-wire/src/handshake/pairing.rs @@ -0,0 +1,83 @@ +use std::fmt::{self, Display, Formatter}; + +use crate::{codec, ByteSlice, QlCrypto, WireEncode, WireError}; + +const PAIRING_ID_DOMAIN: &[u8] = b"ql-wire:pairing-id:v1"; +const PAIRING_PSK_DOMAIN: &[u8] = b"ql-wire:pairing-psk:v1"; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[repr(transparent)] +pub struct PairingToken(pub [u8; Self::SIZE]); + +impl PairingToken { + pub const SIZE: usize = 16; + + pub fn id(&self, crypto: &impl QlCrypto) -> PairingId { + let hash = crypto.sha256(&[PAIRING_ID_DOMAIN, &self.0]); + let mut id = [0u8; PairingId::SIZE]; + id.copy_from_slice(&hash[..PairingId::SIZE]); + PairingId(id) + } + + pub(super) fn psk(&self, crypto: &impl QlCrypto) -> [u8; 32] { + crypto.sha256(&[PAIRING_PSK_DOMAIN, &self.0]) + } +} + +impl Display for PairingToken { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + for byte in self.0 { + write!(f, "{byte:02x}")?; + } + Ok(()) + } +} + +impl WireEncode for PairingToken { + fn encoded_len(&self) -> usize { + Self::SIZE + } + + fn encode(&self, out: &mut W) { + self.0.encode(out); + } +} + +impl codec::WireDecode for PairingToken { + fn decode(reader: &mut codec::Reader) -> Result { + Ok(Self(reader.decode()?)) + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[repr(transparent)] +pub struct PairingId(pub [u8; Self::SIZE]); + +impl PairingId { + pub const SIZE: usize = 16; +} + +impl Display for PairingId { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + for byte in self.0 { + write!(f, "{byte:02x}")?; + } + Ok(()) + } +} + +impl WireEncode for PairingId { + fn encoded_len(&self) -> usize { + Self::SIZE + } + + fn encode(&self, out: &mut W) { + self.0.encode(out); + } +} + +impl codec::WireDecode for PairingId { + fn decode(reader: &mut codec::Reader) -> Result { + Ok(Self(reader.decode()?)) + } +} diff --git a/ql-wire/src/handshake/transport_params.rs b/ql-wire/src/handshake/transport_params.rs new file mode 100644 index 00000000..bfd0d427 --- /dev/null +++ b/ql-wire/src/handshake/transport_params.rs @@ -0,0 +1,38 @@ +use crate::{codec, ByteSlice, WireEncode, WireError}; + +/// Session parameters advertised in the handshake +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct TransportParams { + /// Initial per-stream receive credit granted to the remote peer + pub initial_stream_receive_window: u32, +} + +impl TransportParams { + pub const WIRE_SIZE: usize = size_of::(); +} + +impl WireEncode for TransportParams { + fn encoded_len(&self) -> usize { + Self::WIRE_SIZE + } + + fn encode(&self, out: &mut W) { + self.initial_stream_receive_window.encode(out); + } +} + +impl Default for TransportParams { + fn default() -> Self { + Self { + initial_stream_receive_window: 16 * 1024, + } + } +} + +impl codec::WireDecode for TransportParams { + fn decode(reader: &mut codec::Reader) -> Result { + Ok(Self { + initial_stream_receive_window: reader.decode()?, + }) + } +} diff --git a/ql-wire/src/handshake/xx.rs b/ql-wire/src/handshake/xx.rs new file mode 100644 index 00000000..0b6452d4 --- /dev/null +++ b/ql-wire/src/handshake/xx.rs @@ -0,0 +1,612 @@ +use super::{ + decrypt_mlkem_ciphertext, decrypt_peer_bundle, encrypt_mlkem_ciphertext, encrypt_peer_bundle, + finalize_handshake, generate_ephemeral_keypair, init_xx_symmetric, initialize_handshake_meta, + initialize_transport_params, mix_hash_ephemeral, mix_hash_pairing_handshake, + mix_psk_pairing_token, require_handshake_meta, require_transport_params, + EncryptedMlKemCiphertext, EncryptedPeerBundle, EphemeralKeyPair, EphemeralPublicKey, + FinalizedHandshake, HandshakeHeader, Role, SymmetricState, TransportParams, +}; +use crate::{ + codec, ByteSlice, HandshakeKind, HandshakeMeta, MlKemCiphertext, PairingId, PairingToken, + PeerBundle, QlCrypto, QlIdentity, WireEncode, WireError, QID, +}; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Xx1 { + pub header: HandshakeHeader, + pub meta: HandshakeMeta, + pub pairing_id: PairingId, + pub transport_params: TransportParams, + pub ephemeral: EphemeralPublicKey, +} + +impl Xx1 { + pub const WIRE_SIZE: usize = HandshakeHeader::WIRE_SIZE + + HandshakeMeta::WIRE_SIZE + + PairingId::SIZE + + TransportParams::WIRE_SIZE + + EphemeralPublicKey::WIRE_SIZE; +} + +impl codec::WireDecode for Xx1 { + fn decode(reader: &mut codec::Reader) -> Result { + Ok(Self { + header: reader.decode()?, + meta: reader.decode()?, + pairing_id: reader.decode()?, + transport_params: reader.decode()?, + ephemeral: reader.decode()?, + }) + } +} + +impl WireEncode for Xx1 { + fn encoded_len(&self) -> usize { + Self::WIRE_SIZE + } + + fn encode(&self, out: &mut W) { + self.header.encode(out); + self.meta.encode(out); + self.pairing_id.encode(out); + self.transport_params.encode(out); + self.ephemeral.encode(out); + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Xx2 { + pub header: HandshakeHeader, + pub meta: HandshakeMeta, + pub pairing_id: PairingId, + pub transport_params: TransportParams, + pub ekem_ciphertext: MlKemCiphertext, + pub static_bundle: EncryptedPeerBundle, +} + +impl codec::WireDecode for Xx2 { + fn decode(reader: &mut codec::Reader) -> Result { + Ok(Self { + header: reader.decode()?, + meta: reader.decode()?, + pairing_id: reader.decode()?, + transport_params: reader.decode()?, + ekem_ciphertext: reader.decode()?, + static_bundle: reader.decode()?, + }) + } +} + +impl WireEncode for Xx2 { + fn encoded_len(&self) -> usize { + HandshakeHeader::WIRE_SIZE + + HandshakeMeta::WIRE_SIZE + + PairingId::SIZE + + TransportParams::WIRE_SIZE + + MlKemCiphertext::SIZE + + self.static_bundle.encoded_len() + } + + fn encode(&self, out: &mut W) { + self.header.encode(out); + self.meta.encode(out); + self.pairing_id.encode(out); + self.transport_params.encode(out); + self.ekem_ciphertext.encode(out); + self.static_bundle.encode(out); + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Xx3 { + pub header: HandshakeHeader, + pub meta: HandshakeMeta, + pub pairing_id: PairingId, + pub transport_params: TransportParams, + pub skem_ciphertext: EncryptedMlKemCiphertext, + pub static_bundle: EncryptedPeerBundle, +} + +impl codec::WireDecode for Xx3 { + fn decode(reader: &mut codec::Reader) -> Result { + Ok(Self { + header: reader.decode()?, + meta: reader.decode()?, + pairing_id: reader.decode()?, + transport_params: reader.decode()?, + skem_ciphertext: reader.decode()?, + static_bundle: reader.decode()?, + }) + } +} + +impl WireEncode for Xx3 { + fn encoded_len(&self) -> usize { + HandshakeHeader::WIRE_SIZE + + HandshakeMeta::WIRE_SIZE + + PairingId::SIZE + + TransportParams::WIRE_SIZE + + EncryptedMlKemCiphertext::WIRE_SIZE + + self.static_bundle.encoded_len() + } + + fn encode(&self, out: &mut W) { + self.header.encode(out); + self.meta.encode(out); + self.pairing_id.encode(out); + self.transport_params.encode(out); + self.skem_ciphertext.encode(out); + self.static_bundle.encode(out); + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Xx4 { + pub header: HandshakeHeader, + pub meta: HandshakeMeta, + pub pairing_id: PairingId, + pub transport_params: TransportParams, + pub skem_ciphertext: EncryptedMlKemCiphertext, +} + +impl Xx4 { + pub const WIRE_SIZE: usize = HandshakeHeader::WIRE_SIZE + + HandshakeMeta::WIRE_SIZE + + PairingId::SIZE + + TransportParams::WIRE_SIZE + + EncryptedMlKemCiphertext::WIRE_SIZE; +} + +impl codec::WireDecode for Xx4 { + fn decode(reader: &mut codec::Reader) -> Result { + Ok(Self { + header: reader.decode()?, + meta: reader.decode()?, + pairing_id: reader.decode()?, + transport_params: reader.decode()?, + skem_ciphertext: reader.decode()?, + }) + } +} + +impl WireEncode for Xx4 { + fn encoded_len(&self) -> usize { + Self::WIRE_SIZE + } + + fn encode(&self, out: &mut W) { + self.header.encode(out); + self.meta.encode(out); + self.pairing_id.encode(out); + self.transport_params.encode(out); + self.skem_ciphertext.encode(out); + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum XxStep { + Send1, + Recv1, + Send2, + Recv2, + Send3, + Recv3, + Send4, + Recv4, + Done, +} + +#[derive(Debug, Clone)] +pub struct XxHandshake { + role: Role, + step: XxStep, + symmetric: SymmetricState, + local: QlIdentity, + remote_qid: QID, + pairing_token: PairingToken, + remote_bundle: Option, + local_ephemeral: Option, + remote_ephemeral: Option, + handshake_meta: Option, + local_transport_params: TransportParams, + remote_transport_params: Option, +} + +impl XxHandshake { + pub fn new_initiator( + crypto: &impl QlCrypto, + local: QlIdentity, + remote_qid: QID, + pairing_token: PairingToken, + local_transport_params: TransportParams, + ) -> Self { + Self { + role: Role::Initiator, + step: XxStep::Send1, + symmetric: init_xx_symmetric(crypto), + local, + remote_qid, + pairing_token, + remote_bundle: None, + local_ephemeral: None, + remote_ephemeral: None, + handshake_meta: None, + local_transport_params, + remote_transport_params: None, + } + } + + pub fn new_responder( + crypto: &impl QlCrypto, + local: QlIdentity, + remote_qid: QID, + pairing_token: PairingToken, + local_transport_params: TransportParams, + ) -> Self { + Self { + role: Role::Responder, + step: XxStep::Recv1, + symmetric: init_xx_symmetric(crypto), + local, + remote_qid, + pairing_token, + remote_bundle: None, + local_ephemeral: None, + remote_ephemeral: None, + handshake_meta: None, + local_transport_params, + remote_transport_params: None, + } + } + + pub fn is_finished(&self) -> bool { + self.step == XxStep::Done + } + + pub fn pairing_token(&self) -> PairingToken { + self.pairing_token + } + + pub fn pairing_id(&self, crypto: &impl QlCrypto) -> PairingId { + self.pairing_token.id(crypto) + } + + pub fn remote_qid(&self) -> QID { + self.remote_qid + } + + pub fn remote_bundle(&self) -> Option<&PeerBundle> { + self.remote_bundle.as_ref() + } + + fn header(&self) -> HandshakeHeader { + HandshakeHeader { + sender: self.local.qid, + recipient: self.remote_qid, + } + } + + fn ensure_inbound_header( + &self, + crypto: &impl QlCrypto, + header: HandshakeHeader, + pairing_id: PairingId, + ) -> Result<(), WireError> { + if header.sender != self.remote_qid || header.recipient != self.local.qid { + return Err(WireError::InvalidHandshakeHeader); + } + if pairing_id != self.pairing_token.id(crypto) { + return Err(WireError::InvalidPairingId); + } + Ok(()) + } + + fn ensure_remote_bundle(&self, bundle: &PeerBundle) -> Result<(), WireError> { + if bundle.qid == self.remote_qid { + Ok(()) + } else { + Err(WireError::InvalidRemoteBundle) + } + } + + pub fn write_1( + &mut self, + crypto: &impl QlCrypto, + meta: HandshakeMeta, + ) -> Result { + if self.step != XxStep::Send1 { + return Err(WireError::InvalidState); + } + initialize_handshake_meta(&mut self.handshake_meta, meta)?; + let header = self.header(); + let pairing_id = self.pairing_token.id(crypto); + mix_hash_pairing_handshake( + &mut self.symmetric, + crypto, + header, + HandshakeKind::Xx1, + meta, + pairing_id, + self.local_transport_params, + ); + mix_psk_pairing_token(&mut self.symmetric, crypto, self.pairing_token); + + let local_ephemeral = generate_ephemeral_keypair(crypto); + let ephemeral = local_ephemeral.public(); + mix_hash_ephemeral(&mut self.symmetric, crypto, &ephemeral); + + self.local_ephemeral = Some(local_ephemeral); + self.step = XxStep::Recv2; + Ok(Xx1 { + header, + meta, + pairing_id, + transport_params: self.local_transport_params, + ephemeral, + }) + } + + pub fn read_1(&mut self, crypto: &impl QlCrypto, message: &Xx1) -> Result<(), WireError> { + if self.step != XxStep::Recv1 { + return Err(WireError::InvalidState); + } + initialize_handshake_meta(&mut self.handshake_meta, message.meta)?; + self.ensure_inbound_header(crypto, message.header, message.pairing_id)?; + mix_hash_pairing_handshake( + &mut self.symmetric, + crypto, + message.header, + HandshakeKind::Xx1, + message.meta, + message.pairing_id, + message.transport_params, + ); + mix_psk_pairing_token(&mut self.symmetric, crypto, self.pairing_token); + mix_hash_ephemeral(&mut self.symmetric, crypto, &message.ephemeral); + + self.remote_ephemeral = Some(message.ephemeral.clone()); + initialize_transport_params(&mut self.remote_transport_params, message.transport_params)?; + self.step = XxStep::Send2; + Ok(()) + } + + pub fn write_2( + &mut self, + crypto: &impl QlCrypto, + meta: HandshakeMeta, + ) -> Result { + if self.step != XxStep::Send2 { + return Err(WireError::InvalidState); + } + require_handshake_meta(self.handshake_meta.as_ref(), meta)?; + let header = self.header(); + let pairing_id = self.pairing_token.id(crypto); + mix_hash_pairing_handshake( + &mut self.symmetric, + crypto, + header, + HandshakeKind::Xx2, + meta, + pairing_id, + self.local_transport_params, + ); + + let remote_ephemeral = self + .remote_ephemeral + .as_ref() + .ok_or(WireError::InvalidState)?; + let (ekem_ciphertext, ekem_secret) = + crypto.mlkem_encapsulate(&remote_ephemeral.mlkem_public_key); + self.symmetric.mix_hash(crypto, ekem_ciphertext.as_bytes()); + self.symmetric.mix_key(crypto, ekem_secret.as_bytes()); + + let static_bundle = encrypt_peer_bundle(crypto, &mut self.symmetric, &self.local.bundle())?; + + self.step = XxStep::Recv3; + Ok(Xx2 { + header, + meta, + pairing_id, + transport_params: self.local_transport_params, + ekem_ciphertext, + static_bundle, + }) + } + + pub fn read_2(&mut self, crypto: &impl QlCrypto, message: &Xx2) -> Result<(), WireError> { + if self.step != XxStep::Recv2 { + return Err(WireError::InvalidState); + } + require_handshake_meta(self.handshake_meta.as_ref(), message.meta)?; + self.ensure_inbound_header(crypto, message.header, message.pairing_id)?; + mix_hash_pairing_handshake( + &mut self.symmetric, + crypto, + message.header, + HandshakeKind::Xx2, + message.meta, + message.pairing_id, + message.transport_params, + ); + + let local_ephemeral = self + .local_ephemeral + .as_ref() + .ok_or(WireError::InvalidState)?; + self.symmetric + .mix_hash(crypto, message.ekem_ciphertext.as_bytes()); + let ekem_secret = + crypto.mlkem_decapsulate(&local_ephemeral.mlkem.private, &message.ekem_ciphertext); + self.symmetric.mix_key(crypto, ekem_secret.as_bytes()); + + let remote_bundle = + decrypt_peer_bundle(crypto, &mut self.symmetric, &message.static_bundle)?; + self.ensure_remote_bundle(&remote_bundle)?; + self.remote_bundle = Some(remote_bundle); + initialize_transport_params(&mut self.remote_transport_params, message.transport_params)?; + self.step = XxStep::Send3; + Ok(()) + } + + pub fn write_3( + &mut self, + crypto: &impl QlCrypto, + meta: HandshakeMeta, + ) -> Result { + if self.step != XxStep::Send3 { + return Err(WireError::InvalidState); + } + require_handshake_meta(self.handshake_meta.as_ref(), meta)?; + let header = self.header(); + let pairing_id = self.pairing_token.id(crypto); + mix_hash_pairing_handshake( + &mut self.symmetric, + crypto, + header, + HandshakeKind::Xx3, + meta, + pairing_id, + self.local_transport_params, + ); + + let remote_bundle = self.remote_bundle.as_ref().ok_or(WireError::InvalidState)?; + let (skem_ciphertext, skem_secret) = + crypto.mlkem_encapsulate(&remote_bundle.mlkem_public_key); + let skem_ciphertext = + encrypt_mlkem_ciphertext(crypto, &mut self.symmetric, &skem_ciphertext)?; + self.symmetric + .mix_key_and_hash(crypto, skem_secret.as_bytes()); + + let static_bundle = encrypt_peer_bundle(crypto, &mut self.symmetric, &self.local.bundle())?; + + self.step = XxStep::Recv4; + Ok(Xx3 { + header, + meta, + pairing_id, + transport_params: self.local_transport_params, + skem_ciphertext, + static_bundle, + }) + } + + pub fn read_3(&mut self, crypto: &impl QlCrypto, message: &Xx3) -> Result<(), WireError> { + if self.step != XxStep::Recv3 { + return Err(WireError::InvalidState); + } + require_handshake_meta(self.handshake_meta.as_ref(), message.meta)?; + self.ensure_inbound_header(crypto, message.header, message.pairing_id)?; + require_transport_params( + self.remote_transport_params.as_ref(), + message.transport_params, + )?; + mix_hash_pairing_handshake( + &mut self.symmetric, + crypto, + message.header, + HandshakeKind::Xx3, + message.meta, + message.pairing_id, + message.transport_params, + ); + + let skem_ciphertext = + decrypt_mlkem_ciphertext(crypto, &mut self.symmetric, &message.skem_ciphertext)?; + let skem_secret = crypto.mlkem_decapsulate(&self.local.mlkem_private_key, &skem_ciphertext); + self.symmetric + .mix_key_and_hash(crypto, skem_secret.as_bytes()); + + let remote_bundle = + decrypt_peer_bundle(crypto, &mut self.symmetric, &message.static_bundle)?; + self.ensure_remote_bundle(&remote_bundle)?; + self.remote_bundle = Some(remote_bundle); + self.step = XxStep::Send4; + Ok(()) + } + + pub fn write_4( + &mut self, + crypto: &impl QlCrypto, + meta: HandshakeMeta, + ) -> Result { + if self.step != XxStep::Send4 { + return Err(WireError::InvalidState); + } + require_handshake_meta(self.handshake_meta.as_ref(), meta)?; + let header = self.header(); + let pairing_id = self.pairing_token.id(crypto); + mix_hash_pairing_handshake( + &mut self.symmetric, + crypto, + header, + HandshakeKind::Xx4, + meta, + pairing_id, + self.local_transport_params, + ); + + let remote_bundle = self.remote_bundle.as_ref().ok_or(WireError::InvalidState)?; + let (skem_ciphertext, skem_secret) = + crypto.mlkem_encapsulate(&remote_bundle.mlkem_public_key); + let skem_ciphertext = + encrypt_mlkem_ciphertext(crypto, &mut self.symmetric, &skem_ciphertext)?; + self.symmetric + .mix_key_and_hash(crypto, skem_secret.as_bytes()); + + self.step = XxStep::Done; + Ok(Xx4 { + header, + meta, + pairing_id, + transport_params: self.local_transport_params, + skem_ciphertext, + }) + } + + pub fn read_4(&mut self, crypto: &impl QlCrypto, message: &Xx4) -> Result<(), WireError> { + if self.step != XxStep::Recv4 { + return Err(WireError::InvalidState); + } + require_handshake_meta(self.handshake_meta.as_ref(), message.meta)?; + self.ensure_inbound_header(crypto, message.header, message.pairing_id)?; + require_transport_params( + self.remote_transport_params.as_ref(), + message.transport_params, + )?; + mix_hash_pairing_handshake( + &mut self.symmetric, + crypto, + message.header, + HandshakeKind::Xx4, + message.meta, + message.pairing_id, + message.transport_params, + ); + + let skem_ciphertext = + decrypt_mlkem_ciphertext(crypto, &mut self.symmetric, &message.skem_ciphertext)?; + let skem_secret = crypto.mlkem_decapsulate(&self.local.mlkem_private_key, &skem_ciphertext); + self.symmetric + .mix_key_and_hash(crypto, skem_secret.as_bytes()); + + self.step = XxStep::Done; + Ok(()) + } + + pub fn finalize(self, crypto: &impl QlCrypto) -> Result { + if !self.is_finished() { + return Err(WireError::InvalidState); + } + let remote_bundle = self.remote_bundle.ok_or(WireError::InvalidState)?; + let remote_transport_params = self + .remote_transport_params + .ok_or(WireError::InvalidState)?; + Ok(finalize_handshake( + crypto, + &self.symmetric, + self.role, + remote_bundle, + remote_transport_params, + )) + } +} diff --git a/ql-wire/src/header.rs b/ql-wire/src/header.rs new file mode 100644 index 00000000..88764c0a --- /dev/null +++ b/ql-wire/src/header.rs @@ -0,0 +1,121 @@ +use ::bytes::BufMut; + +use crate::{ + codec, ByteSlice, VarInt, VarIntBoundsExceeded, WireEncode, WireError, QL_WIRE_VERSION, +}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct SessionHeader { + pub connection_id: ConnectionId, + pub seq: RecordSeq, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[repr(transparent)] +pub struct RecordSeq(pub VarInt); + +impl RecordSeq { + pub const MAX_ENCODED_LEN: usize = VarInt::MAX_SIZE; + + pub const fn from_u32(value: u32) -> Self { + Self(VarInt::from_u32(value)) + } + + pub fn from_u64(value: u64) -> Result { + Ok(Self(VarInt::from_u64(value)?)) + } + + pub const fn into_inner(self) -> u64 { + self.0.into_inner() + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[repr(transparent)] +pub struct ConnectionId(pub [u8; Self::SIZE]); + +impl ConnectionId { + pub const SIZE: usize = 16; + + pub const fn from_data(data: [u8; Self::SIZE]) -> Self { + Self(data) + } + + pub const fn as_bytes(&self) -> &[u8; Self::SIZE] { + &self.0 + } +} + +impl codec::WireDecode for RecordSeq { + fn decode(reader: &mut codec::Reader) -> Result { + Ok(Self(reader.decode()?)) + } +} + +impl WireEncode for RecordSeq { + fn encoded_len(&self) -> usize { + self.0.size() + } + + fn encode(&self, out: &mut W) { + self.0.encode(out); + } +} + +impl codec::WireDecode for ConnectionId { + fn decode(reader: &mut codec::Reader) -> Result { + Ok(Self::from_data(reader.decode()?)) + } +} + +impl WireEncode for ConnectionId { + fn encoded_len(&self) -> usize { + Self::SIZE + } + + fn encode(&self, out: &mut W) { + self.0.encode(out); + } +} + +impl SessionHeader { + pub const MAX_ENCODED_LEN: usize = ConnectionId::SIZE + RecordSeq::MAX_ENCODED_LEN; + const AAD_DOMAIN: &[u8] = b"ql-wire:session-aad:v1"; + const AAD_RECORD_KIND_SESSION: u8 = 1; + + pub fn aad(&self) -> Vec { + let aad_len = Self::AAD_DOMAIN.len() + + size_of::() + + size_of::() + + ConnectionId::SIZE + + self.seq.encoded_len(); + let mut aad = Vec::with_capacity(aad_len); + aad.put_slice(Self::AAD_DOMAIN); + aad.put_u8(QL_WIRE_VERSION); + aad.put_u8(Self::AAD_RECORD_KIND_SESSION); + self.connection_id.encode(&mut aad); + self.seq.encode(&mut aad); + debug_assert_eq!(aad.len(), aad_len); + aad + } +} + +impl WireEncode for SessionHeader { + fn encoded_len(&self) -> usize { + ConnectionId::SIZE + self.seq.encoded_len() + } + + fn encode(&self, out: &mut W) { + self.connection_id.encode(out); + self.seq.encode(out); + } +} + +impl codec::WireDecode for SessionHeader { + fn decode(reader: &mut codec::Reader) -> Result { + Ok(Self { + connection_id: reader.decode()?, + seq: reader.decode()?, + }) + } +} diff --git a/ql-wire/src/identity.rs b/ql-wire/src/identity.rs new file mode 100644 index 00000000..1f8dbee7 --- /dev/null +++ b/ql-wire/src/identity.rs @@ -0,0 +1,192 @@ +use std::ops::Deref; + +use crate::{ + codec, ByteSlice, MlKemKeyPair, MlKemPrivateKey, MlKemPublicKey, QlCrypto, QlHash, VarInt, + WireEncode, WireError, QID, +}; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct PeerBundle { + pub version: u16, + pub qid: QID, + pub capabilities: u32, + pub mlkem_public_key: MlKemPublicKey, + pub name: QlName, +} + +impl PeerBundle { + pub const VERSION: u16 = 1; + pub const FIXED_WIRE_SIZE: usize = + size_of::() + QID::SIZE + size_of::() + MlKemPublicKey::SIZE; + pub const MAX_WIRE_SIZE: usize = Self::FIXED_WIRE_SIZE + VarInt::MAX_SIZE + QlName::MAX_LEN; + + pub fn qid_matches_public_key(&self, crypto: &impl QlHash) -> bool { + self.qid.matches_public_key(crypto, &self.mlkem_public_key) + } +} + +impl WireEncode for PeerBundle { + fn encoded_len(&self) -> usize { + Self::FIXED_WIRE_SIZE + self.name.encoded_len() + } + + fn encode(&self, out: &mut W) { + self.version.encode(out); + self.qid.encode(out); + self.capabilities.encode(out); + self.mlkem_public_key.encode(out); + self.name.encode(out); + } +} + +impl codec::WireDecode for PeerBundle { + fn decode(reader: &mut codec::Reader) -> Result { + Ok(Self { + version: reader.decode()?, + qid: reader.decode()?, + capabilities: reader.decode()?, + mlkem_public_key: reader.decode()?, + name: reader.decode()?, + }) + } +} + +#[derive(Debug, Clone)] +pub struct QlIdentity { + pub qid: QID, + pub mlkem_private_key: MlKemPrivateKey, + pub mlkem_public_key: MlKemPublicKey, + pub capabilities: u32, + pub name: QlName, +} + +impl QlIdentity { + pub const FIXED_WIRE_SIZE: usize = + QID::SIZE + MlKemPrivateKey::SIZE + MlKemPublicKey::SIZE + size_of::(); + pub const MAX_WIRE_SIZE: usize = Self::FIXED_WIRE_SIZE + VarInt::MAX_SIZE + QlName::MAX_LEN; + + pub fn new( + crypto: &impl QlHash, + mlkem_private_key: MlKemPrivateKey, + mlkem_public_key: MlKemPublicKey, + name: impl Into, + ) -> Result { + let name = QlName::new(name)?; + let qid = QID::derive(crypto, &mlkem_public_key); + Ok(Self { + qid, + mlkem_private_key, + mlkem_public_key, + capabilities: 0, + name, + }) + } + + #[must_use] + pub fn with_capabilities(mut self, capabilities: u32) -> Self { + self.capabilities = capabilities; + self + } + + pub fn with_name(mut self, name: impl Into) -> Result { + self.name = QlName::new(name)?; + Ok(self) + } + + pub fn bundle(&self) -> PeerBundle { + PeerBundle { + version: PeerBundle::VERSION, + qid: self.qid, + capabilities: self.capabilities, + mlkem_public_key: self.mlkem_public_key.clone(), + name: self.name.clone(), + } + } +} + +impl WireEncode for QlIdentity { + fn encoded_len(&self) -> usize { + Self::FIXED_WIRE_SIZE + self.name.encoded_len() + } + + fn encode(&self, out: &mut W) { + self.qid.encode(out); + self.mlkem_private_key.as_bytes().encode(out); + self.mlkem_public_key.encode(out); + self.capabilities.encode(out); + self.name.encode(out); + } +} + +impl codec::WireDecode for QlIdentity { + fn decode(reader: &mut codec::Reader) -> Result { + Ok(Self { + qid: reader.decode()?, + mlkem_private_key: MlKemPrivateKey::new(reader.decode()?), + mlkem_public_key: reader.decode()?, + capabilities: reader.decode()?, + name: reader.decode()?, + }) + } +} + +pub fn generate_identity( + crypto: &impl QlCrypto, + name: impl Into, +) -> Result { + let MlKemKeyPair { + private: mlkem_private_key, + public: mlkem_public_key, + } = crypto.mlkem_generate_keypair(); + QlIdentity::new(crypto, mlkem_private_key, mlkem_public_key, name) +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct QlName(String); + +impl QlName { + pub const MAX_LEN: usize = 256; + + pub fn new(name: impl Into) -> Result { + let name = name.into(); + if name.is_empty() || name.len() > Self::MAX_LEN { + return Err(WireError::InvalidPayload); + } + Ok(Self(name)) + } +} + +impl Deref for QlName { + type Target = str; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl WireEncode for QlName { + fn encoded_len(&self) -> usize { + let len = VarInt::try_from(self.0.len()).unwrap(); + len.encoded_len() + self.0.len() + } + + fn encode(&self, out: &mut W) { + VarInt::try_from(self.0.len()) + .expect("identity name length fits in varint") + .encode(out); + self.0.as_bytes().encode(out); + } +} + +impl codec::WireDecode for QlName { + fn decode(reader: &mut codec::Reader) -> Result { + let len = usize::try_from(reader.decode::()?.into_inner()) + .map_err(|_| WireError::InvalidPayload)?; + if len == 0 || len > Self::MAX_LEN { + return Err(WireError::InvalidPayload); + } + let bytes = reader.take_bytes(len)?; + let name = std::str::from_utf8(&bytes).map_err(|_| WireError::InvalidPayload)?; + Ok(QlName::new(name)?) + } +} diff --git a/ql-wire/src/lib.rs b/ql-wire/src/lib.rs new file mode 100644 index 00000000..2ad8a171 --- /dev/null +++ b/ql-wire/src/lib.rs @@ -0,0 +1,45 @@ +//! +//! quantum link protocol wire format +//! + +#![allow(clippy::too_many_arguments)] + +mod bytes; +mod codec; +mod crypto; +mod encrypted; +mod encrypted_message; +mod error; +mod handshake; +mod header; +mod identity; +mod nonce; +mod pq; +mod qid; +mod record; +#[cfg(any(feature = "test-utils", test))] +mod testing; +mod varint; + +pub use bytes::*; +pub use codec::*; +pub use crypto::*; +pub use encrypted::*; +pub use encrypted_message::*; +pub use error::*; +pub use handshake::*; +pub use header::*; +pub use identity::*; +pub use nonce::*; +pub use pq::*; +pub use qid::*; +pub use record::*; +#[cfg(any(feature = "test-utils", test))] +pub use testing::*; +pub use varint::*; + +pub const QL_WIRE_VERSION: u8 = 1; +pub const ENCRYPTED_MESSAGE_AUTH_SIZE: usize = 16; + +#[cfg(test)] +mod tests; diff --git a/ql-wire/src/nonce.rs b/ql-wire/src/nonce.rs new file mode 100644 index 00000000..c7e6d793 --- /dev/null +++ b/ql-wire/src/nonce.rs @@ -0,0 +1,13 @@ +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[repr(transparent)] +pub struct Nonce(pub [u8; Self::SIZE]); + +impl Nonce { + pub const SIZE: usize = 12; + + pub fn from_counter(counter: u64) -> Self { + let mut nonce = [0u8; Self::SIZE]; + nonce[4..].copy_from_slice(&counter.to_le_bytes()); + Self(nonce) + } +} diff --git a/ql-wire/src/pq.rs b/ql-wire/src/pq.rs new file mode 100644 index 00000000..327ef7c4 --- /dev/null +++ b/ql-wire/src/pq.rs @@ -0,0 +1,159 @@ +use crate::{codec, ByteSlice, WireEncode, WireError}; + +pub const ML_KEM_SUITE_TAG: &[u8] = b"ml-kem-1024"; + +// ql-wire fixes the protocol to ML-KEM-1024 on the wire, but the host +// platform is free to satisfy QlKem with any backend that produces the same +// serialized sizes. +const ML_KEM_1024_SHARED_SECRET_SIZE: usize = 32; +const ML_KEM_1024_PUBLIC_KEY_SIZE: usize = 1568; +const ML_KEM_1024_PRIVATE_KEY_SIZE: usize = 3168; +const ML_KEM_1024_CIPHERTEXT_SIZE: usize = 1568; + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct SessionKey([u8; Self::SIZE]); + +impl SessionKey { + pub const SIZE: usize = ML_KEM_1024_SHARED_SECRET_SIZE; + + pub const fn from_data(data: [u8; Self::SIZE]) -> Self { + Self(data) + } + + pub const fn data(&self) -> &[u8; Self::SIZE] { + &self.0 + } + + pub const fn as_bytes(&self) -> &[u8; Self::SIZE] { + &self.0 + } +} + +impl AsRef<[u8]> for SessionKey { + fn as_ref(&self) -> &[u8] { + &self.0 + } +} + +impl Drop for SessionKey { + fn drop(&mut self) { + self.0.fill(0); + } +} + +impl WireEncode for SessionKey { + fn encoded_len(&self) -> usize { + Self::SIZE + } + + fn encode(&self, out: &mut W) { + self.0.encode(out); + } +} + +impl codec::WireDecode for SessionKey { + fn decode(reader: &mut codec::Reader) -> Result { + Ok(Self::from_data(reader.decode()?)) + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct MlKemPublicKey(Box<[u8; MlKemPublicKey::SIZE]>); + +impl MlKemPublicKey { + pub const SIZE: usize = ML_KEM_1024_PUBLIC_KEY_SIZE; + + pub fn new(data: Box<[u8; Self::SIZE]>) -> Self { + Self(data) + } + + pub fn as_bytes(&self) -> &[u8; Self::SIZE] { + self.0.as_ref() + } +} + +impl Drop for MlKemPublicKey { + fn drop(&mut self) { + self.0.as_mut().fill(0); + } +} + +impl codec::WireDecode for MlKemPublicKey { + fn decode(reader: &mut codec::Reader) -> Result { + Ok(Self::new(reader.decode()?)) + } +} + +impl WireEncode for MlKemPublicKey { + fn encoded_len(&self) -> usize { + Self::SIZE + } + + fn encode(&self, out: &mut W) { + self.0.as_ref().encode(out); + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct MlKemPrivateKey(Box<[u8; MlKemPrivateKey::SIZE]>); + +impl MlKemPrivateKey { + pub const SIZE: usize = ML_KEM_1024_PRIVATE_KEY_SIZE; + + pub fn new(data: Box<[u8; Self::SIZE]>) -> Self { + Self(data) + } + + pub fn as_bytes(&self) -> &[u8; Self::SIZE] { + self.0.as_ref() + } +} + +impl Drop for MlKemPrivateKey { + fn drop(&mut self) { + self.0.as_mut().fill(0); + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct MlKemCiphertext(Box<[u8; MlKemCiphertext::SIZE]>); + +impl MlKemCiphertext { + pub const SIZE: usize = ML_KEM_1024_CIPHERTEXT_SIZE; + + pub fn new(data: Box<[u8; Self::SIZE]>) -> Self { + Self(data) + } + + pub fn as_bytes(&self) -> &[u8; Self::SIZE] { + self.0.as_ref() + } +} + +impl Drop for MlKemCiphertext { + fn drop(&mut self) { + self.0.as_mut().fill(0); + } +} + +impl codec::WireDecode for MlKemCiphertext { + fn decode(reader: &mut codec::Reader) -> Result { + Ok(Self::new(reader.decode()?)) + } +} + +impl WireEncode for MlKemCiphertext { + fn encoded_len(&self) -> usize { + Self::SIZE + } + + fn encode(&self, out: &mut W) { + self.0.as_ref().encode(out); + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct MlKemKeyPair { + pub private: MlKemPrivateKey, + pub public: MlKemPublicKey, +} diff --git a/ql-wire/src/qid.rs b/ql-wire/src/qid.rs new file mode 100644 index 00000000..55c6684f --- /dev/null +++ b/ql-wire/src/qid.rs @@ -0,0 +1,44 @@ +use crate::{codec, ByteSlice, MlKemPublicKey, QlHash, WireEncode, WireError, ML_KEM_SUITE_TAG}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[repr(transparent)] +pub struct QID(pub [u8; Self::SIZE]); + +impl QID { + pub const SIZE: usize = 16; + + pub fn derive(crypto: &impl QlHash, mlkem_public_key: &MlKemPublicKey) -> Self { + let digest = crypto.sha256(&[ + b"quantum-link qid v1", + ML_KEM_SUITE_TAG, + mlkem_public_key.as_bytes(), + ]); + let mut qid = [0u8; Self::SIZE]; + qid.copy_from_slice(&digest[..Self::SIZE]); + Self(qid) + } + + pub fn matches_public_key( + &self, + crypto: &impl QlHash, + mlkem_public_key: &MlKemPublicKey, + ) -> bool { + *self == Self::derive(crypto, mlkem_public_key) + } +} + +impl WireEncode for QID { + fn encoded_len(&self) -> usize { + Self::SIZE + } + + fn encode(&self, out: &mut W) { + self.0.encode(out); + } +} + +impl codec::WireDecode for QID { + fn decode(reader: &mut codec::Reader) -> Result { + Ok(Self(reader.decode()?)) + } +} diff --git a/ql-wire/src/record.rs b/ql-wire/src/record.rs new file mode 100644 index 00000000..163a1bff --- /dev/null +++ b/ql-wire/src/record.rs @@ -0,0 +1,254 @@ +use crate::{ + codec, + encrypted_message::EncryptedMessage, + handshake::{Ik1, Ik2, Kk1, Kk2, Xx1, Xx2, Xx3, Xx4}, + ByteSlice, SessionHeader, WireDecode, WireEncode, WireError, QL_WIRE_VERSION, +}; + +pub fn encode_record(out: &mut W, record_type: RecordType, body: &T) +where + W: bytes::BufMut + ?Sized, + T: WireEncode + ?Sized, +{ + RecordHeader { + version: QL_WIRE_VERSION, + record_type, + } + .encode(out); + body.encode(out); +} + +pub fn encode_record_vec(record_type: RecordType, body: &T) -> Vec { + let mut out = Vec::with_capacity(RecordHeader::WIRE_SIZE + body.encoded_len()); + encode_record(&mut out, record_type, body); + out +} + +pub fn decode_record(bytes: B) -> Result<(RecordHeader, T), WireError> +where + T: WireDecode, + B: ByteSlice, +{ + let mut reader = codec::Reader::new(bytes); + Ok((reader.decode()?, reader.decode()?)) +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct RecordHeader { + pub version: u8, + pub record_type: RecordType, +} + +impl RecordHeader { + pub const WIRE_SIZE: usize = size_of::() + size_of::(); +} + +impl WireDecode for RecordHeader { + fn decode(reader: &mut codec::Reader) -> Result { + Ok(Self { + version: reader.decode()?, + record_type: reader.decode()?, + }) + } +} + +impl WireEncode for RecordHeader { + fn encoded_len(&self) -> usize { + Self::WIRE_SIZE + } + + fn encode(&self, out: &mut W) { + out.put_u8(self.version); + self.record_type.encode(out); + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[repr(u8)] +pub enum RecordType { + Handshake = 1, + Session = 2, +} + +impl TryFrom for RecordType { + type Error = WireError; + + fn try_from(value: u8) -> Result { + match value { + 1 => Ok(Self::Handshake), + 2 => Ok(Self::Session), + _ => Err(WireError::InvalidPayload), + } + } +} + +impl WireDecode for RecordType { + fn decode(reader: &mut codec::Reader) -> Result { + reader.decode::()?.try_into() + } +} + +impl WireEncode for RecordType { + fn encoded_len(&self) -> usize { + size_of::() + } + + fn encode(&self, out: &mut W) { + out.put_u8(*self as u8); + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum QlHandshakeRecord { + Ik1(Ik1), + Ik2(Ik2), + Kk1(Kk1), + Kk2(Kk2), + Xx1(Xx1), + Xx2(Xx2), + Xx3(Xx3), + Xx4(Xx4), +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[repr(u8)] +pub enum HandshakeKind { + Ik1 = 1, + Ik2 = 2, + Kk1 = 3, + Kk2 = 4, + Xx1 = 5, + Xx2 = 6, + Xx3 = 7, + Xx4 = 8, +} + +impl TryFrom for HandshakeKind { + type Error = WireError; + + fn try_from(value: u8) -> Result { + match value { + 1 => Ok(Self::Ik1), + 2 => Ok(Self::Ik2), + 3 => Ok(Self::Kk1), + 4 => Ok(Self::Kk2), + 5 => Ok(Self::Xx1), + 6 => Ok(Self::Xx2), + 7 => Ok(Self::Xx3), + 8 => Ok(Self::Xx4), + _ => Err(WireError::InvalidPayload), + } + } +} + +impl WireDecode for HandshakeKind { + fn decode(reader: &mut codec::Reader) -> Result { + reader.decode::()?.try_into() + } +} + +impl WireEncode for HandshakeKind { + fn encoded_len(&self) -> usize { + size_of::() + } + + fn encode(&self, out: &mut W) { + out.put_u8(*self as u8); + } +} + +impl QlHandshakeRecord { + pub fn kind(&self) -> HandshakeKind { + match self { + Self::Ik1(_) => HandshakeKind::Ik1, + Self::Ik2(_) => HandshakeKind::Ik2, + Self::Kk1(_) => HandshakeKind::Kk1, + Self::Kk2(_) => HandshakeKind::Kk2, + Self::Xx1(_) => HandshakeKind::Xx1, + Self::Xx2(_) => HandshakeKind::Xx2, + Self::Xx3(_) => HandshakeKind::Xx3, + Self::Xx4(_) => HandshakeKind::Xx4, + } + } +} + +impl WireEncode for QlHandshakeRecord { + fn encoded_len(&self) -> usize { + self.kind().encoded_len() + + match self { + Self::Ik1(message) => message.encoded_len(), + Self::Ik2(message) => message.encoded_len(), + Self::Kk1(message) => message.encoded_len(), + Self::Kk2(message) => message.encoded_len(), + Self::Xx1(message) => message.encoded_len(), + Self::Xx2(message) => message.encoded_len(), + Self::Xx3(message) => message.encoded_len(), + Self::Xx4(message) => message.encoded_len(), + } + } + + fn encode(&self, out: &mut W) { + self.kind().encode(out); + match self { + Self::Ik1(message) => message.encode(out), + Self::Ik2(message) => message.encode(out), + Self::Kk1(message) => message.encode(out), + Self::Kk2(message) => message.encode(out), + Self::Xx1(message) => message.encode(out), + Self::Xx2(message) => message.encode(out), + Self::Xx3(message) => message.encode(out), + Self::Xx4(message) => message.encode(out), + } + } +} + +impl WireDecode for QlHandshakeRecord { + fn decode(reader: &mut codec::Reader) -> Result { + let kind = reader.decode::()?; + match kind { + HandshakeKind::Ik1 => Ok(Self::Ik1(reader.decode()?)), + HandshakeKind::Ik2 => Ok(Self::Ik2(reader.decode()?)), + HandshakeKind::Kk1 => Ok(Self::Kk1(reader.decode()?)), + HandshakeKind::Kk2 => Ok(Self::Kk2(reader.decode()?)), + HandshakeKind::Xx1 => Ok(Self::Xx1(reader.decode()?)), + HandshakeKind::Xx2 => Ok(Self::Xx2(reader.decode()?)), + HandshakeKind::Xx3 => Ok(Self::Xx3(reader.decode()?)), + HandshakeKind::Xx4 => Ok(Self::Xx4(reader.decode()?)), + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct QlSessionRecord { + pub header: SessionHeader, + pub payload: EncryptedMessage, +} + +impl> WireEncode for QlSessionRecord { + fn encoded_len(&self) -> usize { + self.header.encoded_len() + self.payload.encoded_len() + } + + fn encode(&self, out: &mut W) { + self.header.encode(out); + self.payload.encode(out); + } +} + +impl QlSessionRecord { + pub fn into_owned(self) -> QlSessionRecord> { + QlSessionRecord { + header: self.header, + payload: self.payload.into_owned(), + } + } +} + +impl WireDecode for QlSessionRecord { + fn decode(reader: &mut codec::Reader) -> Result { + Ok(Self { + header: reader.decode()?, + payload: reader.decode()?, + }) + } +} diff --git a/ql-wire/src/testing.rs b/ql-wire/src/testing.rs new file mode 100644 index 00000000..a1223c12 --- /dev/null +++ b/ql-wire/src/testing.rs @@ -0,0 +1,181 @@ +use libcrux_aesgcm::AesGcm256Key; +use libcrux_ml_kem::mlkem1024; +use sha2::{Digest, Sha256}; + +use crate::{ + MlKemCiphertext, MlKemKeyPair, MlKemPrivateKey, MlKemPublicKey, Nonce, QlAead, QlCrypto, + QlHash, QlIdentity, QlKem, QlRandom, SessionKey, ENCRYPTED_MESSAGE_AUTH_SIZE, +}; + +#[derive(Debug, Default, Clone, Copy)] +pub struct SoftwareCrypto; + +#[derive(Debug, Default, Clone, Copy)] +pub struct NoopCrypto; + +pub fn test_identities(crypto: &impl QlCrypto) -> (QlIdentity, QlIdentity) { + ( + crate::generate_identity(crypto, "alice").unwrap(), + crate::generate_identity(crypto, "bob").unwrap(), + ) +} + +impl QlRandom for SoftwareCrypto { + fn fill_random_bytes(&self, out: &mut [u8]) { + getrandom::getrandom(out).unwrap(); + } +} + +impl QlHash for SoftwareCrypto { + fn sha256(&self, parts: &[&[u8]]) -> [u8; 32] { + let mut hasher = Sha256::new(); + for part in parts { + hasher.update(part); + } + hasher.finalize().into() + } +} + +impl QlAead for SoftwareCrypto { + fn aes256_gcm_encrypt( + &self, + key: &SessionKey, + nonce: &Nonce, + aad: &[u8], + buffer: &mut [u8], + ) -> [u8; ENCRYPTED_MESSAGE_AUTH_SIZE] { + let key: AesGcm256Key = (*key.data()).into(); + let plaintext = buffer.to_vec(); + let mut auth = [0u8; ENCRYPTED_MESSAGE_AUTH_SIZE]; + key.encrypt( + buffer, + (&mut auth).into(), + (&nonce.0).into(), + aad, + &plaintext, + ) + .unwrap(); + auth + } + + fn aes256_gcm_decrypt( + &self, + key: &SessionKey, + nonce: &Nonce, + aad: &[u8], + buffer: &mut [u8], + auth_tag: &[u8; ENCRYPTED_MESSAGE_AUTH_SIZE], + ) -> bool { + let key: AesGcm256Key = (*key.data()).into(); + let ciphertext = buffer.to_vec(); + key.decrypt(buffer, (&nonce.0).into(), aad, &ciphertext, auth_tag.into()) + .is_ok() + } +} + +impl QlKem for SoftwareCrypto { + fn mlkem_generate_keypair(&self) -> MlKemKeyPair { + let key_pair = mlkem1024::generate_key_pair(random_array(self)); + let mut public = [0u8; MlKemPublicKey::SIZE]; + public.copy_from_slice(key_pair.pk()); + let mut private = [0u8; MlKemPrivateKey::SIZE]; + private.copy_from_slice(key_pair.sk()); + + MlKemKeyPair { + private: MlKemPrivateKey::new(Box::new(private)), + public: MlKemPublicKey::new(Box::new(public)), + } + } + + fn mlkem_encapsulate(&self, public_key: &MlKemPublicKey) -> (MlKemCiphertext, SessionKey) { + let public_key = public_key.as_bytes().into(); + let (ciphertext_value, shared_value) = + mlkem1024::encapsulate(&public_key, random_array(self)); + let mut ciphertext = [0u8; MlKemCiphertext::SIZE]; + ciphertext.copy_from_slice(ciphertext_value.as_slice()); + let mut shared = [0u8; SessionKey::SIZE]; + shared.copy_from_slice(shared_value.as_slice()); + ( + MlKemCiphertext::new(Box::new(ciphertext)), + SessionKey::from_data(shared), + ) + } + + fn mlkem_decapsulate( + &self, + private_key: &MlKemPrivateKey, + ciphertext: &MlKemCiphertext, + ) -> SessionKey { + let private_key = private_key.as_bytes().into(); + let ciphertext = ciphertext.as_bytes().into(); + let shared = mlkem1024::decapsulate(&private_key, &ciphertext); + let mut out = [0u8; SessionKey::SIZE]; + out.copy_from_slice(shared.as_slice()); + SessionKey::from_data(out) + } +} + +impl QlRandom for NoopCrypto { + fn fill_random_bytes(&self, out: &mut [u8]) { + out.fill(0); + } +} + +impl QlHash for NoopCrypto { + fn sha256(&self, _parts: &[&[u8]]) -> [u8; 32] { + [0; 32] + } +} + +impl QlAead for NoopCrypto { + fn aes256_gcm_encrypt( + &self, + _key: &SessionKey, + _nonce: &Nonce, + _aad: &[u8], + _buffer: &mut [u8], + ) -> [u8; ENCRYPTED_MESSAGE_AUTH_SIZE] { + [0; ENCRYPTED_MESSAGE_AUTH_SIZE] + } + + fn aes256_gcm_decrypt( + &self, + _key: &SessionKey, + _nonce: &Nonce, + _aad: &[u8], + _buffer: &mut [u8], + _auth_tag: &[u8; ENCRYPTED_MESSAGE_AUTH_SIZE], + ) -> bool { + false + } +} + +impl QlKem for NoopCrypto { + fn mlkem_generate_keypair(&self) -> MlKemKeyPair { + MlKemKeyPair { + private: MlKemPrivateKey::new(Box::new([0; MlKemPrivateKey::SIZE])), + public: MlKemPublicKey::new(Box::new([0; MlKemPublicKey::SIZE])), + } + } + + fn mlkem_encapsulate(&self, _public_key: &MlKemPublicKey) -> (MlKemCiphertext, SessionKey) { + ( + MlKemCiphertext::new(Box::new([0; MlKemCiphertext::SIZE])), + SessionKey::from_data([0; SessionKey::SIZE]), + ) + } + + fn mlkem_decapsulate( + &self, + _private_key: &MlKemPrivateKey, + _ciphertext: &MlKemCiphertext, + ) -> SessionKey { + SessionKey::from_data([0; SessionKey::SIZE]) + } +} + +fn random_array(crypto: &impl QlRandom) -> [u8; L] { + let mut out = [0u8; L]; + crypto.fill_random_bytes(&mut out); + out +} diff --git a/ql-wire/src/tests.rs b/ql-wire/src/tests.rs new file mode 100644 index 00000000..61826ddd --- /dev/null +++ b/ql-wire/src/tests.rs @@ -0,0 +1,1020 @@ +use std::ops::RangeInclusive; + +use super::*; + +fn decode_handshake_record(bytes: &[u8]) -> QlHandshakeRecord { + decode_record(bytes).unwrap().1 +} + +fn decode_session_record(bytes: &[u8]) -> QlSessionRecord> { + let (_, record) = decode_record::, _>(bytes).unwrap(); + record.into_owned() +} + +fn qid(byte: u8) -> QID { + QID([byte; QID::SIZE]) +} + +fn varint(value: u64) -> VarInt { + VarInt::from_u64(value).unwrap() +} + +fn record_seq(value: u64) -> RecordSeq { + RecordSeq(varint(value)) +} + +fn record_ack_range(start: u64, end: u64) -> RangeInclusive { + record_seq(start)..=record_seq(end) +} + +fn stream_id(value: u64) -> StreamId { + StreamId(varint(value)) +} + +fn handshake_meta(id: u32) -> HandshakeMeta { + HandshakeMeta { + handshake_id: HandshakeId(id), + } +} + +fn handshake_transport_params(window: u32) -> TransportParams { + TransportParams { + initial_stream_receive_window: window, + } +} + +fn handshake_header(sender: u8, recipient: u8) -> HandshakeHeader { + HandshakeHeader { + sender: qid(sender), + recipient: qid(recipient), + } +} + +fn pairing_token(byte: u8) -> PairingToken { + PairingToken([byte; PairingToken::SIZE]) +} + +fn pairing_id(byte: u8) -> PairingId { + PairingId([byte; PairingId::SIZE]) +} + +fn xx_header(sender: u8, recipient: u8) -> HandshakeHeader { + HandshakeHeader { + sender: qid(sender), + recipient: qid(recipient), + } +} + +fn encrypt_record( + crypto: &impl QlCrypto, + header: SessionHeader, + session_key: &SessionKey, + body: &[SessionFrame>], +) -> QlSessionRecord> { + let mut builder = SessionRecordBuilder::new(header.seq, usize::MAX); + for frame in body { + let pushed = builder.push_frame(frame); + debug_assert!(pushed); + } + decode_session_record( + builder + .encrypt(crypto, header.connection_id, session_key) + .as_slice(), + ) +} + +#[test] +fn peer_bundle_round_trip() { + let crypto = SoftwareCrypto; + let identity = generate_identity(&crypto, "alice") + .unwrap() + .with_capabilities(0x55aa_33cc); + let bundle = identity.bundle(); + + let encoded = bundle.encode_vec(); + let decoded = PeerBundle::decode_exact(encoded.as_slice()).unwrap(); + + assert_eq!(decoded, bundle); + assert_eq!(&*decoded.name, "alice"); +} + +#[test] +fn identity_name_validation() { + assert_eq!( + QlName::new("a".repeat(QlName::MAX_LEN)).unwrap().len(), + QlName::MAX_LEN + ); + assert!(matches!( + QlName::new(""), + Err(WireError::InvalidPayload) + )); + assert!(matches!( + QlName::new("a".repeat(QlName::MAX_LEN + 1)), + Err(WireError::InvalidPayload) + )); +} + +#[test] +fn qid_derives_from_mlkem_public_key() { + let crypto = SoftwareCrypto; + let public_key = MlKemPublicKey::new(Box::new([42; MlKemPublicKey::SIZE])); + let qid = QID::derive(&crypto, &public_key); + + let digest = crypto.sha256(&[ + b"quantum-link qid v1", + ML_KEM_SUITE_TAG, + public_key.as_bytes(), + ]); + let mut expected = [0u8; QID::SIZE]; + expected.copy_from_slice(&digest[..QID::SIZE]); + + assert_eq!(qid, QID(expected)); + assert!(qid.matches_public_key(&crypto, &public_key)); +} + +#[test] +fn qid_changes_when_mlkem_public_key_changes() { + let crypto = SoftwareCrypto; + let first = MlKemPublicKey::new(Box::new([1; MlKemPublicKey::SIZE])); + let second = MlKemPublicKey::new(Box::new([2; MlKemPublicKey::SIZE])); + + assert_ne!(QID::derive(&crypto, &first), QID::derive(&crypto, &second)); +} + +#[test] +fn peer_bundle_detects_tampered_qid() { + let crypto = SoftwareCrypto; + let identity = generate_identity(&crypto, "alice").unwrap(); + let mut bundle = identity.bundle(); + + bundle.qid = qid(9); + + assert!(!bundle.qid_matches_public_key(&crypto)); +} + +#[test] +fn handshake_record_round_trip_supports_ik_kk_and_xx() { + let ik = QlHandshakeRecord::Ik1(Ik1 { + header: handshake_header(1, 2), + meta: handshake_meta(1), + transport_params: handshake_transport_params(65_536), + skem_ciphertext: MlKemCiphertext::new(Box::new([7; MlKemCiphertext::SIZE])), + ephemeral: EphemeralPublicKey { + mlkem_public_key: MlKemPublicKey::new(Box::new([9; MlKemPublicKey::SIZE])), + }, + static_bundle: EncryptedPeerBundle(vec![13; 64].into_boxed_slice()), + }); + let ik_encoded = encode_record_vec(RecordType::Handshake, &ik); + assert_eq!( + RecordHeader::decode_bytes(ik_encoded.as_slice()).unwrap(), + RecordHeader { + version: QL_WIRE_VERSION, + record_type: RecordType::Handshake, + } + ); + assert_eq!(decode_handshake_record(ik_encoded.as_slice()), ik); + + let kk = QlHandshakeRecord::Kk1(Kk1 { + header: handshake_header(1, 2), + meta: handshake_meta(2), + transport_params: handshake_transport_params(131_072), + skem_ciphertext: MlKemCiphertext::new(Box::new([11; MlKemCiphertext::SIZE])), + ephemeral: EphemeralPublicKey { + mlkem_public_key: MlKemPublicKey::new(Box::new([15; MlKemPublicKey::SIZE])), + }, + }); + let kk_encoded = encode_record_vec(RecordType::Handshake, &kk); + assert_eq!( + RecordHeader::decode_bytes(kk_encoded.as_slice()).unwrap(), + RecordHeader { + version: QL_WIRE_VERSION, + record_type: RecordType::Handshake, + } + ); + assert_eq!(decode_handshake_record(kk_encoded.as_slice()), kk); + + let xx = QlHandshakeRecord::Xx1(Xx1 { + header: xx_header(1, 2), + meta: handshake_meta(3), + pairing_id: pairing_id(3), + transport_params: handshake_transport_params(196_608), + ephemeral: EphemeralPublicKey { + mlkem_public_key: MlKemPublicKey::new(Box::new([17; MlKemPublicKey::SIZE])), + }, + }); + let xx_encoded = encode_record_vec(RecordType::Handshake, &xx); + assert_eq!( + RecordHeader::decode_bytes(xx_encoded.as_slice()).unwrap(), + RecordHeader { + version: QL_WIRE_VERSION, + record_type: RecordType::Handshake, + } + ); + assert_eq!(decode_handshake_record(xx_encoded.as_slice()), xx); +} + +#[test] +fn ik_handshake_rejects_tampered_handshake_meta() { + let crypto = SoftwareCrypto; + let (initiator, responder) = test_identities(&crypto); + + let mut initiator_state = IkHandshake::new_initiator( + &crypto, + initiator, + responder.bundle(), + TransportParams::default(), + ); + let mut responder_state = + IkHandshake::new_responder(&crypto, responder, None, TransportParams::default()); + + let m1 = initiator_state + .write_1(&crypto, handshake_meta(77)) + .unwrap(); + responder_state.read_1(&crypto, &m1).unwrap(); + + let mut m2 = responder_state + .write_2(&crypto, handshake_meta(77)) + .unwrap(); + m2.meta.handshake_id = HandshakeId(78); + + assert_eq!( + initiator_state.read_2(&crypto, &m2), + Err(WireError::InvalidHandshakeMeta) + ); +} + +#[test] +fn kk_handshake_rejects_tampered_handshake_header() { + let crypto = SoftwareCrypto; + let (initiator, responder) = test_identities(&crypto); + + let mut initiator_state = KkHandshake::new_initiator( + &crypto, + initiator.clone(), + responder.bundle(), + TransportParams::default(), + ); + let mut responder_state = KkHandshake::new_responder( + &crypto, + responder, + initiator.bundle(), + TransportParams::default(), + ); + + let m1 = initiator_state + .write_1(&crypto, handshake_meta(88)) + .unwrap(); + responder_state.read_1(&crypto, &m1).unwrap(); + + let mut m2 = responder_state + .write_2(&crypto, handshake_meta(88)) + .unwrap(); + m2.header = handshake_header(9, 1); + + assert_eq!( + initiator_state.read_2(&crypto, &m2), + Err(WireError::InvalidPayload) + ); +} + +#[test] +fn ik_handshake_rejects_tampered_transport_params() { + let crypto = SoftwareCrypto; + let (initiator, responder) = test_identities(&crypto); + + let mut initiator_state = IkHandshake::new_initiator( + &crypto, + initiator, + responder.bundle(), + handshake_transport_params(4096), + ); + let mut responder_state = + IkHandshake::new_responder(&crypto, responder, None, handshake_transport_params(8192)); + + let m1 = initiator_state + .write_1(&crypto, handshake_meta(89)) + .unwrap(); + responder_state.read_1(&crypto, &m1).unwrap(); + + let mut m2 = responder_state + .write_2(&crypto, handshake_meta(89)) + .unwrap(); + m2.transport_params.initial_stream_receive_window += 1; + + assert_eq!( + initiator_state.read_2(&crypto, &m2), + Err(WireError::DecryptFailed) + ); +} + +#[test] +fn ik_handshake_rejects_tampered_handshake_header() { + let crypto = SoftwareCrypto; + let (initiator, responder) = test_identities(&crypto); + + let mut initiator_state = IkHandshake::new_initiator( + &crypto, + initiator, + responder.bundle(), + TransportParams::default(), + ); + let mut responder_state = + IkHandshake::new_responder(&crypto, responder, None, TransportParams::default()); + + let mut m1 = initiator_state + .write_1(&crypto, handshake_meta(90)) + .unwrap(); + m1.header.sender = qid(9); + + assert_eq!( + responder_state.read_1(&crypto, &m1), + Err(WireError::DecryptFailed) + ); +} + +#[test] +fn ik_handshake_rejects_bound_remote_bundle_mismatch() { + let crypto = SoftwareCrypto; + let (initiator, responder) = test_identities(&crypto); + let bogus = generate_identity(&crypto, "bogus").unwrap(); + + let mut initiator_state = IkHandshake::new_initiator( + &crypto, + initiator, + responder.bundle(), + TransportParams::default(), + ); + let mut responder_state = IkHandshake::new_responder( + &crypto, + responder, + Some(bogus.bundle()), + TransportParams::default(), + ); + + let m1 = initiator_state + .write_1(&crypto, handshake_meta(91)) + .unwrap(); + + assert_eq!( + responder_state.read_1(&crypto, &m1), + Err(WireError::InvalidPayload) + ); +} + +#[test] +fn ik_handshake_round_trip_derives_matching_transport_and_learns_remote() { + let crypto = SoftwareCrypto; + let (initiator, responder) = test_identities(&crypto); + + let initiator_params = handshake_transport_params(4096); + let responder_params = handshake_transport_params(8192); + let mut initiator_state = IkHandshake::new_initiator( + &crypto, + initiator.clone(), + responder.bundle(), + initiator_params, + ); + let mut responder_state = + IkHandshake::new_responder(&crypto, responder.clone(), None, responder_params); + + let m1 = initiator_state + .write_1(&crypto, handshake_meta(11)) + .unwrap(); + responder_state.read_1(&crypto, &m1).unwrap(); + + let m2 = responder_state + .write_2(&crypto, handshake_meta(11)) + .unwrap(); + initiator_state.read_2(&crypto, &m2).unwrap(); + + let initiator_final = initiator_state.finalize(&crypto).unwrap(); + let responder_final = responder_state.finalize(&crypto).unwrap(); + + assert_eq!( + initiator_final.handshake_hash, + responder_final.handshake_hash + ); + assert_eq!(initiator_final.tx_key, responder_final.rx_key); + assert_eq!(initiator_final.rx_key, responder_final.tx_key); + assert_eq!( + initiator_final.tx_connection_id, + responder_final.rx_connection_id + ); + assert_eq!( + initiator_final.rx_connection_id, + responder_final.tx_connection_id + ); + assert_eq!(initiator_final.remote_bundle, responder.bundle()); + assert_eq!(responder_final.remote_bundle, initiator.bundle()); + assert_eq!(initiator_final.remote_transport_params, responder_params); + assert_eq!(responder_final.remote_transport_params, initiator_params); +} + +#[test] +fn ik_handshake_round_trip_derives_matching_transport_with_bound_responder() { + let crypto = SoftwareCrypto; + let (initiator, responder) = test_identities(&crypto); + + let initiator_params = handshake_transport_params(16_384); + let responder_params = handshake_transport_params(32_768); + let mut initiator_state = IkHandshake::new_initiator( + &crypto, + initiator.clone(), + responder.bundle(), + initiator_params, + ); + let mut responder_state = IkHandshake::new_responder( + &crypto, + responder.clone(), + Some(initiator.bundle()), + responder_params, + ); + + let m1 = initiator_state + .write_1(&crypto, handshake_meta(12)) + .unwrap(); + responder_state.read_1(&crypto, &m1).unwrap(); + + let m2 = responder_state + .write_2(&crypto, handshake_meta(12)) + .unwrap(); + initiator_state.read_2(&crypto, &m2).unwrap(); + + let initiator_final = initiator_state.finalize(&crypto).unwrap(); + let responder_final = responder_state.finalize(&crypto).unwrap(); + + assert_eq!( + initiator_final.handshake_hash, + responder_final.handshake_hash + ); + assert_eq!(initiator_final.tx_key, responder_final.rx_key); + assert_eq!(initiator_final.rx_key, responder_final.tx_key); + assert_eq!( + initiator_final.tx_connection_id, + responder_final.rx_connection_id + ); + assert_eq!( + initiator_final.rx_connection_id, + responder_final.tx_connection_id + ); + assert_eq!(initiator_final.remote_bundle, responder.bundle()); + assert_eq!(responder_final.remote_bundle, initiator.bundle()); + assert_eq!(initiator_final.remote_transport_params, responder_params); + assert_eq!(responder_final.remote_transport_params, initiator_params); +} + +#[test] +fn kk_handshake_round_trip_derives_matching_transport() { + let crypto = SoftwareCrypto; + let (initiator, responder) = test_identities(&crypto); + + let initiator_params = handshake_transport_params(24_576); + let responder_params = handshake_transport_params(49_152); + let mut initiator_state = KkHandshake::new_initiator( + &crypto, + initiator.clone(), + responder.bundle(), + initiator_params, + ); + let mut responder_state = KkHandshake::new_responder( + &crypto, + responder.clone(), + initiator.bundle(), + responder_params, + ); + + let m1 = initiator_state + .write_1(&crypto, handshake_meta(21)) + .unwrap(); + responder_state.read_1(&crypto, &m1).unwrap(); + + let m2 = responder_state + .write_2(&crypto, handshake_meta(21)) + .unwrap(); + initiator_state.read_2(&crypto, &m2).unwrap(); + + let initiator_final = initiator_state.finalize(&crypto).unwrap(); + let responder_final = responder_state.finalize(&crypto).unwrap(); + + assert_eq!( + initiator_final.handshake_hash, + responder_final.handshake_hash + ); + assert_eq!(initiator_final.tx_key, responder_final.rx_key); + assert_eq!(initiator_final.rx_key, responder_final.tx_key); + assert_eq!( + initiator_final.tx_connection_id, + responder_final.rx_connection_id + ); + assert_eq!( + initiator_final.rx_connection_id, + responder_final.tx_connection_id + ); + assert_eq!(initiator_final.remote_bundle, responder.bundle()); + assert_eq!(responder_final.remote_bundle, initiator.bundle()); + assert_eq!(initiator_final.remote_transport_params, responder_params); + assert_eq!(responder_final.remote_transport_params, initiator_params); +} + +#[test] +fn kk_handshake_rejects_tampered_transport_params() { + let crypto = SoftwareCrypto; + let (initiator, responder) = test_identities(&crypto); + + let mut initiator_state = KkHandshake::new_initiator( + &crypto, + initiator.clone(), + responder.bundle(), + handshake_transport_params(12288), + ); + let mut responder_state = KkHandshake::new_responder( + &crypto, + responder, + initiator.bundle(), + handshake_transport_params(24576), + ); + + let m1 = initiator_state + .write_1(&crypto, handshake_meta(22)) + .unwrap(); + responder_state.read_1(&crypto, &m1).unwrap(); + + let mut m2 = responder_state + .write_2(&crypto, handshake_meta(22)) + .unwrap(); + m2.transport_params.initial_stream_receive_window += 1; + + assert_eq!( + initiator_state.read_2(&crypto, &m2), + Err(WireError::DecryptFailed) + ); +} + +#[test] +fn xx_handshake_rejects_tampered_pairing_id() { + let crypto = SoftwareCrypto; + let (initiator, responder) = test_identities(&crypto); + let token = pairing_token(7); + + let mut initiator_state = XxHandshake::new_initiator( + &crypto, + initiator.clone(), + responder.qid, + token, + TransportParams::default(), + ); + let mut responder_state = XxHandshake::new_responder( + &crypto, + responder, + initiator.qid, + token, + TransportParams::default(), + ); + + let mut m1 = initiator_state + .write_1(&crypto, handshake_meta(31)) + .unwrap(); + m1.pairing_id = pairing_id(8); + + assert_eq!( + responder_state.read_1(&crypto, &m1), + Err(WireError::InvalidPairingId) + ); +} + +#[test] +fn xx_handshake_rejects_tampered_sender_or_recipient() { + let crypto = SoftwareCrypto; + let (initiator, responder) = test_identities(&crypto); + let token = pairing_token(7); + + let mut initiator_state = XxHandshake::new_initiator( + &crypto, + initiator.clone(), + responder.qid, + token, + TransportParams::default(), + ); + let mut responder_state = XxHandshake::new_responder( + &crypto, + responder.clone(), + initiator.qid, + token, + TransportParams::default(), + ); + + let mut m1 = initiator_state + .write_1(&crypto, handshake_meta(31)) + .unwrap(); + m1.header.sender = responder.qid; + + assert_eq!( + responder_state.read_1(&crypto, &m1), + Err(WireError::InvalidHandshakeHeader) + ); + + let mut initiator_state = XxHandshake::new_initiator( + &crypto, + initiator.clone(), + responder.qid, + token, + TransportParams::default(), + ); + let mut responder_state = XxHandshake::new_responder( + &crypto, + responder.clone(), + initiator.qid, + token, + TransportParams::default(), + ); + + let mut m1 = initiator_state + .write_1(&crypto, handshake_meta(31)) + .unwrap(); + m1.header.recipient = initiator.qid; + + assert_eq!( + responder_state.read_1(&crypto, &m1), + Err(WireError::InvalidHandshakeHeader) + ); +} + +#[test] +fn xx_handshake_rejects_repeated_transport_param_change() { + let crypto = SoftwareCrypto; + let (initiator, responder) = test_identities(&crypto); + let token = pairing_token(9); + + let mut initiator_state = XxHandshake::new_initiator( + &crypto, + initiator.clone(), + responder.qid, + token, + handshake_transport_params(12_288), + ); + let mut responder_state = XxHandshake::new_responder( + &crypto, + responder, + initiator.qid, + token, + handshake_transport_params(24_576), + ); + + let m1 = initiator_state + .write_1(&crypto, handshake_meta(32)) + .unwrap(); + responder_state.read_1(&crypto, &m1).unwrap(); + + let m2 = responder_state + .write_2(&crypto, handshake_meta(32)) + .unwrap(); + initiator_state.read_2(&crypto, &m2).unwrap(); + + let mut m3 = initiator_state + .write_3(&crypto, handshake_meta(32)) + .unwrap(); + m3.transport_params.initial_stream_receive_window += 1; + + assert_eq!( + responder_state.read_3(&crypto, &m3), + Err(WireError::InvalidTransportParams) + ); +} + +#[test] +fn xx_handshake_round_trip_derives_matching_transport_and_learns_remote() { + let crypto = SoftwareCrypto; + let (initiator, responder) = test_identities(&crypto); + let token = pairing_token(10); + + let initiator_params = handshake_transport_params(28_672); + let responder_params = handshake_transport_params(57_344); + let mut initiator_state = XxHandshake::new_initiator( + &crypto, + initiator.clone(), + responder.qid, + token, + initiator_params, + ); + let mut responder_state = XxHandshake::new_responder( + &crypto, + responder.clone(), + initiator.qid, + token, + responder_params, + ); + + assert_eq!(initiator_state.pairing_token(), token); + assert_eq!(responder_state.pairing_token(), token); + assert_eq!(initiator_state.pairing_id(&crypto), token.id(&crypto)); + assert_eq!(responder_state.pairing_id(&crypto), token.id(&crypto)); + assert!(initiator_state.remote_bundle().is_none()); + assert!(responder_state.remote_bundle().is_none()); + + let m1 = initiator_state + .write_1(&crypto, handshake_meta(33)) + .unwrap(); + responder_state.read_1(&crypto, &m1).unwrap(); + + let m2 = responder_state + .write_2(&crypto, handshake_meta(33)) + .unwrap(); + initiator_state.read_2(&crypto, &m2).unwrap(); + assert_eq!(initiator_state.remote_bundle(), Some(&responder.bundle())); + assert!(responder_state.remote_bundle().is_none()); + + let m3 = initiator_state + .write_3(&crypto, handshake_meta(33)) + .unwrap(); + responder_state.read_3(&crypto, &m3).unwrap(); + assert_eq!(responder_state.remote_bundle(), Some(&initiator.bundle())); + + let m4 = responder_state + .write_4(&crypto, handshake_meta(33)) + .unwrap(); + initiator_state.read_4(&crypto, &m4).unwrap(); + + let initiator_final = initiator_state.finalize(&crypto).unwrap(); + let responder_final = responder_state.finalize(&crypto).unwrap(); + + assert_eq!( + initiator_final.handshake_hash, + responder_final.handshake_hash + ); + assert_eq!(initiator_final.tx_key, responder_final.rx_key); + assert_eq!(initiator_final.rx_key, responder_final.tx_key); + assert_eq!( + initiator_final.tx_connection_id, + responder_final.rx_connection_id + ); + assert_eq!( + initiator_final.rx_connection_id, + responder_final.tx_connection_id + ); + assert_eq!(initiator_final.remote_bundle, responder.bundle()); + assert_eq!(responder_final.remote_bundle, initiator.bundle()); + assert_eq!(initiator_final.remote_transport_params, responder_params); + assert_eq!(responder_final.remote_transport_params, initiator_params); +} + +#[test] +fn encrypted_session_record_round_trip_uses_connection_id_header() { + let crypto = SoftwareCrypto; + let header = SessionHeader { + connection_id: ConnectionId::from_data([0x44; ConnectionId::SIZE]), + seq: record_seq(11), + }; + let body = vec![ + SessionFrame::Ping, + SessionFrame::Unpair, + SessionFrame::Ack( + RecordAck::from_ranges([record_ack_range(20, 23), record_ack_range(12, 13)]).unwrap(), + ), + SessionFrame::StreamWindow(StreamWindow { + stream_id: stream_id(9), + maximum_offset: varint(65_536), + }), + SessionFrame::StreamData(StreamData { + stream_id: stream_id(9), + offset: varint(1024), + header: None, + bytes: b"hello".to_vec(), + fin: true, + }), + SessionFrame::StreamClose(StreamClose { + stream_id: stream_id(9), + target: CloseTarget::Both, + code: StreamCloseCode::CANCELLED, + }), + SessionFrame::Close(SessionClose { + code: SessionCloseCode::TIMEOUT, + }), + ]; + let session_key = SessionKey::from_data([7; SessionKey::SIZE]); + let record = encrypt_record(&crypto, header, &session_key, &body); + + let bytes = encode_record_vec(RecordType::Session, &record); + assert_eq!( + RecordHeader::decode_bytes(bytes.as_slice()).unwrap(), + RecordHeader { + version: QL_WIRE_VERSION, + record_type: RecordType::Session, + } + ); + let decoded = decode_session_record(bytes.as_slice()); + assert_eq!(decoded.header, header); + let encrypted = decoded.payload; + + let decrypted = + encrypted::decrypt_record(&crypto, &header, encrypted.clone(), &session_key).unwrap(); + assert_eq!(decode_session_frames(&decrypted).unwrap(), body); + + let wrong_header = SessionHeader { + connection_id: ConnectionId::from_data([0x99; ConnectionId::SIZE]), + seq: header.seq, + }; + assert_eq!( + encrypted::decrypt_record(&crypto, &wrong_header, encrypted.clone(), &session_key), + Err(WireError::DecryptFailed) + ); + + let wrong_seq_header = SessionHeader { + connection_id: header.connection_id, + seq: record_seq(header.seq.into_inner() + 1), + }; + assert_eq!( + encrypted::decrypt_record(&crypto, &wrong_seq_header, encrypted, &session_key), + Err(WireError::DecryptFailed) + ); +} + +#[test] +fn session_varint_fields_expand_at_expected_boundaries() { + let short_header = SessionHeader { + connection_id: ConnectionId::from_data([0x11; ConnectionId::SIZE]), + seq: record_seq(63), + }; + let long_header = SessionHeader { + connection_id: ConnectionId::from_data([0x11; ConnectionId::SIZE]), + seq: record_seq(64), + }; + + assert_eq!(short_header.encode_vec().len(), ConnectionId::SIZE + 1); + assert_eq!(long_header.encode_vec().len(), ConnectionId::SIZE + 2); + + let frame = StreamData { + stream_id: stream_id(64), + offset: varint(16_384), + header: None, + fin: true, + bytes: b"abc".to_vec(), + }; + let encoded = frame.encode_vec(); + + assert_eq!( + StreamData::decode_exact(encoded.as_slice()) + .unwrap() + .into_owned(), + frame + ); +} + +#[test] +fn protocol_record_size_breakdown() { + fn print_size(label: &str, size: usize) { + println!("{label:<32}: {size} bytes"); + } + + let crypto = SoftwareCrypto; + let (initiator, responder) = test_identities(&crypto); + + let mut ik_initiator = IkHandshake::new_initiator( + &crypto, + initiator.clone(), + responder.bundle(), + TransportParams::default(), + ); + let mut ik_responder = + IkHandshake::new_responder(&crypto, responder.clone(), None, TransportParams::default()); + + let ik1 = ik_initiator.write_1(&crypto, handshake_meta(101)).unwrap(); + ik_responder.read_1(&crypto, &ik1).unwrap(); + + let ik2 = ik_responder.write_2(&crypto, handshake_meta(101)).unwrap(); + ik_initiator.read_2(&crypto, &ik2).unwrap(); + + let ik1 = QlHandshakeRecord::Ik1(ik1); + let ik2 = QlHandshakeRecord::Ik2(ik2); + + let mut kk_initiator = KkHandshake::new_initiator( + &crypto, + initiator.clone(), + responder.bundle(), + TransportParams::default(), + ); + let mut kk_responder = KkHandshake::new_responder( + &crypto, + responder.clone(), + initiator.bundle(), + TransportParams::default(), + ); + + let kk1 = kk_initiator.write_1(&crypto, handshake_meta(201)).unwrap(); + kk_responder.read_1(&crypto, &kk1).unwrap(); + + let kk2 = kk_responder.write_2(&crypto, handshake_meta(201)).unwrap(); + kk_initiator.read_2(&crypto, &kk2).unwrap(); + + let kk1 = QlHandshakeRecord::Kk1(kk1); + let kk2 = QlHandshakeRecord::Kk2(kk2); + + let token = pairing_token(0x42); + let mut xx_initiator = XxHandshake::new_initiator( + &crypto, + initiator.clone(), + responder.qid, + token, + TransportParams::default(), + ); + let mut xx_responder = XxHandshake::new_responder( + &crypto, + responder.clone(), + initiator.qid, + token, + TransportParams::default(), + ); + + let xx1 = xx_initiator.write_1(&crypto, handshake_meta(301)).unwrap(); + xx_responder.read_1(&crypto, &xx1).unwrap(); + + let xx2 = xx_responder.write_2(&crypto, handshake_meta(301)).unwrap(); + xx_initiator.read_2(&crypto, &xx2).unwrap(); + + let xx3 = xx_initiator.write_3(&crypto, handshake_meta(301)).unwrap(); + xx_responder.read_3(&crypto, &xx3).unwrap(); + + let xx4 = xx_responder.write_4(&crypto, handshake_meta(301)).unwrap(); + xx_initiator.read_4(&crypto, &xx4).unwrap(); + + let xx1 = QlHandshakeRecord::Xx1(xx1); + let xx2 = QlHandshakeRecord::Xx2(xx2); + let xx3 = QlHandshakeRecord::Xx3(xx3); + let xx4 = QlHandshakeRecord::Xx4(xx4); + + let session = ik_initiator.finalize(&crypto).unwrap(); + let session_ping = encrypt_record( + &crypto, + SessionHeader { + connection_id: session.tx_connection_id, + seq: record_seq(1), + }, + &session.tx_key, + &[SessionFrame::Ping], + ); + let session_ack = encrypt_record( + &crypto, + SessionHeader { + connection_id: session.tx_connection_id, + seq: record_seq(2), + }, + &session.tx_key, + &[SessionFrame::Ack( + RecordAck::from_ranges([record_ack_range(6, 6), record_ack_range(1, 2)]).unwrap(), + )], + ); + let session_unpair = encrypt_record( + &crypto, + SessionHeader { + connection_id: session.tx_connection_id, + seq: record_seq(3), + }, + &session.tx_key, + &[SessionFrame::Unpair], + ); + let session_stream_empty = encrypt_record( + &crypto, + SessionHeader { + connection_id: session.tx_connection_id, + seq: record_seq(4), + }, + &session.tx_key, + &[SessionFrame::StreamData(StreamData { + stream_id: stream_id(1), + offset: varint(0), + header: None, + fin: false, + bytes: Vec::new(), + })], + ); + let session_close = encrypt_record( + &crypto, + SessionHeader { + connection_id: session.tx_connection_id, + seq: record_seq(5), + }, + &session.tx_key, + &[SessionFrame::Close(SessionClose { + code: SessionCloseCode::PROTOCOL, + })], + ); + + print_size("ql-wire peer bundle", initiator.bundle().encode_vec().len()); + print_size("ql-wire mlkem public key", MlKemPublicKey::SIZE); + print_size("ql-wire mlkem ciphertext", MlKemCiphertext::SIZE); + print_size("ql-wire pq ik1", ik1.encode_vec().len()); + print_size("ql-wire pq ik2", ik2.encode_vec().len()); + print_size("ql-wire pq kk1", kk1.encode_vec().len()); + print_size("ql-wire pq kk2", kk2.encode_vec().len()); + print_size("ql-wire pq xx1", xx1.encode_vec().len()); + print_size("ql-wire pq xx2", xx2.encode_vec().len()); + print_size("ql-wire pq xx3", xx3.encode_vec().len()); + print_size("ql-wire pq xx4", xx4.encode_vec().len()); + print_size("ql-wire session ping", session_ping.encode_vec().len()); + print_size("ql-wire session ack", session_ack.encode_vec().len()); + print_size("ql-wire session unpair", session_unpair.encode_vec().len()); + print_size( + "ql-wire session stream empty", + session_stream_empty.encode_vec().len(), + ); + print_size("ql-wire session close", session_close.encode_vec().len()); +} diff --git a/ql-wire/src/varint.rs b/ql-wire/src/varint.rs new file mode 100644 index 00000000..7a39bd16 --- /dev/null +++ b/ql-wire/src/varint.rs @@ -0,0 +1,181 @@ +use core::fmt; + +use bytes::BufMut; + +use crate::{ByteSlice, Reader, WireDecode, WireEncode, WireError}; + +/// An integer less than 2^62 encoded with QUIC variable-length integer rules. +#[derive(Default, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)] +pub struct VarInt(pub(crate) u64); + +impl VarInt { + /// The largest representable value. + pub const MAX: Self = Self((1u64 << 62) - 1); + /// The largest encoded value length. + pub const MAX_SIZE: usize = 8; + pub const MIN_SIZE: usize = 1; + + /// Construct a `VarInt` infallibly from a `u32`. + pub const fn from_u32(x: u32) -> Self { + Self(x as u64) + } + + /// Construct a `VarInt` from a `u64`. + pub fn from_u64(x: u64) -> Result { + if x < (1u64 << 62) { + Ok(Self(x)) + } else { + Err(VarIntBoundsExceeded) + } + } + + /// Create a `VarInt` without checking the bounds. + /// + /// # Safety + /// + /// `x` must be less than 2^62. + pub const unsafe fn from_u64_unchecked(x: u64) -> Self { + Self(x) + } + + /// Extract the inner integer value. + pub const fn into_inner(self) -> u64 { + self.0 + } + + /// Return the number of bytes required to encode this value. + pub const fn size(self) -> usize { + let x = self.0; + if x < (1u64 << 6) { + 1 + } else if x < (1u64 << 14) { + 2 + } else if x < (1u64 << 30) { + 4 + } else { + 8 + } + } +} + +impl From for u64 { + fn from(value: VarInt) -> Self { + value.0 + } +} + +impl From for VarInt { + fn from(value: u8) -> Self { + Self(value.into()) + } +} + +impl From for VarInt { + fn from(value: u16) -> Self { + Self(value.into()) + } +} + +impl From for VarInt { + fn from(value: u32) -> Self { + Self(value.into()) + } +} + +impl TryFrom for VarInt { + type Error = VarIntBoundsExceeded; + + fn try_from(value: u64) -> Result { + Self::from_u64(value) + } +} + +impl TryFrom for VarInt { + type Error = VarIntBoundsExceeded; + + fn try_from(value: u128) -> Result { + Self::from_u64(value.try_into().map_err(|_| VarIntBoundsExceeded)?) + } +} + +impl TryFrom for VarInt { + type Error = VarIntBoundsExceeded; + + fn try_from(value: usize) -> Result { + Self::from_u64(value as u64) + } +} + +impl fmt::Debug for VarInt { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.0.fmt(f) + } +} + +impl fmt::Display for VarInt { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.0.fmt(f) + } +} + +#[derive(Debug, Copy, Clone, Eq, PartialEq)] +pub struct VarIntBoundsExceeded; + +impl fmt::Display for VarIntBoundsExceeded { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("value too large for varint encoding") + } +} + +impl std::error::Error for VarIntBoundsExceeded {} + +impl WireDecode for VarInt { + fn decode(reader: &mut Reader) -> Result { + let first = reader.decode::()?; + let tag = first >> 6; + let first = first & 0b0011_1111; + let value = match tag { + 0b00 => u64::from(first), + 0b01 => { + let mut buf = [0; 2]; + buf[0] = first; + buf[1] = reader.decode()?; + u64::from(u16::from_be_bytes(buf)) + } + 0b10 => { + let mut buf = [0; 4]; + buf[0] = first; + buf[1..].copy_from_slice(&reader.take_bytes(3)?); + u64::from(u32::from_be_bytes(buf)) + } + 0b11 => { + let mut buf = [0; 8]; + buf[0] = first; + buf[1..].copy_from_slice(&reader.take_bytes(7)?); + u64::from_be_bytes(buf) + } + _ => unreachable!(), + }; + + // SAFETY: the decoded value is guaranteed to fit in the 62-bit varint range. + Ok(unsafe { Self::from_u64_unchecked(value) }) + } +} + +impl WireEncode for VarInt { + fn encoded_len(&self) -> usize { + self.size() + } + + #[allow(clippy::cast_possible_truncation)] + fn encode(&self, out: &mut W) { + let x = self.into_inner(); + match self.size() { + 1 => out.put_u8(x as u8), + 2 => out.put_u16((0b01 << 14) | x as u16), + 4 => out.put_u32((0b10 << 30) | x as u32), + 8 => out.put_u64((0b11 << 62) | x), + _ => unreachable!("malformed varint"), + } + } +}