diff --git a/crates/bevy_reflect/bevy_reflect_derive/src/lib.rs b/crates/bevy_reflect/bevy_reflect_derive/src/lib.rs index 8128b292b8559..e70f2d2f8c0b2 100644 --- a/crates/bevy_reflect/bevy_reflect_derive/src/lib.rs +++ b/crates/bevy_reflect/bevy_reflect_derive/src/lib.rs @@ -719,6 +719,7 @@ fn impl_get_type_registration( impl #impl_generics #bevy_reflect_path::GetTypeRegistration for #type_name #ty_generics #where_clause { fn get_type_registration() -> #bevy_reflect_path::TypeRegistration { let mut registration = #bevy_reflect_path::TypeRegistration::of::<#type_name #ty_generics>(); + registration.insert::<#bevy_reflect_path::ReflectFromPtr>(#bevy_reflect_path::FromType::<#type_name #ty_generics>::from_type()); #(registration.insert::<#registration_data>(#bevy_reflect_path::FromType::<#type_name #ty_generics>::from_type());)* registration } diff --git a/crates/bevy_reflect/src/impls/std.rs b/crates/bevy_reflect/src/impls/std.rs index 0b1d4097e1e11..5d270f88d802a 100644 --- a/crates/bevy_reflect/src/impls/std.rs +++ b/crates/bevy_reflect/src/impls/std.rs @@ -1,4 +1,4 @@ -use crate as bevy_reflect; +use crate::{self as bevy_reflect, ReflectFromPtr}; use crate::{ map_partial_eq, serde::Serializable, DynamicMap, FromReflect, FromType, GetTypeRegistration, List, ListIter, Map, MapIter, Reflect, ReflectDeserialize, ReflectMut, ReflectRef, @@ -148,6 +148,7 @@ impl Deserialize<'de>> GetTypeRegistration for Vec fn get_type_registration() -> TypeRegistration { let mut registration = TypeRegistration::of::>(); registration.insert::(FromType::>::from_type()); + registration.insert::(FromType::>::from_type()); registration } } @@ -270,6 +271,7 @@ where fn get_type_registration() -> TypeRegistration { let mut registration = TypeRegistration::of::(); registration.insert::(FromType::::from_type()); + registration.insert::(FromType::::from_type()); registration } } @@ -355,6 +357,7 @@ impl GetTypeRegistration for Cow<'static, str> { fn get_type_registration() -> TypeRegistration { let mut registration = TypeRegistration::of::>(); registration.insert::(FromType::>::from_type()); + registration.insert::(FromType::>::from_type()); registration } } diff --git a/crates/bevy_reflect/src/type_registry.rs b/crates/bevy_reflect/src/type_registry.rs index a6aa64f487305..fbbcd949287a3 100644 --- a/crates/bevy_reflect/src/type_registry.rs +++ b/crates/bevy_reflect/src/type_registry.rs @@ -3,7 +3,7 @@ use bevy_utils::{HashMap, HashSet}; use downcast_rs::{impl_downcast, Downcast}; use parking_lot::{RwLock, RwLockReadGuard, RwLockWriteGuard}; use serde::Deserialize; -use std::{any::TypeId, fmt::Debug, sync::Arc}; +use std::{any::TypeId, fmt::Debug, marker::PhantomData, sync::Arc}; /// A registry of reflected types. #[derive(Default)] @@ -350,11 +350,103 @@ impl Deserialize<'a> + Reflect> FromType for ReflectDeserialize { } } +#[derive(Clone)] +pub struct ReflectFromPtr { + type_id: TypeId, + to_reflect: for<'a> unsafe fn(*const (), lt: PhantomData<&'a ()>) -> &'a dyn Reflect, + to_reflect_mut: for<'a> unsafe fn(*mut (), lt: PhantomData<&'a ()>) -> &'a mut dyn Reflect, +} + +impl ReflectFromPtr { + /// # Safety: + /// - `val` must be a pointer to a value of the type that the [`ReflectFromPtr`] was constructed for + /// - the lifetime `'a` of the return type can be arbitrarily chosen by the caller, who must ensure that during + /// that lifetime, `val` is valid + pub unsafe fn to_reflect_ptr<'a>(&self, val: *const ()) -> &'a dyn Reflect { + (self.to_reflect)(val, PhantomData) + } + + /// # Safety: + /// - `val` must be a pointer to a value of the type that the [`ReflectFromPtr`] was constructed for + /// - the lifetime `'a` of the return type can be arbitrarily chosen by the caller, who must ensure that during + /// that lifetime, `val` is valid and not aliased + pub unsafe fn to_reflect_ptr_mut<'a>(&self, val: *mut ()) -> &'a mut dyn Reflect { + (self.to_reflect_mut)(val, PhantomData) + } + + pub fn to_reflect<'a, T: 'static>(&self, val: &'a T) -> &'a dyn Reflect { + assert_eq!(self.type_id, std::any::TypeId::of::()); + // SAFE: the lifetime of `val` is the same as the lifetime of the return value + // and the type of `val` has been checked to be the same correct one + unsafe { self.to_reflect_ptr(val as *const _ as *const ()) } + } + + pub fn to_reflect_mut<'a, T: 'static>(&self, val: &'a mut T) -> &'a mut dyn Reflect { + assert_eq!(self.type_id, std::any::TypeId::of::()); + // SAFE: the lifetime of `val` is the same as the lifetime of the return value + // and the type of `val` has been checked to be the same correct one + unsafe { self.to_reflect_ptr_mut(val as *mut _ as *mut ()) } + } +} + +impl FromType for ReflectFromPtr { + fn from_type() -> Self { + ReflectFromPtr { + type_id: std::any::TypeId::of::(), + to_reflect: |ptr, _lt| { + // SAFE: can only be called by `to_reflect_ptr` where the caller promises the safety requirements + // or `to_reflect` which is typed and checks that the correct type is used. + let val: &T = unsafe { &*ptr.cast::() }; + val as &dyn Reflect + }, + to_reflect_mut: |ptr, _lt| { + // SAFE: can only be called by `to_reflect_ptr_mut` where the caller promises the safety requirements + // or `to_reflect_mut` which is typed and checks that the correct type is used. + let val: &mut T = unsafe { &mut *ptr.cast::() }; + val as &mut dyn Reflect + }, + } + } +} + #[cfg(test)] mod test { - use crate::TypeRegistration; + use crate::{GetTypeRegistration, ReflectFromPtr, TypeRegistration}; use bevy_utils::HashMap; + use crate as bevy_reflect; + use crate::Reflect; + + #[test] + fn test_reflect_from_ptr() { + #[derive(Reflect)] + struct Foo { + a: f32, + } + + let foo_registration = ::get_type_registration(); + let reflect_from_ptr = foo_registration.data::().unwrap(); + + let mut value = Foo { a: 1.0 }; + + let dyn_reflect = reflect_from_ptr.to_reflect_mut(&mut value); + match dyn_reflect.reflect_mut() { + bevy_reflect::ReflectMut::Struct(strukt) => { + strukt.field_mut("a").unwrap().apply(&2.0f32) + } + _ => panic!("invalid reflection"), + } + + let dyn_reflect = reflect_from_ptr.to_reflect(&value); + match dyn_reflect.reflect_ref() { + bevy_reflect::ReflectRef::Struct(strukt) => { + let a = strukt.field("a").unwrap().downcast_ref::().unwrap(); + assert_eq!(*a, 2.0); + } + _ => panic!("invalid reflection"), + } + } + #[test] fn test_get_short_name() { assert_eq!(