-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
0226af0
commit aa6dd0d
Showing
4 changed files
with
297 additions
and
71 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
"""The Docker STEMMUS_SCOPE model process wrapper.""" | ||
from PyStemmusScope.config_io import read_config | ||
from pathlib import Path | ||
import os | ||
import docker | ||
|
||
|
||
def make_docker_vols_binds(cfg_file: str) -> tuple[list[str], list[str]]: | ||
"""Make docker volume mounting configs. | ||
Args: | ||
cfg_file: Location of the config file | ||
Returns: | ||
volumes, binds | ||
""" | ||
cfg = read_config(cfg_file) | ||
|
||
volumes = [cfg["OutputPath"], cfg["InputPath"]] | ||
binds = [ | ||
f"{cfg['OutputPath']}:{cfg['OutputPath']}:rw", | ||
f"{cfg['InputPath']}:{cfg['InputPath']}:ro", | ||
] | ||
|
||
if ( | ||
not Path(cfg_file).parent.is_relative_to(cfg["InputPath"]) or | ||
not Path(cfg_file).parent.is_relative_to(cfg["OutputPath"]) | ||
): | ||
cfg_folder = str(Path(cfg_file).parent) | ||
volumes.append(cfg_folder) | ||
binds.append(f"{cfg_folder}:{cfg_folder}:ro") | ||
|
||
return volumes, binds | ||
|
||
|
||
class StemmusScopeDocker: | ||
"""Communicate with a STEMMUS_SCOPE Docker container.""" | ||
# The image is hard coded here to ensure compatiblity: | ||
image = "ghcr.io/ecoextreml/stemmus_scope:1.5.0" | ||
|
||
_process_ready_phrase = b"Select BMI mode:" | ||
|
||
def __init__(self, cfg_file: str): | ||
"""Create the Docker container..""" | ||
self.cfg_file = cfg_file | ||
|
||
self.client = docker.APIClient() | ||
|
||
vols, binds = make_docker_vols_binds(cfg_file) | ||
self.container_id = self.client.create_container( | ||
self.image, | ||
stdin_open=True, | ||
tty=True, | ||
detach=True, | ||
volumes=vols, | ||
host_config=self.client.create_host_config(binds=binds) | ||
) | ||
|
||
self.running = False | ||
|
||
def wait_for_model(self): | ||
"""Wait for the model to be ready to receive (more) commands.""" | ||
output = b"" | ||
|
||
while self._process_ready_phrase not in output: | ||
data = self.socket.read(1) | ||
if data is None: | ||
msg = "Could not read data from socket. Docker container might be dead." | ||
raise ConnectionError(msg) | ||
else: | ||
output += bytes(data) | ||
|
||
def is_alive(self): | ||
"""Return if the process is alive.""" | ||
return self.running | ||
|
||
def initialize(self): | ||
"""Initialize the model and wait for it to be ready.""" | ||
if self.is_alive(): | ||
self.client.stop(self.container_id) | ||
|
||
self.client.start(self.container_id) | ||
self.socket = self.client.attach_socket( | ||
self.container_id, {'stdin': 1, 'stdout': 1, 'stream':1} | ||
) | ||
self.wait_for_model() | ||
os.write( | ||
self.socket.fileno(), | ||
bytes(f'initialize "{self.cfg_file}"\n', encoding="utf-8") | ||
) | ||
self.wait_for_model() | ||
|
||
self.running = True | ||
|
||
def update(self): | ||
"""Update the model and wait for it to be ready.""" | ||
if self.is_alive(): | ||
os.write( | ||
self.socket.fileno(), | ||
b'update\n' | ||
) | ||
self.wait_for_model() | ||
else: | ||
msg = "Docker container is not alive. Please restart the model." | ||
raise ConnectionError(msg) | ||
|
||
def finalize(self): | ||
"""Finalize the model.""" | ||
if self.is_alive(): | ||
os.write(self.socket.fileno(),b'finalize\n') | ||
else: | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
"""The local STEMMUS_SCOPE model process wrapper.""" | ||
import subprocess | ||
from typing import Union | ||
from PyStemmusScope.config_io import read_config | ||
import os | ||
|
||
|
||
def is_alive(process: Union[subprocess.Popen, None]) -> subprocess.Popen: | ||
"""Return process if the process is alive, raise an exception if it is not.""" | ||
if process is None: | ||
msg = "Model process does not seem to be open." | ||
raise ConnectionError(msg) | ||
if process.poll() is not None: | ||
msg = f"Model terminated with return code {process.poll()}" | ||
raise ConnectionError(msg) | ||
return process | ||
|
||
|
||
def wait_for_model(process: subprocess.Popen, phrase=b"Select BMI mode:") -> None: | ||
"""Wait for model to be ready for interaction.""" | ||
output = b"" | ||
while is_alive(process) and phrase not in output: | ||
assert process.stdout is not None # required for type narrowing. | ||
output += bytes(process.stdout.read(1)) | ||
|
||
|
||
class LocalStemmusScope: | ||
"""Communicate with the local STEMMUS_SCOPE executable file.""" | ||
def __init__(self, cfg_file: str) -> None: | ||
"""Initialize the process.""" | ||
self.cfg_file = cfg_file | ||
config = read_config(cfg_file) | ||
|
||
exe_file = config["ExeFilePath"] | ||
args = [exe_file, cfg_file, "bmi"] | ||
|
||
os.environ["MATLAB_LOG_DIR"] = str(config["InputPath"]) | ||
|
||
self.matlab_process = subprocess.Popen( | ||
args, | ||
stdin=subprocess.PIPE, | ||
stdout=subprocess.PIPE, | ||
bufsize=0, | ||
) | ||
|
||
wait_for_model(self.matlab_process) | ||
|
||
def is_alive(self) -> bool: | ||
"""Return if the process is alive.""" | ||
try: | ||
is_alive(self.matlab_process) | ||
return True | ||
except ConnectionError: | ||
return False | ||
|
||
def initialize(self) -> None: | ||
"""Initialize the model and wait for it to be ready.""" | ||
self.matlab_process = is_alive(self.matlab_process) | ||
self.matlab_process.stdin.write( | ||
bytes(f'initialize "{self.cfg_file}"\n', encoding="utf-8") # type: ignore | ||
) | ||
wait_for_model(self.matlab_process) | ||
|
||
|
||
def update(self) -> None: | ||
"""Update the model and wait for it to be ready.""" | ||
if self.matlab_process is None: | ||
msg = "Run initialize before trying to update the model." | ||
raise AttributeError(msg) | ||
|
||
self.matlab_process = is_alive(self.matlab_process) | ||
self.matlab_process.stdin.write(b"update\n") # type: ignore | ||
wait_for_model(self.matlab_process) | ||
|
||
|
||
def finalize(self) -> None: | ||
"""Finalize the model.""" | ||
self.matlab_process = is_alive(self.matlab_process) | ||
self.matlab_process.stdin.write(b"finalize\n") # type: ignore | ||
wait_for_model(self.matlab_process, phrase=b"Finished clean up.") |
Oops, something went wrong.