Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 90 additions & 27 deletions crates/fluss/src/rpc/server_connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,17 +85,17 @@ pub struct ServerApiVersions {

impl ServerApiVersions {
/// Build from the server's advertised API version list.
pub(crate) fn new(server_versions: &[PbApiVersion]) -> Self {
pub(crate) fn new(server_versions: &[PbApiVersion]) -> Result<Self, Error> {
let mut versions = HashMap::new();
for sv in server_versions {
let api_key = ApiKey::from(i16::try_from(sv.api_key).unwrap());
let api_key = ApiKey::from(api_version_i16_field("api_key", sv.api_key)?);
// Skip unknown API keys — the client does not support them.
let client_range = match api_key.supported_versions() {
Some(range) => range,
None => continue,
};
let server_min = i16::try_from(sv.min_version).unwrap();
let server_max = i16::try_from(sv.max_version).unwrap();
let server_min = api_version_i16_field("min_version", sv.min_version)?;
let server_max = api_version_i16_field("max_version", sv.max_version)?;
let min_version = client_range.min().0.max(server_min);
let max_version = client_range.max().0.min(server_max);
if min_version > max_version {
Expand All @@ -115,7 +115,7 @@ impl ServerApiVersions {
versions.insert(api_key, Ok(ApiVersion(max_version)));
}
}
Self { versions }
Ok(Self { versions })
}

/// Get the negotiated (highest usable) version for a given API key.
Expand All @@ -132,6 +132,14 @@ impl ServerApiVersions {
}
}

fn api_version_i16_field(field: &str, value: i32) -> Result<i16, Error> {
i16::try_from(value).map_err(|_| Error::UnsupportedVersion {
message: format!(
"ApiVersionsResponse field `{field}` has value {value}, which is outside i16 range"
),
})
}

/// Resolve the API version to use for a given API key.
fn resolve_api_version_for(
api_versions: Option<&ServerApiVersions>,
Expand Down Expand Up @@ -266,7 +274,7 @@ impl RpcClient {
let request = ApiVersionsRequest::new("fluss-rust", env!("CARGO_PKG_VERSION"));
let response = connection.request(request).await?;
validate_server_type(expected_server_type, response.server_type)?;
let api_versions = ServerApiVersions::new(&response.api_versions);
let api_versions = ServerApiVersions::new(&response.api_versions)?;
*connection.api_versions.lock() = Some(api_versions);
Ok(())
}
Expand Down Expand Up @@ -995,11 +1003,14 @@ mod tests {
tokio::spawn(mock_echo_server(server));

let conn = ServerConnectionInner::new(BufStream::new(client), usize::MAX, Arc::from("t"));
*conn.api_versions.lock() = Some(ServerApiVersions::new(&[PbApiVersion {
api_key: 1014,
min_version: 0,
max_version: 0,
}]));
*conn.api_versions.lock() = Some(
ServerApiVersions::new(&[PbApiVersion {
api_key: 1014,
min_version: 0,
max_version: 0,
}])
.unwrap(),
);

let before: Vec<_> = snapshotter.snapshot().into_vec();
let request_before = counter_for_label(&before, CLIENT_REQUESTS_TOTAL, "produce_log");
Expand Down Expand Up @@ -1040,11 +1051,14 @@ mod tests {
tokio::spawn(mock_echo_server(server));

let conn = ServerConnectionInner::new(BufStream::new(client), usize::MAX, Arc::from("t"));
*conn.api_versions.lock() = Some(ServerApiVersions::new(&[PbApiVersion {
api_key: 1012,
min_version: 0,
max_version: 0,
}]));
*conn.api_versions.lock() = Some(
ServerApiVersions::new(&[PbApiVersion {
api_key: 1012,
min_version: 0,
max_version: 0,
}])
.unwrap(),
);
let before: Vec<_> = snapshotter.snapshot().into_vec();
let request_sum_before = counter_sum(&before, CLIENT_REQUESTS_TOTAL);
let response_sum_before = counter_sum(&before, CLIENT_RESPONSES_TOTAL);
Expand Down Expand Up @@ -1081,11 +1095,14 @@ mod tests {
let (client, server) = tokio::io::duplex(64);
drop(server); // force write failure on request path
let conn = ServerConnectionInner::new(BufStream::new(client), usize::MAX, Arc::from("t"));
*conn.api_versions.lock() = Some(ServerApiVersions::new(&[PbApiVersion {
api_key: 1014,
min_version: 0,
max_version: 0,
}]));
*conn.api_versions.lock() = Some(
ServerApiVersions::new(&[PbApiVersion {
api_key: 1014,
min_version: 0,
max_version: 0,
}])
.unwrap(),
);

let before: Vec<_> = snapshotter.snapshot().into_vec();
let request_before = counter_for_label(&before, CLIENT_REQUESTS_TOTAL, "produce_log");
Expand Down Expand Up @@ -1134,11 +1151,14 @@ mod tests {
tokio::spawn(mock_error_server(server));

let conn = ServerConnectionInner::new(BufStream::new(client), usize::MAX, Arc::from("t"));
*conn.api_versions.lock() = Some(ServerApiVersions::new(&[PbApiVersion {
api_key: 1014,
min_version: 0,
max_version: 0,
}]));
*conn.api_versions.lock() = Some(
ServerApiVersions::new(&[PbApiVersion {
api_key: 1014,
min_version: 0,
max_version: 0,
}])
.unwrap(),
);

let before: Vec<_> = snapshotter.snapshot().into_vec();
let response_before = counter_for_label(&before, CLIENT_RESPONSES_TOTAL, "produce_log");
Expand Down Expand Up @@ -1211,7 +1231,7 @@ mod tests {
max_version: 5,
},
];
let negotiated = ServerApiVersions::new(&server_versions);
let negotiated = ServerApiVersions::new(&server_versions).unwrap();

// Successful negotiation cases
assert_eq!(
Expand Down Expand Up @@ -1247,11 +1267,54 @@ mod tests {
// Key not advertised by server → error
assert!(
ServerApiVersions::new(&[])
.unwrap()
.highest_available_version(ApiKey::FetchLog)
.is_err()
);
}

#[test]
fn server_api_versions_rejects_out_of_range_fields() {
let cases = [
(
"api_key",
PbApiVersion {
api_key: i32::from(i16::MAX) + 1,
min_version: 0,
max_version: 0,
},
),
(
"min_version",
PbApiVersion {
api_key: 1014,
min_version: i32::from(i16::MAX) + 1,
max_version: 0,
},
),
(
"max_version",
PbApiVersion {
api_key: 1014,
min_version: 0,
max_version: i32::from(i16::MAX) + 1,
},
),
];

for (field, api_version) in cases {
let err = ServerApiVersions::new(&[api_version]).unwrap_err();
assert!(
matches!(err, Error::UnsupportedVersion { .. }),
"expected UnsupportedVersion, got {err:?}"
);

let message = err.to_string();
assert!(message.contains(field), "message was: {message}");
assert!(message.contains("32768"), "message was: {message}");
}
}

#[test]
fn server_type_validation() {
// Happy path: server advertises the expected type.
Expand Down