Skip to content

Commit

Permalink
Resolve IO.output_gen's issues (#155)
Browse files Browse the repository at this point in the history
* Resolve end-of-line sequence issues

* Add a buffer

* Add test

* Fix type annotations

* Rewrite output_gen

* Ensure all child processes are terminated

* Fix test

* Rewrite output_gen test

* Fix type hints; fix some io test

* Use a simpler and more straightforward method to terminate the process tree
  • Loading branch information
weilycoder authored Dec 17, 2024
1 parent 49a998e commit b4faa53
Show file tree
Hide file tree
Showing 5 changed files with 132 additions and 82 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
config.py
*.cpp
*.in
*.out
*.exe
Expand Down
2 changes: 1 addition & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
[MASTER]
py-version=3.5
py-version=3.6
disable=R0902,R0903,R0913,R0917,R0912
143 changes: 92 additions & 51 deletions cyaron/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
Classes:
IO: IO tool class. It will process the input and output files.
"""

from __future__ import absolute_import
import os
import re
import signal
import subprocess
import tempfile
from typing import Union, overload, Optional
from typing import Union, overload, Optional, List, cast
from io import IOBase
from . import log
from .utils import list_like, make_unicode
Expand All @@ -18,34 +20,39 @@ class IO:
"""IO tool class. It will process the input and output files."""

@overload
def __init__(self,
input_file: Optional[Union[IOBase, str, int]] = None,
output_file: Optional[Union[IOBase, str, int]] = None,
data_id: Optional[int] = None,
disable_output: bool = False,
make_dirs: bool = False):
def __init__(
self,
input_file: Optional[Union[IOBase, str, int]] = None,
output_file: Optional[Union[IOBase, str, int]] = None,
data_id: Optional[int] = None,
disable_output: bool = False,
make_dirs: bool = False,
):
...

@overload
def __init__(self,
data_id: Optional[int] = None,
file_prefix: Optional[str] = None,
input_suffix: str = '.in',
output_suffix: str = '.out',
disable_output: bool = False,
make_dirs: bool = False):
def __init__(
self,
data_id: Optional[int] = None,
file_prefix: Optional[str] = None,
input_suffix: str = ".in",
output_suffix: str = ".out",
disable_output: bool = False,
make_dirs: bool = False,
):
...

def __init__( # type: ignore
self,
input_file: Optional[Union[IOBase, str, int]] = None,
output_file: Optional[Union[IOBase, str, int]] = None,
data_id: Optional[int] = None,
file_prefix: Optional[str] = None,
input_suffix: str = '.in',
output_suffix: str = '.out',
disable_output: bool = False,
make_dirs: bool = False):
self,
input_file: Optional[Union[IOBase, str, int]] = None,
output_file: Optional[Union[IOBase, str, int]] = None,
data_id: Optional[int] = None,
file_prefix: Optional[str] = None,
input_suffix: str = ".in",
output_suffix: str = ".out",
disable_output: bool = False,
make_dirs: bool = False,
):
"""
Args:
input_file (optional): input file object or filename or file descriptor.
Expand Down Expand Up @@ -84,12 +91,13 @@ def __init__( # type: ignore
# if the dir "./io" not found it will be created
"""
self.__closed = False
self.input_file, self.output_file = None, None
self.input_file = cast(IOBase, None)
self.output_file = None
if file_prefix is not None:
# legacy mode
input_file = '{}{{}}{}'.format(self.__escape_format(file_prefix),
input_file = "{}{{}}{}".format(self.__escape_format(file_prefix),
self.__escape_format(input_suffix))
output_file = '{}{{}}{}'.format(
output_file = "{}{{}}{}".format(
self.__escape_format(file_prefix),
self.__escape_format(output_suffix))
self.input_filename, self.output_filename = None, None
Expand All @@ -101,9 +109,13 @@ def __init__( # type: ignore
self.output_file = None
self.is_first_char = {}

def __init_file(self, f: Union[IOBase, str, int,
None], data_id: Union[int, None],
file_type: str, make_dirs: bool):
def __init_file(
self,
f: Union[IOBase, str, int, None],
data_id: Union[int, None],
file_type: str,
make_dirs: bool,
):
if isinstance(f, IOBase):
# consider ``f`` as a file object
if file_type == "i":
Expand All @@ -112,8 +124,12 @@ def __init_file(self, f: Union[IOBase, str, int,
self.output_file = f
elif isinstance(f, int):
# consider ``f`` as a file descor
self.__init_file(open(f, 'w+', encoding="utf-8", newline='\n'),
data_id, file_type, make_dirs)
self.__init_file(
open(f, "w+", encoding="utf-8", newline="\n"),
data_id,
file_type,
make_dirs,
)
elif f is None:
# consider wanna temp file
fd, self.input_filename = tempfile.mkstemp()
Expand All @@ -133,8 +149,11 @@ def __init_file(self, f: Union[IOBase, str, int,
else:
self.output_filename = filename
self.__init_file(
open(filename, 'w+', newline='\n', encoding='utf-8'), data_id,
file_type, make_dirs)
open(filename, "w+", newline="\n", encoding="utf-8"),
data_id,
file_type,
make_dirs,
)

def __escape_format(self, st: str):
"""replace "{}" to "{{}}" """
Expand Down Expand Up @@ -211,6 +230,15 @@ def __clear(self, file: IOBase, pos: int = 0):
self.is_first_char[file] = True
file.seek(pos)

@staticmethod
def _kill_process_and_children(proc: subprocess.Popen):
if os.name == "posix":
os.killpg(os.getpgid(proc.pid), signal.SIGKILL)
elif os.name == "nt":
os.system(f"TASKKILL /F /T /PID {proc.pid} > nul")
else:
proc.kill() # Not currently supported

def input_write(self, *args, **kwargs):
"""
Write every element in *args into the input file. Splits with `separator`.
Expand Down Expand Up @@ -243,38 +271,49 @@ def input_clear_content(self, pos: int = 0):

self.__clear(self.input_file, pos)

def output_gen(self, shell_cmd, time_limit=None):
def output_gen(self,
shell_cmd: Union[str, List[str]],
time_limit: Optional[float] = None,
*,
replace_EOL: bool = True):
"""
Run the command `shell_cmd` (usually the std program) and send it the input file as stdin.
Write its output to the output file.
Args:
shell_cmd: the command to run, usually the std program.
time_limit: the time limit (seconds) of the command to run.
None means infinity. Defaults to None.
replace_EOL: Set whether to replace the end-of-line sequence with `'\\n'`.
Defaults to True.
"""
if self.output_file is None:
raise ValueError("Output file is disabled")
self.flush_buffer()
origin_pos = self.input_file.tell()
self.input_file.seek(0)
if time_limit is not None:
subprocess.check_call(
shell_cmd,
shell=True,
timeout=time_limit,
stdin=self.input_file.fileno(),
stdout=self.output_file.fileno(),
universal_newlines=True,
)

proc = subprocess.Popen(
shell_cmd,
shell=True,
stdin=self.input_file.fileno(),
stdout=subprocess.PIPE,
universal_newlines=replace_EOL,
preexec_fn=os.setsid if os.name == "posix" else None,
)

try:
output, _ = proc.communicate(timeout=time_limit)
except subprocess.TimeoutExpired:
# proc.kill() # didn't work because `shell=True`.
self._kill_process_and_children(proc)
raise
else:
subprocess.check_call(
shell_cmd,
shell=True,
stdin=self.input_file.fileno(),
stdout=self.output_file.fileno(),
universal_newlines=True,
)
self.input_file.seek(origin_pos)
if replace_EOL:
self.output_file.write(output)
else:
os.write(self.output_file.fileno(), output)
finally:
self.input_file.seek(origin_pos)

log.debug(self.output_filename, " done")

Expand Down Expand Up @@ -309,6 +348,8 @@ def output_clear_content(self, pos: int = 0):
Args:
pos: Where file will truncate
"""
if self.output_file is None:
raise ValueError("Output file is disabled")
self.__clear(self.output_file, pos)

def flush_buffer(self):
Expand Down
66 changes: 37 additions & 29 deletions cyaron/tests/io_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import unittest
import sys
import os
import time
import shutil
import tempfile
import subprocess
Expand Down Expand Up @@ -68,38 +70,47 @@ def test_output_gen(self):
with IO("test_gen.in", "test_gen.out") as test:
test.output_gen("echo 233")

with open("test_gen.out") as f:
with open("test_gen.out", "rb") as f:
output = f.read()
self.assertEqual(output.strip("\n"), "233")
self.assertEqual(output, b"233\n")

def test_output_gen_time_limit_exceeded(self):
time_limit_exceeded = False
with captured_output() as (out, err):
with open("long_time.py", "w") as f:
f.write("import time\ntime.sleep(10)\nprint(1)")
with captured_output():
TIMEOUT = 0.02
WAIT_TIME = 0.4 # If the wait time is too short, an error may occur
with open("long_time.py", "w", encoding="utf-8") as f:
f.write("import time, os\n"
"fn = input()\n"
f"time.sleep({WAIT_TIME})\n"
"os.remove(fn)\n")

try:
with IO("test_gen.in", "test_gen.out") as test:
test.output_gen("python long_time.py", time_limit=1)
except subprocess.TimeoutExpired:
time_limit_exceeded = True
self.assertEqual(time_limit_exceeded, True)
with IO("test_gen.in", "test_gen.out") as test:
fd, input_filename = tempfile.mkstemp()
os.close(fd)
abs_input_filename: str = os.path.abspath(input_filename)
with self.assertRaises(subprocess.TimeoutExpired):
test.input_writeln(abs_input_filename)
test.output_gen(f'"{sys.executable}" long_time.py',
time_limit=TIMEOUT)
time.sleep(WAIT_TIME)
try:
os.remove(input_filename)
except FileNotFoundError:
self.fail("Child processes have not been terminated.")

def test_output_gen_time_limit_not_exceeded(self):
time_limit_exceeded = False
with captured_output() as (out, err):
with open("short_time.py", "w") as f:
f.write("import time\ntime.sleep(0.2)\nprint(1)")

try:
with IO("test_gen.in", "test_gen.out") as test:
test.output_gen("python short_time.py", time_limit=1)
except subprocess.TimeoutExpired:
time_limit_exceeded = True
with open("test_gen.out") as f:
with captured_output():
with open("short_time.py", "w", encoding="utf-8") as f:
f.write("import time\n"
"time.sleep(0.1)\n"
"print(1)")

with IO("test_gen.in", "test_gen.out") as test:
test.output_gen(f'"{sys.executable}" short_time.py',
time_limit=0.5)
with open("test_gen.out", encoding="utf-8") as f:
output = f.read()
self.assertEqual(output.strip("\n"), "1")
self.assertEqual(time_limit_exceeded, False)
self.assertEqual(output, "1\n")

def test_init_overload(self):
with IO(file_prefix="data{", data_id=5) as test:
Expand All @@ -124,10 +135,7 @@ def test_make_dirs(self):

mkdir_false = False
try:
with IO(
"./automkdir_false/data.in",
"./automkdir_false/data.out",
):
with IO("./automkdir_false/data.in", "./automkdir_false/data.out"):
pass
except FileNotFoundError:
mkdir_false = True
Expand Down
2 changes: 1 addition & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit b4faa53

Please sign in to comment.