Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add native tls WSS support. #881

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions rumqttc/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
* Made `DisconnectProperties` struct public.
* Replace `Vec<Option<u16>>` with `FixedBitSet` for managing packet ids of released QoS 2 publishes and incoming QoS 2 publishes in `MqttState`.
* Accept `native_tls::TlsConnector` as input for `Transport::tls_with_config`.
* Accept `native_tls::TlsConnector` as input for `Transport::wss_with_config`.

### Deprecated

Expand Down
12 changes: 11 additions & 1 deletion rumqttc/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ rustdoc-args = ["--cfg", "docsrs"]
[features]
default = ["use-rustls"]
use-rustls = ["dep:tokio-rustls", "dep:rustls-webpki", "dep:rustls-pemfile", "dep:rustls-native-certs"]
use-native-tls = ["dep:tokio-native-tls", "dep:native-tls"]
use-native-tls = ["dep:tokio-native-tls", "dep:native-tls", "async-tungstenite?/tokio-native-tls"]
websocket = ["dep:async-tungstenite", "dep:ws_stream_tungstenite", "dep:http"]
proxy = ["dep:async-http-proxy"]

Expand Down Expand Up @@ -83,3 +83,13 @@ required-features = ["websocket"]
name = "websocket_proxy"
path = "examples/websocket_proxy.rs"
required-features = ["websocket", "proxy"]

[[example]]
name = "wss"
path = "examples/wss.rs"
required-features = ["websocket"]

[[example]]
name = "wss_v5"
path = "examples/wss_v5.rs"
required-features = ["websocket"]
96 changes: 96 additions & 0 deletions rumqttc/examples/wss.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
use rumqttc::{AsyncClient, MqttOptions, QoS};
use std::{error::Error, time::Duration};
use tokio::{task, time};

use rumqttc::Transport;
#[cfg(all(feature = "use-rustls", not(feature = "use-native-tls")))]
use tokio_rustls::rustls::ClientConfig;

#[cfg(feature = "use-native-tls")]
use tokio_native_tls::native_tls;

#[tokio::main(flavor = "current_thread")]
async fn main() -> Result<(), Box<dyn Error>> {
pretty_env_logger::init();

// port parameter is ignored when scheme is websocket
let mut mqttoptions = MqttOptions::new(
"test-1",
#[cfg(any(feature = "use-rustls", feature = "use-native-tls"))]
"wss://test.mosquitto.org:8081",
#[cfg(not(any(feature = "use-rustls", feature = "use-native-tls")))]
"wss://test.mosquitto.org:8080",
8080,
);

#[cfg(feature = "use-native-tls")]
{
// Use native-tls to load root certificates from the operating system.
println!("Using native-tls to load root certificates from the operating system.");
let mut builder = native_tls::TlsConnector::builder();
let _pem = &vec![1, 2, 3];
// let _pem = include_bytes!("native-tls-cert.pem");
let cert = native_tls::Certificate::from_pem(_pem)?;
builder.add_root_certificate(cert);
let connector = builder.build()?;
mqttoptions.set_transport(Transport::wss_with_config(connector.into()));
}
#[cfg(all(feature = "use-rustls", not(feature = "use-native-tls")))]
{
// Use rustls-native-certs to load root certificates from the operating system.
println!("Using rustls-native-certs to load root certificates from the operating system.");
let mut root_cert_store = tokio_rustls::rustls::RootCertStore::empty();
root_cert_store.add_parsable_certificates(
rustls_native_certs::load_native_certs().expect("could not load platform certs"),
);

let client_config = ClientConfig::builder()
.with_root_certificates(root_cert_store)
.with_no_client_auth();

mqttoptions.set_transport(Transport::wss_with_config(client_config.into()));
}
#[cfg(not(any(feature = "use-rustls", feature = "use-native-tls")))]
{
mqttoptions.set_transport(Transport::Ws);
}

mqttoptions.set_keep_alive(Duration::from_secs(60));

let (client, mut eventloop) = AsyncClient::new(mqttoptions, 10);
task::spawn(async move {
requests(client).await;
time::sleep(Duration::from_secs(3)).await;
});

loop {
let event = eventloop.poll().await;
match event {
Ok(notif) => {
println!("Event = {notif:?}");
}
Err(err) => {
println!("Error = {err:?}");
return Ok(());
}
}
}
}

