Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RF] Implement channel masking for simultaneous likelihoods #16999

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions roofit/codegen/inc/RooFit/CodegenImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ namespace Detail {
class RooFixedProdPdf;
class RooNLLVarNew;
class RooNormalizedPdf;
class RooSimNLL;
} // namespace Detail

namespace Experimental {
Expand All @@ -76,6 +77,7 @@ class CodegenContext;
void codegenImpl(RooFit::Detail::RooFixedProdPdf &arg, CodegenContext &ctx);
void codegenImpl(RooFit::Detail::RooNLLVarNew &arg, CodegenContext &ctx);
void codegenImpl(RooFit::Detail::RooNormalizedPdf &arg, CodegenContext &ctx);
void codegenImpl(RooFit::Detail::RooSimNLL &arg, CodegenContext &ctx);
void codegenImpl(ParamHistFunc &arg, CodegenContext &ctx);
void codegenImpl(PiecewiseInterpolation &arg, CodegenContext &ctx);
void codegenImpl(RooAbsArg &arg, CodegenContext &ctx);
Expand Down
32 changes: 23 additions & 9 deletions roofit/codegen/src/CodegenImpl.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,28 @@
}
}

void codegenImpl(RooFit::Detail::RooSimNLL &arg, CodegenContext &ctx)
{
if (arg.terms().empty()) {
ctx.addResult(&arg, "0.0");
}

std::string resName = RooFit::Detail::makeValidVarName(arg.GetName()) + "Result";
ctx.addResult(&arg, resName);
ctx.addToGlobalScope("double " + resName + " = 0.0;\n");

std::stringstream ss;

std::size_t i = 0;

Check warning on line 148 in roofit/codegen/src/CodegenImpl.cxx

View workflow job for this annotation

GitHub Actions / mac13 ARM64 LLVM_ENABLE_ASSERTIONS=On, builtin_zlib=ON

variable 'i' set but not used [-Wunused-but-set-variable]

Check warning on line 148 in roofit/codegen/src/CodegenImpl.cxx

View workflow job for this annotation

GitHub Actions / mac14 X64 LLVM_ENABLE_ASSERTIONS=On, CMAKE_CXX_STANDARD=20

variable 'i' set but not used [-Wunused-but-set-variable]

Check warning on line 148 in roofit/codegen/src/CodegenImpl.cxx

View workflow job for this annotation

GitHub Actions / mac15 ARM64 LLVM_ENABLE_ASSERTIONS=On, CMAKE_CXX_STANDARD=20

variable 'i' set but not used [-Wunused-but-set-variable]

Check warning on line 148 in roofit/codegen/src/CodegenImpl.cxx

View workflow job for this annotation

GitHub Actions / mac-beta ARM64 LLVM_ENABLE_ASSERTIONS=On, CMAKE_CXX_STANDARD=20

variable 'i' set but not used [-Wunused-but-set-variable]

Check warning on line 148 in roofit/codegen/src/CodegenImpl.cxx

View workflow job for this annotation

GitHub Actions / alma9-clang clang LLVM_ENABLE_ASSERTIONS=On, CMAKE_C_COMPILER=clang, CMAKE_CXX_COMPILER=clang++

variable 'i' set but not used [-Wunused-but-set-variable]
for (auto *component : static_range_cast<RooAbsReal *>(arg.terms())) {

// TODO: support channel masking here
ss << resName << " += " << ctx.buildFunction(*component, ctx.outputSizes()) << "(params, obs, xlArr);\n";
++i;
}
ctx.addToGlobalScope(ss.str());
}

