Skip to content

Commit

Permalink
correct termination
Browse files Browse the repository at this point in the history
  • Loading branch information
lemunozm committed Jul 17, 2024
1 parent 12315e2 commit e7e30a8
Showing 1 changed file with 34 additions and 31 deletions.
65 changes: 34 additions & 31 deletions pallets/liquidity-pools/src/message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use serde::{
ser::{Error as _, SerializeTuple},
Deserialize, Serialize, Serializer,
};
use sp_runtime::{traits::ConstU32, DispatchError};
use sp_runtime::{traits::ConstU32, DispatchError, DispatchResult};
use sp_std::vec::Vec;

use crate::gmpf; // Generic Message Passing Format
Expand Down Expand Up @@ -72,78 +72,71 @@ impl TryInto<Domain> for SerializableDomain {
}
}

/// A message belonging to a batch that can not be a Batch.
/// A submessage is encoded with a u16 prefix containing its size
/// A message type that can not be a batch.
#[derive(Encode, Decode, Clone, PartialEq, Eq, RuntimeDebug, TypeInfo, MaxEncodedLen)]
pub struct SubMessage(Box<Message>);
pub struct NoBatchMessage(Box<Message>);

impl Serialize for SubMessage {
fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
let encoded = gmpf::to_vec(&self.0).map_err(|e| S::Error::custom(e.to_string()))?;

// Serializing as bytes automatically encodes the prefix size
encoded.serialize(serializer)
}
}

impl<'de> Deserialize<'de> for SubMessage {
fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
let (_, msg) = <(u16, Message)>::deserialize(deserializer)?;
Self::try_from(msg).map_err(|e| D::Error::custom::<&'static str>(e.into()))
}
}

impl TryFrom<Message> for SubMessage {
impl TryFrom<Message> for NoBatchMessage {
type Error = DispatchError;

fn try_from(message: Message) -> Result<Self, DispatchError> {
match message {
Message::Batch { .. } => Err(DispatchError::Other("Batch messages can not be nested")),
Message::Batch { .. } => Err(DispatchError::Other("A submessage can not be a batch")),
_ => Ok(Self(Box::new(message))),
}
}
}

/// We need an spetial serialization/deserialization for batches
#[derive(Encode, Decode, Clone, PartialEq, Eq, RuntimeDebug, TypeInfo, MaxEncodedLen, Default)]
pub struct BatchMessages(BoundedVec<SubMessage, ConstU32<MAX_BATCH_MESSAGES>>);
pub struct BatchMessages(BoundedVec<NoBatchMessage, ConstU32<MAX_BATCH_MESSAGES>>);

impl Serialize for BatchMessages {
fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
// Serializing as a tuple to avoid the prefix length used for dynamic lists
let mut tuple = serializer.serialize_tuple(self.0.len())?;
for msg in self.0.iter() {
tuple.serialize_element(msg)?;
let encoded = gmpf::to_vec(&msg.0).map_err(|e| S::Error::custom(e.to_string()))?;

// Serializing as bytes automatically encodes the prefix size
tuple.serialize_element(&encoded)?;
}
tuple.end()
}
}

impl<'de> Deserialize<'de> for BatchMessages {
fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
// We need a custom visitor because we do not know the length upfront
struct MsgVisitor;

impl<'de> Visitor<'de> for MsgVisitor {
type Value = BatchMessages;

fn expecting(&self, formatter: &mut sp_std::fmt::Formatter) -> sp_std::fmt::Result {
formatter.write_str("A sequence of pairs size-message")
formatter.write_str("A sequence of pairs size-submessage")
}

fn visit_seq<A: SeqAccess<'de>>(self, mut seq: A) -> Result<Self::Value, A::Error> {
let mut batch = BatchMessages::default();

while let Some(msg) = seq.next_element::<SubMessage>().unwrap_or(None) {
// We only stop on error trying to deserialize the length of the submessage.
// The error will happen when we reach EOF
while let Some(_) = seq.next_element::<u16>().unwrap_or(None) {
let msg = seq
.next_element()?
.ok_or(A::Error::custom("expected submessage"))?;

batch
.0
.try_push(msg)
.map_err(|_| A::Error::custom("Batch limit reached"))?;
.try_add(msg)
.map_err(|e| A::Error::custom::<&'static str>(e.into()))?;
}

Ok(batch)
}
}

deserializer.deserialize_tuple(MAX_BATCH_MESSAGES as usize, MsgVisitor)
let limit = MAX_BATCH_MESSAGES as usize * 2; // Lengths and messages
deserializer.deserialize_tuple(limit, MsgVisitor)
}
}

Expand All @@ -162,6 +155,16 @@ impl TryFrom<Vec<Message>> for BatchMessages {
}
}

impl BatchMessages {
pub fn try_add(&mut self, message: Message) -> DispatchResult {
self.0
.try_push(message.try_into()?)
.map_err(|_| DispatchError::Other("Batch limit reached"))?;

Ok(())
}
}

/// A LiquidityPools Message
#[derive(
Encode,
Expand Down

0 comments on commit e7e30a8

Please sign in to comment.