Skip to content

Commit

Permalink
Implement map iteration
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 604447996
  • Loading branch information
kupiakos authored and copybara-github committed Feb 5, 2024
1 parent 37826c1 commit 035d6ec
Show file tree
Hide file tree
Showing 9 changed files with 525 additions and 81 deletions.
173 changes: 147 additions & 26 deletions rust/cpp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@

use crate::__internal::{Enum, Private, PtrAndLen, RawArena, RawMap, RawMessage, RawRepeatedField};
use crate::{
Map, Mut, ProtoStr, Proxied, ProxiedInMapValue, ProxiedInRepeated, Repeated, RepeatedMut,
RepeatedView, SettableValue, View,
Map, MapIter, Mut, ProtoStr, Proxied, ProxiedInMapValue, ProxiedInRepeated, Repeated,
RepeatedMut, RepeatedView, SettableValue, View,
};
use core::fmt::Debug;
use paste::paste;
use std::alloc::Layout;
use std::cell::UnsafeCell;
use std::convert::identity;
use std::ffi::c_int;
use std::ffi::{c_int, c_void};
use std::fmt;
use std::marker::PhantomData;
use std::mem::MaybeUninit;
Expand Down Expand Up @@ -157,6 +157,7 @@ pub type BytesPresentMutData<'msg> = crate::vtable::RawVTableOptionalMutatorData
pub type BytesAbsentMutData<'msg> = crate::vtable::RawVTableOptionalMutatorData<'msg, [u8]>;
pub type InnerBytesMut<'msg> = crate::vtable::RawVTableMutator<'msg, [u8]>;
pub type InnerPrimitiveMut<'msg, T> = crate::vtable::RawVTableMutator<'msg, T>;
pub type RawMapIter = UntypedMapIterator;

#[derive(Debug)]
pub struct MessageVTable {
Expand Down Expand Up @@ -407,8 +408,93 @@ impl<'msg> InnerMapMut<'msg> {
}
}

/// An untyped iterator in a map, produced via `.cbegin()` on a typed map.
///
/// This struct is ABI-compatible with `proto2::internal::UntypedMapIterator`.
/// It is trivially constructible and destructible.
#[repr(C)]
pub struct UntypedMapIterator {
node: *mut c_void,
map: *const c_void,
bucket_index: u32,
}

impl UntypedMapIterator {
/// Returns `true` if this iterator is at the end of the map.
fn at_end(&self) -> bool {
// This behavior is verified via test `IteratorNodeFieldIsNullPtrAtEnd`.
self.node.is_null()
}

/// Assumes that the map iterator is for the input types, gets the current
/// entry, and moves the iterator forward to the next entry.
///
/// Conversion to and from FFI types is provided by the user.
/// This is a helper function for implementing
/// `ProxiedInMapValue::iter_next`.
///
/// # Safety
/// - The backing map must be valid and not be mutated for `'a`.
/// - The thunk must be safe to call if the iterator is not at the end of
/// the map.
/// - The thunk must always write to the `key` and `value` fields, but not
/// read from them.
/// - The get thunk must not move the iterator forward or backward.
#[inline(always)]
pub unsafe fn next_unchecked<'a, K, V, FfiKey, FfiValue>(
&mut self,
_private: Private,
iter_get_thunk: unsafe extern "C" fn(
iter: &mut UntypedMapIterator,
key: *mut FfiKey,
value: *mut FfiValue,
),
from_ffi_key: impl FnOnce(FfiKey) -> View<'a, K>,
from_ffi_value: impl FnOnce(FfiValue) -> View<'a, V>,
) -> Option<(View<'a, K>, View<'a, V>)>
where
K: Proxied + ?Sized + 'a,
V: ProxiedInMapValue<K> + ?Sized + 'a,
{
if self.at_end() {
return None;
}
let mut ffi_key = MaybeUninit::uninit();
let mut ffi_value = MaybeUninit::uninit();
// SAFETY:
// - The backing map outlives `'a`.
// - The iterator is not at the end (node is non-null).
// - `ffi_key` and `ffi_value` are not read (as uninit) as promised by the
// caller.
unsafe { (iter_get_thunk)(self, ffi_key.as_mut_ptr(), ffi_value.as_mut_ptr()) }

// SAFETY:
// - The backing map is alive as promised by the caller.
// - `self.at_end()` is false and the `get` does not change that.
// - `UntypedMapIterator` has the same ABI as
// `proto2::internal::UntypedMapIterator`. It is statically checked to be:
// - Trivially copyable.
// - Trivially destructible.
// - Standard layout.
// - The size and alignment of the Rust type above.
// - With the `node_` field first.
unsafe { __rust_proto_thunk__UntypedMapIterator_increment(self) }

// SAFETY:
// - The `get` function always writes valid values to `ffi_key` and `ffi_value`
// as promised by the caller.
unsafe {
Some((from_ffi_key(ffi_key.assume_init()), from_ffi_value(ffi_value.assume_init())))
}
}
}

