Skip to content

Commit

Permalink
refactor: standardize http errors handling in signer aggregator_client
Browse files Browse the repository at this point in the history
* Always include the status code text in the issued error, this should add
  more context when the response text is empty.
* When the response contains json: try to parse ideally as a `ClientError`
  or `ServerError` to use their properties, if of an unknown type it will
  output all the json key/value pairs as a fallback.
  • Loading branch information
Alenar committed Oct 9, 2024
1 parent 2d6d17b commit ed87bdc
Show file tree
Hide file tree
Showing 3 changed files with 215 additions and 24 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions mithril-signer/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ tikv-jemallocator = { version = "0.6.0", optional = true }

[dev-dependencies]
criterion = { version = "0.5.1", features = ["html_reports", "async_tokio"] }
http = "1.1.0"
httpmock = "0.7.0"
mithril-common = { path = "../mithril-common" }
mockall = "0.13.0"
Expand Down
237 changes: 213 additions & 24 deletions mithril-signer/src/services/aggregator_client.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
use anyhow::anyhow;
use async_trait::async_trait;
use reqwest::header::{self, HeaderValue};
use reqwest::{self, Client, Proxy, RequestBuilder, Response, StatusCode};
use slog::{debug, Logger};
use std::{io, sync::Arc, time::Duration};
use thiserror::Error;