async fn requests(client: AsyncClient) {
client
.subscribe("hello/world", QoS::AtMostOnce)
.await
.unwrap();

for i in 1..=10 {
client
.publish("hello/world", QoS::ExactlyOnce, false, vec![1; i])
.await
.unwrap();

time::sleep(Duration::from_secs(1)).await;
}

time::sleep(Duration::from_secs(120)).await;
}
96 changes: 96 additions & 0 deletions rumqttc/examples/wss_v5.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
use rumqttc::{AsyncClient, MqttOptions, QoS};
use std::{error::Error, time::Duration};
use tokio::{task, time};

use rumqttc::Transport;
#[cfg(all(feature = "use-rustls", not(feature = "use-native-tls")))]
use tokio_rustls::rustls::ClientConfig;

#[cfg(feature = "use-native-tls")]
use tokio_native_tls::native_tls;

#[tokio::main(flavor = "current_thread")]
async fn main() -> Result<(), Box<dyn Error>> {
pretty_env_logger::init();

// port parameter is ignored when scheme is websocket
let mut mqttoptions = MqttOptions::new(
"test-1",
#[cfg(any(feature = "use-rustls", feature = "use-native-tls"))]
"wss://test.mosquitto.org:8081",
#[cfg(not(any(feature = "use-rustls", feature = "use-native-tls")))]
"wss://test.mosquitto.org:8080",
8080,
);

#[cfg(feature = "use-native-tls")]
{
// Use native-tls to load root certificates from the operating system.
println!("Using native-tls to load root certificates from the operating system.");
let mut builder = native_tls::TlsConnector::builder();
let _pem = &vec![1, 2, 3];
// let _pem = include_bytes!("native-tls-cert.pem");
let cert = native_tls::Certificate::from_pem(_pem)?;
builder.add_root_certificate(cert);
let connector = builder.build()?;
mqttoptions.set_transport(Transport::wss_with_config(connector.into()));
}
#[cfg(all(feature = "use-rustls", not(feature = "use-native-tls")))]
{
// Use rustls-native-certs to load root certificates from the operating system.
println!("Using rustls-native-certs to load root certificates from the operating system.");
let mut root_cert_store = tokio_rustls::rustls::RootCertStore::empty();
root_cert_store.add_parsable_certificates(
rustls_native_certs::load_native_certs().expect("could not load platform certs"),
);

let client_config = ClientConfig::builder()
.with_root_certificates(root_cert_store)
.with_no_client_auth();

mqttoptions.set_transport(Transport::wss_with_config(client_config.into()));
}
#[cfg(not(any(feature = "use-rustls", feature = "use-native-tls")))]
{
mqttoptions.set_transport(Transport::Ws);
}

mqttoptions.set_keep_alive(Duration::from_secs(60));

let (client, mut eventloop) = AsyncClient::new(mqttoptions, 10);
task::spawn(async move {
requests(client).await;
time::sleep(Duration::from_secs(3)).await;
});

loop {
let event = eventloop.poll().await;
match event {
Ok(notif) => {
println!("Event = {notif:?}");
}
Err(err) => {
println!("Error = {err:?}");
return Ok(());
}
}
}
}

