Skip to content

Commit

Permalink
Implement TLS handshake with proper host name
Browse files Browse the repository at this point in the history
  • Loading branch information
SevInf committed Oct 9, 2024
1 parent 6f42a5b commit bc9fcf9
Showing 1 changed file with 25 additions and 12 deletions.
37 changes: 25 additions & 12 deletions quaint/src/connector/postgres/native/websocket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use async_tungstenite::{
tungstenite::{
self,
client::IntoClientRequest,
http::{HeaderValue, StatusCode},
http::{HeaderMap, HeaderValue, StatusCode},
Error as TungsteniteError,
},
};
Expand All @@ -20,26 +20,25 @@ use crate::{
};

const CONNECTION_PARAMS_HEADER: &str = "Prisma-Connection-Parameters";
const HOST_HEADER: &str = "Prisma-Db-Host";

pub(crate) async fn connect_via_websocket(url: PostgresWebSocketUrl) -> crate::Result<Client> {
let (ws_stream, response) = connect_async(url).await.inspect_err(|e| {
eprintln!("{}", e);
dbg!(&e);
if let TungsteniteError::Http(response) = e {
dbg!(String::from_utf8(response.body().clone().unwrap()).unwrap());
}
})?;

let Some(header) = response.headers().get(CONNECTION_PARAMS_HEADER) else {
let message = format!("Missing response header {CONNECTION_PARAMS_HEADER}");
let error = Error::builder(ErrorKind::Native(NativeErrorKind::ConnectionError(message.into()))).build();
return Err(error);
};

let connection_params = header.to_str().map_err(|inner| {
Error::builder(ErrorKind::Native(NativeErrorKind::ConnectionError(Box::new(inner)))).build()
})?;
let connection_params = require_header_value(response.headers(), CONNECTION_PARAMS_HEADER)?;
dbg!(&connection_params);
let db_host = require_header_value(response.headers(), HOST_HEADER)?;
dbg!(&connection_params);

let config = Config::from_str(connection_params)?;
let ws_byte_stream = WsStream::new(ws_stream);

let tls = TlsConnector::new(native_tls::TlsConnector::new()?, "TODO");
let tls = TlsConnector::new(native_tls::TlsConnector::new()?, db_host);
let (client, connection) = config.connect_raw(ws_byte_stream, tls).await?;
tokio::spawn(connection.map(|r| match r {
Ok(_) => (),
Expand All @@ -50,6 +49,20 @@ pub(crate) async fn connect_via_websocket(url: PostgresWebSocketUrl) -> crate::R
Ok(client)
}

fn require_header_value<'a>(headers: &'a HeaderMap, name: &str) -> crate::Result<&'a str> {
let Some(header) = headers.get(name) else {
let message = format!("Missing response header {name}");
let error = Error::builder(ErrorKind::Native(NativeErrorKind::ConnectionError(message.into()))).build();
return Err(error);
};

let value = header.to_str().map_err(|inner| {
Error::builder(ErrorKind::Native(NativeErrorKind::ConnectionError(Box::new(inner)))).build()
})?;

Ok(value)
}

impl IntoClientRequest for PostgresWebSocketUrl {
fn into_client_request(self) -> tungstenite::Result<tungstenite::handshake::client::Request> {
let mut request = self.url.to_string().into_client_request()?;
Expand Down

0 comments on commit bc9fcf9

Please sign in to comment.