Skip to content

Commit

Permalink
Migrations over WebSocket (#5010)
Browse files Browse the repository at this point in the history
* [WIP]: Migrations over WebSocket

* Avoid panics in error handling code

* Fix compilation

* Handle unauthorized error

* Fix wasm build

* Cargo fmt

* Use correct urls

* Restore Cargo.toml

* Fix TLS and api_key

* TLS support

* Fix fixed TLS

* Implement TLS handshake with proper host name

* Update quaint/src/connector/postgres/native/websocket.rs

Co-authored-by: Alberto Schiabel <jkomyno@users.noreply.github.com>

* Remove dbg

* Feedback & cleanup

---------

Co-authored-by: Alberto Schiabel <jkomyno@users.noreply.github.com>
  • Loading branch information
SevInf and jkomyno authored Oct 10, 2024
1 parent 6a192e2 commit edd552c
Show file tree
Hide file tree
Showing 15 changed files with 437 additions and 77 deletions.
111 changes: 103 additions & 8 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,10 @@ impl Connector for PostgresDatamodelConnector {
}

fn validate_url(&self, url: &str) -> Result<(), String> {
if !url.starts_with("postgres://") && !url.starts_with("postgresql://") {
if !url.starts_with("postgres://")
&& !url.starts_with("postgresql://")
&& !url.starts_with("prisma+postgres://")
{
return Err("must start with the protocol `postgresql://` or `postgres://`.".to_owned());
}

Expand Down
14 changes: 13 additions & 1 deletion quaint/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ postgresql-native = [
"bit-vec",
"lru-cache",
"byteorder",
"dep:ws_stream_tungstenite",
"dep:async-tungstenite"
]
postgresql = []

Expand Down Expand Up @@ -111,6 +113,16 @@ expect-test = "1"
version = "0.2"
features = ["js"]

[dependencies.ws_stream_tungstenite]
version = "0.14.0"
features = ["tokio_io"]
optional = true

[dependencies.async-tungstenite]
version = "0.28.0"
features = ["tokio-runtime", "tokio-native-tls"]
optional = true

[dependencies.byteorder]
default-features = false
optional = true
Expand Down Expand Up @@ -180,7 +192,7 @@ features = ["rt-multi-thread", "macros", "sync"]
optional = true

[dependencies.tokio-util]
version = "0.6"
version = "0.7"
features = ["compat"]
optional = true

Expand Down
4 changes: 2 additions & 2 deletions quaint/src/connector/connection_info.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ impl ConnectionInfo {
}
#[cfg(feature = "postgresql")]
SqlFamily::Postgres => Ok(ConnectionInfo::Native(NativeConnectionInfo::Postgres(
PostgresUrl::new(url)?,
super::PostgresUrl::new_native(url)?,
))),
#[allow(unreachable_patterns)]
_ => unreachable!(),
Expand Down Expand Up @@ -243,7 +243,7 @@ impl ConnectionInfo {
pub fn pg_bouncer(&self) -> bool {
match self {
#[cfg(all(not(target_arch = "wasm32"), feature = "postgresql"))]
ConnectionInfo::Native(NativeConnectionInfo::Postgres(url)) => url.pg_bouncer(),
ConnectionInfo::Native(NativeConnectionInfo::Postgres(PostgresUrl::Native(url))) => url.pg_bouncer(),
_ => false,
}
}
Expand Down
35 changes: 27 additions & 8 deletions quaint/src/connector/postgres/native/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@ pub(crate) mod column_type;
mod conversion;
mod error;
mod explain;
mod websocket;

pub(crate) use crate::connector::postgres::url::PostgresUrl;
pub(crate) use crate::connector::postgres::url::PostgresNativeUrl;
use crate::connector::postgres::url::{Hidden, SslAcceptMode, SslParams};
use crate::connector::{
timeout, ColumnType, DescribedColumn, DescribedParameter, DescribedQuery, IsolationLevel, Transaction,
Expand Down Expand Up @@ -37,12 +38,15 @@ use std::{
time::Duration,
};
use tokio_postgres::{config::ChannelBinding, Client, Config, Statement};
use websocket::connect_via_websocket;

/// The underlying postgres driver. Only available with the `expose-drivers`
/// Cargo feature.
#[cfg(feature = "expose-drivers")]
pub use tokio_postgres;

use super::PostgresWebSocketUrl;

struct PostgresClient(Client);

impl Debug for PostgresClient {
Expand Down Expand Up @@ -160,7 +164,7 @@ impl SslParams {
}
}

impl PostgresUrl {
impl PostgresNativeUrl {
pub(crate) fn cache(&self) -> StatementCache {
if self.query_params.pg_bouncer {
StatementCache::new(0)
Expand Down Expand Up @@ -228,7 +232,7 @@ impl PostgresUrl {

impl PostgreSql {
/// Create a new connection to the database.
pub async fn new(url: PostgresUrl) -> crate::Result<Self> {
pub async fn new(url: PostgresNativeUrl) -> crate::Result<Self> {
let config = url.to_config();

let mut tls_builder = TlsConnector::builder();
Expand Down Expand Up @@ -292,6 +296,21 @@ impl PostgreSql {
})
}

/// Create a new websocket connection to managed database
pub async fn new_with_websocket(url: PostgresWebSocketUrl) -> crate::Result<Self> {
let client = connect_via_websocket(url).await?;

Ok(Self {
client: PostgresClient(client),
socket_timeout: None,
pg_bouncer: false,
statement_cache: Mutex::new(StatementCache::new(0)),
is_healthy: AtomicBool::new(true),
is_cockroachdb: false,
is_materialize: false,
})
}

/// The underlying tokio_postgres::Client. Only available with the
/// `expose-drivers` Cargo feature. This is a lower level API when you need
/// to get into database specific features.
Expand Down Expand Up @@ -922,7 +941,7 @@ mod tests {
let mut url = Url::parse(&CONN_STR).unwrap();
url.query_pairs_mut().append_pair("schema", schema_name);

let mut pg_url = PostgresUrl::new(url).unwrap();
let mut pg_url = PostgresNativeUrl::new(url).unwrap();
pg_url.set_flavour(PostgresFlavour::Postgres);

let client = PostgreSql::new(pg_url).await.unwrap();
Expand Down Expand Up @@ -974,7 +993,7 @@ mod tests {
url.query_pairs_mut().append_pair("schema", schema_name);
url.query_pairs_mut().append_pair("pbbouncer", "true");

let mut pg_url = PostgresUrl::new(url).unwrap();
let mut pg_url = PostgresNativeUrl::new(url).unwrap();
pg_url.set_flavour(PostgresFlavour::Postgres);

let client = PostgreSql::new(pg_url).await.unwrap();
Expand Down Expand Up @@ -1025,7 +1044,7 @@ mod tests {
let mut url = Url::parse(&CRDB_CONN_STR).unwrap();
url.query_pairs_mut().append_pair("schema", schema_name);

let mut pg_url = PostgresUrl::new(url).unwrap();
let mut pg_url = PostgresNativeUrl::new(url).unwrap();
pg_url.set_flavour(PostgresFlavour::Cockroach);

let client = PostgreSql::new(pg_url).await.unwrap();
Expand Down Expand Up @@ -1076,7 +1095,7 @@ mod tests {
let mut url = Url::parse(&CONN_STR).unwrap();
url.query_pairs_mut().append_pair("schema", schema_name);

let mut pg_url = PostgresUrl::new(url).unwrap();
let mut pg_url = PostgresNativeUrl::new(url).unwrap();
pg_url.set_flavour(PostgresFlavour::Unknown);

let client = PostgreSql::new(pg_url).await.unwrap();
Expand Down Expand Up @@ -1127,7 +1146,7 @@ mod tests {
let mut url = Url::parse(&CONN_STR).unwrap();
url.query_pairs_mut().append_pair("schema", schema_name);

let mut pg_url = PostgresUrl::new(url).unwrap();
let mut pg_url = PostgresNativeUrl::new(url).unwrap();
pg_url.set_flavour(PostgresFlavour::Unknown);

let client = PostgreSql::new(pg_url).await.unwrap();
Expand Down
Loading

0 comments on commit edd552c

Please sign in to comment.