Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 106 additions & 29 deletions integration/rust/tests/integration/notify.rs
Original file line number Diff line number Diff line change
@@ -1,55 +1,62 @@
use std::sync::Arc;
use std::time::Duration;
use std::time::{Duration, SystemTime, UNIX_EPOCH};

use parking_lot::Mutex;
use sqlx::{Connection, Executor, PgConnection, postgres::PgListener};
use tokio::{select, spawn, sync::Barrier, time::timeout};
use rust::setup::admin_sqlx;
use rust_decimal::prelude::ToPrimitive;
use sqlx::{Connection, Executor, PgConnection, Pool, Postgres, Row, postgres::PgListener};
use tokio::{
select, spawn,
sync::Barrier,
time::{sleep, timeout},
};

#[tokio::test]
async fn test_notify() {
let messages = Arc::new(Mutex::new(vec![]));
let mut tasks = vec![];
let mut listeners = vec![];
let barrier = Arc::new(Barrier::new(5));

for i in 0..5 {
let test_id = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_nanos();
let channels = (0..5)
.map(|i| format!("test_notify_{test_id}_{i}"))
.collect::<Vec<_>>();
let received_barrier = Arc::new(Barrier::new(channels.len() + 1));
let release_barrier = Arc::new(Barrier::new(channels.len() + 1));

for channel in &channels {
let task_msgs = messages.clone();

let mut listener = PgListener::connect("postgres://pgdog:pgdog@127.0.0.1:6432/pgdog")
.await
.unwrap();

listener
.listen(format!("test_notify_{}", i).as_str())
.await
.unwrap();
listener.listen(channel).await.unwrap();

let barrier = barrier.clone();
let received_barrier = received_barrier.clone();
let release_barrier = release_barrier.clone();
listeners.push(spawn(async move {
let mut received = 0;
loop {
select! {
msg = listener.recv() => {
let msg = msg.unwrap();
received += 1;
task_msgs.lock().push(msg);
if received == 10 {
break;
}
}

}
while received < 10 {
let msg = listener.recv().await.unwrap();
received += 1;
task_msgs.lock().push(msg);
}
barrier.wait().await;

received_barrier.wait().await;
release_barrier.wait().await;
}));
}

for i in 0..50 {
let channel = channels[i % channels.len()].clone();
let handle = spawn(async move {
let mut conn = PgConnection::connect("postgres://pgdog:pgdog@127.0.0.1:6432/pgdog")
.await
.unwrap();
conn.execute(format!("NOTIFY test_notify_{}, 'test_notify_{}'", i % 5, i % 5).as_str())
conn.execute(format!("NOTIFY {channel}, '{channel}'").as_str())
.await
.unwrap();
});
Expand All @@ -61,17 +68,87 @@ async fn test_notify() {
task.await.unwrap();
}

received_barrier.wait().await;

{
let messages = messages.lock();
assert_eq!(messages.len(), 50);
for message in messages.iter() {
assert_eq!(message.channel(), message.payload());
}
}

assert_listener_stats(&channels, 10).await;

release_barrier.wait().await;
for listener in listeners {
listener.await.unwrap();
}
}

assert_eq!(messages.lock().len(), 50);
let messages = messages.lock();
for message in messages.iter() {
assert_eq!(message.channel(), message.payload());
#[derive(Debug)]
struct ListenerStats {
listeners: i64,
received: i64,
dropped: i64,
}

async fn assert_listener_stats(channels: &[String], expected_received: i64) {
let admin = admin_sqlx().await;
let stats = wait_for_listener_stats(&admin, channels, expected_received).await;

for (channel, stats) in channels.iter().zip(stats) {
assert_eq!(stats.listeners, 1, "{channel} listener count");
assert_eq!(
stats.received, expected_received,
"{channel} received count"
);
assert_eq!(stats.dropped, 0, "{channel} dropped count");
}
}

async fn wait_for_listener_stats(
admin: &Pool<Postgres>,
channels: &[String],
expected_received: i64,
) -> Vec<ListenerStats> {
timeout(Duration::from_secs(5), async {
loop {
let rows = admin.fetch_all("SHOW LISTENERS").await.unwrap();
let stats = channels
.iter()
.filter_map(|channel| {
rows.iter()
.find(|row| row.get::<String, _>("channel") == *channel)
.map(|row| ListenerStats {
listeners: numeric_column(row, "listeners"),
received: numeric_column(row, "received"),
dropped: numeric_column(row, "dropped"),
})
})
.collect::<Vec<_>>();

if stats.len() == channels.len()
&& stats
.iter()
.all(|stats| stats.listeners >= 1 && stats.received >= expected_received)
{
return stats;
}

sleep(Duration::from_millis(10)).await;
}
})
.await
.unwrap()
}

fn numeric_column(row: &sqlx::postgres::PgRow, column: &str) -> i64 {
row.get::<rust_decimal::Decimal, _>(column)
.to_i64()
.unwrap()
}

#[tokio::test]
async fn test_notify_only_delivered_after_transaction_commit() {
let messages = Arc::new(Mutex::new(vec![]));
Expand Down
2 changes: 2 additions & 0 deletions pgdog/src/admin/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ pub mod show_client_memory;
pub mod show_clients;
pub mod show_config;
pub mod show_instance_id;
pub mod show_listeners;
pub mod show_lists;
pub mod show_mirrors;
pub mod show_peers;
Expand Down Expand Up @@ -72,6 +73,7 @@ pub use show_client_memory::*;
pub use show_clients::*;
pub use show_config::*;
pub use show_instance_id::*;
pub use show_listeners::*;
pub use show_lists::*;
pub use show_mirrors::*;
pub use show_peers::*;
Expand Down
10 changes: 10 additions & 0 deletions pgdog/src/admin/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ pub enum ParseResult {
SetupSchema(SetupSchema),
Shutdown(Shutdown),
ShowLists(ShowLists),
ShowListeners(ShowListeners),
ShowPrepared(ShowPreparedStatements),
ShowReplication(ShowReplication),
ShowServerMemory(ShowServerMemory),
Expand Down Expand Up @@ -71,6 +72,7 @@ impl ParseResult {
SetupSchema(setup_schema) => setup_schema.execute().await,
Shutdown(shutdown) => shutdown.execute().await,
ShowLists(show_lists) => show_lists.execute().await,
ShowListeners(show_listeners) => show_listeners.execute().await,
ShowPrepared(cmd) => cmd.execute().await,
ShowReplication(show_replication) => show_replication.execute().await,
ShowServerMemory(show_server_memory) => show_server_memory.execute().await,
Expand Down Expand Up @@ -117,6 +119,7 @@ impl ParseResult {
SetupSchema(setup_schema) => setup_schema.name(),
Shutdown(shutdown) => shutdown.name(),
ShowLists(show_lists) => show_lists.name(),
ShowListeners(show_listeners) => show_listeners.name(),
ShowPrepared(show) => show.name(),
ShowReplication(show_replication) => show_replication.name(),
ShowServerMemory(show_server_memory) => show_server_memory.name(),
Expand Down Expand Up @@ -183,6 +186,7 @@ impl Parser {
"version" => ParseResult::ShowVersion(ShowVersion::parse(&sql)?),
"instance_id" => ParseResult::ShowInstanceId(ShowInstanceId::parse(&sql)?),
"lists" => ParseResult::ShowLists(ShowLists::parse(&sql)?),
"listeners" => ParseResult::ShowListeners(ShowListeners::parse(&sql)?),
"prepared" => ParseResult::ShowPrepared(ShowPreparedStatements::parse(&sql)?),
"replication" => ParseResult::ShowReplication(ShowReplication::parse(&sql)?),
"replication_slots" => {
Expand Down Expand Up @@ -265,6 +269,12 @@ mod tests {
assert!(matches!(result, Ok(ParseResult::ShowClientMemory(_))));
}

#[test]
fn parses_show_listeners_command() {
let result = Parser::parse("SHOW LISTENERS;");
assert!(matches!(result, Ok(ParseResult::ShowListeners(_))));
}

#[test]
fn parses_cutover_command() {
let result = Parser::parse("CUTOVER");
Expand Down
72 changes: 72 additions & 0 deletions pgdog/src/admin/show_listeners.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
//! SHOW LISTENERS.

use crate::backend::pub_sub::listener;

use super::prelude::*;

pub struct ShowListeners;

#[async_trait]
impl Command for ShowListeners {
fn name(&self) -> String {
"SHOW LISTENERS".into()
}

fn parse(_: &str) -> Result<Self, Error> {
Ok(Self)
}

async fn execute(&self) -> Result<Vec<Message>, Error> {
let mut channels: Vec<_> = listener::stats().into_iter().collect();
channels.sort_by(|a, b| a.0.cmp(&b.0));

let mut messages = vec![
RowDescription::new(&[
Field::text("channel"),
Field::numeric("listeners"),
Field::numeric("received"),
Field::numeric("dropped"),
])
.message()?,
];

for (channel, stats) in channels {
let mut data_row = DataRow::new();
data_row
.add(channel.as_str())
.add(stats.listeners as i64)
.add(stats.recv as i64)
.add(stats.dropped as i64);
messages.push(data_row.message()?);
}

Ok(messages)
}
}

#[cfg(test)]
mod tests {
use crate::net::{FromBytes, RowDescription};

use super::*;

#[tokio::test]
async fn show_listeners_reports_columns() {
let messages = ShowListeners
.execute()
.await
.expect("show listeners should execute");

assert_eq!(messages[0].code(), 'T');

let row_description =
RowDescription::from_bytes(messages[0].payload()).expect("row description parses");
let columns: Vec<&str> = row_description
.fields
.iter()
.map(|field| field.name.as_str())
.collect();

assert_eq!(columns, ["channel", "listeners", "received", "dropped"]);
}
}
4 changes: 2 additions & 2 deletions pgdog/src/backend/pool/connection/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -256,8 +256,8 @@ impl Connection {
};

if let Some(shard) = self.cluster()?.shards().get(num) {
let rx = shard.listen(channel).await?;
self.pub_sub.listen(channel, rx);
let listener = shard.listen(channel).await?;
self.pub_sub.listen(channel, listener);
}

Ok(())
Expand Down
10 changes: 4 additions & 6 deletions pgdog/src/backend/pool/shard/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,18 @@ use arc_swap::ArcSwap;
use std::ops::Deref;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::{OnceCell, broadcast};
use tokio::sync::OnceCell;
use tokio::{select, spawn, sync::Notify};
use tracing::{debug, info};

use crate::backend::PubSubListener;
use crate::backend::Schema;
use crate::backend::databases::User;
use crate::backend::pool::lb::ban::Ban;
use crate::backend::pub_sub::listener::Listener;
use crate::config::{LoadBalancingStrategy, ReadWriteSplit, Role};
use crate::net::Parameters;
use crate::net::messages::FrontendPid;
use crate::net::{NotificationResponse, Parameters};

use super::{Error, Guard, LoadBalancer, Pool, PoolConfig, Request};

Expand Down Expand Up @@ -103,10 +104,7 @@ impl Shard {
}

/// Listen for notifications on channel.
pub async fn listen(
&self,
channel: &str,
) -> Result<broadcast::Receiver<NotificationResponse>, Error> {
pub async fn listen(&self, channel: &str) -> Result<Listener, Error> {
match self.pub_sub.load_full().deref() {
Some(listener) => listener.listen(channel).await,
_ => Err(Error::PubSubDisabled),
Expand Down
Loading
Loading