Skip to content

Commit

Permalink
Correctly handle print_info() for api_main()
Browse files Browse the repository at this point in the history
  • Loading branch information
mara004 committed Feb 21, 2024
1 parent 577ff1d commit b933c9b
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 30 deletions.
3 changes: 3 additions & 0 deletions src/ctypesgen/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@
# Helper modules
from . import messages

# Entry points
from .__main__ import main, api_main

__version__ = version.VERSION.partition("-")[-1]
VERSION = __version__
PYPDFIUM2_SPECIFIC = True
18 changes: 12 additions & 6 deletions src/ctypesgen/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import argparse
import itertools
from pathlib import Path
from pprint import pformat

from ctypesgen import (
messages as msgs,
Expand All @@ -16,7 +17,9 @@
printer_python,
printer_json,
)

from ctypesgen.printer_python import (
txtpath, PRIVATE_PATHS,
)

# -- Argparse-based entry point --

Expand All @@ -25,7 +28,8 @@
def main(given_argv=sys.argv[1:]):
args = get_parser().parse_args(given_argv)
postparse(args)
main_impl(args, given_argv)
cmd_str = " ".join(["ctypesgen"] + [shlex.quote(txtpath(a)) for a in given_argv])
main_impl(args, cmd_str)

def postparse(args):
args.cppargs = list( itertools.chain(*args.cppargs) )
Expand All @@ -49,9 +53,11 @@ def api_main(args):
real_args = defaults.copy()
real_args.update(args)
real_args = argparse.Namespace(**real_args)
given_argv = "Unknown API Call".split(" ") # FIXME

return main_impl(real_args, given_argv=given_argv)
args_str = str(pformat(args))
for p, x in PRIVATE_PATHS:
args_str = args_str.replace(p, x)
return main_impl(real_args, f"ctypesgen.api_main(\n{args_str}\n)")


# Adapted from https://stackoverflow.com/a/59395868/15547292
Expand All @@ -70,7 +76,7 @@ def _get_parser_requires(parser):

# -- Main implementation --

def main_impl(args, given_argv):
def main_impl(args, cmd_str):

assert args.headers or args.system_headers, "Either --headers or --system-headers required."

Expand Down Expand Up @@ -105,7 +111,7 @@ def main_impl(args, given_argv):
raise RuntimeError("No target members found.")
printer = {"py": printer_python, "json": printer_json}[args.output_language].WrapperPrinter
msgs.status_message(f"Printing to {args.output}.")
printer(args.output, args, data, given_argv)
printer(args.output, args, data, cmd_str)

msgs.status_message("Wrapping complete.")

Expand Down
44 changes: 20 additions & 24 deletions src/ctypesgen/printer_python.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import shlex
import shutil
from pathlib import Path
from textwrap import indent
Expand All @@ -22,20 +21,34 @@ def paragraph_ctx(txt):
file.write(f"\n# -- End {txt} --")
return paragraph_ctx

PRIVATE_PATHS = [(str(Path.home()), "~")]
if Path.cwd() != Path("/"): # don't strip unix root
PRIVATE_PATHS += [(str(Path.cwd()), ".")]
# sort descending by length to avoid interference
PRIVATE_PATHS.sort(key=lambda x: len(x[0]), reverse=True)

def txtpath(s):
# Returns a path string suitable for embedding into the output, with private paths stripped
s = str(s)
for p, x in PRIVATE_PATHS:
if s.startswith(p):
return x + s[len(p):]
return s


# Important: Concerning newlines handling, please read docs/dev_comments.md

class WrapperPrinter:

def __init__(self, outpath, opts, data, argv):
def __init__(self, outpath, opts, data, cmd_str):

self.opts = opts

with outpath.open("w", encoding="utf-8") as self.file:

self.paragraph_ctx = ParagraphCtxFactory(self.file)

self.print_info(argv)
self.print_info(cmd_str)
self.file.write(
"\n\nimport ctypes"
"\nfrom ctypes import *"
Expand Down Expand Up @@ -65,26 +78,11 @@ def __init__(self, outpath, opts, data, argv):

for fp in opts.inserted_files:
self.file.write("\n\n\n")
self._embed_file(fp, f"inserted file '{self._txtpath(fp)}'")
self._embed_file(fp, f"inserted file '{txtpath(fp)}'")

self.file.write("\n")


PRIVATE_PATHS_TABLE = [(str(Path.home()), "~")]
if Path.cwd() != Path("/"): # don't strip unix root
PRIVATE_PATHS_TABLE += [(str(Path.cwd()), ".")]
# sort descending by length to avoid interference
PRIVATE_PATHS_TABLE.sort(key=lambda x: len(x[0]), reverse=True)

@classmethod
def _txtpath(cls, s):
# Returns a path string suitable for embedding into the output, with private paths stripped
s = str(s)
for p, x in cls.PRIVATE_PATHS_TABLE:
if s.startswith(p):
return x + s[len(p):]
return s

def _embed_file(self, fp, desc):
with self.paragraph_ctx(desc), open(fp, "r") as src_fh:
self.file.write("\n\n")
Expand All @@ -99,13 +97,11 @@ def _srcinfo(self, src):
if fp in ("<built-in>", "<command line>"):
self.file.write(f"# {fp}\n")
else:
self.file.write(f"# {self._txtpath(fp)}: {lineno}\n")
self.file.write(f"# {txtpath(fp)}: {lineno}\n")


def print_info(self, argv):
argv = [self._txtpath(a) for a in argv]
argv_str = ' '.join([shlex.quote(a) for a in argv])
self.file.write(f'R"""\nAuto-generated by:\nctypesgen {argv_str}\n"""')
def print_info(self, cmd_str):
self.file.write(f'R"""\nAuto-generated by:\n{cmd_str}\n"""')


def print_loader(self, opts):
Expand Down

0 comments on commit b933c9b

Please sign in to comment.