diff --git a/src/sys/windows/from_raw_arc.rs b/src/sys/windows/from_raw_arc.rs index b6d38b240..2fbdfed21 100644 --- a/src/sys/windows/from_raw_arc.rs +++ b/src/sys/windows/from_raw_arc.rs @@ -22,6 +22,7 @@ use std::ops::Deref; use std::mem; use std::sync::atomic::{self, AtomicUsize, Ordering}; +use winapi::OVERLAPPED; pub struct FromRawArc { _inner: *mut Inner, @@ -86,6 +87,29 @@ impl Drop for FromRawArc { } } +unsafe impl Send for FromRawArcStore { } +unsafe impl Sync for FromRawArcStore { } + +pub struct FromRawArcStore { + _ptr: *mut OVERLAPPED, + deallocator: fn(*mut OVERLAPPED), +} + +impl FromRawArcStore { + pub fn new(ptr: *mut OVERLAPPED, deallocator: fn(*mut OVERLAPPED)) -> FromRawArcStore { + FromRawArcStore { + _ptr: unsafe { mem::transmute(ptr) }, + deallocator: deallocator, + } + } +} + +impl Drop for FromRawArcStore { + fn drop(&mut self) { + (self.deallocator)(self._ptr); + } +} + #[cfg(test)] mod tests { use super::FromRawArc; diff --git a/src/sys/windows/selector.rs b/src/sys/windows/selector.rs index 21e894b8e..cbdf4589a 100644 --- a/src/sys/windows/selector.rs +++ b/src/sys/windows/selector.rs @@ -6,6 +6,7 @@ use std::os::windows::prelude::*; use std::sync::{Arc, Mutex}; use std::sync::atomic::{AtomicUsize, Ordering, ATOMIC_USIZE_INIT}; use std::time::Duration; +use std::mem; use lazycell::AtomicLazyCell; @@ -16,6 +17,7 @@ use miow::iocp::{CompletionPort, CompletionStatus}; use event_imp::{Event, Evented, Ready}; use poll::{self, Poll}; use sys::windows::buffer_pool::BufferPool; +use sys::windows::from_raw_arc::FromRawArcStore; use {Token, PollOpt}; /// Each Selector has a globally unique(ish) ID associated with it. This ID @@ -48,6 +50,8 @@ struct SelectorInner { /// Primitives will take buffers from this pool to perform I/O operations, /// and once complete they'll be put back in. buffers: Mutex, + + incompletes: Mutex>, } impl Selector { @@ -61,6 +65,7 @@ impl Selector { id: id, port: cp, buffers: Mutex::new(BufferPool::new(256)), + incompletes: Mutex::new(Vec::new()), }), } }) @@ -91,6 +96,21 @@ impl Selector { ret = true; continue; } + // Deadlock will occur if you don't release it first before the callback. + { + let mut incompletes = self.inner.incompletes.lock().unwrap(); + let pos = incompletes.iter().position(|item| item.0 == (status.overlapped() as isize)); + match pos { + Some(pos) => { + let store = incompletes.remove(pos); + mem::forget(store); + }, + None => { + trace!("cannot find store, omiting..."); + continue; + } + } + } let callback = unsafe { (*(status.overlapped() as *mut Overlapped)).callback @@ -104,6 +124,16 @@ impl Selector { Ok(ret) } + pub fn store_overlapped_content(&self, ptr: *mut OVERLAPPED, deallocator: fn(*mut OVERLAPPED)) { + let mut incompletes = self.inner.incompletes.lock().unwrap(); + incompletes.push((ptr as isize, FromRawArcStore::new(ptr, deallocator))); + } + + pub fn clean_overlapped_content(&self, ptr: *mut OVERLAPPED) { + let mut incompletes = self.inner.incompletes.lock().unwrap(); + incompletes.retain(|item| ptr as isize != item.0); + } + /// Gets a reference to the underlying `CompletionPort` structure. pub fn port(&self) -> &CompletionPort { &self.inner.port @@ -404,6 +434,24 @@ impl ReadyBinding { .as_ref().unwrap() .deregister(poll) } + + pub fn store_overlapped_content(&self, ptr: *mut OVERLAPPED, deallocator: fn(*mut OVERLAPPED)) { + if let Some(i) = self.binding.selector.borrow() { + let selector = Selector { + inner: i.clone(), + }; + selector.store_overlapped_content(ptr, deallocator); + } + } + + pub fn clean_overlapped_content(&self, ptr: *mut OVERLAPPED) { + if let Some(i) = self.binding.selector.borrow() { + let selector = Selector { + inner: i.clone(), + }; + selector.clean_overlapped_content(ptr); + } + } } fn other(s: &str) -> io::Error { diff --git a/src/sys/windows/tcp.rs b/src/sys/windows/tcp.rs index 40843f6c5..b542e2a06 100644 --- a/src/sys/windows/tcp.rs +++ b/src/sys/windows/tcp.rs @@ -394,13 +394,14 @@ impl StreamImp { self.inner.inner.lock().unwrap() } - fn schedule_connect(&self, addr: &SocketAddr) -> io::Result<()> { + fn schedule_connect(&self, addr: &SocketAddr, me: &mut StreamInner) -> io::Result<()> { unsafe { trace!("scheduling a connect"); self.inner.socket.connect_overlapped(addr, &[], self.inner.read.as_mut_ptr())?; } // see docs above on StreamImp.inner for rationale on forget mem::forget(self.clone()); + me.iocp.store_overlapped_content(self.inner.read.as_mut_ptr(), read_deallocate); Ok(()) } @@ -454,6 +455,7 @@ impl StreamImp { // see docs above on StreamImp.inner for rationale on forget me.read = State::Pending(()); mem::forget(self.clone()); + me.iocp.store_overlapped_content(self.inner.read.as_mut_ptr(), read_deallocate); } Err(e) => { me.read = State::Error(e); @@ -499,6 +501,7 @@ impl StreamImp { // see docs above on StreamImp.inner for rationale on forget me.write = State::Pending((buf, pos)); mem::forget(self.clone()); + me.iocp.store_overlapped_content(self.inner.write.as_mut_ptr(), write_deallocate); break; } Err(e) => { @@ -578,6 +581,20 @@ fn write_done(status: &OVERLAPPED_ENTRY) { } } +fn read_deallocate(ptr: *mut OVERLAPPED) { + let me = StreamImp { + inner: unsafe { overlapped2arc!(ptr, StreamIo, read) }, + }; + drop(me); +} + +fn write_deallocate(ptr: *mut OVERLAPPED) { + let me = StreamImp { + inner: unsafe { overlapped2arc!(ptr, StreamIo, write) }, + }; + drop(me); +} + impl Evented for TcpStream { fn register(&self, poll: &Poll, token: Token, interest: Ready, opts: PollOpt) -> io::Result<()> { @@ -595,7 +612,7 @@ impl Evented for TcpStream { // successful connect will worry about generating writable/readable // events and scheduling a new read. if let Some(addr) = me.deferred_connect.take() { - return self.imp.schedule_connect(&addr).map(|_| ()) + return self.imp.schedule_connect(&addr, &mut me).map(|_| ()) } self.post_register(interest, &mut me); Ok(()) @@ -632,11 +649,15 @@ impl Drop for TcpStream { // Note that "Empty" here may mean that a connect is pending, so we // cancel even if that happens as well. unsafe { - match self.inner().read { + let inner = self.inner(); + match inner.read { State::Pending(_) | State::Empty => { trace!("cancelling active TCP read"); drop(super::cancel(&self.imp.inner.socket, &self.imp.inner.read)); + trace!("cleaning remaining overlapped contents"); + inner.iocp.clean_overlapped_content(self.imp.inner.read.as_mut_ptr()); + inner.iocp.clean_overlapped_content(self.imp.inner.write.as_mut_ptr()); } State::Ready(_) | State::Error(_) => {} } @@ -754,6 +775,7 @@ impl ListenerImp { // see docs above on StreamImp.inner for rationale on forget me.accept = State::Pending(socket); mem::forget(self.clone()); + me.iocp.store_overlapped_content(self.inner.accept.as_mut_ptr(), accept_dellocate); } Err(e) => { me.accept = State::Error(e); @@ -794,6 +816,13 @@ fn accept_done(status: &OVERLAPPED_ENTRY) { me2.add_readiness(&mut me, Ready::readable()); } +fn accept_dellocate(ptr: *mut OVERLAPPED) { + let me = ListenerImp { + inner: unsafe { overlapped2arc!(ptr, ListenerIo, accept) }, + }; + drop(me); +} + impl Evented for TcpListener { fn register(&self, poll: &Poll, token: Token, interest: Ready, opts: PollOpt) -> io::Result<()> { @@ -836,11 +865,14 @@ impl Drop for TcpListener { fn drop(&mut self) { // If we're still internally reading, we're no longer interested. unsafe { - match self.inner().accept { + let inner = self.inner(); + match inner.accept { State::Pending(_) => { trace!("cancelling active TCP accept"); drop(super::cancel(&self.imp.inner.socket, &self.imp.inner.accept)); + trace!("cleaning remaining overlapped contents"); + inner.iocp.clean_overlapped_content(self.imp.inner.accept.as_mut_ptr()); } State::Empty | State::Ready(_) | diff --git a/src/sys/windows/udp.rs b/src/sys/windows/udp.rs index 4d3fc040f..2e34a2458 100644 --- a/src/sys/windows/udp.rs +++ b/src/sys/windows/udp.rs @@ -113,6 +113,7 @@ impl UdpSocket { }?; me.write = State::Pending(owned_buf); mem::forget(self.imp.clone()); + me.iocp.store_overlapped_content(self.imp.inner.write.as_mut_ptr(), send_deallocate); Ok(amt) } @@ -147,6 +148,7 @@ impl UdpSocket { }?; me.write = State::Pending(owned_buf); mem::forget(self.imp.clone()); + me.iocp.store_overlapped_content(self.imp.inner.write.as_mut_ptr(), send_deallocate); Ok(amt) } @@ -313,6 +315,7 @@ impl Imp { Ok(_) => { me.read = State::Pending(buf); mem::forget(self.clone()); + me.iocp.store_overlapped_content(self.inner.read.as_mut_ptr(), recv_deallocate); } Err(e) => { me.read = State::Error(e); @@ -411,3 +414,17 @@ fn recv_done(status: &OVERLAPPED_ENTRY) { me.read = State::Ready(buf); me2.add_readiness(&mut me, Ready::readable()); } + +fn send_deallocate(ptr: *mut OVERLAPPED) { + let me = Imp { + inner: unsafe { overlapped2arc!(ptr, Io, write) }, + }; + drop(me); +} + +fn recv_deallocate(ptr: *mut OVERLAPPED) { + let me = Imp { + inner: unsafe { overlapped2arc!(ptr, Io, read) }, + }; + drop(me); +} \ No newline at end of file diff --git a/test/mod.rs b/test/mod.rs index 75cda53f4..d88038eae 100644 --- a/test/mod.rs +++ b/test/mod.rs @@ -33,6 +33,7 @@ mod test_tcp_level; mod test_udp_level; mod test_udp_socket; mod test_write_then_drop; +mod test_drop_cancels_interest_and_shuts_down; #[cfg(feature = "with-deprecated")] mod test_notify; diff --git a/test/test_drop_cancels_interest_and_shuts_down.rs b/test/test_drop_cancels_interest_and_shuts_down.rs new file mode 100644 index 000000000..dede12d8f --- /dev/null +++ b/test/test_drop_cancels_interest_and_shuts_down.rs @@ -0,0 +1,60 @@ +#[test] +fn drop_cancels_interest_and_shuts_down() { + use mio::net::TcpStream; + use mio::*; + use std::io; + use std::io::Read; + use std::net::TcpListener; + use std::thread; + use std::time::Duration; + + use env_logger; + let _ = env_logger::init(); + let l = TcpListener::bind("127.0.0.1:0").unwrap(); + let addr = l.local_addr().unwrap(); + + let t = thread::spawn(move || { + let mut s = l.incoming().next().unwrap().unwrap(); + s.set_read_timeout(Some(Duration::from_secs(5))) + .expect("set_read_timeout"); + let r = s.read(&mut [0; 16]); + match r { + Ok(_) => (), + Err(e) => { + if e.kind() != io::ErrorKind::UnexpectedEof { + panic!(e); + } + } + } + }); + + let poll = Poll::new().unwrap(); + let mut s = TcpStream::connect(&addr).unwrap(); + + poll.register( + &s, + Token(1), + Ready::readable() | Ready::writable(), + PollOpt::edge(), + ).unwrap(); + let mut events = Events::with_capacity(16); + 'outer: loop { + poll.poll(&mut events, None).unwrap(); + for event in &events { + if event.token() == Token(1) { + // connected + break 'outer; + } + } + } + + let mut b = [0; 1024]; + match s.read(&mut b) { + Ok(_) => panic!("unexpected ok"), + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => (), + Err(e) => panic!("unexpected error: {:?}", e), + } + + drop(s); + t.join().unwrap(); +}