From 39e6cedc759f6eebb1b15ec013e1d4c08de09cdb Mon Sep 17 00:00:00 2001 From: Herrtian <70463940+Herrtian@users.noreply.github.com> Date: Sun, 28 Jun 2026 16:15:04 +0200 Subject: [PATCH] fix: validate advertised API version ranges --- crates/fluss/src/rpc/server_connection.rs | 117 +++++++++++++++++----- 1 file changed, 90 insertions(+), 27 deletions(-) diff --git a/crates/fluss/src/rpc/server_connection.rs b/crates/fluss/src/rpc/server_connection.rs index a40fd58d..257baf52 100644 --- a/crates/fluss/src/rpc/server_connection.rs +++ b/crates/fluss/src/rpc/server_connection.rs @@ -85,17 +85,17 @@ pub struct ServerApiVersions { impl ServerApiVersions { /// Build from the server's advertised API version list. - pub(crate) fn new(server_versions: &[PbApiVersion]) -> Self { + pub(crate) fn new(server_versions: &[PbApiVersion]) -> Result { let mut versions = HashMap::new(); for sv in server_versions { - let api_key = ApiKey::from(i16::try_from(sv.api_key).unwrap()); + let api_key = ApiKey::from(api_version_i16_field("api_key", sv.api_key)?); // Skip unknown API keys — the client does not support them. let client_range = match api_key.supported_versions() { Some(range) => range, None => continue, }; - let server_min = i16::try_from(sv.min_version).unwrap(); - let server_max = i16::try_from(sv.max_version).unwrap(); + let server_min = api_version_i16_field("min_version", sv.min_version)?; + let server_max = api_version_i16_field("max_version", sv.max_version)?; let min_version = client_range.min().0.max(server_min); let max_version = client_range.max().0.min(server_max); if min_version > max_version { @@ -115,7 +115,7 @@ impl ServerApiVersions { versions.insert(api_key, Ok(ApiVersion(max_version))); } } - Self { versions } + Ok(Self { versions }) } /// Get the negotiated (highest usable) version for a given API key. @@ -132,6 +132,14 @@ impl ServerApiVersions { } } +fn api_version_i16_field(field: &str, value: i32) -> Result { + i16::try_from(value).map_err(|_| Error::UnsupportedVersion { + message: format!( + "ApiVersionsResponse field `{field}` has value {value}, which is outside i16 range" + ), + }) +} + /// Resolve the API version to use for a given API key. fn resolve_api_version_for( api_versions: Option<&ServerApiVersions>, @@ -266,7 +274,7 @@ impl RpcClient { let request = ApiVersionsRequest::new("fluss-rust", env!("CARGO_PKG_VERSION")); let response = connection.request(request).await?; validate_server_type(expected_server_type, response.server_type)?; - let api_versions = ServerApiVersions::new(&response.api_versions); + let api_versions = ServerApiVersions::new(&response.api_versions)?; *connection.api_versions.lock() = Some(api_versions); Ok(()) } @@ -995,11 +1003,14 @@ mod tests { tokio::spawn(mock_echo_server(server)); let conn = ServerConnectionInner::new(BufStream::new(client), usize::MAX, Arc::from("t")); - *conn.api_versions.lock() = Some(ServerApiVersions::new(&[PbApiVersion { - api_key: 1014, - min_version: 0, - max_version: 0, - }])); + *conn.api_versions.lock() = Some( + ServerApiVersions::new(&[PbApiVersion { + api_key: 1014, + min_version: 0, + max_version: 0, + }]) + .unwrap(), + ); let before: Vec<_> = snapshotter.snapshot().into_vec(); let request_before = counter_for_label(&before, CLIENT_REQUESTS_TOTAL, "produce_log"); @@ -1040,11 +1051,14 @@ mod tests { tokio::spawn(mock_echo_server(server)); let conn = ServerConnectionInner::new(BufStream::new(client), usize::MAX, Arc::from("t")); - *conn.api_versions.lock() = Some(ServerApiVersions::new(&[PbApiVersion { - api_key: 1012, - min_version: 0, - max_version: 0, - }])); + *conn.api_versions.lock() = Some( + ServerApiVersions::new(&[PbApiVersion { + api_key: 1012, + min_version: 0, + max_version: 0, + }]) + .unwrap(), + ); let before: Vec<_> = snapshotter.snapshot().into_vec(); let request_sum_before = counter_sum(&before, CLIENT_REQUESTS_TOTAL); let response_sum_before = counter_sum(&before, CLIENT_RESPONSES_TOTAL); @@ -1081,11 +1095,14 @@ mod tests { let (client, server) = tokio::io::duplex(64); drop(server); // force write failure on request path let conn = ServerConnectionInner::new(BufStream::new(client), usize::MAX, Arc::from("t")); - *conn.api_versions.lock() = Some(ServerApiVersions::new(&[PbApiVersion { - api_key: 1014, - min_version: 0, - max_version: 0, - }])); + *conn.api_versions.lock() = Some( + ServerApiVersions::new(&[PbApiVersion { + api_key: 1014, + min_version: 0, + max_version: 0, + }]) + .unwrap(), + ); let before: Vec<_> = snapshotter.snapshot().into_vec(); let request_before = counter_for_label(&before, CLIENT_REQUESTS_TOTAL, "produce_log"); @@ -1134,11 +1151,14 @@ mod tests { tokio::spawn(mock_error_server(server)); let conn = ServerConnectionInner::new(BufStream::new(client), usize::MAX, Arc::from("t")); - *conn.api_versions.lock() = Some(ServerApiVersions::new(&[PbApiVersion { - api_key: 1014, - min_version: 0, - max_version: 0, - }])); + *conn.api_versions.lock() = Some( + ServerApiVersions::new(&[PbApiVersion { + api_key: 1014, + min_version: 0, + max_version: 0, + }]) + .unwrap(), + ); let before: Vec<_> = snapshotter.snapshot().into_vec(); let response_before = counter_for_label(&before, CLIENT_RESPONSES_TOTAL, "produce_log"); @@ -1211,7 +1231,7 @@ mod tests { max_version: 5, }, ]; - let negotiated = ServerApiVersions::new(&server_versions); + let negotiated = ServerApiVersions::new(&server_versions).unwrap(); // Successful negotiation cases assert_eq!( @@ -1247,11 +1267,54 @@ mod tests { // Key not advertised by server → error assert!( ServerApiVersions::new(&[]) + .unwrap() .highest_available_version(ApiKey::FetchLog) .is_err() ); } + #[test] + fn server_api_versions_rejects_out_of_range_fields() { + let cases = [ + ( + "api_key", + PbApiVersion { + api_key: i32::from(i16::MAX) + 1, + min_version: 0, + max_version: 0, + }, + ), + ( + "min_version", + PbApiVersion { + api_key: 1014, + min_version: i32::from(i16::MAX) + 1, + max_version: 0, + }, + ), + ( + "max_version", + PbApiVersion { + api_key: 1014, + min_version: 0, + max_version: i32::from(i16::MAX) + 1, + }, + ), + ]; + + for (field, api_version) in cases { + let err = ServerApiVersions::new(&[api_version]).unwrap_err(); + assert!( + matches!(err, Error::UnsupportedVersion { .. }), + "expected UnsupportedVersion, got {err:?}" + ); + + let message = err.to_string(); + assert!(message.contains(field), "message was: {message}"); + assert!(message.contains("32768"), "message was: {message}"); + } + } + #[test] fn server_type_validation() { // Happy path: server advertises the expected type.