Skip to content

Commit

Permalink
Switched to lxml for mslice export.
Browse files Browse the repository at this point in the history
Should be faster now, although I still need to do more detailed testing.
  • Loading branch information
hexane360 committed Jan 20, 2024
1 parent a0ef0a3 commit 16115ed
Show file tree
Hide file tree
Showing 9 changed files with 94 additions and 70 deletions.
94 changes: 47 additions & 47 deletions atomlib/io/mslice.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@

from __future__ import annotations

from xml.etree import ElementTree as et
from lxml import etree as et # type: ignore
#from xml.etree import ElementTree as et
import builtins
from copy import deepcopy
from warnings import warn
Expand All @@ -17,26 +18,28 @@
from numpy.typing import ArrayLike
import polars

from ..util import FileOrPath, open_file
from ..util import FileOrPath, open_file, open_file_binary, BinaryFileOrPath
from ..atoms import Atoms
from ..cell import Cell
from ..atomcell import HasAtomCell, AtomCell
from ..transform import AffineTransform3D, LinearTransform3D


MSliceFile = t.Union[et.ElementTree, FileOrPath]
ElementTree = et._ElementTree
Element = et._Element
MSliceFile = t.Union[ElementTree, FileOrPath]


DEFAULT_TEMPLATE_PATH = files('atomlib.data') / 'template.mslice'
DEFAULT_TEMPLATE: t.Optional[et.ElementTree] = None
DEFAULT_TEMPLATE: t.Optional[ElementTree] = None


def default_template() -> et.ElementTree:
def default_template() -> ElementTree:
global DEFAULT_TEMPLATE

if DEFAULT_TEMPLATE is None:
with DEFAULT_TEMPLATE_PATH.open('r') as f: # type: ignore
DEFAULT_TEMPLATE = et.parse(f) # type: ignore
DEFAULT_TEMPLATE = t.cast(ElementTree, et.parse(f, None))

return deepcopy(DEFAULT_TEMPLATE)

Expand All @@ -54,10 +57,10 @@ def convert_xml_value(val, ty):
return getattr(builtins, ty)(val)


def parse_xml_object(obj) -> t.Dict[str, t.Any]:
def parse_xml_object(obj: Element) -> t.Dict[str, t.Any]:
"""Parse the attributes of a passed XML object."""
params = {}
for attr in obj:
for attr in t.cast(t.Iterator[Element], obj.iter(None)):
if attr.tag == 'attribute':
params[attr.attrib['name']] = convert_xml_value(attr.text, attr.attrib['type'])
elif attr.tag == 'relationship':
Expand All @@ -67,35 +70,36 @@ def parse_xml_object(obj) -> t.Dict[str, t.Any]:
return params


def find_xml_object(xml, typename) -> t.Dict[str, t.Any]:
def find_xml_object(xml: Element, typename: str) -> t.Dict[str, t.Any]:
"""Find and parse XML objects named `typename`, flattening them into a single Dict."""
params = {}
for obj in xml.findall(f".//*[@type='{typename}']"):
for obj in xml.findall(f".//*[@type='{typename}']", None):
params.update(parse_xml_object(obj))
return params


def find_xml_object_list(xml, typename) -> t.List[t.Any]:
def find_xml_object_list(xml: Element, typename: str) -> t.List[t.Any]:
"""Find and parse a list of XML objects named `typename`."""
return [parse_xml_object(obj) for obj in xml.findall(f".//*[@type='{typename}']")]
return [parse_xml_object(obj) for obj in xml.findall(f".//*[@type='{typename}']", None)]


def find_xml_object_dict(xml, typename, key="id") -> t.Dict[str, t.Any]:
def find_xml_object_dict(xml: Element, typename: str, key: str = "id") -> t.Dict[str, t.Any]:
"""Find and parse XML objects named `typename`, combining them into a dict."""
return {
obj.attrib[key]: parse_xml_object(obj)
for obj in xml.findall(f".//*[@type='{typename}']")
for obj in xml.findall(f".//*[@type='{typename}']", None)
}


def read_mslice(path: MSliceFile) -> AtomCell:
if isinstance(path, et.ElementTree):
tree: ElementTree
if isinstance(path, ElementTree):
tree = path
else:
with open_file(path, 'r') as t:
tree = et.parse(t)
with open_file(path, 'r') as temp:
tree = et.parse(temp, None)

xml = tree.getroot()
xml: Element = tree.getroot()

