diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 5a22da251..997ea7f40 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -33,6 +33,7 @@ jobs: - run: cargo test -p dicom-pixeldata --features gdcm # test dicom-pixeldata without default features - run: cargo test -p dicom-pixeldata --no-default-features + - run: cargo test -p dicom-ul --features async # run Clippy with stable toolchain - if: matrix.rust == 'stable' run: cargo clippy @@ -60,4 +61,4 @@ jobs: toolchain: stable cache: true - run: cargo check - \ No newline at end of file + diff --git a/Cargo.lock b/Cargo.lock index 4dc8d8257..aa5b33439 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,6 +2,15 @@ # It is not intended for manual editing. version = 3 +[[package]] +name = "addr2line" +version = "0.22.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e4503c46a5c0c7844e948c9a4d6acd9f50cccb4de1c48eb9e291ea17470c678" +dependencies = [ + "gimli", +] + [[package]] name = "adler" version = "1.0.2" @@ -93,6 +102,21 @@ version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0c4b4d0bd25bd0b74681c0ad21497610ce1b7c91b1022cd21c80c6fbdd9476b0" +[[package]] +name = "backtrace" +version = "0.3.73" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5cc23269a4f8976d0a4d2e7109211a419fe30e8d88d677cd60b6bc79c5732e0a" +dependencies = [ + "addr2line", + "cc", + "cfg-if", + "libc", + "miniz_oxide", + "object", + "rustc-demangle", +] + [[package]] name = "base64" version = "0.22.1" @@ -159,6 +183,12 @@ dependencies = [ "byteorder", ] +[[package]] +name = "bytes" +version = "1.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a12916984aab3fa6e39d655a33e09c0071eb36d6ab3aea5c2d78551f1df6d952" + [[package]] name = "cc" version = "1.1.10" @@ -558,6 +588,7 @@ dependencies = [ "dicom-transfer-syntax-registry", "dicom-ul", "snafu", + "tokio", "tracing", "tracing-subscriber", ] @@ -576,6 +607,7 @@ dependencies = [ "dicom-ul", "indicatif", "snafu", + "tokio", "tracing", "tracing-subscriber", "walkdir", @@ -628,10 +660,12 @@ name = "dicom-ul" version = "0.7.1" dependencies = [ "byteordered", + "bytes", "dicom-encoding", "dicom-transfer-syntax-registry", "matches", "snafu", + "tokio", "tracing", ] @@ -942,6 +976,12 @@ dependencies = [ "wasi", ] +[[package]] +name = "gimli" +version = "0.29.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "40ecd4077b5ae9fd2e9e169b102c6c330d0605168eb0e8bf79952b256dbefffd" + [[package]] name = "glob" version = "0.3.1" @@ -1237,6 +1277,17 @@ dependencies = [ "simd-adler32", ] +[[package]] +name = "mio" +version = "0.8.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4a650543ca06a924e8b371db273b2756685faae30f8487da1b56505a8f78b0c" +dependencies = [ + "libc", + "wasi", + "windows-sys 0.48.0", +] + [[package]] name = "ndarray" version = "0.15.6" @@ -1287,12 +1338,31 @@ dependencies = [ "autocfg", ] +[[package]] +name = "num_cpus" +version = "1.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4161fcb6d602d4d2081af7c3a45852d875a03dd337a6bfdd6e06407b61342a43" +dependencies = [ + "hermit-abi", + "libc", +] + [[package]] name = "number_prefix" version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3" +[[package]] +name = "object" +version = "0.36.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "081b846d1d56ddfc18fdf1a922e4f6e07a11768ea1b92dec44e42b72712ccfce" +dependencies = [ + "memchr", +] + [[package]] name = "once_cell" version = "1.19.0" @@ -1337,6 +1407,29 @@ dependencies = [ "supports-color", ] +[[package]] +name = "parking_lot" +version = "0.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f1bf18183cf54e8d6059647fc3063646a1801cf30896933ec2311622cc4b9a27" +dependencies = [ + "lock_api", + "parking_lot_core", +] + +[[package]] +name = "parking_lot_core" +version = "0.9.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e401f977ab385c9e4e3ab30627d6f26d00e2c73eef317493c4ec6d468726cf8" +dependencies = [ + "cfg-if", + "libc", + "redox_syscall", + "smallvec", + "windows-targets 0.52.6", +] + [[package]] name = "percent-encoding" version = "2.3.1" @@ -1461,6 +1554,15 @@ dependencies = [ "crossbeam-utils", ] +[[package]] +name = "redox_syscall" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c82cf8cff14456045f55ec4241383baeff27af886adb72ffb2162f99911de0fd" +dependencies = [ + "bitflags 2.6.0", +] + [[package]] name = "regex" version = "1.10.6" @@ -1541,6 +1643,12 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "rustc-demangle" +version = "0.1.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f" + [[package]] name = "rustc_version" version = "0.4.0" @@ -1687,6 +1795,15 @@ dependencies = [ "lazy_static", ] +[[package]] +name = "signal-hook-registry" +version = "1.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a9e9e0b4211b72e7b8b6e85c807d36c212bdb33ea8587f7569562a84df5465b1" +dependencies = [ + "libc", +] + [[package]] name = "simd-adler32" version = "0.3.7" @@ -1729,6 +1846,16 @@ dependencies = [ "syn", ] +[[package]] +name = "socket2" +version = "0.5.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ce305eb0b4296696835b71df73eb912e0f1ffd2556a501fcede6e0c50349191c" +dependencies = [ + "libc", + "windows-sys 0.52.0", +] + [[package]] name = "spin" version = "0.9.8" @@ -1896,6 +2023,36 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" +[[package]] +name = "tokio" +version = "1.38.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba4f4a02a7a80d6f274636f0aa95c7e383b912d41fe721a31f29e29698585a4a" +dependencies = [ + "backtrace", + "bytes", + "libc", + "mio", + "num_cpus", + "parking_lot", + "pin-project-lite", + "signal-hook-registry", + "socket2", + "tokio-macros", + "windows-sys 0.48.0", +] + +[[package]] +name = "tokio-macros" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5f5ae998a069d4b5aba8ee9dad856af7d520c3699e6159b185c2acd48155d39a" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "toml_datetime" version = "0.6.8" diff --git a/scpproxy/src/main.rs b/scpproxy/src/main.rs index 307cdd11c..2696b96c9 100644 --- a/scpproxy/src/main.rs +++ b/scpproxy/src/main.rs @@ -1,7 +1,7 @@ use clap::{crate_version, value_parser, Arg, ArgAction, Command}; -use dicom_ul::pdu::reader::read_pdu; use dicom_ul::pdu::writer::write_pdu; use dicom_ul::pdu::Pdu; +use dicom_ul::association::client::get_client_pdu; use snafu::{Backtrace, OptionExt, Report, ResultExt, Snafu, Whatever}; use std::io::Write; use std::net::{Shutdown, TcpListener, TcpStream}; @@ -63,11 +63,11 @@ pub enum ThreadMessage { }, ReadErr { from: ProviderType, - err: dicom_ul::pdu::reader::Error, + err: dicom_ul::association::client::Error, }, WriteErr { from: ProviderType, - err: dicom_ul::pdu::writer::Error, + err: dicom_ul::pdu::WriteError, }, Shutdown { initiator: ProviderType, @@ -96,7 +96,7 @@ fn run( let message_tx = message_tx.clone(); scu_reader_thread = thread::spawn(move || { loop { - match read_pdu(&mut reader, max_pdu_length, strict) { + match get_client_pdu(&mut reader, max_pdu_length, strict) { Ok(pdu) => { message_tx .send(ThreadMessage::SendPdu { @@ -105,7 +105,7 @@ fn run( }) .context(SendMessageSnafu)?; } - Err(dicom_ul::pdu::reader::Error::NoPduAvailable { .. }) => { + Err(dicom_ul::association::client::Error::ReceiveResponse{ .. }) => { message_tx .send(ThreadMessage::Shutdown { initiator: ProviderType::Scu, @@ -133,7 +133,7 @@ fn run( let mut reader = scp_stream.try_clone().context(CloneSocketSnafu)?; scp_reader_thread = thread::spawn(move || { loop { - match read_pdu(&mut reader, max_pdu_length, strict) { + match get_client_pdu(&mut reader, max_pdu_length, strict) { Ok(pdu) => { message_tx .send(ThreadMessage::SendPdu { @@ -142,7 +142,7 @@ fn run( }) .context(SendMessageSnafu)?; } - Err(dicom_ul::pdu::reader::Error::NoPduAvailable { .. }) => { + Err(dicom_ul::association::client::Error::ReceiveResponse{ .. }) => { message_tx .send(ThreadMessage::Shutdown { initiator: ProviderType::Scp, diff --git a/storescp/Cargo.toml b/storescp/Cargo.toml index 56ae12f68..91c960ab9 100644 --- a/storescp/Cargo.toml +++ b/storescp/Cargo.toml @@ -13,7 +13,7 @@ readme = "README.md" [dependencies] clap = { version = "4.0.18", features = ["derive"] } dicom-core = { path = '../core', version = "0.7.0" } -dicom-ul = { path = '../ul', version = "0.7.1" } +dicom-ul = { path = '../ul', version = "0.7.1", features = ["async"] } dicom-object = { path = '../object', version = "0.7.1" } dicom-encoding = { path = "../encoding/", version = "0.7.1" } dicom-dictionary-std = { path = "../dictionary-std/", version = "0.7.0" } @@ -21,3 +21,5 @@ dicom-transfer-syntax-registry = { path = "../transfer-syntax-registry/", versio snafu = "0.8" tracing = "0.1.36" tracing-subscriber = "0.3.15" +tokio = { version = "1.38.0", features = ["full"] } + diff --git a/storescp/src/main.rs b/storescp/src/main.rs index 84fc163dd..99d549dc6 100644 --- a/storescp/src/main.rs +++ b/storescp/src/main.rs @@ -1,21 +1,21 @@ use std::{ - net::{Ipv4Addr, SocketAddrV4, TcpListener, TcpStream}, + net::{Ipv4Addr, SocketAddrV4}, path::PathBuf, }; use clap::Parser; use dicom_core::{dicom_value, DataElement, VR}; use dicom_dictionary_std::tags; -use dicom_encoding::transfer_syntax::TransferSyntaxIndex; -use dicom_object::{FileMetaTableBuilder, InMemDicomObject, StandardDataDictionary}; -use dicom_transfer_syntax_registry::TransferSyntaxRegistry; -use dicom_ul::{pdu::PDataValueType, Pdu}; -use snafu::{OptionExt, Report, ResultExt, Whatever}; -use tracing::{debug, error, info, warn, Level}; +use dicom_object::{InMemDicomObject, StandardDataDictionary}; +use snafu::Report; +use tracing::{error, info, Level}; -use crate::transfer::ABSTRACT_SYNTAXES; mod transfer; +mod store_async; +mod store_sync; +use store_async::run_store_async; +use store_sync::run_store_sync; /// DICOM C-STORE SCP #[derive(Debug, Parser)] @@ -45,266 +45,8 @@ struct App { /// Which port to listen on #[arg(short, default_value = "11111")] port: u16, -} - -fn run(scu_stream: TcpStream, args: &App) -> Result<(), Whatever> { - let App { - verbose, - calling_ae_title, - strict, - uncompressed_only, - promiscuous, - max_pdu_length, - out_dir, - port: _, - } = args; - let verbose = *verbose; - - let mut buffer: Vec = Vec::with_capacity(*max_pdu_length as usize); - let mut instance_buffer: Vec = Vec::with_capacity(1024 * 1024); - let mut msgid = 1; - let mut sop_class_uid = "".to_string(); - let mut sop_instance_uid = "".to_string(); - - let mut options = dicom_ul::association::ServerAssociationOptions::new() - .accept_any() - .ae_title(calling_ae_title) - .strict(*strict) - .promiscuous(*promiscuous); - - if *uncompressed_only { - options = options - .with_transfer_syntax("1.2.840.10008.1.2") - .with_transfer_syntax("1.2.840.10008.1.2.1"); - } else { - for ts in TransferSyntaxRegistry.iter() { - if !ts.is_unsupported() { - options = options.with_transfer_syntax(ts.uid()); - } - } - }; - - for uid in ABSTRACT_SYNTAXES { - options = options.with_abstract_syntax(*uid); - } - - let mut association = options - .establish(scu_stream) - .whatever_context("could not establish association")?; - - info!("New association from {}", association.client_ae_title()); - debug!( - "> Presentation contexts: {:?}", - association.presentation_contexts() - ); - - loop { - match association.receive() { - Ok(mut pdu) => { - if verbose { - debug!("scu ----> scp: {}", pdu.short_description()); - } - match pdu { - Pdu::PData { ref mut data } => { - if data.is_empty() { - debug!("Ignoring empty PData PDU"); - continue; - } - - for data_value in data { - if data_value.value_type == PDataValueType::Data && !data_value.is_last - { - instance_buffer.append(&mut data_value.data); - } else if data_value.value_type == PDataValueType::Command - && data_value.is_last - { - // commands are always in implicit VR LE - let ts = - dicom_transfer_syntax_registry::entries::IMPLICIT_VR_LITTLE_ENDIAN - .erased(); - let data_value = &data_value; - let v = &data_value.data; - - let obj = InMemDicomObject::read_dataset_with_ts(v.as_slice(), &ts) - .whatever_context("failed to read incoming DICOM command")?; - let command_field = obj - .element(tags::COMMAND_FIELD) - .whatever_context("Missing Command Field")? - .uint16() - .whatever_context("Command Field is not an integer")?; - - if command_field == 0x0030 { - // Handle C-ECHO-RQ - let cecho_response = create_cecho_response(msgid); - let mut cecho_data = Vec::new(); - - cecho_response - .write_dataset_with_ts(&mut cecho_data, &ts) - .whatever_context( - "could not write C-ECHO response object", - )?; - - let pdu_response = Pdu::PData { - data: vec![dicom_ul::pdu::PDataValue { - presentation_context_id: data_value - .presentation_context_id, - value_type: PDataValueType::Command, - is_last: true, - data: cecho_data, - }], - }; - association.send(&pdu_response).whatever_context( - "failed to send C-ECHO response object to SCU", - )?; - } else { - msgid = obj - .element(tags::MESSAGE_ID) - .whatever_context("Missing Message ID")? - .to_int() - .whatever_context("Message ID is not an integer")?; - sop_class_uid = obj - .element(tags::AFFECTED_SOP_CLASS_UID) - .whatever_context("missing Affected SOP Class UID")? - .to_str() - .whatever_context( - "could not retrieve Affected SOP Class UID", - )? - .to_string(); - sop_instance_uid = obj - .element(tags::AFFECTED_SOP_INSTANCE_UID) - .whatever_context("missing Affected SOP Instance UID")? - .to_str() - .whatever_context( - "could not retrieve Affected SOP Instance UID", - )? - .to_string(); - } - instance_buffer.clear(); - } else if data_value.value_type == PDataValueType::Data - && data_value.is_last - { - instance_buffer.append(&mut data_value.data); - - let presentation_context = association - .presentation_contexts() - .iter() - .find(|pc| pc.id == data_value.presentation_context_id) - .whatever_context("missing presentation context")?; - let ts = &presentation_context.transfer_syntax; - - let obj = InMemDicomObject::read_dataset_with_ts( - instance_buffer.as_slice(), - TransferSyntaxRegistry.get(ts).unwrap(), - ) - .whatever_context("failed to read DICOM data object")?; - let file_meta = FileMetaTableBuilder::new() - .media_storage_sop_class_uid( - obj.element(tags::SOP_CLASS_UID) - .whatever_context("missing SOP Class UID")? - .to_str() - .whatever_context("could not retrieve SOP Class UID")?, - ) - .media_storage_sop_instance_uid( - obj.element(tags::SOP_INSTANCE_UID) - .whatever_context("missing SOP Instance UID")? - .to_str() - .whatever_context("missing SOP Instance UID")?, - ) - .transfer_syntax(ts) - .build() - .whatever_context( - "failed to build DICOM meta file information", - )?; - let file_obj = obj.with_exact_meta(file_meta); - - // write the files to the current directory with their SOPInstanceUID as filenames - let mut file_path = out_dir.clone(); - file_path.push( - sop_instance_uid.trim_end_matches('\0').to_string() + ".dcm", - ); - file_obj - .write_to_file(&file_path) - .whatever_context("could not save DICOM object to file")?; - info!("Stored {}", file_path.display()); - - // send C-STORE-RSP object - // commands are always in implicit VR LE - let ts = - dicom_transfer_syntax_registry::entries::IMPLICIT_VR_LITTLE_ENDIAN - .erased(); - - let obj = create_cstore_response( - msgid, - &sop_class_uid, - &sop_instance_uid, - ); - - let mut obj_data = Vec::new(); - - obj.write_dataset_with_ts(&mut obj_data, &ts) - .whatever_context("could not write response object")?; - - let pdu_response = Pdu::PData { - data: vec![dicom_ul::pdu::PDataValue { - presentation_context_id: data_value.presentation_context_id, - value_type: PDataValueType::Command, - is_last: true, - data: obj_data, - }], - }; - association - .send(&pdu_response) - .whatever_context("failed to send response object to SCU")?; - } - } - } - Pdu::ReleaseRQ => { - buffer.clear(); - association.send(&Pdu::ReleaseRP).unwrap_or_else(|e| { - warn!( - "Failed to send association release message to SCU: {}", - snafu::Report::from_error(e) - ); - }); - info!( - "Released association with {}", - association.client_ae_title() - ); - break; - } - Pdu::AbortRQ { source } => { - warn!("Aborted connection from: {:?}", source); - break; - } - _ => {} - } - } - Err(err @ dicom_ul::association::server::Error::Receive { .. }) => { - if verbose { - info!("{}", Report::from_error(err)); - } else { - info!("{}", err); - } - break; - } - Err(err) => { - warn!("Unexpected error: {}", Report::from_error(err)); - break; - } - } - } - - if let Ok(peer_addr) = association.inner_stream().peer_addr() { - info!( - "Dropping connection with {} ({})", - association.client_ae_title(), - peer_addr - ); - } else { - info!("Dropping connection with {}", association.client_ae_title()); - } - - Ok(()) + #[arg(short, long, default_value = "true")] + blocking: bool } fn create_cstore_response( @@ -355,8 +97,70 @@ fn create_cecho_response(message_id: u16) -> InMemDicomObject Result<(), Box> { - let args = App::parse(); +fn main() { + let app = App::parse(); + if !app.blocking { + tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build() + .unwrap() + .block_on(async move { + run_async(app).await.unwrap_or_else(|e| { + error!("{:?}", e); + std::process::exit(-2); + }); + }); + } else { + run_sync(app).unwrap_or_else(|e| { + error!("{:?}", e); + std::process::exit(-2); + }); + } +} + +async fn run_async(args: App) -> Result<(), Box> { + use std::sync::Arc; + let args = Arc::new(args); + tracing::subscriber::set_global_default( + tracing_subscriber::FmtSubscriber::builder() + .with_max_level(if args.verbose { + Level::DEBUG + } else { + Level::INFO + }) + .finish(), + ) + .unwrap_or_else(|e| { + eprintln!( + "Could not set up global logger: {}", + snafu::Report::from_error(e) + ); + }); + + std::fs::create_dir_all(&args.out_dir).unwrap_or_else(|e| { + error!("Could not create output directory: {}", e); + std::process::exit(-2); + }); + + let listen_addr = SocketAddrV4::new(Ipv4Addr::from(0), args.port); + let listener = tokio::net::TcpListener::bind(listen_addr).await?; + info!( + "{} listening on: tcp://{}", + &args.calling_ae_title, listen_addr + ); + + loop { + let (socket, _addr) = listener.accept().await?; + let args = args.clone(); + tokio::task::spawn(async move { + if let Err(e) = run_store_async(socket, &args).await { + error!("{}", Report::from_error(e)); + } + }); + } +} + +fn run_sync(args: App) -> Result<(), Box> { tracing::subscriber::set_global_default( tracing_subscriber::FmtSubscriber::builder() @@ -380,7 +184,7 @@ fn main() -> Result<(), Box> { }); let listen_addr = SocketAddrV4::new(Ipv4Addr::from(0), args.port); - let listener = TcpListener::bind(listen_addr)?; + let listener = std::net::TcpListener::bind(listen_addr)?; info!( "{} listening on: tcp://{}", &args.calling_ae_title, listen_addr @@ -389,7 +193,7 @@ fn main() -> Result<(), Box> { for stream in listener.incoming() { match stream { Ok(scu_stream) => { - if let Err(e) = run(scu_stream, &args) { + if let Err(e) = run_store_sync(scu_stream, &args) { error!("{}", snafu::Report::from_error(e)); } } diff --git a/storescp/src/store_async.rs b/storescp/src/store_async.rs new file mode 100644 index 000000000..9407037ff --- /dev/null +++ b/storescp/src/store_async.rs @@ -0,0 +1,271 @@ +use dicom_dictionary_std::tags; +use dicom_encoding::transfer_syntax::TransferSyntaxIndex; +use dicom_object::{FileMetaTableBuilder, InMemDicomObject}; +use dicom_transfer_syntax_registry::TransferSyntaxRegistry; +use dicom_ul::{pdu::PDataValueType, Pdu}; +use snafu::{OptionExt, Report, ResultExt, Whatever}; +use tracing::{debug, info, warn}; + +use crate::{transfer::ABSTRACT_SYNTAXES, App, create_cecho_response, create_cstore_response}; +pub async fn run_store_async(scu_stream: tokio::net::TcpStream, args: &App) -> Result<(), Whatever> { + let App { + verbose, + calling_ae_title, + strict, + uncompressed_only, + promiscuous, + max_pdu_length, + out_dir, + port: _, + blocking: _, + } = args; + let verbose = *verbose; + + let mut buffer: Vec = Vec::with_capacity(*max_pdu_length as usize); + let mut instance_buffer: Vec = Vec::with_capacity(1024 * 1024); + let mut msgid = 1; + let mut sop_class_uid = "".to_string(); + let mut sop_instance_uid = "".to_string(); + + let mut options = dicom_ul::association::ServerAssociationOptions::new() + .accept_any() + .ae_title(calling_ae_title) + .strict(*strict) + .promiscuous(*promiscuous); + + if *uncompressed_only { + options = options + .with_transfer_syntax("1.2.840.10008.1.2") + .with_transfer_syntax("1.2.840.10008.1.2.1"); + } else { + for ts in TransferSyntaxRegistry.iter() { + if !ts.is_unsupported() { + options = options.with_transfer_syntax(ts.uid()); + } + } + }; + + for uid in ABSTRACT_SYNTAXES { + options = options.with_abstract_syntax(*uid); + } + + let mut association = options + .establish_async(scu_stream) + .await + .whatever_context("could not establish association")?; + + info!("New association from {}", association.client_ae_title()); + debug!( + "> Presentation contexts: {:?}", + association.presentation_contexts() + ); + + loop { + match association.receive().await { + Ok(mut pdu) => { + if verbose { + debug!("scu ----> scp: {}", pdu.short_description()); + } + match pdu { + Pdu::PData { ref mut data } => { + if data.is_empty() { + debug!("Ignoring empty PData PDU"); + continue; + } + + for data_value in data { + if data_value.value_type == PDataValueType::Data && !data_value.is_last + { + instance_buffer.append(&mut data_value.data); + } else if data_value.value_type == PDataValueType::Command + && data_value.is_last + { + // commands are always in implict VR LE + let ts = + dicom_transfer_syntax_registry::entries::IMPLICIT_VR_LITTLE_ENDIAN + .erased(); + let data_value = &data_value; + let v = &data_value.data; + + let obj = InMemDicomObject::read_dataset_with_ts(v.as_slice(), &ts) + .whatever_context("failed to read incoming DICOM command")?; + let command_field = obj + .element(tags::COMMAND_FIELD) + .whatever_context("Missing Command Field")? + .uint16() + .whatever_context("Command Field is not an integer")?; + + if command_field == 0x0030 { + // Handle C-ECHO-RQ + let cecho_response = create_cecho_response(msgid); + let mut cecho_data = Vec::new(); + + cecho_response + .write_dataset_with_ts(&mut cecho_data, &ts) + .whatever_context( + "could not write C-ECHO response object", + )?; + + let pdu_response = Pdu::PData { + data: vec![dicom_ul::pdu::PDataValue { + presentation_context_id: data_value + .presentation_context_id, + value_type: PDataValueType::Command, + is_last: true, + data: cecho_data, + }], + }; + association.send(&pdu_response).await.whatever_context( + "failed to send C-ECHO response object to SCU", + )?; + } else { + msgid = obj + .element(tags::MESSAGE_ID) + .whatever_context("Missing Message ID")? + .to_int() + .whatever_context("Message ID is not an integer")?; + sop_class_uid = obj + .element(tags::AFFECTED_SOP_CLASS_UID) + .whatever_context("missing Affected SOP Class UID")? + .to_str() + .whatever_context( + "could not retrieve Affected SOP Class UID", + )? + .to_string(); + sop_instance_uid = obj + .element(tags::AFFECTED_SOP_INSTANCE_UID) + .whatever_context("missing Affected SOP Instance UID")? + .to_str() + .whatever_context( + "could not retrieve Affected SOP Instance UID", + )? + .to_string(); + } + instance_buffer.clear(); + } else if data_value.value_type == PDataValueType::Data + && data_value.is_last + { + instance_buffer.append(&mut data_value.data); + + let presentation_context = association + .presentation_contexts() + .iter() + .find(|pc| pc.id == data_value.presentation_context_id) + .whatever_context("missing presentation context")?; + let ts = &presentation_context.transfer_syntax; + + let obj = InMemDicomObject::read_dataset_with_ts( + instance_buffer.as_slice(), + TransferSyntaxRegistry.get(ts).unwrap(), + ) + .whatever_context("failed to read DICOM data object")?; + let file_meta = FileMetaTableBuilder::new() + .media_storage_sop_class_uid( + obj.element(tags::SOP_CLASS_UID) + .whatever_context("missing SOP Class UID")? + .to_str() + .whatever_context("could not retrieve SOP Class UID")?, + ) + .media_storage_sop_instance_uid( + obj.element(tags::SOP_INSTANCE_UID) + .whatever_context("missing SOP Instance UID")? + .to_str() + .whatever_context("missing SOP Instance UID")?, + ) + .transfer_syntax(ts) + .build() + .whatever_context( + "failed to build DICOM meta file information", + )?; + let file_obj = obj.with_exact_meta(file_meta); + + // write the files to the current directory with their SOPInstanceUID as filenames + let mut file_path = out_dir.clone(); + file_path.push( + sop_instance_uid.trim_end_matches('\0').to_string() + ".dcm", + ); + file_obj + .write_to_file(&file_path) + .whatever_context("could not save DICOM object to file")?; + info!("Stored {}", file_path.display()); + + // send C-STORE-RSP object + // commands are always in implict VR LE + let ts = + dicom_transfer_syntax_registry::entries::IMPLICIT_VR_LITTLE_ENDIAN + .erased(); + + let obj = create_cstore_response( + msgid, + &sop_class_uid, + &sop_instance_uid, + ); + + let mut obj_data = Vec::new(); + + obj.write_dataset_with_ts(&mut obj_data, &ts) + .whatever_context("could not write response object")?; + + let pdu_response = Pdu::PData { + data: vec![dicom_ul::pdu::PDataValue { + presentation_context_id: data_value.presentation_context_id, + value_type: PDataValueType::Command, + is_last: true, + data: obj_data, + }], + }; + association + .send(&pdu_response) + .await + .whatever_context("failed to send response object to SCU")?; + } + } + } + Pdu::ReleaseRQ => { + buffer.clear(); + association.send(&Pdu::ReleaseRP).await.unwrap_or_else(|e| { + warn!( + "Failed to send association release message to SCU: {}", + snafu::Report::from_error(e) + ); + }); + info!( + "Released association with {}", + association.client_ae_title() + ); + break; + } + Pdu::AbortRQ { source } => { + warn!("Aborted connection from: {:?}", source); + break; + } + _ => {} + } + } + Err(err @ dicom_ul::association::server::Error::Receive { .. }) => { + if verbose { + info!("{}", Report::from_error(err)); + } else { + info!("{}", err); + } + break; + } + Err(err) => { + warn!("Unexpected error: {}", Report::from_error(err)); + break; + } + } + } + + if let Ok(peer_addr) = association.inner_stream().peer_addr() { + info!( + "Dropping connection with {} ({})", + association.client_ae_title(), + peer_addr + ); + } else { + info!("Dropping connection with {}", association.client_ae_title()); + } + + Ok(()) +} \ No newline at end of file diff --git a/storescp/src/store_sync.rs b/storescp/src/store_sync.rs new file mode 100644 index 000000000..a5348d2f4 --- /dev/null +++ b/storescp/src/store_sync.rs @@ -0,0 +1,271 @@ +use std::net::TcpStream; + +use dicom_dictionary_std::tags; +use dicom_encoding::transfer_syntax::TransferSyntaxIndex; +use dicom_object::{FileMetaTableBuilder, InMemDicomObject}; +use dicom_transfer_syntax_registry::TransferSyntaxRegistry; +use dicom_ul::{pdu::PDataValueType, Pdu}; +use snafu::{OptionExt, Report, ResultExt, Whatever}; +use tracing::{debug, info, warn}; + +use crate::{create_cecho_response, create_cstore_response, transfer::ABSTRACT_SYNTAXES, App}; +pub fn run_store_sync(scu_stream: TcpStream, args: &App) -> Result<(), Whatever> { + let App { + verbose, + calling_ae_title, + strict, + uncompressed_only, + promiscuous, + max_pdu_length, + out_dir, + port: _, + blocking: _, + } = args; + let verbose = *verbose; + + let mut buffer: Vec = Vec::with_capacity(*max_pdu_length as usize); + let mut instance_buffer: Vec = Vec::with_capacity(1024 * 1024); + let mut msgid = 1; + let mut sop_class_uid = "".to_string(); + let mut sop_instance_uid = "".to_string(); + + let mut options = dicom_ul::association::ServerAssociationOptions::new() + .accept_any() + .ae_title(calling_ae_title) + .strict(*strict) + .promiscuous(*promiscuous); + + if *uncompressed_only { + options = options + .with_transfer_syntax("1.2.840.10008.1.2") + .with_transfer_syntax("1.2.840.10008.1.2.1"); + } else { + for ts in TransferSyntaxRegistry.iter() { + if !ts.is_unsupported() { + options = options.with_transfer_syntax(ts.uid()); + } + } + }; + + for uid in ABSTRACT_SYNTAXES { + options = options.with_abstract_syntax(*uid); + } + + let mut association = options + .establish(scu_stream) + .whatever_context("could not establish association")?; + + info!("New association from {}", association.client_ae_title()); + debug!( + "> Presentation contexts: {:?}", + association.presentation_contexts() + ); + + loop { + match association.receive() { + Ok(mut pdu) => { + if verbose { + debug!("scu ----> scp: {}", pdu.short_description()); + } + match pdu { + Pdu::PData { ref mut data } => { + if data.is_empty() { + debug!("Ignoring empty PData PDU"); + continue; + } + + for data_value in data { + if data_value.value_type == PDataValueType::Data && !data_value.is_last + { + instance_buffer.append(&mut data_value.data); + } else if data_value.value_type == PDataValueType::Command + && data_value.is_last + { + // commands are always in implict VR LE + let ts = + dicom_transfer_syntax_registry::entries::IMPLICIT_VR_LITTLE_ENDIAN + .erased(); + let data_value = &data_value; + let v = &data_value.data; + + let obj = InMemDicomObject::read_dataset_with_ts(v.as_slice(), &ts) + .whatever_context("failed to read incoming DICOM command")?; + let command_field = obj + .element(tags::COMMAND_FIELD) + .whatever_context("Missing Command Field")? + .uint16() + .whatever_context("Command Field is not an integer")?; + + if command_field == 0x0030 { + // Handle C-ECHO-RQ + let cecho_response = create_cecho_response(msgid); + let mut cecho_data = Vec::new(); + + cecho_response + .write_dataset_with_ts(&mut cecho_data, &ts) + .whatever_context( + "could not write C-ECHO response object", + )?; + + let pdu_response = Pdu::PData { + data: vec![dicom_ul::pdu::PDataValue { + presentation_context_id: data_value + .presentation_context_id, + value_type: PDataValueType::Command, + is_last: true, + data: cecho_data, + }], + }; + association.send(&pdu_response).whatever_context( + "failed to send C-ECHO response object to SCU", + )?; + } else { + msgid = obj + .element(tags::MESSAGE_ID) + .whatever_context("Missing Message ID")? + .to_int() + .whatever_context("Message ID is not an integer")?; + sop_class_uid = obj + .element(tags::AFFECTED_SOP_CLASS_UID) + .whatever_context("missing Affected SOP Class UID")? + .to_str() + .whatever_context( + "could not retrieve Affected SOP Class UID", + )? + .to_string(); + sop_instance_uid = obj + .element(tags::AFFECTED_SOP_INSTANCE_UID) + .whatever_context("missing Affected SOP Instance UID")? + .to_str() + .whatever_context( + "could not retrieve Affected SOP Instance UID", + )? + .to_string(); + } + instance_buffer.clear(); + } else if data_value.value_type == PDataValueType::Data + && data_value.is_last + { + instance_buffer.append(&mut data_value.data); + + let presentation_context = association + .presentation_contexts() + .iter() + .find(|pc| pc.id == data_value.presentation_context_id) + .whatever_context("missing presentation context")?; + let ts = &presentation_context.transfer_syntax; + + let obj = InMemDicomObject::read_dataset_with_ts( + instance_buffer.as_slice(), + TransferSyntaxRegistry.get(ts).unwrap(), + ) + .whatever_context("failed to read DICOM data object")?; + let file_meta = FileMetaTableBuilder::new() + .media_storage_sop_class_uid( + obj.element(tags::SOP_CLASS_UID) + .whatever_context("missing SOP Class UID")? + .to_str() + .whatever_context("could not retrieve SOP Class UID")?, + ) + .media_storage_sop_instance_uid( + obj.element(tags::SOP_INSTANCE_UID) + .whatever_context("missing SOP Instance UID")? + .to_str() + .whatever_context("missing SOP Instance UID")?, + ) + .transfer_syntax(ts) + .build() + .whatever_context( + "failed to build DICOM meta file information", + )?; + let file_obj = obj.with_exact_meta(file_meta); + + // write the files to the current directory with their SOPInstanceUID as filenames + let mut file_path = out_dir.clone(); + file_path.push( + sop_instance_uid.trim_end_matches('\0').to_string() + ".dcm", + ); + file_obj + .write_to_file(&file_path) + .whatever_context("could not save DICOM object to file")?; + info!("Stored {}", file_path.display()); + + // send C-STORE-RSP object + // commands are always in implict VR LE + let ts = + dicom_transfer_syntax_registry::entries::IMPLICIT_VR_LITTLE_ENDIAN + .erased(); + + let obj = create_cstore_response( + msgid, + &sop_class_uid, + &sop_instance_uid, + ); + + let mut obj_data = Vec::new(); + + obj.write_dataset_with_ts(&mut obj_data, &ts) + .whatever_context("could not write response object")?; + + let pdu_response = Pdu::PData { + data: vec![dicom_ul::pdu::PDataValue { + presentation_context_id: data_value.presentation_context_id, + value_type: PDataValueType::Command, + is_last: true, + data: obj_data, + }], + }; + association + .send(&pdu_response) + .whatever_context("failed to send response object to SCU")?; + } + } + } + Pdu::ReleaseRQ => { + buffer.clear(); + association.send(&Pdu::ReleaseRP).unwrap_or_else(|e| { + warn!( + "Failed to send association release message to SCU: {}", + snafu::Report::from_error(e) + ); + }); + info!( + "Released association with {}", + association.client_ae_title() + ); + break; + } + Pdu::AbortRQ { source } => { + warn!("Aborted connection from: {:?}", source); + break; + } + _ => {} + } + } + Err(err @ dicom_ul::association::server::Error::Receive { .. }) => { + if verbose { + info!("{}", Report::from_error(err)); + } else { + info!("{}", err); + } + break; + } + Err(err) => { + warn!("Unexpected error: {}", Report::from_error(err)); + break; + } + } + } + + if let Ok(peer_addr) = association.inner_stream().peer_addr() { + info!( + "Dropping connection with {} ({})", + association.client_ae_title(), + peer_addr + ); + } else { + info!("Dropping connection with {}", association.client_ae_title()); + } + + Ok(()) +} \ No newline at end of file diff --git a/storescu/Cargo.toml b/storescu/Cargo.toml index 287aeeaeb..a1fd3c261 100644 --- a/storescu/Cargo.toml +++ b/storescu/Cargo.toml @@ -23,9 +23,13 @@ dicom-encoding = { path = "../encoding/", version = "0.7.1" } dicom-object = { path = '../object', version = "0.7.1" } dicom-pixeldata = { version = "0.7.1", path = "../pixeldata", optional = true } dicom-transfer-syntax-registry = { path = "../transfer-syntax-registry/", version = "0.7.1" } -dicom-ul = { path = '../ul', version = "0.7.1" } +dicom-ul = { path = '../ul', version = "0.7.1", features = ["async"] } walkdir = "2.3.2" indicatif = "0.17.0" tracing = "0.1.34" tracing-subscriber = "0.3.11" snafu = "0.8" + +[dependencies.tokio] +version = "1.38.0" +features = ["rt", "rt-multi-thread", "macros"] diff --git a/storescu/out.json b/storescu/out.json new file mode 100644 index 000000000..e69de29bb diff --git a/storescu/src/main.rs b/storescu/src/main.rs index 185a9b986..593f6bb0c 100644 --- a/storescu/src/main.rs +++ b/storescu/src/main.rs @@ -3,24 +3,22 @@ use dicom_core::{dicom_value, header::Tag, DataElement, VR}; use dicom_dictionary_std::{tags, uids}; use dicom_encoding::transfer_syntax; use dicom_encoding::TransferSyntax; -use dicom_object::{mem::InMemDicomObject, open_file, DefaultDicomObject, StandardDataDictionary}; +use dicom_object::{mem::InMemDicomObject, DefaultDicomObject, StandardDataDictionary}; use dicom_transfer_syntax_registry::TransferSyntaxRegistry; -use dicom_ul::{ - association::ClientAssociationOptions, - pdu::{PDataValue, PDataValueType, Pdu}, -}; use indicatif::{ProgressBar, ProgressStyle}; use snafu::prelude::*; use snafu::{Report, Whatever}; use std::collections::HashSet; use std::ffi::OsStr; -use std::io::Write; use std::path::{Path, PathBuf}; use std::time::Duration; use tracing::{debug, error, info, warn, Level}; use transfer_syntax::TransferSyntaxIndex; use walkdir::WalkDir; +mod store_async; +mod store_sync; + /// DICOM C-STORE SCU #[derive(Debug, Parser)] #[command(version)] @@ -91,6 +89,9 @@ struct App { conflicts_with("saml_assertion") )] jwt: Option, + + #[arg(long = "blocking")] + blocking: bool, } struct DicomFile { @@ -132,45 +133,31 @@ enum Error { } fn main() { - run().unwrap_or_else(|e| { - error!("{}", Report::from_error(e)); - std::process::exit(-2); - }); -} - -fn run() -> Result<(), Error> { - let App { - addr, - files, - verbose, - message_id, - calling_ae_title, - called_ae_title, - max_pdu_length, - fail_first, - mut never_transcode, - username, - password, - kerberos_service_ticket, - saml_assertion, - jwt, - } = App::parse(); - - // never transcode if the feature is disabled - if cfg!(not(feature = "transcode")) { - never_transcode = true; + let app = App::parse(); + if !app.blocking { + tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build() + .unwrap() + .block_on(async move { + run_async().await.unwrap_or_else(|e| { + error!("{}", Report::from_error(e)); + std::process::exit(-2); + }); + }); + } else { + run(app).unwrap_or_else(|e| { + error!("{}", Report::from_error(e)); + std::process::exit(-2); + }); } +} - tracing::subscriber::set_global_default( - tracing_subscriber::FmtSubscriber::builder() - .with_max_level(if verbose { Level::DEBUG } else { Level::INFO }) - .finish(), - ) - .whatever_context("Could not set up global logging subscriber") - .unwrap_or_else(|e: Whatever| { - eprintln!("[ERROR] {}", Report::from_error(e)); - }); - +fn check_files( + files: Vec, + verbose: bool, + never_transcode: bool, +) -> (Vec, HashSet<(String, String)>) { let mut checked_files: Vec = vec![]; let mut dicom_files: Vec = vec![]; let mut presentation_contexts = HashSet::new(); @@ -227,44 +214,185 @@ fn run() -> Result<(), Error> { eprintln!("No supported files to transfer"); std::process::exit(-1); } + (dicom_files, presentation_contexts) +} + +fn run(app: App) -> Result<(), Error> { + use crate::store_sync::{get_scu, send_file}; + let App { + addr, + files, + verbose, + message_id, + calling_ae_title, + called_ae_title, + max_pdu_length, + fail_first, + mut never_transcode, + username, + password, + kerberos_service_ticket, + saml_assertion, + jwt, + blocking: _, + } = app; + + // never transcode if the feature is disabled + if cfg!(not(feature = "transcode")) { + never_transcode = true; + } + + tracing::subscriber::set_global_default( + tracing_subscriber::FmtSubscriber::builder() + .with_max_level(if verbose { Level::DEBUG } else { Level::INFO }) + .finish(), + ) + .whatever_context("Could not set up global logging subscriber") + .unwrap_or_else(|e: Whatever| { + eprintln!("[ERROR] {}", Report::from_error(e)); + }); if verbose { info!("Establishing association with '{}'...", &addr); } + let (mut dicom_files, presentation_contexts) = check_files(files, verbose, never_transcode); - let mut scu_init = ClientAssociationOptions::new() - .calling_ae_title(calling_ae_title) - .max_pdu_length(max_pdu_length); + let mut scu = get_scu( + addr, + calling_ae_title, + called_ae_title, + max_pdu_length, + username, + password, + kerberos_service_ticket, + saml_assertion, + jwt, + presentation_contexts, + )?; - for (storage_sop_class_uid, transfer_syntax) in &presentation_contexts { - scu_init = scu_init.with_presentation_context(storage_sop_class_uid, vec![transfer_syntax]); + if verbose { + info!("Association established"); } - if let Some(called_ae_title) = called_ae_title { - scu_init = scu_init.called_ae_title(called_ae_title); + for file in &mut dicom_files { + // identify the right transfer syntax to use + let r: Result<_, Error> = + check_presentation_contexts(file, scu.presentation_contexts(), never_transcode) + .whatever_context::<_, _>("Could not choose a transfer syntax"); + match r { + Ok((pc, ts)) => { + if verbose { + debug!( + "{}: Selected presentation context: {:?}", + file.file.display(), + pc + ); + } + file.pc_selected = Some(pc); + file.ts_selected = Some(ts); + } + Err(e) => { + error!("{}", Report::from_error(e)); + if fail_first { + let _ = scu.abort(); + std::process::exit(-2); + } + } + } } - if let Some(username) = username { - scu_init = scu_init.username(username); + let progress_bar; + if !verbose { + progress_bar = Some(ProgressBar::new(dicom_files.len() as u64)); + if let Some(pb) = progress_bar.as_ref() { + pb.set_style( + ProgressStyle::default_bar() + .template("[{elapsed_precise}] {bar:40} {pos}/{len} {wide_msg}") + .expect("Invalid progress bar template"), + ); + pb.enable_steady_tick(Duration::new(0, 480_000_000)); + }; + } else { + progress_bar = None; } - if let Some(password) = password { - scu_init = scu_init.password(password); + for file in dicom_files { + // TODO + scu = send_file( + scu, + file, + message_id, + progress_bar.as_ref(), + verbose, + fail_first, + )?; } - if let Some(kerberos_service_ticket) = kerberos_service_ticket { - scu_init = scu_init.kerberos_service_ticket(kerberos_service_ticket); - } + if let Some(pb) = progress_bar { + pb.finish_with_message("done") + }; + + scu.release() + .whatever_context("Failed to release SCU association")?; + Ok(()) +} + +async fn run_async() -> Result<(), Error> { + use crate::store_async::{get_scu, send_file}; + let App { + addr, + files, + verbose, + message_id, + calling_ae_title, + called_ae_title, + max_pdu_length, + fail_first, + mut never_transcode, + username, + password, + kerberos_service_ticket, + saml_assertion, + jwt, + blocking: _, + } = App::parse(); - if let Some(saml_assertion) = saml_assertion { - scu_init = scu_init.saml_assertion(saml_assertion); + // never transcode if the feature is disabled + if cfg!(not(feature = "transcode")) { + never_transcode = true; } - if let Some(jwt) = jwt { - scu_init = scu_init.jwt(jwt); + tracing::subscriber::set_global_default( + tracing_subscriber::FmtSubscriber::builder() + .with_max_level(if verbose { Level::DEBUG } else { Level::INFO }) + .finish(), + ) + .whatever_context("Could not set up global logging subscriber") + .unwrap_or_else(|e: Whatever| { + eprintln!("[ERROR] {}", Report::from_error(e)); + }); + + if verbose { + info!("Establishing association with '{}'...", &addr); } + let (mut dicom_files, presentation_contexts) = + tokio::task::spawn_blocking(move || check_files(files, verbose, never_transcode)) + .await + .unwrap(); - let mut scu = scu_init.establish_with(&addr).context(InitScuSnafu)?; + let mut scu = get_scu( + addr, + calling_ae_title, + called_ae_title, + max_pdu_length, + username, + password, + kerberos_service_ticket, + saml_assertion, + jwt, + presentation_contexts, + ) + .await?; if verbose { info!("Association established"); @@ -290,7 +418,7 @@ fn run() -> Result<(), Error> { Err(e) => { error!("{}", Report::from_error(e)); if fail_first { - let _ = scu.abort(); + let _ = scu.abort().await; std::process::exit(-2); } } @@ -313,179 +441,17 @@ fn run() -> Result<(), Error> { } for file in dicom_files { - if let (Some(pc_selected), Some(ts_uid_selected)) = (file.pc_selected, file.ts_selected) { - if let Some(pb) = &progress_bar { - pb.set_message(file.sop_instance_uid.clone()); - } - let cmd = store_req_command(&file.sop_class_uid, &file.sop_instance_uid, message_id); - - let mut cmd_data = Vec::with_capacity(128); - cmd.write_dataset_with_ts( - &mut cmd_data, - &dicom_transfer_syntax_registry::entries::IMPLICIT_VR_LITTLE_ENDIAN.erased(), - ) - .map_err(Box::from) - .context(CreateCommandSnafu)?; - - let mut object_data = Vec::with_capacity(2048); - let dicom_file = - open_file(&file.file).whatever_context("Could not open listed DICOM file")?; - let ts_selected = TransferSyntaxRegistry - .get(&ts_uid_selected) - .with_context(|| UnsupportedFileTransferSyntaxSnafu { - uid: ts_uid_selected.to_string(), - })?; - - // transcode file if necessary - let dicom_file = into_ts(dicom_file, ts_selected, verbose)?; - - dicom_file - .write_dataset_with_ts(&mut object_data, ts_selected) - .whatever_context("Could not write object dataset")?; - - let nbytes = cmd_data.len() + object_data.len(); - - if verbose { - info!( - "Sending file {} (~ {} kB), uid={}, sop={}, ts={}", - file.file.display(), - nbytes / 1_000, - &file.sop_instance_uid, - &file.sop_class_uid, - ts_uid_selected, - ); - } - - if nbytes < scu.acceptor_max_pdu_length().saturating_sub(100) as usize { - let pdu = Pdu::PData { - data: vec![ - PDataValue { - presentation_context_id: pc_selected.id, - value_type: PDataValueType::Command, - is_last: true, - data: cmd_data, - }, - PDataValue { - presentation_context_id: pc_selected.id, - value_type: PDataValueType::Data, - is_last: true, - data: object_data, - }, - ], - }; - - scu.send(&pdu) - .whatever_context("Failed to send C-STORE-RQ")?; - } else { - let pdu = Pdu::PData { - data: vec![PDataValue { - presentation_context_id: pc_selected.id, - value_type: PDataValueType::Command, - is_last: true, - data: cmd_data, - }], - }; - - scu.send(&pdu) - .whatever_context("Failed to send C-STORE-RQ command")?; - - { - let mut pdata = scu.send_pdata(pc_selected.id); - pdata - .write_all(&object_data) - .whatever_context("Failed to send C-STORE-RQ P-Data")?; - } - } - - if verbose { - debug!("Awaiting response..."); - } - - let rsp_pdu = scu - .receive() - .whatever_context("Failed to receive C-STORE-RSP")?; - - match rsp_pdu { - Pdu::PData { data } => { - let data_value = &data[0]; - - let cmd_obj = InMemDicomObject::read_dataset_with_ts( - &data_value.data[..], - &dicom_transfer_syntax_registry::entries::IMPLICIT_VR_LITTLE_ENDIAN - .erased(), - ) - .whatever_context("Could not read response from SCP")?; - if verbose { - debug!("Full response: {:?}", cmd_obj); - } - let status = cmd_obj - .element(tags::STATUS) - .whatever_context("Could not find status code in response")? - .to_int::() - .whatever_context("Status code in response is not a valid integer")?; - let storage_sop_instance_uid = file - .sop_instance_uid - .trim_end_matches(|c: char| c.is_whitespace() || c == '\0'); - - match status { - // Success - 0 => { - if verbose { - info!("Successfully stored instance {}", storage_sop_instance_uid); - } - } - // Warning - 1 | 0x0107 | 0x0116 | 0xB000..=0xBFFF => { - warn!( - "Possible issue storing instance `{}` (status code {:04X}H)", - storage_sop_instance_uid, status - ); - } - 0xFF00 | 0xFF01 => { - warn!( - "Possible issue storing instance `{}`: status is pending (status code {:04X}H)", - storage_sop_instance_uid, status - ); - } - 0xFE00 => { - error!( - "Could not store instance `{}`: operation cancelled", - storage_sop_instance_uid - ); - if fail_first { - let _ = scu.abort(); - std::process::exit(-2); - } - } - _ => { - error!( - "Failed to store instance `{}` (status code {:04X}H)", - storage_sop_instance_uid, status - ); - if fail_first { - let _ = scu.abort(); - std::process::exit(-2); - } - } - } - } - - pdu @ Pdu::Unknown { .. } - | pdu @ Pdu::AssociationRQ { .. } - | pdu @ Pdu::AssociationAC { .. } - | pdu @ Pdu::AssociationRJ { .. } - | pdu @ Pdu::ReleaseRQ - | pdu @ Pdu::ReleaseRP - | pdu @ Pdu::AbortRQ { .. } => { - error!("Unexpected SCP response: {:?}", pdu); - let _ = scu.abort(); - std::process::exit(-2); - } - } - } - if let Some(pb) = progress_bar.as_ref() { - pb.inc(1) - }; + // TODO: Eventually expose concurrency option to sping up multiple + // worker tasks to send files in parallel + scu = send_file( + scu, + file, + message_id, + progress_bar.as_ref(), + verbose, + fail_first, + ) + .await?; } if let Some(pb) = progress_bar { @@ -493,10 +459,10 @@ fn run() -> Result<(), Error> { }; scu.release() + .await .whatever_context("Failed to release SCU association")?; Ok(()) } - fn store_req_command( storage_sop_class_uid: &str, storage_sop_instance_uid: &str, diff --git a/storescu/src/store_async.rs b/storescu/src/store_async.rs new file mode 100644 index 000000000..da3dcae6a --- /dev/null +++ b/storescu/src/store_async.rs @@ -0,0 +1,255 @@ +use std::collections::HashSet; + +use dicom_dictionary_std::tags; +use dicom_encoding::TransferSyntaxIndex; +use dicom_object::{open_file, InMemDicomObject}; +use dicom_transfer_syntax_registry::TransferSyntaxRegistry; +use dicom_ul::{ + pdu::{PDataValue, PDataValueType}, + ClientAssociation, ClientAssociationOptions, Pdu, +}; +use indicatif::ProgressBar; +use snafu::{OptionExt, ResultExt}; +use tokio::{io::AsyncWriteExt, net::TcpStream}; +use tracing::{debug, error, info, warn}; + +use crate::{ + into_ts, store_req_command, CreateCommandSnafu, DicomFile, Error, InitScuSnafu, + UnsupportedFileTransferSyntaxSnafu, +}; + +#[allow(clippy::too_many_arguments)] +pub async fn get_scu( + addr: String, + calling_ae_title: String, + called_ae_title: Option, + max_pdu_length: u32, + username: Option, + password: Option, + kerberos_service_ticket: Option, + saml_assertion: Option, + jwt: Option, + presentation_contexts: HashSet<(String, String)>, +) -> Result, Error> { + let mut scu_init = ClientAssociationOptions::new() + .calling_ae_title(calling_ae_title) + .max_pdu_length(max_pdu_length); + + for (storage_sop_class_uid, transfer_syntax) in &presentation_contexts { + scu_init = scu_init.with_presentation_context(storage_sop_class_uid, vec![transfer_syntax]); + } + + if let Some(called_ae_title) = called_ae_title { + scu_init = scu_init.called_ae_title(called_ae_title); + } + + if let Some(username) = username { + scu_init = scu_init.username(username); + } + + if let Some(password) = password { + scu_init = scu_init.password(password); + } + + if let Some(kerberos_service_ticket) = kerberos_service_ticket { + scu_init = scu_init.kerberos_service_ticket(kerberos_service_ticket); + } + + if let Some(saml_assertion) = saml_assertion { + scu_init = scu_init.saml_assertion(saml_assertion); + } + + if let Some(jwt) = jwt { + scu_init = scu_init.jwt(jwt); + } + + scu_init + .establish_with_async(&addr) + .await + .context(InitScuSnafu) +} + +pub async fn send_file( + mut scu: ClientAssociation, + file: DicomFile, + message_id: u16, + progress_bar: Option<&ProgressBar>, + verbose: bool, + fail_first: bool, +) -> Result, Error> { + if let (Some(pc_selected), Some(ts_uid_selected)) = (file.pc_selected, file.ts_selected) { + if let Some(pb) = &progress_bar { + pb.set_message(file.sop_instance_uid.clone()); + } + let cmd = store_req_command(&file.sop_class_uid, &file.sop_instance_uid, message_id); + + let mut cmd_data = Vec::with_capacity(128); + cmd.write_dataset_with_ts( + &mut cmd_data, + &dicom_transfer_syntax_registry::entries::IMPLICIT_VR_LITTLE_ENDIAN.erased(), + ) + .map_err(Box::from) + .context(CreateCommandSnafu)?; + + let mut object_data = Vec::with_capacity(2048); + let dicom_file = + open_file(&file.file).whatever_context("Could not open listed DICOM file")?; + let ts_selected = TransferSyntaxRegistry + .get(&ts_uid_selected) + .with_context(|| UnsupportedFileTransferSyntaxSnafu { + uid: ts_uid_selected.to_string(), + })?; + + // transcode file if necessary + let dicom_file = into_ts(dicom_file, ts_selected, verbose)?; + + dicom_file + .write_dataset_with_ts(&mut object_data, ts_selected) + .whatever_context("Could not write object dataset")?; + + let nbytes = cmd_data.len() + object_data.len(); + + if verbose { + info!( + "Sending file {} (~ {} kB), uid={}, sop={}, ts={}", + file.file.display(), + nbytes / 1_000, + &file.sop_instance_uid, + &file.sop_class_uid, + ts_uid_selected, + ); + } + + if nbytes < scu.acceptor_max_pdu_length().saturating_sub(100) as usize { + let pdu = Pdu::PData { + data: vec![ + PDataValue { + presentation_context_id: pc_selected.id, + value_type: PDataValueType::Command, + is_last: true, + data: cmd_data, + }, + PDataValue { + presentation_context_id: pc_selected.id, + value_type: PDataValueType::Data, + is_last: true, + data: object_data, + }, + ], + }; + + scu.send(&pdu) + .await + .whatever_context("Failed to send C-STORE-RQ")?; + } else { + let pdu = Pdu::PData { + data: vec![PDataValue { + presentation_context_id: pc_selected.id, + value_type: PDataValueType::Command, + is_last: true, + data: cmd_data, + }], + }; + + scu.send(&pdu) + .await + .whatever_context("Failed to send C-STORE-RQ command")?; + + { + let mut pdata = scu.send_pdata(pc_selected.id).await; + pdata.write_all(&object_data).await.unwrap(); + //.whatever_context("Failed to send C-STORE-RQ P-Data")?; + } + } + + if verbose { + debug!("Awaiting response..."); + } + + let rsp_pdu = scu + .receive() + .await + .whatever_context("Failed to receive C-STORE-RSP")?; + + match rsp_pdu { + Pdu::PData { data } => { + let data_value = &data[0]; + + let cmd_obj = InMemDicomObject::read_dataset_with_ts( + &data_value.data[..], + &dicom_transfer_syntax_registry::entries::IMPLICIT_VR_LITTLE_ENDIAN.erased(), + ) + .whatever_context("Could not read response from SCP")?; + if verbose { + debug!("Full response: {:?}", cmd_obj); + } + let status = cmd_obj + .element(tags::STATUS) + .whatever_context("Could not find status code in response")? + .to_int::() + .whatever_context("Status code in response is not a valid integer")?; + let storage_sop_instance_uid = file + .sop_instance_uid + .trim_end_matches(|c: char| c.is_whitespace() || c == '\0'); + + match status { + // Success + 0 => { + if verbose { + info!("Successfully stored instance {}", storage_sop_instance_uid); + } + } + // Warning + 1 | 0x0107 | 0x0116 | 0xB000..=0xBFFF => { + warn!( + "Possible issue storing instance `{}` (status code {:04X}H)", + storage_sop_instance_uid, status + ); + } + 0xFF00 | 0xFF01 => { + warn!( + "Possible issue storing instance `{}`: status is pending (status code {:04X}H)", + storage_sop_instance_uid, status + ); + } + 0xFE00 => { + error!( + "Could not store instance `{}`: operation cancelled", + storage_sop_instance_uid + ); + if fail_first { + let _ = scu.abort().await; + std::process::exit(-2); + } + } + _ => { + error!( + "Failed to store instance `{}` (status code {:04X}H)", + storage_sop_instance_uid, status + ); + if fail_first { + let _ = scu.abort().await; + std::process::exit(-2); + } + } + } + } + + pdu @ Pdu::Unknown { .. } + | pdu @ Pdu::AssociationRQ { .. } + | pdu @ Pdu::AssociationAC { .. } + | pdu @ Pdu::AssociationRJ { .. } + | pdu @ Pdu::ReleaseRQ + | pdu @ Pdu::ReleaseRP + | pdu @ Pdu::AbortRQ { .. } => { + error!("Unexpected SCP response: {:?}", pdu); + let _ = scu.abort().await; + std::process::exit(-2); + } + } + } + if let Some(pb) = progress_bar.as_ref() { + pb.inc(1) + }; + Ok(scu) +} diff --git a/storescu/src/store_sync.rs b/storescu/src/store_sync.rs new file mode 100644 index 000000000..148188913 --- /dev/null +++ b/storescu/src/store_sync.rs @@ -0,0 +1,249 @@ +use std::{collections::HashSet, io::Write, net::TcpStream}; + +use dicom_dictionary_std::tags; +use dicom_encoding::TransferSyntaxIndex; +use dicom_object::{open_file, InMemDicomObject}; +use dicom_transfer_syntax_registry::TransferSyntaxRegistry; +use dicom_ul::{ + pdu::{PDataValue, PDataValueType}, + ClientAssociation, ClientAssociationOptions, Pdu, +}; +use indicatif::ProgressBar; +use snafu::{OptionExt, ResultExt}; +use tracing::{debug, error, info, warn}; + +use crate::{ + into_ts, store_req_command, CreateCommandSnafu, DicomFile, Error, InitScuSnafu, + UnsupportedFileTransferSyntaxSnafu, +}; + +#[allow(clippy::too_many_arguments)] +pub fn get_scu( + addr: String, + calling_ae_title: String, + called_ae_title: Option, + max_pdu_length: u32, + username: Option, + password: Option, + kerberos_service_ticket: Option, + saml_assertion: Option, + jwt: Option, + presentation_contexts: HashSet<(String, String)>, +) -> Result, Error> { + let mut scu_init = ClientAssociationOptions::new() + .calling_ae_title(calling_ae_title) + .max_pdu_length(max_pdu_length); + + for (storage_sop_class_uid, transfer_syntax) in &presentation_contexts { + scu_init = scu_init.with_presentation_context(storage_sop_class_uid, vec![transfer_syntax]); + } + + if let Some(called_ae_title) = called_ae_title { + scu_init = scu_init.called_ae_title(called_ae_title); + } + + if let Some(username) = username { + scu_init = scu_init.username(username); + } + + if let Some(password) = password { + scu_init = scu_init.password(password); + } + + if let Some(kerberos_service_ticket) = kerberos_service_ticket { + scu_init = scu_init.kerberos_service_ticket(kerberos_service_ticket); + } + + if let Some(saml_assertion) = saml_assertion { + scu_init = scu_init.saml_assertion(saml_assertion); + } + + if let Some(jwt) = jwt { + scu_init = scu_init.jwt(jwt); + } + + scu_init.establish_with(&addr).context(InitScuSnafu) +} + +pub fn send_file( + mut scu: ClientAssociation, + file: DicomFile, + message_id: u16, + progress_bar: Option<&ProgressBar>, + verbose: bool, + fail_first: bool, +) -> Result, Error> { + if let (Some(pc_selected), Some(ts_uid_selected)) = (file.pc_selected, file.ts_selected) { + if let Some(pb) = &progress_bar { + pb.set_message(file.sop_instance_uid.clone()); + } + let cmd = store_req_command(&file.sop_class_uid, &file.sop_instance_uid, message_id); + + let mut cmd_data = Vec::with_capacity(128); + cmd.write_dataset_with_ts( + &mut cmd_data, + &dicom_transfer_syntax_registry::entries::IMPLICIT_VR_LITTLE_ENDIAN.erased(), + ) + .map_err(Box::from) + .context(CreateCommandSnafu)?; + + let mut object_data = Vec::with_capacity(2048); + let dicom_file = + open_file(&file.file).whatever_context("Could not open listed DICOM file")?; + let ts_selected = TransferSyntaxRegistry + .get(&ts_uid_selected) + .with_context(|| UnsupportedFileTransferSyntaxSnafu { + uid: ts_uid_selected.to_string(), + })?; + + // transcode file if necessary + let dicom_file = into_ts(dicom_file, ts_selected, verbose)?; + + dicom_file + .write_dataset_with_ts(&mut object_data, ts_selected) + .whatever_context("Could not write object dataset")?; + + let nbytes = cmd_data.len() + object_data.len(); + + if verbose { + info!( + "Sending file {} (~ {} kB), uid={}, sop={}, ts={}", + file.file.display(), + nbytes / 1_000, + &file.sop_instance_uid, + &file.sop_class_uid, + ts_uid_selected, + ); + } + + if nbytes < scu.acceptor_max_pdu_length().saturating_sub(100) as usize { + let pdu = Pdu::PData { + data: vec![ + PDataValue { + presentation_context_id: pc_selected.id, + value_type: PDataValueType::Command, + is_last: true, + data: cmd_data, + }, + PDataValue { + presentation_context_id: pc_selected.id, + value_type: PDataValueType::Data, + is_last: true, + data: object_data, + }, + ], + }; + + scu.send(&pdu) + .whatever_context("Failed to send C-STORE-RQ")?; + } else { + let pdu = Pdu::PData { + data: vec![PDataValue { + presentation_context_id: pc_selected.id, + value_type: PDataValueType::Command, + is_last: true, + data: cmd_data, + }], + }; + + scu.send(&pdu) + .whatever_context("Failed to send C-STORE-RQ command")?; + + { + let mut pdata = scu.send_pdata(pc_selected.id); + pdata + .write_all(&object_data) + .whatever_context("Failed to send C-STORE-RQ P-Data")?; + } + } + + if verbose { + debug!("Awaiting response..."); + } + + let rsp_pdu = scu + .receive() + .whatever_context("Failed to receive C-STORE-RSP")?; + + match rsp_pdu { + Pdu::PData { data } => { + let data_value = &data[0]; + + let cmd_obj = InMemDicomObject::read_dataset_with_ts( + &data_value.data[..], + &dicom_transfer_syntax_registry::entries::IMPLICIT_VR_LITTLE_ENDIAN.erased(), + ) + .whatever_context("Could not read response from SCP")?; + if verbose { + debug!("Full response: {:?}", cmd_obj); + } + let status = cmd_obj + .element(tags::STATUS) + .whatever_context("Could not find status code in response")? + .to_int::() + .whatever_context("Status code in response is not a valid integer")?; + let storage_sop_instance_uid = file + .sop_instance_uid + .trim_end_matches(|c: char| c.is_whitespace() || c == '\0'); + + match status { + // Success + 0 => { + if verbose { + info!("Successfully stored instance {}", storage_sop_instance_uid); + } + } + // Warning + 1 | 0x0107 | 0x0116 | 0xB000..=0xBFFF => { + warn!( + "Possible issue storing instance `{}` (status code {:04X}H)", + storage_sop_instance_uid, status + ); + } + 0xFF00 | 0xFF01 => { + warn!( + "Possible issue storing instance `{}`: status is pending (status code {:04X}H)", + storage_sop_instance_uid, status + ); + } + 0xFE00 => { + error!( + "Could not store instance `{}`: operation cancelled", + storage_sop_instance_uid + ); + if fail_first { + let _ = scu.abort(); + std::process::exit(-2); + } + } + _ => { + error!( + "Failed to store instance `{}` (status code {:04X}H)", + storage_sop_instance_uid, status + ); + if fail_first { + let _ = scu.abort(); + std::process::exit(-2); + } + } + } + } + + pdu @ Pdu::Unknown { .. } + | pdu @ Pdu::AssociationRQ { .. } + | pdu @ Pdu::AssociationAC { .. } + | pdu @ Pdu::AssociationRJ { .. } + | pdu @ Pdu::ReleaseRQ + | pdu @ Pdu::ReleaseRP + | pdu @ Pdu::AbortRQ { .. } => { + error!("Unexpected SCP response: {:?}", pdu); + let _ = scu.abort(); + std::process::exit(-2); + } + } + } + if let Some(pb) = progress_bar.as_ref() { + pb.inc(1) + }; + Ok(scu) +} diff --git a/ul/Cargo.toml b/ul/Cargo.toml index dcb393ed0..47d6a6e92 100644 --- a/ul/Cargo.toml +++ b/ul/Cargo.toml @@ -12,10 +12,27 @@ readme = "README.md" [dependencies] byteordered = "0.6" +bytes = "^1.6" dicom-encoding = { path = "../encoding/", version = "0.7.1" } dicom-transfer-syntax-registry = { path = "../transfer-syntax-registry/", version = "0.7.1", default-features = false } snafu = "0.8" tracing = "0.1.34" +[dependencies.tokio] +version = "^1.38" +optional = true +features = [ + "rt", + "rt-multi-thread", + "net", + "io-util", + "time" +] + [dev-dependencies] matches = "0.1.8" +tokio = { version = "^1.38", features = ["io-util", "macros", "net", "rt", "rt-multi-thread"] } + +[features] +async = ["dep:tokio"] +default = ["async"] diff --git a/ul/src/address.rs b/ul/src/address.rs index 875403ebb..8e4cb5f03 100644 --- a/ul/src/address.rs +++ b/ul/src/address.rs @@ -7,14 +7,13 @@ //! The syntax is `«ae_title»@«network_address»:«port»`, //! which works not only with IPv4 and IPv6 addresses, //! but also with domain names. +use snafu::{ensure, AsErrorSource, ResultExt, Snafu}; use std::{ convert::TryFrom, net::{SocketAddr, SocketAddrV4, SocketAddrV6, ToSocketAddrs}, str::FromStr, }; -use snafu::{ensure, AsErrorSource, ResultExt, Snafu}; - /// A specification for a full address to the target SCP: /// an application entity title, plus a generic address, /// typically a socket address. diff --git a/ul/src/association/client.rs b/ul/src/association/client.rs index 5576da295..13afeb9b4 100644 --- a/ul/src/association/client.rs +++ b/ul/src/association/client.rs @@ -4,25 +4,23 @@ //! in which this application entity is the one requesting the association. //! See [`ClientAssociationOptions`] //! for details and examples on how to create an association. -use std::{ - borrow::Cow, - convert::TryInto, - io::Write, - net::{TcpStream, ToSocketAddrs}, time::Duration, -}; +use bytes::BytesMut; +use std::{borrow::Cow, convert::TryInto, io::Cursor, net::ToSocketAddrs, time::Duration}; +use std::io::{BufRead, BufReader, Read, Write}; use crate::{ pdu::{ - reader::{read_pdu, DEFAULT_MAX_PDU, MAXIMUM_PDU_SIZE}, - writer::write_pdu, - AbortRQSource, AssociationAC, AssociationRJ, AssociationRQ, Pdu, + read_pdu, write_pdu, AbortRQSource, AssociationAC, AssociationRJ, AssociationRQ, Pdu, PresentationContextProposed, PresentationContextResult, PresentationContextResultReason, - UserIdentity, UserIdentityType, UserVariableItem, + ReadPduSnafu, UserIdentity, UserIdentityType, UserVariableItem, DEFAULT_MAX_PDU, + MAXIMUM_PDU_SIZE, }, AeAddr, IMPLEMENTATION_CLASS_UID, IMPLEMENTATION_VERSION_NAME, }; use snafu::{ensure, Backtrace, ResultExt, Snafu}; +use bytes::Buf; + use super::{ pdata::{PDataReader, PDataWriter}, uid::trim_uid, @@ -39,15 +37,15 @@ pub enum Error { source: std::io::Error, backtrace: Backtrace, }, - + /// Could not set tcp read timeout - SetReadTimeout{ + SetReadTimeout { source: std::io::Error, backtrace: Backtrace, }, /// Could not set tcp write timeout - SetWriteTimeout{ + SetWriteTimeout { source: std::io::Error, backtrace: Backtrace, }, @@ -55,13 +53,13 @@ pub enum Error { /// failed to send association request SendRequest { #[snafu(backtrace)] - source: crate::pdu::writer::Error, + source: crate::pdu::WriteError, }, /// failed to receive association response ReceiveResponse { #[snafu(backtrace)] - source: crate::pdu::reader::Error, + source: crate::pdu::ReadError, }, #[snafu(display("unexpected response from server `{:?}`", pdu))] @@ -98,7 +96,7 @@ pub enum Error { #[non_exhaustive] Send { #[snafu(backtrace)] - source: crate::pdu::writer::Error, + source: crate::pdu::WriteError, }, /// failed to send PDU message on wire @@ -119,12 +117,47 @@ pub enum Error { #[non_exhaustive] Receive { #[snafu(backtrace)] - source: crate::pdu::reader::Error, + source: crate::pdu::ReadError, }, + + #[snafu(display("Connection closed by peer"))] + ConnectionClosed, } pub type Result = std::result::Result; +/// Helper function to get a PDU from a reader +pub fn get_client_pdu(reader: &mut R, max_pdu_length: u32, strict: bool) -> Result { + // Receive response + + let mut read_buffer = BytesMut::with_capacity(MAXIMUM_PDU_SIZE as usize); + let mut reader = BufReader::new(reader); + + let msg = loop { + let mut buf = Cursor::new(&read_buffer[..]); + match read_pdu(&mut buf, max_pdu_length, strict).context(ReceiveResponseSnafu)? { + Some(pdu) => { + read_buffer.advance(buf.position() as usize); + break pdu; + } + None => { + // Reset position + buf.set_position(0) + } + } + // Use BufReader to get similar behavior to AsyncRead read_buf + let recv = reader + .fill_buf() + .context(ReadPduSnafu) + .context(ReceiveSnafu)? + .to_vec(); + reader.consume(recv.len()); + read_buffer.extend_from_slice(&recv); + ensure!(!recv.is_empty(), ConnectionClosedSnafu); + }; + Ok(msg) +} + /// A DICOM association builder for a client node. /// The final outcome is a [`ClientAssociation`]. /// @@ -132,18 +165,46 @@ pub type Result = std::result::Result; /// an association with another DICOM node, /// that one usually taking the role of a service class provider (SCP). /// -/// # Example +/// You can create either a blocking or non-blocking client by calling either +/// `establish` or `establish_async` respectively. +/// +/// > **⚠️ Warning:** It is highly recommended to set `timeout` to a reasonable value for the +/// > async client since there is _no_ default timeout on +/// > [`tokio::net::TcpStream`], see the [`ClientAssociationOptions::timeout`] method for details. +/// +/// ## Basic usage +/// +/// ### Sync /// /// ```no_run /// # use dicom_ul::association::client::ClientAssociationOptions; +/// # use std::time::Duration; /// # fn run() -> Result<(), Box> { /// let association = ClientAssociationOptions::new() /// .with_presentation_context("1.2.840.10008.1.1", vec!["1.2.840.10008.1.2.1", "1.2.840.10008.1.2"]) +/// .timeout(Duration::from_secs(60)) /// .establish("129.168.0.5:104")?; /// # Ok(()) /// # } /// ``` /// +/// ### Async +/// +/// ```no_run +/// # use dicom_ul::association::client::ClientAssociationOptions; +/// # use std::time::Duration; +/// #[tokio::main] +/// # async fn run() -> Result<(), Box> { +/// let association = ClientAssociationOptions::new() +/// .with_presentation_context("1.2.840.10008.1.1", vec!["1.2.840.10008.1.2.1", "1.2.840.10008.1.2"]) +/// .timeout(Duration::from_secs(60)) +/// .establish_async("129.168.0.5:104") +/// .await?; +/// # Ok(()) +/// # } +/// ``` +/// +/// /// At least one presentation context must be specified, /// using the method [`with_presentation_context`](Self::with_presentation_context) /// and supplying both an abstract syntax and list of transfer syntaxes. @@ -152,7 +213,6 @@ pub type Result = std::result::Result; /// include by default the transfer syntaxes /// _Implicit VR Little Endian_ and _Explicit VR Little Endian_ /// in the resulting presentation context. -/// # Example /// /// ```no_run /// # use dicom_ul::association::client::ClientAssociationOptions; @@ -189,10 +249,8 @@ pub struct ClientAssociationOptions<'a> { saml_assertion: Option>, /// User identity JWT jwt: Option>, - /// TCP read timeout - read_timeout: Option, - /// TCP write timeout - write_timeout: Option, + /// Timeout for individual send/receive operations + timeout: Option, } impl<'a> Default for ClientAssociationOptions<'a> { @@ -207,15 +265,14 @@ impl<'a> Default for ClientAssociationOptions<'a> { // the list of requested presentation contexts presentation_contexts: Vec::new(), protocol_version: 1, - max_pdu_length: crate::pdu::reader::DEFAULT_MAX_PDU, + max_pdu_length: DEFAULT_MAX_PDU, strict: true, username: None, password: None, kerberos_service_ticket: None, saml_assertion: None, jwt: None, - read_timeout: None, - write_timeout: None, + timeout: None, } } } @@ -415,7 +472,7 @@ impl<'a> ClientAssociationOptions<'a> { /// Initiate the TCP connection to the given address /// and request a new DICOM association, /// negotiating the presentation contexts in the process. - pub fn establish(self, address: A) -> Result { + pub fn establish(self, address: A) -> Result> { self.establish_impl(AeAddr::new_socket_addr(address)) } @@ -442,7 +499,8 @@ impl<'a> ClientAssociationOptions<'a> { /// # Ok(()) /// # } /// ``` - pub fn establish_with(self, ae_address: &str) -> Result { + #[allow(unreachable_patterns)] + pub fn establish_with(self, ae_address: &str) -> Result> { match ae_address.try_into() { Ok(ae_address) => self.establish_impl(ae_address), Err(_) => self.establish_impl(AeAddr::new_socket_addr(ae_address)), @@ -450,22 +508,16 @@ impl<'a> ClientAssociationOptions<'a> { } /// Set the read timeout for the underlying TCP socket - pub fn read_timeout(self, timeout: Duration) -> Self { + /// + /// This is used to set both the read and write timeout. + pub fn timeout(self, timeout: Duration) -> Self { Self { - read_timeout: Some(timeout), + timeout: Some(timeout), ..self } } - /// Set the write timeout for the underlying TCP socket - pub fn write_timeout(self, timeout: Duration) -> Self { - Self { - write_timeout: Some(timeout), - ..self - } - } - - fn establish_impl(self, ae_address: AeAddr) -> Result + fn establish_impl(self, ae_address: AeAddr) -> Result> where T: ToSocketAddrs, { @@ -482,8 +534,7 @@ impl<'a> ClientAssociationOptions<'a> { kerberos_service_ticket, saml_assertion, jwt, - read_timeout, - write_timeout + timeout, } = self; // fail if no presentation contexts were provided: they represent intent, @@ -546,11 +597,12 @@ impl<'a> ClientAssociationOptions<'a> { user_variables, }); - let mut socket = std::net::TcpStream::connect(ae_address) - .context(ConnectSnafu)?; - socket.set_read_timeout(read_timeout) + let mut socket = std::net::TcpStream::connect(ae_address).context(ConnectSnafu)?; + socket + .set_read_timeout(timeout) .context(SetReadTimeoutSnafu)?; - socket.set_write_timeout(write_timeout) + socket + .set_write_timeout(timeout) .context(SetWriteTimeoutSnafu)?; let mut buffer: Vec = Vec::with_capacity(max_pdu_length as usize); // send request @@ -558,9 +610,8 @@ impl<'a> ClientAssociationOptions<'a> { write_pdu(&mut buffer, &msg).context(SendRequestSnafu)?; socket.write_all(&buffer).context(WireSendSnafu)?; buffer.clear(); - // receive response - let msg = - read_pdu(&mut socket, MAXIMUM_PDU_SIZE, self.strict).context(ReceiveResponseSnafu)?; + + let msg = get_client_pdu(&mut socket, MAXIMUM_PDU_SIZE, self.strict)?; match msg { Pdu::AssociationAC(AssociationAC { @@ -617,6 +668,8 @@ impl<'a> ClientAssociationOptions<'a> { socket, buffer, strict, + read_buffer: BytesMut::with_capacity(MAXIMUM_PDU_SIZE as usize), + timeout, }) } Pdu::AssociationRJ(association_rj) => RejectedSnafu { association_rj }.fail(), @@ -708,6 +761,28 @@ impl<'a> ClientAssociationOptions<'a> { } } +/// Trait to close underlying socket +pub trait CloseSocket { + fn close(&mut self) -> std::io::Result<()>; +} + +impl CloseSocket for std::net::TcpStream { + fn close(&mut self) -> std::io::Result<()> { + self.shutdown(std::net::Shutdown::Both) + } +} + +/// Trait to release association +pub trait Release { + fn release(&mut self) -> Result<()>; +} + +impl Release for ClientAssociation { + fn release(&mut self) -> Result<()> { + self.release_impl() + } +} + /// A DICOM upper level association from the perspective /// of a requesting application entity. /// @@ -721,8 +796,15 @@ impl<'a> ClientAssociationOptions<'a> { /// the program will automatically try to gracefully release the association /// through a standard C-RELEASE message exchange, /// then shut down the underlying TCP connection. +/// +/// This may either be sync or async depending on which method was called to +/// establish the association. #[derive(Debug)] -pub struct ClientAssociation { +pub struct ClientAssociation +where + S: CloseSocket, + ClientAssociation: Release, +{ /// The presentation contexts accorded with the acceptor application entity, /// without the rejected ones. presentation_contexts: Vec, @@ -731,14 +813,25 @@ pub struct ClientAssociation { /// The maximum PDU length that the remote application entity accepts acceptor_max_pdu_length: u32, /// The TCP stream to the other DICOM node - socket: TcpStream, + socket: S, /// Buffer to assemble PDU before sending it on wire buffer: Vec, /// whether to receive PDUs in strict mode strict: bool, + /// Timeout for individual Send/Receive operations + timeout: Option, + /// Buffer to assemble PDU before parsing + read_buffer: BytesMut, } -impl ClientAssociation { +impl ClientAssociation +where + ClientAssociation: Release, +{ + /// Retrieve timeout for the association + pub fn timeout(&self) -> Option { + self.timeout + } /// Retrieve the list of negotiated presentation contexts. pub fn presentation_contexts(&self) -> &[PresentationContextResult] { &self.presentation_contexts @@ -759,7 +852,12 @@ impl ClientAssociation { pub fn requestor_max_pdu_length(&self) -> u32 { self.requestor_max_pdu_length } +} +impl ClientAssociation +where + ClientAssociation: Release, +{ /// Send a PDU message to the other intervenient. pub fn send(&mut self, msg: &Pdu) -> Result<()> { self.buffer.clear(); @@ -775,7 +873,34 @@ impl ClientAssociation { /// Read a PDU message from the other intervenient. pub fn receive(&mut self) -> Result { - read_pdu(&mut self.socket, self.requestor_max_pdu_length, self.strict).context(ReceiveSnafu) + use std::io::{BufRead, BufReader, Cursor}; + + let mut reader = BufReader::new(&mut self.socket); + + loop { + let mut buf = Cursor::new(&self.read_buffer[..]); + match read_pdu(&mut buf, self.acceptor_max_pdu_length, self.strict) + .context(ReceiveResponseSnafu)? + { + Some(pdu) => { + self.read_buffer.advance(buf.position() as usize); + return Ok(pdu); + } + None => { + // Reset position + buf.set_position(0) + } + } + // Use BufReader to get similar behavior to AsyncRead read_buf + let recv = reader + .fill_buf() + .context(ReadPduSnafu) + .context(ReceiveSnafu)? + .to_vec(); + reader.consume(recv.len()); + self.read_buffer.extend_from_slice(&recv); + ensure!(!recv.is_empty(), ConnectionClosedSnafu); + } } /// Gracefully terminate the association by exchanging release messages @@ -806,7 +931,7 @@ impl ClientAssociation { /// **Note:** reading and writing should be done with care /// to avoid inconsistencies in the association state. /// Do not call `send` and `receive` while not in a PDU boundary. - pub fn inner_stream(&mut self) -> &mut TcpStream { + pub fn inner_stream(&mut self) -> &mut std::net::TcpStream { &mut self.socket } @@ -815,7 +940,7 @@ impl ClientAssociation { /// /// Returns a writer which automatically /// splits the inner data into separate PDUs if necessary. - pub fn send_pdata(&mut self, presentation_context_id: u8) -> PDataWriter<&mut TcpStream> { + pub fn send_pdata(&mut self, presentation_context_id: u8) -> PDataWriter<&mut std::net::TcpStream> { PDataWriter::new( &mut self.socket, presentation_context_id, @@ -828,7 +953,7 @@ impl ClientAssociation { /// /// Returns a reader which automatically /// receives more data PDUs once the bytes collected are consumed. - pub fn receive_pdata(&mut self) -> PDataReader<&mut TcpStream> { + pub fn receive_pdata(&mut self) -> PDataReader<&mut std::net::TcpStream> { PDataReader::new(&mut self.socket, self.requestor_max_pdu_length) } @@ -840,8 +965,7 @@ impl ClientAssociation { fn release_impl(&mut self) -> Result<()> { let pdu = Pdu::ReleaseRQ; self.send(&pdu)?; - let pdu = read_pdu(&mut self.socket, self.requestor_max_pdu_length, self.strict) - .context(ReceiveSnafu)?; + let pdu = self.receive()?; match pdu { Pdu::ReleaseRP => {} @@ -858,9 +982,528 @@ impl ClientAssociation { } /// Automatically release the association and shut down the connection. -impl Drop for ClientAssociation { +impl Drop for ClientAssociation +where + T: CloseSocket, + ClientAssociation: Release, +{ fn drop(&mut self) { - let _ = self.release_impl(); - let _ = self.socket.shutdown(std::net::Shutdown::Both); + let _ = self.release(); + let _ = self.socket.close(); + } +} + +#[cfg(feature = "async")] +pub mod non_blocking { + use std::{convert::TryInto, io::Cursor, net::ToSocketAddrs}; + + use crate::{ + association::{ + client::{ + ConnectSnafu, ConnectionClosedSnafu, MissingAbstractSyntaxSnafu, + NoAcceptedPresentationContextsSnafu, ProtocolVersionMismatchSnafu, + ReceiveResponseSnafu, ReceiveSnafu, RejectedSnafu, SendRequestSnafu, + UnexpectedResponseSnafu, UnknownResponseSnafu, + }, + pdata::non_blocking::{AsyncPDataWriter, PDataReader}, + }, + pdu::{ + AbortRQSource, AssociationAC, AssociationRQ, PresentationContextProposed, + PresentationContextResultReason, ReadPduSnafu, UserVariableItem, DEFAULT_MAX_PDU, + MAXIMUM_PDU_SIZE, + }, + read_pdu, write_pdu, AeAddr, Pdu, IMPLEMENTATION_CLASS_UID, IMPLEMENTATION_VERSION_NAME, + }; + + use super::{ + ClientAssociation, ClientAssociationOptions, CloseSocket, Release, Result, SendSnafu, + SendTooLongPduSnafu, WireSendSnafu, + }; + use bytes::{Buf, BytesMut}; + use snafu::{ensure, ResultExt}; + use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt}; + use tracing::warn; + + pub async fn get_client_pdu_async( + reader: &mut R, + max_pdu_length: u32, + strict: bool, + ) -> Result { + // receive response + use tokio::io::AsyncReadExt; + let mut read_buffer = BytesMut::with_capacity(MAXIMUM_PDU_SIZE as usize); + + let msg = loop { + let mut buf = Cursor::new(&read_buffer[..]); + match read_pdu(&mut buf, max_pdu_length, strict).context(ReceiveResponseSnafu)? { + Some(pdu) => { + read_buffer.advance(buf.position() as usize); + break pdu; + } + None => { + // Reset position + buf.set_position(0) + } + } + let recv = reader + .read_buf(&mut read_buffer) + .await + .context(ReadPduSnafu) + .context(ReceiveSnafu)?; + ensure!(recv > 0, ConnectionClosedSnafu); + }; + Ok(msg) + } + + impl<'a> ClientAssociationOptions<'a> { + async fn establish_impl_async( + self, + ae_address: AeAddr, + ) -> Result> + where + T: ToSocketAddrs, + { + let timeout = self.timeout; + let task = async { + let ClientAssociationOptions { + calling_ae_title, + called_ae_title, + application_context_name, + presentation_contexts, + protocol_version, + max_pdu_length, + strict, + username, + password, + kerberos_service_ticket, + saml_assertion, + jwt, + timeout, + } = self; + + // fail if no presentation contexts were provided: they represent intent, + // should not be omitted by the user + ensure!( + !presentation_contexts.is_empty(), + MissingAbstractSyntaxSnafu + ); + + // choose called AE title + let called_ae_title: &str = match (&called_ae_title, ae_address.ae_title()) { + (Some(aec), Some(_)) => { + tracing::warn!( + "Option `called_ae_title` overrides the AE title to `{}`", + aec + ); + aec + } + (Some(aec), None) => aec, + (None, Some(aec)) => aec, + (None, None) => "ANY-SCP", + }; + + let presentation_contexts: Vec<_> = presentation_contexts + .into_iter() + .enumerate() + .map(|(i, presentation_context)| PresentationContextProposed { + id: (2 * i + 1) as u8, + abstract_syntax: presentation_context.0.to_string(), + transfer_syntaxes: presentation_context + .1 + .iter() + .map(|uid| uid.to_string()) + .collect(), + }) + .collect(); + + let mut user_variables = vec![ + UserVariableItem::MaxLength(max_pdu_length), + UserVariableItem::ImplementationClassUID(IMPLEMENTATION_CLASS_UID.to_string()), + UserVariableItem::ImplementationVersionName( + IMPLEMENTATION_VERSION_NAME.to_string(), + ), + ]; + + if let Some(user_identity) = Self::determine_user_identity( + username, + password, + kerberos_service_ticket, + saml_assertion, + jwt, + ) { + user_variables.push(UserVariableItem::UserIdentityItem(user_identity)); + } + + let msg = Pdu::AssociationRQ(AssociationRQ { + protocol_version, + calling_ae_title: calling_ae_title.to_string(), + called_ae_title: called_ae_title.to_string(), + application_context_name: application_context_name.to_string(), + presentation_contexts, + user_variables, + }); + let socket_addrs: Vec<_> = ae_address.to_socket_addrs().unwrap().collect(); + + let mut socket = tokio::net::TcpStream::connect(socket_addrs.as_slice()) + .await + .context(ConnectSnafu)?; + let mut buffer: Vec = Vec::with_capacity(max_pdu_length as usize); + // send request + + write_pdu(&mut buffer, &msg).context(SendRequestSnafu)?; + socket.write_all(&buffer).await.context(WireSendSnafu)?; + buffer.clear(); + + // receive response + let msg = get_client_pdu_async(&mut socket, MAXIMUM_PDU_SIZE, self.strict).await?; + + match msg { + Pdu::AssociationAC(AssociationAC { + protocol_version: protocol_version_scp, + application_context_name: _, + presentation_contexts: presentation_contexts_scp, + calling_ae_title: _, + called_ae_title: _, + user_variables, + }) => { + ensure!( + protocol_version == protocol_version_scp, + ProtocolVersionMismatchSnafu { + expected: protocol_version, + got: protocol_version_scp, + } + ); + + let acceptor_max_pdu_length = user_variables + .iter() + .find_map(|item| match item { + UserVariableItem::MaxLength(len) => Some(*len), + _ => None, + }) + .unwrap_or(DEFAULT_MAX_PDU); + + // treat 0 as the maximum size admitted by the standard + let acceptor_max_pdu_length = if acceptor_max_pdu_length == 0 { + MAXIMUM_PDU_SIZE + } else { + acceptor_max_pdu_length + }; + + let presentation_contexts: Vec<_> = presentation_contexts_scp + .into_iter() + .filter(|c| c.reason == PresentationContextResultReason::Acceptance) + .collect(); + if presentation_contexts.is_empty() { + // abort connection + let _ = write_pdu( + &mut buffer, + &Pdu::AbortRQ { + source: AbortRQSource::ServiceUser, + }, + ); + let _ = socket.write_all(&buffer).await; + buffer.clear(); + return NoAcceptedPresentationContextsSnafu.fail(); + } + Ok(ClientAssociation { + presentation_contexts, + requestor_max_pdu_length: max_pdu_length, + acceptor_max_pdu_length, + socket, + buffer, + strict, + timeout, + read_buffer: BytesMut::with_capacity(MAXIMUM_PDU_SIZE as usize), + }) + } + Pdu::AssociationRJ(association_rj) => RejectedSnafu { association_rj }.fail(), + pdu @ Pdu::AbortRQ { .. } + | pdu @ Pdu::ReleaseRQ { .. } + | pdu @ Pdu::AssociationRQ { .. } + | pdu @ Pdu::PData { .. } + | pdu @ Pdu::ReleaseRP { .. } => { + // abort connection + let _ = write_pdu( + &mut buffer, + &Pdu::AbortRQ { + source: AbortRQSource::ServiceUser, + }, + ); + let _ = socket.write_all(&buffer).await; + UnexpectedResponseSnafu { pdu }.fail() + } + pdu @ Pdu::Unknown { .. } => { + // abort connection + let _ = write_pdu( + &mut buffer, + &Pdu::AbortRQ { + source: AbortRQSource::ServiceUser, + }, + ); + let _ = socket.write_all(&buffer).await; + UnknownResponseSnafu { pdu }.fail() + } + } + }; + if let Some(timeout) = timeout { + tokio::time::timeout(timeout, task) + .await + .map_err(|err| std::io::Error::new(std::io::ErrorKind::TimedOut, err)) + .context(ConnectSnafu)? + } else { + warn!("No timeout set. It is highly recommended to set a timeout."); + task.await + } + } + + /// Initiate the TCP connection to the given address + /// and request a new DICOM association, + /// negotiating the presentation contexts in the process. + pub async fn establish_async( + self, + address: A, + ) -> Result> { + self.establish_impl_async(AeAddr::new_socket_addr(address)) + .await + } + + /// Initiate the TCP connection to the given address + /// and request a new DICOM association, + /// negotiating the presentation contexts in the process. + /// + /// This method allows you to specify the called AE title + /// alongside with the socket address. + /// See [AeAddr](`crate::AeAddr`) for more details. + /// However, the AE title in this parameter + /// is overridden by any `called_ae_title` option + /// previously received. + /// + /// # Example + /// + /// ```no_run + /// # use dicom_ul::association::client::ClientAssociationOptions; + /// #[tokio::main] + /// # async fn run() -> Result<(), Box> { + /// let association = ClientAssociationOptions::new() + /// .with_abstract_syntax("1.2.840.10008.1.1") + /// // called AE title in address + /// .establish_with_async("MY-STORAGE@10.0.0.100:104") + /// .await?; + /// # Ok(()) + /// # } + /// ``` + #[allow(unreachable_patterns)] + pub async fn establish_with_async( + self, + ae_address: &str, + ) -> Result> { + match ae_address.try_into() { + Ok(ae_address) => self.establish_impl_async(ae_address).await, + Err(_) => { + self.establish_impl_async(AeAddr::new_socket_addr(ae_address)).await + } + } + } + } + + impl ClientAssociation + where + ClientAssociation: Release, + { + /// Send a PDU message to the other intervenient. + pub async fn send(&mut self, msg: &Pdu) -> Result<()> { + let timeout = self.timeout; + let task = async { + self.buffer.clear(); + write_pdu(&mut self.buffer, msg).context(SendSnafu)?; + if self.buffer.len() > self.acceptor_max_pdu_length as usize { + return SendTooLongPduSnafu { + length: self.buffer.len(), + } + .fail(); + } + self.socket + .write_all(&self.buffer) + .await + .context(WireSendSnafu) + }; + if let Some(timeout) = timeout { + tokio::time::timeout(timeout, task) + .await + .map_err(|err| std::io::Error::new(std::io::ErrorKind::TimedOut, err)) + .context(WireSendSnafu)? + } else { + task.await + } + } + + /// Read a PDU message from the other intervenient. + pub async fn receive(&mut self) -> Result { + let timeout = self.timeout; + let task = async { + loop { + let mut buf = Cursor::new(&self.read_buffer[..]); + match read_pdu(&mut buf, self.requestor_max_pdu_length, self.strict) + .context(ReceiveResponseSnafu)? + { + Some(pdu) => { + self.read_buffer.advance(buf.position() as usize); + return Ok(pdu); + } + None => { + // Reset position + buf.set_position(0) + } + } + let recv = self + .socket + .read_buf(&mut self.read_buffer) + .await + .context(ReadPduSnafu) + .context(ReceiveSnafu)?; + ensure!(recv > 0, ConnectionClosedSnafu); + } + }; + if let Some(timeout) = timeout { + tokio::time::timeout(timeout, task) + .await + .map_err(|err| std::io::Error::new(std::io::ErrorKind::TimedOut, err)) + .context(ReadPduSnafu) + .context(ReceiveSnafu)? + } else { + task.await + } + } + + /// Gracefully terminate the association by exchanging release messages + /// and then shutting down the TCP connection. + pub async fn release(mut self) -> Result<()> { + let timeout = self.timeout; + let task = async { + let out = self.release_impl().await; + let _ = self.socket.shutdown().await; + out + }; + if let Some(timeout) = timeout { + tokio::time::timeout(timeout, task) + .await + .map_err(|err| std::io::Error::new(std::io::ErrorKind::TimedOut, err)) + .context(WireSendSnafu)? + } else { + task.await + } + } + + /// Send an abort message and shut down the TCP connection, + /// terminating the association. + pub async fn abort(mut self) -> Result<()> { + let timeout = self.timeout; + let task = async { + let pdu = Pdu::AbortRQ { + source: AbortRQSource::ServiceUser, + }; + let out = self.send(&pdu).await; + let _ = self.socket.shutdown().await; + out + }; + if let Some(timeout) = timeout { + tokio::time::timeout(timeout, task) + .await + .map_err(|err| std::io::Error::new(std::io::ErrorKind::TimedOut, err)) + .context(WireSendSnafu)? + } else { + task.await + } + } + + /// Prepare a P-Data writer for sending + /// one or more data items. + /// + /// Returns a writer which automatically + /// splits the inner data into separate PDUs if necessary. + pub async fn send_pdata( + &mut self, + presentation_context_id: u8, + ) -> AsyncPDataWriter<&mut tokio::net::TcpStream> { + AsyncPDataWriter::new( + &mut self.socket, + presentation_context_id, + self.acceptor_max_pdu_length, + ) + } + + /// Prepare a P-Data reader for receiving + /// one or more data item PDUs. + /// + /// Returns a reader which automatically + /// receives more data PDUs once the bytes collected are consumed. + #[cfg(feature = "async")] + pub fn receive_pdata(&mut self) -> PDataReader<&mut tokio::net::TcpStream> { + PDataReader::new(&mut self.socket, self.requestor_max_pdu_length) + } + + /// Release implementation function, + /// which tries to send a release request and receive a release response. + /// This is in a separate private function because + /// terminating a connection should still close the connection + /// if the exchange fails. + async fn release_impl(&mut self) -> Result<()> { + let pdu = Pdu::ReleaseRQ; + self.send(&pdu).await?; + use tokio::io::AsyncReadExt; + let mut read_buffer = BytesMut::with_capacity(MAXIMUM_PDU_SIZE as usize); + + let pdu = loop { + if let Ok(Some(pdu)) = read_pdu(&mut read_buffer, MAXIMUM_PDU_SIZE, self.strict) { + break pdu; + } + let recv = self + .socket + .read_buf(&mut read_buffer) + .await + .context(ReadPduSnafu) + .context(ReceiveSnafu)?; + ensure!(recv > 0, ConnectionClosedSnafu); + }; + match pdu { + Pdu::ReleaseRP => {} + pdu @ Pdu::AbortRQ { .. } + | pdu @ Pdu::AssociationAC { .. } + | pdu @ Pdu::AssociationRJ { .. } + | pdu @ Pdu::AssociationRQ { .. } + | pdu @ Pdu::PData { .. } + | pdu @ Pdu::ReleaseRQ { .. } => return UnexpectedResponseSnafu { pdu }.fail(), + pdu @ Pdu::Unknown { .. } => return UnknownResponseSnafu { pdu }.fail(), + } + Ok(()) + } + /// Obtain access to the inner TCP stream + /// connected to the association acceptor. + /// + /// This can be used to send the PDU in semantic fragments of the message, + /// thus using less memory. + /// + /// **Note:** reading and writing should be done with care + /// to avoid inconsistencies in the association state. + /// Do not call `send` and `receive` while not in a PDU boundary. + pub fn inner_stream(&mut self) -> &mut tokio::net::TcpStream { + &mut self.socket + } + } + + impl Release for ClientAssociation { + fn release(&mut self) -> super::Result<()> { + tokio::task::block_in_place(move || { + tokio::runtime::Handle::current().block_on(async move { self.release_impl().await }) + }) + } + } + /// Automatically release the association and shut down the connection. + impl CloseSocket for tokio::net::TcpStream { + fn close(&mut self) -> std::io::Result<()> { + tokio::task::block_in_place(move || { + tokio::runtime::Handle::current().block_on(async move { self.shutdown().await }) + }) + } } } diff --git a/ul/src/association/mod.rs b/ul/src/association/mod.rs index ba1cead91..514a537bc 100644 --- a/ul/src/association/mod.rs +++ b/ul/src/association/mod.rs @@ -17,10 +17,10 @@ //! [1]: std::net::TcpStream pub mod client; pub mod server; + mod uid; -pub(crate) mod pdata; +pub mod pdata; pub use client::{ClientAssociation, ClientAssociationOptions}; -pub use pdata::{PDataReader, PDataWriter}; pub use server::{ServerAssociation, ServerAssociationOptions}; diff --git a/ul/src/association/pdata.rs b/ul/src/association/pdata.rs index 94c753020..c86b46637 100644 --- a/ul/src/association/pdata.rs +++ b/ul/src/association/pdata.rs @@ -1,11 +1,38 @@ use std::{ collections::VecDeque, - io::{Read, Write}, + io::{BufRead, BufReader, Cursor, Read, Write}, }; +use bytes::{Buf, BytesMut}; use tracing::warn; -use crate::{pdu::reader::PDU_HEADER_SIZE, read_pdu, Pdu}; +use crate::{pdu::PDU_HEADER_SIZE, read_pdu, Pdu}; + +/// Set up the P-Data PDU header for sending. +fn setup_pdata_header(buffer: &mut [u8], is_last: bool) { + let data_len = (buffer.len() - 12) as u32; + + // full PDU length (minus PDU type and reserved byte) + let pdu_len = data_len + 4 + 2; + let pdu_len_bytes = pdu_len.to_be_bytes(); + + buffer[2] = pdu_len_bytes[0]; + buffer[3] = pdu_len_bytes[1]; + buffer[4] = pdu_len_bytes[2]; + buffer[5] = pdu_len_bytes[3]; + + // presentation data length (data + 2 properties below) + let pdv_data_len = data_len + 2; + let data_len_bytes = pdv_data_len.to_be_bytes(); + + buffer[6] = data_len_bytes[0]; + buffer[7] = data_len_bytes[1]; + buffer[8] = data_len_bytes[2]; + buffer[9] = data_len_bytes[3]; + + // message control header + buffer[11] = if is_last { 0x02 } else { 0x00 }; +} /// A P-Data value writer. /// @@ -22,7 +49,7 @@ use crate::{pdu::reader::PDU_HEADER_SIZE, read_pdu, Pdu}; /// /// ```no_run /// # use std::io::Write; -/// # use dicom_ul::association::{ClientAssociationOptions, PDataWriter}; +/// # use dicom_ul::association::{ClientAssociationOptions, pdata::PDataWriter}; /// # use dicom_ul::pdu::{Pdu, PDataValue, PDataValueType}; /// # fn command_data() -> Vec { unimplemented!() } /// # fn dicom_data() -> &'static [u8] { unimplemented!() } @@ -104,36 +131,10 @@ where Ok(()) } - /// Set up the P-Data PDU header for sending. - fn setup_pdata_header(&mut self, is_last: bool) { - let data_len = (self.buffer.len() - 12) as u32; - - // full PDU length (minus PDU type and reserved byte) - let pdu_len = data_len + 4 + 2; - let pdu_len_bytes = pdu_len.to_be_bytes(); - - self.buffer[2] = pdu_len_bytes[0]; - self.buffer[3] = pdu_len_bytes[1]; - self.buffer[4] = pdu_len_bytes[2]; - self.buffer[5] = pdu_len_bytes[3]; - - // presentation data length (data + 2 properties below) - let pdv_data_len = data_len + 2; - let data_len_bytes = pdv_data_len.to_be_bytes(); - - self.buffer[6] = data_len_bytes[0]; - self.buffer[7] = data_len_bytes[1]; - self.buffer[8] = data_len_bytes[2]; - self.buffer[9] = data_len_bytes[3]; - - // message control header - self.buffer[11] = if is_last { 0x02 } else { 0x00 }; - } - fn finish_impl(&mut self) -> std::io::Result<()> { if !self.buffer.is_empty() { // send last PDU - self.setup_pdata_header(true); + setup_pdata_header(&mut self.buffer, true); self.stream.write_all(&self.buffer[..])?; // clear buffer so that subsequent calls to `finish_impl` // do not send any more PDUs @@ -149,7 +150,7 @@ where fn dispatch_pdu(&mut self) -> std::io::Result<()> { debug_assert!(self.buffer.len() >= 12); // send PDU now - self.setup_pdata_header(false); + setup_pdata_header(&mut self.buffer, false); self.stream.write_all(&self.buffer)?; // back to just the header @@ -176,6 +177,7 @@ where self.buffer.extend(buf); debug_assert_eq!(self.buffer.len(), total_len); self.dispatch_pdu()?; + println!("{:?}", buf.len()); Ok(buf.len()) } } @@ -215,7 +217,7 @@ where /// /// ```no_run /// # use std::io::Read; -/// # use dicom_ul::association::{ClientAssociationOptions, PDataReader}; +/// # use dicom_ul::association::{ClientAssociationOptions, pdata::PDataReader}; /// # use dicom_ul::pdu::{Pdu, PDataValue, PDataValueType}; /// # fn command_data() -> Vec { unimplemented!() } /// # fn dicom_data() -> &'static [u8] { unimplemented!() } @@ -232,6 +234,7 @@ where /// }; /// # Ok(()) /// # } +/// ``` #[must_use] pub struct PDataReader { buffer: VecDeque, @@ -239,12 +242,10 @@ pub struct PDataReader { presentation_context_id: Option, max_data_length: u32, last_pdu: bool, + read_buffer: BytesMut, } -impl PDataReader -where - R: Read, -{ +impl PDataReader { pub fn new(stream: R, max_data_length: u32) -> Self { PDataReader { buffer: VecDeque::with_capacity(max_data_length as usize), @@ -252,6 +253,7 @@ where presentation_context_id: None, max_data_length, last_pdu: false, + read_buffer: BytesMut::with_capacity(max_data_length as usize), } } @@ -277,10 +279,33 @@ where return Ok(0); } - let pdu = read_pdu(&mut self.stream, self.max_data_length, false) - .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?; + let mut reader = BufReader::new(&mut self.stream); + let msg = loop { + let mut buf = Cursor::new(&self.read_buffer[..]); + match read_pdu(&mut buf, self.max_data_length, false) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))? + { + Some(pdu) => { + self.read_buffer.advance(buf.position() as usize); + break pdu; + } + None => { + // Reset position + buf.set_position(0) + } + } + let recv = reader.fill_buf()?.to_vec(); + reader.consume(recv.len()); + self.read_buffer.extend_from_slice(&recv); + if recv.is_empty() { + return Err(std::io::Error::new( + std::io::ErrorKind::Other, + "Connection closed by peer", + )); + } + }; - match pdu { + match msg { Pdu::PData { data } => { for pdata_value in data { self.presentation_context_id = match self.presentation_context_id { @@ -317,17 +342,335 @@ fn calculate_max_data_len_single(pdu_len: u32) -> u32 { pdu_len - 4 - 2 } +#[cfg(feature = "async")] +pub mod non_blocking { + use std::{ + future::Future, + io::Cursor, + pin::Pin, + task::{ready, Context, Poll}, + }; + + use bytes::{Buf, BufMut}; + use tokio::io::{ + AsyncBufRead, AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader, ReadBuf, + }; + use tracing::warn; + + use crate::{pdu::PDU_HEADER_SIZE, read_pdu, Pdu}; + + pub use super::PDataReader; + use super::{calculate_max_data_len_single, setup_pdata_header}; + + /// A P-Data async value writer. + /// + /// This exposes an API to iteratively construct and send Data messages + /// to another node. + /// Using this as a [standard writer](std::io::Write) + /// will automatically split the incoming bytes + /// into separate PDUs if they do not fit in a single one. + /// + /// # Example + /// + /// Use an association's `send_pdata` method + /// to create a new P-Data value writer. + /// + /// ```no_run + /// # use std::io::Write; + /// use tokio::io::AsyncWriteExt; + /// # use dicom_ul::association::{ClientAssociationOptions, pdata::non_blocking::AsyncPDataWriter}; + /// # use dicom_ul::pdu::{Pdu, PDataValue, PDataValueType}; + /// # fn command_data() -> Vec { unimplemented!() } + /// # fn dicom_data() -> &'static [u8] { unimplemented!() } + /// #[tokio::main] + /// # async fn main() -> Result<(), Box> { + /// let mut association = ClientAssociationOptions::new() + /// .establish_async("129.168.0.5:104") + /// .await?; + /// + /// let presentation_context_id = association.presentation_contexts()[0].id; + /// + /// // send a command first + /// association.send(&Pdu::PData { + /// data: vec![PDataValue { + /// presentation_context_id, + /// value_type: PDataValueType::Command, + /// is_last: true, + /// data: command_data(), + /// }], + /// }).await; + /// + /// // then send a DICOM object which may be split into multiple PDUs + /// let mut pdata = association.send_pdata(presentation_context_id).await; + /// pdata.write_all(dicom_data()).await?; + /// pdata.finish().await?; + /// + /// let pdu_ac = association.receive().await?; + /// # Ok(()) + /// # } + /// ``` + #[must_use] + pub struct AsyncPDataWriter { + buffer: Vec, + stream: W, + max_data_len: u32, + msg: u32, + writing: bool, + } + + #[cfg(feature = "async")] + impl AsyncPDataWriter + where + W: AsyncWrite + Unpin, + { + /// Construct a new P-Data value writer. + /// + /// `max_pdu_length` is the maximum value of the PDU-length property. + pub(crate) fn new(stream: W, presentation_context_id: u8, max_pdu_length: u32) -> Self { + let max_data_length = calculate_max_data_len_single(max_pdu_length); + let mut buffer = Vec::with_capacity((max_data_length + PDU_HEADER_SIZE) as usize); + // initial buffer set up + buffer.extend([ + // PDU-type + reserved byte + 0x04, + 0x00, + // full PDU length, unknown at this point + 0xFF, + 0xFF, + 0xFF, + 0xFF, + // presentation data length, unknown at this point + 0xFF, + 0xFF, + 0xFF, + 0xFF, + // presentation context id + presentation_context_id, + // message control header, unknown at this point + 0xFF, + ]); + + AsyncPDataWriter { + stream, + max_data_len: max_data_length, + buffer, + msg: 0, + writing: false, + } + } + + /// Declare to have finished sending P-Data fragments, + /// thus emitting the last P-Data fragment PDU. + /// + /// This is also done automatically once the P-Data writer is dropped. + pub async fn finish(mut self) -> std::io::Result<()> { + self.finish_impl().await?; + Ok(()) + } + + async fn finish_impl(&mut self) -> std::io::Result<()> { + println!("Finish, {}", self.msg); + if !self.buffer.is_empty() { + // send last PDU + setup_pdata_header(&mut self.buffer, true); + if let Err(e) = self.stream.write_all(&self.buffer[..]).await { + println!("Error: {:?}", e); + } + // clear buffer so that subsequent calls to `finish_impl` + // do not send any more PDUs + self.buffer.clear(); + } + Ok(()) + } + } + + #[cfg(feature = "async")] + impl AsyncWrite for AsyncPDataWriter + where + W: AsyncWrite + Unpin, + { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + // If we're still writing (i.e. last write was pending), continue writing + if self.writing { + let this = self.get_mut(); + let buffer = &this.buffer; + let mut stream = Pin::new(&mut this.stream); + // Each call to `poll_write` may or may not write the whole of `self.buffer` + let write_all = stream.write_all(buffer); + tokio::pin!(write_all); + match write_all.poll(cx) { + Poll::Ready(Ok(_)) => { + this.writing = false; + println!("{:?}", this.msg); + this.msg += 1; + this.buffer.truncate(12); + return Poll::Ready(Ok(buf.len())); + } + Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), + Poll::Pending => return Poll::Pending, + } + } + let total_len = self.max_data_len as usize + 12; + if self.buffer.len() + buf.len() <= total_len { + // accumulate into buffer, do nothing + self.buffer.extend(buf); + Poll::Ready(Ok(buf.len())) + } else { + // fill in the rest of the buffer, send PDU, + // and leave out the rest for subsequent writes + let buf = &buf[..total_len - self.buffer.len()]; + self.buffer.extend(buf); + debug_assert_eq!(self.buffer.len(), total_len); + setup_pdata_header(&mut self.buffer, false); + let this = self.get_mut(); + let buffer = &this.buffer; + let mut stream = Pin::new(&mut this.stream); + // Each call to `poll_write` may or may not write the whole of `self.buffer` + let write_all = stream.write_all(buffer); + tokio::pin!(write_all); + match write_all.poll(cx) { + Poll::Ready(Ok(_)) => { + this.msg += 1; + this.buffer.truncate(12); + Poll::Ready(Ok(buf.len())) + } + Poll::Ready(Err(e)) => Poll::Ready(Err(e)), + Poll::Pending => { + this.writing = true; + Poll::Pending + } + } + } + } + + fn poll_flush( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + Pin::new(&mut self.stream).poll_flush(cx) + } + + fn poll_shutdown( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + Pin::new(&mut self.stream).poll_shutdown(cx) + } + } + + /// With the P-Data writer dropped, + /// this `Drop` implementation + /// will construct and emit the last P-Data fragment PDU + /// if there is any data left to send. + impl Drop for AsyncPDataWriter + where + W: AsyncWrite + Unpin, + { + fn drop(&mut self) { + tokio::task::block_in_place(move || { + tokio::runtime::Handle::current().block_on(async move { + let _ = self.finish_impl().await; + }) + }) + } + } + + impl AsyncRead for PDataReader + where + R: AsyncRead + Unpin, + { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf, + ) -> Poll> { + if self.buffer.is_empty() { + if self.last_pdu { + return Poll::Ready(Ok(())); + } + let Self { + ref mut stream, + ref mut read_buffer, + ref max_data_length, + .. + } = &mut *self; + let mut reader = BufReader::new(stream); + let msg = loop { + let mut buf = Cursor::new(&read_buffer[..]); + match read_pdu(&mut buf, *max_data_length, false) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))? + { + Some(pdu) => { + read_buffer.advance(buf.position() as usize); + break pdu; + } + None => { + // Reset position + buf.set_position(0) + } + } + let recv = ready!(Pin::new(&mut reader).poll_fill_buf(cx))?.to_vec(); + reader.consume(recv.len()); + read_buffer.extend_from_slice(&recv); + if recv.is_empty() { + return Poll::Ready(Err(std::io::Error::new( + std::io::ErrorKind::Other, + "Connection closed by peer", + ))); + } + }; + match msg { + Pdu::PData { data } => { + for pdata_value in data { + self.presentation_context_id = match self.presentation_context_id { + None => Some(pdata_value.presentation_context_id), + Some(cid) if cid == pdata_value.presentation_context_id => { + Some(cid) + } + Some(cid) => { + warn!("Received PData value of presentation context {}, but should be {}", pdata_value.presentation_context_id, cid); + Some(cid) + } + }; + self.buffer.extend(pdata_value.data); + self.last_pdu = pdata_value.is_last; + } + } + _ => { + return Poll::Ready(Err(std::io::Error::new( + std::io::ErrorKind::UnexpectedEof, + "Unexpected PDU type", + ))) + } + } + } + let len = std::cmp::min(self.buffer.len(), buf.remaining()); + for _ in 0..len { + buf.put_u8(self.buffer.pop_front().unwrap()); + } + Poll::Ready(Ok(())) + } + } +} + #[cfg(test)] mod tests { - use std::collections::VecDeque; use std::io::{Read, Write}; - use crate::pdu::reader::{read_pdu, MINIMUM_PDU_SIZE, PDU_HEADER_SIZE}; - use crate::pdu::Pdu; + use crate::association::pdata::PDataWriter; + use crate::pdu::{read_pdu, Pdu, MINIMUM_PDU_SIZE, PDU_HEADER_SIZE}; use crate::pdu::{PDataValue, PDataValueType}; use crate::write_pdu; - use super::{PDataReader, PDataWriter}; + use super::PDataReader; + + use tokio::io::AsyncWriteExt; + + use crate::association::pdata::non_blocking::AsyncPDataWriter; #[test] fn test_write_pdata_and_finish() { @@ -346,7 +689,43 @@ mod tests { // concatenate data chunks, compare with all data match same_pdu { - Pdu::PData { data: data_1 } => { + Some(Pdu::PData { data: data_1 }) => { + let data_1 = &data_1[0]; + + // check that this PDU is consistent + assert_eq!(data_1.value_type, PDataValueType::Data); + assert_eq!(data_1.presentation_context_id, presentation_context_id); + assert_eq!(data_1.data.len(), 64); + assert_eq!(data_1.data, (0..64).collect::>()); + } + pdu => panic!("Expected PData, got {:?}", pdu), + } + + assert_eq!(cursor.len(), 0); + } + + #[tokio::test(flavor = "multi_thread")] + async fn test_async_write_pdata_and_finish() { + let presentation_context_id = 12; + + let mut buf = Vec::new(); + { + let mut writer = + AsyncPDataWriter::new(&mut buf, presentation_context_id, MINIMUM_PDU_SIZE); + writer + .write_all(&(0..64).collect::>()) + .await + .unwrap(); + writer.finish().await.unwrap(); + } + + let mut cursor = &buf[..]; + let same_pdu = read_pdu(&mut cursor, MINIMUM_PDU_SIZE, true).unwrap(); + + // concatenate data chunks, compare with all data + + match same_pdu { + Some(Pdu::PData { data: data_1 }) => { let data_1 = &data_1[0]; // check that this PDU is consistent @@ -383,9 +762,88 @@ mod tests { match (pdu_1, pdu_2, pdu_3) { ( - Pdu::PData { data: data_1 }, - Pdu::PData { data: data_2 }, - Pdu::PData { data: data_3 }, + Some(Pdu::PData { data: data_1 }), + Some(Pdu::PData { data: data_2 }), + Some(Pdu::PData { data: data_3 }), + ) => { + assert_eq!(data_1.len(), 1); + let data_1 = &data_1[0]; + assert_eq!(data_2.len(), 1); + let data_2 = &data_2[0]; + assert_eq!(data_3.len(), 1); + let data_3 = &data_3[0]; + + // check that these two PDUs are consistent + assert_eq!(data_1.value_type, PDataValueType::Data); + assert_eq!(data_2.value_type, PDataValueType::Data); + assert_eq!(data_1.presentation_context_id, presentation_context_id); + assert_eq!(data_2.presentation_context_id, presentation_context_id); + + // check expected lengths + assert_eq!( + data_1.data.len(), + (MINIMUM_PDU_SIZE - PDU_HEADER_SIZE) as usize + ); + assert_eq!( + data_2.data.len(), + (MINIMUM_PDU_SIZE - PDU_HEADER_SIZE) as usize + ); + assert_eq!(data_3.data.len(), 820); + + // check data consistency + assert_eq!( + &data_1.data[..], + (0..MINIMUM_PDU_SIZE - PDU_HEADER_SIZE) + .map(|x| x as u8) + .collect::>() + ); + assert_eq!( + data_1.data.len() + data_2.data.len() + data_3.data.len(), + 9000 + ); + + let data_1 = &data_1.data; + let data_2 = &data_2.data; + let data_3 = &data_3.data; + + let mut all_data: Vec = Vec::new(); + all_data.extend(data_1); + all_data.extend(data_2); + all_data.extend(data_3); + assert_eq!(all_data, my_data); + } + x => panic!("Expected 3 PDatas, got {:?}", x), + } + + assert_eq!(cursor.len(), 0); + } + + #[tokio::test(flavor = "multi_thread")] + async fn test_async_write_large_pdata_and_finish() { + let presentation_context_id = 32; + + let my_data: Vec<_> = (0..9000).map(|x: u32| x as u8).collect(); + + let mut buf = Vec::new(); + { + let mut writer = + AsyncPDataWriter::new(&mut buf, presentation_context_id, MINIMUM_PDU_SIZE); + writer.write_all(&my_data).await.unwrap(); + writer.finish().await.unwrap(); + } + + let mut cursor = &buf[..]; + let pdu_1 = read_pdu(&mut cursor, MINIMUM_PDU_SIZE, true).unwrap(); + let pdu_2 = read_pdu(&mut cursor, MINIMUM_PDU_SIZE, true).unwrap(); + let pdu_3 = read_pdu(&mut cursor, MINIMUM_PDU_SIZE, true).unwrap(); + + // concatenate data chunks, compare with all data + + match (pdu_1, pdu_2, pdu_3) { + ( + Some(Pdu::PData { data: data_1 }), + Some(Pdu::PData { data: data_2 }), + Some(Pdu::PData { data: data_3 }), ) => { assert_eq!(data_1.len(), 1); let data_1 = &data_1[0]; @@ -441,6 +899,7 @@ mod tests { #[test] fn test_read_large_pdata_and_finish() { + use std::collections::VecDeque; let presentation_context_id = 32; let my_data: Vec<_> = (0..9000).map(|x: u32| x as u8).collect(); @@ -477,4 +936,47 @@ mod tests { } assert_eq!(buf, my_data); } + + #[tokio::test] + async fn test_async_read_large_pdata_and_finish() { + use tokio::io::AsyncReadExt; + + let presentation_context_id = 32; + + let my_data: Vec<_> = (0..9000).map(|x: u32| x as u8).collect(); + let pdata_1 = vec![PDataValue { + value_type: PDataValueType::Data, + data: my_data[0..3000].to_owned(), + presentation_context_id, + is_last: false, + }]; + let pdata_2 = vec![PDataValue { + value_type: PDataValueType::Data, + data: my_data[3000..6000].to_owned(), + presentation_context_id, + is_last: false, + }]; + let pdata_3 = vec![PDataValue { + value_type: PDataValueType::Data, + data: my_data[6000..].to_owned(), + presentation_context_id, + is_last: true, + }]; + + let mut pdu_stream = std::io::Cursor::new(Vec::new()); + + // write some PDUs + write_pdu(&mut pdu_stream, &Pdu::PData { data: pdata_1 }).unwrap(); + write_pdu(&mut pdu_stream, &Pdu::PData { data: pdata_2 }).unwrap(); + write_pdu(&mut pdu_stream, &Pdu::PData { data: pdata_3 }).unwrap(); + + let mut buf = Vec::new(); + let inner = pdu_stream.into_inner(); + let mut stream = tokio::io::BufReader::new(inner.as_slice()); + { + let mut reader = PDataReader::new(&mut stream, MINIMUM_PDU_SIZE); + reader.read_to_end(&mut buf).await.unwrap(); + } + assert_eq!(buf, my_data); + } } diff --git a/ul/src/association/server.rs b/ul/src/association/server.rs index 6568de8be..86ab65705 100644 --- a/ul/src/association/server.rs +++ b/ul/src/association/server.rs @@ -4,7 +4,11 @@ //! in which this application entity listens to incoming association requests. //! See [`ServerAssociationOptions`] //! for details and examples on how to create an association. -use std::{borrow::Cow, io::Write, net::TcpStream}; +use bytes::{Buf, BytesMut}; +use std::io::{BufRead, BufReader}; +use std::time::Duration; +use std::{borrow::Cow, io::Cursor}; +use std::{io::Write, net::TcpStream}; use dicom_encoding::transfer_syntax::TransferSyntaxIndex; use dicom_transfer_syntax_registry::TransferSyntaxRegistry; @@ -12,12 +16,10 @@ use snafu::{ensure, Backtrace, ResultExt, Snafu}; use crate::{ pdu::{ - reader::{read_pdu, DEFAULT_MAX_PDU, MAXIMUM_PDU_SIZE}, - writer::write_pdu, - AbortRQServiceProviderReason, AbortRQSource, AssociationAC, AssociationRJ, - AssociationRJResult, AssociationRJServiceUserReason, AssociationRJSource, AssociationRQ, - Pdu, PresentationContextResult, PresentationContextResultReason, UserIdentity, - UserVariableItem, + read_pdu, write_pdu, AbortRQServiceProviderReason, AbortRQSource, AssociationAC, + AssociationRJ, AssociationRJResult, AssociationRJServiceUserReason, AssociationRJSource, + AssociationRQ, Pdu, PresentationContextResult, PresentationContextResultReason, + ReadPduSnafu, UserIdentity, UserVariableItem, DEFAULT_MAX_PDU, MAXIMUM_PDU_SIZE, }, IMPLEMENTATION_CLASS_UID, IMPLEMENTATION_VERSION_NAME, }; @@ -36,21 +38,25 @@ pub enum Error { /// failed to receive association request ReceiveRequest { #[snafu(backtrace)] - source: crate::pdu::reader::Error, + source: crate::pdu::ReadError, }, /// failed to send association response SendResponse { #[snafu(backtrace)] - source: crate::pdu::writer::Error, + source: crate::pdu::WriteError, }, /// failed to prepare PDU Send { #[snafu(backtrace)] - source: crate::pdu::writer::Error, + source: crate::pdu::WriteError, + }, + /// Failed to read from the wire + WireRead { + source: std::io::Error, + backtrace: Backtrace, }, - /// failed to send PDU over the wire WireSend { source: std::io::Error, @@ -60,7 +66,7 @@ pub enum Error { /// failed to receive PDU Receive { #[snafu(backtrace)] - source: crate::pdu::reader::Error, + source: crate::pdu::ReadError, }, #[snafu(display("unexpected request from SCU `{:?}`", pdu))] @@ -89,6 +95,20 @@ pub enum Error { ))] #[non_exhaustive] SendTooLongPdu { length: usize, backtrace: Backtrace }, + #[snafu(display("Connection closed by peer"))] + ConnectionClosed, + + /// Could not set tcp read timeout + SetReadTimeout { + source: std::io::Error, + backtrace: Backtrace, + }, + + /// Could not set tcp write timeout + SetWriteTimeout { + source: std::io::Error, + backtrace: Backtrace, + }, } pub type Result = std::result::Result; @@ -221,6 +241,8 @@ pub struct ServerAssociationOptions<'a, A> { strict: bool, /// whether to accept unknown abstract syntaxes promiscuous: bool, + /// Timeout for individual send/receive operations + timeout: Option, } impl<'a> Default for ServerAssociationOptions<'a, AcceptAny> { @@ -232,9 +254,10 @@ impl<'a> Default for ServerAssociationOptions<'a, AcceptAny> { abstract_syntax_uids: Vec::new(), transfer_syntax_uids: Vec::new(), protocol_version: 1, - max_pdu_length: crate::pdu::reader::DEFAULT_MAX_PDU, + max_pdu_length: DEFAULT_MAX_PDU, strict: true, promiscuous: false, + timeout: None, } } } @@ -285,6 +308,7 @@ where strict, promiscuous, ae_access_control: _, + timeout, } = self; ServerAssociationOptions { @@ -297,6 +321,7 @@ where max_pdu_length, strict, promiscuous, + timeout, } } @@ -353,19 +378,56 @@ where self } + /// Set the timeout for the underlying TCP socket + pub fn timeout(self, timeout: Duration) -> Self { + Self { + timeout: Some(timeout), + ..self + } + } + /// Negotiate an association with the given TCP stream. - pub fn establish(&self, mut socket: TcpStream) -> Result { + pub fn establish(&self, mut socket: TcpStream) -> Result> { ensure!( !self.abstract_syntax_uids.is_empty() || self.promiscuous, MissingAbstractSyntaxSnafu ); let max_pdu_length = self.max_pdu_length; - - let pdu = - read_pdu(&mut socket, max_pdu_length, self.strict).context(ReceiveRequestSnafu)?; + socket + .set_read_timeout(self.timeout) + .context(SetReadTimeoutSnafu)?; + socket + .set_write_timeout(self.timeout) + .context(SetWriteTimeoutSnafu)?; + + let mut read_buffer = BytesMut::with_capacity(MAXIMUM_PDU_SIZE as usize); + let mut reader = BufReader::new(&mut socket); + + let msg = loop { + let mut buf = Cursor::new(&read_buffer[..]); + match read_pdu(&mut buf, MAXIMUM_PDU_SIZE, self.strict).context(ReceiveRequestSnafu)? { + Some(pdu) => { + read_buffer.advance(buf.position() as usize); + break pdu; + } + None => { + // Reset position + buf.set_position(0) + } + } + // Use BufReader to get similar behavior to AsyncRead read_buf + let recv = reader + .fill_buf() + .context(ReadPduSnafu) + .context(ReceiveSnafu)? + .to_vec(); + reader.consume(recv.len()); + read_buffer.extend_from_slice(&recv); + ensure!(!recv.is_empty(), ConnectionClosedSnafu); + }; let mut buffer: Vec = Vec::with_capacity(max_pdu_length as usize); - match pdu { + match msg { Pdu::AssociationRQ(AssociationRQ { protocol_version, calling_ae_title, @@ -511,6 +573,8 @@ where client_ae_title: calling_ae_title, buffer, strict: self.strict, + read_buffer: BytesMut::with_capacity(MAXIMUM_PDU_SIZE as usize), + timeout: self.timeout, }) } Pdu::ReleaseRQ => { @@ -566,7 +630,7 @@ where /// When the value falls out of scope, /// the program will shut down the underlying TCP connection. #[derive(Debug)] -pub struct ServerAssociation { +pub struct ServerAssociation { /// The accorded presentation contexts presentation_contexts: Vec, /// The maximum PDU length that the remote application entity accepts @@ -574,16 +638,20 @@ pub struct ServerAssociation { /// The maximum PDU length that this application entity is expecting to receive acceptor_max_pdu_length: u32, /// The TCP stream to the other DICOM node - socket: TcpStream, + socket: S, /// The application entity title of the other DICOM node client_ae_title: String, /// write buffer to send fully assembled PDUs on wire buffer: Vec, /// whether to receive PDUs in strict mode strict: bool, + /// Read buffer from the socket + read_buffer: bytes::BytesMut, + /// Timeout for individual send/receive operations + timeout: Option, } -impl ServerAssociation { +impl ServerAssociation { /// Obtain a view of the negotiated presentation contexts. pub fn presentation_contexts(&self) -> &[PresentationContextResult] { &self.presentation_contexts @@ -593,7 +661,9 @@ impl ServerAssociation { pub fn client_ae_title(&self) -> &str { &self.client_ae_title } +} +impl ServerAssociation { /// Send a PDU message to the other intervenient. pub fn send(&mut self, msg: &Pdu) -> Result<()> { self.buffer.clear(); @@ -609,7 +679,34 @@ impl ServerAssociation { /// Read a PDU message from the other intervenient. pub fn receive(&mut self) -> Result { - read_pdu(&mut self.socket, self.acceptor_max_pdu_length, self.strict).context(ReceiveSnafu) + use std::io::{BufRead, BufReader, Cursor}; + + let mut reader = BufReader::new(&mut self.socket); + + loop { + let mut buf = Cursor::new(&self.read_buffer[..]); + match read_pdu(&mut buf, self.acceptor_max_pdu_length, self.strict) + .context(ReceiveRequestSnafu)? + { + Some(pdu) => { + self.read_buffer.advance(buf.position() as usize); + return Ok(pdu); + } + None => { + // Reset position + buf.set_position(0) + } + } + // Use BufReader to get similar behavior to AsyncRead read_buf + let recv = reader + .fill_buf() + .context(ReadPduSnafu) + .context(ReceiveSnafu)? + .to_vec(); + reader.consume(recv.len()); + self.read_buffer.extend_from_slice(&recv); + ensure!(!recv.is_empty(), ConnectionClosedSnafu); + } } /// Send a provider initiated abort message @@ -719,6 +816,351 @@ where it.into_iter().find(|ts| is_supported(ts.as_ref())) } +#[cfg(feature = "async")] +pub mod non_blocking { + use std::{borrow::Cow, io::Cursor}; + + use bytes::{Buf, BytesMut}; + use snafu::{ensure, ResultExt}; + use tokio::{ + io::{AsyncReadExt, AsyncWriteExt}, + net::TcpStream, + }; + + use super::{ + AccessControl, Result, SendSnafu, SendTooLongPduSnafu, ServerAssociation, + ServerAssociationOptions, WireSendSnafu, + }; + use crate::{ + association::{ + server::{ + AbortedSnafu, ConnectionClosedSnafu, MissingAbstractSyntaxSnafu, + ReceiveRequestSnafu, ReceiveSnafu, RejectedSnafu, SendResponseSnafu, + UnexpectedRequestSnafu, UnknownRequestSnafu, WireReadSnafu, + }, + uid::trim_uid, + }, + pdu::{ + AbortRQServiceProviderReason, AbortRQSource, AssociationAC, AssociationRJ, + AssociationRJResult, AssociationRJServiceUserReason, AssociationRJSource, + AssociationRQ, PresentationContextResult, PresentationContextResultReason, + ReadPduSnafu, UserVariableItem, DEFAULT_MAX_PDU, MAXIMUM_PDU_SIZE, + }, + read_pdu, write_pdu, Pdu, IMPLEMENTATION_CLASS_UID, IMPLEMENTATION_VERSION_NAME, + }; + + impl<'a, A> ServerAssociationOptions<'a, A> + where + A: AccessControl, + { + /// Negotiate an association with the given TCP stream. + pub async fn establish_async( + &self, + mut socket: TcpStream, + ) -> Result> { + ensure!( + !self.abstract_syntax_uids.is_empty() || self.promiscuous, + MissingAbstractSyntaxSnafu + ); + let timeout = self.timeout; + let task = async { + let max_pdu_length = self.max_pdu_length; + let mut read_buffer = BytesMut::with_capacity(MAXIMUM_PDU_SIZE as usize); + + let pdu = loop { + let mut buf = Cursor::new(&read_buffer[..]); + match read_pdu(&mut buf, MAXIMUM_PDU_SIZE, self.strict) + .context(ReceiveRequestSnafu)? + { + Some(pdu) => { + read_buffer.advance(buf.position() as usize); + break pdu; + } + None => { + // Reset position + buf.set_position(0) + } + } + let recv = socket + .read_buf(&mut read_buffer) + .await + .context(ReadPduSnafu) + .context(ReceiveSnafu)?; + ensure!(recv > 0, ConnectionClosedSnafu); + }; + + let mut buffer: Vec = Vec::with_capacity(max_pdu_length as usize); + match pdu { + Pdu::AssociationRQ(AssociationRQ { + protocol_version, + calling_ae_title, + called_ae_title, + application_context_name, + presentation_contexts, + user_variables, + }) => { + if protocol_version != self.protocol_version { + write_pdu( + &mut buffer, + &Pdu::AssociationRJ(AssociationRJ { + result: AssociationRJResult::Permanent, + source: AssociationRJSource::ServiceUser( + AssociationRJServiceUserReason::NoReasonGiven, + ), + }), + ) + .context(SendResponseSnafu)?; + socket.write_all(&buffer).await.context(WireSendSnafu)?; + return RejectedSnafu.fail(); + } + + if application_context_name != self.application_context_name { + write_pdu( + &mut buffer, + &Pdu::AssociationRJ(AssociationRJ { + result: AssociationRJResult::Permanent, + source: AssociationRJSource::ServiceUser( + AssociationRJServiceUserReason::ApplicationContextNameNotSupported, + ), + }), + ) + .context(SendResponseSnafu)?; + socket.write_all(&buffer).await.context(WireSendSnafu)?; + return RejectedSnafu.fail(); + } + + match self.ae_access_control.check_access( + &self.ae_title, + &calling_ae_title, + &called_ae_title, + user_variables + .iter() + .find_map(|user_variable| match user_variable { + UserVariableItem::UserIdentityItem(user_identity) => { + Some(user_identity) + } + _ => None, + }), + ) { + Ok(()) => {} + Err(reason) => { + write_pdu( + &mut buffer, + &Pdu::AssociationRJ(AssociationRJ { + result: AssociationRJResult::Permanent, + source: AssociationRJSource::ServiceUser(reason), + }), + ) + .context(SendResponseSnafu)?; + socket.write_all(&buffer).await.context(WireSendSnafu)?; + return Err(RejectedSnafu.build()); + } + } + + // fetch requested maximum PDU length + let requestor_max_pdu_length = user_variables + .iter() + .find_map(|item| match item { + UserVariableItem::MaxLength(len) => Some(*len), + _ => None, + }) + .unwrap_or(DEFAULT_MAX_PDU); + + // treat 0 as the maximum size admitted by the standard + let requestor_max_pdu_length = if requestor_max_pdu_length == 0 { + MAXIMUM_PDU_SIZE + } else { + requestor_max_pdu_length + }; + + let presentation_contexts: Vec<_> = presentation_contexts + .into_iter() + .map(|pc| { + if !self + .abstract_syntax_uids + .contains(&trim_uid(Cow::from(pc.abstract_syntax))) + && !self.promiscuous + { + return PresentationContextResult { + id: pc.id, + reason: PresentationContextResultReason::AbstractSyntaxNotSupported, + transfer_syntax: "1.2.840.10008.1.2".to_string(), + }; + } + + let (transfer_syntax, reason) = self + .choose_ts(pc.transfer_syntaxes) + .map(|ts| (ts, PresentationContextResultReason::Acceptance)) + .unwrap_or_else(|| { + ( + "1.2.840.10008.1.2".to_string(), + PresentationContextResultReason::TransferSyntaxesNotSupported, + ) + }); + + PresentationContextResult { + id: pc.id, + reason, + transfer_syntax, + } + }) + .collect(); + + write_pdu( + &mut buffer, + &Pdu::AssociationAC(AssociationAC { + protocol_version: self.protocol_version, + application_context_name, + presentation_contexts: presentation_contexts.clone(), + calling_ae_title: calling_ae_title.clone(), + called_ae_title, + user_variables: vec![ + UserVariableItem::MaxLength(max_pdu_length), + UserVariableItem::ImplementationClassUID( + IMPLEMENTATION_CLASS_UID.to_string(), + ), + UserVariableItem::ImplementationVersionName( + IMPLEMENTATION_VERSION_NAME.to_string(), + ), + ], + }), + ) + .context(SendResponseSnafu)?; + socket.write_all(&buffer).await.context(WireSendSnafu)?; + + Ok(ServerAssociation { + presentation_contexts, + requestor_max_pdu_length, + acceptor_max_pdu_length: max_pdu_length, + socket, + client_ae_title: calling_ae_title, + buffer, + strict: self.strict, + read_buffer: BytesMut::with_capacity(MAXIMUM_PDU_SIZE as usize), + timeout, + }) + } + Pdu::ReleaseRQ => { + write_pdu(&mut buffer, &Pdu::ReleaseRP).context(SendResponseSnafu)?; + socket.write_all(&buffer).await.context(WireSendSnafu)?; + AbortedSnafu.fail() + } + pdu @ Pdu::AssociationAC { .. } + | pdu @ Pdu::AssociationRJ { .. } + | pdu @ Pdu::PData { .. } + | pdu @ Pdu::ReleaseRP + | pdu @ Pdu::AbortRQ { .. } => UnexpectedRequestSnafu { pdu }.fail(), + pdu @ Pdu::Unknown { .. } => UnknownRequestSnafu { pdu }.fail(), + } + }; + if let Some(timeout) = timeout { + tokio::time::timeout(timeout, task) + .await + .map_err(|err| std::io::Error::new(std::io::ErrorKind::TimedOut, err)) + .context(WireReadSnafu)? + } else { + task.await + } + } + } + + impl ServerAssociation { + /// Send a PDU message to the other intervenient. + pub async fn send(&mut self, msg: &Pdu) -> Result<()> { + let timeout = self.timeout; + let task = async { + self.buffer.clear(); + write_pdu(&mut self.buffer, msg).context(SendSnafu)?; + if self.buffer.len() > self.requestor_max_pdu_length as usize { + return SendTooLongPduSnafu { + length: self.buffer.len(), + } + .fail(); + } + self.socket + .write_all(&self.buffer) + .await + .context(WireSendSnafu) + }; + if let Some(timeout) = timeout { + tokio::time::timeout(timeout, task) + .await + .map_err(|err| std::io::Error::new(std::io::ErrorKind::TimedOut, err)) + .context(WireSendSnafu)? + } else { + task.await + } + } + + /// Read a PDU message from the other intervenient. + pub async fn receive(&mut self) -> Result { + let timeout = self.timeout; + let task = async { + loop { + let mut buf = Cursor::new(&self.read_buffer[..]); + match read_pdu(&mut buf, self.requestor_max_pdu_length, self.strict) + .context(ReceiveRequestSnafu)? + { + Some(pdu) => { + self.read_buffer.advance(buf.position() as usize); + return Ok(pdu); + } + None => { + // Reset position + buf.set_position(0) + } + } + let recv = self + .socket + .read_buf(&mut self.read_buffer) + .await + .context(ReadPduSnafu) + .context(ReceiveSnafu)?; + ensure!(recv > 0, ConnectionClosedSnafu); + } + }; + if let Some(timeout) = timeout { + tokio::time::timeout(timeout, task) + .await + .map_err(|err| std::io::Error::new(std::io::ErrorKind::TimedOut, err)) + .context(ReadPduSnafu) + .context(ReceiveSnafu)? + } else { + task.await + } + } + + /// Send a provider initiated abort message + /// and shut down the TCP connection, + /// terminating the association. + pub async fn abort(mut self) -> Result<()> { + let timeout = self.timeout; + let task = async { + let pdu = Pdu::AbortRQ { + source: AbortRQSource::ServiceProvider( + AbortRQServiceProviderReason::ReasonNotSpecified, + ), + }; + let out = self.send(&pdu).await; + let _ = self.socket.shutdown().await; + out + }; + if let Some(timeout) = timeout { + tokio::time::timeout(timeout, task) + .await + .map_err(|err| std::io::Error::new(std::io::ErrorKind::TimedOut, err)) + .context(WireSendSnafu)? + } else { + task.await + } + } + + pub fn inner_stream(&mut self) -> &mut TcpStream { + &mut self.socket + } + } +} + #[cfg(test)] mod tests { use super::choose_supported; diff --git a/ul/src/lib.rs b/ul/src/lib.rs index 992c43a54..8eafec261 100644 --- a/ul/src/lib.rs +++ b/ul/src/lib.rs @@ -38,6 +38,6 @@ pub const IMPLEMENTATION_VERSION_NAME: &str = "DICOM-rs 0.6"; pub use address::{AeAddr, FullAeAddr}; pub use association::client::{ClientAssociation, ClientAssociationOptions}; pub use association::server::{ServerAssociation, ServerAssociationOptions}; -pub use pdu::reader::read_pdu; -pub use pdu::writer::write_pdu; +pub use pdu::read_pdu; +pub use pdu::write_pdu; pub use pdu::Pdu; diff --git a/ul/src/pdu/mod.rs b/ul/src/pdu/mod.rs index 6930c892b..0deebe251 100644 --- a/ul/src/pdu/mod.rs +++ b/ul/src/pdu/mod.rs @@ -10,7 +10,139 @@ pub mod writer; use std::fmt::Display; pub use reader::read_pdu; -pub use writer::write_pdu; +use snafu::{Backtrace, Snafu}; +pub use writer::{write_pdu, WriteChunkError}; + +/// The default maximum PDU size +pub const DEFAULT_MAX_PDU: u32 = 16_384; + +/// The minimum PDU size, +/// as specified by the standard +pub const MINIMUM_PDU_SIZE: u32 = 4_096; + +/// The maximum PDU size, +/// as specified by the standard +pub const MAXIMUM_PDU_SIZE: u32 = 131_072; + +/// The length of the PDU header in bytes, +/// comprising the PDU type (1 byte), +/// reserved byte (1 byte), +/// and PDU length (4 bytes). +pub const PDU_HEADER_SIZE: u32 = 6; + +#[derive(Debug, Snafu)] +#[non_exhaustive] +pub enum WriteError { + #[snafu(display("Could not write chunk of {} PDU structure", name))] + WriteChunk { + /// the name of the PDU structure + name: &'static str, + source: WriteChunkError, + }, + + #[snafu(display("Could not write field `{}`", field))] + WriteField { + field: &'static str, + backtrace: Backtrace, + source: std::io::Error, + }, + + #[snafu(display("Could not write {} reserved bytes", bytes))] + WriteReserved { + bytes: u32, + backtrace: Backtrace, + source: std::io::Error, + }, + + #[snafu(display("Could not write field `{}`", field))] + EncodeField { + field: &'static str, + #[snafu(backtrace)] + source: dicom_encoding::text::EncodeTextError, + }, +} + +#[derive(Debug, Snafu)] +#[non_exhaustive] +pub enum ReadError { + #[snafu(display("Invalid max PDU length {}", max_pdu_length))] + InvalidMaxPdu { + max_pdu_length: u32, + backtrace: Backtrace, + }, + + #[snafu(display("No PDU available"))] + NoPduAvailable { backtrace: Backtrace }, + + #[snafu(display("Could not read PDU"), visibility(pub(crate)))] + ReadPdu { + source: std::io::Error, + backtrace: Backtrace, + }, + + #[snafu(display("Could not read PDU item"))] + ReadPduItem { + source: std::io::Error, + backtrace: Backtrace, + }, + + #[snafu(display("Could not read PDU field `{}`", field))] + ReadPduField { + field: &'static str, + source: std::io::Error, + backtrace: Backtrace, + }, + + #[snafu(display("Invalid item length {} (must be >=2)", length))] + InvalidItemLength { length: u32 }, + + #[snafu(display("Could not read {} reserved bytes", bytes))] + ReadReserved { + bytes: u32, + source: std::io::Error, + backtrace: Backtrace, + }, + + #[snafu(display( + "Incoming pdu was too large: length {}, maximum is {}", + pdu_length, + max_pdu_length + ))] + PduTooLarge { + pdu_length: u32, + max_pdu_length: u32, + backtrace: Backtrace, + }, + #[snafu(display("PDU contained an invalid value {:?}", var_item))] + InvalidPduVariable { + var_item: PduVariableItem, + backtrace: Backtrace, + }, + #[snafu(display("Multiple transfer syntaxes were accepted"))] + MultipleTransferSyntaxesAccepted { backtrace: Backtrace }, + #[snafu(display("Invalid reject source or reason"))] + InvalidRejectSourceOrReason { backtrace: Backtrace }, + #[snafu(display("Invalid abort service provider"))] + InvalidAbortSourceOrReason { backtrace: Backtrace }, + #[snafu(display("Invalid presentation context result reason"))] + InvalidPresentationContextResultReason { backtrace: Backtrace }, + #[snafu(display("invalid transfer syntax sub-item"))] + InvalidTransferSyntaxSubItem { backtrace: Backtrace }, + #[snafu(display("unknown presentation context sub-item"))] + UnknownPresentationContextSubItem { backtrace: Backtrace }, + #[snafu(display("Could not decode text field `{}`", field))] + DecodeText { + field: &'static str, + #[snafu(backtrace)] + source: dicom_encoding::text::DecodeTextError, + }, + #[snafu(display("Missing application context name"))] + MissingApplicationContextName { backtrace: Backtrace }, + #[snafu(display("Missing abstract syntax"))] + MissingAbstractSyntax { backtrace: Backtrace }, + #[snafu(display("Missing transfer syntax"))] + MissingTransferSyntax { backtrace: Backtrace }, +} /// Message component for a proposed presentation context. #[derive(Clone, Eq, PartialEq, PartialOrd, Hash, Debug)] diff --git a/ul/src/pdu/reader.rs b/ul/src/pdu/reader.rs index 3423f2710..fabc55370 100644 --- a/ul/src/pdu/reader.rs +++ b/ul/src/pdu/reader.rs @@ -1,116 +1,13 @@ /// PDU reader module use crate::pdu::*; -use byteordered::byteorder::{BigEndian, ReadBytesExt}; +use bytes::Buf; use dicom_encoding::text::{DefaultCharacterSetCodec, TextCodec}; -use snafu::{ensure, Backtrace, OptionExt, ResultExt, Snafu}; -use std::io::{Cursor, ErrorKind, Read, Seek, SeekFrom}; +use snafu::{ensure, OptionExt, ResultExt}; use tracing::warn; -/// The default maximum PDU size -pub const DEFAULT_MAX_PDU: u32 = 16_384; - -/// The minimum PDU size, -/// as specified by the standard -pub const MINIMUM_PDU_SIZE: u32 = 4_096; - -/// The maximum PDU size, -/// as specified by the standard -pub const MAXIMUM_PDU_SIZE: u32 = 131_072; - -/// The length of the PDU header in bytes, -/// comprising the PDU type (1 byte), -/// reserved byte (1 byte), -/// and PDU length (4 bytes). -pub const PDU_HEADER_SIZE: u32 = 6; - -#[derive(Debug, Snafu)] -#[non_exhaustive] -pub enum Error { - #[snafu(display("Invalid max PDU length {}", max_pdu_length))] - InvalidMaxPdu { - max_pdu_length: u32, - backtrace: Backtrace, - }, - - #[snafu(display("No PDU available"))] - NoPduAvailable { backtrace: Backtrace }, - - #[snafu(display("Could not read PDU"))] - ReadPdu { - source: std::io::Error, - backtrace: Backtrace, - }, - - #[snafu(display("Could not read PDU item"))] - ReadPduItem { - source: std::io::Error, - backtrace: Backtrace, - }, - - #[snafu(display("Could not read PDU field `{}`", field))] - ReadPduField { - field: &'static str, - source: std::io::Error, - backtrace: Backtrace, - }, - - #[snafu(display("Invalid item length {} (must be >=2)", length))] - InvalidItemLength { length: u32 }, - - #[snafu(display("Could not read {} reserved bytes", bytes))] - ReadReserved { - bytes: u32, - source: std::io::Error, - backtrace: Backtrace, - }, - - #[snafu(display( - "Incoming pdu was too large: length {}, maximum is {}", - pdu_length, - max_pdu_length - ))] - PduTooLarge { - pdu_length: u32, - max_pdu_length: u32, - backtrace: Backtrace, - }, - #[snafu(display("PDU contained an invalid value {:?}", var_item))] - InvalidPduVariable { - var_item: PduVariableItem, - backtrace: Backtrace, - }, - #[snafu(display("Multiple transfer syntaxes were accepted"))] - MultipleTransferSyntaxesAccepted { backtrace: Backtrace }, - #[snafu(display("Invalid reject source or reason"))] - InvalidRejectSourceOrReason { backtrace: Backtrace }, - #[snafu(display("Invalid abort service provider"))] - InvalidAbortSourceOrReason { backtrace: Backtrace }, - #[snafu(display("Invalid presentation context result reason"))] - InvalidPresentationContextResultReason { backtrace: Backtrace }, - #[snafu(display("invalid transfer syntax sub-item"))] - InvalidTransferSyntaxSubItem { backtrace: Backtrace }, - #[snafu(display("unknown presentation context sub-item"))] - UnknownPresentationContextSubItem { backtrace: Backtrace }, - #[snafu(display("Could not decode text field `{}`", field))] - DecodeText { - field: &'static str, - #[snafu(backtrace)] - source: dicom_encoding::text::DecodeTextError, - }, - #[snafu(display("Missing application context name"))] - MissingApplicationContextName { backtrace: Backtrace }, - #[snafu(display("Missing abstract syntax"))] - MissingAbstractSyntax { backtrace: Backtrace }, - #[snafu(display("Missing transfer syntax"))] - MissingTransferSyntax { backtrace: Backtrace }, -} - -pub type Result = std::result::Result; +pub type Result = std::result::Result; -pub fn read_pdu(reader: &mut R, max_pdu_length: u32, strict: bool) -> Result -where - R: Read, -{ +pub fn read_pdu(mut buf: impl Buf, max_pdu_length: u32, strict: bool) -> Result> { ensure!( (MINIMUM_PDU_SIZE..=MAXIMUM_PDU_SIZE).contains(&max_pdu_length), InvalidMaxPduSnafu { max_pdu_length } @@ -121,16 +18,15 @@ where // this method can block and wake up when stream is closed, so in this case, we // want to know if we had trouble even beginning to read a PDU. We still return // UnexpectedEof if we get after we have already began reading a PDU message. - let mut bytes = [0; 2]; - if let Err(e) = reader.read_exact(&mut bytes) { - ensure!(e.kind() != ErrorKind::UnexpectedEof, NoPduAvailableSnafu); - return Err(e).context(ReadPduFieldSnafu { field: "type" }); + if buf.remaining() < 2 { + return Ok(None); } - + let bytes = buf.copy_to_bytes(2); let pdu_type = bytes[0]; - let pdu_length = reader - .read_u32::() - .context(ReadPduFieldSnafu { field: "length" })?; + if buf.remaining() < 4 { + return Ok(None); + } + let pdu_length = buf.get_u32(); // Check max_pdu_length if strict { @@ -155,9 +51,10 @@ where max_pdu_length ); } - - let bytes = read_n(reader, pdu_length as usize).context(ReadPduSnafu)?; - let mut cursor = Cursor::new(bytes); + if buf.remaining() < pdu_length as usize { + return Ok(None); + } + let mut bytes = buf.copy_to_bytes(pdu_length as usize); let codec = DefaultCharacterSetCodec; match pdu_type { @@ -173,29 +70,26 @@ where // Version 1 and shall be identified with bit 0 set. A receiver of this PDU // implementing only this version of the DICOM UL protocol shall only test that bit // 0 is set. - let protocol_version = cursor.read_u16::().context(ReadPduFieldSnafu { - field: "Protocol-version", - })?; + if bytes.remaining() < 2 { + return Ok(None); + } + let protocol_version = bytes.get_u16(); // 9-10 - Reserved - This reserved field shall be sent with a value 0000H but not // tested to this value when received. - cursor - .read_u16::() - .context(ReadReservedSnafu { bytes: 2_u32 })?; + if bytes.remaining() < 2 { + return Ok(None); + } + bytes.get_u16(); // 11-26 - Called-AE-title - Destination DICOM Application Name. It shall be encoded // as 16 characters as defined by the ISO 646:1990-Basic G0 Set with leading and // trailing spaces (20H) being non-significant. The value made of 16 spaces (20H) // meaning "no Application Name specified" shall not be used. For a complete // description of the use of this field, see Section 7.1.1.4. - let mut ae_bytes = [0; 16]; - cursor - .read_exact(&mut ae_bytes) - .context(ReadPduFieldSnafu { - field: "Called-AE-title", - })?; + let ae_bytes = bytes.copy_to_bytes(16); let called_ae_title = codec - .decode(&ae_bytes) + .decode(ae_bytes.as_ref()) .context(DecodeTextSnafu { field: "Called-AE-title", })? @@ -207,14 +101,9 @@ where // trailing spaces (20H) being non-significant. The value made of 16 spaces (20H) // meaning "no Application Name specified" shall not be used. For a complete // description of the use of this field, see Section 7.1.1.3. - let mut ae_bytes = [0; 16]; - cursor - .read_exact(&mut ae_bytes) - .context(ReadPduFieldSnafu { - field: "Calling-AE-title", - })?; + let ae_bytes = bytes.copy_to_bytes(16); let calling_ae_title = codec - .decode(&ae_bytes) + .decode(ae_bytes.as_ref()) .context(DecodeTextSnafu { field: "Calling-AE-title", })? @@ -223,32 +112,34 @@ where // 43-74 - Reserved - This reserved field shall be sent with a value 00H for all // bytes but not tested to this value when received - cursor - .seek(SeekFrom::Current(32)) - .context(ReadReservedSnafu { bytes: 32_u32 })?; + bytes.advance(32); // 75-xxx - Variable items - This variable field shall contain the following items: // one Application Context Item, one or more Presentation Context Items and one User // Information Item. For a complete description of the use of these items see // Section 7.1.1.2, Section 7.1.1.13, and Section 7.1.1.6. - while cursor.position() < cursor.get_ref().len() as u64 { - match read_pdu_variable(&mut cursor, &codec)? { - PduVariableItem::ApplicationContext(val) => { + while bytes.has_remaining() { + match read_pdu_variable(&mut bytes, &codec)? { + Some(PduVariableItem::ApplicationContext(val)) => { application_context_name = Some(val); } - PduVariableItem::PresentationContextProposed(val) => { + Some(PduVariableItem::PresentationContextProposed(val)) => { presentation_contexts.push(val); } - PduVariableItem::UserVariables(val) => { + Some(PduVariableItem::UserVariables(val)) => { user_variables = val; } - var_item => { + Some(var_item) => { return InvalidPduVariableSnafu { var_item }.fail(); } + None => { + println!("PDU variable none"); + return Ok(None); + } } } - Ok(Pdu::AssociationRQ(AssociationRQ { + Ok(Some(Pdu::AssociationRQ(AssociationRQ { protocol_version, application_context_name: application_context_name .context(MissingApplicationContextNameSnafu)?, @@ -256,7 +147,7 @@ where calling_ae_title, presentation_contexts, user_variables, - })) + }))) } 0x02 => { // A-ASSOCIATE-AC PDU Structure @@ -270,25 +161,22 @@ where // Version 1 and shall be identified with bit 0 set. A receiver of this PDU // implementing only this version of the DICOM UL protocol shall only test that bit // 0 is set. - let protocol_version = cursor.read_u16::().context(ReadPduFieldSnafu { - field: "Protocol-version", - })?; + if bytes.remaining() < 2 { + return Ok(None); + } + let protocol_version = bytes.get_u16(); // 9-10 - Reserved - This reserved field shall be sent with a value 0000H but not // tested to this value when received. - cursor - .read_u16::() - .context(ReadReservedSnafu { bytes: 2_u32 })?; + if bytes.remaining() < 2 { + return Ok(None); + } + bytes.get_u16(); // 11-26 - Reserved - This reserved field shall be sent with a value identical to // the value received in the same field of the A-ASSOCIATE-RQ PDU, but its value // shall not be tested when received. - let mut ae_bytes = [0; 16]; - cursor - .read_exact(&mut ae_bytes) - .context(ReadPduFieldSnafu { - field: "Called-AE-title", - })?; + let ae_bytes = bytes.copy_to_bytes(16); let called_ae_title = codec .decode(&ae_bytes) .context(DecodeTextSnafu { @@ -300,12 +188,7 @@ where // 27-42 - Reserved - This reserved field shall be sent with a value identical to // the value received in the same field of the A-ASSOCIATE-RQ PDU, but its value // shall not be tested when received. - let mut ae_bytes = [0; 16]; - cursor - .read_exact(&mut ae_bytes) - .context(ReadPduFieldSnafu { - field: "Calling-AE-title", - })?; + let ae_bytes = bytes.copy_to_bytes(16); let calling_ae_title = codec .decode(&ae_bytes) .context(DecodeTextSnafu { @@ -317,32 +200,31 @@ where // 43-74 - Reserved - This reserved field shall be sent with a value identical to // the value received in the same field of the A-ASSOCIATE-RQ PDU, but its value // shall not be tested when received. - cursor - .seek(SeekFrom::Current(32)) - .context(ReadReservedSnafu { bytes: 32_u32 })?; + bytes.advance(32); // 75-xxx - Variable items - This variable field shall contain the following items: // one Application Context Item, one or more Presentation Context Item(s) and one // User Information Item. For a complete description of these items see Section // 7.1.1.2, Section 7.1.1.14, and Section 7.1.1.6. - while cursor.position() < cursor.get_ref().len() as u64 { - match read_pdu_variable(&mut cursor, &codec)? { - PduVariableItem::ApplicationContext(val) => { + while bytes.has_remaining() { + match read_pdu_variable(&mut bytes, &codec)? { + Some(PduVariableItem::ApplicationContext(val)) => { application_context_name = Some(val); } - PduVariableItem::PresentationContextResult(val) => { + Some(PduVariableItem::PresentationContextResult(val)) => { presentation_contexts.push(val); } - PduVariableItem::UserVariables(val) => { + Some(PduVariableItem::UserVariables(val)) => { user_variables = val; } - var_item => { + Some(var_item) => { return InvalidPduVariableSnafu { var_item }.fail(); } + None => return Ok(None), } } - Ok(Pdu::AssociationAC(AssociationAC { + Ok(Some(Pdu::AssociationAC(AssociationAC { protocol_version, application_context_name: application_context_name .context(MissingApplicationContextNameSnafu)?, @@ -350,27 +232,27 @@ where calling_ae_title, presentation_contexts, user_variables, - })) + }))) } 0x03 => { // A-ASSOCIATE-RJ PDU Structure // 7 - Reserved - This reserved field shall be sent with a value 00H but not tested to // this value when received. - cursor - .read_u8() - .context(ReadReservedSnafu { bytes: 1_u32 })?; + if bytes.remaining() < 1 { + return Ok(None); + } + bytes.get_u8(); // 8 - Result - This Result field shall contain an integer value encoded as an unsigned // binary number. One of the following values shall be used: // 1 - rejected-permanent // 2 - rejected-transient - let result = AssociationRJResult::from( - cursor - .read_u8() - .context(ReadPduFieldSnafu { field: "Result" })?, - ) - .context(InvalidRejectSourceOrReasonSnafu)?; + if bytes.remaining() < 1 { + return Ok(None); + } + let result = AssociationRJResult::from(bytes.get_u8()) + .context(InvalidRejectSourceOrReasonSnafu)?; // 9 - Source - This Source field shall contain an integer value encoded as an unsigned // binary number. One of the following values shall be used: 1 - DICOM UL @@ -393,17 +275,13 @@ where // 1 - temporary-congestio // 2 - local-limit-exceeded // 3-7 - reserved - let source = AssociationRJSource::from( - cursor - .read_u8() - .context(ReadPduFieldSnafu { field: "Source" })?, - cursor.read_u8().context(ReadPduFieldSnafu { - field: "Reason/Diag.", - })?, - ) - .context(InvalidRejectSourceOrReasonSnafu)?; - - Ok(Pdu::AssociationRJ(AssociationRJ { result, source })) + if bytes.remaining() < 2 { + return Ok(None); + } + let source = AssociationRJSource::from(bytes.get_u8(), bytes.get_u8()) + .context(InvalidRejectSourceOrReasonSnafu)?; + + Ok(Some(Pdu::AssociationRJ(AssociationRJ { result, source }))) } 0x04 => { // P-DATA-TF PDU Structure @@ -412,15 +290,16 @@ where // or more Presentation-data-value Items(s). For a complete description of the use of // this field see Section 9.3.5.1 let mut values = vec![]; - while cursor.position() < cursor.get_ref().len() as u64 { + while bytes.has_remaining() { // Presentation Data Value Item Structure // 1-4 - Item-length - This Item-length shall be the number of bytes from the first // byte of the following field to the last byte of the Presentation-data-value // field. It shall be encoded as an unsigned binary number. - let item_length = cursor.read_u32::().context(ReadPduFieldSnafu { - field: "Item-Length", - })?; + if bytes.remaining() < 4 { + return Ok(None); + } + let item_length = bytes.get_u32(); ensure!( item_length >= 2, @@ -432,9 +311,10 @@ where // 5 - Presentation-context-ID - Presentation-context-ID values shall be odd // integers between 1 and 255, encoded as an unsigned binary number. For a complete // description of the use of this field see Section 7.1.1.13. - let presentation_context_id = cursor.read_u8().context(ReadPduFieldSnafu { - field: "Presentation-context-ID", - })?; + if bytes.remaining() < 1 { + return Ok(None); + } + let presentation_context_id = bytes.get_u8(); // 6-xxx - Presentation-data-value - This Presentation-data-value field shall // contain DICOM message information (command and/or data set) with a message @@ -449,9 +329,10 @@ where // following fragment shall contain the last fragment of a Message Data Set or of a // Message Command. If bit 1 is set to 0, the following fragment // does not contain the last fragment of a Message Data Set or of a Message Command. - let header = cursor.read_u8().context(ReadPduFieldSnafu { - field: "Message Control Header", - })?; + if bytes.remaining() < 1 { + return Ok(None); + } + let header = bytes.get_u8(); let value_type = if header & 0x01 > 0 { PDataValueType::Command @@ -459,43 +340,39 @@ where PDataValueType::Data }; let is_last = (header & 0x02) > 0; - - let data = - read_n(&mut cursor, (item_length - 2) as usize).context(ReadPduFieldSnafu { - field: "Presentation-data-value", - })?; - + if bytes.remaining() < (item_length - 2) as usize { + return Ok(None); + } values.push(PDataValue { presentation_context_id, value_type, is_last, - data, - }) + data: bytes.copy_to_bytes((item_length - 2) as usize).to_vec(), + }); } - Ok(Pdu::PData { data: values }) + Ok(Some(Pdu::PData { data: values })) } 0x05 => { // A-RELEASE-RQ PDU Structure // 7-10 - Reserved - This reserved field shall be sent with a value 00000000H but not // tested to this value when received. - cursor - .seek(SeekFrom::Current(4)) - .context(ReadReservedSnafu { bytes: 4_u32 })?; + bytes.advance(4); - Ok(Pdu::ReleaseRQ) + Ok(Some(Pdu::ReleaseRQ)) } 0x06 => { // A-RELEASE-RP PDU Structure // 7-10 - Reserved - This reserved field shall be sent with a value 00000000H but not // tested to this value when received. - cursor - .seek(SeekFrom::Current(4)) - .context(ReadReservedSnafu { bytes: 4_u32 })?; + if bytes.remaining() < 4 { + return Ok(None); + } + bytes.advance(4); - Ok(Pdu::ReleaseRP) + Ok(Some(Pdu::ReleaseRP)) } 0x07 => { // A-ABORT PDU Structure @@ -504,10 +381,10 @@ where // this value when received. // 8 - Reserved - This reserved field shall be sent with a value 00H but not tested to // this value when received. - let mut buf = [0u8; 2]; - cursor - .read_exact(&mut buf) - .context(ReadReservedSnafu { bytes: 2_u32 })?; + if bytes.remaining() < 2 { + return Ok(None); + } + let _ = bytes.copy_to_bytes(2); // 9 - Source - This Source field shall contain an integer value encoded as an unsigned // binary number. One of the following values shall be used: @@ -523,57 +400,49 @@ where // - 4 - unrecognized-PDU parameter // - 5 - unexpected-PDU parameter // - 6 - invalid-PDU-parameter value - let source = AbortRQSource::from( - cursor - .read_u8() - .context(ReadPduFieldSnafu { field: "Source" })?, - cursor.read_u8().context(ReadPduFieldSnafu { - field: "Reason/Diag", - })?, - ) - .context(InvalidAbortSourceOrReasonSnafu)?; - - Ok(Pdu::AbortRQ { source }) + if bytes.remaining() < 2 { + return Ok(None); + } + let source = AbortRQSource::from(bytes.get_u8(), bytes.get_u8()) + .context(InvalidAbortSourceOrReasonSnafu)?; + + Ok(Some(Pdu::AbortRQ { source })) } _ => { - let data = read_n(&mut cursor, pdu_length as usize) - .context(ReadPduFieldSnafu { field: "Unknown" })?; - Ok(Pdu::Unknown { pdu_type, data }) + if bytes.remaining() < pdu_length as usize { + return Ok(None); + } + Ok(Some(Pdu::Unknown { + pdu_type, + data: bytes.copy_to_bytes(pdu_length as usize).to_vec(), + })) } } } -fn read_n(reader: &mut R, bytes_to_read: usize) -> std::io::Result> -where - R: Read, -{ - let mut result = Vec::new(); - reader.take(bytes_to_read as u64).read_to_end(&mut result)?; - Ok(result) -} - -fn read_pdu_variable(reader: &mut R, codec: &dyn TextCodec) -> Result -where - R: Read, -{ +fn read_pdu_variable(mut buf: impl Buf, codec: &dyn TextCodec) -> Result> { // 1 - Item-type - XXH - let item_type = reader - .read_u8() - .context(ReadPduFieldSnafu { field: "Item-type" })?; + if buf.remaining() < 1 { + return Ok(None); + } + let item_type = buf.get_u8(); // 2 - Reserved - reader - .read_u8() - .context(ReadReservedSnafu { bytes: 1_u32 })?; + if buf.remaining() < 1 { + return Ok(None); + } + buf.get_u8(); // 3-4 - Item-length - let item_length = reader.read_u16::().context(ReadPduFieldSnafu { - field: "Item-length", - })?; - - let bytes = read_n(reader, item_length as usize).context(ReadPduItemSnafu)?; - let mut cursor = Cursor::new(bytes); + if buf.remaining() < 2 { + return Ok(None); + } + let item_length = buf.get_u16(); + if buf.remaining() < item_length as usize { + return Ok(None); + } + let mut bytes = buf.copy_to_bytes(item_length as usize); match item_type { 0x10 => { // Application Context Item Structure @@ -583,12 +452,10 @@ where // 7.1.1.2. Application-context-names are structured as UIDs as defined in PS3.5 (see // Annex A for an overview of this concept). DICOM Application-context-names are // registered in PS3.7. - let val = codec - .decode(&cursor.into_inner()) - .context(DecodeTextSnafu { - field: "Application-context-name", - })?; - Ok(PduVariableItem::ApplicationContext(val)) + let val = codec.decode(bytes.as_ref()).context(DecodeTextSnafu { + field: "Application-context-name", + })?; + Ok(Some(PduVariableItem::ApplicationContext(val))) } 0x20 => { // Presentation Context Item Structure (proposed) @@ -599,48 +466,55 @@ where // 5 - Presentation-context-ID - Presentation-context-ID values shall be odd integers // between 1 and 255, encoded as an unsigned binary number. For a complete description // of the use of this field see Section 7.1.1.13. - let presentation_context_id = cursor.read_u8().context(ReadPduFieldSnafu { - field: "Presentation-context-ID", - })?; + if bytes.remaining() < 1 { + return Ok(None); + } + let presentation_context_id = bytes.get_u8(); // 6 - Reserved - This reserved field shall be sent with a value 00H but not tested to // this value when received. - cursor - .read_u8() - .context(ReadReservedSnafu { bytes: 1_u32 })?; + if bytes.remaining() < 1 { + return Ok(None); + } + bytes.get_u8(); // 7 - Reserved - This reserved field shall be sent with a value 00H but not tested to // this value when received. - cursor - .read_u8() - .context(ReadReservedSnafu { bytes: 1_u32 })?; + if bytes.remaining() < 1 { + return Ok(None); + } + bytes.get_u8(); // 8 - Reserved - This reserved field shall be sent with a value 00H but not tested to // this value when received. - cursor - .read_u8() - .context(ReadReservedSnafu { bytes: 1_u32 })?; + if bytes.remaining() < 1 { + return Ok(None); + } + bytes.get_u8(); // 9-xxx - Abstract/Transfer Syntax Sub-Items - This variable field shall contain the // following sub-items: one Abstract Syntax and one or more Transfer Syntax(es). For a // complete description of the use and encoding of these sub-items see Section 9.3.2.2.1 // and Section 9.3.2.2.2. - while cursor.position() < cursor.get_ref().len() as u64 { + while bytes.has_remaining() { // 1 - Item-type - XXH - let item_type = cursor - .read_u8() - .context(ReadPduFieldSnafu { field: "Item-type" })?; + if bytes.remaining() < 1 { + return Ok(None); + } + let item_type = bytes.get_u8(); // 2 - Reserved - This reserved field shall be sent with a value 00H but not tested // to this value when received. - cursor - .read_u8() - .context(ReadReservedSnafu { bytes: 1_u32 })?; + if bytes.remaining() < 1 { + return Ok(None); + } + bytes.get_u8(); // 3-4 - Item-length - let item_length = cursor.read_u16::().context(ReadPduFieldSnafu { - field: "Item-length", - })?; + if bytes.remaining() < 2 { + return Ok(None); + } + let item_length = bytes.get_u16(); match item_type { 0x30 => { @@ -653,13 +527,12 @@ where // Abstract-syntax-names are structured as UIDs as defined in PS3.5 (see // Annex B for an overview of this concept). DICOM Abstract-syntax-names are // registered in PS3.4. + if bytes.remaining() < item_length as usize { + return Ok(None); + } abstract_syntax = Some( codec - .decode(&read_n(&mut cursor, item_length as usize).context( - ReadPduFieldSnafu { - field: "Abstract-syntax-name", - }, - )?) + .decode(bytes.copy_to_bytes(item_length as usize).as_ref()) .context(DecodeTextSnafu { field: "Abstract-syntax-name", })? @@ -677,13 +550,12 @@ where // Transfer-syntax-names are structured as UIDs as defined in PS3.5 (see // Annex B for an overview of this concept). DICOM Transfer-syntax-names are // registered in PS3.5. + if bytes.remaining() < item_length as usize { + return Ok(None); + } transfer_syntaxes.push( codec - .decode(&read_n(&mut cursor, item_length as usize).context( - ReadPduFieldSnafu { - field: "Transfer-syntax-name", - }, - )?) + .decode(bytes.copy_to_bytes(item_length as usize).as_ref()) .context(DecodeTextSnafu { field: "Transfer-syntax-name", })? @@ -697,13 +569,13 @@ where } } - Ok(PduVariableItem::PresentationContextProposed( + Ok(Some(PduVariableItem::PresentationContextProposed( PresentationContextProposed { id: presentation_context_id, abstract_syntax: abstract_syntax.context(MissingAbstractSyntaxSnafu)?, transfer_syntaxes, }, - )) + ))) } 0x21 => { // Presentation Context Item Structure (result) @@ -713,15 +585,17 @@ where // 5 - Presentation-context-ID - Presentation-context-ID values shall be odd integers // between 1 and 255, encoded as an unsigned binary number. For a complete description // of the use of this field see Section 7.1.1.13. - let presentation_context_id = cursor.read_u8().context(ReadPduFieldSnafu { - field: "Presentation-context-ID", - })?; + if bytes.remaining() < 1 { + return Ok(None); + } + let presentation_context_id = bytes.get_u8(); // 6 - Reserved - This reserved field shall be sent with a value 00H but not tested to // this value when received. - cursor - .read_u8() - .context(ReadReservedSnafu { bytes: 1_u32 })?; + if bytes.remaining() < 1 { + return Ok(None); + } + bytes.get_u8(); // 7 - Result/Reason - This Result/Reason field shall contain an integer value encoded // as an unsigned binary number. One of the following values shall be used: @@ -730,40 +604,43 @@ where // 2 - no-reason (provider rejection) // 3 - abstract-syntax-not-supported (provider rejection) // 4 - transfer-syntaxes-not-supported (provider rejection) - let reason = PresentationContextResultReason::from(cursor.read_u8().context( - ReadPduFieldSnafu { - field: "Result/Reason", - }, - )?) - .context(InvalidPresentationContextResultReasonSnafu)?; + if bytes.remaining() < 1 { + return Ok(None); + } + let reason = PresentationContextResultReason::from(bytes.get_u8()) + .context(InvalidPresentationContextResultReasonSnafu)?; // 8 - Reserved - This reserved field shall be sent with a value 00H but not tested to // this value when received. - cursor - .read_u8() - .context(ReadReservedSnafu { bytes: 1_u32 })?; + if bytes.remaining() < 1 { + return Ok(None); + } + bytes.get_u8(); // 9-xxx - Transfer syntax sub-item - This variable field shall contain one Transfer // Syntax Sub-Item. When the Result/Reason field has a value other than acceptance (0), // this field shall not be significant and its value shall not be tested when received. // For a complete description of the use and encoding of this item see Section // 9.3.3.2.1. - while cursor.position() < cursor.get_ref().len() as u64 { + while bytes.has_remaining() { // 1 - Item-type - XXH - let item_type = cursor - .read_u8() - .context(ReadPduFieldSnafu { field: "Item-type" })?; + if bytes.remaining() < 1 { + return Ok(None); + } + let item_type = bytes.get_u8(); // 2 - Reserved - This reserved field shall be sent with a value 00H but not tested // to this value when received. - cursor - .read_u8() - .context(ReadReservedSnafu { bytes: 1_u32 })?; + if bytes.remaining() < 1 { + return Ok(None); + } + bytes.get_u8(); // 3-4 - Item-length - let item_length = cursor.read_u16::().context(ReadPduFieldSnafu { - field: "Item-length", - })?; + if bytes.remaining() < 2 { + return Ok(None); + } + let item_length = bytes.get_u16(); match item_type { 0x40 => { @@ -782,15 +659,12 @@ where return MultipleTransferSyntaxesAcceptedSnafu.fail(); } None => { + if bytes.remaining() < item_length as usize { + return Ok(None); + } transfer_syntax = Some( codec - .decode( - &read_n(&mut cursor, item_length as usize).context( - ReadPduFieldSnafu { - field: "Transfer-syntax-name", - }, - )?, - ) + .decode(bytes.copy_to_bytes(item_length as usize).as_ref()) .context(DecodeTextSnafu { field: "Transfer-syntax-name", })? @@ -806,13 +680,13 @@ where } } - Ok(PduVariableItem::PresentationContextResult( + Ok(Some(PduVariableItem::PresentationContextResult( PresentationContextResult { id: presentation_context_id, reason, transfer_syntax: transfer_syntax.context(MissingTransferSyntaxSnafu)?, }, - )) + ))) } 0x50 => { // User Information Item Structure @@ -822,21 +696,24 @@ where // 5-xxx - User-data - This variable field shall contain User-data sub-items as defined // by the DICOM Application Entity. The structure and content of these sub-items is // defined in Annex D. - while cursor.position() < cursor.get_ref().len() as u64 { + while bytes.has_remaining() { // 1 - Item-type - XXH - let item_type = cursor - .read_u8() - .context(ReadPduFieldSnafu { field: "Item-type" })?; + if bytes.remaining() < 1 { + return Ok(None); + } + let item_type = bytes.get_u8(); // 2 - Reserved - cursor - .read_u8() - .context(ReadReservedSnafu { bytes: 1_u32 })?; + if bytes.remaining() < 1 { + return Ok(None); + } + bytes.get_u8(); // 3-4 - Item-length - let item_length = cursor.read_u16::().context(ReadPduFieldSnafu { - field: "Item-length", - })?; + if bytes.remaining() < 2 { + return Ok(None); + } + let item_length = bytes.get_u16(); match item_type { 0x51 => { @@ -851,11 +728,10 @@ where // the PDU length values used in the PDU-length field of the P-DATA-TF PDUs // received by the association-requestor. Otherwise, it shall be a protocol // error. - user_variables.push(UserVariableItem::MaxLength( - cursor.read_u32::().context(ReadPduFieldSnafu { - field: "Maximum-length-received", - })?, - )); + if bytes.remaining() < 4 { + return Ok(None); + } + user_variables.push(UserVariableItem::MaxLength(bytes.get_u32())); } 0x52 => { // Implementation Class UID Sub-Item Structure @@ -864,12 +740,11 @@ where // the Implementation-class-uid of the Association-acceptor as defined in // Section D.3.3.2. The Implementation-class-uid field is structured as a // UID as defined in PS3.5. + if bytes.remaining() < item_length as usize { + return Ok(None); + } let implementation_class_uid = codec - .decode(&read_n(&mut cursor, item_length as usize).context( - ReadPduFieldSnafu { - field: "Implementation-class-uid", - }, - )?) + .decode(bytes.copy_to_bytes(item_length as usize).as_ref()) .context(DecodeTextSnafu { field: "Implementation-class-uid", })? @@ -886,12 +761,11 @@ where // the Implementation-version-name of the Association-acceptor as defined in // Section D.3.3.2. It shall be encoded as a string of 1 to 16 ISO 646:1990 // (basic G0 set) characters. + if bytes.remaining() < item_length as usize { + return Ok(None); + } let implementation_version_name = codec - .decode(&read_n(&mut cursor, item_length as usize).context( - ReadPduFieldSnafu { - field: "Implementation-version-name", - }, - )?) + .decode(bytes.copy_to_bytes(item_length as usize).as_ref()) .context(DecodeTextSnafu { field: "Implementation-version-name", })? @@ -907,83 +781,80 @@ where // 5-6 - SOP-class-uid-length - The SOP-class-uid-length shall be the number // of bytes from the first byte of the following field to the last byte of the // SOP-class-uid field. It shall be encoded as an unsigned binary number. - let sop_class_uid_length = - cursor.read_u16::().context(ReadPduFieldSnafu { - field: "SOP-class-uid-length", - })?; + if bytes.remaining() < 2 { + return Ok(None); + } + let sop_class_uid_length = bytes.get_u16(); // 7 - xxx - SOP-class-uid - The SOP Class or Meta SOP Class identifier // encoded as a UID as defined in Section 9 “Unique Identifiers (UIDs)” in PS3.5. + if bytes.remaining() < sop_class_uid_length as usize { + return Ok(None); + } let sop_class_uid = codec - .decode(&read_n(&mut cursor, sop_class_uid_length as usize).context( - ReadPduFieldSnafu { - field: "SOP-class-uid", - }, - )?) + .decode(bytes.copy_to_bytes(sop_class_uid_length as usize).as_ref()) .context(DecodeTextSnafu { field: "SOP-class-uid", })? .trim() .to_string(); - let data_length = - cursor.read_u16::().context(ReadPduFieldSnafu { - field: "Service-class-application-information-length", - })?; + if bytes.remaining() < 2 { + return Ok(None); + } + let data_length = bytes.get_u16(); // xxx-xxx - Service-class-application-information -This field shall contain // the application information specific to the Service Class specification // identified by the SOP-class-uid. The semantics and value of this field // is defined in the identified Service Class specification. - let data = read_n(&mut cursor, data_length as usize).context( - ReadPduFieldSnafu { - field: "Service-class-application-information", - }, - )?; - + if bytes.remaining() < data_length as usize { + return Ok(None); + } + let data = bytes.copy_to_bytes(data_length as usize); user_variables.push(UserVariableItem::SopClassExtendedNegotiationSubItem( sop_class_uid, - data, + data.to_vec(), )); } 0x58 => { // User Identity Negotiation // 5 - User Identity Type - let user_identity_type = cursor.read_u8().context(ReadPduFieldSnafu { - field: "User-Identity-type", - })?; + if bytes.remaining() < 1 { + return Ok(None); + } + let user_identity_type = bytes.get_u8(); // 6 - Positive-response-requested - let positive_response_requested = - cursor.read_u8().context(ReadPduFieldSnafu { - field: "User-Identity-positive-response-requested", - })?; + if bytes.remaining() < 1 { + return Ok(None); + } + let positive_response_requested = bytes.get_u8(); // 7-8 - Primary Field Length - let primary_field_length = - cursor.read_u16::().context(ReadPduFieldSnafu { - field: "User-Identity-primary-field-length", - })?; + if bytes.remaining() < 2 { + return Ok(None); + } + let primary_field_length = bytes.get_u16(); // 9-n - Primary Field - let primary_field = read_n(&mut cursor, primary_field_length as usize) - .context(ReadPduFieldSnafu { - field: "User-Identity-primary-field", - })?; - + if bytes.remaining() < primary_field_length as usize { + return Ok(None); + } + let primary_field = bytes.copy_to_bytes(primary_field_length as usize); // n+1-n+2 - Secondary Field Length // Only non-zero if user identity type is 2 (username and password) - let secondary_field_length = - cursor.read_u16::().context(ReadPduFieldSnafu { - field: "User-Identity-secondary-field-length", - })?; + if bytes.remaining() < 2 { + return Ok(None); + } + let secondary_field_length = bytes.get_u16(); // n+3-m - Secondary Field - let secondary_field = read_n(&mut cursor, secondary_field_length as usize) - .context(ReadPduFieldSnafu { - field: "User-Identity-secondary-field", - })?; + if bytes.remaining() < secondary_field_length as usize { + return Ok(None); + } + let secondary_field = bytes.copy_to_bytes(secondary_field_length as usize); match UserIdentityType::from(user_identity_type) { Some(user_identity_type) => { @@ -991,8 +862,8 @@ where UserIdentity::new( positive_response_requested == 1, user_identity_type, - primary_field, - secondary_field, + primary_field.to_vec(), + secondary_field.to_vec(), ), )); } @@ -1002,17 +873,19 @@ where } } _ => { + if bytes.remaining() < item_length as usize { + return Ok(None); + } user_variables.push(UserVariableItem::Unknown( item_type, - read_n(&mut cursor, item_length as usize) - .context(ReadPduFieldSnafu { field: "Unknown" })?, + bytes.copy_to_bytes(item_length as usize).to_vec(), )); } } } - Ok(PduVariableItem::UserVariables(user_variables)) + Ok(Some(PduVariableItem::UserVariables(user_variables))) } - _ => Ok(PduVariableItem::Unknown(item_type)), + _ => Ok(Some(PduVariableItem::Unknown(item_type))), } } diff --git a/ul/src/pdu/writer.rs b/ul/src/pdu/writer.rs index 5d388609a..44be42720 100644 --- a/ul/src/pdu/writer.rs +++ b/ul/src/pdu/writer.rs @@ -5,46 +5,14 @@ use dicom_encoding::text::TextCodec; use snafu::{Backtrace, ResultExt, Snafu}; use std::io::Write; -#[derive(Debug, Snafu)] -#[non_exhaustive] -pub enum Error { - #[snafu(display("Could not write chunk of {} PDU structure", name))] - WriteChunk { - /// the name of the PDU structure - name: &'static str, - source: WriteChunkError, - }, - - #[snafu(display("Could not write field `{}`", field))] - WriteField { - field: &'static str, - backtrace: Backtrace, - source: std::io::Error, - }, - - #[snafu(display("Could not write {} reserved bytes", bytes))] - WriteReserved { - bytes: u32, - backtrace: Backtrace, - source: std::io::Error, - }, - - #[snafu(display("Could not write field `{}`", field))] - EncodeField { - field: &'static str, - #[snafu(backtrace)] - source: dicom_encoding::text::EncodeTextError, - }, -} - -pub type Result = std::result::Result; +pub type Result = std::result::Result; #[derive(Debug, Snafu)] pub enum WriteChunkError { #[snafu(display("Failed to build chunk"))] BuildChunk { #[snafu(backtrace)] - source: Box, + source: Box, }, #[snafu(display("Failed to write chunk length"))] WriteLength { diff --git a/ul/tests/association_echo.rs b/ul/tests/association_echo.rs index 0ee5380f7..16804bdb9 100644 --- a/ul/tests/association_echo.rs +++ b/ul/tests/association_echo.rs @@ -2,11 +2,8 @@ use dicom_ul::{ association::client::ClientAssociationOptions, pdu::{Pdu, PresentationContextResult, PresentationContextResultReason}, }; -use std::net::TcpListener; -use std::{ - net::SocketAddr, - thread::{spawn, JoinHandle}, -}; + +use std::net::SocketAddr; use dicom_ul::association::server::ServerAssociationOptions; @@ -21,15 +18,15 @@ static JPEG_BASELINE: &str = "1.2.840.10008.1.2.4.50"; static VERIFICATION_SOP_CLASS: &str = "1.2.840.10008.1.1"; static DIGITAL_MG_STORAGE_SOP_CLASS: &str = "1.2.840.10008.5.1.4.1.1.1.2"; -fn spawn_scp() -> Result<(JoinHandle>, SocketAddr)> { - let listener = TcpListener::bind("localhost:0")?; +fn spawn_scp() -> Result<(std::thread::JoinHandle>, SocketAddr)> { + let listener = std::net::TcpListener::bind("localhost:0")?; let addr = listener.local_addr()?; let scp = ServerAssociationOptions::new() .accept_called_ae_title() .ae_title(SCP_AE_TITLE) .with_abstract_syntax(VERIFICATION_SOP_CLASS); - let h = spawn(move || -> Result<()> { + let h = std::thread::spawn(move || -> Result<()> { let (stream, _addr) = listener.accept()?; let mut association = scp.establish(stream)?; @@ -59,6 +56,44 @@ fn spawn_scp() -> Result<(JoinHandle>, SocketAddr)> { Ok((h, addr)) } +async fn spawn_scp_async() -> Result<(tokio::task::JoinHandle>, SocketAddr)> { + let listener = tokio::net::TcpListener::bind("localhost:0").await?; + let addr = listener.local_addr()?; + let scp = ServerAssociationOptions::new() + .accept_called_ae_title() + .ae_title(SCP_AE_TITLE) + .with_abstract_syntax(VERIFICATION_SOP_CLASS); + + let h = tokio::spawn(async move { + let (stream, _addr) = listener.accept().await?; + let mut association = scp.establish_async(stream).await?; + + assert_eq!( + association.presentation_contexts(), + &[ + PresentationContextResult { + id: 1, + reason: PresentationContextResultReason::Acceptance, + transfer_syntax: IMPLICIT_VR_LE.to_string(), + }, + PresentationContextResult { + id: 3, + reason: PresentationContextResultReason::AbstractSyntaxNotSupported, + transfer_syntax: IMPLICIT_VR_LE.to_string(), + } + ], + ); + + // handle one release request + let pdu = association.receive().await?; + assert_eq!(pdu, Pdu::ReleaseRQ); + association.send(&Pdu::ReleaseRP).await?; + + Ok(()) + }); + Ok((h, addr)) +} + /// Run an SCP and an SCU concurrently, negotiate an association and release it. #[test] fn scu_scp_association_test() { @@ -84,3 +119,30 @@ fn scu_scp_association_test() { .expect("SCP panicked") .expect("Error at the SCP"); } + +#[tokio::test(flavor = "multi_thread")] +async fn scu_scp_asociation_test() { + let (scp_handle, scp_addr) = spawn_scp_async().await.unwrap(); + + let association = ClientAssociationOptions::new() + .calling_ae_title(SCU_AE_TITLE) + .called_ae_title(SCP_AE_TITLE) + .with_presentation_context(VERIFICATION_SOP_CLASS, vec![IMPLICIT_VR_LE, EXPLICIT_VR_LE]) + .with_presentation_context( + DIGITAL_MG_STORAGE_SOP_CLASS, + vec![IMPLICIT_VR_LE, EXPLICIT_VR_LE, JPEG_BASELINE], + ) + .establish_async(scp_addr) + .await + .unwrap(); + + association + .release() + .await + .expect("did not have a peaceful release"); + + scp_handle + .await + .expect("SCP panicked") + .expect("Error at the SCP"); +} diff --git a/ul/tests/association_promiscuous.rs b/ul/tests/association_promiscuous.rs index 9419e6182..ba0d432d4 100644 --- a/ul/tests/association_promiscuous.rs +++ b/ul/tests/association_promiscuous.rs @@ -1,8 +1,7 @@ use dicom_ul::association::client::Error::NoAcceptedPresentationContexts; use dicom_ul::pdu::{PresentationContextResult, PresentationContextResultReason}; use dicom_ul::{ClientAssociationOptions, Pdu, ServerAssociationOptions}; -use std::net::{SocketAddr, TcpListener}; -use std::thread::{spawn, JoinHandle}; +use std::net::SocketAddr; type Result = std::result::Result>; @@ -16,8 +15,8 @@ const ULTRASOUND_IMAGE_STORAGE_RAW: &str = "1.2.840.10008.5.1.4.1.1.6.1\0"; fn spawn_scp( abstract_syntax_uids: &'static [&str], promiscuous: bool, -) -> Result<(JoinHandle>, SocketAddr)> { - let listener = TcpListener::bind("localhost:0")?; +) -> Result<(std::thread::JoinHandle>, SocketAddr)> { + let listener = std::net::TcpListener::bind("localhost:0")?; let addr = listener.local_addr()?; let mut options = ServerAssociationOptions::new() .accept_called_ae_title() @@ -28,7 +27,7 @@ fn spawn_scp( options = options.with_abstract_syntax(*abstract_syntax_uid); } - let handle = spawn(move || { + let handle = std::thread::spawn(move || { let (stream, _addr) = listener.accept()?; let mut association = options.establish(stream)?; assert_eq!( @@ -50,6 +49,43 @@ fn spawn_scp( Ok((handle, addr)) } +async fn spawn_scp_async( + abstract_syntax_uids: &'static [&str], + promiscuous: bool, +) -> Result<(tokio::task::JoinHandle>, SocketAddr)> { + let listener = tokio::net::TcpListener::bind("localhost:0").await?; + let addr = listener.local_addr()?; + let mut options = ServerAssociationOptions::new() + .accept_called_ae_title() + .ae_title(SCP_AE_TITLE) + .promiscuous(promiscuous); + + for abstract_syntax_uid in abstract_syntax_uids { + options = options.with_abstract_syntax(*abstract_syntax_uid); + } + + let handle = tokio::spawn(async move { + let (stream, _addr) = listener.accept().await?; + let mut association = options.establish_async(stream).await?; + assert_eq!( + association.presentation_contexts(), + &[PresentationContextResult { + id: 1, + reason: PresentationContextResultReason::Acceptance, + transfer_syntax: IMPLICIT_VR_LE.to_string(), + }] + ); + + let pdu = association.receive().await?; + assert_eq!(pdu, Pdu::ReleaseRQ); + association.send(&Pdu::ReleaseRP).await?; + + Ok(()) + }); + + Ok((handle, addr)) +} + #[test] fn scu_scp_association_promiscuous_enabled() { // SCP is set to promiscuous mode - all abstract syntaxes are accepted @@ -72,6 +108,30 @@ fn scu_scp_association_promiscuous_enabled() { .expect("Error at the SCP"); } +#[tokio::test(flavor = "multi_thread")] +async fn scu_scp_association_promiscuous_enabled_async() { + // SCP is set to promiscuous mode - all abstract syntaxes are accepted + let (scp_handle, scp_addr) = spawn_scp_async(&[], true).await.unwrap(); + + let association = ClientAssociationOptions::new() + .calling_ae_title(SCU_AE_TITLE) + .called_ae_title(SCP_AE_TITLE) + .with_presentation_context(MR_IMAGE_STORAGE_RAW, vec![IMPLICIT_VR_LE]) + .establish_async(scp_addr) + .await + .unwrap(); + + association + .release() + .await + .expect("did not have a peaceful release"); + + scp_handle + .await + .expect("SCP panicked") + .expect("Error at the SCP"); +} + #[test] fn scu_scp_association_promiscuous_disabled() { // SCP only accepts Ultrasound Image Storage @@ -89,3 +149,24 @@ fn scu_scp_association_promiscuous_disabled() { Err(NoAcceptedPresentationContexts { .. }) )); } + +#[tokio::test(flavor = "multi_thread")] +async fn scu_scp_association_promiscuous_disabled_async() { + // SCP only accepts Ultrasound Image Storage + let (_scu_handle, scp_addr) = spawn_scp_async(&[ULTRASOUND_IMAGE_STORAGE_RAW], false) + .await + .unwrap(); + + let association = ClientAssociationOptions::new() + .calling_ae_title(SCU_AE_TITLE) + .called_ae_title(SCP_AE_TITLE) + .with_presentation_context(MR_IMAGE_STORAGE_RAW, vec![IMPLICIT_VR_LE]) + .establish_async(scp_addr) + .await; + + // Assert that no presentation context was accepted + assert!(matches!( + association, + Err(NoAcceptedPresentationContexts { .. }) + )); +} diff --git a/ul/tests/association_store.rs b/ul/tests/association_store.rs index 6e4645424..e44aed79a 100644 --- a/ul/tests/association_store.rs +++ b/ul/tests/association_store.rs @@ -2,11 +2,7 @@ use dicom_ul::{ association::client::ClientAssociationOptions, pdu::{Pdu, PresentationContextResult, PresentationContextResultReason}, }; -use std::net::TcpListener; -use std::{ - net::SocketAddr, - thread::{spawn, JoinHandle}, -}; +use std::net::SocketAddr; use dicom_ul::association::server::ServerAssociationOptions; @@ -24,8 +20,8 @@ static MR_IMAGE_STORAGE: &str = "1.2.840.10008.5.1.4.1.1.4"; static DIGITAL_MG_STORAGE_SOP_CLASS_RAW: &str = "1.2.840.10008.5.1.4.1.1.1.2\0"; static DIGITAL_MG_STORAGE_SOP_CLASS: &str = "1.2.840.10008.5.1.4.1.1.1.2"; -fn spawn_scp() -> Result<(JoinHandle>, SocketAddr)> { - let listener = TcpListener::bind("localhost:0")?; +fn spawn_scp() -> Result<(std::thread::JoinHandle>, SocketAddr)> { + let listener = std::net::TcpListener::bind("localhost:0")?; let addr = listener.local_addr()?; let scp = ServerAssociationOptions::new() .accept_called_ae_title() @@ -33,7 +29,7 @@ fn spawn_scp() -> Result<(JoinHandle>, SocketAddr)> { .with_abstract_syntax(MR_IMAGE_STORAGE) .with_abstract_syntax(DIGITAL_MG_STORAGE_SOP_CLASS); - let h = spawn(move || -> Result<()> { + let h = std::thread::spawn(move || -> Result<()> { let (stream, _addr) = listener.accept()?; let mut association = scp.establish(stream)?; @@ -63,6 +59,45 @@ fn spawn_scp() -> Result<(JoinHandle>, SocketAddr)> { Ok((h, addr)) } +async fn spawn_scp_async() -> Result<(tokio::task::JoinHandle>, SocketAddr)> { + let listener = tokio::net::TcpListener::bind("localhost:0").await?; + let addr = listener.local_addr()?; + let scp = ServerAssociationOptions::new() + .accept_called_ae_title() + .ae_title(SCP_AE_TITLE) + .with_abstract_syntax(MR_IMAGE_STORAGE) + .with_abstract_syntax(DIGITAL_MG_STORAGE_SOP_CLASS); + + let h = tokio::task::spawn(async move { + let (stream, _addr) = listener.accept().await?; + let mut association = scp.establish_async(stream).await?; + + assert_eq!( + association.presentation_contexts(), + &[ + PresentationContextResult { + id: 1, + reason: PresentationContextResultReason::Acceptance, + transfer_syntax: IMPLICIT_VR_LE.to_string(), + }, + PresentationContextResult { + id: 3, + reason: PresentationContextResultReason::Acceptance, + transfer_syntax: JPEG_BASELINE.to_string(), + } + ], + ); + + // handle one release request + let pdu = association.receive().await?; + assert_eq!(pdu, Pdu::ReleaseRQ); + association.send(&Pdu::ReleaseRP).await?; + + Ok(()) + }); + Ok((h, addr)) +} + /// Run an SCP and an SCU concurrently, /// negotiate an association with distinct transfer syntaxes /// and release it. @@ -102,3 +137,42 @@ fn scu_scp_association_test() { .expect("SCP panicked") .expect("Error at the SCP"); } + +#[tokio::test(flavor = "multi_thread")] +async fn scu_scp_association_test_async() { + let (scp_handle, scp_addr) = spawn_scp_async().await.unwrap(); + + let association = ClientAssociationOptions::new() + .calling_ae_title(SCU_AE_TITLE) + .called_ae_title(SCP_AE_TITLE) + .with_presentation_context(MR_IMAGE_STORAGE_RAW, vec![IMPLICIT_VR_LE]) + // MG storage, JPEG baseline + .with_presentation_context(DIGITAL_MG_STORAGE_SOP_CLASS_RAW, vec![JPEG_BASELINE]) + .establish_async(scp_addr) + .await + .unwrap(); + + for pc in association.presentation_contexts() { + match pc.id { + 1 => { + // guaranteed to be MR image storage + assert_eq!(pc.transfer_syntax, IMPLICIT_VR_LE); + } + 3 => { + // guaranteed to be MG image storage + assert_eq!(pc.transfer_syntax, JPEG_BASELINE); + } + id => panic!("unexpected presentation context ID {}", id), + } + } + + association + .release() + .await + .expect("did not have a peaceful release"); + + scp_handle + .await + .expect("SCP panicked") + .expect("Error at the SCP"); +} diff --git a/ul/tests/association_store_uncompressed.rs b/ul/tests/association_store_uncompressed.rs index 0f2c6cc90..44c753e04 100644 --- a/ul/tests/association_store_uncompressed.rs +++ b/ul/tests/association_store_uncompressed.rs @@ -5,11 +5,7 @@ use dicom_ul::{ association::client::ClientAssociationOptions, pdu::{Pdu, PresentationContextResult, PresentationContextResultReason}, }; -use std::net::TcpListener; -use std::{ - net::SocketAddr, - thread::{spawn, JoinHandle}, -}; +use std::net::SocketAddr; use dicom_ul::association::server::ServerAssociationOptions; @@ -28,8 +24,8 @@ static MR_IMAGE_STORAGE: &str = "1.2.840.10008.5.1.4.1.1.4"; static DIGITAL_MG_STORAGE_SOP_CLASS_RAW: &str = "1.2.840.10008.5.1.4.1.1.1.2\0"; static DIGITAL_MG_STORAGE_SOP_CLASS: &str = "1.2.840.10008.5.1.4.1.1.1.2"; -fn spawn_scp() -> Result<(JoinHandle>, SocketAddr)> { - let listener = TcpListener::bind("localhost:0")?; +fn spawn_scp() -> Result<(std::thread::JoinHandle>, SocketAddr)> { + let listener = std::net::TcpListener::bind("localhost:0")?; let addr = listener.local_addr()?; let scp = ServerAssociationOptions::new() .accept_called_ae_title() @@ -39,7 +35,7 @@ fn spawn_scp() -> Result<(JoinHandle>, SocketAddr)> { .with_transfer_syntax(EXPLICIT_VR_LE) .with_transfer_syntax(IMPLICIT_VR_LE); - let h = spawn(move || -> Result<()> { + let h = std::thread::spawn(move || -> Result<()> { let (stream, _addr) = listener.accept()?; let mut association = scp.establish(stream)?; @@ -71,6 +67,49 @@ fn spawn_scp() -> Result<(JoinHandle>, SocketAddr)> { Ok((h, addr)) } +async fn spawn_scp_async() -> Result<(tokio::task::JoinHandle>, SocketAddr)> { + let listener = tokio::net::TcpListener::bind("localhost:0").await?; + let addr = listener.local_addr()?; + let scp = ServerAssociationOptions::new() + .accept_called_ae_title() + .ae_title(SCP_AE_TITLE) + .with_abstract_syntax(MR_IMAGE_STORAGE) + .with_abstract_syntax(DIGITAL_MG_STORAGE_SOP_CLASS) + .with_transfer_syntax(EXPLICIT_VR_LE) + .with_transfer_syntax(IMPLICIT_VR_LE); + + let h = tokio::task::spawn(async move { + let (stream, _addr) = listener.accept().await?; + let mut association = scp.establish_async(stream).await?; + + assert_eq!( + association.presentation_contexts(), + &[ + PresentationContextResult { + id: 1, + reason: PresentationContextResultReason::Acceptance, + transfer_syntax: IMPLICIT_VR_LE.to_string(), + }, + // should always pick Explicit VR LE + // because JPEG baseline was not explicitly enabled in SCP + PresentationContextResult { + id: 3, + reason: PresentationContextResultReason::Acceptance, + transfer_syntax: EXPLICIT_VR_LE.to_string(), + } + ], + ); + + // handle one release request + let pdu = association.receive().await?; + assert_eq!(pdu, Pdu::ReleaseRQ); + association.send(&Pdu::ReleaseRP).await?; + + Ok(()) + }); + Ok((h, addr)) +} + /// Run an SCP and an SCU concurrently, /// negotiate an association with distinct transfer syntaxes /// and release it. @@ -115,3 +154,47 @@ fn scu_scp_association_uncompressed() { .expect("SCP panicked") .expect("Error at the SCP"); } + +#[tokio::test(flavor = "multi_thread")] +async fn scu_scp_association_uncompressed_async() { + let (scp_handle, scp_addr) = spawn_scp_async().await.unwrap(); + + let association = ClientAssociationOptions::new() + .calling_ae_title(SCU_AE_TITLE) + .called_ae_title(SCP_AE_TITLE) + .with_presentation_context(MR_IMAGE_STORAGE_RAW, vec![IMPLICIT_VR_LE]) + // MG storage, JPEG baseline + .with_presentation_context( + DIGITAL_MG_STORAGE_SOP_CLASS_RAW, + vec![JPEG_BASELINE, EXPLICIT_VR_LE, IMPLICIT_VR_LE], + ) + .establish_async(scp_addr) + .await + .unwrap(); + + for pc in association.presentation_contexts() { + match pc.id { + // guaranteed to be MR image storage + 1 => { + // only one option provided + assert_eq!(pc.transfer_syntax, IMPLICIT_VR_LE); + } + // guaranteed to be MG image storage + 3 => { + // server picked this one because it did not accept JPEG baseline + assert_eq!(pc.transfer_syntax, EXPLICIT_VR_LE); + } + id => panic!("unexpected presentation context ID {}", id), + } + } + + association + .release() + .await + .expect("did not have a peaceful release"); + + scp_handle + .await + .expect("SCP panicked") + .expect("Error at the SCP"); +} diff --git a/ul/tests/pdu.rs b/ul/tests/pdu.rs index 26c7fa157..a28121cb6 100644 --- a/ul/tests/pdu.rs +++ b/ul/tests/pdu.rs @@ -1,8 +1,8 @@ -use dicom_ul::pdu::reader::{read_pdu, DEFAULT_MAX_PDU}; +use dicom_ul::pdu::reader::read_pdu; use dicom_ul::pdu::writer::write_pdu; use dicom_ul::pdu::{ AssociationRQ, PDataValue, PDataValueType, Pdu, PresentationContextProposed, UserIdentity, - UserIdentityType, UserVariableItem, + UserIdentityType, UserVariableItem, DEFAULT_MAX_PDU, }; use matches::matches; use std::io::Cursor; @@ -46,7 +46,7 @@ fn can_read_write_associate_rq() -> Result<(), Box> { let mut bytes = vec![0u8; 0]; write_pdu(&mut bytes, &association_rq.into())?; - let result = read_pdu(&mut Cursor::new(&bytes), DEFAULT_MAX_PDU, true)?; + let result = read_pdu(&mut Cursor::new(&bytes), DEFAULT_MAX_PDU, true)?.unwrap(); if let Pdu::AssociationRQ(AssociationRQ { protocol_version, @@ -134,7 +134,7 @@ fn can_read_write_primary_field_only_user_identity() -> Result<(), Box Result<(), Box> { let mut bytes = Vec::new(); write_pdu(&mut bytes, &pdata_rq)?; - let result = read_pdu(&mut Cursor::new(&bytes), DEFAULT_MAX_PDU, true)?; + let result = read_pdu(&mut Cursor::new(&bytes), DEFAULT_MAX_PDU, true)?.unwrap(); if let Pdu::PData { data } = result { assert_eq!(data.len(), 1);