Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce range wrapper for OneToManyRelations and VectorMembers #107

Merged
merged 2 commits into from
Aug 3, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions include/podio/RelationRange.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#ifndef PODIO_RELATIONRANGE_H
#define PODIO_RELATIONRANGE_H

#include <vector>
#include <iterator>

namespace podio {
/**
* A simple helper class that allows to return related objects in a way that
* makes it possible to use the return type in a range-based for loop.
*/
template<typename ReferenceType>
class RelationRange {
public:
using ConstIteratorType = typename std::vector<ReferenceType>::const_iterator;
RelationRange(ConstIteratorType begin, ConstIteratorType end) : m_begin(begin), m_end(end) {}

/// begin of the range (necessary for range-based for loop)
ConstIteratorType begin() const { return m_begin; }
/// end of the range (necessary for range-based for loop)
ConstIteratorType end() const { return m_end; }
/// convenience overload for size
size_t size() const { return std::distance(m_begin, m_end); }
private:
ConstIteratorType m_begin;
ConstIteratorType m_end;
};
}

#endif // PODIO_RELATIONRANGE_H
2 changes: 2 additions & 0 deletions python/podio_class_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,7 @@ def create_class(self, classname, definition):
refvectors = definition["OneToManyRelations"]
if len(refvectors) != 0:
datatype["includes"].add("#include <vector>")
datatype["includes"].add('#include "podio/RelationRange.h"')
for item in refvectors:
klass = item["type"]
if klass in self.requested_classes:
Expand Down Expand Up @@ -434,6 +435,7 @@ def create_class(self, classname, definition):
vectormembers = definition["VectorMembers"]
if len(vectormembers) != 0:
datatype["includes"].add("#include <vector>")
datatype["includes"].add('#include "podio/RelationRange.h"')
for item in vectormembers:
klass = item["type"]
if klass not in self.buildin_types and klass not in self.reader.components:
Expand Down
8 changes: 8 additions & 0 deletions python/templates/ConstRefVector.cc.template
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,11 @@ ${relationtype} Const${classname}::${get_relation}(unsigned int index) const {
}
else throw std::out_of_range ("index out of bounds for existing references");
}

podio::RelationRange<${relationtype}> Const${classname}::${get_relation}() const {
auto begin = m_obj->m_${relation}->begin();
std::advance(begin, m_obj->data.${relation}_begin);
auto end = m_obj->m_${relation}->begin();
std::advance(end, m_obj->data.${relation}_end);
return {begin, end};
}
1 change: 1 addition & 0 deletions python/templates/ConstRefVector.h.template
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
${relationtype} ${get_relation}(unsigned int) const;
std::vector<${relationtype}>::const_iterator ${relation}_begin() const;
std::vector<${relationtype}>::const_iterator ${relation}_end() const;
podio::RelationRange<${relationtype}> ${get_relation}() const;
8 changes: 8 additions & 0 deletions python/templates/RefVector.cc.template
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,11 @@ ${relationtype} ${classname}::${get_relation}(unsigned int index) const {
}
else throw std::out_of_range ("index out of bounds for existing references");
}

podio::RelationRange<${relationtype}> ${classname}::${get_relation}() const {
auto begin = m_obj->m_${relation}->begin();
std::advance(begin, m_obj->data.${relation}_begin);
auto end = m_obj->m_${relation}->begin();
std::advance(end, m_obj->data.${relation}_end);
return {begin, end};
}
2 changes: 1 addition & 1 deletion python/templates/RefVector.h.template
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
${relationtype} ${get_relation}(unsigned int) const;
std::vector<${relationtype}>::const_iterator ${relation}_begin() const;
std::vector<${relationtype}>::const_iterator ${relation}_end() const;

podio::RelationRange<${relationtype}> ${get_relation}() const;
6 changes: 5 additions & 1 deletion tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ foreach( sourcefile ${executables} )
add_executable( ${name} ${sourcefile} )
target_link_libraries( ${name} TestDataModel podio::podioRootIO)
endforeach( sourcefile ${executables} )

#--- Adding tests --------------------------------------------------------------

add_test(NAME write COMMAND write)
Expand All @@ -50,6 +49,11 @@ set_property(TEST read-multiple PROPERTY ENVIRONMENT
ROOT_INCLUDE_PATH=${CMAKE_SOURCE_DIR}/tests/datamodel:${ROOT_INCLUDE_PATH})
set_property(TEST read-multiple PROPERTY DEPENDS write)

add_test(NAME relation_range COMMAND relation_range)
set_property(TEST relation_range PROPERTY ENVIRONMENT
LD_LIBRARY_PATH=${CMAKE_CURRENT_BINARY_DIR}:${CMAKE_BINARY_DIR}/src:$ENV{LD_LIBRARY_PATH}
ROOT_INCLUDE_PATH=${CMAKE_SOURCE_DIR}/tests/datamodel:${ROOT_INCLUDE_PATH})

