Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Type-specific handle validity checking #1648

Merged
merged 6 commits into from
Mar 31, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ use windows_sys::{

fn main() {
unsafe {
let event = CreateEventW(std::ptr::null(), 1, 0, std::ptr::null());
let event = CreateEventW(std::ptr::null(), 1, 0, std::ptr::null())?;
SetEvent(event);
WaitForSingleObject(event, 0);
CloseHandle(event);
Expand Down
2 changes: 1 addition & 1 deletion crates/libs/bindgen/src/async.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ fn gen_async_kind(kind: AsyncKind, name: &TypeDef, self_name: &TypeDef, cfg: &Cf
impl<#(#constraints)*> #name {
pub fn get(&self) -> ::windows::core::Result<#return_type> {
if self.Status()? == #namespace AsyncStatus::Started {
let (_waiter, signaler) = ::windows::core::Waiter::new();
let (_waiter, signaler) = ::windows::core::Waiter::new()?;
self.SetCompleted(#namespace #handler::new(move |_sender, _args| {
// Safe because the waiter will only be dropped after being signaled.
unsafe { signaler.signal(); }
Expand Down
73 changes: 58 additions & 15 deletions crates/libs/bindgen/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -203,24 +203,49 @@ fn gen_win_function(def: &MethodDef, gen: &Gen) -> TokenStream {
}
}
SignatureKind::ReturnStruct | SignatureKind::PreserveSig => {
let args = gen_win32_args(&signature.params);
let params = gen_win32_params(&signature.params, gen);
if handle_last_error(def, &signature) {
let args = gen_win32_args(&signature.params);
let params = gen_win32_params(&signature.params, gen);
let return_type = gen_element_name(&signature.return_type.unwrap(), gen);

quote! {
#doc
#features
#[inline]
pub unsafe fn #name<#constraints>(#params) #abi_return_type {
#[cfg(windows)]
{
#link_attr
extern "system" {
fn #name(#(#abi_params),*) #abi_return_type;
quote! {
#doc
#features
#[inline]
pub unsafe fn #name<#constraints>(#params) -> ::windows::core::Result<#return_type> {
#[cfg(windows)]
{
#link_attr
extern "system" {
fn #name(#(#abi_params),*) -> #return_type;
}
let result__ = #name(#args);
(!result__.is_invalid()).then(||result__).ok_or_else(::windows::core::Error::from_win32)
}
::core::mem::transmute(#name(#args))
#[cfg(not(windows))]
unimplemented!("Unsupported target OS");
}
}
} else {
let args = gen_win32_args(&signature.params);
let params = gen_win32_params(&signature.params, gen);

quote! {
#doc
#features
#[inline]
pub unsafe fn #name<#constraints>(#params) #abi_return_type {
#[cfg(windows)]
{
#link_attr
extern "system" {
fn #name(#(#abi_params),*) #abi_return_type;
}
::core::mem::transmute(#name(#args))
}
#[cfg(not(windows))]
unimplemented!("Unsupported target OS");
}
#[cfg(not(windows))]
unimplemented!("Unsupported target OS");
}
}
}
Expand Down Expand Up @@ -257,3 +282,21 @@ fn does_not_return(def: &MethodDef) -> TokenStream {
quote! {}
}
}

fn handle_last_error(def: &MethodDef, signature: &Signature) -> bool {
if let Some(map) = def.impl_map() {
if map.flags().last_error() {
if let Some(Type::TypeDef(return_type)) = &signature.return_type {
if return_type.is_handle() {
if return_type.underlying_type().is_pointer() {
return true;
}
if !return_type.invalid_values().is_empty() {
return true;
}
}
}
}
}
false
}
51 changes: 31 additions & 20 deletions crates/libs/bindgen/src/handles.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ pub fn gen(def: &TypeDef, gen: &Gen) -> TokenStream {

pub fn gen_sys_handle(def: &TypeDef, gen: &Gen) -> TokenStream {
let ident = gen_ident(def.name());
let signature = gen_signature(def, gen);
let signature = gen_default_type(&def.underlying_type(), gen);

quote! {
pub type #ident = #signature;
Expand All @@ -20,26 +20,42 @@ pub fn gen_sys_handle(def: &TypeDef, gen: &Gen) -> TokenStream {
pub fn gen_win_handle(def: &TypeDef, gen: &Gen) -> TokenStream {
let name = def.name();
let ident = gen_ident(def.name());
let signature = gen_signature(def, gen);
let underlying_type = def.underlying_type();
let signature = gen_default_type(&underlying_type, gen);
let check = if underlying_type.is_pointer() {
quote! {
impl #ident {
pub fn is_invalid(&self) -> bool {
self.0.is_null()
}
}
}
} else {
let invalid = def.invalid_values();

if !invalid.is_empty() {
let invalid = invalid.iter().map(|value| {
let value = Literal::i64_unsuffixed(*value);
quote! { self.0 == #value }
});
quote! {
impl #ident {
pub fn is_invalid(&self) -> bool {
#(#invalid)||*
}
kennykerr marked this conversation as resolved.
Show resolved Hide resolved
}
}
} else {
quote! {}
}
};

let mut tokens = quote! {
#[repr(transparent)]
// Unfortunately, Rust requires these to be derived to allow constant patterns.
#[derive(::core::cmp::PartialEq, ::core::cmp::Eq)]
pub struct #ident(pub #signature);
impl #ident {
pub fn is_invalid(&self) -> bool {
*self == unsafe { ::core::mem::zeroed() }
}

pub fn ok(self) -> ::windows::core::Result<Self> {
if !self.is_invalid() {
Ok(self)
} else {
Err(::windows::core::Error::from_win32())
}
}
}
#check
impl ::core::default::Default for #ident {
fn default() -> Self {
unsafe { ::core::mem::zeroed() }
Expand Down Expand Up @@ -77,8 +93,3 @@ pub fn gen_win_handle(def: &TypeDef, gen: &Gen) -> TokenStream {

tokens
}

fn gen_signature(def: &TypeDef, gen: &Gen) -> TokenStream {
let def = def.fields().next().map(|field| field.get_type(Some(def))).unwrap();
gen_default_type(&def, gen)
}
46 changes: 0 additions & 46 deletions crates/libs/bindgen/src/replacements/handle.rs

This file was deleted.

2 changes: 0 additions & 2 deletions crates/libs/bindgen/src/replacements/mod.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
use super::*;
mod bool32;
mod bstr;
mod handle;
mod ntstatus;

pub fn gen(def: &TypeDef) -> Option<TokenStream> {
match def.type_name() {
TypeName::BOOL => Some(bool32::gen()),
TypeName::BSTR => Some(bstr::gen()),
TypeName::NTSTATUS => Some(ntstatus::gen()),
TypeName::HANDLE => Some(handle::gen()),
kennykerr marked this conversation as resolved.
Show resolved Hide resolved
_ => None,
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,12 @@ impl ParamFlags {
self.0 & 0x0010 != 0
}
}

#[derive(Default)]
pub struct PInvokeAttributes(pub u32);

impl PInvokeAttributes {
pub fn last_error(&self) -> bool {
self.0 & 0x0040 != 0
}
}
4 changes: 2 additions & 2 deletions crates/libs/metadata/src/reader/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ mod cfg;
mod codes;
mod constant_value;
mod file;
mod flags;
mod guid;
mod interface_kind;
mod param_flags;
mod row;
mod signature;
mod signature_kind;
Expand All @@ -27,9 +27,9 @@ pub use cfg::*;
pub use codes::*;
pub use constant_value::*;
pub use file::*;
pub use flags::*;
pub use guid::*;
pub use interface_kind::*;
pub use param_flags::*;
pub use r#type::*;
pub use row::*;
pub use signature::*;
Expand Down
4 changes: 4 additions & 0 deletions crates/libs/metadata/src/reader/tables/impl_map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@ use super::*;
pub struct ImplMap(pub Row);

impl ImplMap {
pub fn flags(&self) -> PInvokeAttributes {
PInvokeAttributes(self.0.u32(0))
}

pub fn scope(&self) -> ModuleRef {
ModuleRef(Row::new(self.0.u32(3) - 1, TableIndex::ModuleRef, self.0.file))
}
Expand Down
13 changes: 13 additions & 0 deletions crates/libs/metadata/src/reader/tables/type_def.rs
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,19 @@ impl TypeDef {
})
}

pub fn invalid_values(&self) -> Vec<i64> {
self.attributes()
.filter_map(|attribute| {
if attribute.name() == "InvalidHandleValueAttribute" {
if let Some((_, ConstantValue::I64(value))) = attribute.args().get(0) {
return Some(*value);
}
}
None
})
.collect()
}

pub fn is_convertible_to(&self) -> Option<&Type> {
self.attributes().find_map(|attribute| {
if attribute.name() == "AlsoUsableForAttribute" {
Expand Down
1 change: 0 additions & 1 deletion crates/libs/metadata/src/reader/type_name.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ impl TypeName {
pub const PWSTR: Self = Self::from_const("Windows.Win32.Foundation", "PWSTR");
pub const PSTR: Self = Self::from_const("Windows.Win32.Foundation", "PSTR");
pub const BSTR: Self = Self::from_const("Windows.Win32.Foundation", "BSTR");
pub const HANDLE: Self = Self::from_const("Windows.Win32.Foundation", "HANDLE");
pub const HRESULT: Self = Self::from_const("Windows.Win32.Foundation", "HRESULT");
pub const D2D_MATRIX_3X2_F: Self = Self::from_const("Windows.Win32.Graphics.Direct2D.Common", "D2D_MATRIX_3X2_F");
pub const IUnknown: Self = Self::from_const("Windows.Win32.System.Com", "IUnknown");
Expand Down
1 change: 1 addition & 0 deletions crates/libs/tokens/src/token_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ macro_rules! unsuffixed {
}

impl Literal {
unsuffixed!(i64 => i64_unsuffixed);
unsuffixed!(usize => usize_unsuffixed);
unsuffixed!(u32 => u32_unsuffixed);
unsuffixed!(u16 => u16_unsuffixed);
Expand Down
12 changes: 6 additions & 6 deletions crates/libs/windows/src/Windows/Devices/Sms/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ impl ::windows::core::RuntimeName for DeleteSmsMessageOperation {
impl DeleteSmsMessageOperation {
pub fn get(&self) -> ::windows::core::Result<()> {
if self.Status()? == super::super::Foundation::AsyncStatus::Started {
let (_waiter, signaler) = ::windows::core::Waiter::new();
let (_waiter, signaler) = ::windows::core::Waiter::new()?;
self.SetCompleted(super::super::Foundation::AsyncActionCompletedHandler::new(move |_sender, _args| {
unsafe {
signaler.signal();
Expand Down Expand Up @@ -377,7 +377,7 @@ impl ::windows::core::RuntimeName for DeleteSmsMessagesOperation {
impl DeleteSmsMessagesOperation {
pub fn get(&self) -> ::windows::core::Result<()> {
if self.Status()? == super::super::Foundation::AsyncStatus::Started {
let (_waiter, signaler) = ::windows::core::Waiter::new();
let (_waiter, signaler) = ::windows::core::Waiter::new()?;
self.SetCompleted(super::super::Foundation::AsyncActionCompletedHandler::new(move |_sender, _args| {
unsafe {
signaler.signal();
Expand Down Expand Up @@ -615,7 +615,7 @@ impl ::windows::core::RuntimeName for GetSmsDeviceOperation {
impl GetSmsDeviceOperation {
pub fn get(&self) -> ::windows::core::Result<SmsDevice> {
if self.Status()? == super::super::Foundation::AsyncStatus::Started {
let (_waiter, signaler) = ::windows::core::Waiter::new();
let (_waiter, signaler) = ::windows::core::Waiter::new()?;
self.SetCompleted(super::super::Foundation::AsyncOperationCompletedHandler::new(move |_sender, _args| {
unsafe {
signaler.signal();
Expand Down Expand Up @@ -853,7 +853,7 @@ impl ::windows::core::RuntimeName for GetSmsMessageOperation {
impl GetSmsMessageOperation {
pub fn get(&self) -> ::windows::core::Result<ISmsMessage> {
if self.Status()? == super::super::Foundation::AsyncStatus::Started {
let (_waiter, signaler) = ::windows::core::Waiter::new();
let (_waiter, signaler) = ::windows::core::Waiter::new()?;
self.SetCompleted(super::super::Foundation::AsyncOperationCompletedHandler::new(move |_sender, _args| {
unsafe {
signaler.signal();
Expand Down Expand Up @@ -1106,7 +1106,7 @@ impl ::windows::core::RuntimeName for GetSmsMessagesOperation {
impl GetSmsMessagesOperation {
pub fn get(&self) -> ::windows::core::Result<super::super::Foundation::Collections::IVectorView<ISmsMessage>> {
if self.Status()? == super::super::Foundation::AsyncStatus::Started {
let (_waiter, signaler) = ::windows::core::Waiter::new();
let (_waiter, signaler) = ::windows::core::Waiter::new()?;
self.SetCompleted(super::super::Foundation::AsyncOperationWithProgressCompletedHandler::new(move |_sender, _args| {
unsafe {
signaler.signal();
Expand Down Expand Up @@ -2830,7 +2830,7 @@ impl ::windows::core::RuntimeName for SendSmsMessageOperation {
impl SendSmsMessageOperation {
pub fn get(&self) -> ::windows::core::Result<()> {
if self.Status()? == super::super::Foundation::AsyncStatus::Started {
let (_waiter, signaler) = ::windows::core::Waiter::new();
let (_waiter, signaler) = ::windows::core::Waiter::new()?;
self.SetCompleted(super::super::Foundation::AsyncActionCompletedHandler::new(move |_sender, _args| {
unsafe {
signaler.signal();
Expand Down
Loading