void codegenImpl(ParamHistFunc &arg, CodegenContext &ctx)
{
std::string const &idx = arg.dataHist().calculateTreeIndexForCodeSquash(&arg, ctx, arg.dataVars(), true);
Expand Down Expand Up @@ -251,15 +273,7 @@

std::size_t i = 0;
for (auto *component : static_range_cast<RooAbsReal *>(arg.list())) {

if (!dynamic_cast<RooFit::Detail::RooNLLVarNew *>(component) || arg.list().size() == 1) {
result += ctx.getResult(*component);
++i;
if (i < arg.list().size())
result += '+';
continue;
}
result += ctx.buildFunction(*component, ctx.outputSizes()) + "(params, obs, xlArr)";
result += ctx.getResult(*component);
++i;
if (i < arg.list().size())
result += '+';
Expand Down
1 change: 1 addition & 0 deletions roofit/roofitcore/inc/LinkDef.h
Original file line number Diff line number Diff line change
Expand Up @@ -337,5 +337,6 @@
#pragma link C++ class RooBinWidthFunction+;
#pragma link C++ class RooFit::Detail::RooNLLVarNew+;
#pragma link C++ class RooFit::Detail::RooNormalizedPdf+ ;
#pragma link C++ class RooFit::Detail::RooSimNLL+;

#endif
27 changes: 27 additions & 0 deletions roofit/roofitcore/inc/RooFit/Detail/RooNLLVarNew.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,13 @@
#include <RooAbsPdf.h>
#include <RooAbsReal.h>
#include <RooGlobalFunc.h>
#include <RooListProxy.h>
#include <RooTemplateProxy.h>

#include <Math/Util.h>

class RooAbsCategory;

namespace RooFit {
namespace Detail {

Expand Down Expand Up @@ -87,6 +90,30 @@ class RooNLLVarNew : public RooAbsReal {
ClassDefOverride(RooFit::Detail::RooNLLVarNew, 0);
};

class RooSimNLL : public RooAbsReal {
public:
RooSimNLL(const char *name, const char *title, const RooArgSet &terms, RooAbsCategoryLValue const &indexCat,
bool channelMasking);

RooSimNLL(const RooSimNLL &other, const char *name = nullptr);
TObject *clone(const char *newname) const override { return new RooSimNLL(*this, newname); }

double defaultErrorLevel() const override;

const RooArgSet &terms() const { return _set; }
const RooArgSet &masks() const { return _mask; }

void doEval(RooFit::EvalContext &) const override;

protected:
double evaluate() const override;

RooSetProxy _set; ///< set of terms to be summed
RooSetProxy _mask;

ClassDefOverride(RooFit::Detail::RooSimNLL, 0) // Sum of RooNLLVarNew instances
};

} // namespace Detail
} // namespace RooFit

Expand Down
3 changes: 2 additions & 1 deletion roofit/roofitcore/src/FitHelpers.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
#endif

using RooFit::Detail::RooNLLVarNew;
using RooFit::Detail::RooSimNLL;

namespace {

Expand Down Expand Up @@ -357,7 +358,7 @@ std::unique_ptr<RooAbsArg> createSimultaneousNLL(RooSimultaneous const &simPdf,
}

// Time to sum the NLLs
auto nll = std::make_unique<RooAddition>("mynll", "mynll", nllTerms);
auto nll = std::make_unique<RooSimNLL>("mynll", "mynll", nllTerms, simCat, true);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are the test failures due to this change?

nll->addOwnedComponents(std::move(nllTerms));
return nll;
}
Expand Down
59 changes: 56 additions & 3 deletions roofit/roofitcore/src/RooNLLVarNew.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,15 @@ computation times.

#include "RooFit/Detail/RooNLLVarNew.h"

#include <RooHistPdf.h>
#include <RooAbsCategoryLValue.h>
#include <RooBatchCompute.h>
#include <RooConstVar.h>
#include <RooDataHist.h>
#include <RooFit/Detail/MathFuncs.h>
#include <RooHistPdf.h>
#include <RooNaNPacker.h>
#include <RooConstVar.h>
#include <RooRealVar.h>
#include <RooSetProxy.h>
#include <RooFit/Detail/MathFuncs.h>

#include "RooFitImplHelpers.h"

Expand All @@ -47,6 +48,7 @@ computation times.
#include <vector>

ClassImp(RooFit::Detail::RooNLLVarNew);
ClassImp(RooFit::Detail::RooSimNLL);

namespace RooFit {
namespace Detail {
Expand Down Expand Up @@ -336,6 +338,57 @@ void RooNLLVarNew::finalizeResult(RooFit::EvalContext &ctx, ROOT::Math::KahanSum
ctx.setOutputWithOffset(this, result, _offset);
}

RooSimNLL::RooSimNLL(const char *name, const char *title, const RooArgSet &terms, RooAbsCategoryLValue const &indexCat,
bool channelMasking)
: RooAbsReal(name, title), _set("!set", "set of components", this), _mask("!mask", "set of masks", this)
{
_set.addTyped<RooAbsReal>(terms);

if (channelMasking) {
for (auto const &catState : indexCat) {
std::string const &catName = catState.first;
std::string maskName = "mask_" + catName;
_mask.addOwned(std::make_unique<RooRealVar>(maskName.c_str(), maskName.c_str(), 0.0));
}
}
}

RooSimNLL::RooSimNLL(const RooSimNLL &other, const char *name)
: RooAbsReal(other, name), _set("!set", this, other._set), _mask("!mask", this, other._set)
{
}

double RooSimNLL::evaluate() const
{
double sum(0);
const RooArgSet *nset = _set.nset();

std::size_t i = 0;
for (auto *comp : static_range_cast<RooAbsReal *>(_set)) {
if (_mask.empty() || static_cast<RooAbsReal const *>(_mask[i])->getVal() == 0.0) {
sum += comp->getVal(nset);
}
++i;
}
return sum;
}

void RooSimNLL::doEval(RooFit::EvalContext &ctx) const
{
double result = 0.;
for (std::size_t i = 0; i < _set.size(); ++i) {
if (_mask.empty() || ctx.at(_mask[i])[0] == 0.0) {
result += ctx.at(_set[i])[0];
}
}
ctx.output()[0] = result;
}

double RooSimNLL::defaultErrorLevel() const
{
return static_cast<RooNLLVarNew *>(_set[0])->defaultErrorLevel();
}

} // namespace Detail
} // namespace RooFit

Expand Down
3 changes: 2 additions & 1 deletion roofit/roofitcore/test/testRooFuncWrapper.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ void randomizeParameters(const RooArgSet &parameters)
double mul = unif(re);

auto par = dynamic_cast<RooAbsRealLValue *>(param);
if (!par)
if (!par || par->isConstant())
continue;
double val = par->getVal();
val = val + mul * (mul > 0 ? (par->getMax() - val) : (val - par->getMin()));
Expand Down Expand Up @@ -241,6 +241,7 @@ TEST_P(FactoryTest, NLLFit)
// We don't use the RooFit::Evaluator for the nominal likelihood. Like this,
// we make sure to validate also the NLL values of the generated code.
static_cast<RooFit::Experimental::RooFuncWrapper &>(*nllFunc).disableEvaluator();
static_cast<RooFit::Experimental::RooFuncWrapper &>(*nllFunc).writeDebugMacro(_params._name);

double tol = _params._fitResultTolerance;

Expand Down
Loading