extern "C" {
fn __rust_proto_thunk__UntypedMapIterator_increment(iter: &mut UntypedMapIterator);
}

macro_rules! impl_ProxiedInMapValue_for_non_generated_value_types {
($key_t:ty, $ffi_key_t:ty, $to_ffi_key:expr, for $($t:ty, $ffi_t:ty, $to_ffi_value:expr, $from_ffi_value:expr, $zero_val:literal;)*) => {
($key_t:ty, $ffi_key_t:ty, $to_ffi_key:expr, $from_ffi_key:expr, for $($t:ty, $ffi_t:ty, $to_ffi_value:expr, $from_ffi_value:expr;)*) => {
paste! { $(
extern "C" {
fn [< __rust_proto_thunk__Map_ $key_t _ $t _new >]() -> RawMap;
Expand All @@ -417,6 +503,8 @@ macro_rules! impl_ProxiedInMapValue_for_non_generated_value_types {
fn [< __rust_proto_thunk__Map_ $key_t _ $t _size >](m: RawMap) -> usize;
fn [< __rust_proto_thunk__Map_ $key_t _ $t _insert >](m: RawMap, key: $ffi_key_t, value: $ffi_t) -> bool;
fn [< __rust_proto_thunk__Map_ $key_t _ $t _get >](m: RawMap, key: $ffi_key_t, value: *mut $ffi_t) -> bool;
fn [< __rust_proto_thunk__Map_ $key_t _ $t _iter >](m: RawMap) -> UntypedMapIterator;
fn [< __rust_proto_thunk__MapIter_ $key_t _ $t _get >](iter: &mut UntypedMapIterator, key: *mut $ffi_key_t, value: *mut $ffi_t);
fn [< __rust_proto_thunk__Map_ $key_t _ $t _remove >](m: RawMap, key: $ffi_key_t, value: *mut $ffi_t) -> bool;
}

Expand Down Expand Up @@ -457,18 +545,50 @@ macro_rules! impl_ProxiedInMapValue_for_non_generated_value_types {

fn map_get<'a>(map: View<'a, Map<$key_t, Self>>, key: View<'_, $key_t>) -> Option<View<'a, Self>> {
let ffi_key = $to_ffi_key(key);
let mut ffi_value = $to_ffi_value($zero_val);
let found = unsafe { [< __rust_proto_thunk__Map_ $key_t _ $t _get >](map.as_raw(Private), ffi_key, &mut ffi_value) };
let mut ffi_value = MaybeUninit::uninit();
let found = unsafe { [< __rust_proto_thunk__Map_ $key_t _ $t _get >](map.as_raw(Private), ffi_key, ffi_value.as_mut_ptr()) };
if !found {
return None;
}
Some($from_ffi_value(ffi_value))
// SAFETY: if `found` is true, then the `ffi_value` was written to by `get`.
Some($from_ffi_value(unsafe { ffi_value.assume_init() }))
}

fn map_remove(mut map: Mut<'_, Map<$key_t, Self>>, key: View<'_, $key_t>) -> bool {
let ffi_key = $to_ffi_key(key);
let mut ffi_value = $to_ffi_value($zero_val);
unsafe { [< __rust_proto_thunk__Map_ $key_t _ $t _remove >](map.as_raw(Private), ffi_key, &mut ffi_value) }
let mut ffi_value = MaybeUninit::uninit();
unsafe { [< __rust_proto_thunk__Map_ $key_t _ $t _remove >](map.as_raw(Private), ffi_key, ffi_value.as_mut_ptr()) }
}

fn map_iter(map: View<'_, Map<$key_t, Self>>) -> MapIter<'_, $key_t, Self> {
// SAFETY:
// - The backing map for `map.as_raw` is valid for at least '_.
// - A View that is live for '_ guarantees the backing map is unmodified for '_.
// - The `iter` function produces an iterator that is valid for the key
// and value types, and live for at least '_.
unsafe {
MapIter::from_raw(
Private,
[< __rust_proto_thunk__Map_ $key_t _ $t _iter >](map.as_raw(Private))
)
}
}

fn map_iter_next<'a>(iter: &mut MapIter<'a, $key_t, Self>) -> Option<(View<'a, $key_t>, View<'a, Self>)> {
// SAFETY:
// - The `MapIter` API forbids the backing map from being mutated for 'a,
// and guarantees that it's the correct key and value types.
// - The thunk is safe to call as long as the iterator isn't at the end.
// - The thunk always writes to key and value fields and does not read.
// - The thunk does not increment the iterator.
unsafe {
iter.as_raw_mut(Private).next_unchecked::<$key_t, Self, _, _>(
Private,
[< __rust_proto_thunk__MapIter_ $key_t _ $t _get >],
$from_ffi_key,
$from_ffi_value,
)
}
}
}
)* }
Expand Down Expand Up @@ -496,32 +616,33 @@ fn ptrlen_to_bytes<'msg>(val: PtrAndLen) -> &'msg [u8] {
}

macro_rules! impl_ProxiedInMapValue_for_key_types {
($($t:ty, $ffi_t:ty, $to_ffi_key:expr;)*) => {
($($t:ty, $ffi_t:ty, $to_ffi_key:expr, $from_ffi_key:expr;)*) => {
paste! {
$(
impl_ProxiedInMapValue_for_non_generated_value_types!($t, $ffi_t, $to_ffi_key, for
f32, f32, identity, identity, 0f32;
f64, f64, identity, identity, 0f64;
i32, i32, identity, identity, 0i32;
u32, u32, identity, identity, 0u32;
i64, i64, identity, identity, 0i64;
u64, u64, identity, identity, 0u64;
bool, bool, identity, identity, false;
ProtoStr, PtrAndLen, str_to_ptrlen, ptrlen_to_str, "";
Bytes, PtrAndLen, bytes_to_ptrlen, ptrlen_to_bytes, b"";
impl_ProxiedInMapValue_for_non_generated_value_types!(
$t, $ffi_t, $to_ffi_key, $from_ffi_key, for
f32, f32, identity, identity;
f64, f64, identity, identity;
i32, i32, identity, identity;
u32, u32, identity, identity;
i64, i64, identity, identity;
u64, u64, identity, identity;
bool, bool, identity, identity;
ProtoStr, PtrAndLen, str_to_ptrlen, ptrlen_to_str;
Bytes, PtrAndLen, bytes_to_ptrlen, ptrlen_to_bytes;
);
)*
}
}
}

impl_ProxiedInMapValue_for_key_types!(
i32, i32, identity;
u32, u32, identity;
i64, i64, identity;
u64, u64, identity;
bool, bool, identity;
ProtoStr, PtrAndLen, str_to_ptrlen;
i32, i32, identity, identity;
u32, u32, identity, identity;
i64, i64, identity, identity;
u64, u64, identity, identity;
bool, bool, identity, identity;
ProtoStr, PtrAndLen, str_to_ptrlen, ptrlen_to_str;
);

#[cfg(test)]
Expand Down
94 changes: 61 additions & 33 deletions rust/cpp_kernel/cpp_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,14 @@ expose_repeated_ptr_field_methods(Bytes);

#undef expose_repeated_ptr_field_methods

void __rust_proto_thunk__UntypedMapIterator_increment(
google::protobuf::internal::UntypedMapIterator* iter) {
iter->PlusPlus();
}

#define expose_scalar_map_methods(key_ty, rust_key_ty, ffi_key_ty, to_cpp_key, \
value_ty, rust_value_ty, ffi_value_ty, \
to_cpp_value, to_ffi_value) \
to_ffi_key, value_ty, rust_value_ty, \
ffi_value_ty, to_cpp_value, to_ffi_value) \
google::protobuf::Map<key_ty, value_ty>* \
__rust_proto_thunk__Map_##rust_key_ty##_##rust_value_ty##_new() { \
return new google::protobuf::Map<key_ty, value_ty>(); \
Expand Down Expand Up @@ -132,46 +137,69 @@ expose_repeated_ptr_field_methods(Bytes);
*value = to_ffi_value; \
return true; \
} \
google::protobuf::internal::UntypedMapIterator \
__rust_proto_thunk__Map_##rust_key_ty##_##rust_value_ty##_iter( \
const google::protobuf::Map<key_ty, value_ty>* m) { \
return google::protobuf::internal::UntypedMapIterator::FromTyped(m->cbegin()); \
} \
void __rust_proto_thunk__MapIter_##rust_key_ty##_##rust_value_ty##_get( \
const google::protobuf::internal::UntypedMapIterator* iter, ffi_key_ty* key, \
ffi_value_ty* value) { \
auto typed_iter = \
iter->ToTyped<google::protobuf::Map<key_ty, value_ty>::const_iterator>(); \
const auto& cpp_key = typed_iter->first; \
const auto& cpp_value = typed_iter->second; \
*key = to_ffi_key; \
*value = to_ffi_value; \
} \
bool __rust_proto_thunk__Map_##rust_key_ty##_##rust_value_ty##_remove( \
google::protobuf::Map<key_ty, value_ty>* m, ffi_key_ty key, ffi_value_ty* value) { \
auto cpp_key = to_cpp_key; \
auto num_removed = m->erase(cpp_key); \
return num_removed > 0; \
}