use mithril_common::{
api_version::APIVersionProvider,
entities::{Epoch, ProtocolMessage, SignedEntityType, Signer, SingleSignatures},
entities::{
ClientError, Epoch, ProtocolMessage, ServerError, SignedEntityType, Signer,
SingleSignatures,
},
logging::LoggerExtensions,
messages::{
AggregatorFeaturesMessage, EpochSettingsMessage, TryFromMessageAdapter, TryToMessageAdapter,
Expand All @@ -21,6 +25,8 @@ use crate::message_adapters::{
};
use crate::services::SignaturePublisher;

const JSON_CONTENT_TYPE: HeaderValue = HeaderValue::from_static("application/json");

/// Error structure for the Aggregator Client.
#[derive(Error, Debug)]
pub enum AggregatorClientError {
Expand All @@ -36,6 +42,10 @@ pub enum AggregatorClientError {
#[error("remote server unreachable")]
RemoteServerUnreachable(#[source] StdError),

/// Unhandled status code
#[error("unhandled status code: {0}, response text: {1}")]
UnhandledStatusCode(StatusCode, String),

/// Could not parse response.
#[error("json parsing failed")]
JsonParseFailed(#[source] StdError),
Expand Down Expand Up @@ -69,6 +79,65 @@ impl AggregatorClientError {
}
}

impl AggregatorClientError {
/// Create an `AggregatorClientError` from a response.
///
/// This method is meant to be used after handling domain-specific cases leaving only
/// 4xx or 5xx status codes.
/// Otherwise, it will return an `UnhandledStatusCode` error.
pub async fn from_response(response: Response) -> Self {
let error_code = response.status();

if error_code.is_client_error() {
let root_cause = Self::get_root_cause(response).await;
Self::RemoteServerLogical(anyhow!(root_cause))
} else if error_code.is_server_error() {
let root_cause = Self::get_root_cause(response).await;
Self::RemoteServerTechnical(anyhow!(root_cause))
} else {
let response_text = response.text().await.unwrap_or_default();
Self::UnhandledStatusCode(error_code, response_text)
}
}

async fn get_root_cause(response: Response) -> String {
let error_code = response.status();
let canonical_reason = error_code.canonical_reason().unwrap_or_default();
let is_json = response
.headers()
.get(header::CONTENT_TYPE)
.is_some_and(|ct| JSON_CONTENT_TYPE == ct);

if is_json {
let json_value: serde_json::Value = response.json().await.unwrap_or_default();

if let Ok(client_error) = serde_json::from_value::<ClientError>(json_value.clone()) {
format!(
"{}: {}: {}",
canonical_reason.to_lowercase(),
client_error.label,
client_error.message
)
} else if let Ok(server_error) =
serde_json::from_value::<ServerError>(json_value.clone())
{
format!(
"{}: {}",
canonical_reason.to_lowercase(),
server_error.message
)
} else if json_value.is_null() {
canonical_reason.to_lowercase().to_string()
} else {
format!("{}: {}", canonical_reason.to_lowercase(), json_value)
}
} else {
let response_text = response.text().await.unwrap_or_default();
format!("{}: {}", canonical_reason.to_lowercase(), response_text)
}
}
}

/// Trait for mocking and testing a `AggregatorClient`
#[cfg_attr(test, mockall::automock)]
#[async_trait]
Expand Down Expand Up @@ -216,10 +285,7 @@ impl AggregatorClient for AggregatorHTTPClient {
Err(err) => Err(AggregatorClientError::JsonParseFailed(anyhow!(err))),
},
StatusCode::PRECONDITION_FAILED => Err(self.handle_api_error(&response)),
_ => Err(AggregatorClientError::RemoteServerTechnical(anyhow!(
"{}",
response.text().await.unwrap_or_default()
))),
_ => Err(AggregatorClientError::from_response(response).await),
},
Err(err) => Err(AggregatorClientError::RemoteServerUnreachable(anyhow!(err))),
}
Expand All @@ -245,13 +311,7 @@ impl AggregatorClient for AggregatorHTTPClient {
Ok(response) => match response.status() {
StatusCode::CREATED => Ok(()),
StatusCode::PRECONDITION_FAILED => Err(self.handle_api_error(&response)),
StatusCode::BAD_REQUEST => Err(AggregatorClientError::RemoteServerLogical(
anyhow!("bad request: {}", response.text().await.unwrap_or_default()),
)),
_ => Err(AggregatorClientError::RemoteServerTechnical(anyhow!(
"{}",
response.text().await.unwrap_or_default()
))),
_ => Err(AggregatorClientError::from_response(response).await),
},
Err(err) => Err(AggregatorClientError::RemoteServerUnreachable(anyhow!(err))),
}
Expand Down Expand Up @@ -285,16 +345,10 @@ impl AggregatorClient for AggregatorHTTPClient {
Ok(())
}
StatusCode::PRECONDITION_FAILED => Err(self.handle_api_error(&response)),
StatusCode::BAD_REQUEST => Err(AggregatorClientError::RemoteServerLogical(
anyhow!("bad request: {}", response.text().await.unwrap_or_default()),
)),
StatusCode::CONFLICT => Err(AggregatorClientError::RemoteServerLogical(anyhow!(
"already registered single signatures"
))),
_ => Err(AggregatorClientError::RemoteServerTechnical(anyhow!(
"{}",
response.text().await.unwrap_or_default()
))),
_ => Err(AggregatorClientError::from_response(response).await),
},
Err(err) => Err(AggregatorClientError::RemoteServerUnreachable(anyhow!(err))),
}
Expand All @@ -317,10 +371,7 @@ impl AggregatorClient for AggregatorHTTPClient {
.await
.map_err(|e| AggregatorClientError::JsonParseFailed(anyhow!(e)))?),
StatusCode::PRECONDITION_FAILED => Err(self.handle_api_error(&response)),
_ => Err(AggregatorClientError::RemoteServerTechnical(anyhow!(
"{}",
response.text().await.unwrap_or_default()
))),
_ => Err(AggregatorClientError::from_response(response).await),
},
Err(err) => Err(AggregatorClientError::RemoteServerUnreachable(anyhow!(err))),
}
Expand Down Expand Up @@ -426,10 +477,11 @@ pub(crate) mod dumb {

#[cfg(test)]
mod tests {
use http::response::Builder as HttpResponseBuilder;
use httpmock::prelude::*;
use serde_json::json;

use mithril_common::entities::{ClientError, Epoch};
use mithril_common::entities::Epoch;
use mithril_common::era::{EraChecker, SupportedEra};
use mithril_common::messages::TryFromMessageAdapter;
use mithril_common::test_utils::fake_data;
Expand Down Expand Up @@ -494,6 +546,34 @@ mod tests {
});
}

fn build_text_response<T: Into<String>>(status_code: StatusCode, body: T) -> Response {
HttpResponseBuilder::new()
.status(status_code)
.body(body.into())
.unwrap()
.into()
}

fn build_json_response<T: serde::Serialize>(status_code: StatusCode, body: &T) -> Response {
HttpResponseBuilder::new()
.status(status_code)
.header(header::CONTENT_TYPE, JSON_CONTENT_TYPE)
.body(serde_json::to_string(&body).unwrap())
.unwrap()
.into()
}