structure = find_xml_object(xml, "STRUCTURE")
structure_atoms = find_xml_object_list(xml, "STRUCTUREATOM")
Expand All @@ -118,7 +122,7 @@ def read_mslice(path: MSliceFile) -> AtomCell:
return AtomCell(atoms, cell, frame='cell_frac')


def write_mslice(cell: HasAtomCell, f: FileOrPath, template: t.Optional[MSliceFile] = None, *,
def write_mslice(cell: HasAtomCell, f: BinaryFileOrPath, template: t.Optional[MSliceFile] = None, *,
slice_thickness: t.Optional[float] = None, # angstrom
scan_points: t.Optional[ArrayLike] = None,
scan_extent: t.Optional[ArrayLike] = None,
Expand Down Expand Up @@ -157,47 +161,48 @@ def write_mslice(cell: HasAtomCell, f: FileOrPath, template: t.Optional[MSliceFi
.transform(AffineTransform3D.scale(1/box_size)) \
.with_wobble().with_occupancy()

out: ElementTree
if template is None:
out = default_template()
elif not isinstance(template, et.ElementTree):
with open_file(template, 'r') as t:
out = et.parse(t)
elif not isinstance(template, ElementTree):
with open_file(template, 'r') as temp:
out = et.parse(temp, None)
else:
out = deepcopy(template)

# TODO clean up this code
db = out.getroot() if out.getroot().tag == 'database' else out.find("./database")
db: t.Optional[Element] = out.getroot() if out.getroot().tag == 'database' else out.find("./database", None)
if db is None:
raise ValueError("Couldn't find 'database' tag in template.")

struct = db.find(".//object[@type='STRUCTURE']")
struct = db.find(".//object[@type='STRUCTURE']", None)
if struct is None:
raise ValueError("Couldn't find STRUCTURE object in template.")

params = db.find(".//object[@type='SIMPARAMETERS']")
params = db.find(".//object[@type='SIMPARAMETERS']", None)
if params is None:
raise ValueError("Couldn't find SIMPARAMETERS object in template.")

microscope = db.find(".//object[@type='MICROSCOPE']")
microscope = db.find(".//object[@type='MICROSCOPE']", None)
if microscope is None:
raise ValueError("Couldn't find MICROSCOPE object in template.")

scan = db.find(".//object[@type='SCAN']")
aberrations = db.findall(".//object[@type='ABERRATION']")
scan = db.find(".//object[@type='SCAN']", None)
aberrations = db.findall(".//object[@type='ABERRATION']", None)

def set_attr(struct: et.Element, name: str, type: str, val: str):
node = struct.find(f".//attribute[@name='{name}']")
def set_attr(struct: Element, name: str, type: str, val: str):
node = t.cast(t.Optional[Element], struct.find(f".//attribute[@name='{name}']", None))
if node is None:
node = et.Element('attribute', dict(name=name, type=type))
node = t.cast(Element, et.Element('attribute', dict(name=name, type=type), None))
struct.append(node)
else:
node.attrib['type'] = type
node.text = val
node.text = val # type: ignore

def parse_xml_object(obj: et.Element) -> t.Dict[str, t.Any]:
def parse_xml_object(obj: Element) -> t.Dict[str, t.Any]:
"""Parse the attributes of a passed XML object."""
params = {}
for attr in obj:
for attr in obj.iterchildren(None):
if attr.tag == 'attribute':
params[attr.attrib['name']] = convert_xml_value(attr.text, attr.attrib['type'])
elif attr.tag == 'relationship':
Expand Down Expand Up @@ -277,7 +282,7 @@ def parse_xml_object(obj: et.Element) -> t.Dict[str, t.Any]:
set_attr(elem, name, 'float', f"{float(val):.8g}")

# remove existing atoms
for elem in db.findall("./object[@type='STRUCTUREATOM']"):
for elem in db.findall("./object[@type='STRUCTUREATOM']", None):
db.remove(elem)

# <u^2> -> 1d sigma
Expand All @@ -287,20 +292,15 @@ def parse_xml_object(obj: et.Element) -> t.Dict[str, t.Any]:
e = _atom_elem(i, elem, x, y, z, wobble, frac_occupancy)
db.append(e)

et.indent(db, space=" ", level=0)
et.indent(db, space=" ", level=0) # type: ignore

with open_file(f, 'w') as f:
# hack to specify doctype of output
f.write("""\
<?xml version="1.0" encoding="UTF-8" standalone="yes" ?>
<!DOCTYPE database SYSTEM "file:///System/Library/DTDs/CoreData.dtd">
with open_file_binary(f, 'w') as f:
doctype = b"""<!DOCTYPE database SYSTEM "file:///System/Library/DTDs/CoreData.dtd">\n"""
out.write(f, encoding='UTF-8', xml_declaration=True, standalone=True, doctype=doctype) # type: ignore
f.write(b'\n')

""")
out.write(f, encoding='unicode', xml_declaration=False, short_empty_elements=False)
f.write('\n')


def _atom_elem(i: int, atomic_number: int, x: float, y: float, z: float, wobble: float = 0., frac_occupancy: float = 1.) -> et.Element:
def _atom_elem(i: int, atomic_number: int, x: float, y: float, z: float, wobble: float = 0., frac_occupancy: float = 1.) -> Element:
return et.XML(f"""\
<object type="STRUCTUREATOM" id="atom{i}">
<attribute name="x" type="float">{x:.8f}</attribute>
Expand All @@ -309,4 +309,4 @@ def _atom_elem(i: int, atomic_number: int, x: float, y: float, z: float, wobble:
<attribute name="wobble" type="float">{wobble:.4f}</attribute>
<attribute name="fracoccupancy" type="float">{frac_occupancy:.4f}</attribute>
<attribute name="atomicnumber" type="int16">{atomic_number}</attribute>
</object>""")
</object>""", None)
8 changes: 4 additions & 4 deletions atomlib/io/test_mslice.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy

from .mslice import write_mslice
from ..testing import check_equals_file, INPUT_PATH, OUTPUT_PATH
from ..testing import check_equals_binary_file, INPUT_PATH, OUTPUT_PATH
from ..make import fcc, fluorite
from ..io import read
from ..atomcell import HasAtomCell
Expand All @@ -18,18 +18,18 @@ def ceo2_ortho_cell():
.transform(AffineTransform3D.rotate_euler(x=numpy.pi/2.).translate(y=10.))


@check_equals_file('Al_from_template.mslice')
@check_equals_binary_file('Al_from_template.mslice')
def test_mslice_default_template(buf: StringIO):
cell = fcc('Al', 4.05, cell='conv').with_wobble(0.030)
write_mslice(cell, buf, slice_thickness=2.025)


@check_equals_file('CeO2_ortho_rotated.mslice')
@check_equals_binary_file('CeO2_ortho_rotated.mslice')
def test_mslice_custom_template(buf: StringIO, ceo2_ortho_cell):
write_mslice(ceo2_ortho_cell, buf, template=INPUT_PATH / 'bare_template.mslice')


@check_equals_file('Al_roundtrip.mslice')
@check_equals_binary_file('Al_roundtrip.mslice')
def test_mslice_roundtrip(buf: StringIO):
cell = read(OUTPUT_PATH / 'Al_roundtrip.mslice')
assert isinstance(cell, HasAtomCell)
Expand Down
4 changes: 2 additions & 2 deletions atomlib/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from .atomcell import HasAtomCell as _HasAtomCell

from .io import CIF, XYZ, XYZFormat, XSF, CFG, FileType, FileOrPath
from .io.mslice import MSliceFile
from .io.mslice import MSliceFile, BinaryFileOrPath

else:
class _HasAtoms: ...
Expand Down Expand Up @@ -120,7 +120,7 @@ def write(self, path: FileOrPath, ty: t.Optional[FileType] = None):


class AtomCellIOMixin(_HasAtomCell, AtomsIOMixin):
def write_mslice(self, f: FileOrPath, template: t.Optional[MSliceFile] = None, *,
def write_mslice(self, f: BinaryFileOrPath, template: t.Optional[MSliceFile] = None, *,
slice_thickness: t.Optional[float] = None, # angstrom
scan_points: t.Optional[ArrayLike] = None,
scan_extent: t.Optional[ArrayLike] = None,
Expand Down
21 changes: 17 additions & 4 deletions atomlib/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from pathlib import Path
import inspect
from io import StringIO
from io import StringIO, BytesIO
import re
import typing as t

Expand Down Expand Up @@ -52,12 +52,25 @@ def assert_files_equal(expected_path: t.Union[str, Path], actual_path: t.Union[s
def check_equals_file(name: t.Union[str, Path]) -> t.Callable[[t.Callable[..., t.Any]], t.Callable[..., None]]:
def decorator(f: t.Callable[..., str]):
@pytest.mark.expected_filename(name)
def wrapper(expected_contents: str, *args, **kwargs): # type: ignore
def wrapper(expected_contents_text: str, *args, **kwargs): # type: ignore
buf = StringIO()
f(buf, *args, **kwargs)
assert buf.getvalue() == expected_contents
assert buf.getvalue() == expected_contents_text

return _wrap_pytest(wrapper, f, lambda params: [inspect.Parameter('expected_contents', inspect.Parameter.POSITIONAL_OR_KEYWORD), *params[1:]])
return _wrap_pytest(wrapper, f, lambda params: [inspect.Parameter('expected_contents_text', inspect.Parameter.POSITIONAL_OR_KEYWORD), *params[1:]])

return decorator


def check_equals_binary_file(name: t.Union[str, Path]) -> t.Callable[[t.Callable[..., t.Any]], t.Callable[..., None]]:
def decorator(f: t.Callable[..., str]):
@pytest.mark.expected_filename(name)
def wrapper(expected_contents_binary: bytes, *args, **kwargs): # type: ignore
buf = BytesIO()
f(buf, *args, **kwargs)
assert buf.getvalue() == expected_contents_binary

return _wrap_pytest(wrapper, f, lambda params: [inspect.Parameter('expected_contents_binary', inspect.Parameter.POSITIONAL_OR_KEYWORD), *params[1:]])

return decorator

Expand Down
12 changes: 11 additions & 1 deletion conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,20 @@ def expected_structure(request) -> 'HasAtoms':


@pytest.fixture(scope='function')
def expected_contents(request) -> str:
def expected_contents_text(request) -> str:
from atomlib.testing import OUTPUT_PATH

marker = request.node.get_closest_marker('expected_filename')
name = str(marker.args[0])
with open(OUTPUT_PATH / name, 'r') as f:
return f.read()


@pytest.fixture(scope='function')
def expected_contents_binary(request) -> bytes:
from atomlib.testing import OUTPUT_PATH

marker = request.node.get_closest_marker('expected_filename')
name = str(marker.args[0])
with open(OUTPUT_PATH / name, 'rb') as f:
return f.read()
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ dependencies = [
"polars~=0.17.5",
"matplotlib~=3.5",
"requests~=2.28",
"lxml~=5.0",
"typing-extensions~=4.4;python_version<'3.10'",
"importlib_resources>=5.0", # importlib.resources backport
]
Expand Down
10 changes: 5 additions & 5 deletions tests/baseline_files/Al_from_template.mslice
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
<?xml version="1.0" encoding="UTF-8" standalone="yes" ?>
<?xml version='1.0' encoding='UTF-8' standalone='yes'?>
<!DOCTYPE database SYSTEM "file:///System/Library/DTDs/CoreData.dtd">

<database>
Expand Down Expand Up @@ -68,15 +68,15 @@
<attribute name="m" type="int16">0</attribute>
<attribute name="cnmb" type="float">0</attribute>
<attribute name="cnma" type="float">0</attribute>
<relationship name="unit" type="0/1" destination="UNIT" idrefs="z106"></relationship>
<relationship name="unit" type="0/1" destination="UNIT" idrefs="z106"/>
</object>
<object type="ABERRATION" id="z105">
<attribute name="name" type="string">Two-fold Stig</attribute>
<attribute name="n" type="int16">1</attribute>
<attribute name="m" type="int16">2</attribute>
<attribute name="cnmb" type="float">0</attribute>
<attribute name="cnma" type="float">0</attribute>
<relationship name="unit" type="0/1" destination="UNIT" idrefs="z106"></relationship>
<relationship name="unit" type="0/1" destination="UNIT" idrefs="z106"/>
</object>
<object type="UNIT" id="z106">
<attribute name="value" type="int32">1</attribute>
Expand Down Expand Up @@ -111,7 +111,7 @@
<attribute name="m" type="int16">0</attribute>
<attribute name="cnmb" type="float">0</attribute>
<attribute name="cnma" type="float">0</attribute>
<relationship name="unit" type="0/1" destination="UNIT" idrefs="z102"></relationship>
<relationship name="unit" type="0/1" destination="UNIT" idrefs="z102"/>
</object>
<object type="STRUCTURE" id="z110">
<attribute name="tilty" type="float">0</attribute>
Expand All @@ -122,7 +122,7 @@
<attribute name="aparam" type="float">4.05</attribute>
<attribute name="bparam" type="float">4.05</attribute>
<attribute name="cparam" type="float">4.05</attribute>
<relationship name="atoms" type="0/1" destination="STRUCTUREATOM"></relationship>
<relationship name="atoms" type="0/1" destination="STRUCTUREATOM"/>
</object>
<object type="DETECTOR" id="ABF">
<attribute name="outerangle" type="float">30</attribute>
Expand Down
Loading

0 comments on commit 16115ed

Please sign in to comment.