Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions integration/rust/tests/integration/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ pub mod reset;
pub mod rewrite;
pub mod rewrite_omni;
pub mod savepoint;
pub mod set_config;
pub mod set_in_transaction;
pub mod set_sharding_key;
pub mod shard_consistency;
Expand Down
57 changes: 57 additions & 0 deletions integration/rust/tests/integration/set_config.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
use rust::setup::connections_sqlx;
use sqlx::query_scalar;
use std::assert_matches;

#[tokio::test]
async fn test_set_config_behaves_like_set() {
let pool = connections_sqlx().await.pop().unwrap();
let mut conn = pool.acquire().await.unwrap();
let set_config: String = query_scalar("SELECT set_config('lock_timeout', '1000s', false);")
.fetch_one(&mut *conn)
.await
.unwrap();
assert_eq!(set_config, "1000s");
let lock_timeout: String = query_scalar("SHOW lock_timeout")
.fetch_one(&mut *conn)
.await
.unwrap();
assert_eq!(lock_timeout, "1000s");

let set_config: String = query_scalar("SELECT set_config('lock_timeout', '500s', false);")
.fetch_one(&mut *conn)
.await
.unwrap();
assert_eq!(set_config, "500s");
let lock_timeout: String = query_scalar("SHOW lock_timeout")
.fetch_one(&mut *conn)
.await
.unwrap();
assert_eq!(lock_timeout, "500s");

let set_config: Option<String> =
query_scalar("SELECT set_config('lock_timeout', NULL, false);")
.fetch_one(&mut *conn)
.await
.unwrap();
assert_eq!(set_config, None);

let lock_timeout: String = query_scalar("SHOW lock_timeout")
.fetch_one(&mut *conn)
.await
.unwrap();
assert_eq!(lock_timeout, "0");
}

#[tokio::test]
// We can't handle every possible invocation yet, but we at least shouldn't
// error
async fn test_set_config_does_something_when_unable_to_resolve_args() {
let pool = connections_sqlx().await.pop().unwrap();
let mut conn = pool.acquire().await.unwrap();

let set_config: Result<String, _> =
query_scalar("SELECT set_config('lock_timeout', (SELECT '1'), false);")
.fetch_one(&mut *conn)
.await;
assert_matches!(set_config, Ok(_));
}
35 changes: 30 additions & 5 deletions pgdog/src/frontend/client/query_engine/fake.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use tokio::io::AsyncWriteExt;

use crate::net::{
BindComplete, CloseComplete, CommandComplete, NoData, ParameterDescription, ParseComplete,
ProtocolMessage, ReadyForQuery, RowDescription,
BindComplete, CloseComplete, CommandComplete, DataRow, Field, NoData, ParameterDescription,
ParseComplete, ProtocolMessage, ReadyForQuery, RowDescription, parameter::ParameterValue,
};