macro_rules! assert_error_text_contains {
($error: expr, $expect_contains: expr) => {
let error = &$error;
assert!(
error.contains($expect_contains),
"Expected error message to contain '{}'\ngot '{error:?}'",
$expect_contains,
);
};
}

#[tokio::test]
async fn test_aggregator_features_ok_200() {
let (server, client) = setup_server_and_client();
Expand Down Expand Up @@ -1001,4 +1081,113 @@ mod tests {
"unexpected error type: {error:?}"
);
}

#[tokio::test]
async fn test_4xx_errors_are_handled_as_remote_server_logical() {
let response = build_text_response(StatusCode::BAD_REQUEST, "error text");
let handled_error = AggregatorClientError::from_response(response).await;

assert!(
matches!(
handled_error,
AggregatorClientError::RemoteServerLogical(..)
),
"Expected error to be RemoteServerLogical\ngot '{handled_error:?}'",
);
}

#[tokio::test]
async fn test_5xx_errors_are_handled_as_remote_server_technical() {
let response = build_text_response(StatusCode::INTERNAL_SERVER_ERROR, "error text");
let handled_error = AggregatorClientError::from_response(response).await;

assert!(
matches!(
handled_error,
AggregatorClientError::RemoteServerTechnical(..)
),
"Expected error to be RemoteServerLogical\ngot '{handled_error:?}'",
);
}

#[tokio::test]
async fn test_non_4xx_or_5xx_errors_are_handled_as_unhandled_status_code_and_contains_response_text(
) {
let response = build_text_response(StatusCode::OK, "ok text");
let handled_error = AggregatorClientError::from_response(response).await;

assert!(
matches!(
handled_error,
AggregatorClientError::UnhandledStatusCode(..) if format!("{handled_error:?}").contains("ok text")
),
"Expected error to be UnhandledStatusCode with 'ok text' in error text\ngot '{handled_error:?}'",
);
}

#[tokio::test]
async fn test_root_cause_of_non_json_response_contains_response_plain_text() {
let error_text = "An error occurred; please try again later.";
let response = build_text_response(StatusCode::EXPECTATION_FAILED, error_text);

assert_error_text_contains!(
AggregatorClientError::get_root_cause(response).await,
"expectation failed: An error occurred; please try again later."
);
}

#[tokio::test]
async fn test_root_cause_of_json_formatted_client_error_response_contains_error_label_and_message(
) {
let client_error = ClientError::new("label", "message");
let response = build_json_response(StatusCode::BAD_REQUEST, &client_error);

assert_error_text_contains!(
AggregatorClientError::get_root_cause(response).await,
"bad request: label: message"
);
}

#[tokio::test]
async fn test_root_cause_of_json_formatted_server_error_response_contains_error_label_and_message(
) {
let server_error = ServerError::new("message");
let response = build_json_response(StatusCode::BAD_REQUEST, &server_error);

assert_error_text_contains!(
AggregatorClientError::get_root_cause(response).await,
"bad request: message"
);
}

#[tokio::test]
async fn test_root_cause_of_unknown_formatted_json_response_contains_json_key_value_pairs() {
let response = build_json_response(
StatusCode::INTERNAL_SERVER_ERROR,
&json!({ "second": "unknown", "first": "foreign" }),
);

assert_error_text_contains!(
AggregatorClientError::get_root_cause(response).await,
r#"internal server error: {"first":"foreign","second":"unknown"}"#
);
}

#[tokio::test]
async fn test_root_cause_with_invalid_json_response_still_contains_response_status_name() {
let response = HttpResponseBuilder::new()
.status(StatusCode::BAD_REQUEST)
.header(header::CONTENT_TYPE, JSON_CONTENT_TYPE)
.body(r#"{"invalid":"unexpected dot", "key": "value".}"#)
.unwrap()
.into();

let root_cause = AggregatorClientError::get_root_cause(response).await;

assert_error_text_contains!(root_cause, "bad request");
assert!(
!root_cause.contains("bad request: "),
"Expected error message should not contain additional information \ngot '{root_cause:?}'"
);
}
}

0 comments on commit ed87bdc

Please sign in to comment.