diff --git a/src/libraries/JANA/CMakeLists.txt b/src/libraries/JANA/CMakeLists.txt index ffb7ea743..375d1f3e2 100644 --- a/src/libraries/JANA/CMakeLists.txt +++ b/src/libraries/JANA/CMakeLists.txt @@ -53,11 +53,6 @@ set(JANA2_SOURCES Compatibility/md5.c ) -if (${USE_PODIO}) - list(APPEND JANA2_SOURCES - Podio/JFactoryPodioT.cc - ) -endif() if (NOT ${USE_XERCES}) message(STATUS "Skipping support for libJANA's JGeometryXML because USE_XERCES=Off") diff --git a/src/libraries/JANA/Components/JBasicDataBundle.h b/src/libraries/JANA/Components/JBasicDataBundle.h new file mode 100644 index 000000000..fe5053bc0 --- /dev/null +++ b/src/libraries/JANA/Components/JBasicDataBundle.h @@ -0,0 +1,108 @@ + +#pragma once + +#include +#include +#include + +#ifdef JANA2_HAVE_ROOT +#include +#endif + +class JBasicDataBundle : public JDataBundle { + bool m_is_persistent = false; + bool m_not_object_owner = false; + bool m_write_to_output = true; + +public: + JBasicDataBundle() = default; + ~JBasicDataBundle() override = default; + void SetPersistentFlag(bool persistent) { m_is_persistent = persistent; } + void SetNotOwnerFlag(bool not_owner) { m_not_object_owner = not_owner; } + void SetWriteToOutputFlag(bool write_to_output) { m_write_to_output = write_to_output; } + + bool GetPersistentFlag() { return m_is_persistent; } + bool GetNotOwnerFlag() { return m_not_object_owner; } + bool GetWriteToOutputFlag() { return m_write_to_output; } +}; + + + +template +class JBasicDataBundleT : public JBasicDataBundle { +private: + std::vector m_data; + +public: + JBasicDataBundleT(); + void ClearData() override; + size_t GetSize() const override { return m_data.size();} + + std::vector& GetData() { return m_data; } + + /// EnableGetAs generates a vtable entry so that users may extract the + /// contents of this JFactoryT from the type-erased JFactory. The user has to manually specify which upcasts + /// to allow, and they have to do so for each instance. It is recommended to do so in the constructor. + /// Note that EnableGetAs() is called automatically. + template void EnableGetAs (); + + // The following specializations allow automatically adding standard types (e.g. JObject) using things like + // std::is_convertible(). The std::true_type version defers to the standard EnableGetAs(). + template void EnableGetAs(std::true_type) { EnableGetAs(); } + template void EnableGetAs(std::false_type) {} +}; + +// Template definitions + +template +JBasicDataBundleT::JBasicDataBundleT() { + SetTypeName(JTypeInfo::demangle()); + EnableGetAs(); + EnableGetAs( std::is_convertible() ); // Automatically add JObject if this can be converted to it +#ifdef JANA2_HAVE_ROOT + EnableGetAs( std::is_convertible() ); // Automatically add TObject if this can be converted to it +#endif +} + +template +void JBasicDataBundleT::ClearData() { + + // ClearData won't do anything if Init() hasn't been called + if (GetStatus() == Status::Empty) { + return; + } + // ClearData() does nothing if persistent flag is set. + // User must manually recycle data, e.g. during ChangeRun() + if (GetPersistentFlag()) { + return; + } + + // Assuming we _are_ the object owner, delete the underlying jobjects + if (!GetNotOwnerFlag()) { + for (auto p : m_data) delete p; + } + m_data.clear(); + SetStatus(Status::Empty); +} + +template +template +void JBasicDataBundleT::EnableGetAs() { + + auto upcast_lambda = [this]() { + std::vector results; + for (auto t : m_data) { + results.push_back(static_cast(t)); + } + return results; + }; + + auto key = std::type_index(typeid(S)); + using upcast_fn_t = std::function()>; + mUpcastVTable[key] = std::unique_ptr(new JAnyT(std::move(upcast_lambda))); +} + + + + + diff --git a/src/libraries/JANA/Components/JBasicOutput.h b/src/libraries/JANA/Components/JBasicOutput.h new file mode 100644 index 000000000..7465e0bf9 --- /dev/null +++ b/src/libraries/JANA/Components/JBasicOutput.h @@ -0,0 +1,27 @@ +#pragma once +#include +#include + +namespace jana::components { + +template +class Output : public JHasFactoryOutputs::OutputBase { + std::vector m_data; + +public: + Output(JHasFactoryOutputs* owner, std::string default_tag_name="") { + owner->RegisterOutput(this); + this->collection_names.push_back(default_tag_name); + this->type_name = JTypeInfo::demangle(); + } + + std::vector& operator()() { return m_data; } + +protected: + void PutCollections(const JEvent& event) override { + event.Insert(m_data, this->collection_names[0]); + } + void Reset() override { } +}; + +} // jana::components diff --git a/src/libraries/JANA/Components/JDataBundle.h b/src/libraries/JANA/Components/JDataBundle.h new file mode 100644 index 000000000..c370076c0 --- /dev/null +++ b/src/libraries/JANA/Components/JDataBundle.h @@ -0,0 +1,102 @@ +// Copyright 2024, Jefferson Science Associates, LLC. +// Subject to the terms in the LICENSE file found in the top-level directory. + +#pragma once + +#include +#include +#include + +#include +#include +#include +#include +#include +#include + + +class JFactory; + +class JDataBundle { +public: + // Typedefs + enum class Status { Empty, Created, Inserted, InsertedViaGetObjects }; + +private: + // Fields + Status m_status = Status::Empty; + std::string m_unique_name; + JOptional m_short_name; + std::string m_type_name; + JFactory* m_factory = nullptr; + JOptional m_inner_type_index; + mutable JCallGraphRecorder::JDataOrigin m_insert_origin = JCallGraphRecorder::ORIGIN_NOT_AVAILABLE; + +protected: + std::unordered_map> mUpcastVTable; + +public: + // Interface + JDataBundle() = default; + virtual ~JDataBundle() = default; + virtual size_t GetSize() const = 0; + virtual void ClearData() = 0; + + // Getters + Status GetStatus() const { return m_status; } + std::string GetUniqueName() const { return m_unique_name; } + JOptional GetShortName() const { return m_short_name; } + std::string GetTypeName() const { return m_type_name; } + JOptional GetTypeIndex() const { return m_inner_type_index; } + JCallGraphRecorder::JDataOrigin GetInsertOrigin() const { return m_insert_origin; } ///< If objects were placed here by JEvent::Insert() this records whether that call was made from a source or factory. + JFactory* GetFactory() const { return m_factory; } + + // Setters + void SetStatus(Status s) { m_status = s;} + void SetUniqueName(std::string unique_name) { m_unique_name = unique_name; } + void SetShortName(std::string short_name) { m_short_name = short_name; } + void SetTypeName(std::string type_name) { m_type_name = type_name; } + void SetInsertOrigin(JCallGraphRecorder::JDataOrigin origin) { m_insert_origin = origin; } ///< Called automatically by JEvent::Insert() to records whether that call was made by a source or factory. + void SetFactory(JFactory* fac) { m_factory = fac; } + + // Templates + // + /// Generically access the encapsulated data, performing an upcast if necessary. This is useful for extracting data from + /// all JFactories where T extends a parent class S, such as JObject or TObject, in contexts where T is not known + /// or it would introduce an unwanted coupling. The main application is for building DSTs. + /// + /// Be aware of the following caveats: + /// - The factory's object type must not use virtual inheritance. + /// - If JFactory::Process hasn't already been called, this will return an empty vector. This will NOT call JFactory::Process. + /// - Someone must call JFactoryT::EnableGetAs, preferably the constructor. Otherwise, this will return an empty vector. + /// - If S isn't a base class of T, this will return an empty vector. + template + std::vector GetAs(); + +}; + + + +// Because C++ doesn't support templated virtual functions, we implement our own dispatch table, mUpcastVTable. +// This means that the JFactoryT is forced to manually populate this table by calling JFactoryT::EnableGetAs. +// We have the option to make the vtable be a static member of JFactoryT, but we have chosen not to because: +// +// 1. It would be inconsistent with the fact that the user is supposed to call EnableGetAs in the ctor +// 2. People in the future may want to generalize GetAs to support user-defined S* -> T* conversions (which I don't recommend) +// 3. The size of the vtable is expected to be very small (<10 elements, most likely 2) + +template +std::vector JDataBundle::GetAs() { + std::vector results; + auto ti = std::type_index(typeid(S)); + auto search = mUpcastVTable.find(ti); + if (search != mUpcastVTable.end()) { + using upcast_fn_t = std::function()>; + auto temp = static_cast*>(&(*search->second)); + upcast_fn_t upcast_fn = temp->t; + results = upcast_fn(); + } + return results; +} + + diff --git a/src/libraries/JANA/Components/JHasFactoryOutputs.h b/src/libraries/JANA/Components/JHasFactoryOutputs.h new file mode 100644 index 000000000..c266ea773 --- /dev/null +++ b/src/libraries/JANA/Components/JHasFactoryOutputs.h @@ -0,0 +1,40 @@ + +#pragma once +#include +#include + +class JEvent; + +namespace jana::components { + + +class JHasFactoryOutputs { +public: + struct OutputBase { + protected: + std::vector> m_databundles; + bool m_is_variadic = false; + public: + const std::vector>& GetDataBundles() const { return m_databundles; } + virtual void StoreData(const JEvent& event) = 0; + virtual void Reset() = 0; + }; + +private: + std::vector m_outputs; + +public: + const std::vector& GetOutputs() const { + return m_outputs; + } + + void RegisterOutput(OutputBase* output) { + m_outputs.push_back(output); + } +}; + + + +} // namespace jana::components + + diff --git a/src/libraries/JANA/Components/JPodioDataBundle.h b/src/libraries/JANA/Components/JPodioDataBundle.h new file mode 100644 index 000000000..d27e20de7 --- /dev/null +++ b/src/libraries/JANA/Components/JPodioDataBundle.h @@ -0,0 +1,42 @@ +// Copyright 2024, Jefferson Science Associates, LLC. +// Subject to the terms in the LICENSE file found in the top-level directory. + +#pragma once + +#include +#include +#include + + +class JPodioDataBundle : public JDataBundle { + +private: + const podio::CollectionBase* m_collection = nullptr; + +public: + size_t GetSize() const override { + if (m_collection == nullptr) { + return 0; + } + return m_collection->size(); + } + + virtual void ClearData() override { + m_collection = nullptr; + SetStatus(JDataBundle::Status::Empty); + // Podio clears the data itself when the frame is destroyed. + // Until then, the collection is immutable. + // + // Consider: Instead of putting the frame in its own JFactory, maybe we + // want to maintain a shared_ptr to the frame here, and delete the + // the reference on ClearData(). Thus, the final call to ClearData() + // for each events deletes the frame and actually frees the data. + // This would let us support multiple frames within one event, though + // it might also prevent the user from accessing frames directly. + } + + const podio::CollectionBase* GetCollection() const { return m_collection; } + void SetCollection(const podio::CollectionBase* collection) { m_collection = collection; } +}; + + diff --git a/src/libraries/JANA/Components/JPodioOutput.h b/src/libraries/JANA/Components/JPodioOutput.h new file mode 100644 index 000000000..de4d6d476 --- /dev/null +++ b/src/libraries/JANA/Components/JPodioOutput.h @@ -0,0 +1,122 @@ +#pragma once +#include +#include +#include +#include +#include +#include + + +namespace jana::components { + + +template +class PodioOutput : public JHasFactoryOutputs::OutputBase { +private: + std::unique_ptr m_transient_collection; + JPodioDataBundle* m_podio_databundle; +public: + PodioOutput(JHasFactoryOutputs* owner, std::string default_collection_name="") { + owner->RegisterOutput(this); + auto bundle = std::make_unique(); + bundle->SetUniqueName(default_collection_name); + bundle->SetTypeName(JTypeInfo::demangle()); + m_podio_databundle = bundle.get(); + m_databundles.push_back(std::move(bundle)); + m_transient_collection = std::move(std::make_unique()); + } + + std::unique_ptr& operator()() { return m_transient_collection; } + + JPodioDataBundle* GetDataBundle() const { return m_podio_databundle; } + + +protected: + + void StoreData(const JEvent& event) override { + podio::Frame* frame; + try { + frame = const_cast(event.GetSingle()); + if (frame == nullptr) { + frame = new podio::Frame; + event.Insert(frame); + } + } + catch (...) { + frame = new podio::Frame; + event.Insert(frame); + } + + frame->put(std::move(m_transient_collection), m_podio_databundle->GetUniqueName()); + const auto* moved = &frame->template get(m_podio_databundle->GetUniqueName()); + m_transient_collection = nullptr; + m_podio_databundle->SetCollection(moved); + } + void Reset() override { + m_transient_collection = std::move(std::make_unique()); + } +}; + + +template +class VariadicPodioOutput : public JHasFactoryOutputs::OutputBase { +private: + std::vector> m_collections; + std::vector m_databundles; + +public: + VariadicPodioOutput(JHasFactoryOutputs* owner, std::vector default_collection_names={}) { + owner->RegisterOutput(this); + this->m_is_variadic = true; + for (const std::string& name : default_collection_names) { + auto coll = std::make_unique(); + coll->SetUniqueName(name); + coll->SetTypeName(JTypeInfo::demangle()); + m_collections.push_back(std::move(coll)); + } + for (auto& coll_name : this->collection_names) { + m_collections.push_back(std::make_unique()); + } + } + void StoreData(const JEvent& event) override { + if (m_collections.size() != this->collection_names.size()) { + throw JException("VariadicPodioOutput InsertCollection failed: Declared %d collections, but provided %d.", this->collection_names.size(), m_collections.size()); + } + + podio::Frame* frame; + try { + frame = const_cast(event.GetSingle()); + if (frame == nullptr) { + frame = new podio::Frame; + event.Insert(frame); + } + } + catch (...) { + frame = new podio::Frame; + event.Insert(frame); + } + + size_t i = 0; + for (auto& collection : m_collections) { + frame->put(std::move(std::move(collection)), m_databundles[i]->GetUniqueName()); + const auto* moved = &frame->template get(m_databundles[i]->GetUniqueName()); + collection = nullptr; + const auto &databundle = dynamic_cast(m_databundles[i]); + databundle->SetCollection(moved); + i += 1; + } + } + void Reset() override { + m_collections.clear(); + for (auto& coll : this->m_databundles) { + coll->ClearData(); + } + for (auto& coll_name : this->collection_names) { + m_collections.push_back(std::make_unique()); + } + } +}; + + +} // namespace jana::components + diff --git a/src/libraries/JANA/JEvent.h b/src/libraries/JANA/JEvent.h index e60e7cbf7..f4699be3e 100644 --- a/src/libraries/JANA/JEvent.h +++ b/src/libraries/JANA/JEvent.h @@ -9,9 +9,8 @@ #include #include #include - #include - +#include #include #include #include @@ -19,18 +18,15 @@ #include #include +#include #include #include #include -#include #include -#include #if JANA2_HAVE_PODIO -#include -namespace podio { -class CollectionBase; -} +#include +#include #endif class JApplication; @@ -53,19 +49,6 @@ class JEvent : public std::enable_shared_from_this void SetFactorySet(JFactorySet* aFactorySet) { delete mFactorySet; mFactorySet = aFactorySet; -#if JANA2_HAVE_PODIO - // Maintain the index of PODIO factories - for (JFactory* factory : mFactorySet->GetAllFactories()) { - if (dynamic_cast(factory) != nullptr) { - auto tag = factory->GetTag(); - auto it = mPodioFactories.find(tag); - if (it != mPodioFactories.end()) { - throw JException("SetFactorySet failed because PODIO factory tag '%s' is not unique", tag.c_str()); - } - mPodioFactories[tag] = factory; - } - } -#endif } JFactorySet* GetFactorySet() const { return mFactorySet; } @@ -102,11 +85,14 @@ class JEvent : public std::enable_shared_from_this #if JANA2_HAVE_PODIO std::vector GetAllCollectionNames() const; const podio::CollectionBase* GetCollectionBase(std::string name, bool throw_on_missing=true) const; - template const typename JFactoryPodioT::CollectionT* GetCollection(std::string name, bool throw_on_missing=true) const; - template JFactoryPodioT* InsertCollection(typename JFactoryPodioT::CollectionT&& collection, std::string name); - template JFactoryPodioT* InsertCollectionAlreadyInFrame(const podio::CollectionBase* collection, std::string name); + template const typename PodioT::collection_type* GetCollection(std::string name, bool throw_on_missing=true) const; + template JPodioDataBundle* InsertCollection(typename PodioT::collection_type&& collection, std::string name); + template JPodioDataBundle* InsertCollectionAlreadyInFrame(const podio::CollectionBase* collection, std::string name); #endif + // EXPERIMENTAL NEW THING + JDataBundle* GetDataBundle(const std::string& name, bool create) const; + //SETTERS void SetRunNumber(int32_t aRunNumber){mRunNumber = aRunNumber;} void SetEventNumber(uint64_t aEventNumber){mEventNumber = aEventNumber;} @@ -211,10 +197,6 @@ class JEvent : public std::enable_shared_from_this int64_t mEventIndex = -1; - -#if JANA2_HAVE_PODIO - std::map mPodioFactories; -#endif }; /// Insert() allows an EventSource to insert items directly into the JEvent, @@ -224,51 +206,73 @@ class JEvent : public std::enable_shared_from_this template inline JFactoryT* JEvent::Insert(T* item, const std::string& tag) const { + std::string object_name = JTypeInfo::demangle(); + std::string resolved_tag = tag; if (mUseDefaultTags && tag.empty()) { - auto defaultTag = mDefaultTags.find(JTypeInfo::demangle()); + auto defaultTag = mDefaultTags.find(object_name); if (defaultTag != mDefaultTags.end()) resolved_tag = defaultTag->second; } - auto factory = mFactorySet->GetFactory(resolved_tag); - if (factory == nullptr) { - factory = new JFactoryT; - factory->SetTag(tag); - factory->SetLevel(mFactorySet->GetLevel()); - mFactorySet->Add(factory); + auto untyped_factory = mFactorySet->GetFactory(std::type_index(typeid(T)), object_name, resolved_tag); + JFactoryT* typed_factory; + if (untyped_factory == nullptr) { + typed_factory = new JFactoryT; + typed_factory->SetTag(tag); + typed_factory->SetLevel(mFactorySet->GetLevel()); + mFactorySet->Add(typed_factory); } - factory->Insert(item); - factory->SetInsertOrigin( mCallGraph.GetInsertDataOrigin() ); // (see note at top of JCallGraphRecorder.h) - return factory; + else { + typed_factory = dynamic_cast*>(untyped_factory); + if (typed_factory == nullptr) { + throw JException("Retrieved factory is not a JFactoryT!"); + } + } + typed_factory->Insert(item); + typed_factory->SetInsertOrigin( mCallGraph.GetInsertDataOrigin() ); // (see note at top of JCallGraphRecorder.h) + return typed_factory; } template inline JFactoryT* JEvent::Insert(const std::vector& items, const std::string& tag) const { + std::string object_name = JTypeInfo::demangle(); std::string resolved_tag = tag; if (mUseDefaultTags && tag.empty()) { - auto defaultTag = mDefaultTags.find(JTypeInfo::demangle()); + auto defaultTag = mDefaultTags.find(object_name); if (defaultTag != mDefaultTags.end()) resolved_tag = defaultTag->second; } - auto factory = mFactorySet->GetFactory(resolved_tag); - if (factory == nullptr) { - factory = new JFactoryT; - factory->SetTag(tag); - factory->SetLevel(mFactorySet->GetLevel()); - mFactorySet->Add(factory); + auto untyped_factory = mFactorySet->GetFactory(std::type_index(typeid(T)), object_name, resolved_tag); + JFactoryT* typed_factory; + if (untyped_factory == nullptr) { + typed_factory = new JFactoryT; + typed_factory->SetTag(tag); + typed_factory->SetLevel(mFactorySet->GetLevel()); + mFactorySet->Add(typed_factory); + } + else { + typed_factory = dynamic_cast*>(untyped_factory); + if (typed_factory == nullptr) { + throw JException("Retrieved factory is not a JFactoryT!"); + } } for (T* item : items) { - factory->Insert(item); + typed_factory->Insert(item); } - factory->SetStatus(JFactory::Status::Inserted); // for when items is empty - factory->SetCreationStatus(JFactory::CreationStatus::Inserted); // for when items is empty - factory->SetInsertOrigin( mCallGraph.GetInsertDataOrigin() ); // (see note at top of JCallGraphRecorder.h) - return factory; + typed_factory->SetStatus(JFactory::Status::Inserted); // for when items is empty + typed_factory->SetCreationStatus(JFactory::CreationStatus::Inserted); // for when items is empty + typed_factory->SetInsertOrigin( mCallGraph.GetInsertDataOrigin() ); // (see note at top of JCallGraphRecorder.h) + return typed_factory; } /// GetFactory() should be used with extreme care because it subverts the JEvent abstraction. /// Most historical uses of GetFactory are far better served by JMultifactory inline JFactory* JEvent::GetFactory(const std::string& object_name, const std::string& tag) const { - return mFactorySet->GetFactory(object_name, tag); + std::string resolved_tag = tag; + if (mUseDefaultTags && tag.empty()) { + auto defaultTag = mDefaultTags.find(object_name); + if (defaultTag != mDefaultTags.end()) resolved_tag = defaultTag->second; + } + return mFactorySet->GetFactory(object_name, resolved_tag); } /// GetAllFactories() should be used with extreme care because it subverts the JEvent abstraction. @@ -282,20 +286,27 @@ inline std::vector JEvent::GetAllFactories() const { template inline JFactoryT* JEvent::GetFactory(const std::string& tag, bool throw_on_missing) const { + std::string object_name = JTypeInfo::demangle(); std::string resolved_tag = tag; if (mUseDefaultTags && tag.empty()) { - auto defaultTag = mDefaultTags.find(JTypeInfo::demangle()); + auto defaultTag = mDefaultTags.find(object_name); if (defaultTag != mDefaultTags.end()) resolved_tag = defaultTag->second; } - auto factory = mFactorySet->GetFactory(resolved_tag); + auto factory = mFactorySet->GetFactory(std::type_index(typeid(T)), object_name, resolved_tag); if (factory == nullptr) { if (throw_on_missing) { - JException ex("Could not find JFactoryT<" + JTypeInfo::demangle() + "> with tag=" + tag); + JException ex("Could not find JFactory producing '%s' with tag '%s'", object_name.c_str(), resolved_tag.c_str()); ex.show_stacktrace = false; throw ex; } + return nullptr; }; - return factory; + auto typed_factory = dynamic_cast*>(factory); + if (typed_factory == nullptr) { + JException ex("JFactory producing '%s' with tag '%s' is not a JFactoryT!", object_name.c_str(), resolved_tag.c_str()); + throw ex; + }; + return typed_factory; } @@ -412,15 +423,23 @@ std::vector JEvent::Get(const std::string& tag, bool strict) const { /// wishes to examine them all together. template inline std::vector*> JEvent::GetFactoryAll(bool throw_on_missing) const { - auto factories = mFactorySet->GetAllFactories(); + std::vector*> results; + std::string object_name = JTypeInfo::demangle(); + auto factories = mFactorySet->GetAllFactories(std::type_index(typeid(T)), object_name); if (factories.size() == 0) { if (throw_on_missing) { - JException ex("Could not find any JFactoryT<" + JTypeInfo::demangle() + "> (from any tag)"); + JException ex("Could not find any JFactoryT<%s> (from any tag)", object_name.c_str()); ex.show_stacktrace = false; throw ex; } }; - return factories; + for (auto* fac : factories) { + auto fac_typed = dynamic_cast*>(fac); + if (fac_typed != nullptr) { + results.push_back(fac_typed); + } + } + return results; } /// GetAll returns all JObjects of (child) type T, regardless of tag. @@ -507,113 +526,139 @@ JFactoryT* JEvent::GetSingle(const T* &t, const char *tag, bool exception_if_ return fac; } +inline JDataBundle* JEvent::GetDataBundle(const std::string& name, bool create) const { + + auto* storage = mFactorySet->GetDataBundle(name); + if (storage == nullptr) return nullptr; + auto fac = storage->GetFactory(); + + if (fac != nullptr && create) { + + // The regenerate logic lives out here now + if ((storage->GetStatus() == JDataBundle::Status::Empty) || + (fac->TestFactoryFlag(JFactory::JFactory_Flags_t::REGENERATE))) { + + // If this was inserted, there would be no factory to run + // fac->Create() will short-circuit if something was already inserted + JCallGraphEntryMaker cg_entry(mCallGraph, fac); // times execution until this goes out of scope + fac->Create(this->shared_from_this()); + } + } + return storage; +} + #if JANA2_HAVE_PODIO inline std::vector JEvent::GetAllCollectionNames() const { - std::vector keys; - for (auto pair : mPodioFactories) { - keys.push_back(pair.first); - } - return keys; + return mFactorySet->GetAllDataBundleNames(); } inline const podio::CollectionBase* JEvent::GetCollectionBase(std::string name, bool throw_on_missing) const { - auto it = mPodioFactories.find(name); - if (it == mPodioFactories.end()) { - if (throw_on_missing) { - throw JException("No factory with tag '%s' found", name.c_str()); + auto* storage = GetDataBundle(name, true); + if (storage != nullptr) { + auto* podio_storage = dynamic_cast(storage); + if (podio_storage == nullptr) { + throw JException("Not a podio collection: %s", name.c_str()); } else { - return nullptr; + return podio_storage->GetCollection(); } } - JFactoryPodio* factory = dynamic_cast(it->second); - if (factory == nullptr) { - // Should be no way to get here if we encapsulate mPodioFactories correctly - throw JException("Factory with tag '%s' does not inherit from JFactoryPodio!", name.c_str()); + else if (throw_on_missing) { + throw JException("Collection not found: '%s'", name.c_str()); } - JCallGraphEntryMaker cg_entry(mCallGraph, it->second); // times execution until this goes out of scope - it->second->Create(this->shared_from_this()); - return factory->GetCollection(); - // TODO: Might be cheaper/simpler to obtain factory from mPodioFactories instead of mFactorySet + return nullptr; } template -const typename JFactoryPodioT::CollectionT* JEvent::GetCollection(std::string name, bool throw_on_missing) const { - JFactoryT* factory = GetFactory(name, throw_on_missing); - if (factory == nullptr) { - return nullptr; +const typename T::collection_type* JEvent::GetCollection(std::string name, bool throw_on_missing) const { + auto* coll = GetDataBundle(name, true); + if (coll != nullptr) { + auto* podio_coll = dynamic_cast(coll); + if (podio_coll == nullptr) { + throw JException("Not a podio collection: %s", name.c_str()); + } + else { + auto coll = podio_coll->GetCollection(); + auto typed_coll = dynamic_cast(coll); + if (typed_coll == nullptr) { + throw JException("Unable to cast Podio collection to %s", JTypeInfo::demangle().c_str()); + } + return typed_coll; + } } - JFactoryPodioT* typed_factory = dynamic_cast*>(factory); - if (typed_factory == nullptr) { - throw JException("Factory must inherit from JFactoryPodioT in order to use JEvent::GetCollection()"); + else if (throw_on_missing) { + throw JException("Collection not found: '%s'", name.c_str()); } - JCallGraphEntryMaker cg_entry(mCallGraph, typed_factory); // times execution until this goes out of scope - typed_factory->Create(this->shared_from_this()); - return static_cast::CollectionT*>(typed_factory->GetCollection()); + return nullptr; } -template -JFactoryPodioT* JEvent::InsertCollection(typename JFactoryPodioT::CollectionT&& collection, std::string name) { +template +JPodioDataBundle* JEvent::InsertCollection(typename PodioT::collection_type&& collection, std::string name) { /// InsertCollection inserts the provided PODIO collection into both the podio::Frame and then a JFactoryPodioT - auto frame = GetOrCreateFrame(shared_from_this()); + podio::Frame* frame = nullptr; + try { + frame = const_cast(GetSingle("")); + if (frame == nullptr) { + frame = new podio::Frame; + Insert(frame); + } + } + catch (...) { + frame = new podio::Frame; + Insert(frame); + } const auto& owned_collection = frame->put(std::move(collection), name); - return InsertCollectionAlreadyInFrame(&owned_collection, name); + return InsertCollectionAlreadyInFrame(&owned_collection, name); } -template -JFactoryPodioT* JEvent::InsertCollectionAlreadyInFrame(const podio::CollectionBase* collection, std::string name) { - /// InsertCollection inserts the provided PODIO collection into a JFactoryPodioT. It assumes that the collection pointer +template +JPodioDataBundle* JEvent::InsertCollectionAlreadyInFrame(const podio::CollectionBase* collection, std::string name) { + /// InsertCollection inserts the provided PODIO collection into a JPodioStorage. It assumes that the collection pointer /// is _already_ owned by the podio::Frame corresponding to this JEvent. This is meant to be used if you are starting out /// with a PODIO frame (e.g. a JEventSource that uses podio::ROOTFrameReader). - const auto* typed_collection = dynamic_cast(collection); + const auto* typed_collection = dynamic_cast(collection); if (typed_collection == nullptr) { throw JException("Attempted to insert a collection of the wrong type! name='%s', expected type='%s', actual type='%s'", - name.c_str(), JTypeInfo::demangle().c_str(), collection->getDataTypeName().data()); + name.c_str(), JTypeInfo::demangle().c_str(), collection->getDataTypeName().data()); } // Users are allowed to Insert with tag="" if and only if that tag gets resolved by default tags. if (mUseDefaultTags && name.empty()) { - auto defaultTag = mDefaultTags.find(JTypeInfo::demangle()); + auto defaultTag = mDefaultTags.find(JTypeInfo::demangle()); if (defaultTag != mDefaultTags.end()) name = defaultTag->second; } - // Retrieve factory if it already exists, else create it - JFactoryT* factory = mFactorySet->GetFactory(name); - if (factory == nullptr) { - factory = new JFactoryPodioT(); - factory->SetTag(name); - factory->SetLevel(GetLevel()); - mFactorySet->Add(factory); - - auto it = mPodioFactories.find(name); - if (it != mPodioFactories.end()) { - throw JException("InsertCollection failed because tag '%s' is not unique", name.c_str()); - } - mPodioFactories[name] = factory; + // Retrieve storage if it already exists, else create it + auto storage = mFactorySet->GetDataBundle(name); + + if (storage == nullptr) { + // No factories already registered this! E.g. from an event source + auto coll = new JPodioDataBundle; + coll->SetUniqueName(name); + coll->SetTypeName(JTypeInfo::demangle()); + coll->SetStatus(JDataBundle::Status::Inserted); + coll->SetInsertOrigin(mCallGraph.GetInsertDataOrigin()); + coll->SetCollection(typed_collection); + mFactorySet->Add(coll); + return coll; } - - // PODIO collections can only be inserted once, unlike regular JANA factories. - if (factory->GetStatus() == JFactory::Status::Inserted || - factory->GetStatus() == JFactory::Status::Processed) { - - throw JException("PODIO collections can only be inserted once, but factory with tag '%s' already has data", name.c_str()); - } - - // There's a chance that some user already added to the event's JFactorySet a - // JFactoryT which ISN'T a JFactoryPodioT. In this case, we cannot set the collection. - JFactoryPodioT* typed_factory = dynamic_cast*>(factory); - if (typed_factory == nullptr) { - throw JException("Factory must inherit from JFactoryPodioT in order to use JEvent::GetCollection()"); + else { + // This is overriding a factory + // Check that we only inserted this collection once + if (storage->GetStatus() != JDataBundle::Status::Empty) { + throw JException("Collections can only be inserted once!"); + } + auto typed_storage = dynamic_cast(storage); + typed_storage->SetCollection(typed_collection); + typed_storage->SetStatus(JDataBundle::Status::Inserted); + typed_storage->SetInsertOrigin(mCallGraph.GetInsertDataOrigin()); + return typed_storage; } - - typed_factory->SetCollectionAlreadyInFrame(typed_collection); - typed_factory->SetInsertOrigin( mCallGraph.GetInsertDataOrigin() ); - return typed_factory; } #endif // JANA2_HAVE_PODIO diff --git a/src/libraries/JANA/JFactory.cc b/src/libraries/JANA/JFactory.cc index 5a1dc7b0f..759591121 100644 --- a/src/libraries/JANA/JFactory.cc +++ b/src/libraries/JANA/JFactory.cc @@ -53,6 +53,9 @@ void JFactory::Create(const std::shared_ptr& event) { mPreviousRunNumber = run_number; } CallWithJExceptionWrapper("JFactory::Process", [&](){ Process(event); }); + for (auto& output : this->GetOutputs()) { + output->StoreData(*event); + } mStatus = Status::Processed; mCreationStatus = CreationStatus::Created; } diff --git a/src/libraries/JANA/JFactory.h b/src/libraries/JANA/JFactory.h index 823c324e7..5de3aee08 100644 --- a/src/libraries/JANA/JFactory.h +++ b/src/libraries/JANA/JFactory.h @@ -9,6 +9,7 @@ #include #include #include +#include #include #include @@ -22,12 +23,12 @@ class JEvent; class JObject; class JApplication; -class JFactory : public jana::components::JComponent { +class JFactory : public jana::components::JComponent, + public jana::components::JHasFactoryOutputs { public: enum class Status {Uninitialized, Unprocessed, Processed, Inserted}; enum class CreationStatus { NotCreatedYet, Created, Inserted, InsertedViaGetObjects, NeverCreated }; - enum JFactory_Flags_t { JFACTORY_NULL = 0x00, // Not used anywhere PERSISTENT = 0x01, // Used heavily. Possibly better served by JServices, hierarchical events, or event groups. @@ -36,10 +37,27 @@ class JFactory : public jana::components::JComponent { REGENERATE = 0x08 // Replaces JANA1 JFactory_base::use_factory and JFactory::GetCheckSourceFirst() }; +protected: + std::string mObjectName; + std::string mTag; + uint32_t mFlags = WRITE_TO_OUTPUT; + int32_t mPreviousRunNumber = -1; + std::unordered_map> mUpcastVTable; + + mutable Status mStatus = Status::Uninitialized; + CreationStatus mCreationStatus = CreationStatus::NotCreatedYet; + mutable JCallGraphRecorder::JDataOrigin m_insert_origin = JCallGraphRecorder::ORIGIN_NOT_AVAILABLE; // (see note at top of JCallGraphRecorder.h) + + +public: + JFactory() : mStatus(Status::Uninitialized) { + } + JFactory(std::string aName, std::string aTag = "") : mObjectName(std::move(aName)), mTag(std::move(aTag)), mStatus(Status::Uninitialized) { + SetTypeName(mObjectName); SetPrefix(aTag.empty() ? mObjectName : mObjectName + ":" + mTag); }; @@ -135,9 +153,13 @@ class JFactory : public jana::components::JComponent { } // Overloaded by JFactoryT - virtual std::type_index GetObjectType() const = 0; - virtual void ClearData() = 0; + virtual void ClearData() { + if (mStatus == Status::Processed) { + mStatus = Status::Unprocessed; + mCreationStatus = CreationStatus::NotCreatedYet; + } + }; // Overloaded by user Factories @@ -148,8 +170,10 @@ class JFactory : public jana::components::JComponent { virtual void Process(const std::shared_ptr&) {} virtual void Finish() {} + virtual std::type_index GetObjectType() const { throw JException("GetObjectType not supported for non-JFactoryT's"); } + virtual std::size_t GetNumObjects() const { - return 0; + throw JException("Not implemented!"); } @@ -173,23 +197,15 @@ class JFactory : public jana::components::JComponent { void DoInit(); void Summarize(JComponentSummary& summary) const override; + virtual void Set(const std::vector &) { + throw JException("Not implemented!"); + }; - virtual void Set(const std::vector &data) = 0; - virtual void Insert(JObject *data) = 0; - - -protected: - - std::string mObjectName; - std::string mTag; - uint32_t mFlags = WRITE_TO_OUTPUT; - int32_t mPreviousRunNumber = -1; - std::unordered_map> mUpcastVTable; + virtual void Insert(JObject*) { + throw JException("Not implemented!"); + }; - mutable Status mStatus = Status::Uninitialized; - mutable JCallGraphRecorder::JDataOrigin m_insert_origin = JCallGraphRecorder::ORIGIN_NOT_AVAILABLE; // (see note at top of JCallGraphRecorder.h) - CreationStatus mCreationStatus = CreationStatus::NotCreatedYet; }; // Because C++ doesn't support templated virtual functions, we implement our own dispatch table, mUpcastVTable. diff --git a/src/libraries/JANA/JFactoryGenerator.h b/src/libraries/JANA/JFactoryGenerator.h index e774801ee..62eb6eec2 100644 --- a/src/libraries/JANA/JFactoryGenerator.h +++ b/src/libraries/JANA/JFactoryGenerator.h @@ -4,8 +4,9 @@ #pragma once -#include #include +#include + class JApplication; diff --git a/src/libraries/JANA/JFactorySet.cc b/src/libraries/JANA/JFactorySet.cc index 3b29258dc..8f6053c71 100644 --- a/src/libraries/JANA/JFactorySet.cc +++ b/src/libraries/JANA/JFactorySet.cc @@ -3,8 +3,10 @@ // Subject to the terms in the LICENSE file found in the top-level directory. #include -#include +#include +#include +#include #include "JFactorySet.h" #include "JFactory.h" #include "JMultifactory.h" @@ -38,10 +40,49 @@ JFactorySet::~JFactorySet() /// The only time mIsFactoryOwner should/can be set false is when a JMultifactory is using a JFactorySet internally /// to manage its JMultifactoryHelpers. if (mIsFactoryOwner) { - for (auto& f : mFactories) delete f.second; + for (auto& s : mDataBundlesFromName) { + // Only delete _inserted_ collections. Collections are otherwise owned by their factories + if (s.second->GetFactory() == nullptr) { + delete s.second; + } + } + for (auto f : mAllFactories) delete f; + // Now that the factories are deleted, nothing can call the multifactories so it is safe to delete them as well + + for (auto* mf : mMultifactories) { delete mf; } + } +} + +//--------------------------------- +// Add +//--------------------------------- +void JFactorySet::Add(JDataBundle* databundle) { + + if (databundle->GetUniqueName().empty()) { + throw JException("Attempted to add a databundle with no unique_name"); + } + auto named_result = mDataBundlesFromName.find(databundle->GetUniqueName()); + if (named_result != std::end(mDataBundlesFromName)) { + // Collection is duplicate. Since this almost certainly indicates a user error, and + // the caller will not be able to do anything about it anyway, throw an exception. + // We show the user which factory is causing this problem, including both plugin names + + auto ex = JException("Attempted to add duplicate databundles"); + ex.function_name = "JFactorySet::Add"; + ex.instance_name = databundle->GetUniqueName(); + + auto fac = databundle->GetFactory(); + if (fac != nullptr) { + ex.type_name = fac->GetTypeName(); + ex.plugin_name = fac->GetPluginName(); + if (named_result->second->GetFactory() != nullptr) { + ex.plugin_name += ", " + named_result->second->GetFactory()->GetPluginName(); + } + } + throw ex; } - // Now that the factories are deleted, nothing can call the multifactories so it is safe to delete them as well - for (auto* mf : mMultifactories) { delete mf; } + // Note that this is agnostic to event level. We may decide to change this. + mDataBundlesFromName[databundle->GetUniqueName()] = databundle; } //--------------------------------- @@ -54,34 +95,56 @@ bool JFactorySet::Add(JFactory* aFactory) /// throw an exception and let the user figure out what to do. /// This scenario occurs when the user has multiple JFactory producing the /// same T JObject, and is not distinguishing between them via tags. + + // There are two different ways JFactories can work now. In the old way, JFactory must be + // a JFactoryT, and have exactly one output collection. In the new way (which includes JFactoryPodioT), + // JFactory has an arbitrary number of output collections which are explicitly + // represented, similar to but better than JMultifactory. We distinguish between + // these two cases by checking whether JFactory::GetObjectType returns an object type vs nullopt. - auto typed_key = std::make_pair( aFactory->GetObjectType(), aFactory->GetTag() ); - auto untyped_key = std::make_pair( aFactory->GetObjectName(), aFactory->GetTag() ); - auto typed_result = mFactories.find(typed_key); - auto untyped_result = mFactoriesFromString.find(untyped_key); + mAllFactories.push_back(aFactory); - if (typed_result != std::end(mFactories) || untyped_result != std::end(mFactoriesFromString)) { - // Factory is duplicate. Since this almost certainly indicates a user error, and - // the caller will not be able to do anything about it anyway, throw an exception. - // We show the user which factory is causing this problem, including both plugin names - std::string other_plugin_name; - if (typed_result != std::end(mFactories)) { - other_plugin_name = typed_result->second->GetPluginName(); + if (aFactory->GetOutputs().empty()) { + // We have an old-style JFactory! + + auto typed_key = std::make_pair( aFactory->GetObjectType(), aFactory->GetTag() ); + auto untyped_key = std::make_pair( aFactory->GetObjectName(), aFactory->GetTag() ); + + auto typed_result = mFactories.find(typed_key); + auto untyped_result = mFactoriesFromString.find(untyped_key); + + if (typed_result != std::end(mFactories) || untyped_result != std::end(mFactoriesFromString)) { + // Factory is duplicate. Since this almost certainly indicates a user error, and + // the caller will not be able to do anything about it anyway, throw an exception. + // We show the user which factory is causing this problem, including both plugin names + std::string other_plugin_name; + if (typed_result != std::end(mFactories)) { + other_plugin_name = typed_result->second->GetPluginName(); + } + else { + other_plugin_name = untyped_result->second->GetPluginName(); + } + auto ex = JException("Attempted to add duplicate factories"); + ex.function_name = "JFactorySet::Add"; + ex.instance_name = aFactory->GetPrefix(); + ex.type_name = aFactory->GetTypeName(); + ex.plugin_name = aFactory->GetPluginName() + ", " + other_plugin_name; + throw ex; } - else { - other_plugin_name = untyped_result->second->GetPluginName(); + + mFactories[typed_key] = aFactory; + mFactoriesFromString[untyped_key] = aFactory; + } + else { + // We have a new-style JFactory! + for (const auto* output : aFactory->GetOutputs()) { + for (const auto& bundle : output->GetDataBundles()) { + bundle->SetFactory(aFactory); + Add(bundle.get()); + } } - auto ex = JException("Attempted to add duplicate factories"); - ex.function_name = "JFactorySet::Add"; - ex.instance_name = aFactory->GetPrefix(); - ex.type_name = aFactory->GetTypeName(); - ex.plugin_name = aFactory->GetPluginName() + ", " + other_plugin_name; - throw ex; } - - mFactories[typed_key] = aFactory; - mFactoriesFromString[untyped_key] = aFactory; return true; } @@ -103,6 +166,22 @@ bool JFactorySet::Add(JMultifactory *multifactory) { return true; } +//--------------------------------- +// GetDataBundle +//--------------------------------- +JDataBundle* JFactorySet::GetDataBundle(const std::string& name) const { + auto it = mDataBundlesFromName.find(name); + if (it != std::end(mDataBundlesFromName)) { + auto fac = it->second->GetFactory(); + if (fac != nullptr && fac->GetLevel() != mLevel) { + throw JException("Data bundle belongs to a different level on the event hierarchy!"); + } + return it->second; + } + return nullptr; +} + + //--------------------------------- // GetFactory //--------------------------------- @@ -119,13 +198,50 @@ JFactory* JFactorySet::GetFactory(const std::string& object_name, const std::str return nullptr; } +//--------------------------------- +// GetFactory +//--------------------------------- +JFactory* JFactorySet::GetFactory(std::type_index object_type, const std::string& object_name, const std::string& tag) const { + + auto typed_key = std::make_pair(object_type, tag); + auto typed_iter = mFactories.find(typed_key); + if (typed_iter != std::end(mFactories)) { + JEventLevel found_level = typed_iter->second->GetLevel(); + if (found_level != mLevel) { + throw JException("Factory belongs to a different level on the event hierarchy. Expected: %s, Found: %s", toString(mLevel).c_str(), toString(found_level).c_str()); + } + return typed_iter->second; + } + return GetFactory(object_name, tag); +} + //--------------------------------- // GetAllFactories //--------------------------------- std::vector JFactorySet::GetAllFactories() const { + + // This returns both old-style (JFactoryT) and new-style (JFactory+JStorage) factories, unlike + // GetAllFactories(object_type, object_name) below. This is because we use this method in + // JEventProcessors to activate factories in a generic way, particularly when working with Podio data. + + return mAllFactories; +} + +//--------------------------------- +// GetAllFactories +//--------------------------------- +std::vector JFactorySet::GetAllFactories(std::type_index object_type, const std::string& object_name) const { + + // This returns all factories which _directly_ produce objects of type object_type, i.e. they don't use a JStorage. + // This is what all of its callers already expect anyhow. Obviously we'd like to migrate everything over to JStorage + // eventually. Rather than updating this to also check mDataBundlesFromName, it probably makes more sense to create + // a JFactorySet::GetAllStorages(type_index, object_name) instead, and migrate all callers to use that. + std::vector results; - for (auto p : mFactories) { - results.push_back(p.second); + for (auto& it : mFactories) { + if (it.second->GetObjectType() == object_type || it.second->GetObjectName() == object_name) { + results.push_back(it.second); + } } return results; } @@ -134,9 +250,16 @@ std::vector JFactorySet::GetAllFactories() const { // GetAllMultifactories //--------------------------------- std::vector JFactorySet::GetAllMultifactories() const { - std::vector results; - for (auto f : mMultifactories) { - results.push_back(f); + return mMultifactories; +} + +//--------------------------------- +// GetAllDataBundleNames +//--------------------------------- +std::vector JFactorySet::GetAllDataBundleNames() const { + std::vector results; + for (const auto& it : mDataBundlesFromName) { + results.push_back(it.first); } return results; } @@ -167,10 +290,15 @@ void JFactorySet::Print() const /// Release() loops over all contained factories, clearing their data void JFactorySet::Release() { - - for (const auto& sFactoryPair : mFactories) { - auto sFactory = sFactoryPair.second; - sFactory->ClearData(); + for (auto* fac : mAllFactories) { + fac->ClearData(); + } + for (auto& it : mDataBundlesFromName) { + // fac->ClearData() only clears JFactoryT's, because that's how it always worked. + // Clearing is fundamentally an operation on the data bundle, not on the factory itself. + // Furthermore, "clearing" the factory is misleading because factories can cache arbitrary + // state inside member variables, and there's no way to clear that. + it.second->ClearData(); } } diff --git a/src/libraries/JANA/JFactorySet.h b/src/libraries/JANA/JFactorySet.h index 25ebc50c0..46b58d178 100644 --- a/src/libraries/JANA/JFactorySet.h +++ b/src/libraries/JANA/JFactorySet.h @@ -4,88 +4,63 @@ #pragma once +#include +#include + +#include #include #include #include -#include -#include -#include - class JFactoryGenerator; class JFactory; class JMultifactory; +class JDataBundle; class JFactorySet { - public: - JFactorySet(); - JFactorySet(const std::vector& aFactoryGenerators); - virtual ~JFactorySet(); - - bool Add(JFactory* aFactory); - bool Add(JMultifactory* multifactory); - void Print(void) const; - void Release(void); - - JFactory* GetFactory(const std::string& object_name, const std::string& tag="") const; - template JFactoryT* GetFactory(const std::string& tag = "") const; - std::vector GetAllFactories() const; - std::vector GetAllMultifactories() const; - template std::vector*> GetAllFactories() const; - - JEventLevel GetLevel() const { return mLevel; } - void SetLevel(JEventLevel level) { mLevel = level; } - - protected: - std::map, JFactory*> mFactories; // {(typeid, tag) : factory} - std::map, JFactory*> mFactoriesFromString; // {(objname, tag) : factory} - std::vector mMultifactories; - bool mIsFactoryOwner = true; - JEventLevel mLevel = JEventLevel::PhysicsEvent; - -}; - - -template -JFactoryT* JFactorySet::GetFactory(const std::string& tag) const { - - auto typed_key = std::make_pair(std::type_index(typeid(T)), tag); - auto typed_iter = mFactories.find(typed_key); - if (typed_iter != std::end(mFactories)) { - JEventLevel found_level = typed_iter->second->GetLevel(); - if (found_level != mLevel) { - throw JException("Factory belongs to a different level on the event hierarchy. Expected: %s, Found: %s", toString(mLevel).c_str(), toString(found_level).c_str()); - } - return static_cast*>(typed_iter->second); - } - - auto untyped_key = std::make_pair(JTypeInfo::demangle(), tag); - auto untyped_iter = mFactoriesFromString.find(untyped_key); - if (untyped_iter != std::end(mFactoriesFromString)) { - JEventLevel found_level = untyped_iter->second->GetLevel(); - if (found_level != mLevel) { - throw JException("Factory belongs to a different level on the event hierarchy. Expected: %s, Found: %s", toString(mLevel).c_str(), toString(found_level).c_str()); - } - return static_cast*>(untyped_iter->second); - } - return nullptr; -} - -template -std::vector*> JFactorySet::GetAllFactories() const { - auto sKey = std::type_index(typeid(T)); - std::vector*> data; - for (auto it=std::begin(mFactories);it!=std::end(mFactories);it++){ - if (it->first.first==sKey){ - if (it->second->GetLevel() == mLevel) { - data.push_back(static_cast*>(it->second)); - } - } +private: + std::vector mAllFactories; + std::map, JFactory*> mFactories; // {(typeid, tag) : factory} + std::map, JFactory*> mFactoriesFromString; // {(objname, tag) : factory} + std::map mDataBundlesFromName; + std::vector mMultifactories; + bool mIsFactoryOwner = true; + JEventLevel mLevel = JEventLevel::PhysicsEvent; + +public: + JFactorySet(); + JFactorySet(const std::vector& aFactoryGenerators); + virtual ~JFactorySet(); + + bool Add(JFactory* aFactory); + bool Add(JMultifactory* multifactory); + void Add(JDataBundle* storage); + void Print() const; + void Release(); + + std::vector GetAllDataBundleNames() const; + JDataBundle* GetDataBundle(const std::string& collection_name) const; + + JFactory* GetFactory(const std::string& object_name, const std::string& tag="") const; + JFactory* GetFactory(std::type_index object_type, const std::string& object_name, const std::string& tag = "") const; + + std::vector GetAllFactories() const; + std::vector GetAllFactories(std::type_index object_type, const std::string& object_name) const; + + std::vector GetAllMultifactories() const; + + JEventLevel GetLevel() const { return mLevel; } + void SetLevel(JEventLevel level) { mLevel = level; } + + template + [[deprecated]] + JFactory* GetFactory(const std::string& tag) const { + auto object_name = JTypeInfo::demangle(); + return GetFactory(object_name, tag); } - return data; -} +}; diff --git a/src/libraries/JANA/JMultifactory.cc b/src/libraries/JANA/JMultifactory.cc index e6da04fc1..389085792 100644 --- a/src/libraries/JANA/JMultifactory.cc +++ b/src/libraries/JANA/JMultifactory.cc @@ -10,11 +10,6 @@ void JMultifactory::Execute(const std::shared_ptr& event) { std::lock_guard lock(m_mutex); -#if JANA2_HAVE_PODIO - if (mNeedPodio) { - mPodioFrame = GetOrCreateFrame(event); - } -#endif if (m_status == Status::Uninitialized) { CallWithJExceptionWrapper("JMultifactory::Init", [&](){ diff --git a/src/libraries/JANA/JMultifactory.h b/src/libraries/JANA/JMultifactory.h index d41dc1442..763fca62e 100644 --- a/src/libraries/JANA/JMultifactory.h +++ b/src/libraries/JANA/JMultifactory.h @@ -10,9 +10,10 @@ #include #include #include +#include #if JANA2_HAVE_PODIO -#include "JANA/Podio/JFactoryPodioT.h" +#include #endif class JMultifactory; @@ -40,14 +41,18 @@ class JMultifactoryHelper : public JFactoryT{ #if JANA2_HAVE_PODIO -// TODO: This redundancy goes away if we merge JFactoryPodioT with JFactoryT template -class JMultifactoryHelperPodio : public JFactoryPodioT{ +class JMultifactoryHelperPodio : public JFactory { + jana::components::PodioOutput m_output {this}; JMultifactory* mMultiFactory; public: - JMultifactoryHelperPodio(JMultifactory* parent) : mMultiFactory(parent) {} + JMultifactoryHelperPodio(JMultifactory* parent, std::string collection_name) : mMultiFactory(parent) { + mObjectName = JTypeInfo::demangle(); + mTag = collection_name; + m_output.GetDataBundle()->SetUniqueName(collection_name); + } virtual ~JMultifactoryHelperPodio() = default; // This does NOT own mMultiFactory; the enclosing JFactorySet does @@ -61,6 +66,13 @@ class JMultifactoryHelperPodio : public JFactoryPodioT{ // Helpers do not produce any summary information void Summarize(JComponentSummary&) const override { } + + void SetCollection(typename T::collection_type&& collection) { + m_output() = std::make_unique(std::move(collection)); + } + void SetCollection(std::unique_ptr collection) { + m_output() = std::move(collection); + } }; #endif // JANA2_HAVE_PODIO @@ -74,10 +86,6 @@ class JMultifactory : public jana::components::JComponent, // However, don't worry about a Status variable. Every time Execute() gets called, so does Process(). // The JMultifactoryHelpers will control calls to Execute(). -#if JANA2_HAVE_PODIO - bool mNeedPodio = false; // Whether we need to retrieve the podio::Frame - podio::Frame* mPodioFrame = nullptr; // To provide the podio::Frame to SetPodioData, SetCollection -#endif public: JMultifactory() = default; @@ -107,10 +115,10 @@ class JMultifactory : public jana::components::JComponent, void DeclarePodioOutput(std::string tag, bool owns_data=true); template - void SetCollection(std::string tag, typename JFactoryPodioT::CollectionT&& collection); + void SetCollection(std::string tag, typename T::collection_type&& collection); template - void SetCollection(std::string tag, std::unique_ptr::CollectionT> collection); + void SetCollection(std::string tag, std::unique_ptr collection); #endif @@ -156,8 +164,8 @@ void JMultifactory::DeclareOutput(std::string tag, bool owns_data) { template void JMultifactory::SetData(std::string tag, std::vector data) { - JFactoryT* helper = mHelpers.GetFactory(tag); - if (helper == nullptr) { + JFactory* helper_untyped = mHelpers.GetFactory(std::type_index(typeid(T)), JTypeInfo::demangle(), tag); + if (helper_untyped == nullptr) { auto ex = JException("JMultifactory: Attempting to SetData() without corresponding DeclareOutput()"); ex.function_name = "JMultifactory::SetData"; ex.type_name = m_type_name; @@ -165,13 +173,8 @@ void JMultifactory::SetData(std::string tag, std::vector data) { ex.plugin_name = m_plugin_name; throw ex; } -#if JANA2_HAVE_PODIO - // This may or may not be a Podio factory. We find out if it is, and if so, set the frame before calling Set(). - auto* typed = dynamic_cast(helper); - if (typed != nullptr) { - typed->SetFrame(mPodioFrame); // Needs to be called before helper->Set(), otherwise Set() excepts - } -#endif + JFactoryT* helper = static_cast*>(helper_untyped); + // This will except if helper is a JMultifactoryHelperPodio. User should use SetPodioData() instead for PODIO data. helper->Set(data); } @@ -179,23 +182,19 @@ void JMultifactory::SetData(std::string tag, std::vector data) { #if JANA2_HAVE_PODIO template -void JMultifactory::DeclarePodioOutput(std::string tag, bool owns_data) { - // TODO: Decouple tag name from collection name - auto* helper = new JMultifactoryHelperPodio(this); - if (!owns_data) helper->SetSubsetCollection(true); - - helper->SetTag(std::move(tag)); +void JMultifactory::DeclarePodioOutput(std::string coll_name, bool) { + auto* helper = new JMultifactoryHelperPodio(this, coll_name); + helper->SetTag(std::move(coll_name)); helper->SetPluginName(m_plugin_name); helper->SetFactoryName(GetTypeName() + "::Helper<" + JTypeInfo::demangle() + ">"); helper->SetLevel(GetLevel()); mHelpers.SetLevel(GetLevel()); mHelpers.Add(helper); - mNeedPodio = true; } template -void JMultifactory::SetCollection(std::string tag, typename JFactoryPodioT::CollectionT&& collection) { - JFactoryT* helper = mHelpers.GetFactory(tag); +void JMultifactory::SetCollection(std::string name, typename T::collection_type&& collection) { + JFactory* helper = mHelpers.GetDataBundle(name)->GetFactory(); if (helper == nullptr) { auto ex = JException("JMultifactory: Attempting to SetData() without corresponding DeclareOutput()"); ex.function_name = "JMultifactory::SetCollection"; @@ -204,9 +203,9 @@ void JMultifactory::SetCollection(std::string tag, typename JFactoryPodioT::C ex.plugin_name = m_plugin_name; throw ex; } - auto* typed = dynamic_cast*>(helper); + auto* typed = dynamic_cast*>(helper); if (typed == nullptr) { - auto ex = JException("JMultifactory: Helper needs to be a JFactoryPodioT (this shouldn't be reachable)"); + auto ex = JException("JMultifactory: Helper needs to be a JMultifactoryHelperPodio (this shouldn't be reachable)"); ex.function_name = "JMultifactory::SetCollection"; ex.type_name = m_type_name; ex.instance_name = m_prefix; @@ -214,13 +213,12 @@ void JMultifactory::SetCollection(std::string tag, typename JFactoryPodioT::C throw ex; } - typed->SetFrame(mPodioFrame); typed->SetCollection(std::move(collection)); } template -void JMultifactory::SetCollection(std::string tag, std::unique_ptr::CollectionT> collection) { - JFactoryT* helper = mHelpers.GetFactory(tag); +void JMultifactory::SetCollection(std::string name, std::unique_ptr collection) { + JFactory* helper = mHelpers.GetDataBundle(name)->GetFactory(); if (helper == nullptr) { auto ex = JException("JMultifactory: Attempting to SetData() without corresponding DeclareOutput()"); ex.function_name = "JMultifactory::SetCollection"; @@ -229,9 +227,9 @@ void JMultifactory::SetCollection(std::string tag, std::unique_ptr*>(helper); + auto* typed = dynamic_cast*>(helper); if (typed == nullptr) { - auto ex = JException("JMultifactory: Helper needs to be a JFactoryPodioT (this shouldn't be reachable)"); + auto ex = JException("JMultifactory: Helper needs to be a JMultifactoryHelperPodio (this shouldn't be reachable)"); ex.function_name = "JMultifactory::SetCollection"; ex.type_name = m_type_name; ex.instance_name = m_prefix; @@ -239,7 +237,6 @@ void JMultifactory::SetCollection(std::string tag, std::unique_ptrSetFrame(mPodioFrame); typed->SetCollection(std::move(collection)); } diff --git a/src/libraries/JANA/Podio/JFactoryPodioT.cc b/src/libraries/JANA/Podio/JFactoryPodioT.cc deleted file mode 100644 index 506948502..000000000 --- a/src/libraries/JANA/Podio/JFactoryPodioT.cc +++ /dev/null @@ -1,23 +0,0 @@ - -// Copyright 2023, Jefferson Science Associates, LLC. -// Subject to the terms in the LICENSE file found in the top-level directory. - -#include "JFactoryPodioT.h" -#include - -podio::Frame* GetOrCreateFrame(const std::shared_ptr& event) { - podio::Frame* result = nullptr; - try { - result = const_cast(event->GetSingle("")); - if (result == nullptr) { - result = new podio::Frame; - event->Insert(result); - } - } - catch (...) { - result = new podio::Frame; - event->Insert(result); - } - return result; -} - diff --git a/src/libraries/JANA/Podio/JFactoryPodioT.h b/src/libraries/JANA/Podio/JFactoryPodioT.h index 9ff5a60ad..7db5ad5cf 100644 --- a/src/libraries/JANA/Podio/JFactoryPodioT.h +++ b/src/libraries/JANA/Podio/JFactoryPodioT.h @@ -5,47 +5,28 @@ #pragma once -#include -#include - -/// The point of this additional base class is to allow us _untyped_ access to the underlying PODIO collection, -/// at the cost of some weird multiple inheritance. The JEvent can trigger the untyped factory using Create(), then -/// -class JFactoryPodio { -protected: - const podio::CollectionBase* mCollection = nullptr; - bool mIsSubsetCollection = false; - podio::Frame* mFrame = nullptr; - -private: - // Meant to be called internally, from JMultifactory - friend class JMultifactory; - void SetFrame(podio::Frame* frame) { mFrame = frame; } - - // Meant to be called internally, from JEvent: - friend class JEvent; - const podio::CollectionBase* GetCollection() { return mCollection; } - -public: - // Meant to be called from ctor, or externally, if we are creating a dummy factory such as a multifactory helper - void SetSubsetCollection(bool isSubsetCollection=true) { mIsSubsetCollection = isSubsetCollection; } -}; +#include "JANA/Utils/JTypeInfo.h" +#include +#include template -class JFactoryPodioT : public JFactoryT, public JFactoryPodio { +class JFactoryPodioT : public JFactory { public: using CollectionT = typename T::collection_type; + private: - // mCollection is owned by the frame. - // mFrame is owned by the JFactoryT. - // mData holds lightweight value objects which hold a pointer into mCollection. - // This factory owns these value objects. + jana::components::PodioOutput m_output {this}; public: explicit JFactoryPodioT(); ~JFactoryPodioT() override; + void SetTag(std::string tag) { + mTag = tag; + m_output.GetDataBundle()->SetUniqueName(tag); + } + void Init() override {} void BeginRun(const std::shared_ptr&) override {} void ChangeRun(const std::shared_ptr&) override {} @@ -53,29 +34,17 @@ class JFactoryPodioT : public JFactoryT, public JFactoryPodio { void EndRun() override {} void Finish() override {} - void Create(const std::shared_ptr& event) final; - std::type_index GetObjectType() const final { return std::type_index(typeid(T)); } - std::size_t GetNumObjects() const final { return mCollection->size(); } - void ClearData() final; + std::size_t GetNumObjects() const final { return m_output.GetDataBundle()->GetSize(); } void SetCollection(CollectionT&& collection); void SetCollection(std::unique_ptr collection); - void Set(const std::vector& aData) final; - void Set(std::vector&& aData) final; - void Insert(T* aDatum) final; - - - -private: - // This is meant to be called by JEvent::Insert - friend class JEvent; - void SetCollectionAlreadyInFrame(const CollectionT* collection); - }; template -JFactoryPodioT::JFactoryPodioT() = default; +JFactoryPodioT::JFactoryPodioT() { + mObjectName = JTypeInfo::demangle(); +} template JFactoryPodioT::~JFactoryPodioT() { @@ -88,18 +57,7 @@ void JFactoryPodioT::SetCollection(CollectionT&& collection) { /// Provide a PODIO collection. Note that PODIO assumes ownership of this collection, and the /// collection pointer should be assumed to be invalid after this call - if (this->mFrame == nullptr) { - throw JException("JFactoryPodioT: Unable to add collection to frame as frame is missing!"); - } - const auto& moved = this->mFrame->put(std::move(collection), this->GetTag()); - this->mCollection = &moved; - - for (const T& item : moved) { - T* clone = new T(item); - this->mData.push_back(clone); - } - this->mStatus = JFactory::Status::Inserted; - this->mCreationStatus = JFactory::CreationStatus::Inserted; + m_output() = std::make_unique(std::move(collection)); } @@ -108,103 +66,8 @@ void JFactoryPodioT::SetCollection(std::unique_ptr collection) { /// Provide a PODIO collection. Note that PODIO assumes ownership of this collection, and the /// collection pointer should be assumed to be invalid after this call - if (this->mFrame == nullptr) { - throw JException("JFactoryPodioT: Unable to add collection to frame as frame is missing!"); - } - this->mFrame->put(std::move(collection), this->GetTag()); - const auto* moved = &this->mFrame->template get(this->GetTag()); - this->mCollection = moved; - - for (const T& item : *moved) { - T* clone = new T(item); - this->mData.push_back(clone); - } - this->mStatus = JFactory::Status::Inserted; - this->mCreationStatus = JFactory::CreationStatus::Inserted; + m_output() = std::move(collection); } -template -void JFactoryPodioT::ClearData() { - if (this->mStatus == JFactory::Status::Uninitialized) { - return; - } - for (auto p : this->mData) { - // Avoid potentially invalid call to ObjBase::release(). The frame and - // all the collections and all Obj may have been deallocated at this point. - p->unlink(); - delete p; - } - this->mData.clear(); - this->mCollection = nullptr; // Collection is owned by the Frame, so we ignore here - this->mFrame = nullptr; // Frame is owned by the JEvent, so we ignore here - this->mStatus = JFactory::Status::Unprocessed; - this->mCreationStatus = JFactory::CreationStatus::NotCreatedYet; -} - -template -void JFactoryPodioT::SetCollectionAlreadyInFrame(const CollectionT* collection) { - for (const T& item : *collection) { - T* clone = new T(item); - this->mData.push_back(clone); - } - this->mCollection = collection; - this->mStatus = JFactory::Status::Inserted; - this->mCreationStatus = JFactory::CreationStatus::Inserted; -} - -// This free function is used to break the dependency loop between JFactoryPodioT and JEvent. -podio::Frame* GetOrCreateFrame(const std::shared_ptr& event); - -template -void JFactoryPodioT::Create(const std::shared_ptr& event) { - mFrame = GetOrCreateFrame(event); - try { - JFactory::Create(event); - } - catch (...) { - if (mCollection == nullptr) { - // If calling Create() excepts, we still create an empty collection - // so that podio::ROOTFrameWriter doesn't segfault on the null mCollection pointer - SetCollection(CollectionT()); - } - throw; - } - if (mCollection == nullptr) { - SetCollection(CollectionT()); - // If calling Process() didn't result in a call to Set() or SetCollection(), we create an empty collection - // so that podio::ROOTFrameWriter doesn't segfault on the null mCollection pointer - } -} - -template -void JFactoryPodioT::Set(const std::vector& aData) { - CollectionT collection; - if (mIsSubsetCollection) collection.setSubsetCollection(true); - for (T* item : aData) { - collection.push_back(*item); - delete item; - } - SetCollection(std::move(collection)); -} - -template -void JFactoryPodioT::Set(std::vector&& aData) { - CollectionT collection; - if (mIsSubsetCollection) collection.setSubsetCollection(true); - for (T* item : aData) { - collection.push_back(*item); - delete item; - } - SetCollection(std::move(collection)); -} - -template -void JFactoryPodioT::Insert(T* aDatum) { - CollectionT collection; - if (mIsSubsetCollection) collection->setSubsetCollection(true); - collection->push_back(*aDatum); - delete aDatum; - SetCollection(std::move(collection)); -} diff --git a/src/libraries/JANA/Utils/JAny.h b/src/libraries/JANA/Utils/JAny.h index ad244afaf..68e343af1 100644 --- a/src/libraries/JANA/Utils/JAny.h +++ b/src/libraries/JANA/Utils/JAny.h @@ -3,6 +3,9 @@ // Subject to the terms in the LICENSE file found in the top-level directory. #pragma once +#include +#include +#include /// Ideally we'd just use std::any, but we are restricted to C++14 for the time being struct JAny { @@ -17,3 +20,68 @@ struct JAnyT : JAny { ~JAnyT() override = default; // deletes the t }; + +template +class JOptional { +private: + using StorageT = typename std::aligned_storage::type; + bool has_value; + StorageT storage; + +public: + JOptional() : has_value(false) {} + + JOptional(const T& val) : has_value(true) { + new (&storage) T(val); + } + + JOptional(T&& val) : has_value(true) { + new (&storage) T(std::move(val)); + } + + ~JOptional() { + reset(); + } + + // Checks if there is a value + bool hasValue() const { return has_value; } + + // Accesses the value, throws if not present + T& get() { + if (!has_value) { + throw std::runtime_error("No value present"); + } + return *reinterpret_cast(&storage); // Access without launder (C++14) + } + + const T& get() const { + if (!has_value) { + throw std::runtime_error("No value present"); + } + return *reinterpret_cast(&storage); // Access without launder (C++14) + } + + // Resets the optional (removes the value) + void reset() { + if (has_value) { + reinterpret_cast(&storage)->~T(); // Explicitly call destructor + has_value = false; + } + } + + // Set the value + void set(const T& val) { + reset(); + new (&storage) T(val); // Placement new + has_value = true; + } + + // Set using move semantics + void set(T&& val) { + reset(); + new (&storage) T(std::move(val)); // Placement new + has_value = true; + } +}; + + diff --git a/src/programs/unit_tests/CMakeLists.txt b/src/programs/unit_tests/CMakeLists.txt index 9b34afecd..6c6f8b662 100644 --- a/src/programs/unit_tests/CMakeLists.txt +++ b/src/programs/unit_tests/CMakeLists.txt @@ -44,7 +44,7 @@ set(TEST_SOURCES ) if (${USE_PODIO}) - list(APPEND TEST_SOURCES Components/PodioTests.cc) + list(APPEND TEST_SOURCES Components/PodioTests.cc Components/JStorageTests.cc) endif() add_jana_test(jana-unit-tests SOURCES ${TEST_SOURCES}) diff --git a/src/programs/unit_tests/Components/JComponentTests.cc b/src/programs/unit_tests/Components/JComponentTests.cc index 0f20cb6a8..fba6574f6 100644 --- a/src/programs/unit_tests/Components/JComponentTests.cc +++ b/src/programs/unit_tests/Components/JComponentTests.cc @@ -5,12 +5,13 @@ #include #include #include +#include namespace jana { template MultifactoryT* RetrieveMultifactory(JFactorySet* facset, std::string output_collection_name) { - auto fac = facset->GetFactory(output_collection_name); + auto fac = facset->GetFactory(std::type_index(typeid(OutputCollectionT)), JTypeInfo::demangle(), output_collection_name); REQUIRE(fac != nullptr); auto helper = dynamic_cast*>(fac); REQUIRE(helper != nullptr); diff --git a/src/programs/unit_tests/Components/JStorageTests.cc b/src/programs/unit_tests/Components/JStorageTests.cc new file mode 100644 index 000000000..8a0cec486 --- /dev/null +++ b/src/programs/unit_tests/Components/JStorageTests.cc @@ -0,0 +1,197 @@ + + +// Copyright 2023, Jefferson Science Associates, LLC. +// Subject to the terms in the LICENSE file found in the top-level directory. + +#include "catch.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace jstorage_tests { + +struct TestSource : public JEventSource { + Parameter y {this, "y", "asdf", "Does something"}; + TestSource() { + SetCallbackStyle(CallbackStyle::ExpertMode); + } + static std::string GetDescription() { + return "Test source"; + } + void Init() { + REQUIRE(y() == "asdf"); + } + Result Emit(JEvent&) { + return Result::Success; + } +}; + +struct TestFactory : public JFactory { + jana::components::PodioOutput m_clusters {this, "my_collection"}; + + TestFactory() { + SetCallbackStyle(CallbackStyle::ExpertMode); + } + void Init() { + } + void Process(const std::shared_ptr&) { + LOG_WARN(GetLogger()) << "Calling TestFactory::Process" << LOG_END; + m_clusters()->push_back(MutableExampleCluster(22.2)); + m_clusters()->push_back(MutableExampleCluster(27)); + } +}; + +struct RegeneratingTestFactory : public JFactory { + + jana::components::PodioOutput m_clusters {this, "my_collection"}; + + RegeneratingTestFactory() { + SetCallbackStyle(CallbackStyle::ExpertMode); + SetFactoryFlag(JFactory_Flags_t::REGENERATE); + } + void Init() { + } + void Process(const std::shared_ptr&) { + LOG_WARN(GetLogger()) << "Calling TestFactory::Process" << LOG_END; + m_clusters()->push_back(MutableExampleCluster(22.2)); + m_clusters()->push_back(MutableExampleCluster(27)); + } +}; + +struct TestProc : public JEventProcessor { + + PodioInput m_clusters_in {this, InputOptions{.name="my_collection"}}; + + TestProc() { + SetCallbackStyle(CallbackStyle::ExpertMode); + } + + void Process(const JEvent& event) { + auto clusters = event.GetCollection("my_collection", true); + REQUIRE(clusters->size() == 2); + + REQUIRE(m_clusters_in() != nullptr); + REQUIRE(m_clusters_in()->size() == 2); + std::cout << "Proc found data: " << m_clusters_in() << std::endl; + } +}; + +TEST_CASE("JStorageTests_EventInsertAndRetrieve") { + JApplication app; + app.Initialize(); + + auto event = std::make_shared(); + app.GetService()->configure_event(*event); + + auto coll_in = std::make_unique(); + coll_in->create().energy(7.6); + coll_in->create().energy(28); + coll_in->create().energy(42); + + event->InsertCollection(std::move(*coll_in), "my_collection"); + + auto coll_out = event->GetCollectionBase("my_collection"); + REQUIRE(coll_out != nullptr); + REQUIRE(coll_out->size() == 3); + auto typed_coll = dynamic_cast(coll_out); + REQUIRE(typed_coll->at(1).energy() == 28); +} + +TEST_CASE("JStorageTests_InsertOverridesFactory") { + JApplication app; + app.Add(new JFactoryGeneratorT); + app.Initialize(); + + auto event = std::make_shared(); + app.GetService()->configure_event(*event); + + auto coll_in = std::make_unique(); + coll_in->create().energy(7.6); + coll_in->create().energy(28); + coll_in->create().energy(42); + + event->InsertCollection(std::move(*coll_in), "my_collection"); + + auto coll_out = event->GetCollectionBase("my_collection"); + REQUIRE(coll_out != nullptr); + REQUIRE(coll_out->size() == 3); + auto typed_coll = dynamic_cast(coll_out); + REQUIRE(typed_coll->at(1).energy() == 28); + +} +TEST_CASE("JStorageTests_RegenerateOverridesInsert") { + JApplication app; + app.SetParameterValue("enable_regenerate", true); + app.Add(new JFactoryGeneratorT); + app.Initialize(); + + auto event = std::make_shared(); + app.GetService()->configure_event(*event); + + auto coll_in = std::make_unique(); + coll_in->create().energy(7.6); + coll_in->create().energy(28); + coll_in->create().energy(42); + + event->InsertCollection(std::move(*coll_in), "my_collection"); + + try { + event->GetCollectionBase("my_collection"); + REQUIRE(1==0); + + } + catch (const std::exception& e) { + std::cout << e.what() << std::endl; + } +} + +TEST_CASE("JStorageTests_FactoryProcessAndRetrieveUntyped") { + JApplication app; + app.Add(new JFactoryGeneratorT); + app.Initialize(); + + auto event = std::make_shared(); + app.GetService()->configure_event(*event); + + auto coll_out = event->GetCollectionBase("my_collection"); + REQUIRE(coll_out != nullptr); + REQUIRE(coll_out->size() == 2); + auto typed_coll = dynamic_cast(coll_out); + REQUIRE(typed_coll->at(1).energy() == 27); +} + +TEST_CASE("JStorageTests_FactoryProcessAndRetrieveTyped") { + JApplication app; + app.Add(new JFactoryGeneratorT); + app.Initialize(); + + auto event = std::make_shared(); + app.GetService()->configure_event(*event); + + auto coll = event->GetCollection("my_collection"); + REQUIRE(coll != nullptr); + REQUIRE(coll->size() == 2); + REQUIRE(coll->at(1).energy() == 27); +} + +TEST_CASE("JStorageEndToEndTest") { + JApplication app; + app.Add(new JEventSourceGeneratorT); + app.Add(new JFactoryGeneratorT); + app.Add(new TestProc); + app.Add("fake_file.root"); + app.SetParameterValue("jana:nevents", 2); + app.Run(); +} + +} // namespace jcollection_tests diff --git a/src/programs/unit_tests/Components/PodioTests.cc b/src/programs/unit_tests/Components/PodioTests.cc index c0bedb560..faa0f54c6 100644 --- a/src/programs/unit_tests/Components/PodioTests.cc +++ b/src/programs/unit_tests/Components/PodioTests.cc @@ -2,9 +2,17 @@ #include #include -#include + #include +#include +#include +#include +#include +#include +#include +#include +#include namespace podiotests { @@ -30,12 +38,6 @@ TEST_CASE("PodioTestsInsertAndRetrieve") { REQUIRE((*collection_retrieved)[0].energy() == 16.0); } - SECTION("Retrieve using JEvent::Get()") { - std::vector clusters_retrieved = event->Get("clusters"); - REQUIRE(clusters_retrieved.size() == 2); - REQUIRE(clusters_retrieved[0]->energy() == 16.0); - } - SECTION("Retrieve directly from podio::Frame") { auto frame = event->GetSingle(); auto* collection_retrieved = dynamic_cast(frame->get("clusters")); @@ -161,9 +163,228 @@ TEST_CASE("JFactoryPodioT::Init gets called") { const auto* res = dynamic_cast(r); REQUIRE(res != nullptr); REQUIRE((*res)[0].energy() == 16.0); - auto fac = dynamic_cast(event->GetFactory("clusters")); + + auto fac_untyped = event->GetDataBundle("clusters", false)->GetFactory(); + REQUIRE(fac_untyped != nullptr); + auto fac = dynamic_cast(fac_untyped); REQUIRE(fac != nullptr); REQUIRE(fac->init_called == true); } + + +namespace multifactory { + +struct TestMultiFac : public JMultifactory { + TestMultiFac() { + DeclarePodioOutput("sillyclusters"); + } + bool init_called = false; + void Init() override { + init_called = true; + } + void Process(const std::shared_ptr& event) override { + ExampleClusterCollection c; + c.push_back(MutableExampleCluster(16.0 + event->GetEventNumber())); + SetCollection("sillyclusters", std::move(c)); + } +}; + +TEST_CASE("PodioTests_JMultifactoryInit") { + + JApplication app; + auto event = std::make_shared(&app); + event->SetEventNumber(0); + auto fs = new JFactorySet; + fs->Add(new TestMultiFac); + event->SetFactorySet(fs); + + // Simulate a trip to the event pool _before_ calling JFactory::Process() + // This is important because a badly designed JFactory::ClearData() could + // mangle the Unprocessed status and consequently skip Init(). + event->GetFactorySet()->Release(); // In theory this shouldn't hurt + + auto r = event->GetCollectionBase("sillyclusters"); + REQUIRE(r != nullptr); + const auto* res = dynamic_cast(r); + REQUIRE(res != nullptr); + REQUIRE((*res)[0].energy() == 16.0); + + auto multifac = event->GetFactorySet()->GetAllMultifactories().at(0); + REQUIRE(multifac != nullptr); + auto multifac_typed = dynamic_cast(multifac); + REQUIRE(multifac_typed != nullptr); + REQUIRE(multifac_typed->init_called == true); +} + + +TEST_CASE("PodioTests_MultifacMultiple") { + + JApplication app; + auto event = std::make_shared(&app); + event->SetEventNumber(0); + auto fs = new JFactorySet; + fs->Add(new TestMultiFac); + event->SetFactorySet(fs); + + auto r = event->GetCollection("sillyclusters"); + REQUIRE(r->at(0).energy() == 16.0); + + event->GetFactorySet()->Release(); // Simulate a trip to the JEventPool + + event->SetEventNumber(4); + r = event->GetCollection("sillyclusters"); + REQUIRE(r->at(0).energy() == 20.0); + +} +} // namespace multifactory + + +TEST_CASE("PodioTests_InsertMultiple") { + + JApplication app; + auto event = std::make_shared(&app); + + // Insert a cluster + + auto coll1 = ExampleClusterCollection(); + auto cluster1 = coll1.create(22.0); + auto storage = event->InsertCollection(std::move(coll1), "clusters"); + + REQUIRE(storage->GetSize() == 1); + REQUIRE(storage->GetStatus() == JDataBundle::Status::Inserted); + + // Retrieve and validate cluster + + auto cluster1_retrieved = event->GetCollection("clusters"); + REQUIRE(cluster1_retrieved->at(0).energy() == 22.0); + + // Clear event + + event->GetFactorySet()->Release(); // Simulate a trip to the JEventPool + + // After clearing, the JDataBundle will still exist, but it will be empty + auto storage2 = event->GetDataBundle("clusters", false); + REQUIRE(storage2->GetStatus() == JDataBundle::Status::Empty); + REQUIRE(storage2->GetSize() == 0); + + // Insert a cluster. If event isn't being cleared correctly, this will throw + + auto coll2 = ExampleClusterCollection(); + auto cluster2 = coll2.create(33.0); + auto storage3 = event->InsertCollection(std::move(coll2), "clusters"); + REQUIRE(storage3->GetStatus() == JDataBundle::Status::Inserted); + REQUIRE(storage3->GetSize() == 1); + + // Retrieve and validate cluster + + auto cluster2_retrieved = event->GetCollection("clusters"); + REQUIRE(cluster2_retrieved->at(0).energy() == 33.0); +} + + +namespace omnifacmultiple { + +struct MyClusterFactory : public JOmniFactory { + + PodioOutput m_clusters_out{this, "clusters"}; + + void Configure() { + } + + void ChangeRun(int32_t /*run_nr*/) { + } + + void Execute(int32_t /*run_nr*/, uint64_t evt_nr) { + + auto cs = std::make_unique(); + auto cluster = MutableExampleCluster(101.0 + evt_nr); + cs->push_back(cluster); + m_clusters_out() = std::move(cs); + } +}; + +TEST_CASE("PodioTests_OmniFacMultiple") { + + JApplication app; + app.SetParameterValue("jana:loglevel", "error"); + app.Add(new JOmniFactoryGeneratorT("cluster_fac", {}, {"clusters"})); + app.Initialize(); + auto event = std::make_shared(&app); + app.GetService()->configure_event(*event); + event->SetEventNumber(22); + + // Check that storage is already present + auto storage = event->GetDataBundle("clusters", false); + REQUIRE(storage != nullptr); + REQUIRE(storage->GetStatus() == JDataBundle::Status::Empty); + + // Retrieve triggers factory + auto coll = event->GetCollection("clusters"); + REQUIRE(coll->size() == 1); + REQUIRE(coll->at(0).energy() == 123.0); + + // Clear everything + event->GetFactorySet()->Release(); + event->SetEventNumber(1010); + + // Check that storage has been reset + storage = event->GetDataBundle("clusters", false); + REQUIRE(storage != nullptr); + REQUIRE(storage->GetStatus() == JDataBundle::Status::Empty); + + REQUIRE(storage->GetFactory() != nullptr); + + // Retrieve triggers factory + auto coll2 = event->GetCollection("clusters"); + REQUIRE(coll2->size() == 1); + REQUIRE(coll2->at(0).energy() == 1111.0); +} + +} // namespace omnifacmultiple + + +namespace omnifacreadinsert { + +struct RWClusterFac : public JOmniFactory { + + PodioInput m_clusters_in{this}; + PodioOutput m_clusters_out{this}; + + void Configure() {} + void ChangeRun(int32_t /*run_nr*/) {} + void Execute(int32_t /*run_nr*/, uint64_t evt_nr) { + + auto cs = std::make_unique(); + for (const auto& cluster_in : *m_clusters_in()) { + auto cluster = MutableExampleCluster(1.0 + cluster_in.energy()); + cs->push_back(cluster); + } + m_clusters_out() = std::move(cs); + } +}; + +TEST_CASE("PodioTests_OmniFacReadInsert") { + + JApplication app; + app.SetParameterValue("jana:loglevel", "error"); + app.Add(new JOmniFactoryGeneratorT("cluster_fac", {"protoclusters"}, {"clusters"})); + app.Initialize(); + auto event = std::make_shared(&app); + app.GetService()->configure_event(*event); + + auto coll1 = ExampleClusterCollection(); + auto cluster1 = coll1.create(22.0); + auto storage = event->InsertCollection(std::move(coll1), "protoclusters"); + + REQUIRE(storage->GetSize() == 1); + REQUIRE(storage->GetStatus() == JDataBundle::Status::Inserted); + + // Retrieve triggers factory + auto coll = event->GetCollection("clusters"); + REQUIRE(coll->size() == 1); + REQUIRE(coll->at(0).energy() == 23.0); + +} +} // namespace omnifacreadinsert } // namespace podiotests