#define expose_scalar_map_methods_for_key_type(key_ty, rust_key_ty, \
ffi_key_ty, to_cpp_key) \
expose_scalar_map_methods(key_ty, rust_key_ty, ffi_key_ty, to_cpp_key, \
int32_t, i32, int32_t, value, cpp_value); \
expose_scalar_map_methods(key_ty, rust_key_ty, ffi_key_ty, to_cpp_key, \
uint32_t, u32, uint32_t, value, cpp_value); \
expose_scalar_map_methods(key_ty, rust_key_ty, ffi_key_ty, to_cpp_key, \
float, f32, float, value, cpp_value); \
expose_scalar_map_methods(key_ty, rust_key_ty, ffi_key_ty, to_cpp_key, \
double, f64, double, value, cpp_value); \
expose_scalar_map_methods(key_ty, rust_key_ty, ffi_key_ty, to_cpp_key, bool, \
bool, bool, value, cpp_value); \
expose_scalar_map_methods(key_ty, rust_key_ty, ffi_key_ty, to_cpp_key, \
uint64_t, u64, uint64_t, value, cpp_value); \
expose_scalar_map_methods(key_ty, rust_key_ty, ffi_key_ty, to_cpp_key, \
int64_t, i64, int64_t, value, cpp_value); \
expose_scalar_map_methods( \
key_ty, rust_key_ty, ffi_key_ty, to_cpp_key, std::string, Bytes, \
google::protobuf::rust_internal::PtrAndLen, std::string(value.ptr, value.len), \
google::protobuf::rust_internal::PtrAndLen(cpp_value.data(), cpp_value.size())); \
expose_scalar_map_methods( \
key_ty, rust_key_ty, ffi_key_ty, to_cpp_key, std::string, ProtoStr, \
google::protobuf::rust_internal::PtrAndLen, std::string(value.ptr, value.len), \
#define expose_scalar_map_methods_for_key_type( \
key_ty, rust_key_ty, ffi_key_ty, to_cpp_key, to_ffi_key) \
expose_scalar_map_methods(key_ty, rust_key_ty, ffi_key_ty, to_cpp_key, \
to_ffi_key, int32_t, i32, int32_t, value, \
cpp_value); \
expose_scalar_map_methods(key_ty, rust_key_ty, ffi_key_ty, to_cpp_key, \
to_ffi_key, uint32_t, u32, uint32_t, value, \
cpp_value); \
expose_scalar_map_methods(key_ty, rust_key_ty, ffi_key_ty, to_cpp_key, \
to_ffi_key, float, f32, float, value, cpp_value); \
expose_scalar_map_methods(key_ty, rust_key_ty, ffi_key_ty, to_cpp_key, \
to_ffi_key, double, f64, double, value, \
cpp_value); \
expose_scalar_map_methods(key_ty, rust_key_ty, ffi_key_ty, to_cpp_key, \
to_ffi_key, bool, bool, bool, value, cpp_value); \
expose_scalar_map_methods(key_ty, rust_key_ty, ffi_key_ty, to_cpp_key, \
to_ffi_key, uint64_t, u64, uint64_t, value, \
cpp_value); \
expose_scalar_map_methods(key_ty, rust_key_ty, ffi_key_ty, to_cpp_key, \
to_ffi_key, int64_t, i64, int64_t, value, \
cpp_value); \
expose_scalar_map_methods( \
key_ty, rust_key_ty, ffi_key_ty, to_cpp_key, to_ffi_key, std::string, \
Bytes, google::protobuf::rust_internal::PtrAndLen, \
std::string(value.ptr, value.len), \
google::protobuf::rust_internal::PtrAndLen(cpp_value.data(), cpp_value.size())); \
expose_scalar_map_methods( \
key_ty, rust_key_ty, ffi_key_ty, to_cpp_key, to_ffi_key, std::string, \
ProtoStr, google::protobuf::rust_internal::PtrAndLen, \
std::string(value.ptr, value.len), \
google::protobuf::rust_internal::PtrAndLen(cpp_value.data(), cpp_value.size()));