add_test( NAME pyunittest COMMAND python -m unittest discover -s ${CMAKE_SOURCE_DIR}/python)
set_property(TEST pyunittest
PROPERTY ENVIRONMENT
Expand Down
210 changes: 210 additions & 0 deletions tests/relation_range.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
#include "datamodel/ExampleMCCollection.h"
#include "datamodel/ExampleReferencingType.h"
#include "datamodel/ExampleWithVectorMember.h"

#include "podio/EventStore.h"
#include "podio/ROOTWriter.h"
#include "podio/ROOTReader.h"

#include <vector>
#include <string>
#include <sstream>
#include <stdexcept>

#define LOCATION() \
std::string(__FILE__) + std::string(":") + std::to_string(__LINE__) + \
std::string(" in ") + std::string(__PRETTY_FUNCTION__)

#define ASSERT_CONDITION( condition, message ) \
{ \
if ( !(condition) ) { \
throw std::runtime_error( LOCATION() + std::string(": ") + message); \
} \
}

#define ASSERT_EQUAL( left, right, message ) \
{ \
std::stringstream msg; \
msg << message << " | expected: " << right << " - actual: " << left; \
ASSERT_CONDITION(left == right, msg.str()) \
}

void fillExampleMCCollection(ExampleMCCollection& collection)
{
for (int i = 0; i < 10; ++i) {
ExampleMC mcp;
mcp.PDG(i);
collection.push_back(mcp);
}

// only daughters
auto mcp = collection[0];
mcp.adddaughters(collection[1]);
mcp.adddaughters(collection[5]);
mcp.adddaughters(collection[4]);

// daughters and parents
mcp = collection[2];
mcp.adddaughters(collection[1]);
mcp.adddaughters(collection[9]);
mcp.addparents(collection[8]);
mcp.addparents(collection[6]);

// realistic case with possibly cyclical references
mcp = collection[6];
mcp.adddaughters(collection[2]);
mcp.addparents(collection[3]);
collection[3].adddaughters(collection[6]);
}


// TODO: break up into smaller cases
void doTestExampleMC(ExampleMCCollection const& collection)
{
// Empty
ASSERT_CONDITION(collection[7].daughters().size() == 0 && collection[7].parents().size() == 0,
"RelationRange of empty collection is not empty");
// alternatively check if a loop is entered
for (const auto& p: collection[7].daughters()) {
throw std::runtime_error("Range based for loop entered on a supposedly empty range");
}

ASSERT_EQUAL(collection[0].daughters().size(), 3, "Range has wrong size");
// check daughter relations are OK
std::vector<int> expectedPDG = {1, 5, 4};
int index = 0;
for (const auto& p : collection[0].daughters()) {
ASSERT_EQUAL(p.PDG(), expectedPDG[index],
"ExampleMC daughters range points to wrong particle (by PDG)");
index++;
}

// mothers and daughters
ASSERT_EQUAL(collection[2].daughters().size(), 2, "Range has wrong size");
ASSERT_EQUAL(collection[2].parents().size(), 2, "Range has wrong size");

expectedPDG = {1, 9};
index = 0;
for (const auto& p : collection[2].daughters()) {
ASSERT_EQUAL(p.PDG(), expectedPDG[index],
"ExampleMC daughters range points to wrong particle (by PDG)");
index++;
}
expectedPDG = {8, 6};
index = 0;
for (const auto& p : collection[2].parents()) {
ASSERT_EQUAL(p.PDG(), expectedPDG[index],
"ExampleMC parents range points to wrong particle (by PDG)");
index++;
}

// realistic case
auto mcp6 = collection[6];
ASSERT_EQUAL(mcp6.daughters().size(), 1, "Wrong number of daughters");
auto parent = mcp6.parents(0);
ASSERT_EQUAL(parent.daughters().size(), 1, "Wrong number of daughters");

for (const auto& p : mcp6.parents()) {
// loop will only run once as per the above assertion
ASSERT_EQUAL(p, parent, "parent-daughter relation is not as expected");
}
for (const auto& p : parent.daughters()) {
// loop will only run once as per the above assertion
ASSERT_EQUAL(p, mcp6, "daughter-parent relation is not as expected");
}
}

void testExampleMC()
{
ExampleMCCollection mcps;
fillExampleMCCollection(mcps);
doTestExampleMC(mcps);
}


void testExampleWithVectorMember()
{
ExampleWithVectorMember ex;
ex.addcount(1);
ex.addcount(2);
ex.addcount(10);

ASSERT_EQUAL(ex.count().size(), 3, "vector member ragnehas wrong size");

std::vector<int> expected = {1, 2, 10};
int index = 0;
for (const int c : ex.count()) {
ASSERT_EQUAL(c, expected[index], "wrong content in range-based for loop");
index++;
}
}

void testExampleReferencingType()
{
ExampleReferencingType ex;
ExampleReferencingType ex1;
ExampleReferencingType ex2;

ex.addRefs(ex1);
ex.addRefs(ex2);

ASSERT_EQUAL(ex.Refs().size(), 2, "Wrong number of references");

int index = 0;
for (const auto& e : ex.Refs()) {
if (index == 0) {
ASSERT_EQUAL(e, ex1, "First element of range is not as expected");
} else {
ASSERT_EQUAL(e, ex2, "Second element of range is not as expected");
}

index++;
}
}

void testWithIO()
{
podio::EventStore store;
podio::ROOTWriter writer("relation_range_io_test.root", &store);

auto& collection = store.create<ExampleMCCollection>("mcparticles");
writer.registerForWrite("mcparticles");

for (int i = 0; i < 10; ++i) {
fillExampleMCCollection(collection);
doTestExampleMC(collection);
writer.writeEvent();
store.clearCollections();
}
writer.finish();

podio::EventStore readStore;
podio::ROOTReader reader;
reader.openFile("relation_range_io_test.root");
readStore.setReader(&reader);

for (int i = 0; i < 10; ++i) {
auto& readColl = readStore.get<ExampleMCCollection>("mcparticles");
ASSERT_CONDITION(readColl.isValid(), "Collection 'mcparticles' should be present");
ASSERT_EQUAL(readColl.size(), 10, "'mcparticles' should have 10 entries");

doTestExampleMC(readColl);
readStore.clear();
reader.endOfEvent();
}
}

int main(int, char**)
{
try {
testExampleMC();
testExampleWithVectorMember();
testExampleReferencingType();
testWithIO();
} catch (std::runtime_error const& ex) {
std::cerr << ex.what() << std::endl;
return 1;
}

return 0;
}