Skip to content

Commit

Permalink
Merge pull request #193 from rayosborn/improve-group-initialization
Browse files Browse the repository at this point in the history
Improve group initialization with the `entries` dictionary.
  • Loading branch information
rayosborn authored Mar 15, 2023
2 parents a17eaf7 + 4c3db54 commit 6715bf4
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 31 deletions.
4 changes: 2 additions & 2 deletions conda-recipe/meta.yaml
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
package:
name: nexusformat
version: "1.0.0"
version: "1.0.1"

source:
git_url: https://github.com/nexpy/nexusformat.git
git_tag: v1.0.0
git_tag: v1.0.1

build:
entry_points:
Expand Down
12 changes: 3 additions & 9 deletions src/nexusformat/nexus/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -4438,24 +4438,18 @@ def __init__(self, *args, **kwargs):
self._class = kwargs.pop("nxclass")
if "group" in kwargs:
self._group = kwargs.pop("group")
self._entries = None
if "entries" in kwargs:
self._entries = {}
for k, v in kwargs["entries"].items():
self._entries[k] = deepcopy(v)
self[k] = v
del kwargs["entries"]
else:
self._entries = None
if "attrs" in kwargs:
self._attrs = AttrDict(self, attrs=kwargs["attrs"])
del kwargs["attrs"]
else:
self._attrs = AttrDict(self)
for k, v in kwargs.items():
try:
self[k] = v
except AttributeError:
raise NeXusError(
"Keyword arguments must be valid NXobjects")
self[k] = v
if self.nxclass.startswith("NX"):
if self.nxname == "unknown" or self.nxname == "":
self._name = self.nxclass[2:]
Expand Down
76 changes: 56 additions & 20 deletions tests/test_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,42 @@ def test_group_creation(field1, field2, field3):
assert group3["g1"].nxgroup == group3


def test_group_entries(field1, field2, field3, field4, arr1D, arr2D):

entries = {"f2": field2, "f3": arr1D, "s1": "string",
"g1": NXgroup(field3)}

group1 = NXgroup(field1, f4=field4, f5=arr2D, entries=entries)

assert "f1" in group1
assert "f2" in group1
assert "f3" in group1
assert "f4" in group1
assert "f5" in group1
assert "s1" in group1
assert "g1" in group1
assert "g1/f3" in group1

assert group1["f1"] == field1
assert group1["f2"] == field2
assert group1["f3"].nxdata.sum() == arr1D.sum()
assert group1["f4"].nxdata == field4.nxdata
assert group1["f5"].nxdata.sum() == arr2D.sum()
assert group1["s1"].nxdata == "string"
assert group1["g1/f3"] == field3


def test_group_attrs():

group1 = NXgroup(attrs={"a": "b", "c": 1})

assert "a" in group1.attrs
assert "c" in group1.attrs

assert group1.attrs["a"] == "b"
assert group1.attrs["c"] == 1


def test_group_insertion(field2):

group1 = NXgroup()
Expand All @@ -50,7 +86,7 @@ def test_group_insertion(field2):
assert len(group1) == 1


def test_rename(field1):
def test_group_rename(field1):

group = NXgroup(field1)

Expand Down Expand Up @@ -82,13 +118,13 @@ def test_group_class():

def test_group_components():

g1 = NXdata(name='g1')
g2 = NXdata(name='g2')
g3 = NXdata(name='g3')
g1 = NXdata(name="g1")
g2 = NXdata(name="g2")
g3 = NXdata(name="g3")
group = NXentry(g1, g2, g3)

assert group.component('NXdata') == [group['g1'], group['g2'], group['g3']]
assert group.NXdata == [group['g1'], group['g2'], group['g3']]
assert group.component("NXdata") == [group["g1"], group["g2"], group["g3"]]
assert group.NXdata == [group["g1"], group["g2"], group["g3"]]


def test_group_title():
Expand All @@ -102,24 +138,24 @@ def test_group_title():
def test_group_move(field1):

group = NXentry()
group['g1'] = NXgroup()
group['g1/f1'] = field1
group['g2'] = NXgroup()
group['g1'].move('f1', 'g2', name='f2')
group["g1"] = NXgroup()
group["g1/f1"] = field1
group["g2"] = NXgroup()
group["g1"].move("f1", "g2", name="f2")

assert 'g1/f1' not in group
assert 'g2/f2' in group
assert "g1/f1" not in group
assert "g2/f2" in group

group['g2'].move(group['g2/f2'], group['g1'], name='f1')
group["g2"].move(group["g2/f2"], group["g1"], name="f1")

assert 'g2/f2' not in group
assert 'g1/f1' in group
assert "g2/f2" not in group
assert "g1/f1" in group

group['g3'] = NXgroup()
group['g2/f2'] = NXlink(target='g1/f1')
group['g2'].move('f2', 'g3', name='f3')
group["g3"] = NXgroup()
group["g2/f2"] = NXlink(target="g1/f1")
group["g2"].move("f2", "g3", name="f3")

assert group['g3/f3'].nxlink == field1
assert group["g3/f3"].nxlink == field1


def test_group_copy(tmpdir, field1):
Expand All @@ -130,7 +166,7 @@ def test_group_copy(tmpdir, field1):

external_filename = os.path.join(tmpdir, "file2.nxs")
external_root = NXroot(
NXentry(NXgroup(field1, name='g1', attrs={"a": "b"})))
NXentry(NXgroup(field1, name="g1", attrs={"a": "b"})))
external_root.save(external_filename, mode="w")

root["entry/g2"] = NXlink(target="entry/g1", file=external_filename)
Expand Down

0 comments on commit 6715bf4

Please sign in to comment.