Skip to content

Commit

Permalink
Several code fix and improvements
Browse files Browse the repository at this point in the history
- fix Project.toml of test to allow use of Pkg.test() function
- fix table support
- fix behaviour of TTree constructor when column data are provided: data are now inserted in the tree.
- fix GC.@preserve
- add several input checks to improve error messages
- added support for StdValArray and StdString (not tested).
- added support for classes inherited from TObject and for TString
- StdVector, StdValArray, and StdString rexported from CxxWrap to save for the user the need to import CxxWrap.
  • Loading branch information
grasph committed Sep 15, 2024
1 parent becac27 commit a866685
Show file tree
Hide file tree
Showing 5 changed files with 228 additions and 125 deletions.
9 changes: 9 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,12 @@ version = "0.1.0"
CxxWrap = "1f15a43c-97ca-5a2a-ae31-89f07a497df4"
ROOT = "1706fdcc-8426-44f1-a283-5be479e9517c"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"

[compat]
julia = "1.6"

[extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test"]
5 changes: 0 additions & 5 deletions docs/src/examples.md
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,6 @@ Close(f)

### Example 5: columns provided as vectors

_⚠ This example is not yet working with this version of `RootIO`._

```julia
using RootIO, ROOT
using Random
Expand Down Expand Up @@ -164,9 +162,6 @@ Close(f)

This example illustrates how to store a table (in the `Tables.jl` sense), like a `NamedTuple` or a `DataFrame` from the `DataFrames.jl` package.

_⚠ This example is not yet working with this RootIO version._


```julia
using RootIO, ROOT
using Random
Expand Down
240 changes: 181 additions & 59 deletions src/RootIO.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
module RootIO

import ROOT, Tables, CxxWrap
import ROOT.Write, ROOT.Fill, ROOT.Scan, ROOT.Print, ROOT.GetEntries, CxxWrap.StdVector
import ROOT.Write, ROOT.Fill, ROOT.Scan, ROOT.Print, ROOT.GetEntries, CxxWrap.StdVector, CxxWrap.StdValArray, CxxWrap.StdString

export TTree, Write, Fill, Print, Scan, GetEntries, StdVector
export TTree, Write, Fill, Print, Scan, GetEntries, StdVector, StdValArray, StdString

"""
`TTree`
Expand All @@ -12,11 +12,28 @@ Type representing a `ROOT` tree. It must be used in place of the `TTree` type of
"""
struct TTree
_ROOT_ttree::ROOT.TTree # The ROOT TTree object.
_branch_array # An array of branches associated with the TTree.
_branch_names # An array of names of branches associated with the TTree
_branch_array::Vector{Union{CxxWrap.CxxPtr{ROOT.TBranch}, ROOT.TBranchPtr}} # Branches associated with the TTree.
_branch_names::Vector{Symbol} # Names of the branches associated with the TTree
_branch_types::Vector{DataType} # Expected Julia types for branch associated with the TTree
_size_branches::Vector{Union{Integer, Symbol}} # Name of the branch containing the size of a c-array,
# size if it is fixed, or 0 if i-th branch is not a c-array
_file::CxxWrap.CxxWrapCore.CxxPtr{ROOT.TFile} # A pointer to the ROOT file where the TTree is stored.
_rowbuffer::Vector{Any} # Current row, scalars are put in a zero-length array

TTree(tree, branches, names, types, sizebranches, file) = new(tree, branches, names, types, sizebranches,
file, Vector{DenseArray}(undef, length(branches)))
end

Base.convert(t::Type{StdVector{T}}, x::Array{T}) where {T} = StdVector(x)

# Helper function to retrieve the branch type from the provided type or type instance
# returns StdVector{T} for StdVector{T} and its subtypes
_branchtype(data::Type{T}) where {T <: StdVector{U}} where {U} = (StdVector{U}, 0)
_branchtype(data::Type{T}) where {T <: DenseArray{U}} where {U} = (StdVector{U}, 0)
_branchtype(data::Tuple{T,S}) where {T <: Type, S <: Union{Symbol, Integer}} = (Vector{data[1]}, data[2])
_branchtype(type::Type) = (type, 0)
_branchtype(data) = invoke(_branchtype, [Type], typeof(data)) #use invoke to be 100% sure to not call the same method recursively

"""
`Scan(tree, varexp = "", selection = "", option = "", nentries = -1, firstentry = 0) `
Expand Down Expand Up @@ -110,35 +127,48 @@ end

# # Returns
# - A RootIO `TTree` object containing the ROOT TTree and its branches.
function _makeTTree(file::CxxWrap.CxxWrapCore.CxxPtr{ROOT.TFile}, name::String, title::String, branch_types, branch_names)
function _makeTTree(file::CxxWrap.CxxWrapCore.CxxPtr{ROOT.TFile}, name::String, title::String, branch_names, branch_types)
tree = ROOT.TTree(name, title)
current_branches = []
for i in eachindex(branch_types)
if isa(branch_types[i], Tuple)
type_identifier = _getTypeCharacter(branch_types[i][1])
if isa(branch_types[i][2], Symbol)
curr_branch = ROOT.Branch(tree, string(branch_names[i]), Ptr{Nothing}(), "$(branch_names[i])[$(string(branch_types[i][2]))]/$(type_identifier)")
current_branches = Union{CxxWrap.CxxPtr{ROOT.TBranch}, ROOT.TBranchPtr}[]

(branch_datatypes, sizebranches) = tuple.(_branchtype.(branch_types)...)

for i in eachindex(branch_datatypes)
if sizebranches[i] !== 0 #c-array
type_identifier = _getTypeCharacter(eltype(branch_datatypes[i]))
isnothing(type_identifier) && throw(ArgumentError("Element type $(eltype(branch_datatypes[i])) is not supported for a c-array storage. You can use StdVector storage mode instead."))
if isa(sizebranches[i], Symbol)
curr_branch = ROOT.Branch(tree, string(branch_names[i]), Ptr{Nothing}(), "$(branch_names[i])[$(string(sizebranches[i]))]/$(type_identifier)")
push!(current_branches, curr_branch)
else
curr_branch = ROOT.Branch(tree, string(branch_names[i]), Ptr{Nothing}(), "$(branch_names[i])[$(branch_types[i][2])]/$(type_identifier)")
elseif isa(sizebranches[i], Number) && sizebranches[i] > 0
curr_branch = ROOT.Branch(tree, string(branch_names[i]), Ptr{Nothing}(), "$(branch_names[i])[$(sizebranches[i])]/$(type_identifier)")
push!(current_branches, curr_branch)
else
throw(ArgumentError("Bad c-array specification, $(branch_types[i]). Second element needs to be either a symbol or a strictly positive number."))
end
elseif branch_types[i] <: CxxWrap.StdVector
ptr = (branch_types[i])()
else
if branch_datatypes[i] <: CxxWrap.StdVector
ptr = (branch_datatypes[i])()
curr_branch = ROOT.Branch(tree, string(branch_names[i]), ptr, 3200, 99)
push!(current_branches, curr_branch)
elseif branch_types[i] == String
elseif branch_datatypes[i] == String
curr_branch = ROOT.Branch(tree, string(branch_names[i]), Ptr{Nothing}(), "$(branch_names[i])/C")
push!(current_branches, curr_branch)
elseif branch_types[i] == Bool
curr_branch = ROOT.Branch(tree, string(branch_names[i]), Ptr{Int8}(), "$(branch_names[i])/O")
elseif branch_datatypes[i] == Bool
curr_branch = ROOT.Branch(tree, string(branch_names[i]), Ptr{Nothing}(), "$(branch_names[i])/O")
push!(current_branches, curr_branch)
elseif branch_datatypes[i] <: Union{ROOT.TObject, ROOT.TString}
classname = replace(string(branch_datatypes[i]), "ROOT." => "")
curr_branch = ROOT.Branch(tree, string(branch_names[i]), classname, Ptr{Nothing}(), 3200, 99)
push!(current_branches, curr_branch)
else
curr_branch = ROOT.Branch(tree, string(branch_names[i]), Ref(one(branch_types[i])), 3200, 99)
curr_branch = ROOT.Branch(tree, string(branch_names[i]), Ref{branch_datatypes[i]}(), 3200, 99)
push!(current_branches, curr_branch)
end
end
return TTree(tree, current_branches, branch_names, file)
end

return TTree(tree, current_branches, collect(branch_names), collect(branch_datatypes), collect(sizebranches), file)
end

"""
Expand Down Expand Up @@ -180,17 +210,25 @@ end
Scan(tree)
```
"""
function TTree(file::CxxWrap.CxxWrapCore.CxxPtr{ROOT.TFile}, name::String, title::String, data)
branch_types= []
branch_names = []
if isa(data, DataType)
branch_types = fieldtypes(data)
branch_names = fieldnames(data)
else
branch_types = fieldtypes(typeof(data))
branch_names = fieldnames(typeof(data))
function TTree(file::CxxWrap.CxxWrapCore.CxxPtr{ROOT.TFile}, name::String, title::String, rowtype::DataType)
branchnames = fieldnames.(rowtype)
branchtypes = fieldtypes.(rowtype)
return _makeTTree(file, name, title, branchnames, branchtypes)
end

# no stringdoc here, as it is shared with the previous method declaration
function TTree(file::CxxWrap.CxxWrapCore.CxxPtr{ROOT.TFile}, name::String, title::String, table)
if Tables.istable(table)
sch = Tables.schema(table)
isnothing(sch) && error("Failed to retrieve the schema of the provided table")
tree = _makeTTree(file, name, title, sch.names, sch.types)
Fill(tree, table)
tree
else #handle the case where the 4th argument is an instance of a single row
type = typeof(rowtype)
isa(type, DataType) || throw(ArgumentError("The rowtype argument needs to be a DataType."))
TTree(file, name, title, type)
end
return _makeTTree(file, name, title, branch_types, branch_names)
end

"""
Expand Down Expand Up @@ -233,20 +271,75 @@ data = (col_float=rand(Float64, 3), col_int=rand(Int32, 3))
tree = RootIO.TTree(file, name, title; data...)
```
"""
function TTree(file::CxxWrap.CxxWrapCore.CxxPtr{ROOT.TFile}, name::String, title::String; kwargs...)
_branch_types_array = []
_branch_names_array = []

for (key, value) in kwargs
push!(_branch_names_array, key)
if isa(value, Tuple) || isa(value, DataType)
push!(_branch_types_array, value)
function TTree(file::CxxWrap.CxxWrapCore.CxxPtr{ROOT.TFile}, name::String, title::String; columns...)
branch_types = []
branch_names = Symbol[]

nrowdata = 0
for (key, value) in columns
push!(branch_names, key)
if isa(value, Union{DataType, Tuple{DataType, Any}})
nrowdata > 0 && throw(ArgumentError("Mix of keyword argument specifying column type and contents"))
push!(branch_types, value)
elseif isa(value, Union{Vector, Tuple})
push!(branch_types, eltype(value))
if nrowdata == 0
nrowdata = length(value)
else
push!(_branch_types_array, eltype(value))
nrowdata == length(value) || throw(ArgumentError("Column size mimatch"))
end
else
throw(ArgumentError("Invalid value for keyword argument $key."))
end
end

return _makeTTree(file, name, title, _branch_types_array, _branch_names_array)
tree = _makeTTree(file, name, title, branch_names, branch_types)

if nrowdata > 0
data = last(zip(columns...))
for row in zip(data...)
Fill(tree, row)
end
end

return tree
end

_SupportedStdContainers = Union{StdVector, StdValArray, StdString}

#wrap data to reference them in the _rowbuffer.
#everything is put in a mutable as we need a pointer to the data to pass to the C++ library
_wrap(a) = fill(a)
_wrap(a::DenseArray) = a
#_wrap(a::StdVector) = a
_wrap(a::_SupportedStdContainers) = a
_wrap(a::String) = a

# Set address of data to store in a branch

_SetColAddress(tree::RootIO.TTree, icol, x::Union{StdVector, ROOT.TObject}) = ROOT.SetObject(tree._branch_array[icol], x)

@inline function _SetColAddress(tree::RootIO.TTree, icol, x::DenseArray)
#FIXME handle nested std::vector, used for multidimensionnal arrays
if tree._branch_types[icol] <: _SupportedStdContainers
ROOT.SetObject(tree._branch_array[icol], x)
elseif tree._branch_types[icol] <: Union{String, Bool, Vector}
ROOT.SetAddress(tree._branch_array[icol], convert(Ptr{Nothing}, pointer(x)))
elseif tree._branch_types[icol] <: Union{ROOT.TObject, ROOT.TString}
ROOT.SetAddress(tree._branch_array[icol], convert(Ptr{Nothing}, pointer(x)))
else
ROOT.SetAddress(tree._branch_array[icol], x)
end
end

_SetColAddress(tree::RootIO.TTree, icol, x::Union{Bool, String}) = ROOT.SetAddress(tree._branch_array[icol], convert(Ptr{Nothing}, pointer(x)))


function _UpdateAddresses(tree)
for (icol, data) in enumerate(tree._rowbuffer)
_SetColAddress(tree, icol, data)
end
nothing
end

"""
Expand All @@ -273,33 +366,62 @@ end
Scan(tree)
```
"""
function Fill(tree::TTree, data)
function Fill(tree::TTree, data; unsafe = false)
if Tables.istable(data)
for row in Tables.rows(data)
Fill(tree, row)
end
_fillFromTable(tree, data, unsafe = unsafe)
elseif(applicable(fieldnames, typeof(data))
&& !isa(data, Tuple) #Exclude Tuples which have fieldnames :1, :2,...
&& length(fieldnames(typeof(data))) > 0)
#a composite type or a NamedTuple
_fillOneRowFromStructOrNamedTuple(tree, data)
else
row = data
if !isa(row, Array)
row = map(field -> getfield(data, field), fieldnames(typeof(data)))
_fillOneRowFromIterable(tree, data)
end
end

function _fillOneRowFromStructOrNamedTuple(tree, data)
for (icol, nm) in enumerate(tree._branch_names)
_set(tree, icol, getfield(data, nm))
end
_fill(tree)
end

function _fillOneRowFromIterable(tree, data)
nfilled = 0
for (icol, val) in enumerate(data)
icol > length(tree._branch_array) && throw(ArgumentError("Provided data contains two many columns, expected $(length(tree._branch_array))."))
_set(tree, icol, val)
nfilled += 1
end
nfilled == length(tree._branch_array) || throw(ArgumentError("Provided data does not contain enough columns, expected $(length(tree._branch_array))."))
_fill(tree)
end

function _fillFromTable(tree::TTree, data; unsafe = false)
sch = Tables.schema(data)
if !unsafe
all(sch.names .== tree._branch_names) || throw(ArgumentError("Column name or order mismatch, Got $(sch.names), expected $(tree._branch_names)"))
end
for row in Tables.rows(data)
Tables.eachcolumn(sch, row) do val, icol, colname
_set(tree, icol, val)
end
GC.@preserve row begin
for i in eachindex(tree._branch_array)
if isa(row[i], Array)
ROOT.SetAddress(tree._branch_array[i], convert(Ptr{Nothing}, pointer(row[i])))
elseif isa(row[i], CxxWrap.StdVector)
ROOT.SetObject(tree._branch_array[i], row[i])
elseif isa(row[i], String)
ROOT.SetAddress(tree._branch_array[i], convert(Ptr{Nothing}, pointer(row[i])))
elseif isa(row[i], Bool)
ROOT.SetAddress(tree._branch_array[i], convert(Ptr{Nothing}, pointer(fill(row[i]))))
else
ROOT.SetAddress(tree._branch_array[i], Ref(row[i]))
_fill(tree)
end
end


function _fill(tree)
rowbuffer = tree._rowbuffer
GC.@preserve rowbuffer begin
_UpdateAddresses(tree)
ROOT.Fill(tree._ROOT_ttree)
end
end

# Set value of a column in the row buffer.
function _set(tree, icol, val)
tree._rowbuffer[icol] = _wrap(convert(tree._branch_types[icol], val))
end

end # module RootIO
3 changes: 0 additions & 3 deletions test/Project.toml

This file was deleted.

Loading

0 comments on commit a866685

Please sign in to comment.