async fn requests(client: AsyncClient) {
client
.subscribe("hello/world", QoS::AtMostOnce)
.await
.unwrap();

for i in 1..=10 {
client
.publish("hello/world", QoS::ExactlyOnce, false, vec![1; i])
.await
.unwrap();

time::sleep(Duration::from_secs(1)).await;
}

time::sleep(Duration::from_secs(120)).await;
}
9 changes: 7 additions & 2 deletions rumqttc/src/eventloop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ async fn network_connect(
let (domain, port) = match options.transport() {
#[cfg(feature = "websocket")]
Transport::Ws => split_url(&options.broker_addr)?,
#[cfg(all(feature = "use-rustls", feature = "websocket"))]
#[cfg(all(any(feature = "use-rustls", feature = "use-native-tls"), feature = "websocket"))]
Transport::Wss(_) => split_url(&options.broker_addr)?,
_ => options.broker_address(),
};
Expand Down Expand Up @@ -450,7 +450,7 @@ async fn network_connect(
options.max_outgoing_packet_size,
)
}
#[cfg(all(feature = "use-rustls", feature = "websocket"))]
#[cfg(all(any(feature = "use-rustls", feature = "use-native-tls"), feature = "websocket"))]
Transport::Wss(tls_config) => {
let mut request = options.broker_addr.as_str().into_client_request()?;
request
Expand All @@ -461,6 +461,11 @@ async fn network_connect(
request = request_modifier(request).await;
}

// Accept only one of tls features to avoid conflicts.
// When native-tls is enabled, rustls is as disabled.
#[cfg(feature = "use-native-tls")]
let connector = tls::native_tls_connector(&tls_config).await?;
#[cfg(all(feature = "use-rustls", not(feature = "use-native-tls")))]
let connector = tls::rustls_connector(&tls_config).await?;

let (socket, response) = async_tungstenite::tokio::client_async_tls_with_connector(
Expand Down
13 changes: 8 additions & 5 deletions rumqttc/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,10 @@ extern crate log;

use std::fmt::{self, Debug, Formatter};

#[cfg(any(feature = "use-rustls", feature = "websocket"))]
#[cfg(any(
feature = "use-rustls",
feature = "websocket"
))]
use std::sync::Arc;

use std::time::Duration;
Expand Down Expand Up @@ -233,8 +236,8 @@ pub enum Transport {
#[cfg(feature = "websocket")]
#[cfg_attr(docsrs, doc(cfg(feature = "websocket")))]
Ws,
#[cfg(all(feature = "use-rustls", feature = "websocket"))]
#[cfg_attr(docsrs, doc(cfg(all(feature = "use-rustls", feature = "websocket"))))]
#[cfg(all(any(feature = "use-rustls", feature = "use-native-tls"), feature = "websocket"))]
#[cfg_attr(docsrs, doc(cfg(all(any(feature = "use-rustls", feature = "use-native-tls"), feature = "websocket"))))]
Wss(TlsConfiguration),
}

Expand Down Expand Up @@ -305,8 +308,8 @@ impl Transport {
Self::wss_with_config(config)
}

#[cfg(all(feature = "use-rustls", feature = "websocket"))]
#[cfg_attr(docsrs, doc(cfg(all(feature = "use-rustls", feature = "websocket"))))]
#[cfg(all(any(feature = "use-rustls", feature = "use-native-tls"), feature = "websocket"))]
#[cfg_attr(docsrs, doc(cfg(all(any(feature = "use-rustls", feature = "use-native-tls"), feature = "websocket"))))]
pub fn wss_with_config(tls_config: TlsConfiguration) -> Self {
Self::Wss(tls_config)
}
Expand Down
9 changes: 7 additions & 2 deletions rumqttc/src/v5/eventloop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ async fn network_connect(options: &MqttOptions) -> Result<Network, ConnectionErr
let (domain, port) = match options.transport() {
#[cfg(feature = "websocket")]
Transport::Ws => split_url(&options.broker_addr)?,
#[cfg(all(feature = "use-rustls", feature = "websocket"))]
#[cfg(all(any(feature = "use-rustls", feature = "use-native-tls"), feature = "websocket"))]
Transport::Wss(_) => split_url(&options.broker_addr)?,
_ => options.broker_address(),
};
Expand Down Expand Up @@ -366,7 +366,7 @@ async fn network_connect(options: &MqttOptions) -> Result<Network, ConnectionErr

Network::new(WsStream::new(socket), max_incoming_pkt_size)
}
#[cfg(all(feature = "use-rustls", feature = "websocket"))]
#[cfg(all(any(feature = "use-rustls", feature = "use-native-tls"), feature = "websocket"))]
Transport::Wss(tls_config) => {
let mut request = options.broker_addr.as_str().into_client_request()?;
request
Expand All @@ -377,6 +377,11 @@ async fn network_connect(options: &MqttOptions) -> Result<Network, ConnectionErr
request = request_modifier(request).await;
}

// Accept only one of tls features to avoid conflicts.
// When native-tls is enabled, rustls is as disabled.
#[cfg(feature = "use-native-tls")]
let connector = tls::native_tls_connector(&tls_config).await?;
#[cfg(all(feature = "use-rustls", not(feature = "use-native-tls")))]
let connector = tls::rustls_connector(&tls_config).await?;

let (socket, response) = async_tungstenite::tokio::client_async_tls_with_connector(
Expand Down