use super::*;
Expand All @@ -14,8 +14,24 @@ impl QueryEngine {
&mut self,
context: &mut QueryEngineContext<'_>,
command: &str,
return_value: Option<impl IntoIterator<Item = Option<&'_ ParameterValue>> + Clone>,
) -> Result<(), Error> {
let mut sent = 0;
let return_fields = return_value
.clone()
.into_iter()
.flatten()
// FIXME(sage): Don't assume `set_config` is the only consumer
.map(|_| Field::text("set_config"))
.collect::<Vec<_>>();
let row_description = RowDescription::new(&return_fields);
let data_row = return_value.map(|return_value| {
let mut row = DataRow::new();
for val in return_value {
row.add(val);
}
row
});
for message in context.client_request.iter() {
sent += match message {
ProtocolMessage::Parse(_) => context.stream.send(&ParseComplete).await?,
Expand All @@ -26,13 +42,17 @@ impl QueryEngine {
.stream
.send(&ParameterDescription::default())
.await?
+ context.stream.send(&RowDescription::default()).await?
+ context.stream.send(&row_description).await?
} else {
context.stream.send(&NoData).await?
}
}
ProtocolMessage::Execute(_) => {
context.stream.send(&CommandComplete::new(command)).await?
(if let Some(row) = data_row.as_ref() {
context.stream.send(row).await?
} else {
0
}) + context.stream.send(&CommandComplete::new(command)).await?
}
ProtocolMessage::Sync(_) => {
context
Expand All @@ -41,7 +61,12 @@ impl QueryEngine {
.await?
}
ProtocolMessage::Query(_) => {
context.stream.send(&CommandComplete::new(command)).await?
(if let Some(row) = data_row.as_ref() {
context.stream.send(&row_description).await?
+ context.stream.send(row).await?
} else {
0
}) + context.stream.send(&CommandComplete::new(command)).await?
+ context
.stream
.send(&ReadyForQuery::in_transaction(context.in_transaction()))
Expand Down
8 changes: 6 additions & 2 deletions pgdog/src/frontend/client/query_engine/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -228,9 +228,13 @@ impl QueryEngine {
.await?
}
Command::Unlisten(channel) => self.unlisten(context, &channel.clone()).await?,
Command::Set { params, .. } => {
Command::Set {
params,
behave_like_select,
..
} => {
let params = params.clone();
self.set(context, &params).await?;
self.set(context, &params, *behave_like_select).await?;
}
Command::ResetAll => {
self.reset_all(context).await?;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ impl<'a> UpdateMulti<'a> {
// This happens, but the UPDATE's WHERE clause
// doesn't match any rows, so this whole thing is a no-op.
self.engine
.fake_command_response(context, "UPDATE 0")
.fake_command_response(context, "UPDATE 0", None::<Option<_>>)
.await?;
}

Expand Down
27 changes: 17 additions & 10 deletions pgdog/src/frontend/client/query_engine/set.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,21 @@ impl QueryEngine {
&mut self,
context: &mut QueryEngineContext<'_>,
params: &[SetParam],
behave_like_select: bool,
) -> Result<(), Error> {
let mut fake_command = "SET";
for param in params {
if param.reset {
context.params.reset(&param.name);
fake_command = "RESET";
} else if context.in_transaction() {
context
.params
.insert_transaction(&param.name, param.value.clone(), param.local);
if let Some(value) = param.value.clone() {
if context.in_transaction() {
context
.params
.insert_transaction(&param.name, value, param.local);
} else {
context.params.insert(&param.name, value);
}
} else {
context.params.insert(&param.name, param.value.clone());
fake_command = "RESET";
context.params.reset(&param.name);
}
}

Expand All @@ -29,7 +32,10 @@ impl QueryEngine {
if self.backend.connected() {
self.execute(context).await?;
} else {
self.fake_command_response(context, fake_command).await?;
let values_to_return =
behave_like_select.then(|| params.iter().map(|p| p.value.as_ref()));
self.fake_command_response(context, fake_command, values_to_return)
.await?;
}

Ok(())
Expand All @@ -44,7 +50,8 @@ impl QueryEngine {
if self.backend.connected() {
self.execute(context).await?;
} else {
self.fake_command_response(context, "RESET").await?;
self.fake_command_response(context, "RESET", None::<Option<_>>)
.await?;
}

Ok(())
Expand Down
4 changes: 2 additions & 2 deletions pgdog/src/frontend/router/parser/command.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,8 @@ use lazy_static::lazy_static;
#[derive(Debug, Clone, PartialEq)]
pub struct SetParam {
pub name: String,
pub value: ParameterValue,
pub value: Option<ParameterValue>,
pub local: bool,
pub reset: bool,
}

#[derive(Debug, Clone)]
Expand All @@ -33,6 +32,7 @@ pub enum Command {
Set {
params: Vec<SetParam>,
route: Route,
behave_like_select: bool,
},
ResetAll,
PreparedStatement(Prepare),
Expand Down
1 change: 1 addition & 0 deletions pgdog/src/frontend/router/parser/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ pub mod sequence;
pub mod statement;
pub mod table;
pub mod tuple;
pub(crate) mod util;
pub mod value;
pub mod where_clause;

Expand Down
37 changes: 36 additions & 1 deletion pgdog/src/frontend/router/parser/query/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@ use crate::{
config::Role,
frontend::router::{
context::RouterContext,
parser::{OrderBy, Shard},
parser::{
OrderBy, Shard,
util::{PgStr, pg_str},
},
round_robin,
sharding::{Centroids, ContextBuilder},
},
Expand All @@ -28,6 +31,7 @@ mod plugins;
mod schema_sharding;
mod select;
mod set;
mod set_config;
mod shared;
mod show;
mod transaction;
Expand Down Expand Up @@ -284,6 +288,14 @@ impl QueryParser {
let mut command = match root.node {
// SET statements -> return immediately.
Some(NodeEnum::VariableSetStmt(ref stmt)) => return self.set(stmt, context),

// SELECT set_config(...) -> treat as SET and return
Some(NodeEnum::SelectStmt(ref stmt))
if let Some(set_config) = extract_set_config(stmt) =>
{
return Ok(self.set_config(set_config, context));
}

// SHOW statements -> return immediately.
Some(NodeEnum::VariableShowStmt(ref stmt)) => return self.show(stmt, context),
// DEALLOCATE statements -> return immediately.
Expand Down Expand Up @@ -566,5 +578,28 @@ impl QueryParser {
}
}

fn extract_set_config(stmt: &SelectStmt) -> Option<&FuncCall> {
static SET_CONFIG: &[&[PgStr<'static>]] = &[
&[pg_str("pg_catalog"), pg_str("set_config")],
&[pg_str("set_config")],
];
// FIXME(sage): Dear god we need some pattern macros for this
if let [
Node {
node: Some(NodeEnum::ResTarget(r)),
},
] = &*stmt.target_list
&& let ResTarget { val: Some(n), .. } = &**r
&& let Node {
node: Some(NodeEnum::FuncCall(f)),
} = &**n
&& SET_CONFIG.iter().any(|&n| n == f.funcname)
{
Some(f)
} else {
None
}
}

#[cfg(test)]
mod test;
8 changes: 4 additions & 4 deletions pgdog/src/frontend/router/parser/query/set.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ impl QueryParser {
return Ok(Command::Set {
params: vec![param],
route: Route::write(context.shards_calculator.shard()),
behave_like_select: false,
});
}

Expand All @@ -40,17 +41,15 @@ impl QueryParser {
let value = Self::parse_set_value(stmt)?;

match value {
Some(value) => Ok(Some(SetParam {
value @ Some(_) => Ok(Some(SetParam {
name: stmt.name.to_string(),
value,
local: stmt.is_local,
reset: is_reset,
})),
None if is_reset => Ok(Some(SetParam {
name: stmt.name.to_string(),
value: ParameterValue::String(std::string::String::new()),
value: None,
local: false,
reset: true,
})),
None => Ok(None),
}
Expand Down Expand Up @@ -103,6 +102,7 @@ impl QueryParser {
Ok(Some(Command::Set {
params,
route: Route::write(context.shards_calculator.shard()),
behave_like_select: false,
}))
}

Expand Down
Loading
Loading