expose_scalar_map_methods_for_key_type(int32_t, i32, int32_t, key);
expose_scalar_map_methods_for_key_type(uint32_t, u32, uint32_t, key);
expose_scalar_map_methods_for_key_type(bool, bool, bool, key);
expose_scalar_map_methods_for_key_type(uint64_t, u64, uint64_t, key);
expose_scalar_map_methods_for_key_type(int64_t, i64, int64_t, key);
expose_scalar_map_methods_for_key_type(std::string, ProtoStr,
google::protobuf::rust_internal::PtrAndLen,
std::string(key.ptr, key.len));
expose_scalar_map_methods_for_key_type(int32_t, i32, int32_t, key, cpp_key);
expose_scalar_map_methods_for_key_type(uint32_t, u32, uint32_t, key, cpp_key);
expose_scalar_map_methods_for_key_type(bool, bool, bool, key, cpp_key);
expose_scalar_map_methods_for_key_type(uint64_t, u64, uint64_t, key, cpp_key);
expose_scalar_map_methods_for_key_type(int64_t, i64, int64_t, key, cpp_key);
expose_scalar_map_methods_for_key_type(
std::string, ProtoStr, google::protobuf::rust_internal::PtrAndLen,
std::string(key.ptr, key.len),
google::protobuf::rust_internal::PtrAndLen(cpp_key.data(), cpp_key.size()));

#undef expose_scalar_map_methods
#undef expose_map_methods
Expand Down
Loading

0 comments on commit 035d6ec

Please sign in to comment.