From 6d312f6a9f8e2ad5fef083fe8405bd446ffded28 Mon Sep 17 00:00:00 2001 From: Marc Mettke Date: Sun, 28 Jun 2020 23:10:23 +0200 Subject: [PATCH] Remove async_trait and refactor (#108) Resolves #80. --- .travis.yml | 6 +- Cargo.toml | 3 - src/async_internal.rs | 185 ----------------------------- src/lib.rs | 136 +++++++++++++-------- src/{reqwest/mod.rs => reqwest.rs} | 67 +++++++++-- src/reqwest/async_client.rs | 51 -------- src/tests.rs | 6 +- 7 files changed, 148 insertions(+), 306 deletions(-) delete mode 100644 src/async_internal.rs rename src/{reqwest/mod.rs => reqwest.rs} (59%) delete mode 100644 src/reqwest/async_client.rs diff --git a/.travis.yml b/.travis.yml index 38951ddd..ab3d803e 100644 --- a/.travis.yml +++ b/.travis.yml @@ -12,10 +12,8 @@ script: - cargo test --tests --examples - cargo test --doc - cargo test --all-features - # Futures 0.3 with reqwest 0.10 - - cargo test --tests --examples --features futures-03,reqwest-010 --no-default-features - # Futures 0.3 without reqwest (examples will not build) - - cargo test --tests --features futures-03 --no-default-features + # Curl without reqwest (examples will not build) + - cargo test --tests --features curl --no-default-features - cargo audit notifications: email: diff --git a/Cargo.toml b/Cargo.toml index 1618a378..81d3ef31 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,15 +16,12 @@ maintenance = { status = "actively-developed" } [features] default = ["reqwest-010"] -futures-03 = ["futures-0-3", "async-trait"] reqwest-010 = ["reqwest-0-10"] [dependencies] -async-trait = { version = "0.1", optional = true } base64 = "0.12" curl = { version = "0.4.0", optional = true } thiserror="1.0" -futures-0-3 = { version = "0.3", optional = true, package = "futures" } http = "0.2" rand = "0.7" reqwest-0-10 = { version = "0.10", optional = true, features = ["blocking", "rustls-tls"], package = "reqwest", default-features = false } diff --git a/src/async_internal.rs b/src/async_internal.rs deleted file mode 100644 index 28be7e33..00000000 --- a/src/async_internal.rs +++ /dev/null @@ -1,185 +0,0 @@ -use crate::{ - token_response, ClientCredentialsTokenRequest, CodeTokenRequest, ErrorResponse, HttpRequest, - HttpResponse, PasswordTokenRequest, RefreshTokenRequest, RequestTokenError, TokenResponse, - TokenType, -}; -use async_trait::async_trait; -use std::error::Error; -use futures_0_3::Future; - -/// -/// Asynchronous request to exchange an authorization code for an access token. -/// -#[async_trait] -pub trait AsyncCodeTokenRequest -where - TE: ErrorResponse + 'static, - TR: TokenResponse + Send, - TT: TokenType + Send, -{ - /// - /// Asynchronously sends the request to the authorization server. - /// - async fn request_async(self, http_client: C) -> Result> - where - C: FnOnce(HttpRequest) -> F + Send, - F: Future> + Send, - RE: Error + Send + Sync + 'static; -} - -#[async_trait] -impl AsyncCodeTokenRequest for CodeTokenRequest<'_, TE, TR, TT> -where - TE: ErrorResponse + 'static, - TR: TokenResponse + Send, - TT: TokenType + Send, -{ - /// - /// Asynchronously sends the request to the authorization server and returns a Future. - /// - async fn request_async(self, http_client: C) -> Result> - where - C: FnOnce(HttpRequest) -> F + Send, - F: Future> + Send, - RE: Error + Send + Sync + 'static, - { - let http_request = self.prepare_request()?; - let http_response = http_client(http_request) - .await - .map_err(RequestTokenError::Request)?; - token_response(http_response) - } -} - -/// -/// Asynchronous request to exchange a refresh token for an access token. -/// -#[async_trait] -pub trait AsyncRefreshTokenRequest -where - TE: ErrorResponse + 'static, - TR: TokenResponse + Send, - TT: TokenType + Send, -{ - /// - /// Asynchronously sends the request to the authorization server and awaits a response. - /// - async fn request_async(self, http_client: C) -> Result> - where - C: FnOnce(HttpRequest) -> F + Send, - F: Future> + Send, - RE: Error + Send + Sync + 'static; -} - -#[async_trait] -impl AsyncRefreshTokenRequest for RefreshTokenRequest<'_, TE, TR, TT> -where - TE: ErrorResponse + 'static, - TR: TokenResponse + Send, - TT: TokenType + Send, -{ - /// - /// Asynchronously sends the request to the authorization server and awaits a response. - /// - async fn request_async(self, http_client: C) -> Result> - where - C: FnOnce(HttpRequest) -> F + Send, - F: Future> + Send, - RE: Error + Send + Sync + 'static, - { - let http_request = self.prepare_request()?; - let http_response = http_client(http_request) - .await - .map_err(RequestTokenError::Request)?; - token_response(http_response) - } -} - -/// -/// Asynchronous request to exchange resource owner credentials for an access token. -/// -#[async_trait] -pub trait AsyncPasswordTokenRequest -where - TE: ErrorResponse + 'static, - TR: TokenResponse + Send, - TT: TokenType + Send, -{ - /// - /// Asynchronously sends the request to the authorization server and awaits a response. - /// - async fn request_async(self, http_client: C) -> Result> - where - C: FnOnce(HttpRequest) -> F + Send, - F: Future> + Send, - RE: Error + Send + Sync + 'static; -} - -#[async_trait] -impl AsyncPasswordTokenRequest for PasswordTokenRequest<'_, TE, TR, TT> -where - TE: ErrorResponse + 'static, - TR: TokenResponse + Send, - TT: TokenType + Send, -{ - /// - /// Asynchronously sends the request to the authorization server and awaits a response. - /// - async fn request_async(self, http_client: C) -> Result> - where - C: FnOnce(HttpRequest) -> F + Send, - F: Future> + Send, - RE: Error + Send + Sync + 'static, - { - let http_request = self.prepare_request()?; - let http_response = http_client(http_request) - .await - .map_err(RequestTokenError::Request)?; - token_response(http_response) - } -} - -/// -/// Asynchronous request to exchange client credentials for an access token. -/// -#[async_trait] -pub trait AsyncClientCredentialsTokenRequest -where - TE: ErrorResponse + 'static, - TR: TokenResponse + Send, - TT: TokenType + Send, -{ - /// - /// Asynchronously sends the request to the authorization server and awaits a response. - /// - async fn request_async(self, http_client: C) -> Result> - where - C: FnOnce(HttpRequest) -> F + Send, - F: Future> + Send, - RE: Error + Send + Sync + 'static; -} - -#[async_trait] -impl AsyncClientCredentialsTokenRequest - for ClientCredentialsTokenRequest<'_, TE, TR, TT> -where - TE: ErrorResponse + 'static, - TR: TokenResponse + Send, - TT: TokenType + Send, -{ - /// - /// Asynchronously sends the request to the authorization server and awaits a response. - /// - async fn request_async(self, http_client: C) -> Result> - where - C: FnOnce(HttpRequest) -> F + Send, - F: Future> + Send, - RE: Error + Send + Sync + 'static, - { - let http_request = self.prepare_request()?; - let http_response = http_client(http_request) - .await - .map_err(RequestTokenError::Request)?; - token_response(http_response) - } -} diff --git a/src/lib.rs b/src/lib.rs index ee42a5e9..a068fcb2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -25,16 +25,6 @@ //! ```toml //! oauth2 = "3.0" //! ``` -//! * **Async/await via `futures` 0.3** -//! -//! Support is enabled via the `futures-03` feature flag. If desired, the -//! `reqwest-010` feature flag can be used to enable `reqwest` 0.10 and its async/await -//! client interface. -//! -//! Example import in `Cargo.toml`: -//! ```toml -//! oauth2 = { version = "3.0", features = ["futures-03"], default-features = false } -//! ``` //! //! For the HTTP client modes described above, the following HTTP client implementations can be //! used: @@ -68,21 +58,12 @@ //! ```ignore //! FnOnce(HttpRequest) -> Result //! where RE: std::error::Error + Send + Sync + 'static -//! ``` //! -//! Asynchronous `futures` 0.1 HTTP clients should implement the following trait: +//! Async/await HTTP clients should implement the following trait: //! ```ignore //! FnOnce(HttpRequest) -> F //! where -//! F: Future, -//! RE: std::error::Error + Send + Sync + 'static -//! ``` -//! -//! Async/await `futures` 0.3 HTTP clients should implement the following trait: -//! ```ignore -//! FnOnce(HttpRequest) -> F + Send -//! where -//! F: Future> + Send, +//! F: Future>, //! RE: std::error::Error + Send + Sync + 'static //! ``` //! @@ -163,18 +144,11 @@ //! //! ## Example: Async/Await API //! -//! In order to use async/await, include `oauth2` as follows: -//! -//! ```toml -//! [dependencies] -//! oauth2 = { version = "3.0", features = ["futures-03", "reqwest-010"], default-features = false } -//! ``` +//! One can use async/await as follows: //! //! ```rust,no_run //! use anyhow; -//! # #[cfg(feature = "futures-03")] //! use oauth2::{ -//! AsyncCodeTokenRequest, //! AuthorizationCode, //! AuthUrl, //! ClientId, @@ -187,11 +161,11 @@ //! TokenUrl //! }; //! use oauth2::basic::BasicClient; -//! # #[cfg(all(feature = "futures-03", feature = "reqwest-010"))] +//! # #[cfg(feature = "reqwest-010")] //! use oauth2::reqwest::async_http_client; //! use url::Url; //! -//! # #[cfg(all(feature = "futures-03", feature = "reqwest-010"))] +//! # #[cfg(feature = "reqwest-010")] //! # async fn err_wrapper() -> Result<(), anyhow::Error> { //! // Create an OAuth2 client by specifying the client ID, client secret, authorization URL and //! // token URL. @@ -383,11 +357,12 @@ //! - [`actix-web-oauth2`](https://github.com/pka/actix-web-oauth2) (version 2.x of this crate) //! use std::borrow::Cow; +use std::error::Error; use std::fmt::Error as FormatterError; use std::fmt::{Debug, Display, Formatter}; +use std::future::Future; use std::marker::PhantomData; use std::time::Duration; -use std::error::Error; use http::header::{HeaderMap, HeaderValue, ACCEPT, AUTHORIZATION, CONTENT_TYPE}; use http::status::StatusCode; @@ -395,12 +370,6 @@ use serde::de::DeserializeOwned; use serde::{Deserialize, Serialize}; use url::{form_urlencoded, Url}; -/// -/// Async/await module, requires "futures-03" feature. -/// -#[cfg(feature = "futures-03")] -mod async_internal; - /// /// Basic OAuth2 implementation with no extensions /// ([RFC 6749](https://tools.ietf.org/html/rfc6749)). @@ -443,12 +412,6 @@ pub use types::{ ResourceOwnerUsername, ResponseType, Scope, TokenUrl, }; -#[cfg(feature = "futures-03")] -pub use async_internal::{ - AsyncClientCredentialsTokenRequest, AsyncCodeTokenRequest, AsyncPasswordTokenRequest, - AsyncRefreshTokenRequest, -}; - const CONTENT_TYPE_JSON: &str = "application/json"; const CONTENT_TYPE_FORMENCODED: &str = "application/x-www-form-urlencoded"; @@ -490,7 +453,7 @@ where impl Client where - TE: ErrorResponse, + TE: ErrorResponse + 'static, TR: TokenResponse, TT: TokenType, { @@ -863,7 +826,7 @@ where } impl<'a, TE, TR, TT> CodeTokenRequest<'a, TE, TR, TT> where - TE: ErrorResponse, + TE: ErrorResponse + 'static, TR: TokenResponse, TT: TokenType, { @@ -940,6 +903,25 @@ where .map_err(RequestTokenError::Request) .and_then(token_response) } + + /// + /// Asynchronously sends the request to the authorization server and returns a Future. + /// + pub async fn request_async( + self, + http_client: C, + ) -> Result> + where + C: FnOnce(HttpRequest) -> F, + F: Future>, + RE: Error + Send + Sync + 'static, + { + let http_request = self.prepare_request()?; + let http_response = http_client(http_request) + .await + .map_err(RequestTokenError::Request)?; + token_response(http_response) + } } /// @@ -965,7 +947,7 @@ where } impl<'a, TE, TR, TT> RefreshTokenRequest<'a, TE, TR, TT> where - TE: ErrorResponse, + TE: ErrorResponse + 'static, TR: TokenResponse, TT: TokenType, { @@ -1013,6 +995,24 @@ where .map_err(RequestTokenError::Request) .and_then(token_response) } + /// + /// Asynchronously sends the request to the authorization server and awaits a response. + /// + pub async fn request_async( + self, + http_client: C, + ) -> Result> + where + C: FnOnce(HttpRequest) -> F, + F: Future>, + RE: Error + Send + Sync + 'static, + { + let http_request = self.prepare_request()?; + let http_response = http_client(http_request) + .await + .map_err(RequestTokenError::Request)?; + token_response(http_response) + } fn prepare_request(&self) -> Result> where @@ -1059,7 +1059,7 @@ where } impl<'a, TE, TR, TT> PasswordTokenRequest<'a, TE, TR, TT> where - TE: ErrorResponse, + TE: ErrorResponse + 'static, TR: TokenResponse, TT: TokenType, { @@ -1108,6 +1108,25 @@ where .and_then(token_response) } + /// + /// Asynchronously sends the request to the authorization server and awaits a response. + /// + pub async fn request_async( + self, + http_client: C, + ) -> Result> + where + C: FnOnce(HttpRequest) -> F, + F: Future>, + RE: Error + Send + Sync + 'static, + { + let http_request = self.prepare_request()?; + let http_response = http_client(http_request) + .await + .map_err(RequestTokenError::Request)?; + token_response(http_response) + } + fn prepare_request(&self) -> Result> where RE: Error + Send + Sync + 'static, @@ -1152,7 +1171,7 @@ where } impl<'a, TE, TR, TT> ClientCredentialsTokenRequest<'a, TE, TR, TT> where - TE: ErrorResponse, + TE: ErrorResponse + 'static, TR: TokenResponse, TT: TokenType, { @@ -1201,6 +1220,25 @@ where .and_then(token_response) } + /// + /// Asynchronously sends the request to the authorization server and awaits a response. + /// + pub async fn request_async( + self, + http_client: C, + ) -> Result> + where + C: FnOnce(HttpRequest) -> F, + F: Future>, + RE: Error + Send + Sync + 'static, + { + let http_request = self.prepare_request()?; + let http_response = http_client(http_request) + .await + .map_err(RequestTokenError::Request)?; + token_response(http_response) + } + fn prepare_request(&self) -> Result> where RE: Error + Send + Sync + 'static, diff --git a/src/reqwest/mod.rs b/src/reqwest.rs similarity index 59% rename from src/reqwest/mod.rs rename to src/reqwest.rs index a9fc690c..f5cb65de 100644 --- a/src/reqwest/mod.rs +++ b/src/reqwest.rs @@ -22,34 +22,25 @@ where Other(String), } -#[cfg(feature = "reqwest-010")] pub use blocking::http_client; /// /// Error type returned by failed reqwest blocking HTTP requests. -/// Requires "reqwest-010" feature. /// -#[cfg(feature = "reqwest-010")] pub type HttpClientError = Error; -#[cfg(all(feature = "futures-03", feature = "reqwest-010"))] pub use async_client::async_http_client; + /// /// Error type returned by failed reqwest async HTTP requests. -/// Requires "futures-03" and "reqwest-010" feature. /// -#[cfg(all(feature = "futures-03", feature = "reqwest-010"))] pub type AsyncHttpClientError = Error; -#[cfg(feature = "reqwest-010")] mod blocking { use super::super::{HttpRequest, HttpResponse}; use super::Error; - #[cfg(feature = "reqwest-010")] pub use reqwest_0_10 as reqwest; - #[cfg(feature = "reqwest-010")] use reqwest_0_10::blocking; - #[cfg(feature = "reqwest-010")] use reqwest_0_10::redirect::Policy as RedirectPolicy; use std::io::Read; @@ -103,5 +94,57 @@ mod blocking { } } -#[cfg(all(feature = "reqwest-010", feature = "futures-03"))] -mod async_client; +mod async_client { + use super::super::{HttpRequest, HttpResponse}; + use super::Error; + + pub use reqwest_0_10 as reqwest; + use reqwest_0_10::redirect::Policy as RediretPolicy; + + use http::header::HeaderName; + use http::{HeaderMap, HeaderValue, StatusCode}; + + /// + /// Asynchronous HTTP client. + /// + pub async fn async_http_client( + request: HttpRequest, + ) -> Result> { + let client = reqwest::Client::builder() + // Following redirects opens the client up to SSRF vulnerabilities. + .redirect(RediretPolicy::none()) + .build() + .map_err(Error::Reqwest)?; + + let mut request_builder = client + .request(request.method, request.url.as_str()) + .body(request.body); + for (name, value) in &request.headers { + request_builder = request_builder.header(name.as_str(), value.as_bytes()); + } + let request = request_builder.build().map_err(Error::Reqwest)?; + + let response = client.execute(request).await.map_err(Error::Reqwest)?; + + let status_code = response.status(); + let headers = response + .headers() + .iter() + .map(|(name, value)| { + ( + HeaderName::from_bytes(name.as_str().as_ref()) + .expect("failed to convert HeaderName from http 0.2 to 0.1"), + HeaderValue::from_bytes(value.as_bytes()) + .expect("failed to convert HeaderValue from http 0.2 to 0.1"), + ) + }) + .collect::(); + let chunks = response.bytes().await.map_err(Error::Reqwest)?; + Ok(HttpResponse { + status_code: StatusCode::from_u16(status_code.as_u16()) + .expect("failed to convert StatusCode from http 0.2 to 0.1"), + headers, + body: chunks.to_vec(), + }) + } +} diff --git a/src/reqwest/async_client.rs b/src/reqwest/async_client.rs deleted file mode 100644 index fdaa052f..00000000 --- a/src/reqwest/async_client.rs +++ /dev/null @@ -1,51 +0,0 @@ -use super::super::{HttpRequest, HttpResponse}; -use super::Error; - -use http::header::HeaderName; -use http::{HeaderMap, HeaderValue, StatusCode}; - -pub use reqwest_0_10 as reqwest; - -/// -/// Asynchronous HTTP client. -/// -pub async fn async_http_client( - request: HttpRequest, -) -> Result> { - let client = reqwest::Client::builder() - // Following redirects opens the client up to SSRF vulnerabilities. - .redirect(reqwest::redirect::Policy::none()) - .build() - .map_err(Error::Reqwest)?; - - let mut request_builder = client - .request(request.method, request.url.as_str()) - .body(request.body); - for (name, value) in &request.headers { - request_builder = request_builder.header(name.as_str(), value.as_bytes()); - } - let request = request_builder.build().map_err(Error::Reqwest)?; - - let response = client.execute(request).await.map_err(Error::Reqwest)?; - - let status_code = response.status(); - let headers = response - .headers() - .iter() - .map(|(name, value)| { - ( - HeaderName::from_bytes(name.as_str().as_ref()) - .expect("failed to convert HeaderName from http 0.2 to 0.1"), - HeaderValue::from_bytes(value.as_bytes()) - .expect("failed to convert HeaderValue from http 0.2 to 0.1"), - ) - }) - .collect::(); - let chunks = response.bytes().await.map_err(Error::Reqwest)?; - Ok(HttpResponse { - status_code: StatusCode::from_u16(status_code.as_u16()) - .expect("failed to convert StatusCode from http 0.2 to 0.1"), - headers, - body: chunks.to_vec(), - }) -} diff --git a/src/tests.rs b/src/tests.rs index f6460dac..ad4d2c40 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -1,6 +1,6 @@ -use thiserror::Error; use http::header::{HeaderMap, HeaderName, HeaderValue, ACCEPT, AUTHORIZATION, CONTENT_TYPE}; use http::status::StatusCode; +use thiserror::Error; use url::form_urlencoded::byte_serialize; use url::Url; @@ -236,7 +236,9 @@ fn test_authorize_url_with_redirect_url_override() { let (url, _) = client .authorize_url(|| CsrfToken::new("csrf_token".to_string())) - .set_redirect_url(Cow::Owned(RedirectUrl::new("https://localhost/alternative".to_string()).unwrap())) + .set_redirect_url(Cow::Owned( + RedirectUrl::new("https://localhost/alternative".to_string()).unwrap(), + )) .url(); assert_eq!(