Skip to content

Commit

Permalink
Add a getAvailableCollections method in python (#536)
Browse files Browse the repository at this point in the history
* Add a getAvailableCollections method in python

* Deprecated .collections

---------

Co-authored-by: jmcarcell <jmcarcell@users.noreply.github.com>
  • Loading branch information
jmcarcell and jmcarcell authored Jan 12, 2024
1 parent 308b33f commit 532234f
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 8 deletions.
14 changes: 12 additions & 2 deletions python/podio/frame.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/usr/bin/env python3
"""Module for the python bindings of the podio::Frame"""

# pylint: disable-next=import-error # gbl is a dynamic module from cppyy
import warnings
import cppyy

import ROOT
Expand Down Expand Up @@ -110,14 +110,24 @@ def __init__(self, data=None):

self._param_key_types = self._get_param_keys_types()

def getAvailableCollections(self):
"""Get the currently available collection (names) from this Frame.
Returns:
tuple(str): The names of the available collections from this Frame.
"""
return tuple(str(s) for s in self._frame.getAvailableCollections())

@property
def collections(self):
"""Get the currently available collection (names) from this Frame.
Returns:
tuple(str): The names of the available collections from this Frame.
"""
return tuple(str(s) for s in self._frame.getAvailableCollections())
warnings.warn('WARNING: collections is deprecated, use getAvailableCollections()'
' (like in C++) instead', FutureWarning)
return self.getAvailableCollections()

def get(self, name):
"""Get a collection from the Frame by name.
Expand Down
9 changes: 5 additions & 4 deletions python/podio/test_Frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,12 @@ def test_frame_invalid_access(self):
def test_frame_put_collection(self):
"""Check that putting a collection works as expected"""
frame = Frame()
self.assertEqual(frame.collections, tuple())
self.assertEqual(frame.getAvailableCollections(), tuple())

hits = ExampleHitCollection()
hits.create()
hits2 = frame.put(hits, "hits_from_python")
self.assertEqual(frame.collections, tuple(["hits_from_python"]))
self.assertEqual(frame.getAvailableCollections(), tuple(["hits_from_python"]))
# The original collection is gone at this point, and ideally just leaves an
# empty shell
self.assertEqual(len(hits), 0)
Expand Down Expand Up @@ -116,8 +116,9 @@ def setUp(self):

def test_frame_collections(self):
"""Check that all expected collections are available."""
self.assertEqual(set(self.event.collections), EXPECTED_COLL_NAMES)
self.assertEqual(set(self.other_event.collections), EXPECTED_COLL_NAMES.union(EXPECTED_EXTENSION_COLL_NAMES))
self.assertEqual(set(self.event.getAvailableCollections()), EXPECTED_COLL_NAMES)
self.assertEqual(set(self.other_event.getAvailableCollections()),
EXPECTED_COLL_NAMES.union(EXPECTED_EXTENSION_COLL_NAMES))

# Not going over all collections here, as that should all be covered by the
# c++ test cases; Simply picking a few and doing some basic tests
Expand Down
4 changes: 2 additions & 2 deletions tools/podio-dump
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def print_frame_detailed(frame):
frame (podio.Frame): The frame to print
"""
print('Collections:')
for name in sorted(frame.collections, key=str.casefold):
for name in sorted(frame.getAvailableCollections(), key=str.casefold):
coll = frame.get(name)
print(name, flush=True)
coll.print()
Expand All @@ -56,7 +56,7 @@ def print_frame_overview(frame):
frame (podio.Frame): The frame to print
"""
rows = []
for name in sorted(frame.collections, key=str.casefold):
for name in sorted(frame.getAvailableCollections(), key=str.casefold):
coll = frame.get(name)
rows.append(
(name, coll.getValueTypeName().data(), len(coll), f'{coll.getID():0>8x}')
Expand Down

0 comments on commit 532234f

Please sign in to comment.