From cb94381e85c235f1e87bf6225edb8522a469b21d Mon Sep 17 00:00:00 2001 From: Yael Balbastre Date: Tue, 5 Oct 2021 13:43:41 -0400 Subject: [PATCH] First commit --- .gitattributes | 1 + MANIFEST.in | 2 + README.md | 274 ++++++ interpol/__init__.py | 3 + interpol/_version.py | 623 +++++++++++++ interpol/api.py | 533 +++++++++++ interpol/autograd.py | 293 ++++++ interpol/bounds.py | 86 ++ interpol/coeff.py | 344 +++++++ interpol/iso0.py | 368 ++++++++ interpol/iso1.py | 1338 ++++++++++++++++++++++++++ interpol/jit_utils.py | 418 +++++++++ interpol/nd.py | 439 +++++++++ interpol/pushpull.py | 323 +++++++ interpol/splines.py | 196 ++++ interpol/utils.py | 101 ++ setup.cfg | 31 + setup.py | 9 + versioneer.py | 2064 +++++++++++++++++++++++++++++++++++++++++ 19 files changed, 7446 insertions(+) create mode 100644 .gitattributes create mode 100644 MANIFEST.in create mode 100644 interpol/__init__.py create mode 100644 interpol/_version.py create mode 100755 interpol/api.py create mode 100644 interpol/autograd.py create mode 100644 interpol/bounds.py create mode 100644 interpol/coeff.py create mode 100644 interpol/iso0.py create mode 100644 interpol/iso1.py create mode 100644 interpol/jit_utils.py create mode 100644 interpol/nd.py create mode 100644 interpol/pushpull.py create mode 100644 interpol/splines.py create mode 100644 interpol/utils.py create mode 100644 setup.cfg create mode 100755 setup.py create mode 100644 versioneer.py diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000..6596e27 --- /dev/null +++ b/.gitattributes @@ -0,0 +1 @@ +interpol/_version.py export-subst diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 0000000..4d09922 --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1,2 @@ +include versioneer.py +include interpol/_version.py diff --git a/README.md b/README.md index b80a406..b95aac8 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,276 @@ # torch-interpol High-order spline interpolation in PyTorch + +## Description + +This package contains a pure python implementation of **high-order spline +interpolation** for ND tensors (including 2D and 3D images). It makes use +of the just-in-time capabilities of TorchScript and explicitly implements +the forward and backward passes of all functions, making it **fast** and +**memory-efficient**. + +All the functions available in this (small) package were originally +implemented in [NITorch](https://github/balbasty/nitorch), a larger +PyTorch-based package dedicated to NeuroImaging and Medical Image Computing. + +## Installation + +```shell +pip install git+https://github.com/balbasty/torch-interpol + +# Or, alternatively +git clone git@github.com:balbasty/torch-interpol.git +pip install ./torch-interpol +``` + +## Usage + +The most useful function is `grid_pull`, which samples an image at a given +set of coordinates according to some spline order. Here's a small example +that show how to reslice an image to a different image space: +```python +# we are going to rotate and resample a 32x32 pixels square +import torch, math +import matplotlib.pyplot as plt +from interpol import grid_pull, affine_grid + +# generate a "square" phantom image +x = torch.zeros([64, 64]) +x[16:48, 16:48] = 1 + +# build rotation matrix +rot = [[math.cos(math.pi/4), -math.sin(math.pi/4), 0], + [math.sin(math.pi/4), math.cos(math.pi/4), 0], + [0, 0, 1]] +center = [[1, 0, -32], + [0, 1, -32], + [0, 0, 1]] +rot = torch.as_tensor(rot, dtype=torch.float) +center = torch.as_tensor(center, dtype=torch.float) +full_affine = center.inverse() @ rot @ center + +# build dense field of sampling coordinates +grid = affine_grid(full_affine, [64, 64]) + +# resample +y1 = grid_pull(x, grid, bound='mirror', interpolation=1) +y3 = grid_pull(x, grid, bound='mirror', interpolation=3, prefilter=True) +y5 = grid_pull(x, grid, bound='mirror', interpolation=5, prefilter=True) + +# plot +plt.subplot(1, 4, 1) +plt.imshow(x, vmin=0, vmax=1) +plt.axis('off') +plt.title('original') +plt.subplot(1, 4, 2) +plt.imshow(y1, vmin=0, vmax=1) +plt.axis('off') +plt.title('1st order') +plt.subplot(1, 4, 3) +plt.imshow(y3, vmin=0, vmax=1) +plt.axis('off') +plt.title('3rd order') +plt.subplot(1, 4, 4) +plt.imshow(y5, vmin=0, vmax=1) +plt.axis('off') +plt.title('5th order') +plt.show() +``` + +## Quick doc + + +``` +Notes +----- + +`interpolation` can be an int, a string or an InterpolationType. +Possible values are: + - 0 or 'nearest' + - 1 or 'linear' + - 2 or 'quadratic' + - 3 or 'cubic' + - 4 or 'fourth' + - 5 or 'fifth' + - etc. +A list of values can be provided, in the order [W, H, D], +to specify dimension-specific interpolation orders. + +`bound` can be an int, a string or a BoundType. +Possible values are: + - 'replicate' or 'nearest' : a a a | a b c d | d d d + - 'dct1' or 'mirror' : d c b | a b c d | c b a + - 'dct2' or 'reflect' : c b a | a b c d | d c b + - 'dst1' or 'antimirror' : -b -a 0 | a b c d | 0 -d -c + - 'dst2' or 'antireflect' : -c -b -a | a b c d | -d -c -b + - 'dft' or 'wrap' : b c d | a b c d | a b c + - 'zero' or 'zeros' : 0 0 0 | a b c d | 0 0 0 +A list of values can be provided, in the order [W, H, D], +to specify dimension-specific boundary conditions. +Note that +- `dft` corresponds to circular padding +- `dct2` corresponds to Neumann boundary conditions (symmetric) +- `dst2` corresponds to Dirichlet boundary conditions (antisymmetric) +See https://en.wikipedia.org/wiki/Discrete_cosine_transform + https://en.wikipedia.org/wiki/Discrete_sine_transform +``` + +```python +interpol.grid_pull( + input, + grid, + interpolation='linear', + bound='zero', + extrapolate=False, + prefilter=False, +) +""" +Sample an image with respect to a deformation field. + +If the input dtype is not a floating point type, the input image is +assumed to contain labels. Then, unique labels are extracted +and resampled individually, making them soft labels. Finally, +the label map is reconstructed from the individual soft labels by +assigning the label with maximum soft value. + +Parameters +---------- +input : (..., [channel], *inshape) tensor + Input image. +grid : (..., *outshape, dim) tensor + Transformation field. +interpolation : int or sequence[int], default=1 + Interpolation order. +bound : BoundType or sequence[BoundType], default='zero' + Boundary conditions. +extrapolate : bool or int, default=True + Extrapolate out-of-bound data. +prefilter : bool, default=False + Apply spline pre-filter (= interpolates the input) + +Returns +------- +output : (..., [channel], *outshape) tensor + Deformed image. +""" +``` + +```python +interpol.grid_push( + input, + grid, + shape=None, + interpolation='linear', + bound='zero', + extrapolate=False, + prefilter=False, +) +""" +Splat an image with respect to a deformation field (pull adjoint). + +Parameters +---------- +input : (..., [channel], *inshape) tensor + Input image. +grid : (..., *inshape, dim) tensor + Transformation field. +shape : sequence[int], default=inshape + Output shape +interpolation : int or sequence[int], default=1 + Interpolation order. +bound : BoundType, or sequence[BoundType], default='zero' + Boundary conditions. +extrapolate : bool or int, default=True + Extrapolate out-of-bound data. +prefilter : bool, default=False + Apply spline pre-filter. + +Returns +------- +output : (..., [channel], *shape) tensor + Spatted image. +""" +``` + +```python +interpol.grid_grad( + input, + grid, + interpolation='linear', + bound='zero', + extrapolate=False, + prefilter=False, +) +""" +Sample spatial gradients of an image with respect to a deformation field. + +Parameters +---------- +input : (..., [channel], *inshape) tensor + Input image. +grid : (..., *inshape, dim) tensor + Transformation field. +shape : sequence[int], default=inshape + Output shape +interpolation : int or sequence[int], default=1 + Interpolation order. +bound : BoundType, or sequence[BoundType], default='zero' + Boundary conditions. +extrapolate : bool or int, default=True + Extrapolate out-of-bound data. +prefilter : bool, default=False + Apply spline pre-filter (= interpolates the input) + +Returns +------- +output : (..., [channel], *shape, dim) tensor + Sampled gradients. +""" +``` + +```python +interpol.spline_coeff_nd( + input, + interpolation='linear', + bound='dct2', + dim=None, + inplace=False, +) +""" +Compute the interpolating spline coefficients, for a given spline order +and boundary conditions, along the last `dim` dimensions. + +References +---------- +..[1] M. Unser, A. Aldroubi and M. Eden. + "B-Spline Signal Processing: Part I-Theory," + IEEE Transactions on Signal Processing 41(2):821-832 (1993). +..[2] M. Unser, A. Aldroubi and M. Eden. + "B-Spline Signal Processing: Part II-Efficient Design and Applications," + IEEE Transactions on Signal Processing 41(2):834-848 (1993). +..[3] M. Unser. + "Splines: A Perfect Fit for Signal and Image Processing," + IEEE Signal Processing Magazine 16(6):22-38 (1999). + +Parameters +---------- +input : (..., *spatial) tensor + Input image. +interpolation : int or sequence[int], default=1 + Interpolation order. +bound : BoundType or sequence[BoundType], default='dct1' + Boundary conditions. +dim : int, default=-1 + Number of spatial dimensions +inplace : bool, default=False + Process the volume in place. + +Returns +------- +output : (..., *spatial) tensor + Coefficient image. +""" +``` +## License + +torch-interpol is released under the MIT license. diff --git a/interpol/__init__.py b/interpol/__init__.py new file mode 100644 index 0000000..01efa79 --- /dev/null +++ b/interpol/__init__.py @@ -0,0 +1,3 @@ +from .api import * +from . import _version +__version__ = _version.get_versions()['version'] diff --git a/interpol/_version.py b/interpol/_version.py new file mode 100644 index 0000000..d2cc98b --- /dev/null +++ b/interpol/_version.py @@ -0,0 +1,623 @@ + +# This file helps to compute a version number in source trees obtained from +# git-archive tarball (such as those provided by githubs download-from-tag +# feature). Distribution tarballs (built by setup.py sdist) and build +# directories (produced by setup.py build) will contain a much shorter file +# that just contains the computed version number. + +# This file is released into the public domain. Generated by +# versioneer-0.20 (https://github.com/python-versioneer/python-versioneer) + +"""Git implementation of _version.py.""" + +import errno +import os +import re +import subprocess +import sys + + +def get_keywords(): + """Get the keywords needed to look up the version information.""" + # these strings will be replaced by git during git-archive. + # setup.py/versioneer.py will grep for the variable names, so they must + # each be defined on a line of their own. _version.py will just call + # get_keywords(). + git_refnames = "$Format:%d$" + git_full = "$Format:%H$" + git_date = "$Format:%ci$" + keywords = {"refnames": git_refnames, "full": git_full, "date": git_date} + return keywords + + +class VersioneerConfig: # pylint: disable=too-few-public-methods + """Container for Versioneer configuration parameters.""" + + +def get_config(): + """Create, populate and return the VersioneerConfig() object.""" + # these strings are filled in when 'setup.py versioneer' creates + # _version.py + cfg = VersioneerConfig() + cfg.VCS = "git" + cfg.style = "pep440" + cfg.tag_prefix = "" + cfg.parentdir_prefix = "" + cfg.versionfile_source = "interpol/_version.py" + cfg.verbose = False + return cfg + + +class NotThisMethod(Exception): + """Exception raised if a method is not valid for the current scenario.""" + + +LONG_VERSION_PY = {} +HANDLERS = {} + + +def register_vcs_handler(vcs, method): # decorator + """Create decorator to mark a method as the handler of a VCS.""" + def decorate(f): + """Store f in HANDLERS[vcs][method].""" + if vcs not in HANDLERS: + HANDLERS[vcs] = {} + HANDLERS[vcs][method] = f + return f + return decorate + + +# pylint:disable=too-many-arguments,consider-using-with # noqa +def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, + env=None): + """Call the given command(s).""" + assert isinstance(commands, list) + process = None + for command in commands: + try: + dispcmd = str([command] + args) + # remember shell=False, so use git.cmd on windows, not just git + process = subprocess.Popen([command] + args, cwd=cwd, env=env, + stdout=subprocess.PIPE, + stderr=(subprocess.PIPE if hide_stderr + else None)) + break + except EnvironmentError: + e = sys.exc_info()[1] + if e.errno == errno.ENOENT: + continue + if verbose: + print("unable to run %s" % dispcmd) + print(e) + return None, None + else: + if verbose: + print("unable to find command, tried %s" % (commands,)) + return None, None + stdout = process.communicate()[0].strip().decode() + if process.returncode != 0: + if verbose: + print("unable to run %s (error)" % dispcmd) + print("stdout was %s" % stdout) + return None, process.returncode + return stdout, process.returncode + + +def versions_from_parentdir(parentdir_prefix, root, verbose): + """Try to determine the version from the parent directory name. + + Source tarballs conventionally unpack into a directory that includes both + the project name and a version string. We will also support searching up + two directory levels for an appropriately named parent directory + """ + rootdirs = [] + + for _ in range(3): + dirname = os.path.basename(root) + if dirname.startswith(parentdir_prefix): + return {"version": dirname[len(parentdir_prefix):], + "full-revisionid": None, + "dirty": False, "error": None, "date": None} + rootdirs.append(root) + root = os.path.dirname(root) # up a level + + if verbose: + print("Tried directories %s but none started with prefix %s" % + (str(rootdirs), parentdir_prefix)) + raise NotThisMethod("rootdir doesn't start with parentdir_prefix") + + +@register_vcs_handler("git", "get_keywords") +def git_get_keywords(versionfile_abs): + """Extract version information from the given file.""" + # the code embedded in _version.py can just fetch the value of these + # keywords. When used from setup.py, we don't want to import _version.py, + # so we do it with a regexp instead. This function is not used from + # _version.py. + keywords = {} + try: + with open(versionfile_abs, "r") as fobj: + for line in fobj: + if line.strip().startswith("git_refnames ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["refnames"] = mo.group(1) + if line.strip().startswith("git_full ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["full"] = mo.group(1) + if line.strip().startswith("git_date ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["date"] = mo.group(1) + except EnvironmentError: + pass + return keywords + + +@register_vcs_handler("git", "keywords") +def git_versions_from_keywords(keywords, tag_prefix, verbose): + """Get version information from git keywords.""" + if "refnames" not in keywords: + raise NotThisMethod("Short version file found") + date = keywords.get("date") + if date is not None: + # Use only the last line. Previous lines may contain GPG signature + # information. + date = date.splitlines()[-1] + + # git-2.2.0 added "%cI", which expands to an ISO-8601 -compliant + # datestamp. However we prefer "%ci" (which expands to an "ISO-8601 + # -like" string, which we must then edit to make compliant), because + # it's been around since git-1.5.3, and it's too difficult to + # discover which version we're using, or to work around using an + # older one. + date = date.strip().replace(" ", "T", 1).replace(" ", "", 1) + refnames = keywords["refnames"].strip() + if refnames.startswith("$Format"): + if verbose: + print("keywords are unexpanded, not using") + raise NotThisMethod("unexpanded keywords, not a git-archive tarball") + refs = {r.strip() for r in refnames.strip("()").split(",")} + # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of + # just "foo-1.0". If we see a "tag: " prefix, prefer those. + TAG = "tag: " + tags = {r[len(TAG):] for r in refs if r.startswith(TAG)} + if not tags: + # Either we're using git < 1.8.3, or there really are no tags. We use + # a heuristic: assume all version tags have a digit. The old git %d + # expansion behaves like git log --decorate=short and strips out the + # refs/heads/ and refs/tags/ prefixes that would let us distinguish + # between branches and tags. By ignoring refnames without digits, we + # filter out many common branch names like "release" and + # "stabilization", as well as "HEAD" and "master". + tags = {r for r in refs if re.search(r'\d', r)} + if verbose: + print("discarding '%s', no digits" % ",".join(refs - tags)) + if verbose: + print("likely tags: %s" % ",".join(sorted(tags))) + for ref in sorted(tags): + # sorting will prefer e.g. "2.0" over "2.0rc1" + if ref.startswith(tag_prefix): + r = ref[len(tag_prefix):] + # Filter out refs that exactly match prefix or that don't start + # with a number once the prefix is stripped (mostly a concern + # when prefix is '') + if not re.match(r'\d', r): + continue + if verbose: + print("picking %s" % r) + return {"version": r, + "full-revisionid": keywords["full"].strip(), + "dirty": False, "error": None, + "date": date} + # no suitable tags, so version is "0+unknown", but full hex is still there + if verbose: + print("no suitable tags, using unknown + full revision id") + return {"version": "0+unknown", + "full-revisionid": keywords["full"].strip(), + "dirty": False, "error": "no suitable tags", "date": None} + + +@register_vcs_handler("git", "pieces_from_vcs") +def git_pieces_from_vcs(tag_prefix, root, verbose, runner=run_command): + """Get version from 'git describe' in the root of the source tree. + + This only gets called if the git-archive 'subst' keywords were *not* + expanded, and _version.py hasn't already been rewritten with a short + version string, meaning we're inside a checked out source tree. + """ + GITS = ["git"] + if sys.platform == "win32": + GITS = ["git.cmd", "git.exe"] + + _, rc = runner(GITS, ["rev-parse", "--git-dir"], cwd=root, + hide_stderr=True) + if rc != 0: + if verbose: + print("Directory %s not under git control" % root) + raise NotThisMethod("'git rev-parse --git-dir' returned error") + + # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] + # if there isn't one, this yields HEX[-dirty] (no NUM) + describe_out, rc = runner(GITS, ["describe", "--tags", "--dirty", + "--always", "--long", + "--match", "%s*" % tag_prefix], + cwd=root) + # --long was added in git-1.5.5 + if describe_out is None: + raise NotThisMethod("'git describe' failed") + describe_out = describe_out.strip() + full_out, rc = runner(GITS, ["rev-parse", "HEAD"], cwd=root) + if full_out is None: + raise NotThisMethod("'git rev-parse' failed") + full_out = full_out.strip() + + pieces = {} + pieces["long"] = full_out + pieces["short"] = full_out[:7] # maybe improved later + pieces["error"] = None + + branch_name, rc = runner(GITS, ["rev-parse", "--abbrev-ref", "HEAD"], + cwd=root) + # --abbrev-ref was added in git-1.6.3 + if rc != 0 or branch_name is None: + raise NotThisMethod("'git rev-parse --abbrev-ref' returned error") + branch_name = branch_name.strip() + + if branch_name == "HEAD": + # If we aren't exactly on a branch, pick a branch which represents + # the current commit. If all else fails, we are on a branchless + # commit. + branches, rc = runner(GITS, ["branch", "--contains"], cwd=root) + # --contains was added in git-1.5.4 + if rc != 0 or branches is None: + raise NotThisMethod("'git branch --contains' returned error") + branches = branches.split("\n") + + # Remove the first line if we're running detached + if "(" in branches[0]: + branches.pop(0) + + # Strip off the leading "* " from the list of branches. + branches = [branch[2:] for branch in branches] + if "master" in branches: + branch_name = "master" + elif not branches: + branch_name = None + else: + # Pick the first branch that is returned. Good or bad. + branch_name = branches[0] + + pieces["branch"] = branch_name + + # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty] + # TAG might have hyphens. + git_describe = describe_out + + # look for -dirty suffix + dirty = git_describe.endswith("-dirty") + pieces["dirty"] = dirty + if dirty: + git_describe = git_describe[:git_describe.rindex("-dirty")] + + # now we have TAG-NUM-gHEX or HEX + + if "-" in git_describe: + # TAG-NUM-gHEX + mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe) + if not mo: + # unparseable. Maybe git-describe is misbehaving? + pieces["error"] = ("unable to parse git-describe output: '%s'" + % describe_out) + return pieces + + # tag + full_tag = mo.group(1) + if not full_tag.startswith(tag_prefix): + if verbose: + fmt = "tag '%s' doesn't start with prefix '%s'" + print(fmt % (full_tag, tag_prefix)) + pieces["error"] = ("tag '%s' doesn't start with prefix '%s'" + % (full_tag, tag_prefix)) + return pieces + pieces["closest-tag"] = full_tag[len(tag_prefix):] + + # distance: number of commits since tag + pieces["distance"] = int(mo.group(2)) + + # commit: short hex revision ID + pieces["short"] = mo.group(3) + + else: + # HEX: no tags + pieces["closest-tag"] = None + count_out, rc = runner(GITS, ["rev-list", "HEAD", "--count"], cwd=root) + pieces["distance"] = int(count_out) # total number of commits + + # commit date: see ISO-8601 comment in git_versions_from_keywords() + date = runner(GITS, ["show", "-s", "--format=%ci", "HEAD"], cwd=root)[0].strip() + # Use only the last line. Previous lines may contain GPG signature + # information. + date = date.splitlines()[-1] + pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1) + + return pieces + + +def plus_or_dot(pieces): + """Return a + if we don't already have one, else return a .""" + if "+" in pieces.get("closest-tag", ""): + return "." + return "+" + + +def render_pep440(pieces): + """Build up version string, with post-release "local version identifier". + + Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you + get a tagged build and then dirty it, you'll get TAG+0.gHEX.dirty + + Exceptions: + 1: no tags. git_describe was just HEX. 0+untagged.DISTANCE.gHEX[.dirty] + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + rendered += plus_or_dot(pieces) + rendered += "%d.g%s" % (pieces["distance"], pieces["short"]) + if pieces["dirty"]: + rendered += ".dirty" + else: + # exception #1 + rendered = "0+untagged.%d.g%s" % (pieces["distance"], + pieces["short"]) + if pieces["dirty"]: + rendered += ".dirty" + return rendered + + +def render_pep440_branch(pieces): + """TAG[[.dev0]+DISTANCE.gHEX[.dirty]] . + + The ".dev0" means not master branch. Note that .dev0 sorts backwards + (a feature branch will appear "older" than the master branch). + + Exceptions: + 1: no tags. 0[.dev0]+untagged.DISTANCE.gHEX[.dirty] + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += plus_or_dot(pieces) + rendered += "%d.g%s" % (pieces["distance"], pieces["short"]) + if pieces["dirty"]: + rendered += ".dirty" + else: + # exception #1 + rendered = "0" + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += "+untagged.%d.g%s" % (pieces["distance"], + pieces["short"]) + if pieces["dirty"]: + rendered += ".dirty" + return rendered + + +def render_pep440_pre(pieces): + """TAG[.post0.devDISTANCE] -- No -dirty. + + Exceptions: + 1: no tags. 0.post0.devDISTANCE + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"]: + rendered += ".post0.dev%d" % pieces["distance"] + else: + # exception #1 + rendered = "0.post0.dev%d" % pieces["distance"] + return rendered + + +def render_pep440_post(pieces): + """TAG[.postDISTANCE[.dev0]+gHEX] . + + The ".dev0" means dirty. Note that .dev0 sorts backwards + (a dirty tree will appear "older" than the corresponding clean one), + but you shouldn't be releasing software with -dirty anyways. + + Exceptions: + 1: no tags. 0.postDISTANCE[.dev0] + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + rendered += ".post%d" % pieces["distance"] + if pieces["dirty"]: + rendered += ".dev0" + rendered += plus_or_dot(pieces) + rendered += "g%s" % pieces["short"] + else: + # exception #1 + rendered = "0.post%d" % pieces["distance"] + if pieces["dirty"]: + rendered += ".dev0" + rendered += "+g%s" % pieces["short"] + return rendered + + +def render_pep440_post_branch(pieces): + """TAG[.postDISTANCE[.dev0]+gHEX[.dirty]] . + + The ".dev0" means not master branch. + + Exceptions: + 1: no tags. 0.postDISTANCE[.dev0]+gHEX[.dirty] + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + rendered += ".post%d" % pieces["distance"] + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += plus_or_dot(pieces) + rendered += "g%s" % pieces["short"] + if pieces["dirty"]: + rendered += ".dirty" + else: + # exception #1 + rendered = "0.post%d" % pieces["distance"] + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += "+g%s" % pieces["short"] + if pieces["dirty"]: + rendered += ".dirty" + return rendered + + +def render_pep440_old(pieces): + """TAG[.postDISTANCE[.dev0]] . + + The ".dev0" means dirty. + + Exceptions: + 1: no tags. 0.postDISTANCE[.dev0] + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + rendered += ".post%d" % pieces["distance"] + if pieces["dirty"]: + rendered += ".dev0" + else: + # exception #1 + rendered = "0.post%d" % pieces["distance"] + if pieces["dirty"]: + rendered += ".dev0" + return rendered + + +def render_git_describe(pieces): + """TAG[-DISTANCE-gHEX][-dirty]. + + Like 'git describe --tags --dirty --always'. + + Exceptions: + 1: no tags. HEX[-dirty] (note: no 'g' prefix) + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"]: + rendered += "-%d-g%s" % (pieces["distance"], pieces["short"]) + else: + # exception #1 + rendered = pieces["short"] + if pieces["dirty"]: + rendered += "-dirty" + return rendered + + +def render_git_describe_long(pieces): + """TAG-DISTANCE-gHEX[-dirty]. + + Like 'git describe --tags --dirty --always -long'. + The distance/hash is unconditional. + + Exceptions: + 1: no tags. HEX[-dirty] (note: no 'g' prefix) + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + rendered += "-%d-g%s" % (pieces["distance"], pieces["short"]) + else: + # exception #1 + rendered = pieces["short"] + if pieces["dirty"]: + rendered += "-dirty" + return rendered + + +def render(pieces, style): + """Render the given version pieces into the requested style.""" + if pieces["error"]: + return {"version": "unknown", + "full-revisionid": pieces.get("long"), + "dirty": None, + "error": pieces["error"], + "date": None} + + if not style or style == "default": + style = "pep440" # the default + + if style == "pep440": + rendered = render_pep440(pieces) + elif style == "pep440-branch": + rendered = render_pep440_branch(pieces) + elif style == "pep440-pre": + rendered = render_pep440_pre(pieces) + elif style == "pep440-post": + rendered = render_pep440_post(pieces) + elif style == "pep440-post-branch": + rendered = render_pep440_post_branch(pieces) + elif style == "pep440-old": + rendered = render_pep440_old(pieces) + elif style == "git-describe": + rendered = render_git_describe(pieces) + elif style == "git-describe-long": + rendered = render_git_describe_long(pieces) + else: + raise ValueError("unknown style '%s'" % style) + + return {"version": rendered, "full-revisionid": pieces["long"], + "dirty": pieces["dirty"], "error": None, + "date": pieces.get("date")} + + +def get_versions(): + """Get version information or return default if unable to do so.""" + # I am in _version.py, which lives at ROOT/VERSIONFILE_SOURCE. If we have + # __file__, we can work backwards from there to the root. Some + # py2exe/bbfreeze/non-CPython implementations don't do __file__, in which + # case we can only use expanded keywords. + + cfg = get_config() + verbose = cfg.verbose + + try: + return git_versions_from_keywords(get_keywords(), cfg.tag_prefix, + verbose) + except NotThisMethod: + pass + + try: + root = os.path.realpath(__file__) + # versionfile_source is the relative path from the top of the source + # tree (where the .git directory might live) to this file. Invert + # this to find the root from __file__. + for _ in cfg.versionfile_source.split('/'): + root = os.path.dirname(root) + except NameError: + return {"version": "0+unknown", "full-revisionid": None, + "dirty": None, + "error": "unable to find root of source tree", + "date": None} + + try: + pieces = git_pieces_from_vcs(cfg.tag_prefix, root, verbose) + return render(pieces, cfg.style) + except NotThisMethod: + pass + + try: + if cfg.parentdir_prefix: + return versions_from_parentdir(cfg.parentdir_prefix, root, verbose) + except NotThisMethod: + pass + + return {"version": "0+unknown", "full-revisionid": None, + "dirty": None, + "error": "unable to compute version", "date": None} diff --git a/interpol/api.py b/interpol/api.py new file mode 100755 index 0000000..df4e2df --- /dev/null +++ b/interpol/api.py @@ -0,0 +1,533 @@ +"""High level interpolation API""" +import torch +from .utils import expanded_shape, matvec +from .jit_utils import movedim1 +from .autograd import (GridPull, GridPush, GridCount, GridGrad, + SplineCoeff, SplineCoeffND) + +_doc_interpolation = \ +"""`interpolation` can be an int, a string or an InterpolationType. + Possible values are: + - 0 or 'nearest' + - 1 or 'linear' + - 2 or 'quadratic' + - 3 or 'cubic' + - 4 or 'fourth' + - 5 or 'fifth' + - etc. + A list of values can be provided, in the order [W, H, D], + to specify dimension-specific interpolation orders.""" + +_doc_bound = \ +"""`bound` can be an int, a string or a BoundType. + Possible values are: + - 'replicate' or 'nearest' : a a a | a b c d | d d d + - 'dct1' or 'mirror' : d c b | a b c d | c b a + - 'dct2' or 'reflect' : c b a | a b c d | d c b + - 'dst1' or 'antimirror' : -b -a 0 | a b c d | 0 -d -c + - 'dst2' or 'antireflect' : -c -b -a | a b c d | -d -c -b + - 'dft' or 'wrap' : b c d | a b c d | a b c + - 'zero' or 'zeros' : 0 0 0 | a b c d | 0 0 0 + A list of values can be provided, in the order [W, H, D], + to specify dimension-specific boundary conditions. + Note that + - `dft` corresponds to circular padding + - `dct2` corresponds to Neumann boundary conditions (symmetric) + - `dst2` corresponds to Dirichlet boundary conditions (antisymmetric) + See https://en.wikipedia.org/wiki/Discrete_cosine_transform + https://en.wikipedia.org/wiki/Discrete_sine_transform""" + +_doc_bound_coeff = \ +"""`bound` can be an int, a string or a BoundType. + Possible values are: + - 'replicate' or 'nearest' : a a a | a b c d | d d d + - 'dct1' or 'mirror' : d c b | a b c d | c b a + - 'dct2' or 'reflect' : c b a | a b c d | d c b + - 'dst1' or 'antimirror' : -b -a 0 | a b c d | 0 -d -c + - 'dst2' or 'antireflect' : -c -b -a | a b c d | -d -c -b + - 'dft' or 'wrap' : b c d | a b c d | a b c + - 'zero' or 'zeros' : 0 0 0 | a b c d | 0 0 0 + A list of values can be provided, in the order [W, H, D], + to specify dimension-specific boundary conditions. + Note that + - `dft` corresponds to circular padding + - `dct1` corresponds to mirroring about the center of the first/last voxel + - `dct2` corresponds to mirroring about the edge of the first/last voxel + See https://en.wikipedia.org/wiki/Discrete_cosine_transform + https://en.wikipedia.org/wiki/Discrete_sine_transform + + /!\ Only 'dct1', 'dct2' and 'dft' are implemented for interpolation + orders >= 6.""" + + +_ref_coeff = \ +"""..[1] M. Unser, A. Aldroubi and M. Eden. + "B-Spline Signal Processing: Part I-Theory," + IEEE Transactions on Signal Processing 41(2):821-832 (1993). +..[2] M. Unser, A. Aldroubi and M. Eden. + "B-Spline Signal Processing: Part II-Efficient Design and Applications," + IEEE Transactions on Signal Processing 41(2):834-848 (1993). +..[3] M. Unser. + "Splines: A Perfect Fit for Signal and Image Processing," + IEEE Signal Processing Magazine 16(6):22-38 (1999). +""" + + +def _preproc(grid, input=None, mode=None): + """Preprocess tensors for pull/push/count/grad + + Low level bindings expect inputs of shape + [batch, channel, *spatial] and [batch, *spatial, dim], whereas + the high level python API accepts inputs of shape + [..., [channel], *spatial] and [..., *spatial, dim]. + + This function broadcasts and reshapes the input tensors accordingly. + /!\\ This *can* trigger large allocations /!\\ + """ + dim = grid.shape[-1] + if input is None: + spatial = grid.shape[-dim-1:-1] + batch = grid.shape[:-dim-1] + grid = grid.reshape([-1, *spatial, dim]) + info = dict(batch=batch, channel=[1] if batch else [], dim=dim) + return grid, info + + grid_spatial = grid.shape[-dim-1:-1] + grid_batch = grid.shape[:-dim-1] + input_spatial = input.shape[-dim:] + channel = 0 if input.dim() == dim else input.shape[-dim-1] + input_batch = input.shape[:-dim-1] + + if mode == 'push': + grid_spatial = input_spatial = expanded_shape(grid_spatial, input_spatial) + + # broadcast and reshape + batch = expanded_shape(grid_batch, input_batch) + grid = grid.expand([*batch, *grid_spatial, dim]) + grid = grid.reshape([-1, *grid_spatial, dim]) + input = input.expand([*batch, channel or 1, *input_spatial]) + input = input.reshape([-1, channel or 1, *input_spatial]) + + out_channel = [channel] if channel else ([1] if batch else []) + info = dict(batch=batch, channel=out_channel, dim=dim) + return grid, input, info + + +def _postproc(out, shape_info, mode): + """Postprocess tensors for pull/push/count/grad""" + dim = shape_info['dim'] + if mode != 'grad': + spatial = out.shape[-dim:] + feat = [] + else: + spatial = out.shape[-dim-1:-1] + feat = [out.shape[-1]] + batch = shape_info['batch'] + channel = shape_info['channel'] + + out = out.reshape([*batch, *channel, *spatial, *feat]) + return out + + +def grid_pull(input, grid, interpolation='linear', bound='zero', + extrapolate=False, prefilter=False): + """Sample an image with respect to a deformation field. + + Notes + ----- + {interpolation} + + {bound} + + If the input dtype is not a floating point type, the input image is + assumed to contain labels. Then, unique labels are extracted + and resampled individually, making them soft labels. Finally, + the label map is reconstructed from the individual soft labels by + assigning the label with maximum soft value. + + Parameters + ---------- + input : (..., [channel], *inshape) tensor + Input image. + grid : (..., *outshape, dim) tensor + Transformation field. + interpolation : int or sequence[int], default=1 + Interpolation order. + bound : BoundType or sequence[BoundType], default='zero' + Boundary conditions. + extrapolate : bool or int, default=True + Extrapolate out-of-bound data. + prefilter : bool, default=False + Apply spline pre-filter (= interpolates the input) + + Returns + ------- + output : (..., [channel], *outshape) tensor + Deformed image. + + """ + grid, input, shape_info = _preproc(grid, input) + batch, channel = input.shape[:2] + dim = grid.shape[-1] + + if not input.dtype.is_floating_point: + # label map -> specific processing + out = input.new_zeros([batch, channel, *grid.shape[1:-1]]) + pmax = grid.new_zeros([batch, channel, *grid.shape[1:-1]]) + for label in input.unique(): + soft = (input == label).to(grid.dtype) + if prefilter: + input = spline_coeff_nd(soft, interpolation=interpolation, + bound=bound, dim=dim, inplace=True) + soft = GridPull.apply(soft, grid, interpolation, bound, extrapolate) + out[soft > pmax] = label + pmax = torch.max(pmax, soft) + else: + if prefilter: + input = spline_coeff_nd(input, interpolation=interpolation, + bound=bound, dim=dim) + out = GridPull.apply(input, grid, interpolation, bound, extrapolate) + + return _postproc(out, shape_info, mode='pull') + + +def grid_push(input, grid, shape=None, interpolation='linear', bound='zero', + extrapolate=False, prefilter=False): + """Splat an image with respect to a deformation field (pull adjoint). + + Notes + ----- + {interpolation} + + {bound} + + Parameters + ---------- + input : (..., [channel], *inshape) tensor + Input image. + grid : (..., *inshape, dim) tensor + Transformation field. + shape : sequence[int], default=inshape + Output shape + interpolation : int or sequence[int], default=1 + Interpolation order. + bound : BoundType, or sequence[BoundType], default='zero' + Boundary conditions. + extrapolate : bool or int, default=True + Extrapolate out-of-bound data. + prefilter : bool, default=False + Apply spline pre-filter. + + Returns + ------- + output : (..., [channel], *shape) tensor + Spatted image. + + """ + grid, input, shape_info = _preproc(grid, input, mode='push') + dim = grid.shape[-1] + + if shape is None: + shape = tuple(input.shape[2:]) + + out = GridPush.apply(input, grid, shape, interpolation, bound, extrapolate) + if prefilter: + out = spline_coeff_nd(out, interpolation=interpolation, bound=bound, + dim=dim, inplace=True) + return _postproc(out, shape_info, mode='push') + + +def grid_count(grid, shape=None, interpolation='linear', bound='zero', + extrapolate=False): + """Splatting weights with respect to a deformation field (pull adjoint). + + Notes + ----- + {interpolation} + + {bound} + + Parameters + ---------- + grid : (..., *inshape, dim) tensor + Transformation field. + shape : sequence[int], default=inshape + Output shape + interpolation : int or sequence[int], default=1 + Interpolation order. + bound : BoundType, or sequence[BoundType], default='zero' + Boundary conditions. + extrapolate : bool or int, default=True + Extrapolate out-of-bound data. + + Returns + ------- + output : (..., [1], *shape) tensor + Splatted weights. + + """ + grid, shape_info = _preproc(grid) + out = GridCount.apply(grid, shape, interpolation, bound, extrapolate) + return _postproc(out, shape_info, mode='count') + + +def grid_grad(input, grid, interpolation='linear', bound='zero', + extrapolate=False, prefilter=False): + """Sample spatial gradients of an image with respect to a deformation field. + + Notes + ----- + {interpolation} + + {bound} + + Parameters + ---------- + input : (..., [channel], *inshape) tensor + Input image. + grid : (..., *inshape, dim) tensor + Transformation field. + shape : sequence[int], default=inshape + Output shape + interpolation : int or sequence[int], default=1 + Interpolation order. + bound : BoundType, or sequence[BoundType], default='zero' + Boundary conditions. + extrapolate : bool or int, default=True + Extrapolate out-of-bound data. + prefilter : bool, default=False + Apply spline pre-filter (= interpolates the input) + + Returns + ------- + output : (..., [channel], *shape, dim) tensor + Sampled gradients. + + """ + grid, input, shape_info = _preproc(grid, input) + dim = grid.shape[-1] + if prefilter: + input = spline_coeff_nd(input, interpolation, bound, dim) + out = GridGrad.apply(input, grid, interpolation, bound, extrapolate) + return _postproc(out, shape_info, mode='grad') + + +def spline_coeff(input, interpolation='linear', bound='dct2', dim=-1, + inplace=False): + """Compute the interpolating spline coefficients, for a given spline order + and boundary conditions, along a single dimension. + + Notes + ----- + {interpolation} + + {bound} + + References + ---------- + {ref} + + + Parameters + ---------- + input : tensor + Input image. + interpolation : int or sequence[int], default=1 + Interpolation order. + bound : BoundType or sequence[BoundType], default='dct1' + Boundary conditions. + dim : int, default=-1 + Dimension along which to process + inplace : bool, default=False + Process the volume in place. + + Returns + ------- + output : tensor + Coefficient image. + + """ + # This implementation is based on the file bsplines.c in SPM12, written + # by John Ashburner, which is itself based on the file coeff.c, + # written by Philippe Thevenaz: http://bigwww.epfl.ch/thevenaz/interpolation + # . DCT1 boundary conditions were derived by Thevenaz and Unser. + # . DFT boundary conditions were derived by John Ashburner. + # SPM12 is released under the GNU-GPL v2 license. + # Philippe Thevenaz's code does not have an explicit license as far + # as we know. + out = SplineCoeff.apply(input, bound, interpolation, dim, inplace) + return out + + +def spline_coeff_nd(input, interpolation='linear', bound='dct2', dim=None, + inplace=False): + """Compute the interpolating spline coefficients, for a given spline order + and boundary conditions, along the last `dim` dimensions. + + Notes + ----- + {interpolation} + + {bound} + + References + ---------- + {ref} + + Parameters + ---------- + input : (..., *spatial) tensor + Input image. + interpolation : int or sequence[int], default=1 + Interpolation order. + bound : BoundType or sequence[BoundType], default='dct1' + Boundary conditions. + dim : int, default=-1 + Number of spatial dimensions + inplace : bool, default=False + Process the volume in place. + + Returns + ------- + output : (..., *spatial) tensor + Coefficient image. + + """ + # This implementation is based on the file bsplines.c in SPM12, written + # by John Ashburner, which is itself based on the file coeff.c, + # written by Philippe Thevenaz: http://bigwww.epfl.ch/thevenaz/interpolation + # . DCT1 boundary conditions were derived by Thevenaz and Unser. + # . DFT boundary conditions were derived by John Ashburner. + # SPM12 is released under the GNU-GPL v2 license. + # Philippe Thevenaz's code does not have an explicit license as far + # as we know. + out = SplineCoeffND.apply(input, bound, interpolation, dim, inplace) + return out + + +grid_pull.__doc__ = grid_pull.__doc__.format( + interpolation=_doc_interpolation, bound=_doc_bound) +grid_push.__doc__ = grid_push.__doc__.format( + interpolation=_doc_interpolation, bound=_doc_bound) +grid_count.__doc__ = grid_count.__doc__.format( + interpolation=_doc_interpolation, bound=_doc_bound) +grid_grad.__doc__ = grid_grad.__doc__.format( + interpolation=_doc_interpolation, bound=_doc_bound) +spline_coeff.__doc__ = spline_coeff.__doc__.format( + interpolation=_doc_interpolation, bound=_doc_bound_coeff, ref=_ref_coeff) +spline_coeff_nd.__doc__ = spline_coeff_nd.__doc__.format( + interpolation=_doc_interpolation, bound=_doc_bound_coeff, ref=_ref_coeff) + +# aliases +pull = grid_pull +push = grid_push +count = grid_count + + +def identity_grid(shape, dtype=None, device=None): + """Returns an identity deformation field. + + Parameters + ---------- + shape : (dim,) sequence of int + Spatial dimension of the field. + dtype : torch.dtype, default=`get_default_dtype()` + Data type. + device torch.device, optional + Device. + + Returns + ------- + grid : (*shape, dim) tensor + Transformation field + + """ + mesh1d = [torch.arange(float(s), dtype=dtype, device=device) + for s in shape] + grid = torch.meshgrid(*mesh1d) + grid = torch.stack(grid, dim=-1) + return grid + + +@torch.jit.script +def add_identity_grid_(disp): + """Adds the identity grid to a displacement field, inplace. + + Parameters + ---------- + disp : (..., *spatial, dim) tensor + Displacement field + + Returns + ------- + grid : (..., *spatial, dim) tensor + Transformation field + + """ + dim = disp.shape[-1] + spatial = disp.shape[-dim-1:-1] + mesh1d = [torch.arange(s, dtype=disp.dtype, device=disp.device) + for s in spatial] + grid = torch.meshgrid(mesh1d) + disp = movedim1(disp, -1, 0) + for i, grid1 in enumerate(grid): + disp[i].add_(grid1) + disp = movedim1(disp, 0, -1) + return disp + + +@torch.jit.script +def add_identity_grid(disp): + """Adds the identity grid to a displacement field. + + Parameters + ---------- + disp : (..., *spatial, dim) tensor + Displacement field + + Returns + ------- + grid : (..., *spatial, dim) tensor + Transformation field + + """ + return add_identity_grid_(disp.clone()) + + +def affine_grid(mat, shape): + """Create a dense transformation grid from an affine matrix. + + Parameters + ---------- + mat : (..., D[+1], D+1) tensor + Affine matrix (or matrices). + shape : (D,) sequence[int] + Shape of the grid, with length D. + + Returns + ------- + grid : (..., *shape, D) tensor + Dense transformation grid + + """ + mat = torch.as_tensor(mat) + shape = list(shape) + nb_dim = mat.shape[-1] - 1 + if nb_dim != len(shape): + raise ValueError('Dimension of the affine matrix ({}) and shape ({}) ' + 'are not the same.'.format(nb_dim, len(shape))) + if mat.shape[-2] not in (nb_dim, nb_dim+1): + raise ValueError('First argument should be matrces of shape ' + '(..., {0}, {1}) or (..., {1], {1}) but got {2}.' + .format(nb_dim, nb_dim+1, mat.shape)) + batch_shape = mat.shape[:-2] + grid = identity_grid(shape, mat.dtype, mat.device) + if batch_shape: + for _ in range(len(batch_shape)): + grid = grid.unsqueeze(0) + for _ in range(nb_dim): + mat = mat.unsqueeze(-1) + lin = mat[..., :nb_dim, :nb_dim] + off = mat[..., :nb_dim, -1] + grid = matvec(lin, grid) + off + return grid \ No newline at end of file diff --git a/interpol/autograd.py b/interpol/autograd.py new file mode 100644 index 0000000..c32bdfb --- /dev/null +++ b/interpol/autograd.py @@ -0,0 +1,293 @@ +"""AutoGrad version of pull/push/count/grad""" +import torch +from torch.cuda.amp import custom_fwd, custom_bwd +from .coeff import spline_coeff_nd, spline_coeff +from .bounds import BoundType +from .splines import InterpolationType +from .pushpull import ( + grid_pull, grid_pull_backward, + grid_push, grid_push_backward, + grid_count, grid_count_backward, + grid_grad, grid_grad_backward) + + +def make_list(x): + if not isinstance(x, (list, tuple)): + x = [x] + return list(x) + + +def bound_to_nitorch(bound, as_type='str'): + """Convert boundary type to niTorch's convention. + + Parameters + ---------- + bound : [list of] str or bound_like + Boundary condition in any convention + as_type : {'str', 'enum', 'int'}, default='str' + Return BoundType or int rather than str + + Returns + ------- + bound : [list of] str or BoundType + Boundary condition in NITorch's convention + + """ + intype = type(bound) + if not isinstance(bound, (list, tuple)): + bound = [bound] + obound = [] + for b in bound: + b = b.lower() if isinstance(b, str) else b + if b in ('replicate', 'repeat', 'border', 'nearest', BoundType.replicate): + obound.append('replicate') + elif b in ('zero', 'zeros', 'constant', BoundType.zero): + obound.append('zero') + elif b in ('dct2', 'reflect', 'reflection', 'neumann', BoundType.dct2): + obound.append('dct2') + elif b in ('dct1', 'mirror', BoundType.dct1): + obound.append('dct1') + elif b in ('dft', 'wrap', 'circular', BoundType.dft): + obound.append('dft') + elif b in ('dst2', 'antireflect', 'dirichlet', BoundType.dst2): + obound.append('dst2') + elif b in ('dst1', 'antimirror', BoundType.dst1): + obound.append('dst1') + else: + raise ValueError(f'Unknown boundary condition {b}') + if as_type in ('enum', 'int', int): + obound = list(map(lambda b: getattr(BoundType, b), obound)) + if as_type in ('int', int): + obound = [b.value for b in obound] + if issubclass(intype, (list, tuple)): + obound = intype(obound) + else: + obound = obound[0] + return obound + + +def inter_to_nitorch(inter, as_type='str'): + """Convert interpolation order to NITorch's convention. + + Parameters + ---------- + inter : [sequence of] int or str or InterpolationType + as_type : {'str', 'enum', 'int'}, default='int' + + Returns + ------- + inter : [sequence of] int or InterpolationType + + """ + intype = type(inter) + if not isinstance(inter, (list, tuple)): + inter = [inter] + ointer = [] + for o in inter: + o = o.lower() if isinstance(o, str) else o + if o in (0, 'nearest', InterpolationType.nearest): + ointer.append(0) + elif o in (1, 'linear', InterpolationType.linear): + ointer.append(1) + elif o in (2, 'quadratic', InterpolationType.quadratic): + ointer.append(2) + elif o in (3, 'cubic', InterpolationType.cubic): + ointer.append(3) + elif o in (4, 'fourth', InterpolationType.fourth): + ointer.append(4) + elif o in (5, 'fifth', InterpolationType.fifth): + ointer.append(5) + elif o in (6, 'sixth', InterpolationType.sixth): + ointer.append(6) + elif o in (7, 'seventh', InterpolationType.seventh): + ointer.append(7) + else: + raise ValueError(f'Unknown interpolation order {o}') + if as_type in ('enum', 'str', str): + ointer = list(map(InterpolationType, ointer)) + if as_type in ('str', str): + ointer = [o.name for o in ointer] + if issubclass(intype, (list, tuple)): + ointer = intype(ointer) + else: + ointer = ointer[0] + return ointer + + +class GridPull(torch.autograd.Function): + + @staticmethod + @custom_fwd(cast_inputs=torch.float32) + def forward(ctx, input, grid, interpolation, bound, extrapolate): + + bound = bound_to_nitorch(make_list(bound), as_type='int') + interpolation = inter_to_nitorch(make_list(interpolation), as_type='int') + extrapolate = int(extrapolate) + opt = (bound, interpolation, extrapolate) + + # Pull + output = grid_pull(input, grid, *opt) + + # Context + ctx.opt = opt + ctx.save_for_backward(input, grid) + + return output + + @staticmethod + @custom_bwd + def backward(ctx, grad): + var = ctx.saved_tensors + opt = ctx.opt + grads = grid_pull_backward(grad, *var, *opt) + grad_input, grad_grid = grads + return grad_input, grad_grid, None, None, None + + +class GridPush(torch.autograd.Function): + + @staticmethod + @custom_fwd(cast_inputs=torch.float32) + def forward(ctx, input, grid, shape, interpolation, bound, extrapolate): + + bound = bound_to_nitorch(make_list(bound), as_type='int') + interpolation = inter_to_nitorch(make_list(interpolation), as_type='int') + extrapolate = int(extrapolate) + opt = (bound, interpolation, extrapolate) + + # Push + output = grid_push(input, grid, shape, *opt) + + # Context + ctx.opt = opt + ctx.save_for_backward(input, grid) + + return output + + @staticmethod + @custom_bwd + def backward(ctx, grad): + var = ctx.saved_tensors + opt = ctx.opt + grads = grid_push_backward(grad, *var, *opt) + grad_input, grad_grid = grads + return grad_input, grad_grid, None, None, None, None + + +class GridCount(torch.autograd.Function): + + @staticmethod + @custom_fwd(cast_inputs=torch.float32) + def forward(ctx, grid, shape, interpolation, bound, extrapolate): + + bound = bound_to_nitorch(make_list(bound), as_type='int') + interpolation = inter_to_nitorch(make_list(interpolation), as_type='int') + extrapolate = int(extrapolate) + opt = (bound, interpolation, extrapolate) + + # Push + output = grid_count(grid, shape, *opt) + + # Context + ctx.opt = opt + ctx.save_for_backward(grid) + + return output + + @staticmethod + @custom_bwd + def backward(ctx, grad): + var = ctx.saved_tensors + opt = ctx.opt + grad_grid = None + if ctx.needs_input_grad[0]: + grad_grid = grid_count_backward(grad, *var, *opt) + return grad_grid, None, None, None, None + + +class GridGrad(torch.autograd.Function): + + @staticmethod + @custom_fwd(cast_inputs=torch.float32) + def forward(ctx, input, grid, interpolation, bound, extrapolate): + + bound = bound_to_nitorch(make_list(bound), as_type='int') + interpolation = inter_to_nitorch(make_list(interpolation), as_type='int') + extrapolate = int(extrapolate) + opt = (bound, interpolation, extrapolate) + + # Pull + output = grid_grad(input, grid, *opt) + + # Context + ctx.opt = opt + ctx.save_for_backward(input, grid) + + return output + + @staticmethod + @custom_bwd + def backward(ctx, grad): + var = ctx.saved_tensors + opt = ctx.opt + grad_input = grad_grid = None + if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]: + grads = grid_grad_backward(grad, *var, *opt) + grad_input, grad_grid = grads + return grad_input, grad_grid, None, None, None + + +class SplineCoeff(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward(ctx, input, bound, interpolation, dim, inplace): + + bound = bound_to_nitorch(make_list(bound), as_type='int') + interpolation = inter_to_nitorch(make_list(interpolation), as_type='int') + opt = (bound, interpolation, dim, inplace) + + # Pull + output = spline_coeff(input, *opt) + + # Context + if input.requires_grad: + ctx.opt = opt + + return output + + @staticmethod + @custom_bwd + def backward(ctx, grad): + # symmetric filter -> backward == forward + # (I don't know if I can write into grad, so inplace=False to be safe) + grad = spline_coeff(grad, *ctx.opt[:-1], inplace=False) + return [grad] + [None] * 4 + + +class SplineCoeffND(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward(ctx, input, bound, interpolation, dim, inplace): + + bound = bound_to_nitorch(make_list(bound), as_type='int') + interpolation = inter_to_nitorch(make_list(interpolation), as_type='int') + opt = (bound, interpolation, dim, inplace) + + # Pull + output = spline_coeff_nd(input, *opt) + + # Context + if input.requires_grad: + ctx.opt = opt + + return output + + @staticmethod + @custom_bwd + def backward(ctx, grad): + # symmetric filter -> backward == forward + # (I don't know if I can write into grad, so inplace=False to be safe) + grad = spline_coeff_nd(grad, *ctx.opt[:-1], inplace=False) + return grad, None, None, None, None diff --git a/interpol/bounds.py b/interpol/bounds.py new file mode 100644 index 0000000..7134196 --- /dev/null +++ b/interpol/bounds.py @@ -0,0 +1,86 @@ +import torch +from enum import Enum +from typing import Optional +Tensor = torch.Tensor + + +class BoundType(Enum): + zero = zeros = 0 + replicate = nearest = 1 + dct1 = mirror = 2 + dct2 = reflect = 3 + dst1 = antimirror = 4 + dst2 = antireflect = 5 + dft = wrap = 6 + + +class ExtrapolateType(Enum): + no = 0 # threshold: (0, n-1) + yes = 1 + hist = 2 # threshold: (-0.5, n-0.5) + + +@torch.jit.script +class Bound: + + def __init__(self, bound_type: int = 3): + self.type = bound_type + + def index(self, i, n: int): + if self.type in (0, 1): # zero / replicate + return i.clamp(min=0, max=n-1) + elif self.type in (3, 5): # dct2 / dst2 + n2 = n * 2 + i = torch.where(i < 0, (-i-1).remainder(n2).neg().add(n2 - 1), + i.remainder(n2)) + i = torch.where(i >= n, -i + (n2 - 1), i) + return i + elif self.type == 2: # dct1 + if n == 1: + return torch.zeros(i.shape, dtype=i.dtype, device=i.device) + else: + n2 = (n - 1) * 2 + i = i.abs().remainder(n2) + i = torch.where(i >= n, -i + n2, i) + return i + elif self.type == 4: # dst1 + n2 = 2 * (n + 1) + first = torch.zeros([1], dtype=i.dtype, device=i.device) + last = torch.full([1], n - 1, dtype=i.dtype, device=i.device) + i = torch.where(i < 0, -i - 2, i) + i = i.remainder(n2) + i = torch.where(i > n, -i + (n2 - 2), i) + i = torch.where(i == -1, first, i) + i = torch.where(i == n, last, i) + return i + elif self.type == 6: # dft + return i.remainder(n) + else: + return i + + def transform(self, i, n: int) -> Optional[Tensor]: + if self.type == 4: # dst1 + if n == 1: + return None + one = torch.ones([1], dtype=torch.int8, device=i.device) + zero = torch.zeros([1], dtype=torch.int8, device=i.device) + n2 = 2 * (n + 1) + i = torch.where(i < 0, -i + (n-1), i) + i = i.remainder(n2) + x = torch.where(i == 0, zero, one) + x = torch.where(i.remainder(n + 1) == n, zero, x) + x = torch.where((i / (n+1)).remainder_(2) > 0, -x, x) + return x + elif self.type == 5: # dst2 + i = torch.where(i < 0, n - 1 - i, i) + x = torch.ones([1], dtype=torch.int8, device=i.device) + x = torch.where((i / n).remainder_(2) > 0, -x, x) + return x + elif self.type == 0: # zero + one = torch.ones([1], dtype=torch.int8, device=i.device) + zero = torch.zeros([1], dtype=torch.int8, device=i.device) + outbounds = ((i < 0) | (i >= n)) + x = torch.where(outbounds, zero, one) + return x + else: + return None diff --git a/interpol/coeff.py b/interpol/coeff.py new file mode 100644 index 0000000..d1d6c04 --- /dev/null +++ b/interpol/coeff.py @@ -0,0 +1,344 @@ +"""Compute spline interpolating coefficients + +These functions are ported from the C routines in SPM's bsplines.c +by John Ashburner, which are themselves ports from Philippe Thevenaz's +code. JA furthermore derived the initial conditions for the DFT ("wrap around") +boundary conditions. + +Note that similar routines are available in scipy with boundary conditions +DCT1 ("mirror"), DCT2 ("reflect") and DFT ("wrap"); all derived by P. Thevenaz, +according to the comments. Our DCT2 boundary conditions are ported from +scipy. + +Only boundary conditions DCT1, DCT2 and DFT are implemented. + +References +---------- +..[1] M. Unser, A. Aldroubi and M. Eden. + "B-Spline Signal Processing: Part I-Theory," + IEEE Transactions on Signal Processing 41(2):821-832 (1993). +..[2] M. Unser, A. Aldroubi and M. Eden. + "B-Spline Signal Processing: Part II-Efficient Design and Applications," + IEEE Transactions on Signal Processing 41(2):834-848 (1993). +..[3] M. Unser. + "Splines: A Perfect Fit for Signal and Image Processing," + IEEE Signal Processing Magazine 16(6):22-38 (1999). +""" +import torch +import math +from typing import List, Optional +from .jit_utils import movedim1 +from .pushpull import pad_list_int + + +@torch.jit.script +def get_poles(order: int) -> List[float]: + empty: List[float] = [] + if order in (0, 1): + return empty + if order == 2: + return [math.sqrt(8.) - 3.] + if order == 3: + return [math.sqrt(3.) - 2.] + if order == 4: + return [math.sqrt(664. - math.sqrt(438976.)) + math.sqrt(304.) - 19., + math.sqrt(664. + math.sqrt(438976.)) - math.sqrt(304.) - 19.] + if order == 5: + return [math.sqrt(67.5 - math.sqrt(4436.25)) + math.sqrt(26.25) - 6.5, + math.sqrt(67.5 + math.sqrt(4436.25)) - math.sqrt(26.25) - 6.5] + if order == 6: + return [-0.488294589303044755130118038883789062112279161239377608394, + -0.081679271076237512597937765737059080653379610398148178525368, + -0.00141415180832581775108724397655859252786416905534669851652709] + if order == 7: + return [-0.5352804307964381655424037816816460718339231523426924148812, + -0.122554615192326690515272264359357343605486549427295558490763, + -0.0091486948096082769285930216516478534156925639545994482648003] + raise NotImplementedError + + +@torch.jit.script +def get_gain(poles: List[float]) -> float: + lam: float = 1. + for pole in poles: + lam *= (1. - pole) * (1. - 1./pole) + return lam + + +@torch.jit.script +def dft_initial(inp, pole: float, dim: int = -1, keepdim: bool = False): + + assert inp.shape[dim] > 1 + max_iter: int = int(math.ceil(-30./math.log(abs(pole)))) + max_iter = min(max_iter, inp.shape[dim]) + + poles = torch.as_tensor(pole, dtype=inp.dtype, device=inp.device) + poles = poles.pow(torch.arange(1, max_iter, dtype=inp.dtype, device=inp.device)) + poles = poles.flip(0) + + inp = movedim1(inp, dim, 0) + inp0 = inp[0] + inp = inp[1-max_iter:] + inp = movedim1(inp, 0, -1) + out = torch.matmul(inp.unsqueeze(-2), poles.unsqueeze(-1)).squeeze(-1) + out = out + inp0.unsqueeze(-1) + if keepdim: + out = movedim1(out, -1, dim) + else: + out = out.squeeze(-1) + + pole = pole ** max_iter + out = out / (1 - pole) + return out + + +@torch.jit.script +def dct1_initial(inp, pole: float, dim: int = -1, keepdim: bool = False): + + n = inp.shape[dim] + max_iter: int = int(math.ceil(-30./math.log(abs(pole)))) + + if max_iter < n: + + poles = torch.as_tensor(pole, dtype=inp.dtype, device=inp.device) + poles = poles.pow(torch.arange(1, max_iter, dtype=inp.dtype, device=inp.device)) + + inp = movedim1(inp, dim, 0) + inp0 = inp[0] + inp = inp[1:max_iter] + inp = movedim1(inp, 0, -1) + out = torch.matmul(inp.unsqueeze(-2), poles.unsqueeze(-1)).squeeze(-1) + out = out + inp0.unsqueeze(-1) + if keepdim: + out = movedim1(out, -1, dim) + else: + out = out.squeeze(-1) + + else: + max_iter = n + + polen = pole ** (n - 1) + inp0 = inp[0] + polen * inp[-1] + inp = inp[1:-1] + inp = movedim1(inp, 0, -1) + + poles = torch.as_tensor(pole, dtype=inp.dtype, device=inp.device) + poles = poles.pow(torch.arange(1, n-1, dtype=inp.dtype, device=inp.device)) + poles = poles + (polen * polen) / poles + + out = torch.matmul(inp.unsqueeze(-2), poles.unsqueeze(-1)).squeeze(-1) + out = out + inp0.unsqueeze(-1) + if keepdim: + out = movedim1(out, -1, dim) + else: + out = out.squeeze(-1) + + pole = pole ** (max_iter - 1) + out = out / (1 - pole * pole) + + return out + + +@torch.jit.script +def dct2_initial(inp, pole: float, dim: int = -1, keepdim: bool = False): + # Ported from scipy: + # https://github.com/scipy/scipy/blob/master/scipy/ndimage/src/ni_splines.c + # + # I (YB) unwarped and simplied the terms so that I could use a dot + # product instead of a loop. + # It should certainly be possible to derive a version for max_iter < n, + # as JA did for DCT1, to avoid long recursions when `n` is large. But + # I think it would require a more complicated anticausal/final condition. + + n = inp.shape[dim] + + polen = pole ** n + pole_last = polen * (1 + 1/(pole + polen * polen)) + inp00 = inp[0] + inp0 = inp[0] + pole_last * inp[-1] + inp = inp[1:-1] + inp = movedim1(inp, 0, -1) + + poles = torch.as_tensor(pole, dtype=inp.dtype, device=inp.device) + poles = (poles.pow(torch.arange(1, n-1, dtype=inp.dtype, device=inp.device)) + + poles.pow(torch.arange(2*n-2, n, -1, dtype=inp.dtype, device=inp.device))) + + out = torch.matmul(inp.unsqueeze(-2), poles.unsqueeze(-1)).squeeze(-1) + + out = out + inp0.unsqueeze(-1) + out = out * (pole / (1 - polen * polen)) + out = out + inp00.unsqueeze(-1) + + if keepdim: + out = movedim1(out, -1, dim) + else: + out = out.squeeze(-1) + + return out + + +@torch.jit.script +def dft_final(inp, pole: float, dim: int = -1, keepdim: bool = False): + + assert inp.shape[dim] > 1 + max_iter: int = int(math.ceil(-30./math.log(abs(pole)))) + max_iter = min(max_iter, inp.shape[dim]) + + poles = torch.as_tensor(pole, dtype=inp.dtype, device=inp.device) + poles = poles.pow(torch.arange(2, max_iter+1, dtype=inp.dtype, device=inp.device)) + + inp = movedim1(inp, dim, 0) + inp0 = inp[-1] + inp = inp[:max_iter-1] + inp = movedim1(inp, 0, -1) + out = torch.matmul(inp.unsqueeze(-2), poles.unsqueeze(-1)).squeeze(-1) + out = out.add(inp0.unsqueeze(-1), alpha=pole) + if keepdim: + out = movedim1(out, -1, dim) + else: + out = out.squeeze(-1) + + pole = pole ** max_iter + out = out / (pole - 1) + return out + + +@torch.jit.script +def dct1_final(inp, pole: float, dim: int = -1, keepdim: bool = False): + inp = movedim1(inp, dim, 0) + out = pole * inp[-2] + inp[-1] + out = out * (pole / (pole*pole - 1)) + if keepdim: + out = movedim1(out.unsqueeze(0), 0, dim) + return out + + +@torch.jit.script +def dct2_final(inp, pole: float, dim: int = -1, keepdim: bool = False): + # Ported from scipy: + # https://github.com/scipy/scipy/blob/master/scipy/ndimage/src/ni_splines.c + inp = movedim1(inp, dim, 0) + out = inp[-1] * (pole / (pole - 1)) + if keepdim: + out = movedim1(out.unsqueeze(0), 0, dim) + return out + + +@torch.jit.script +class CoeffBound: + + def __init__(self, bound: int): + self.bound = bound + + def initial(self, inp, pole: float, dim: int = -1, keepdim: bool = False): + if self.bound in (0, 2): # zero, dct1 + return dct1_initial(inp, pole, dim, keepdim) + elif self.bound in (1, 3): # nearest, dct2 + return dct2_initial(inp, pole, dim, keepdim) + elif self.bound == 6: # dft + return dft_initial(inp, pole, dim, keepdim) + else: + raise NotImplementedError + + def final(self, inp, pole: float, dim: int = -1, keepdim: bool = False): + if self.bound in (0, 2): # zero, dct1 + return dct1_final(inp, pole, dim, keepdim) + elif self.bound in (1, 3): # nearest, dct2 + return dct2_final(inp, pole, dim, keepdim) + elif self.bound == 6: # dft + return dft_final(inp, pole, dim, keepdim) + else: + raise NotImplementedError + + +@torch.jit.script +def filter(inp, bound: CoeffBound, poles: List[float], + dim: int = -1, inplace: bool = False): + + if not inplace: + inp = inp.clone() + + if inp.shape[dim] == 1: + return inp + + gain = get_gain(poles) + inp *= gain + inp = movedim1(inp, dim, 0) + n = inp.shape[0] + + for pole in poles: + inp[0] = bound.initial(inp, pole, dim=0, keepdim=False) + + for i in range(1, n): + inp[i].add_(inp[i-1], alpha=pole) + + inp[-1] = bound.final(inp, pole, dim=0, keepdim=False) + + for i in range(n-2, -1, -1): + inp[i].neg_().add_(inp[i+1]).mul_(pole) + + inp = movedim1(inp, 0, dim) + return inp + + +@torch.jit.script +def spline_coeff(inp, bound: int, order: int, dim: int = -1, + inplace: bool = False): + """Compute the interpolating spline coefficients, for a given spline order + and boundary conditions, along a single dimension. + + Parameters + ---------- + inp : tensor + bound : {2: dct1, 6: dft} + order : {0..7} + dim : int, default=-1 + inplace : bool, default=False + + Returns + ------- + out : tensor + + """ + if not inplace: + inp = inp.clone() + + if order in (0, 1): + return inp + + poles = get_poles(order) + return filter(inp, CoeffBound(bound), poles, dim=dim, inplace=True) + + +@torch.jit.script +def spline_coeff_nd(inp, bound: List[int], order: List[int], + dim: Optional[int] = None, inplace: bool = False): + """Compute the interpolating spline coefficients, for a given spline order + and boundary condition, along the last `dim` dimensions. + + Parameters + ---------- + inp : (..., *spatial) tensor + bound : List[{2: dct1, 6: dft}] + order : List[{0..7}] + dim : int, default=`inp.dim()` + inplace : bool, default=False + + Returns + ------- + out : (..., *spatial) tensor + + """ + if not inplace: + inp = inp.clone() + + if dim is None: + dim = inp.dim() + + bound = pad_list_int(bound, dim) + order = pad_list_int(order, dim) + + for d, b, o in zip(range(dim), bound, order): + inp = spline_coeff(inp, b, o, dim=-dim + d, inplace=True) + + return inp diff --git a/interpol/iso0.py b/interpol/iso0.py new file mode 100644 index 0000000..c86fdb3 --- /dev/null +++ b/interpol/iso0.py @@ -0,0 +1,368 @@ +"""Isotropic 0-th order splines ("nearest neighbor")""" +import torch +from .bounds import Bound +from .jit_utils import (sub2ind_list, make_sign, + inbounds_mask_3d, inbounds_mask_2d, inbounds_mask_1d) +from typing import List, Optional +Tensor = torch.Tensor + + +@torch.jit.script +def get_indices(g, n: int, bound: Bound): + g0 = g.round().long() + sign0 = bound.transform(g0, n) + g0 = bound.index(g0, n) + return g0, sign0 + + +# ====================================================================== +# 3D +# ====================================================================== + + +@torch.jit.script +def pull3d(inp, g, bound: List[Bound], extrapolate: int = 1): + """ + inp: (B, C, iX, iY, iZ) tensor + g: (B, oX, oY, oZ, 3) tensor + bound: List{3}[Bound] tensor + extrapolate: ExtrapolateType + returns: (B, C, oX, oY, oZ) tensor + """ + dim = 3 + boundx, boundy, boundz = bound + oshape = g.shape[-dim-1:-1] + g = g.reshape([g.shape[0], 1, -1, dim]) + gx, gy, gz = g.unbind(-1) + batch = max(inp.shape[0], gx.shape[0]) + channel = inp.shape[1] + shape = inp.shape[-dim:] + nx, ny, nz = shape + + # mask of inbounds voxels + mask = inbounds_mask_3d(extrapolate, gx, gy, gz, nx, ny, nz) + + # nearest integer coordinates + gx, signx = get_indices(gx, nx, boundx) + gy, signy = get_indices(gy, ny, boundy) + gz, signz = get_indices(gz, nz, boundz) + + # gather + inp = inp.reshape(inp.shape[:2] + [-1]) + idx = sub2ind_list([gx, gy, gz], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out = inp.gather(-1, idx) + sign = make_sign([signx, signy, signz]) + if sign is not None: + out *= sign + if mask is not None: + out *= mask + out = out.reshape(out.shape[:2] + oshape) + return out + + +@torch.jit.script +def push3d(inp, g, shape: Optional[List[int]], bound: List[Bound], + extrapolate: int = 1): + """ + inp: (B, C, iX, iY, iZ) tensor + g: (B, iX, iY, iZ, 3) tensor + shape: List{3}[int], optional + bound: List{3}[Bound] tensor + extrapolate: ExtrapolateType + returns: (B, C, *shape) tensor + """ + dim = 3 + boundx, boundy, boundz = bound + if inp.shape[-dim:] != g.shape[-dim-1:-1]: + raise ValueError('Input and grid should have the same spatial shape') + ishape = inp.shape[-dim:] + g = g.reshape([g.shape[0], 1, -1, dim]) + gx, gy, gz = torch.unbind(g, -1) + inp = inp.reshape(inp.shape[:2] + [-1]) + batch = max(inp.shape[0], gx.shape[0]) + channel = inp.shape[1] + + if shape is None: + shape = ishape + nx, ny, nz = shape + + # mask of inbounds voxels + mask = inbounds_mask_3d(extrapolate, gx, gy, gz, nx, ny, nz) + + # nearest integer coordinates + gx, signx = get_indices(gx, nx, boundx) + gy, signy = get_indices(gy, ny, boundy) + gz, signz = get_indices(gz, nz, boundz) + + # scatter + out = torch.zeros([batch, channel, nx*ny*nz], dtype=inp.dtype, device=inp.device) + idx = sub2ind_list([gx, gy, gz], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + sign = make_sign([signx, signy, signz]) + if sign is not None or mask is not None: + inp = inp.clone() + if sign is not None: + inp *= sign + if mask is not None: + inp *= mask + out.scatter_add_(-1, idx, inp) + + out = out.reshape(out.shape[:2] + shape) + return out + + +# ====================================================================== +# 2D +# ====================================================================== + + +@torch.jit.script +def pull2d(inp, g, bound: List[Bound], extrapolate: int = 1): + """ + inp: (B, C, iX, iY) tensor + g: (B, oX, oY, 2) tensor + bound: List{2}[Bound] tensor + extrapolate: ExtrapolateType + returns: (B, C, oX, oY) tensor + """ + dim = 2 + boundx, boundy = bound + oshape = g.shape[-dim-1:-1] + g = g.reshape([g.shape[0], 1, -1, dim]) + gx, gy = g.unbind(-1) + batch = max(inp.shape[0], gx.shape[0]) + channel = inp.shape[1] + shape = inp.shape[-dim:] + nx, ny = shape + + # mask of inbounds voxels + mask = inbounds_mask_2d(extrapolate, gx, gy, nx, ny) + + # nearest integer coordinates + gx, signx = get_indices(gx, nx, boundx) + gy, signy = get_indices(gy, ny, boundy) + + # gather + inp = inp.reshape(inp.shape[:2] + [-1]) + idx = sub2ind_list([gx, gy], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out = inp.gather(-1, idx) + sign = make_sign([signx, signy]) + if sign is not None: + out *= sign + if mask is not None: + out *= mask + out = out.reshape(out.shape[:2] + oshape) + return out + + +@torch.jit.script +def push2d(inp, g, shape: Optional[List[int]], bound: List[Bound], + extrapolate: int = 1): + """ + inp: (B, C, iX, iY) tensor + g: (B, iX, iY, 2) tensor + shape: List{2}[int], optional + bound: List{2}[Bound] tensor + extrapolate: ExtrapolateType + returns: (B, C, *shape) tensor + """ + dim = 2 + boundx, boundy = bound + if inp.shape[-dim:] != g.shape[-dim-1:-1]: + raise ValueError('Input and grid should have the same spatial shape') + ishape = inp.shape[-dim:] + g = g.reshape([g.shape[0], 1, -1, dim]) + gx, gy = torch.unbind(g, -1) + inp = inp.reshape(inp.shape[:2] + [-1]) + batch = max(inp.shape[0], gx.shape[0]) + channel = inp.shape[1] + + if shape is None: + shape = ishape + nx, ny = shape + + # mask of inbounds voxels + mask = inbounds_mask_2d(extrapolate, gx, gy, nx, ny) + + # nearest integer coordinates + gx, signx = get_indices(gx, nx, boundx) + gy, signy = get_indices(gy, ny, boundy) + + # scatter + out = torch.zeros([batch, channel, nx*ny], dtype=inp.dtype, device=inp.device) + idx = sub2ind_list([gx, gy], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + sign = make_sign([signx, signy]) + if sign is not None or mask is not None: + inp = inp.clone() + if sign is not None: + inp *= sign + if mask is not None: + inp *= mask + out.scatter_add_(-1, idx, inp) + + out = out.reshape(out.shape[:2] + shape) + return out + + +# ====================================================================== +# 1D +# ====================================================================== + + +@torch.jit.script +def pull1d(inp, g, bound: List[Bound], extrapolate: int = 1): + """ + inp: (B, C, iX) tensor + g: (B, oX, 1) tensor + bound: List{1}[Bound] tensor + extrapolate: ExtrapolateType + returns: (B, C, oX) tensor + """ + dim = 1 + boundx = bound[0] + oshape = g.shape[-dim-1:-1] + g = g.reshape([g.shape[0], 1, -1, dim]) + gx = g.squeeze(-1) + batch = max(inp.shape[0], gx.shape[0]) + channel = inp.shape[1] + shape = inp.shape[-dim:] + nx = shape[0] + + # mask of inbounds voxels + mask = inbounds_mask_1d(extrapolate, gx, nx) + + # nearest integer coordinates + gx, signx = get_indices(gx, nx, boundx) + + # gather + inp = inp.reshape(inp.shape[:2] + [-1]) + idx = gx + idx = idx.expand([batch, channel, idx.shape[-1]]) + out = inp.gather(-1, idx) + sign = signx + if sign is not None: + out *= sign + if mask is not None: + out *= mask + out = out.reshape(out.shape[:2] + oshape) + return out + + +@torch.jit.script +def push1d(inp, g, shape: Optional[List[int]], bound: List[Bound], + extrapolate: int = 1): + """ + inp: (B, C, iX) tensor + g: (B, iX, 1) tensor + shape: List{1}[int], optional + bound: List{1}[Bound] tensor + extrapolate: ExtrapolateType + returns: (B, C, *shape) tensor + """ + dim = 1 + boundx = bound[0] + if inp.shape[-dim:] != g.shape[-dim-1:-1]: + raise ValueError('Input and grid should have the same spatial shape') + ishape = inp.shape[-dim:] + g = g.reshape([g.shape[0], 1, -1, dim]) + gx = g.squeeze(-1) + inp = inp.reshape(inp.shape[:2] + [-1]) + batch = max(inp.shape[0], gx.shape[0]) + channel = inp.shape[1] + + if shape is None: + shape = ishape + nx = shape[0] + + # mask of inbounds voxels + mask = inbounds_mask_1d(extrapolate, gx, nx) + + # nearest integer coordinates + gx, signx = get_indices(gx, nx, boundx) + + # scatter + out = torch.zeros([batch, channel, nx], dtype=inp.dtype, device=inp.device) + idx = gx + idx = idx.expand([batch, channel, idx.shape[-1]]) + sign = signx + if sign is not None or mask is not None: + inp = inp.clone() + if sign is not None: + inp *= sign + if mask is not None: + inp *= mask + out.scatter_add_(-1, idx, inp) + + out = out.reshape(out.shape[:2] + shape) + return out + + +# ====================================================================== +# ND +# ====================================================================== + + +@torch.jit.script +def grad(inp, g, bound: List[Bound], extrapolate: int = 1): + """ + inp: (B, C, *ishape) tensor + g: (B, *oshape, D) tensor + bound: List{D}[Bound] tensor + extrapolate: ExtrapolateType + returns: (B, C, *oshape, D) tensor + """ + dim = g.shape[-1] + oshape = list(g.shape[-dim-1:-1]) + batch = max(inp.shape[0], g.shape[0]) + channel = inp.shape[1] + + return torch.zeros([batch, channel] + oshape + [dim], + dtype=inp.dtype, device=inp.device) + + +@torch.jit.script +def pushgrad(inp, g, shape: Optional[List[int]], bound: List[Bound], + extrapolate: int = 1): + """ + inp: (B, C, *ishape, D) tensor + g: (B, *ishape, D) tensor + shape: List{D}[int], optional, optional + bound: List{D}[Bound] tensor + extrapolate: ExtrapolateType + returns: (B, C, *shape) tensor + """ + dim = g.shape[-1] + if inp.shape[-dim-1:-1] != g.shape[-dim-1:-1]: + raise ValueError('Input and grid should have the same spatial shape') + ishape = inp.shape[-dim-1:-1] + batch = max(inp.shape[0], g.shape[0]) + channel = inp.shape[1] + + if shape is None: + shape = ishape + shape = list(shape) + + return torch.zeros([batch, channel] + shape, + dtype=inp.dtype, device=inp.device) + + +@torch.jit.script +def hess(inp, g, bound: List[Bound], extrapolate: int = 1): + """ + inp: (B, C, *ishape) tensor + g: (B, *oshape, D) tensor + bound: List{D}[Bound] tensor + extrapolate: ExtrapolateType + returns: (B, C, *oshape, D, D) tensor + """ + dim = g.shape[-1] + oshape = list(g.shape[-dim-1:-1]) + g = g.reshape([g.shape[0], 1, -1, dim]) + batch = max(inp.shape[0], g.shape[0]) + channel = inp.shape[1] + + return torch.zeros([batch, channel] + oshape + [dim, dim], + dtype=inp.dtype, device=inp.device) diff --git a/interpol/iso1.py b/interpol/iso1.py new file mode 100644 index 0000000..1d64c40 --- /dev/null +++ b/interpol/iso1.py @@ -0,0 +1,1338 @@ +"""Isotropic 1-st order splines ("linear/bilinear/trilinear")""" +import torch +from .bounds import Bound +from .jit_utils import (sub2ind_list, make_sign, + inbounds_mask_3d, inbounds_mask_2d, inbounds_mask_1d) +from typing import List, Optional +Tensor = torch.Tensor + + +@torch.jit.script +def get_weights_and_indices(g, n: int, bound: Bound): + g0 = g.floor().long() + g1 = g0 + 1 + sign1 = bound.transform(g1, n) + sign0 = bound.transform(g0, n) + g1 = bound.index(g1, n) + g0 = bound.index(g0, n) + g = g - g.floor() + return g, g0, g1, sign0, sign1 + + +# ====================================================================== +# 3D +# ====================================================================== + + +@torch.jit.script +def pull3d(inp, g, bound: List[Bound], extrapolate: int = 1): + """ + inp: (B, C, iX, iY, iZ) tensor + g: (B, oX, oY, oZ, 3) tensor + bound: List{3}[Bound] tensor + extrapolate: ExtrapolateType + returns: (B, C, oX, oY, oZ) tensor + """ + dim = 3 + boundx, boundy, boundz = bound + oshape = list(g.shape[-dim-1:-1]) + g = g.reshape([g.shape[0], 1, -1, dim]) + gx, gy, gz = g.unbind(-1) + batch = max(inp.shape[0], gx.shape[0]) + channel = inp.shape[1] + shape = list(inp.shape[-dim:]) + nx, ny, nz = shape + + # mask of inbounds voxels + mask = inbounds_mask_3d(extrapolate, gx, gy, gz, nx, ny, nz) + + # corners + # (upper weight, lower corner, upper corner, lower sign, upper sign) + gx, gx0, gx1, signx0, signx1 = get_weights_and_indices(gx, nx, boundx) + gy, gy0, gy1, signy0, signy1 = get_weights_and_indices(gy, ny, boundy) + gz, gz0, gz1, signz0, signz1 = get_weights_and_indices(gz, nz, boundz) + + # gather + inp = inp.reshape(list(inp.shape[:2]) + [-1]) + # - corner 000 + idx = sub2ind_list([gx0, gy0, gz0], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out = inp.gather(-1, idx) + sign = make_sign([signx0, signy0, signz0]) + if sign is not None: + out *= sign + out *= (1 - gx) * (1 - gy) * (1 - gz) + # - corner 001 + idx = sub2ind_list([gx0, gy0, gz1], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.gather(-1, idx) + sign = make_sign([signx0, signy0, signz1]) + if sign is not None: + out1 *= sign + out1 *= (1 - gx) * (1 - gy) * gz + out += out1 + # - corner 010 + idx = sub2ind_list([gx0, gy1, gz0], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.gather(-1, idx) + sign = make_sign([signx0, signy1, signz0]) + if sign is not None: + out1 *= sign + out1 *= (1 - gx) * gy * (1 - gz) + out += out1 + # - corner 011 + idx = sub2ind_list([gx0, gy1, gz1], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.gather(-1, idx) + sign = make_sign([signx0, signy1, signz1]) + if sign is not None: + out1 *= sign + out1 *= (1 - gx) * gy * gz + out += out1 + # - corner 100 + idx = sub2ind_list([gx1, gy0, gz0], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.gather(-1, idx) + sign = make_sign([signx1, signy0, signz0]) + if sign is not None: + out1 *= sign + out1 *= gx * (1 - gy) * (1 - gz) + out += out1 + # - corner 101 + idx = sub2ind_list([gx1, gy0, gz1], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.gather(-1, idx) + sign = make_sign([signx1, signy0, signz1]) + if sign is not None: + out1 *= sign + out1 *= gx * (1 - gy) * gz + out += out1 + # - corner 110 + idx = sub2ind_list([gx1, gy1, gz0], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.gather(-1, idx) + sign = make_sign([signx1, signy1, signz0]) + if sign is not None: + out1 *= sign + out1 *= gx * gy * (1 - gz) + out += out1 + # - corner 111 + idx = sub2ind_list([gx1, gy1, gz1], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.gather(-1, idx) + sign = make_sign([signx1, signy1, signz1]) + if sign is not None: + out1 *= sign + out1 *= gx * gy * gz + out += out1 + + if mask is not None: + out *= mask + out = out.reshape(list(out.shape[:2]) + oshape) + return out + + +@torch.jit.script +def push3d(inp, g, shape: Optional[List[int]], bound: List[Bound], + extrapolate: int = 1): + """ + inp: (B, C, iX, iY, iZ) tensor + g: (B, iX, iY, iZ, 3) tensor + shape: List{3}[int], optional + bound: List{3}[Bound] tensor + extrapolate: ExtrapolateType + returns: (B, C, *shape) tensor + """ + dim = 3 + boundx, boundy, boundz = bound + if inp.shape[-dim:] != g.shape[-dim-1:-1]: + raise ValueError('Input and grid should have the same spatial shape') + ishape = list(inp.shape[-dim:]) + g = g.reshape([g.shape[0], 1, -1, dim]) + gx, gy, gz = torch.unbind(g, -1) + inp = inp.reshape(list(inp.shape[:2]) + [-1]) + batch = max(inp.shape[0], gx.shape[0]) + channel = inp.shape[1] + + if shape is None: + shape = ishape + shape = list(shape) + nx, ny, nz = shape + + # mask of inbounds voxels + mask = inbounds_mask_3d(extrapolate, gx, gy, gz, nx, ny, nz) + + # corners + # (upper weight, lower corner, upper corner, lower sign, upper sign) + gx, gx0, gx1, signx0, signx1 = get_weights_and_indices(gx, nx, boundx) + gy, gy0, gy1, signy0, signy1 = get_weights_and_indices(gy, ny, boundy) + gz, gz0, gz1, signz0, signz1 = get_weights_and_indices(gz, nz, boundz) + + # scatter + out = torch.zeros([batch, channel, nx*ny*nz], + dtype=inp.dtype, device=inp.device) + # - corner 000 + idx = sub2ind_list([gx0, gy0, gz0], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.clone() + sign = make_sign([signx0, signy0, signz0]) + if sign is not None: + out1 *= sign + if mask is not None: + out1 *= mask + out1 *= (1 - gx) * (1 - gy) * (1 - gz) + out.scatter_add_(-1, idx, out1) + # - corner 001 + idx = sub2ind_list([gx0, gy0, gz1], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.clone() + sign = make_sign([signx0, signy0, signz1]) + if sign is not None: + out1 *= sign + if mask is not None: + out1 *= mask + out1 *= (1 - gx) * (1 - gy) * gz + out.scatter_add_(-1, idx, out1) + # - corner 010 + idx = sub2ind_list([gx0, gy1, gz0], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.clone() + sign = make_sign([signx0, signy1, signz0]) + if sign is not None: + out1 *= sign + if mask is not None: + out1 *= mask + out1 *= (1 - gx) * gy * (1 - gz) + out.scatter_add_(-1, idx, out1) + # - corner 011 + idx = sub2ind_list([gx0, gy1, gz1], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.clone() + sign = make_sign([signx0, signy1, signz1]) + if sign is not None: + out1 *= sign + if mask is not None: + out1 *= mask + out1 *= (1 - gx) * gy * gz + out.scatter_add_(-1, idx, out1) + # - corner 100 + idx = sub2ind_list([gx1, gy0, gz0], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.clone() + sign = make_sign([signx1, signy0, signz0]) + if sign is not None: + out1 *= sign + if mask is not None: + out1 *= mask + out1 *= gx * (1 - gy) * (1 - gz) + out.scatter_add_(-1, idx, out1) + # - corner 101 + idx = sub2ind_list([gx1, gy0, gz1], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.clone() + sign = make_sign([signx1, signy0, signz1]) + if sign is not None: + out1 *= sign + if mask is not None: + out1 *= mask + out1 *= gx * (1 - gy) * gz + out.scatter_add_(-1, idx, out1) + # - corner 110 + idx = sub2ind_list([gx1, gy1, gz0], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.clone() + sign = make_sign([signx1, signy1, signz0]) + if sign is not None: + out1 *= sign + if mask is not None: + out1 *= mask + out1 *= gx * gy * (1 - gz) + out.scatter_add_(-1, idx, out1) + # - corner 111 + idx = sub2ind_list([gx1, gy1, gz1], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.clone() + sign = make_sign([signx1, signy1, signz1]) + if sign is not None: + out1 *= sign + if mask is not None: + out1 *= mask + out1 *= gx * gy * gz + out.scatter_add_(-1, idx, out1) + + out = out.reshape(list(out.shape[:2]) + shape) + return out + + +@torch.jit.script +def grad3d(inp, g, bound: List[Bound], extrapolate: int = 1): + """ + inp: (B, C, iX, iY, iZ) tensor + g: (B, oX, oY, oZ, 3) tensor + bound: List{3}[Bound] tensor + extrapolate: ExtrapolateType + returns: (B, C, oX, oY, oZ, 3) tensor + """ + dim = 3 + boundx, boundy, boundz = bound + oshape = list(g.shape[-dim-1:-1]) + g = g.reshape([g.shape[0], 1, -1, dim]) + gx, gy, gz = torch.unbind(g, -1) + batch = max(inp.shape[0], gx.shape[0]) + channel = inp.shape[1] + shape = list(inp.shape[-dim:]) + nx, ny, nz = shape + + # mask of inbounds voxels + mask = inbounds_mask_3d(extrapolate, gx, gy, gz, nx, ny, nz) + + # corners + # (upper weight, lower corner, upper corner, lower sign, upper sign) + gx, gx0, gx1, signx0, signx1 = get_weights_and_indices(gx, nx, boundx) + gy, gy0, gy1, signy0, signy1 = get_weights_and_indices(gy, ny, boundy) + gz, gz0, gz1, signz0, signz1 = get_weights_and_indices(gz, nz, boundz) + + # gather + inp = inp.reshape(list(inp.shape[:2]) + [-1]) + out = torch.empty([batch, channel] + list(g.shape[-2:]), + dtype=inp.dtype, device=inp.device) + outx, outy, outz = out.unbind(-1) + # - corner 000 + idx = sub2ind_list([gx0, gy0, gz0], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + torch.gather(inp, -1, idx, out=outx) + outy.copy_(outx) + outz.copy_(outx) + sign = make_sign([signx0, signy0, signz0]) + if sign is not None: + out *= sign.unsqueeze(-1) + outx *= - (1 - gy) * (1 - gz) + outy *= - (1 - gx) * (1 - gz) + outz *= - (1 - gx) * (1 - gy) + # - corner 001 + idx = sub2ind_list([gx0, gy0, gz1], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.gather(-1, idx) + sign = make_sign([signx0, signy0, signz1]) + if sign is not None: + out1 *= sign + outx.addcmul_(out1, - (1 - gy) * gz) + outy.addcmul_(out1, - (1 - gx) * gz) + outz.addcmul_(out1, (1 - gx) * (1 - gy)) + # - corner 010 + idx = sub2ind_list([gx0, gy1, gz0], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.gather(-1, idx) + sign = make_sign([signx0, signy1, signz0]) + if sign is not None: + out1 *= sign + outx.addcmul_(out1, - gy * (1 - gz)) + outy.addcmul_(out1, (1 - gx) * (1 - gz)) + outz.addcmul_(out1, - (1 - gx) * gy) + # - corner 011 + idx = sub2ind_list([gx0, gy1, gz1], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.gather(-1, idx) + sign = make_sign([signx0, signy1, signz1]) + if sign is not None: + out1 *= sign + outx.addcmul_(out1, - gy * gz) + outy.addcmul_(out1, (1 - gx) * gz) + outz.addcmul_(out1, (1 - gx) * gy) + # - corner 100 + idx = sub2ind_list([gx1, gy0, gz0], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.gather(-1, idx) + sign = make_sign([signx1, signy0, signz0]) + if sign is not None: + out1 *= sign + outx.addcmul_(out1, (1 - gy) * (1 - gz)) + outy.addcmul_(out1, - gx * (1 - gz)) + outz.addcmul_(out1, - gx * (1 - gy)) + # - corner 101 + idx = sub2ind_list([gx1, gy0, gz1], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.gather(-1, idx) + sign = make_sign([signx1, signy0, signz1]) + if sign is not None: + out1 *= sign + outx.addcmul_(out1, (1 - gy) * gz) + outy.addcmul_(out1, - gx * gz) + outz.addcmul_(out1, gx * (1 - gy)) + # - corner 110 + idx = sub2ind_list([gx1, gy1, gz0], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.gather(-1, idx) + sign = make_sign([signx1, signy1, signz0]) + if sign is not None: + out1 *= sign + outx.addcmul_(out1, gy * (1 - gz)) + outy.addcmul_(out1, gx * (1 - gz)) + outz.addcmul_(out1, - gx * gy) + # - corner 111 + idx = sub2ind_list([gx1, gy1, gz1], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.gather(-1, idx) + sign = make_sign([signx1, signy1, signz1]) + if sign is not None: + out1 *= sign + outx.addcmul_(out1, gy * gz) + outy.addcmul_(out1, gx * gz) + outz.addcmul_(out1, gx * gy) + + if mask is not None: + out *= mask.unsqueeze(-1) + out = out.reshape(list(out.shape[:2]) + oshape + [3]) + return out + + +@torch.jit.script +def pushgrad3d(inp, g, shape: Optional[List[int]], bound: List[Bound], + extrapolate: int = 1): + """ + inp: (B, C, iX, iY, iZ, 3) tensor + g: (B, iX, iY, iZ, 3) tensor + shape: List{3}[int], optional + bound: List{3}[Bound] tensor + extrapolate: ExtrapolateType + returns: (B, C, *shape) tensor + """ + dim = 3 + boundx, boundy, boundz = bound + if inp.shape[-dim-1:-1] != g.shape[-dim-1:-1]: + raise ValueError('Input and grid should have the same spatial shape') + ishape = list(inp.shape[-dim-1:-1]) + g = g.reshape([g.shape[0], 1, -1, dim]) + gx, gy, gz = g.unbind(-1) + inp = inp.reshape(list(inp.shape[:2]) + [-1, dim]) + batch = max(inp.shape[0], g.shape[0]) + channel = inp.shape[1] + + if shape is None: + shape = ishape + shape = list(shape) + nx, ny, nz = shape + + # mask of inbounds voxels + mask = inbounds_mask_3d(extrapolate, gx, gy, gz, nx, ny, nz) + + # corners + # (upper weight, lower corner, upper corner, lower sign, upper sign) + gx, gx0, gx1, signx0, signx1 = get_weights_and_indices(gx, nx, boundx) + gy, gy0, gy1, signy0, signy1 = get_weights_and_indices(gy, ny, boundy) + gz, gz0, gz1, signz0, signz1 = get_weights_and_indices(gz, nz, boundz) + + # scatter + out = torch.zeros([batch, channel, nx*ny*nz], + dtype=inp.dtype, device=inp.device) + # - corner 000 + idx = sub2ind_list([gx0, gy0, gz0], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.clone() + sign = make_sign([signx0, signy0, signz0]) + if sign is not None: + out1 *= sign.unsqueeze(-1) + if mask is not None: + out1 *= mask.unsqueeze(-1) + out1x, out1y, out1z = out1.unbind(-1) + out1x *= - (1 - gy) * (1 - gz) + out1y *= - (1 - gx) * (1 - gz) + out1z *= - (1 - gx) * (1 - gy) + out.scatter_add_(-1, idx, out1x + out1y + out1z) + # - corner 001 + idx = sub2ind_list([gx0, gy0, gz1], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.clone() + sign = make_sign([signx0, signy0, signz0]) + if sign is not None: + out1 *= sign.unsqueeze(-1) + if mask is not None: + out1 *= mask.unsqueeze(-1) + out1x, out1y, out1z = out1.unbind(-1) + out1x *= - (1 - gy) * gz + out1y *= - (1 - gx) * gz + out1z *= (1 - gx) * (1 - gy) + out.scatter_add_(-1, idx, out1x + out1y + out1z) + # - corner 010 + idx = sub2ind_list([gx0, gy1, gz0], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.clone() + sign = make_sign([signx0, signy0, signz0]) + if sign is not None: + out1 *= sign.unsqueeze(-1) + if mask is not None: + out1 *= mask.unsqueeze(-1) + out1x, out1y, out1z = out1.unbind(-1) + out1x *= - gy * (1 - gz) + out1y *= (1 - gx) * (1 - gz) + out1z *= - (1 - gx) * gy + out.scatter_add_(-1, idx, out1x + out1y + out1z) + # - corner 011 + idx = sub2ind_list([gx0, gy1, gz1], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.clone() + sign = make_sign([signx0, signy0, signz0]) + if sign is not None: + out1 *= sign.unsqueeze(-1) + if mask is not None: + out1 *= mask.unsqueeze(-1) + out1x, out1y, out1z = out1.unbind(-1) + out1x *= - gy * gz + out1y *= (1 - gx) * gz + out1z *= (1 - gx) * gy + out.scatter_add_(-1, idx, out1x + out1y + out1z) + # - corner 100 + idx = sub2ind_list([gx1, gy0, gz0], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.clone() + sign = make_sign([signx0, signy0, signz0]) + if sign is not None: + out1 *= sign.unsqueeze(-1) + if mask is not None: + out1 *= mask.unsqueeze(-1) + out1x, out1y, out1z = out1.unbind(-1) + out1x *= (1 - gy) * (1 - gz) + out1y *= - gx * (1 - gz) + out1z *= - gx * (1 - gy) + out.scatter_add_(-1, idx, out1x + out1y + out1z) + # - corner 101 + idx = sub2ind_list([gx1, gy0, gz1], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.clone() + sign = make_sign([signx0, signy0, signz0]) + if sign is not None: + out1 *= sign.unsqueeze(-1) + if mask is not None: + out1 *= mask.unsqueeze(-1) + out1x, out1y, out1z = out1.unbind(-1) + out1x *= (1 - gy) * gz + out1y *= - gx * gz + out1z *= gx * (1 - gy) + out.scatter_add_(-1, idx, out1x + out1y + out1z) + # - corner 110 + idx = sub2ind_list([gx1, gy1, gz0], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.clone() + sign = make_sign([signx0, signy0, signz0]) + if sign is not None: + out1 *= sign.unsqueeze(-1) + if mask is not None: + out1 *= mask.unsqueeze(-1) + out1x, out1y, out1z = out1.unbind(-1) + out1x *= gy * (1 - gz) + out1y *= gx * (1 - gz) + out1z *= - gx * gy + out.scatter_add_(-1, idx, out1x + out1y + out1z) + # - corner 111 + idx = sub2ind_list([gx1, gy1, gz1], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.clone() + sign = make_sign([signx0, signy0, signz0]) + if sign is not None: + out1 *= sign.unsqueeze(-1) + if mask is not None: + out1 *= mask.unsqueeze(-1) + out1x, out1y, out1z = out1.unbind(-1) + out1x *= gy * gz + out1y *= gx * gz + out1z *= gx * gy + out.scatter_add_(-1, idx, out1x + out1y + out1z) + + out = out.reshape(list(out.shape[:2]) + shape) + return out + + +@torch.jit.script +def hess3d(inp, g, bound: List[Bound], extrapolate: int = 1): + """ + inp: (B, C, iX, iY, iZ) tensor + g: (B, oX, oY, oZ, 3) tensor + bound: List{3}[Bound] tensor + extrapolate: ExtrapolateType + returns: (B, C, oX, oY, oZ, 3, 3) tensor + """ + dim = 3 + boundx, boundy, boundz = bound + oshape = list(g.shape[-dim-1:-1]) + g = g.reshape([g.shape[0], 1, -1, dim]) + gx, gy, gz = torch.unbind(g, -1) + batch = max(inp.shape[0], gx.shape[0]) + channel = inp.shape[1] + shape = list(inp.shape[-dim:]) + nx, ny, nz = shape + + # mask of inbounds voxels + mask = inbounds_mask_3d(extrapolate, gx, gy, gz, nx, ny, nz) + + # corners + # (upper weight, lower corner, upper corner, lower sign, upper sign) + gx, gx0, gx1, signx0, signx1 = get_weights_and_indices(gx, nx, boundx) + gy, gy0, gy1, signy0, signy1 = get_weights_and_indices(gy, ny, boundy) + gz, gz0, gz1, signz0, signz1 = get_weights_and_indices(gz, nz, boundz) + + # gather + inp = inp.reshape(list(inp.shape[:2]) + [-1]) + out = torch.empty([batch, channel, g.shape[-2], dim, dim], + dtype=inp.dtype, device=inp.device) + outx, outy, outz = out.unbind(-1) + outxx, outyx, outzx = outx.unbind(-1) + outxy, outyy, outzy = outy.unbind(-1) + outxz, outyz, outzz = outz.unbind(-1) + # - corner 000 + idx = sub2ind_list([gx0, gy0, gz0], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + torch.gather(inp, -1, idx, out=outxy) + outxz.copy_(outxy) + outyz.copy_(outxy) + outxx.zero_() + outyy.zero_() + outzz.zero_() + sign = make_sign([signx0, signy0, signz0]) + if sign is not None: + out *= sign.unsqueeze(-1) + outxy *= (1 - gz) + outxz *= (1 - gy) + outyz *= (1 - gx) + # - corner 001 + idx = sub2ind_list([gx0, gy0, gz1], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.gather(-1, idx) + sign = make_sign([signx0, signy0, signz1]) + if sign is not None: + out1 *= sign + outxy.addcmul_(out1, gz) + outxz.addcmul_(out1, - (1 - gy)) + outyz.addcmul_(out1, - (1 - gx)) + # - corner 010 + idx = sub2ind_list([gx0, gy1, gz0], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.gather(-1, idx) + sign = make_sign([signx0, signy1, signz0]) + if sign is not None: + out1 *= sign + outxy.addcmul_(out1, - (1 - gz)) + outxz.addcmul_(out1, gy) + outyz.addcmul_(out1, - (1 - gx)) + # - corner 011 + idx = sub2ind_list([gx0, gy1, gz1], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.gather(-1, idx) + sign = make_sign([signx0, signy1, signz1]) + if sign is not None: + out1 *= sign + outxy.addcmul_(out1, - gz) + outxz.addcmul_(out1, - gy) + outyz.addcmul_(out1, (1 - gx)) + # - corner 100 + idx = sub2ind_list([gx1, gy0, gz0], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.gather(-1, idx) + sign = make_sign([signx1, signy0, signz0]) + if sign is not None: + out1 *= sign + outxy.addcmul_(out1, - (1 - gz)) + outxz.addcmul_(out1, - (1 - gy)) + outyz.addcmul_(out1, gx) + # - corner 101 + idx = sub2ind_list([gx1, gy0, gz1], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.gather(-1, idx) + sign = make_sign([signx1, signy0, signz1]) + if sign is not None: + out1 *= sign + outxy.addcmul_(out1, - gz) + outxz.addcmul_(out1, (1 - gy)) + outyz.addcmul_(out1, - gx) + # - corner 110 + idx = sub2ind_list([gx1, gy1, gz0], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.gather(-1, idx) + sign = make_sign([signx1, signy1, signz0]) + if sign is not None: + out1 *= sign + outxy.addcmul_(out1, (1 - gz)) + outxz.addcmul_(out1, - gy) + outyz.addcmul_(out1, - gx) + # - corner 111 + idx = sub2ind_list([gx1, gy1, gz1], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.gather(-1, idx) + sign = make_sign([signx1, signy1, signz1]) + if sign is not None: + out1 *= sign + outxy.addcmul_(out1, gz) + outxz.addcmul_(out1, gy) + outyz.addcmul_(out1, gx) + + outyx.copy_(outxy) + outzx.copy_(outxz) + outzy.copy_(outyz) + + if mask is not None: + out *= mask.unsqueeze(-1).unsqueeze(-1) + out = out.reshape(list(out.shape[:2]) + oshape + [dim, dim]) + return out + + +# ====================================================================== +# 2D +# ====================================================================== + + +@torch.jit.script +def pull2d(inp, g, bound: List[Bound], extrapolate: int = 1): + """ + inp: (B, C, iX, iY) tensor + g: (B, oX, oY, 2) tensor + bound: List{2}[Bound] tensor + extrapolate: ExtrapolateType + returns: (B, C, oX, oY) tensor + """ + dim = 2 + boundx, boundy = bound + oshape = list(g.shape[-dim-1:-1]) + g = g.reshape([g.shape[0], 1, -1, dim]) + gx, gy = g.unbind(-1) + batch = max(inp.shape[0], gx.shape[0]) + channel = inp.shape[1] + shape = list(inp.shape[-dim:]) + nx, ny = shape + + # mask of inbounds voxels + mask = inbounds_mask_2d(extrapolate, gx, gy, nx, ny) + + # corners + # (upper weight, lower corner, upper corner, lower sign, upper sign) + gx, gx0, gx1, signx0, signx1 = get_weights_and_indices(gx, nx, boundx) + gy, gy0, gy1, signy0, signy1 = get_weights_and_indices(gy, ny, boundy) + + # gather + inp = inp.reshape(list(inp.shape[:2]) + [-1]) + # - corner 00 + idx = sub2ind_list([gx0, gy0], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out = inp.gather(-1, idx) + sign = make_sign([signx0, signy0]) + if sign is not None: + out *= sign + out *= (1 - gx) * (1 - gy) + # - corner 01 + idx = sub2ind_list([gx0, gy1], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.gather(-1, idx) + sign = make_sign([signx0, signy1]) + if sign is not None: + out1 *= sign + out1 *= (1 - gx) * gy + out += out1 + # - corner 10 + idx = sub2ind_list([gx1, gy0], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.gather(-1, idx) + sign = make_sign([signx1, signy0]) + if sign is not None: + out1 *= sign + out1 *= gx * (1 - gy) + out += out1 + # - corner 11 + idx = sub2ind_list([gx1, gy1], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.gather(-1, idx) + sign = make_sign([signx1, signy1]) + if sign is not None: + out1 *= sign + out1 *= gx * gy + out += out1 + + if mask is not None: + out *= mask + out = out.reshape(list(out.shape[:2]) + oshape) + return out + + +@torch.jit.script +def push2d(inp, g, shape: Optional[List[int]], bound: List[Bound], + extrapolate: int = 1): + """ + inp: (B, C, iX, iY) tensor + g: (B, iX, iY, 2) tensor + shape: List{2}[int], optional + bound: List{2}[Bound] tensor + extrapolate: ExtrapolateType + returns: (B, C, *shape) tensor + """ + dim = 2 + boundx, boundy = bound + if inp.shape[-dim:] != g.shape[-dim-1:-1]: + raise ValueError('Input and grid should have the same spatial shape') + ishape = list(inp.shape[-dim:]) + g = g.reshape([g.shape[0], 1, -1, dim]) + gx, gy = torch.unbind(g, -1) + inp = inp.reshape(list(inp.shape[:2]) + [-1]) + batch = max(inp.shape[0], gx.shape[0]) + channel = inp.shape[1] + + if shape is None: + shape = ishape + shape = list(shape) + nx, ny = shape + + # mask of inbounds voxels + mask = inbounds_mask_2d(extrapolate, gx, gy, nx, ny) + + # corners + # (upper weight, lower corner, upper corner, lower sign, upper sign) + gx, gx0, gx1, signx0, signx1 = get_weights_and_indices(gx, nx, boundx) + gy, gy0, gy1, signy0, signy1 = get_weights_and_indices(gy, ny, boundy) + + # scatter + out = torch.zeros([batch, channel, nx*ny], + dtype=inp.dtype, device=inp.device) + # - corner 00 + idx = sub2ind_list([gx0, gy0], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.clone() + sign = make_sign([signx0, signy0]) + if sign is not None: + out1 *= sign + if mask is not None: + out1 *= mask + out1 *= (1 - gx) * (1 - gy) + out.scatter_add_(-1, idx, out1) + # - corner 01 + idx = sub2ind_list([gx0, gy1], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.clone() + sign = make_sign([signx0, signy1]) + if sign is not None: + out1 *= sign + if mask is not None: + out1 *= mask + out1 *= (1 - gx) * gy + out.scatter_add_(-1, idx, out1) + # - corner 10 + idx = sub2ind_list([gx1, gy0], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.clone() + sign = make_sign([signx1, signy0]) + if sign is not None: + out1 *= sign + if mask is not None: + out1 *= mask + out1 *= gx * (1 - gy) + out.scatter_add_(-1, idx, out1) + # - corner 11 + idx = sub2ind_list([gx1, gy1], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.clone() + sign = make_sign([signx1, signy1]) + if sign is not None: + out1 *= sign + if mask is not None: + out1 *= mask + out1 *= gx * gy + out.scatter_add_(-1, idx, out1) + + out = out.reshape(list(out.shape[:2]) + shape) + return out + + +@torch.jit.script +def grad2d(inp, g, bound: List[Bound], extrapolate: int = 1): + """ + inp: (B, C, iX, iY) tensor + g: (B, oX, oY, 2) tensor + bound: List{2}[Bound] tensor + extrapolate: ExtrapolateType + returns: (B, C, oX, oY, 2) tensor + """ + dim = 2 + boundx, boundy = bound + oshape = list(g.shape[-dim-1:-1]) + g = g.reshape([g.shape[0], 1, -1, dim]) + gx, gy = torch.unbind(g, -1) + batch = max(inp.shape[0], gx.shape[0]) + channel = inp.shape[1] + shape = list(inp.shape[-dim:]) + nx, ny = shape + + # mask of inbounds voxels + mask = inbounds_mask_2d(extrapolate, gx, gy, nx, ny) + + # corners + # (upper weight, lower corner, upper corner, lower sign, upper sign) + gx, gx0, gx1, signx0, signx1 = get_weights_and_indices(gx, nx, boundx) + gy, gy0, gy1, signy0, signy1 = get_weights_and_indices(gy, ny, boundy) + + # gather + inp = inp.reshape(list(inp.shape[:2]) + [-1]) + out = torch.empty([batch, channel] + list(g.shape[-2:]), + dtype=inp.dtype, device=inp.device) + outx, outy = out.unbind(-1) + # - corner 00 + idx = sub2ind_list([gx0, gy0], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + torch.gather(inp, -1, idx, out=outx) + outy.copy_(outx) + sign = make_sign([signx0, signy0]) + if sign is not None: + out *= sign.unsqueeze(-1) + outx *= - (1 - gy) + outy *= - (1 - gx) + # - corner 01 + idx = sub2ind_list([gx0, gy1], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.gather(-1, idx) + sign = make_sign([signx0, signy1]) + if sign is not None: + out1 *= sign + outx.addcmul_(out1, - gy) + outy.addcmul_(out1, (1 - gx)) + # - corner 10 + idx = sub2ind_list([gx1, gy0], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.gather(-1, idx) + sign = make_sign([signx1, signy0]) + if sign is not None: + out1 *= sign + outx.addcmul_(out1, (1 - gy)) + outy.addcmul_(out1, - gx) + # - corner 11 + idx = sub2ind_list([gx1, gy1], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.gather(-1, idx) + sign = make_sign([signx1, signy1]) + if sign is not None: + out1 *= sign + outx.addcmul_(out1, gy) + outy.addcmul_(out1, gx) + + if mask is not None: + out *= mask.unsqueeze(-1) + out = out.reshape(list(out.shape[:2]) + oshape + [dim]) + return out + + +@torch.jit.script +def pushgrad2d(inp, g, shape: Optional[List[int]], bound: List[Bound], + extrapolate: int = 1): + """ + inp: (B, C, iX, iY, 2) tensor + g: (B, iX, iY, 2) tensor + shape: List{2}[int], optional + bound: List{2}[Bound] tensor + extrapolate: ExtrapolateType + returns: (B, C, *shape) tensor + """ + dim = 2 + boundx, boundy = bound + if inp.shape[-dim-1:-1] != g.shape[-dim-1:-1]: + raise ValueError('Input and grid should have the same spatial shape') + ishape = list(inp.shape[-dim-1:-1]) + g = g.reshape([g.shape[0], 1, -1, dim]) + gx, gy = g.unbind(-1) + inp = inp.reshape(list(inp.shape[:2]) + [-1, dim]) + batch = max(inp.shape[0], g.shape[0]) + channel = inp.shape[1] + + if shape is None: + shape = ishape + shape = list(shape) + nx, ny = shape + + # mask of inbounds voxels + mask = inbounds_mask_2d(extrapolate, gx, gy, nx, ny) + + # corners + # (upper weight, lower corner, upper corner, lower sign, upper sign) + gx, gx0, gx1, signx0, signx1 = get_weights_and_indices(gx, nx, boundx) + gy, gy0, gy1, signy0, signy1 = get_weights_and_indices(gy, ny, boundy) + + # scatter + out = torch.zeros([batch, channel, nx*ny], + dtype=inp.dtype, device=inp.device) + # - corner 00 + idx = sub2ind_list([gx0, gy0], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.clone() + sign = make_sign([signx0, signy0]) + if sign is not None: + out1 *= sign.unsqueeze(-1) + if mask is not None: + out1 *= mask.unsqueeze(-1) + out1x, out1y = out1.unbind(-1) + out1x *= - (1 - gy) + out1y *= - (1 - gx) + out.scatter_add_(-1, idx, out1x + out1y) + # - corner 01 + idx = sub2ind_list([gx0, gy1], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.clone() + sign = make_sign([signx0, signy0]) + if sign is not None: + out1 *= sign.unsqueeze(-1) + if mask is not None: + out1 *= mask.unsqueeze(-1) + out1x, out1y = out1.unbind(-1) + out1x *= - gy + out1y *= (1 - gx) + out.scatter_add_(-1, idx, out1x + out1y) + # - corner 10 + idx = sub2ind_list([gx1, gy0], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.clone() + sign = make_sign([signx0, signy0]) + if sign is not None: + out1 *= sign.unsqueeze(-1) + if mask is not None: + out1 *= mask.unsqueeze(-1) + out1x, out1y = out1.unbind(-1) + out1x *= (1 - gy) + out1y *= - gx + out.scatter_add_(-1, idx, out1x + out1y) + # - corner 11 + idx = sub2ind_list([gx1, gy1], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.clone() + sign = make_sign([signx0, signy0]) + if sign is not None: + out1 *= sign.unsqueeze(-1) + if mask is not None: + out1 *= mask.unsqueeze(-1) + out1x, out1y = out1.unbind(-1) + out1x *= gy + out1y *= gx + out.scatter_add_(-1, idx, out1x + out1y) + + out = out.reshape(list(out.shape[:2]) + shape) + return out + + +@torch.jit.script +def hess2d(inp, g, bound: List[Bound], extrapolate: int = 1): + """ + inp: (B, C, iX, iY) tensor + g: (B, oX, oY, 2) tensor + bound: List{2}[Bound] tensor + extrapolate: ExtrapolateType + returns: (B, C, oX, oY, 2, 2) tensor + """ + dim = 2 + boundx, boundy = bound + oshape = list(g.shape[-dim-1:-1]) + g = g.reshape([g.shape[0], 1, -1, dim]) + gx, gy = torch.unbind(g, -1) + batch = max(inp.shape[0], gx.shape[0]) + channel = inp.shape[1] + shape = list(inp.shape[-dim:]) + nx, ny = shape + + # mask of inbounds voxels + mask = inbounds_mask_2d(extrapolate, gx, gy, nx, ny) + + # corners + # (upper weight, lower corner, upper corner, lower sign, upper sign) + gx, gx0, gx1, signx0, signx1 = get_weights_and_indices(gx, nx, boundx) + gy, gy0, gy1, signy0, signy1 = get_weights_and_indices(gy, ny, boundy) + + # gather + inp = inp.reshape(list(inp.shape[:2]) + [-1]) + out = torch.empty([batch, channel, g.shape[-2], dim, dim], + dtype=inp.dtype, device=inp.device) + outx, outy = out.unbind(-1) + outxx, outyx = outx.unbind(-1) + outxy, outyy = outy.unbind(-1) + # - corner 00 + idx = sub2ind_list([gx0, gy0], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + torch.gather(inp, -1, idx, out=outxy) + outxx.zero_() + outyy.zero_() + sign = make_sign([signx0, signy0]) + if sign is not None: + out *= sign.unsqueeze(-1) + outxy *= 1 + # - corner 01 + idx = sub2ind_list([gx0, gy1], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.gather(-1, idx) + sign = make_sign([signx0, signy1]) + if sign is not None: + out1 *= sign + outxy.add_(out1, alpha=-1) + # - corner 10 + idx = sub2ind_list([gx1, gy0], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.gather(-1, idx) + sign = make_sign([signx1, signy0]) + if sign is not None: + out1 *= sign + outxy.add_(out1, alpha=-1) + # - corner 11 + idx = sub2ind_list([gx1, gy1], shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.gather(-1, idx) + sign = make_sign([signx1, signy1]) + if sign is not None: + out1 *= sign + outxy.add_(out1) + + outyx.copy_(outxy) + + if mask is not None: + out *= mask.unsqueeze(-1).unsqueeze(-1) + out = out.reshape(list(out.shape[:2]) + oshape + [dim, dim]) + return out + + +# ====================================================================== +# 1D +# ====================================================================== + + +@torch.jit.script +def pull1d(inp, g, bound: List[Bound], extrapolate: int = 1): + """ + inp: (B, C, iX) tensor + g: (B, oX, 1) tensor + bound: List{1}[Bound] tensor + extrapolate: ExtrapolateType + returns: (B, C, oX) tensor + """ + dim = 1 + boundx = bound[0] + oshape = list(g.shape[-dim-1:-1]) + g = g.reshape([g.shape[0], 1, -1, dim]) + gx = g.squeeze(-1) + batch = max(inp.shape[0], gx.shape[0]) + channel = inp.shape[1] + shape = list(inp.shape[-dim:]) + nx = shape[0] + + # mask of inbounds voxels + mask = inbounds_mask_1d(extrapolate, gx, nx) + + # corners + # (upper weight, lower corner, upper corner, lower sign, upper sign) + gx, gx0, gx1, signx0, signx1 = get_weights_and_indices(gx, nx, boundx) + + # gather + inp = inp.reshape(list(inp.shape[:2]) + [-1]) + # - corner 0 + idx = gx0 + idx = idx.expand([batch, channel, idx.shape[-1]]) + out = inp.gather(-1, idx) + sign = signx0 + if sign is not None: + out *= sign + out *= (1 - gx) + # - corner 1 + idx = gx1 + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.gather(-1, idx) + sign = signx1 + if sign is not None: + out1 *= sign + out1 *= gx + out += out1 + + if mask is not None: + out *= mask + out = out.reshape(list(out.shape[:2]) + oshape) + return out + + +@torch.jit.script +def push1d(inp, g, shape: Optional[List[int]], bound: List[Bound], + extrapolate: int = 1): + """ + inp: (B, C, iX, iY) tensor + g: (B, iX, iY, 2) tensor + shape: List{2}[int], optional + bound: List{2}[Bound] tensor + extrapolate: ExtrapolateType + returns: (B, C, *shape) tensor + """ + dim = 1 + boundx = bound[0] + if inp.shape[-dim:] != g.shape[-dim-1:-1]: + raise ValueError('Input and grid should have the same spatial shape') + ishape = list(inp.shape[-dim:]) + g = g.reshape([g.shape[0], 1, -1, dim]) + gx = g.squeeze(-1) + inp = inp.reshape(list(inp.shape[:2]) + [-1]) + batch = max(inp.shape[0], gx.shape[0]) + channel = inp.shape[1] + + if shape is None: + shape = ishape + shape = list(shape) + nx = shape[0] + + # mask of inbounds voxels + mask = inbounds_mask_1d(extrapolate, gx, nx) + + # corners + # (upper weight, lower corner, upper corner, lower sign, upper sign) + gx, gx0, gx1, signx0, signx1 = get_weights_and_indices(gx, nx, boundx) + + # scatter + out = torch.zeros([batch, channel, nx], + dtype=inp.dtype, device=inp.device) + # - corner 0 + idx = gx0 + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.clone() + sign = signx0 + if sign is not None: + out1 *= sign + if mask is not None: + out1 *= mask + out1 *= (1 - gx) + out.scatter_add_(-1, idx, out1) + # - corner 1 + idx = gx1 + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.clone() + sign = signx1 + if sign is not None: + out1 *= sign + if mask is not None: + out1 *= mask + out1 *= gx + out.scatter_add_(-1, idx, out1) + + out = out.reshape(list(out.shape[:2]) + shape) + return out + + +@torch.jit.script +def grad1d(inp, g, bound: List[Bound], extrapolate: int = 1): + """ + inp: (B, C, iX) tensor + g: (B, oX, 1) tensor + bound: List{1}[Bound] tensor + extrapolate: ExtrapolateType + returns: (B, C, oX, 1) tensor + """ + dim = 1 + boundx = bound[0] + oshape = list(g.shape[-dim-1:-1]) + g = g.reshape([g.shape[0], 1, -1, dim]) + gx = g.squeeze(-1) + batch = max(inp.shape[0], gx.shape[0]) + channel = inp.shape[1] + shape = list(inp.shape[-dim:]) + nx = shape[0] + + # mask of inbounds voxels + mask = inbounds_mask_1d(extrapolate, gx, nx) + + # corners + # (upper weight, lower corner, upper corner, lower sign, upper sign) + gx, gx0, gx1, signx0, signx1 = get_weights_and_indices(gx, nx, boundx) + + # gather + inp = inp.reshape(list(inp.shape[:2]) + [-1]) + out = torch.empty([batch, channel] + list(g.shape[-2:]), + dtype=inp.dtype, device=inp.device) + outx = out.squeeze(-1) + # - corner 0 + idx = gx0 + idx = idx.expand([batch, channel, idx.shape[-1]]) + torch.gather(inp, -1, idx, out=outx) + sign = signx0 + if sign is not None: + out *= sign.unsqueeze(-1) + outx.neg_() + # - corner 1 + idx = gx1 + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.gather(-1, idx) + sign = signx1 + if sign is not None: + out1 *= sign + outx.add_(out1) + + if mask is not None: + out *= mask.unsqueeze(-1) + out = out.reshape(list(out.shape[:2]) + oshape + [dim]) + return out + + +@torch.jit.script +def pushgrad1d(inp, g, shape: Optional[List[int]], bound: List[Bound], + extrapolate: int = 1): + """ + inp: (B, C, iX, 1) tensor + g: (B, iX, 1) tensor + shape: List{1}[int], optional + bound: List{1}[Bound] tensor + extrapolate: ExtrapolateType + returns: (B, C, *shape) tensor + """ + dim = 1 + boundx = bound[0] + if inp.shape[-2] != g.shape[-2]: + raise ValueError('Input and grid should have the same spatial shape') + ishape = list(inp.shape[-dim-1:-1]) + g = g.reshape([g.shape[0], 1, -1, dim]) + gx = g.squeeze(-1) + inp = inp.reshape(list(inp.shape[:2]) + [-1, dim]) + batch = max(inp.shape[0], g.shape[0]) + channel = inp.shape[1] + + if shape is None: + shape = ishape + shape = list(shape) + nx = shape[0] + + # mask of inbounds voxels + mask = inbounds_mask_1d(extrapolate, gx, nx) + + # corners + # (upper weight, lower corner, upper corner, lower sign, upper sign) + gx, gx0, gx1, signx0, signx1 = get_weights_and_indices(gx, nx, boundx) + + # scatter + out = torch.zeros([batch, channel, nx], dtype=inp.dtype, device=inp.device) + # - corner 000 + idx = gx0 + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.clone() + sign = signx0 + if sign is not None: + out1 *= sign.unsqueeze(-1) + if mask is not None: + out1 *= mask.unsqueeze(-1) + out1x = out1.squeeze(-1) + out1x.neg_() + out.scatter_add_(-1, idx, out1x) + # - corner 100 + idx = gx1 + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.clone() + sign = signx0 + if sign is not None: + out1 *= sign.unsqueeze(-1) + if mask is not None: + out1 *= mask.unsqueeze(-1) + out1x = out1.squeeze(-1) + out.scatter_add_(-1, idx, out1x) + + out = out.reshape(list(out.shape[:2]) + shape) + return out + + +@torch.jit.script +def hess1d(inp, g, bound: List[Bound], extrapolate: int = 1): + """ + inp: (B, C, iX) tensor + g: (B, oX, 1) tensor + bound: List{1}[Bound] tensor + extrapolate: ExtrapolateType + returns: (B, C, oX, 1, 1) tensor + """ + batch = max(inp.shape[0], g.shape[0]) + return torch.zeros([batch, inp.shape[1], g.shape[1], 1, 1], + dtype=inp.dtype, device=inp.device) \ No newline at end of file diff --git a/interpol/jit_utils.py b/interpol/jit_utils.py new file mode 100644 index 0000000..18224fb --- /dev/null +++ b/interpol/jit_utils.py @@ -0,0 +1,418 @@ +"""A lot of utility functions for TorchScript""" +import torch +import os +from typing import List, Tuple, Optional +Tensor = torch.Tensor + + +@torch.jit.script +def pad_list_int(x: List[int], dim: int) -> List[int]: + if len(x) < dim: + x = x + x[-1:] * (dim - len(x)) + if len(x) > dim: + x = x[:dim] + return x + + +@torch.jit.script +def pad_list_float(x: List[float], dim: int) -> List[float]: + if len(x) < dim: + x = x + x[-1:] * (dim - len(x)) + if len(x) > dim: + x = x[:dim] + return x + + +@torch.jit.script +def pad_list_str(x: List[str], dim: int) -> List[str]: + if len(x) < dim: + x = x + x[-1:] * (dim - len(x)) + if len(x) > dim: + x = x[:dim] + return x + + +@torch.jit.script +def list_any(x: List[bool]) -> bool: + for elem in x: + if elem: + return True + return False + + +@torch.jit.script +def list_all(x: List[bool]) -> bool: + for elem in x: + if not elem: + return False + return True + + +@torch.jit.script +def list_prod_int(x: List[int]) -> int: + if len(x) == 0: + return 1 + x0 = x[0] + for x1 in x[1:]: + x0 = x0 * x1 + return x0 + + +@torch.jit.script +def list_sum_int(x: List[int]) -> int: + if len(x) == 0: + return 1 + x0 = x[0] + for x1 in x[1:]: + x0 = x0 + x1 + return x0 + + +@torch.jit.script +def list_reverse_int(x: List[int]) -> List[int]: + if len(x) == 0: + return x + return [x[i] for i in range(-1, -len(x)-1, -1)] + + +@torch.jit.script +def list_cumprod_int(x: List[int], reverse: bool = False, + exclusive: bool = False) -> List[int]: + if len(x) == 0: + lx: List[int] = [] + return lx + if reverse: + x = list_reverse_int(x) + + x0 = 1 if exclusive else x[0] + lx = [x0] + all_x = x[:-1] if exclusive else x[1:] + for x1 in all_x: + x0 = x0 * x1 + lx.append(x0) + if reverse: + lx = list_reverse_int(lx) + return lx + + +@torch.jit.script +def movedim1(x, source: int, destination: int): + dim = x.dim() + source = dim + source if source < 0 else source + destination = dim + destination if destination < 0 else destination + permutation = [d for d in range(dim)] + permutation = permutation[:source] + permutation[source+1:] + permutation = permutation[:destination] + [source] + permutation[destination:] + return x.permute(permutation) + + +def compare_versions(version1: List[int], mode: str, version2: List[int]) -> bool: + for v1, v2 in zip(version1, version2): + if mode in ('gt', '>'): + if v1 > v2: + return True + elif v1 < v2: + return False + elif mode in ('ge', '>='): + if v1 > v2: + return True + elif v1 < v2: + return False + elif mode in ('lt', '<'): + if v1 < v2: + return True + elif v1 > v2: + return False + elif mode in ('le', '<='): + if v1 < v2: + return True + elif v1 > v2: + return False + if mode in ('gt', 'lt', '>', '<'): + return False + else: + return True + + +def torch_version(mode: str, version: List[int]) -> bool: + """Check torch version + + Parameters + ---------- + mode : {'<', '<=', '>', '>='} + version : list[int] + + Returns + ------- + True if "torch.version version" + + """ + current_version = torch.__version__.split('+')[0] + current_version = current_version.split('.') + current_version = [int(current_version[0]), + int(current_version[1]), + int(current_version[2])] + return compare_versions(current_version, mode, version) + + +@torch.jit.script +def sub2ind(subs, shape: List[int]): + """Convert sub indices (i, j, k) into linear indices. + + The rightmost dimension is the most rapidly changing one + -> if shape == [D, H, W], the strides are therefore [H*W, W, 1] + + Parameters + ---------- + subs : (D, ...) tensor + List of sub-indices. The first dimension is the number of dimension. + Each element should have the same number of elements and shape. + shape : (D,) list[int] + Size of each dimension. Its length should be the same as the + first dimension of ``subs``. + + Returns + ------- + ind : (...) tensor + Linear indices + """ + subs = subs.unbind(0) + ind = subs[-1] + subs = subs[:-1] + ind = ind.clone() + stride = list_cumprod_int(shape[1:], reverse=True, exclusive=False) + for i, s in zip(subs, stride): + ind += i * s + return ind + + +@torch.jit.script +def sub2ind_list(subs: List[Tensor], shape: List[int]): + """Convert sub indices (i, j, k) into linear indices. + + The rightmost dimension is the most rapidly changing one + -> if shape == [D, H, W], the strides are therefore [H*W, W, 1] + + Parameters + ---------- + subs : (D,) list[tensor] + List of sub-indices. The first dimension is the number of dimension. + Each element should have the same number of elements and shape. + shape : (D,) list[int] + Size of each dimension. Its length should be the same as the + first dimension of ``subs``. + + Returns + ------- + ind : (...) tensor + Linear indices + """ + ind = subs[-1] + subs = subs[:-1] + ind = ind.clone() + stride = list_cumprod_int(shape[1:], reverse=True, exclusive=False) + for i, s in zip(subs, stride): + ind += i * s + return ind + +# floor_divide returns wrong results for negative values, because it truncates +# instead of performing a proper floor. In recent version of pytorch, it is +# advised to use div(..., rounding_mode='trunc'|'floor') instead. +# Here, we only use floor_divide on positive values so we do not care. +if torch_version('>=', [1, 8]): + @torch.jit.script + def floor_div(x, y) -> torch.Tensor: + return torch.div(x, y, rounding_mode='floor') + @torch.jit.script + def floor_div_int(x, y: int) -> torch.Tensor: + return torch.div(x, y, rounding_mode='floor') +else: + @torch.jit.script + def floor_div(x, y) -> torch.Tensor: + return (x / y).floor_() + @torch.jit.script + def floor_div_int(x, y: int) -> torch.Tensor: + return (x / y).floor_() + + +@torch.jit.script +def ind2sub(ind, shape: List[int]): + """Convert linear indices into sub indices (i, j, k). + + The rightmost dimension is the most rapidly changing one + -> if shape == [D, H, W], the strides are therefore [H*W, W, 1] + + Parameters + ---------- + ind : tensor_like + Linear indices + shape : (D,) vector_like + Size of each dimension. + + Returns + ------- + subs : (D, ...) tensor + Sub-indices. + """ + stride = list_cumprod_int(shape, reverse=True, exclusive=True) + sub = ind.new_empty([len(shape)] + ind.shape) + sub.copy_(ind) + for d in range(len(shape)): + if d > 0: + sub[d] = torch.remainder(sub[d], stride[d-1]) + sub[d] = floor_div_int(sub[d], stride[d]) + return sub + + +@torch.jit.script +def inbounds_mask_3d(extrapolate: int, gx, gy, gz, nx: int, ny: int, nz: int) \ + -> Optional[Tensor]: + # mask of inbounds voxels + mask: Optional[Tensor] = None + if extrapolate in (0, 2): # no / hist + tiny = 1e-5 + threshold = tiny + if extrapolate == 2: + threshold = 0.5 + tiny + mask = ((gx > -threshold) & (gx < nx - 1 + threshold) & + (gy > -threshold) & (gy < ny - 1 + threshold) & + (gz > -threshold) & (gz < nz - 1 + threshold)) + return mask + return mask + + +@torch.jit.script +def inbounds_mask_2d(extrapolate: int, gx, gy, nx: int, ny: int) \ + -> Optional[Tensor]: + # mask of inbounds voxels + mask: Optional[Tensor] = None + if extrapolate in (0, 2): # no / hist + tiny = 1e-5 + threshold = tiny + if extrapolate == 2: + threshold = 0.5 + tiny + mask = ((gx > -threshold) & (gx < nx - 1 + threshold) & + (gy > -threshold) & (gy < ny - 1 + threshold)) + return mask + return mask + + +@torch.jit.script +def inbounds_mask_1d(extrapolate: int, gx, nx: int) -> Optional[Tensor]: + # mask of inbounds voxels + mask: Optional[Tensor] = None + if extrapolate in (0, 2): # no / hist + tiny = 1e-5 + threshold = tiny + if extrapolate == 2: + threshold = 0.5 + tiny + mask = (gx > -threshold) & (gx < nx - 1 + threshold) + return mask + return mask + + +@torch.jit.script +def make_sign(sign: List[Optional[Tensor]]) -> Optional[Tensor]: + osign: Optional[Tensor] = None + for s in sign: + if s is not None: + if osign is None: + osign = s + else: + osign = osign * s + return osign + + +@torch.jit.script +def square(x): + return x * x + + +@torch.jit.script +def square_(x): + return x.mul_(x) + + +@torch.jit.script +def cube(x): + return x * x * x + + +@torch.jit.script +def cube_(x): + return square_(x).mul_(x) + + +@torch.jit.script +def pow4(x): + return square(square(x)) + + +@torch.jit.script +def pow4_(x): + return square_(square_(x)) + + +@torch.jit.script +def pow5(x): + return x * pow4(x) + + +@torch.jit.script +def pow5_(x): + return pow4_(x).mul_(x) + + +@torch.jit.script +def pow6(x): + return square(cube(x)) + + +@torch.jit.script +def pow6_(x): + return square_(cube_(x)) + + +@torch.jit.script +def pow7(x): + return pow6(x) * x + + +@torch.jit.script +def pow7_(x): + return pow6_(x).mul_(x) + + +@torch.jit.script +def dot(x, y, dim: int = -1, keepdim: bool = False): + """(Batched) dot product along a dimension""" + x = movedim1(x, dim, -1).unsqueeze(-2) + y = movedim1(y, dim, -1).unsqueeze(-1) + d = torch.matmul(x, y).squeeze(-1).squeeze(-1) + if keepdim: + d.unsqueeze(dim) + return d + + +@torch.jit.script +def dot_multi(x, y, dim: List[int], keepdim: bool = False): + """(Batched) dot product along a dimension""" + for d in dim: + x = movedim1(x, d, -1) + y = movedim1(y, d, -1) + x = x.reshape(x.shape[:-len(dim)] + [1, -1]) + y = y.reshape(x.shape[:-len(dim)] + [-1, 1]) + dt = torch.matmul(x, y).squeeze(-1).squeeze(-1) + if keepdim: + for d in dim: + dt.unsqueeze(d) + return dt + + +# cartesian_prod takes multiple inout tensors as input in eager mode +# but takes a list of tensor in jit mode. This is a helper that works +# in both cases. +if not int(os.environ.get('PYTORCH_JIT', '1')): + cartesian_prod = lambda x: torch.cartesian_prod(*x) +else: + cartesian_prod = torch.cartesian_prod diff --git a/interpol/nd.py b/interpol/nd.py new file mode 100644 index 0000000..59eebc8 --- /dev/null +++ b/interpol/nd.py @@ -0,0 +1,439 @@ +"""Generic N-dimensional version: any combination of spline orders""" +import torch +from typing import List, Optional, Tuple +from .bounds import Bound +from .splines import Spline +from .jit_utils import sub2ind_list, make_sign, list_prod_int, cartesian_prod +Tensor = torch.Tensor + + +@torch.jit.script +def inbounds_mask(extrapolate: int, grid, shape: List[int])\ + -> Optional[Tensor]: + # mask of inbounds voxels + mask: Optional[Tensor] = None + if extrapolate in (0, 2): # no / hist + grid = grid.unsqueeze(1) + tiny = 1e-5 + threshold = tiny + if extrapolate == 2: + threshold = 0.5 + tiny + mask = torch.ones(grid.shape[:-1], + dtype=torch.bool, device=grid.device) + for grid1, shape1 in zip(grid.unbind(-1), shape): + mask = mask & (grid1 > -threshold) + mask = mask & (grid1 < shape1 - 1 + threshold) + return mask + return mask + + +@torch.jit.script +def get_weights(grid, bound: List[Bound], spline: List[Spline], + shape: List[int], grad: bool = False, hess: bool = False) \ + -> Tuple[List[List[Tensor]], + List[List[Optional[Tensor]]], + List[List[Optional[Tensor]]], + List[List[Tensor]], + List[List[Optional[Tensor]]]]: + + weights: List[List[Tensor]] = [] + grads: List[List[Optional[Tensor]]] = [] + hesss: List[List[Optional[Tensor]]] = [] + coords: List[List[Tensor]] = [] + signs: List[List[Optional[Tensor]]] = [] + for g, b, s, n in zip(grid.unbind(-1), bound, spline, shape): + grid0 = (g - (s.order-1)/2).floor() + dist0 = g - grid0 + grid0 = grid0.long() + nb_nodes = s.order + 1 + subweights: List[Tensor] = [] + subcoords: List[Tensor] = [] + subgrads: List[Optional[Tensor]] = [] + subhesss: List[Optional[Tensor]] = [] + subsigns: List[Optional[Tensor]] = [] + for node in range(nb_nodes): + grid1 = grid0 + node + sign1 = b.transform(grid1, n) + subsigns.append(sign1) + grid1 = b.index(grid1, n) + subcoords.append(grid1) + dist1 = dist0 - node + weight1 = s.fastweight(dist1) + subweights.append(weight1) + grad1 = s.fastgrad(dist1) if grad else None + subgrads.append(grad1) + hess1 = s.fasthess(dist1) if hess else None + subhesss.append(hess1) + weights.append(subweights) + coords.append(subcoords) + signs.append(subsigns) + grads.append(subgrads) + hesss.append(subhesss) + + return weights, grads, hesss, coords, signs + + +@torch.jit.script +def pull(inp, grid, bound: List[Bound], spline: List[Spline], + extrapolate: int = 1): + """ + inp: (B, C, *ishape) tensor + g: (B, *oshape, D) tensor + bound: List{D}[Bound] tensor + spline: List{D}[Spline] tensor + extrapolate: int + returns: (B, C, *oshape) tensor + """ + + dim = grid.shape[-1] + shape = list(inp.shape[-dim:]) + oshape = list(grid.shape[-dim-1:-1]) + batch = max(inp.shape[0], grid.shape[0]) + channel = inp.shape[1] + + grid = grid.reshape([grid.shape[0], -1, grid.shape[-1]]) + inp = inp.reshape([inp.shape[0], inp.shape[1], -1]) + mask = inbounds_mask(extrapolate, grid, shape) + + # precompute weights along each dimension + weights, _, _, coords, signs = get_weights(grid, bound, spline, shape) + + # initialize + out = torch.zeros([batch, channel, grid.shape[1]], + dtype=inp.dtype, device=inp.device) + + # iterate across nodes/corners + range_nodes = [torch.as_tensor([d for d in range(n)]) + for n in [s.order + 1 for s in spline]] + all_nodes = cartesian_prod(range_nodes) + if dim == 1: + all_nodes = all_nodes.unsqueeze(0) + for nodes in all_nodes: + + # gather + idx = [c[n] for c, n in zip(coords, nodes)] + idx = sub2ind_list(idx, shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.gather(-1, idx) + + # apply sign + sign1 = make_sign([sgn[n] for sgn, n in zip(signs, nodes)]) + if sign1 is not None: + out1 *= sign1 + + # apply weights + for weight, n in zip(weights, nodes): + out1 *= weight[n] + + # accumulate + out += out1 + + # out-of-bounds mask + if mask is not None: + out *= mask + + out = out.reshape(list(out.shape[:2]) + oshape) + return out + + +@torch.jit.script +def push(inp, grid, shape: Optional[List[int]], bound: List[Bound], + spline: List[Spline], extrapolate: int = 1): + """ + inp: (B, C, *ishape) tensor + g: (B, *ishape, D) tensor + shape: List{D}[int], optional + bound: List{D}[Bound] tensor + spline: List{D}[Spline] tensor + extrapolate: int + returns: (B, C, *oshape) tensor + """ + + dim = grid.shape[-1] + ishape = list(grid.shape[-dim - 1:-1]) + if shape is None: + shape = ishape + shape = list(shape) + batch = max(inp.shape[0], grid.shape[0]) + channel = inp.shape[1] + + grid = grid.reshape([grid.shape[0], -1, grid.shape[-1]]) + inp = inp.reshape([inp.shape[0], inp.shape[1], -1]) + mask = inbounds_mask(extrapolate, grid, shape) + + # precompute weights along each dimension + weights, _, _, coords, signs = get_weights(grid, bound, spline, shape) + + # initialize + out = torch.zeros([batch, channel, list_prod_int(shape)], + dtype=inp.dtype, device=inp.device) + + # iterate across nodes/corners + range_nodes = [torch.as_tensor([d for d in range(n)]) + for n in [s.order + 1 for s in spline]] + all_nodes = cartesian_prod(range_nodes) + if dim == 1: + all_nodes = all_nodes.unsqueeze(0) + for nodes in all_nodes: + + # gather + idx = [c[n] for c, n in zip(coords, nodes)] + idx = sub2ind_list(idx, shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.clone() + + # apply sign + sign1 = make_sign([sgn[n] for sgn, n in zip(signs, nodes)]) + if sign1 is not None: + out1 *= sign1 + + # out-of-bounds mask + if mask is not None: + out1 *= mask + + # apply weights + for weight, n in zip(weights, nodes): + out1 *= weight[n] + + # accumulate + out.scatter_add_(-1, idx, out1) + + out = out.reshape(list(out.shape[:2]) + shape) + return out + + +@torch.jit.script +def grad(inp, grid, bound: List[Bound], spline: List[Spline], + extrapolate: int = 1): + """ + inp: (B, C, *ishape) tensor + grid: (B, *oshape, D) tensor + bound: List{D}[Bound] tensor + spline: List{D}[Spline] tensor + extrapolate: int + returns: (B, C, *oshape, D) tensor + """ + + dim = grid.shape[-1] + shape = list(inp.shape[-dim:]) + oshape = list(grid.shape[-dim-1:-1]) + batch = max(inp.shape[0], grid.shape[0]) + channel = inp.shape[1] + + grid = grid.reshape([grid.shape[0], -1, grid.shape[-1]]) + inp = inp.reshape([inp.shape[0], inp.shape[1], -1]) + mask = inbounds_mask(extrapolate, grid, shape) + + # precompute weights along each dimension + weights, grads, _, coords, signs = get_weights(grid, bound, spline, shape, + grad=True) + + # initialize + out = torch.zeros([batch, channel, grid.shape[1], dim], + dtype=inp.dtype, device=inp.device) + + # iterate across nodes/corners + range_nodes = [torch.as_tensor([d for d in range(n)]) + for n in [s.order + 1 for s in spline]] + all_nodes = cartesian_prod(range_nodes) + if dim == 1: + all_nodes = all_nodes.unsqueeze(0) + for nodes in all_nodes: + + # gather + idx = [c[n] for c, n in zip(coords, nodes)] + idx = sub2ind_list(idx, shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out0 = inp.gather(-1, idx) + + # apply sign + sign1 = make_sign([sgn[n] for sgn, n in zip(signs, nodes)]) + if sign1 is not None: + out0 *= sign1 + + for d in range(dim): + out1 = out0.clone() + # apply weights + for dd, (weight, grad1, n) in enumerate(zip(weights, grads, nodes)): + if d == dd: + grad11 = grad1[n] + if grad11 is not None: + out1 *= grad11 + else: + out1 *= weight[n] + + # accumulate + out.unbind(-1)[d].add_(out1) + + # out-of-bounds mask + if mask is not None: + out *= mask.unsqueeze(-1) + + out = out.reshape(list(out.shape[:2]) + oshape + list(out.shape[-1:])) + return out + + +@torch.jit.script +def pushgrad(inp, grid, shape: Optional[List[int]], bound: List[Bound], + spline: List[Spline], extrapolate: int = 1): + """ + inp: (B, C, *ishape, D) tensor + g: (B, *ishape, D) tensor + shape: List{D}[int], optional + bound: List{D}[Bound] tensor + spline: List{D}[Spline] tensor + extrapolate: int + returns: (B, C, *shape) tensor + """ + dim = grid.shape[-1] + oshape = list(grid.shape[-dim-1:-1]) + if shape is None: + shape = oshape + shape = list(shape) + batch = max(inp.shape[0], grid.shape[0]) + channel = inp.shape[1] + + grid = grid.reshape([grid.shape[0], -1, grid.shape[-1]]) + inp = inp.reshape([inp.shape[0], inp.shape[1], -1, dim]) + mask = inbounds_mask(extrapolate, grid, shape) + + # precompute weights along each dimension + weights, grads, _, coords, signs = get_weights(grid, bound, spline, shape, grad=True) + + # initialize + out = torch.zeros([batch, channel, list_prod_int(shape)], + dtype=inp.dtype, device=inp.device) + + # iterate across nodes/corners + range_nodes = [torch.as_tensor([d for d in range(n)]) + for n in [s.order + 1 for s in spline]] + all_nodes = cartesian_prod(range_nodes) + if dim == 1: + all_nodes = all_nodes.unsqueeze(0) + for nodes in all_nodes: + + # gather + idx = [c[n] for c, n in zip(coords, nodes)] + idx = sub2ind_list(idx, shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out0 = inp.clone() + + # apply sign + sign1 = make_sign([sgn[n] for sgn, n in zip(signs, nodes)]) + if sign1 is not None: + out0 *= sign1.unsqueeze(-1) + + # out-of-bounds mask + if mask is not None: + out0 *= mask.unsqueeze(-1) + + for d in range(dim): + out1 = out0.unbind(-1)[d].clone() + # apply weights + for dd, (weight, grad1, n) in enumerate(zip(weights, grads, nodes)): + if d == dd: + grad11 = grad1[n] + if grad11 is not None: + out1 *= grad11 + else: + out1 *= weight[n] + + # accumulate + out.scatter_add_(-1, idx, out1) + + out = out.reshape(list(out.shape[:2]) + shape) + return out + + +@torch.jit.script +def hess(inp, grid, bound: List[Bound], spline: List[Spline], + extrapolate: int = 1): + """ + inp: (B, C, *ishape) tensor + grid: (B, *oshape, D) tensor + bound: List{D}[Bound] tensor + spline: List{D}[Spline] tensor + extrapolate: int + returns: (B, C, *oshape, D, D) tensor + """ + + dim = grid.shape[-1] + shape = list(inp.shape[-dim:]) + oshape = list(grid.shape[-dim-1:-1]) + batch = max(inp.shape[0], grid.shape[0]) + channel = inp.shape[1] + + grid = grid.reshape([grid.shape[0], -1, grid.shape[-1]]) + inp = inp.reshape([inp.shape[0], inp.shape[1], -1]) + mask = inbounds_mask(extrapolate, grid, shape) + + # precompute weights along each dimension + weights, grads, hesss, coords, signs = get_weights(grid, bound, spline, shape, + grad=True, hess=True) + + # initialize + out = torch.zeros([batch, channel, grid.shape[1], dim, dim], + dtype=inp.dtype, device=inp.device) + + # iterate across nodes/corners + range_nodes = [torch.as_tensor([d for d in range(n)]) + for n in [s.order + 1 for s in spline]] + all_nodes = cartesian_prod(range_nodes) + if dim == 1: + all_nodes = all_nodes.unsqueeze(0) + for nodes in all_nodes: + + # gather + idx = [c[n] for c, n in zip(coords, nodes)] + idx = sub2ind_list(idx, shape) + idx = idx.expand([batch, channel, idx.shape[-1]]) + out1 = inp.gather(-1, idx) + + # apply sign + sign1 = make_sign([sgn[n] for sgn, n in zip(signs, nodes)]) + if sign1 is not None: + out1 *= sign1 + + for d in range(dim): + # -- diagonal -- + + # apply weights + for dd, (weight, hess1, n) \ + in enumerate(zip(weights, hesss, nodes)): + if d == dd: + hess11 = hess1[n] + if hess11 is not None: + out1 *= hess11 + else: + out1 *= weight[n] + + # accumulate + out.unbind(-1)[d].unbind(-1)[d].add_(out1) + + # -- off diagonal -- + for d2 in range(d+1, dim): + + # apply weights + for dd, (weight, grad1, hess1, n) \ + in enumerate(zip(weights, grads, hesss, nodes)): + if dd in (d, d2): + grad11 = grad1[n] + if grad11 is not None: + out1 *= grad11 + else: + out1 *= weight[n] + + # accumulate + out.unbind(-1)[d].unbind(-1)[d2].add_(out1) + + # out-of-bounds mask + if mask is not None: + out *= mask.unsqueeze(-1) + + # fill lower triangle + for d in range(dim): + for d2 in range(d+1, dim): + out.unbind(-1)[d2].unbind(-1)[d].copy_(out.unbind(-1)[d].unbind(-1)[d2]) + + out = out.reshape(list(out.shape[:2]) + oshape + list(out.shape[-2:])) + return out \ No newline at end of file diff --git a/interpol/pushpull.py b/interpol/pushpull.py new file mode 100644 index 0000000..3b8677f --- /dev/null +++ b/interpol/pushpull.py @@ -0,0 +1,323 @@ +""" +Non-differentiable forward/backward components. +These components are put together in `interpol.autograd` to generate +differentiable functions. + +Note +---- +.. I removed @torch.jit.script from these entry-points because compiling + all possible combinations of bound+interpolation made the first call + extremely slow. +.. I am not using the dot/multi_dot helpers even though they should be + more efficient that "multiply and sum" because I haven't had the time + to test them. It would be worth doing it. +""" +import torch +from typing import List, Optional, Tuple +from .jit_utils import list_all, dot, dot_multi, pad_list_int +from .bounds import Bound +from .splines import Spline +from . import iso0, iso1, nd +Tensor = torch.Tensor + + +@torch.jit.script +def make_bound(bound: List[int]) -> List[Bound]: + return [Bound(b) for b in bound] + + +@torch.jit.script +def make_spline(spline: List[int]) -> List[Spline]: + return [Spline(s) for s in spline] + + +# @torch.jit.script +def grid_pull(inp, grid, bound: List[int], interpolation: List[int], + extrapolate: int): + """ + inp: (B, C, *spatial_in) tensor + grid: (B, *spatial_out, D) tensor + bound: List{D}[int] tensor + interpolation: List{D}[int] + extrapolate: int + returns: (B, C, *spatial_out) tensor + """ + dim = grid.shape[-1] + bound = pad_list_int(bound, dim) + interpolation = pad_list_int(interpolation, dim) + bound_fn = make_bound(bound) + is_iso1 = list_all([order == 1 for order in interpolation]) + if is_iso1: + if dim == 3: + return iso1.pull3d(inp, grid, bound_fn, extrapolate) + elif dim == 2: + return iso1.pull2d(inp, grid, bound_fn, extrapolate) + elif dim == 1: + return iso1.pull1d(inp, grid, bound_fn, extrapolate) + is_iso0 = list_all([order == 0 for order in interpolation]) + if is_iso0: + if dim == 3: + return iso0.pull3d(inp, grid, bound_fn, extrapolate) + elif dim == 2: + return iso0.pull2d(inp, grid, bound_fn, extrapolate) + elif dim == 1: + return iso0.pull1d(inp, grid, bound_fn, extrapolate) + spline_fn = make_spline(interpolation) + return nd.pull(inp, grid, bound_fn, spline_fn, extrapolate) + + +# @torch.jit.script +def grid_push(inp, grid, shape: Optional[List[int]], bound: List[int], + interpolation: List[int], extrapolate: int): + """ + inp: (B, C, *spatial_in) tensor + grid: (B, *spatial_in, D) tensor + shape: List{D}[int] tensor, optional, default=spatial_in + bound: List{D}[int] tensor + interpolation: List{D}[int] + extrapolate: int + returns: (B, C, *shape) tensor + """ + dim = grid.shape[-1] + bound = pad_list_int(bound, dim) + interpolation = pad_list_int(interpolation, dim) + bound_fn = make_bound(bound) + is_iso1 = list_all([order == 1 for order in interpolation]) + if is_iso1: + if dim == 3: + return iso1.push3d(inp, grid, shape, bound_fn, extrapolate) + elif dim == 2: + return iso1.push2d(inp, grid, shape, bound_fn, extrapolate) + elif dim == 1: + return iso1.push1d(inp, grid, shape, bound_fn, extrapolate) + is_iso0 = list_all([order == 0 for order in interpolation]) + if is_iso0: + if dim == 3: + return iso0.push3d(inp, grid, shape, bound_fn, extrapolate) + elif dim == 2: + return iso0.push2d(inp, grid, shape, bound_fn, extrapolate) + elif dim == 1: + return iso0.push1d(inp, grid, shape, bound_fn, extrapolate) + spline_fn = make_spline(interpolation) + return nd.push(inp, grid, shape, bound_fn, spline_fn, extrapolate) + + +# @torch.jit.script +def grid_count(grid, shape: Optional[List[int]], bound: List[int], + interpolation: List[int], extrapolate: int): + """ + grid: (B, *spatial_in, D) tensor + shape: List{D}[int] tensor, optional, default=spatial_in + bound: List{D}[int] tensor + interpolation: List{D}[int] + extrapolate: int + returns: (B, 1, *shape) tensor + """ + dim = grid.shape[-1] + bound = pad_list_int(bound, dim) + interpolation = pad_list_int(interpolation, dim) + bound_fn = make_bound(bound) + if shape is None: + shape = grid.shape[-dim-1:-1] + inp = torch.ones([1, 1] + list(shape), dtype=grid.dtype, device=grid.device) + is_iso1 = list_all([order == 1 for order in interpolation]) + if is_iso1: + if dim == 3: + return iso1.push3d(inp, grid, shape, bound_fn, extrapolate) + elif dim == 2: + return iso1.push2d(inp, grid, shape, bound_fn, extrapolate) + elif dim == 1: + return iso1.push1d(inp, grid, shape, bound_fn, extrapolate) + is_iso0 = list_all([order == 0 for order in interpolation]) + if is_iso0: + if dim == 3: + return iso0.push3d(inp, grid, shape, bound_fn, extrapolate) + elif dim == 2: + return iso0.push2d(inp, grid, shape, bound_fn, extrapolate) + elif dim == 1: + return iso0.push1d(inp, grid, shape, bound_fn, extrapolate) + spline_fn = make_spline(interpolation) + return nd.push(inp, grid, shape, bound_fn, spline_fn, extrapolate) + + +# @torch.jit.script +def grid_grad(inp, grid, bound: List[int], interpolation: List[int], + extrapolate: int): + """ + inp: (B, C, *spatial_in) tensor + grid: (B, *spatial_out, D) tensor + bound: List{D}[int] tensor + interpolation: List{D}[int] + extrapolate: int + returns: (B, C, *spatial_out, D) tensor + """ + dim = grid.shape[-1] + bound = pad_list_int(bound, dim) + interpolation = pad_list_int(interpolation, dim) + bound_fn = make_bound(bound) + is_iso1 = list_all([order == 1 for order in interpolation]) + if is_iso1: + if dim == 3: + return iso1.grad3d(inp, grid, bound_fn, extrapolate) + elif dim == 2: + return iso1.grad2d(inp, grid, bound_fn, extrapolate) + elif dim == 1: + return iso1.grad1d(inp, grid, bound_fn, extrapolate) + is_iso0 = list_all([order == 0 for order in interpolation]) + if is_iso0: + return iso0.grad(inp, grid, bound_fn, extrapolate) + spline_fn = make_spline(interpolation) + return nd.grad(inp, grid, bound_fn, spline_fn, extrapolate) + + +# @torch.jit.script +def grid_pushgrad(inp, grid, shape: List[int], bound: List[int], + interpolation: List[int], extrapolate: int): + """ /!\ Used only in backward pass of grid_grad + inp: (B, C, *spatial_in, D) tensor + grid: (B, *spatial_in, D) tensor + shape: List{D}[int], optional + bound: List{D}[int] tensor + interpolation: List{D}[int] + extrapolate: int + returns: (B, C, *shape) tensor + """ + dim = grid.shape[-1] + bound = pad_list_int(bound, dim) + interpolation = pad_list_int(interpolation, dim) + bound_fn = make_bound(bound) + is_iso1 = list_all([order == 1 for order in interpolation]) + if is_iso1: + if dim == 3: + return iso1.pushgrad3d(inp, grid, shape, bound_fn, extrapolate) + elif dim == 2: + return iso1.pushgrad2d(inp, grid, shape, bound_fn, extrapolate) + elif dim == 1: + return iso1.pushgrad1d(inp, grid, shape, bound_fn, extrapolate) + is_iso0 = list_all([order == 0 for order in interpolation]) + if is_iso0: + return iso0.pushgrad(inp, grid, shape, bound_fn, extrapolate) + spline_fn = make_spline(interpolation) + return nd.pushgrad(inp, grid, shape, bound_fn, spline_fn, extrapolate) + + +# @torch.jit.script +def grid_hess(inp, grid, bound: List[int], interpolation: List[int], + extrapolate: int): + """ /!\ Used only in backward pass of grid_grad + inp: (B, C, *spatial_in) tensor + grid: (B, *spatial_out, D) tensor + bound: List{D}[int] tensor + interpolation: List{D}[int] + extrapolate: int + returns: (B, C, *spatial_out, D, D) tensor + """ + dim = grid.shape[-1] + bound = pad_list_int(bound, dim) + interpolation = pad_list_int(interpolation, dim) + bound_fn = make_bound(bound) + is_iso1 = list_all([order == 1 for order in interpolation]) + if is_iso1: + if dim == 3: + return iso1.hess3d(inp, grid, bound_fn, extrapolate) + if dim == 2: + return iso1.hess2d(inp, grid, bound_fn, extrapolate) + if dim == 1: + return iso1.hess1d(inp, grid, bound_fn, extrapolate) + is_iso0 = list_all([order == 0 for order in interpolation]) + if is_iso0: + return iso0.hess(inp, grid, bound_fn, extrapolate) + spline_fn = make_spline(interpolation) + return nd.hess(inp, grid, bound_fn, spline_fn, extrapolate) + + +# @torch.jit.script +def grid_pull_backward(grad, inp, grid, bound: List[int], + interpolation: List[int], extrapolate: int) \ + -> Tuple[Optional[Tensor], Optional[Tensor], ]: + """ + grad: (B, C, *spatial_out) tensor + inp: (B, C, *spatial_in) tensor + grid: (B, *spatial_out, D) tensor + bound: List{D}[int] tensor + interpolation: List{D}[int] + extrapolate: int + returns: (B, C, *spatial_in) tensor, (B, *spatial_out, D) + """ + dim = grid.shape[-1] + grad_inp: Optional[Tensor] = None + grad_grid: Optional[Tensor] = None + if inp.requires_grad: + grad_inp = grid_push(grad, grid, inp.shape[-dim:], bound, interpolation, extrapolate) + if grid.requires_grad: + grad_grid = grid_grad(inp, grid, bound, interpolation, extrapolate) + # grad_grid = dot(grad_grid, grad.unsqueeze(-1), dim=1) + grad_grid = (grad_grid * grad.unsqueeze(-1)).sum(dim=1) + return grad_inp, grad_grid + + +# @torch.jit.script +def grid_push_backward(grad, inp, grid, bound: List[int], + interpolation: List[int], extrapolate: int) \ + -> Tuple[Optional[Tensor], Optional[Tensor], ]: + """ + grad: (B, C, *spatial_out) tensor + inp: (B, C, *spatial_in) tensor + grid: (B, *spatial_in, D) tensor + bound: List{D}[int] tensor + interpolation: List{D}[int] + extrapolate: int + returns: (B, C, *spatial_in) tensor, (B, *spatial_in, D) + """ + grad_inp: Optional[Tensor] = None + grad_grid: Optional[Tensor] = None + if inp.requires_grad: + grad_inp = grid_pull(grad, grid, bound, interpolation, extrapolate) + if grid.requires_grad: + grad_grid = grid_grad(grad, grid, bound, interpolation, extrapolate) + # grad_grid = dot(grad_grid, inp.unsqueeze(-1), dim=1) + grad_grid = (grad_grid * inp.unsqueeze(-1)).sum(dim=1) + return grad_inp, grad_grid + + +# @torch.jit.script +def grid_count_backward(grad, grid, bound: List[int], + interpolation: List[int], extrapolate: int) \ + -> Optional[Tensor]: + """ + grad: (B, C, *spatial_out) tensor + grid: (B, *spatial_in, D) tensor + bound: List{D}[int] tensor + interpolation: List{D}[int] + extrapolate: int + returns: (B, C, *spatial_in) tensor, (B, *spatial_in, D) + """ + if grid.requires_grad: + return grid_grad(grad, grid, bound, interpolation, extrapolate).sum(1) + return None + + +# @torch.jit.script +def grid_grad_backward(grad, inp, grid, bound: List[int], + interpolation: List[int], extrapolate: int) \ + -> Tuple[Optional[Tensor], Optional[Tensor], ]: + """ + grad: (B, C, *spatial_out, D) tensor + inp: (B, C, *spatial_in) tensor + grid: (B, *spatial_out, D) tensor + bound: List{D}[int] tensor + interpolation: List{D}[int] + extrapolate: int + returns: (B, C, *spatial_in, D) tensor, (B, *spatial_out, D) + """ + dim = grid.shape[-1] + shape = inp.shape[-dim:] + grad_inp: Optional[Tensor] = None + grad_grid: Optional[Tensor] = None + if inp.requires_grad: + grad_inp = grid_pushgrad(grad, grid, shape, bound, interpolation, extrapolate) + if grid.requires_grad: + grad_grid = grid_hess(inp, grid, bound, interpolation, extrapolate) + # grad_grid = dot_multi(grad_grid, grad.unsqueeze(-1), dim=[1, -2]) + grad_grid = (grad_grid * grad.unsqueeze(-1)).sum(dim=[1, -2]) + return grad_inp, grad_grid diff --git a/interpol/splines.py b/interpol/splines.py new file mode 100644 index 0000000..8c7db91 --- /dev/null +++ b/interpol/splines.py @@ -0,0 +1,196 @@ +"""Weights and derivatives of spline orders 0 to 7.""" +import torch +from enum import Enum +from .jit_utils import square, cube, pow4, pow5, pow6, pow7 + + +class InterpolationType(Enum): + nearest = zeroth = 0 + linear = first = 1 + quadratic = second = 2 + cubic = third = 3 + fourth = 4 + fifth = 5 + sixth = 6 + seventh = 7 + + +@torch.jit.script +class Spline: + + def __init__(self, order: int = 1): + self.order = order + + def weight(self, x): + w = self.fastweight(x) + zero = torch.zeros([1], dtype=x.dtype, device=x.device) + w = torch.where(x.abs() >= (self.order + 1)/2, zero, w) + return w + + def fastweight(self, x): + if self.order == 0: + return torch.ones(x.shape, dtype=x.dtype, device=x.device) + x = x.abs() + if self.order == 1: + return 1 - x + if self.order == 2: + x_low = 0.75 - square(x) + x_up = 0.5 * square(1.5 - x) + return torch.where(x < 0.5, x_low, x_up) + if self.order == 3: + x_low = (x * x * (x - 2.) * 3. + 4.) / 6. + x_up = cube(2. - x) / 6. + return torch.where(x < 1., x_low, x_up) + if self.order == 4: + x_low = square(x) + x_low = x_low * (x_low * 0.25 - 0.625) + 115. / 192. + x_mid = x * (x * (x * (5. - x) / 6. - 1.25) + 5./24.) + 55./96. + x_up = pow4(x - 2.5) / 24. + return torch.where(x < 0.5, x_low, torch.where(x < 1.5, x_mid, x_up)) + if self.order == 5: + x_low = square(x) + x_low = x_low * (x_low * (0.25 - x / 12.) - 0.5) + 0.55 + x_mid = x * (x * (x * (x * (x / 24. - 0.375) + 1.25) - 1.75) + 0.625) + 0.425 + x_up = pow5(3 - x) / 120. + return torch.where(x < 1., x_low, torch.where(x < 2., x_mid, x_up)) + if self.order == 6: + x_low = square(x) + x_low = x_low * (x_low * (7./48. - x_low/36.) - 77./192.) + 5887./11520. + x_mid_low = (x * (x * (x * (x * (x * (x / 48. - 7./48.) + 0.328125) + - 35./288.) - 91./256.) - 7./768.) + 7861./15360.) + x_mid_up = (x * (x * (x * (x * (x * (7./60. - x / 120.) - 0.65625) + + 133./72.) - 2.5703125) + 1267./960.) + 1379./7680.) + x_up = pow6(x - 3.5) / 720. + return torch.where(x < .5, x_low, + torch.where(x < 1.5, x_mid_low, + torch.where(x < 2.5, x_mid_up, x_up))) + if self.order == 7: + x_low = square(x) + x_low = (x_low * (x_low * (x_low * (x / 144. - 1./36.) + + 1./9.) - 1./3.) + 151./315.) + x_mid_low = (x * (x * (x * (x * (x * (x * (0.05 - x/240.) - 7./30.) + + 0.5) - 7./18.) - 0.1) - 7./90.) + 103./210.) + x_mid_up = (x * (x * (x * (x * (x * (x * (x / 720. - 1./36.) + + 7./30.) - 19./18.) + 49./18.) - 23./6.) + 217./90.) + - 139./630.) + x_up = pow7(4 - x) / 5040. + return torch.where(x < 1., x_low, + torch.where(x < 2., x_mid_low, + torch.where(x < 3., x_mid_up, x_up))) + raise NotImplementedError + + def grad(self, x): + if self.order == 0: + return torch.zeros(x.shape, dtype=x.dtype, device=x.device) + g = self.fastgrad(x) + zero = torch.zeros([1], dtype=x.dtype, device=x.device) + g = torch.where(x.abs() >= (self.order + 1)/2, zero, g) + return g + + def fastgrad(self, x): + if self.order == 0: + return torch.zeros(x.shape, dtype=x.dtype, device=x.device) + return self._fastgrad(x.abs()).mul(x.sign()) + + def _fastgrad(self, x): + if self.order == 1: + return torch.ones(x.shape, dtype=x.dtype, device=x.device) + if self.order == 2: + return torch.where(x < 0.5, -2*x, x - 1.5) + if self.order == 3: + g_low = x * (x * 1.5 - 2) + g_up = -0.5 * square(2 - x) + return torch.where(x < 1, g_low, g_up) + if self.order == 4: + g_low = x * (square(x) - 1.25) + g_mid = x * (x * (x * (-2./3.) + 2.5) - 2.5) + 5./24. + g_up = cube(2. * x - 5.) / 48. + return torch.where(x < 0.5, g_low, + torch.where(x < 1.5, g_mid, g_up)) + if self.order == 5: + g_low = x * (x * (x * (x * (-5./12.) + 1.)) - 1.) + g_mid = x * (x * (x * (x * (5./24.) - 1.5) + 3.75) - 3.5) + 0.625 + g_up = pow4(x - 3.) / (-24.) + return torch.where(x < 1, g_low, + torch.where(x < 2, g_mid, g_up)) + if self.order == 6: + g_low = square(x) + g_low = x * (g_low * (7./12.) - square(g_low) / 6. - 77./96.) + g_mid_low = (x * (x * (x * (x * (x * 0.125 - 35./48.) + 1.3125) + - 35./96.) - 0.7109375) - 7./768.) + g_mid_up = (x * (x * (x * (x * (x / (-20.) + 7./12.) - 2.625) + + 133./24.) - 5.140625) + 1267./960.) + g_up = pow5(2*x - 7) / 3840. + return torch.where(x < 0.5, g_low, + torch.where(x < 1.5, g_mid_low, + torch.where(x < 2.5, g_mid_up, + g_up))) + if self.order == 7: + g_low = square(x) + g_low = x * (g_low * (g_low * (x * (7./144.) - 1./6.) + 4./9.) - 2./3.) + g_mid_low = (x * (x * (x * (x * (x * (x * (-7./240.) + 3./10.) + - 7./6.) + 2.) - 7./6.) - 1./5.) - 7./90.) + g_mid_up = (x * (x * (x * (x * (x * (x * (7./720.) - 1./6.) + + 7./6.) - 38./9.) + 49./6.) - 23./3.) + 217./90.) + g_up = pow6(x - 4) / (-720.) + return torch.where(x < 1, g_low, + torch.where(x < 2, g_mid_low, + torch.where(x < 3, g_mid_up, g_up))) + raise NotImplementedError + + def hess(self, x): + if self.order == 0: + return torch.zeros(x.shape, dtype=x.dtype, device=x.device) + h = self.fasthess(x) + zero = torch.zeros([1], dtype=x.dtype, device=x.device) + h = torch.where(x.abs() >= (self.order + 1)/2, zero, h) + return h + + def fasthess(self, x): + if self.order in (0, 1): + return torch.zeros(x.shape, dtype=x.dtype, device=x.device) + x = x.abs() + if self.order == 2: + one = torch.ones([1], dtype=x.dtype, device=x.device) + return torch.where(x < 0.5, one, -2 * one) + if self.order == 3: + return torch.where(x < 1, 3. * x - 2., 2. - x) + if self.order == 4: + return torch.where(x < 0.5, 3. * square(x) - 1.25, + torch.where(x < 1.5, x * (-2. * x + 5.) - 2.5, + square(2. * x - 5.) / 8.)) + if self.order == 5: + h_low = square(x) + h_low = - h_low * (x * (5./3.) - 3.) - 1. + h_mid = x * (x * (x * (5./6.) - 9./2.) + 15./2.) - 7./2. + h_up = - x * (x * (x/6. - 3./2.) + 9./2.) + return torch.where(x < 1, h_low, + torch.where(x < 2, h_mid, h_up)) + if self.order == 6: + h_low = square(x) + h_low = - h_low * (h_low * (5./6) - 7./4.) - 77./96. + h_mid_low = (x * (x * (x * (x * (5./8.) - 35./12.) + 63./16.) + - 35./48.) - 91./128.) + h_mid_up = -(x * (x * (x * (x/4. - 7./3.) + 63./8.) - 133./12.) + + 329./64.) + h_up = (x * (x * (x * (x/24. - 7./12.) + 49./16.) - 343./48.) + + 2401./384.) + return torch.where(x < 0.5, h_low, + torch.where(x < 1.5, h_mid_low, + torch.where(x < 2.5, h_mid_up, + h_up))) + if self.order == 7: + h_low = square(x) + h_low = h_low * (h_low*(x * (7./24.) - 5./6.) + 4./3.) - 2./3. + h_mid_low = - (x * (x * (x * (x * (x * (7./40.) - 3./2.) + 14./3.) + - 6.) + 7./3.) + 1./5.) + h_mid_up = (x * (x * (x * (x * (x * (7./120.) - 5./6.) + 14./3.) + - 38./3.) + 49./3.) - 23./3.) + h_up = - (x * (x * (x * (x * (x/120. - 1./6.) + 4./3.) - 16./3.) + + 32.) - 128./15.) + return torch.where(x < 1, h_low, + torch.where(x < 2, h_mid_low, + torch.where(x < 3, h_mid_up, + h_up))) + raise NotImplementedError + diff --git a/interpol/utils.py b/interpol/utils.py new file mode 100644 index 0000000..b58838d --- /dev/null +++ b/interpol/utils.py @@ -0,0 +1,101 @@ +import torch + + +def make_list(x, n=None, **kwargs): + """Ensure that the input is a list (of a given size) + + Parameters + ---------- + x : list or tuple or scalar + Input object + n : int, optional + Required length + default : scalar, optional + Value to right-pad with. Use last value of the input by default. + + Returns + ------- + x : list + """ + if not isinstance(x, (list, tuple)): + x = [x] + x = list(x) + default = kwargs.get('default', x[-1]) + if n: + x = x + [default] * max(0, n - len(x)) + return x + + +def expanded_shape(*shapes, side='left'): + """Expand input shapes according to broadcasting rules + + Parameters + ---------- + *shapes : sequence[int] + Input shapes + side : {'left', 'right'}, default='left' + Side to add singleton dimensions. + + Returns + ------- + shape : tuple[int] + Output shape + + Raises + ------ + ValueError + If shapes are not compatible for broadcast. + + """ + def error(s0, s1): + raise ValueError('Incompatible shapes for broadcasting: {} and {}.' + .format(s0, s1)) + + # 1. nb dimensions + nb_dim = 0 + for shape in shapes: + nb_dim = max(nb_dim, len(shape)) + + # 2. enumerate + shape = [1] * nb_dim + for i, shape1 in enumerate(shapes): + pad_size = nb_dim - len(shape1) + ones = [1] * pad_size + if side == 'left': + shape1 = [*ones, *shape1] + else: + shape1 = [*shape1, *ones] + shape = [max(s0, s1) if s0 == 1 or s1 == 1 or s0 == s1 + else error(s0, s1) for s0, s1 in zip(shape, shape1)] + + return tuple(shape) + + +def matvec(mat, vec, out=None): + """Matrix-vector product (supports broadcasting) + + Parameters + ---------- + mat : (..., M, N) tensor + Input matrix. + vec : (..., N) tensor + Input vector. + out : (..., M) tensor, optional + Placeholder for the output tensor. + + Returns + ------- + mv : (..., M) tensor + Matrix vector product of the inputs + + """ + vec = vec[..., None] + if out is not None: + out = out[..., None] + + mv = torch.matmul(mat, vec, out=out) + mv = mv[..., 0] + if out is not None: + out = out[..., 0] + + return mv \ No newline at end of file diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..55b1780 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,31 @@ +[metadata] +name = torch-interpol +author = Yael Balbastre +author_email = yael.balbastre@gmail.com +description = High-order spline interpolation in PyTorch +long_description = file:README.md +platforms = OS Independent +license = MIT +license_file = LICENSE +classifiers = + License :: OSI Approved :: MIT License + Operating System :: OS Independent + Programming Language :: Python :: 3 + Intended Audience :: Science/Research + Topic :: Scientific/Engineering :: Artificial Intelligence + Topic :: Scientific/Engineering :: Medical Science Apps. +project_urls = + Source Code=https://github.com/balbasty/torch-interpol + +[options] +python_requires = >= 3.6 +# we should be able to make all numpy/scipy dependencies optional +install_requires = torch >= 1.3 + +[versioneer] +VCS = git +style = pep440 +versionfile_source = interpol/_version.py +versionfile_build = interpol/_version.py +tag_prefix = +parentdir_prefix = \ No newline at end of file diff --git a/setup.py b/setup.py new file mode 100755 index 0000000..53ea4b6 --- /dev/null +++ b/setup.py @@ -0,0 +1,9 @@ +#!/usr/bin/env python3 +from setuptools import setup, find_packages +import versioneer + +setup( + version=versioneer.get_version(), + packages=find_packages(), + cmdclass=versioneer.get_cmdclass(), +) diff --git a/versioneer.py b/versioneer.py new file mode 100644 index 0000000..9713007 --- /dev/null +++ b/versioneer.py @@ -0,0 +1,2064 @@ + +# Version: 0.20 + +"""The Versioneer - like a rocketeer, but for versions. + +The Versioneer +============== + +* like a rocketeer, but for versions! +* https://github.com/python-versioneer/python-versioneer +* Brian Warner +* License: Public Domain +* Compatible with: Python 3.6, 3.7, 3.8, 3.9 and pypy3 +* [![Latest Version][pypi-image]][pypi-url] +* [![Build Status][travis-image]][travis-url] + +This is a tool for managing a recorded version number in distutils-based +python projects. The goal is to remove the tedious and error-prone "update +the embedded version string" step from your release process. Making a new +release should be as easy as recording a new tag in your version-control +system, and maybe making new tarballs. + + +## Quick Install + +* `pip install versioneer` to somewhere in your $PATH +* add a `[versioneer]` section to your setup.cfg (see [Install](INSTALL.md)) +* run `versioneer install` in your source tree, commit the results +* Verify version information with `python setup.py version` + +## Version Identifiers + +Source trees come from a variety of places: + +* a version-control system checkout (mostly used by developers) +* a nightly tarball, produced by build automation +* a snapshot tarball, produced by a web-based VCS browser, like github's + "tarball from tag" feature +* a release tarball, produced by "setup.py sdist", distributed through PyPI + +Within each source tree, the version identifier (either a string or a number, +this tool is format-agnostic) can come from a variety of places: + +* ask the VCS tool itself, e.g. "git describe" (for checkouts), which knows + about recent "tags" and an absolute revision-id +* the name of the directory into which the tarball was unpacked +* an expanded VCS keyword ($Id$, etc) +* a `_version.py` created by some earlier build step + +For released software, the version identifier is closely related to a VCS +tag. Some projects use tag names that include more than just the version +string (e.g. "myproject-1.2" instead of just "1.2"), in which case the tool +needs to strip the tag prefix to extract the version identifier. For +unreleased software (between tags), the version identifier should provide +enough information to help developers recreate the same tree, while also +giving them an idea of roughly how old the tree is (after version 1.2, before +version 1.3). Many VCS systems can report a description that captures this, +for example `git describe --tags --dirty --always` reports things like +"0.7-1-g574ab98-dirty" to indicate that the checkout is one revision past the +0.7 tag, has a unique revision id of "574ab98", and is "dirty" (it has +uncommitted changes). + +The version identifier is used for multiple purposes: + +* to allow the module to self-identify its version: `myproject.__version__` +* to choose a name and prefix for a 'setup.py sdist' tarball + +## Theory of Operation + +Versioneer works by adding a special `_version.py` file into your source +tree, where your `__init__.py` can import it. This `_version.py` knows how to +dynamically ask the VCS tool for version information at import time. + +`_version.py` also contains `$Revision$` markers, and the installation +process marks `_version.py` to have this marker rewritten with a tag name +during the `git archive` command. As a result, generated tarballs will +contain enough information to get the proper version. + +To allow `setup.py` to compute a version too, a `versioneer.py` is added to +the top level of your source tree, next to `setup.py` and the `setup.cfg` +that configures it. This overrides several distutils/setuptools commands to +compute the version when invoked, and changes `setup.py build` and `setup.py +sdist` to replace `_version.py` with a small static file that contains just +the generated version data. + +## Installation + +See [INSTALL.md](./INSTALL.md) for detailed installation instructions. + +## Version-String Flavors + +Code which uses Versioneer can learn about its version string at runtime by +importing `_version` from your main `__init__.py` file and running the +`get_versions()` function. From the "outside" (e.g. in `setup.py`), you can +import the top-level `versioneer.py` and run `get_versions()`. + +Both functions return a dictionary with different flavors of version +information: + +* `['version']`: A condensed version string, rendered using the selected + style. This is the most commonly used value for the project's version + string. The default "pep440" style yields strings like `0.11`, + `0.11+2.g1076c97`, or `0.11+2.g1076c97.dirty`. See the "Styles" section + below for alternative styles. + +* `['full-revisionid']`: detailed revision identifier. For Git, this is the + full SHA1 commit id, e.g. "1076c978a8d3cfc70f408fe5974aa6c092c949ac". + +* `['date']`: Date and time of the latest `HEAD` commit. For Git, it is the + commit date in ISO 8601 format. This will be None if the date is not + available. + +* `['dirty']`: a boolean, True if the tree has uncommitted changes. Note that + this is only accurate if run in a VCS checkout, otherwise it is likely to + be False or None + +* `['error']`: if the version string could not be computed, this will be set + to a string describing the problem, otherwise it will be None. It may be + useful to throw an exception in setup.py if this is set, to avoid e.g. + creating tarballs with a version string of "unknown". + +Some variants are more useful than others. Including `full-revisionid` in a +bug report should allow developers to reconstruct the exact code being tested +(or indicate the presence of local changes that should be shared with the +developers). `version` is suitable for display in an "about" box or a CLI +`--version` output: it can be easily compared against release notes and lists +of bugs fixed in various releases. + +The installer adds the following text to your `__init__.py` to place a basic +version in `YOURPROJECT.__version__`: + + from ._version import get_versions + __version__ = get_versions()['version'] + del get_versions + +## Styles + +The setup.cfg `style=` configuration controls how the VCS information is +rendered into a version string. + +The default style, "pep440", produces a PEP440-compliant string, equal to the +un-prefixed tag name for actual releases, and containing an additional "local +version" section with more detail for in-between builds. For Git, this is +TAG[+DISTANCE.gHEX[.dirty]] , using information from `git describe --tags +--dirty --always`. For example "0.11+2.g1076c97.dirty" indicates that the +tree is like the "1076c97" commit but has uncommitted changes (".dirty"), and +that this commit is two revisions ("+2") beyond the "0.11" tag. For released +software (exactly equal to a known tag), the identifier will only contain the +stripped tag, e.g. "0.11". + +Other styles are available. See [details.md](details.md) in the Versioneer +source tree for descriptions. + +## Debugging + +Versioneer tries to avoid fatal errors: if something goes wrong, it will tend +to return a version of "0+unknown". To investigate the problem, run `setup.py +version`, which will run the version-lookup code in a verbose mode, and will +display the full contents of `get_versions()` (including the `error` string, +which may help identify what went wrong). + +## Known Limitations + +Some situations are known to cause problems for Versioneer. This details the +most significant ones. More can be found on Github +[issues page](https://github.com/python-versioneer/python-versioneer/issues). + +### Subprojects + +Versioneer has limited support for source trees in which `setup.py` is not in +the root directory (e.g. `setup.py` and `.git/` are *not* siblings). The are +two common reasons why `setup.py` might not be in the root: + +* Source trees which contain multiple subprojects, such as + [Buildbot](https://github.com/buildbot/buildbot), which contains both + "master" and "slave" subprojects, each with their own `setup.py`, + `setup.cfg`, and `tox.ini`. Projects like these produce multiple PyPI + distributions (and upload multiple independently-installable tarballs). +* Source trees whose main purpose is to contain a C library, but which also + provide bindings to Python (and perhaps other languages) in subdirectories. + +Versioneer will look for `.git` in parent directories, and most operations +should get the right version string. However `pip` and `setuptools` have bugs +and implementation details which frequently cause `pip install .` from a +subproject directory to fail to find a correct version string (so it usually +defaults to `0+unknown`). + +`pip install --editable .` should work correctly. `setup.py install` might +work too. + +Pip-8.1.1 is known to have this problem, but hopefully it will get fixed in +some later version. + +[Bug #38](https://github.com/python-versioneer/python-versioneer/issues/38) is tracking +this issue. The discussion in +[PR #61](https://github.com/python-versioneer/python-versioneer/pull/61) describes the +issue from the Versioneer side in more detail. +[pip PR#3176](https://github.com/pypa/pip/pull/3176) and +[pip PR#3615](https://github.com/pypa/pip/pull/3615) contain work to improve +pip to let Versioneer work correctly. + +Versioneer-0.16 and earlier only looked for a `.git` directory next to the +`setup.cfg`, so subprojects were completely unsupported with those releases. + +### Editable installs with setuptools <= 18.5 + +`setup.py develop` and `pip install --editable .` allow you to install a +project into a virtualenv once, then continue editing the source code (and +test) without re-installing after every change. + +"Entry-point scripts" (`setup(entry_points={"console_scripts": ..})`) are a +convenient way to specify executable scripts that should be installed along +with the python package. + +These both work as expected when using modern setuptools. When using +setuptools-18.5 or earlier, however, certain operations will cause +`pkg_resources.DistributionNotFound` errors when running the entrypoint +script, which must be resolved by re-installing the package. This happens +when the install happens with one version, then the egg_info data is +regenerated while a different version is checked out. Many setup.py commands +cause egg_info to be rebuilt (including `sdist`, `wheel`, and installing into +a different virtualenv), so this can be surprising. + +[Bug #83](https://github.com/python-versioneer/python-versioneer/issues/83) describes +this one, but upgrading to a newer version of setuptools should probably +resolve it. + + +## Updating Versioneer + +To upgrade your project to a new release of Versioneer, do the following: + +* install the new Versioneer (`pip install -U versioneer` or equivalent) +* edit `setup.cfg`, if necessary, to include any new configuration settings + indicated by the release notes. See [UPGRADING](./UPGRADING.md) for details. +* re-run `versioneer install` in your source tree, to replace + `SRC/_version.py` +* commit any changed files + +## Future Directions + +This tool is designed to make it easily extended to other version-control +systems: all VCS-specific components are in separate directories like +src/git/ . The top-level `versioneer.py` script is assembled from these +components by running make-versioneer.py . In the future, make-versioneer.py +will take a VCS name as an argument, and will construct a version of +`versioneer.py` that is specific to the given VCS. It might also take the +configuration arguments that are currently provided manually during +installation by editing setup.py . Alternatively, it might go the other +direction and include code from all supported VCS systems, reducing the +number of intermediate scripts. + +## Similar projects + +* [setuptools_scm](https://github.com/pypa/setuptools_scm/) - a non-vendored build-time + dependency +* [minver](https://github.com/jbweston/miniver) - a lightweight reimplementation of + versioneer +* [versioningit](https://github.com/jwodder/versioningit) - a PEP 518-based setuptools + plugin + +## License + +To make Versioneer easier to embed, all its code is dedicated to the public +domain. The `_version.py` that it creates is also in the public domain. +Specifically, both are released under the Creative Commons "Public Domain +Dedication" license (CC0-1.0), as described in +https://creativecommons.org/publicdomain/zero/1.0/ . + +[pypi-image]: https://img.shields.io/pypi/v/versioneer.svg +[pypi-url]: https://pypi.python.org/pypi/versioneer/ +[travis-image]: +https://img.shields.io/travis/com/python-versioneer/python-versioneer.svg +[travis-url]: https://travis-ci.com/github/python-versioneer/python-versioneer + +""" + +import configparser +import errno +import json +import os +import re +import subprocess +import sys + + +class VersioneerConfig: # pylint: disable=too-few-public-methods # noqa + """Container for Versioneer configuration parameters.""" + + +def get_root(): + """Get the project root directory. + + We require that all commands are run from the project root, i.e. the + directory that contains setup.py, setup.cfg, and versioneer.py . + """ + root = os.path.realpath(os.path.abspath(os.getcwd())) + setup_py = os.path.join(root, "setup.py") + versioneer_py = os.path.join(root, "versioneer.py") + if not (os.path.exists(setup_py) or os.path.exists(versioneer_py)): + # allow 'python path/to/setup.py COMMAND' + root = os.path.dirname(os.path.realpath(os.path.abspath(sys.argv[0]))) + setup_py = os.path.join(root, "setup.py") + versioneer_py = os.path.join(root, "versioneer.py") + if not (os.path.exists(setup_py) or os.path.exists(versioneer_py)): + err = ("Versioneer was unable to run the project root directory. " + "Versioneer requires setup.py to be executed from " + "its immediate directory (like 'python setup.py COMMAND'), " + "or in a way that lets it use sys.argv[0] to find the root " + "(like 'python path/to/setup.py COMMAND').") + raise VersioneerBadRootError(err) + try: + # Certain runtime workflows (setup.py install/develop in a setuptools + # tree) execute all dependencies in a single python process, so + # "versioneer" may be imported multiple times, and python's shared + # module-import table will cache the first one. So we can't use + # os.path.dirname(__file__), as that will find whichever + # versioneer.py was first imported, even in later projects. + my_path = os.path.realpath(os.path.abspath(__file__)) + me_dir = os.path.normcase(os.path.splitext(my_path)[0]) + vsr_dir = os.path.normcase(os.path.splitext(versioneer_py)[0]) + if me_dir != vsr_dir: + print("Warning: build in %s is using versioneer.py from %s" + % (os.path.dirname(my_path), versioneer_py)) + except NameError: + pass + return root + + +def get_config_from_root(root): + """Read the project setup.cfg file to determine Versioneer config.""" + # This might raise EnvironmentError (if setup.cfg is missing), or + # configparser.NoSectionError (if it lacks a [versioneer] section), or + # configparser.NoOptionError (if it lacks "VCS="). See the docstring at + # the top of versioneer.py for instructions on writing your setup.cfg . + setup_cfg = os.path.join(root, "setup.cfg") + parser = configparser.ConfigParser() + with open(setup_cfg, "r") as cfg_file: + parser.read_file(cfg_file) + VCS = parser.get("versioneer", "VCS") # mandatory + + # Dict-like interface for non-mandatory entries + section = parser["versioneer"] + + # pylint:disable=attribute-defined-outside-init # noqa + cfg = VersioneerConfig() + cfg.VCS = VCS + cfg.style = section.get("style", "") + cfg.versionfile_source = section.get("versionfile_source") + cfg.versionfile_build = section.get("versionfile_build") + cfg.tag_prefix = section.get("tag_prefix") + if cfg.tag_prefix in ("''", '""'): + cfg.tag_prefix = "" + cfg.parentdir_prefix = section.get("parentdir_prefix") + cfg.verbose = section.get("verbose") + return cfg + + +class NotThisMethod(Exception): + """Exception raised if a method is not valid for the current scenario.""" + + +# these dictionaries contain VCS-specific tools +LONG_VERSION_PY = {} +HANDLERS = {} + + +def register_vcs_handler(vcs, method): # decorator + """Create decorator to mark a method as the handler of a VCS.""" + def decorate(f): + """Store f in HANDLERS[vcs][method].""" + HANDLERS.setdefault(vcs, {})[method] = f + return f + return decorate + + +# pylint:disable=too-many-arguments,consider-using-with # noqa +def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, + env=None): + """Call the given command(s).""" + assert isinstance(commands, list) + process = None + for command in commands: + try: + dispcmd = str([command] + args) + # remember shell=False, so use git.cmd on windows, not just git + process = subprocess.Popen([command] + args, cwd=cwd, env=env, + stdout=subprocess.PIPE, + stderr=(subprocess.PIPE if hide_stderr + else None)) + break + except EnvironmentError: + e = sys.exc_info()[1] + if e.errno == errno.ENOENT: + continue + if verbose: + print("unable to run %s" % dispcmd) + print(e) + return None, None + else: + if verbose: + print("unable to find command, tried %s" % (commands,)) + return None, None + stdout = process.communicate()[0].strip().decode() + if process.returncode != 0: + if verbose: + print("unable to run %s (error)" % dispcmd) + print("stdout was %s" % stdout) + return None, process.returncode + return stdout, process.returncode + + +LONG_VERSION_PY['git'] = r''' +# This file helps to compute a version number in source trees obtained from +# git-archive tarball (such as those provided by githubs download-from-tag +# feature). Distribution tarballs (built by setup.py sdist) and build +# directories (produced by setup.py build) will contain a much shorter file +# that just contains the computed version number. + +# This file is released into the public domain. Generated by +# versioneer-0.20 (https://github.com/python-versioneer/python-versioneer) + +"""Git implementation of _version.py.""" + +import errno +import os +import re +import subprocess +import sys + + +def get_keywords(): + """Get the keywords needed to look up the version information.""" + # these strings will be replaced by git during git-archive. + # setup.py/versioneer.py will grep for the variable names, so they must + # each be defined on a line of their own. _version.py will just call + # get_keywords(). + git_refnames = "%(DOLLAR)sFormat:%%d%(DOLLAR)s" + git_full = "%(DOLLAR)sFormat:%%H%(DOLLAR)s" + git_date = "%(DOLLAR)sFormat:%%ci%(DOLLAR)s" + keywords = {"refnames": git_refnames, "full": git_full, "date": git_date} + return keywords + + +class VersioneerConfig: # pylint: disable=too-few-public-methods + """Container for Versioneer configuration parameters.""" + + +def get_config(): + """Create, populate and return the VersioneerConfig() object.""" + # these strings are filled in when 'setup.py versioneer' creates + # _version.py + cfg = VersioneerConfig() + cfg.VCS = "git" + cfg.style = "%(STYLE)s" + cfg.tag_prefix = "%(TAG_PREFIX)s" + cfg.parentdir_prefix = "%(PARENTDIR_PREFIX)s" + cfg.versionfile_source = "%(VERSIONFILE_SOURCE)s" + cfg.verbose = False + return cfg + + +class NotThisMethod(Exception): + """Exception raised if a method is not valid for the current scenario.""" + + +LONG_VERSION_PY = {} +HANDLERS = {} + + +def register_vcs_handler(vcs, method): # decorator + """Create decorator to mark a method as the handler of a VCS.""" + def decorate(f): + """Store f in HANDLERS[vcs][method].""" + if vcs not in HANDLERS: + HANDLERS[vcs] = {} + HANDLERS[vcs][method] = f + return f + return decorate + + +# pylint:disable=too-many-arguments,consider-using-with # noqa +def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, + env=None): + """Call the given command(s).""" + assert isinstance(commands, list) + process = None + for command in commands: + try: + dispcmd = str([command] + args) + # remember shell=False, so use git.cmd on windows, not just git + process = subprocess.Popen([command] + args, cwd=cwd, env=env, + stdout=subprocess.PIPE, + stderr=(subprocess.PIPE if hide_stderr + else None)) + break + except EnvironmentError: + e = sys.exc_info()[1] + if e.errno == errno.ENOENT: + continue + if verbose: + print("unable to run %%s" %% dispcmd) + print(e) + return None, None + else: + if verbose: + print("unable to find command, tried %%s" %% (commands,)) + return None, None + stdout = process.communicate()[0].strip().decode() + if process.returncode != 0: + if verbose: + print("unable to run %%s (error)" %% dispcmd) + print("stdout was %%s" %% stdout) + return None, process.returncode + return stdout, process.returncode + + +def versions_from_parentdir(parentdir_prefix, root, verbose): + """Try to determine the version from the parent directory name. + + Source tarballs conventionally unpack into a directory that includes both + the project name and a version string. We will also support searching up + two directory levels for an appropriately named parent directory + """ + rootdirs = [] + + for _ in range(3): + dirname = os.path.basename(root) + if dirname.startswith(parentdir_prefix): + return {"version": dirname[len(parentdir_prefix):], + "full-revisionid": None, + "dirty": False, "error": None, "date": None} + rootdirs.append(root) + root = os.path.dirname(root) # up a level + + if verbose: + print("Tried directories %%s but none started with prefix %%s" %% + (str(rootdirs), parentdir_prefix)) + raise NotThisMethod("rootdir doesn't start with parentdir_prefix") + + +@register_vcs_handler("git", "get_keywords") +def git_get_keywords(versionfile_abs): + """Extract version information from the given file.""" + # the code embedded in _version.py can just fetch the value of these + # keywords. When used from setup.py, we don't want to import _version.py, + # so we do it with a regexp instead. This function is not used from + # _version.py. + keywords = {} + try: + with open(versionfile_abs, "r") as fobj: + for line in fobj: + if line.strip().startswith("git_refnames ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["refnames"] = mo.group(1) + if line.strip().startswith("git_full ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["full"] = mo.group(1) + if line.strip().startswith("git_date ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["date"] = mo.group(1) + except EnvironmentError: + pass + return keywords + + +@register_vcs_handler("git", "keywords") +def git_versions_from_keywords(keywords, tag_prefix, verbose): + """Get version information from git keywords.""" + if "refnames" not in keywords: + raise NotThisMethod("Short version file found") + date = keywords.get("date") + if date is not None: + # Use only the last line. Previous lines may contain GPG signature + # information. + date = date.splitlines()[-1] + + # git-2.2.0 added "%%cI", which expands to an ISO-8601 -compliant + # datestamp. However we prefer "%%ci" (which expands to an "ISO-8601 + # -like" string, which we must then edit to make compliant), because + # it's been around since git-1.5.3, and it's too difficult to + # discover which version we're using, or to work around using an + # older one. + date = date.strip().replace(" ", "T", 1).replace(" ", "", 1) + refnames = keywords["refnames"].strip() + if refnames.startswith("$Format"): + if verbose: + print("keywords are unexpanded, not using") + raise NotThisMethod("unexpanded keywords, not a git-archive tarball") + refs = {r.strip() for r in refnames.strip("()").split(",")} + # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of + # just "foo-1.0". If we see a "tag: " prefix, prefer those. + TAG = "tag: " + tags = {r[len(TAG):] for r in refs if r.startswith(TAG)} + if not tags: + # Either we're using git < 1.8.3, or there really are no tags. We use + # a heuristic: assume all version tags have a digit. The old git %%d + # expansion behaves like git log --decorate=short and strips out the + # refs/heads/ and refs/tags/ prefixes that would let us distinguish + # between branches and tags. By ignoring refnames without digits, we + # filter out many common branch names like "release" and + # "stabilization", as well as "HEAD" and "master". + tags = {r for r in refs if re.search(r'\d', r)} + if verbose: + print("discarding '%%s', no digits" %% ",".join(refs - tags)) + if verbose: + print("likely tags: %%s" %% ",".join(sorted(tags))) + for ref in sorted(tags): + # sorting will prefer e.g. "2.0" over "2.0rc1" + if ref.startswith(tag_prefix): + r = ref[len(tag_prefix):] + # Filter out refs that exactly match prefix or that don't start + # with a number once the prefix is stripped (mostly a concern + # when prefix is '') + if not re.match(r'\d', r): + continue + if verbose: + print("picking %%s" %% r) + return {"version": r, + "full-revisionid": keywords["full"].strip(), + "dirty": False, "error": None, + "date": date} + # no suitable tags, so version is "0+unknown", but full hex is still there + if verbose: + print("no suitable tags, using unknown + full revision id") + return {"version": "0+unknown", + "full-revisionid": keywords["full"].strip(), + "dirty": False, "error": "no suitable tags", "date": None} + + +@register_vcs_handler("git", "pieces_from_vcs") +def git_pieces_from_vcs(tag_prefix, root, verbose, runner=run_command): + """Get version from 'git describe' in the root of the source tree. + + This only gets called if the git-archive 'subst' keywords were *not* + expanded, and _version.py hasn't already been rewritten with a short + version string, meaning we're inside a checked out source tree. + """ + GITS = ["git"] + if sys.platform == "win32": + GITS = ["git.cmd", "git.exe"] + + _, rc = runner(GITS, ["rev-parse", "--git-dir"], cwd=root, + hide_stderr=True) + if rc != 0: + if verbose: + print("Directory %%s not under git control" %% root) + raise NotThisMethod("'git rev-parse --git-dir' returned error") + + # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] + # if there isn't one, this yields HEX[-dirty] (no NUM) + describe_out, rc = runner(GITS, ["describe", "--tags", "--dirty", + "--always", "--long", + "--match", "%%s*" %% tag_prefix], + cwd=root) + # --long was added in git-1.5.5 + if describe_out is None: + raise NotThisMethod("'git describe' failed") + describe_out = describe_out.strip() + full_out, rc = runner(GITS, ["rev-parse", "HEAD"], cwd=root) + if full_out is None: + raise NotThisMethod("'git rev-parse' failed") + full_out = full_out.strip() + + pieces = {} + pieces["long"] = full_out + pieces["short"] = full_out[:7] # maybe improved later + pieces["error"] = None + + branch_name, rc = runner(GITS, ["rev-parse", "--abbrev-ref", "HEAD"], + cwd=root) + # --abbrev-ref was added in git-1.6.3 + if rc != 0 or branch_name is None: + raise NotThisMethod("'git rev-parse --abbrev-ref' returned error") + branch_name = branch_name.strip() + + if branch_name == "HEAD": + # If we aren't exactly on a branch, pick a branch which represents + # the current commit. If all else fails, we are on a branchless + # commit. + branches, rc = runner(GITS, ["branch", "--contains"], cwd=root) + # --contains was added in git-1.5.4 + if rc != 0 or branches is None: + raise NotThisMethod("'git branch --contains' returned error") + branches = branches.split("\n") + + # Remove the first line if we're running detached + if "(" in branches[0]: + branches.pop(0) + + # Strip off the leading "* " from the list of branches. + branches = [branch[2:] for branch in branches] + if "master" in branches: + branch_name = "master" + elif not branches: + branch_name = None + else: + # Pick the first branch that is returned. Good or bad. + branch_name = branches[0] + + pieces["branch"] = branch_name + + # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty] + # TAG might have hyphens. + git_describe = describe_out + + # look for -dirty suffix + dirty = git_describe.endswith("-dirty") + pieces["dirty"] = dirty + if dirty: + git_describe = git_describe[:git_describe.rindex("-dirty")] + + # now we have TAG-NUM-gHEX or HEX + + if "-" in git_describe: + # TAG-NUM-gHEX + mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe) + if not mo: + # unparseable. Maybe git-describe is misbehaving? + pieces["error"] = ("unable to parse git-describe output: '%%s'" + %% describe_out) + return pieces + + # tag + full_tag = mo.group(1) + if not full_tag.startswith(tag_prefix): + if verbose: + fmt = "tag '%%s' doesn't start with prefix '%%s'" + print(fmt %% (full_tag, tag_prefix)) + pieces["error"] = ("tag '%%s' doesn't start with prefix '%%s'" + %% (full_tag, tag_prefix)) + return pieces + pieces["closest-tag"] = full_tag[len(tag_prefix):] + + # distance: number of commits since tag + pieces["distance"] = int(mo.group(2)) + + # commit: short hex revision ID + pieces["short"] = mo.group(3) + + else: + # HEX: no tags + pieces["closest-tag"] = None + count_out, rc = runner(GITS, ["rev-list", "HEAD", "--count"], cwd=root) + pieces["distance"] = int(count_out) # total number of commits + + # commit date: see ISO-8601 comment in git_versions_from_keywords() + date = runner(GITS, ["show", "-s", "--format=%%ci", "HEAD"], cwd=root)[0].strip() + # Use only the last line. Previous lines may contain GPG signature + # information. + date = date.splitlines()[-1] + pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1) + + return pieces + + +def plus_or_dot(pieces): + """Return a + if we don't already have one, else return a .""" + if "+" in pieces.get("closest-tag", ""): + return "." + return "+" + + +def render_pep440(pieces): + """Build up version string, with post-release "local version identifier". + + Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you + get a tagged build and then dirty it, you'll get TAG+0.gHEX.dirty + + Exceptions: + 1: no tags. git_describe was just HEX. 0+untagged.DISTANCE.gHEX[.dirty] + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + rendered += plus_or_dot(pieces) + rendered += "%%d.g%%s" %% (pieces["distance"], pieces["short"]) + if pieces["dirty"]: + rendered += ".dirty" + else: + # exception #1 + rendered = "0+untagged.%%d.g%%s" %% (pieces["distance"], + pieces["short"]) + if pieces["dirty"]: + rendered += ".dirty" + return rendered + + +def render_pep440_branch(pieces): + """TAG[[.dev0]+DISTANCE.gHEX[.dirty]] . + + The ".dev0" means not master branch. Note that .dev0 sorts backwards + (a feature branch will appear "older" than the master branch). + + Exceptions: + 1: no tags. 0[.dev0]+untagged.DISTANCE.gHEX[.dirty] + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += plus_or_dot(pieces) + rendered += "%%d.g%%s" %% (pieces["distance"], pieces["short"]) + if pieces["dirty"]: + rendered += ".dirty" + else: + # exception #1 + rendered = "0" + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += "+untagged.%%d.g%%s" %% (pieces["distance"], + pieces["short"]) + if pieces["dirty"]: + rendered += ".dirty" + return rendered + + +def render_pep440_pre(pieces): + """TAG[.post0.devDISTANCE] -- No -dirty. + + Exceptions: + 1: no tags. 0.post0.devDISTANCE + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"]: + rendered += ".post0.dev%%d" %% pieces["distance"] + else: + # exception #1 + rendered = "0.post0.dev%%d" %% pieces["distance"] + return rendered + + +def render_pep440_post(pieces): + """TAG[.postDISTANCE[.dev0]+gHEX] . + + The ".dev0" means dirty. Note that .dev0 sorts backwards + (a dirty tree will appear "older" than the corresponding clean one), + but you shouldn't be releasing software with -dirty anyways. + + Exceptions: + 1: no tags. 0.postDISTANCE[.dev0] + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + rendered += ".post%%d" %% pieces["distance"] + if pieces["dirty"]: + rendered += ".dev0" + rendered += plus_or_dot(pieces) + rendered += "g%%s" %% pieces["short"] + else: + # exception #1 + rendered = "0.post%%d" %% pieces["distance"] + if pieces["dirty"]: + rendered += ".dev0" + rendered += "+g%%s" %% pieces["short"] + return rendered + + +def render_pep440_post_branch(pieces): + """TAG[.postDISTANCE[.dev0]+gHEX[.dirty]] . + + The ".dev0" means not master branch. + + Exceptions: + 1: no tags. 0.postDISTANCE[.dev0]+gHEX[.dirty] + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + rendered += ".post%%d" %% pieces["distance"] + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += plus_or_dot(pieces) + rendered += "g%%s" %% pieces["short"] + if pieces["dirty"]: + rendered += ".dirty" + else: + # exception #1 + rendered = "0.post%%d" %% pieces["distance"] + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += "+g%%s" %% pieces["short"] + if pieces["dirty"]: + rendered += ".dirty" + return rendered + + +def render_pep440_old(pieces): + """TAG[.postDISTANCE[.dev0]] . + + The ".dev0" means dirty. + + Exceptions: + 1: no tags. 0.postDISTANCE[.dev0] + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + rendered += ".post%%d" %% pieces["distance"] + if pieces["dirty"]: + rendered += ".dev0" + else: + # exception #1 + rendered = "0.post%%d" %% pieces["distance"] + if pieces["dirty"]: + rendered += ".dev0" + return rendered + + +def render_git_describe(pieces): + """TAG[-DISTANCE-gHEX][-dirty]. + + Like 'git describe --tags --dirty --always'. + + Exceptions: + 1: no tags. HEX[-dirty] (note: no 'g' prefix) + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"]: + rendered += "-%%d-g%%s" %% (pieces["distance"], pieces["short"]) + else: + # exception #1 + rendered = pieces["short"] + if pieces["dirty"]: + rendered += "-dirty" + return rendered + + +def render_git_describe_long(pieces): + """TAG-DISTANCE-gHEX[-dirty]. + + Like 'git describe --tags --dirty --always -long'. + The distance/hash is unconditional. + + Exceptions: + 1: no tags. HEX[-dirty] (note: no 'g' prefix) + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + rendered += "-%%d-g%%s" %% (pieces["distance"], pieces["short"]) + else: + # exception #1 + rendered = pieces["short"] + if pieces["dirty"]: + rendered += "-dirty" + return rendered + + +def render(pieces, style): + """Render the given version pieces into the requested style.""" + if pieces["error"]: + return {"version": "unknown", + "full-revisionid": pieces.get("long"), + "dirty": None, + "error": pieces["error"], + "date": None} + + if not style or style == "default": + style = "pep440" # the default + + if style == "pep440": + rendered = render_pep440(pieces) + elif style == "pep440-branch": + rendered = render_pep440_branch(pieces) + elif style == "pep440-pre": + rendered = render_pep440_pre(pieces) + elif style == "pep440-post": + rendered = render_pep440_post(pieces) + elif style == "pep440-post-branch": + rendered = render_pep440_post_branch(pieces) + elif style == "pep440-old": + rendered = render_pep440_old(pieces) + elif style == "git-describe": + rendered = render_git_describe(pieces) + elif style == "git-describe-long": + rendered = render_git_describe_long(pieces) + else: + raise ValueError("unknown style '%%s'" %% style) + + return {"version": rendered, "full-revisionid": pieces["long"], + "dirty": pieces["dirty"], "error": None, + "date": pieces.get("date")} + + +def get_versions(): + """Get version information or return default if unable to do so.""" + # I am in _version.py, which lives at ROOT/VERSIONFILE_SOURCE. If we have + # __file__, we can work backwards from there to the root. Some + # py2exe/bbfreeze/non-CPython implementations don't do __file__, in which + # case we can only use expanded keywords. + + cfg = get_config() + verbose = cfg.verbose + + try: + return git_versions_from_keywords(get_keywords(), cfg.tag_prefix, + verbose) + except NotThisMethod: + pass + + try: + root = os.path.realpath(__file__) + # versionfile_source is the relative path from the top of the source + # tree (where the .git directory might live) to this file. Invert + # this to find the root from __file__. + for _ in cfg.versionfile_source.split('/'): + root = os.path.dirname(root) + except NameError: + return {"version": "0+unknown", "full-revisionid": None, + "dirty": None, + "error": "unable to find root of source tree", + "date": None} + + try: + pieces = git_pieces_from_vcs(cfg.tag_prefix, root, verbose) + return render(pieces, cfg.style) + except NotThisMethod: + pass + + try: + if cfg.parentdir_prefix: + return versions_from_parentdir(cfg.parentdir_prefix, root, verbose) + except NotThisMethod: + pass + + return {"version": "0+unknown", "full-revisionid": None, + "dirty": None, + "error": "unable to compute version", "date": None} +''' + + +@register_vcs_handler("git", "get_keywords") +def git_get_keywords(versionfile_abs): + """Extract version information from the given file.""" + # the code embedded in _version.py can just fetch the value of these + # keywords. When used from setup.py, we don't want to import _version.py, + # so we do it with a regexp instead. This function is not used from + # _version.py. + keywords = {} + try: + with open(versionfile_abs, "r") as fobj: + for line in fobj: + if line.strip().startswith("git_refnames ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["refnames"] = mo.group(1) + if line.strip().startswith("git_full ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["full"] = mo.group(1) + if line.strip().startswith("git_date ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["date"] = mo.group(1) + except EnvironmentError: + pass + return keywords + + +@register_vcs_handler("git", "keywords") +def git_versions_from_keywords(keywords, tag_prefix, verbose): + """Get version information from git keywords.""" + if "refnames" not in keywords: + raise NotThisMethod("Short version file found") + date = keywords.get("date") + if date is not None: + # Use only the last line. Previous lines may contain GPG signature + # information. + date = date.splitlines()[-1] + + # git-2.2.0 added "%cI", which expands to an ISO-8601 -compliant + # datestamp. However we prefer "%ci" (which expands to an "ISO-8601 + # -like" string, which we must then edit to make compliant), because + # it's been around since git-1.5.3, and it's too difficult to + # discover which version we're using, or to work around using an + # older one. + date = date.strip().replace(" ", "T", 1).replace(" ", "", 1) + refnames = keywords["refnames"].strip() + if refnames.startswith("$Format"): + if verbose: + print("keywords are unexpanded, not using") + raise NotThisMethod("unexpanded keywords, not a git-archive tarball") + refs = {r.strip() for r in refnames.strip("()").split(",")} + # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of + # just "foo-1.0". If we see a "tag: " prefix, prefer those. + TAG = "tag: " + tags = {r[len(TAG):] for r in refs if r.startswith(TAG)} + if not tags: + # Either we're using git < 1.8.3, or there really are no tags. We use + # a heuristic: assume all version tags have a digit. The old git %d + # expansion behaves like git log --decorate=short and strips out the + # refs/heads/ and refs/tags/ prefixes that would let us distinguish + # between branches and tags. By ignoring refnames without digits, we + # filter out many common branch names like "release" and + # "stabilization", as well as "HEAD" and "master". + tags = {r for r in refs if re.search(r'\d', r)} + if verbose: + print("discarding '%s', no digits" % ",".join(refs - tags)) + if verbose: + print("likely tags: %s" % ",".join(sorted(tags))) + for ref in sorted(tags): + # sorting will prefer e.g. "2.0" over "2.0rc1" + if ref.startswith(tag_prefix): + r = ref[len(tag_prefix):] + # Filter out refs that exactly match prefix or that don't start + # with a number once the prefix is stripped (mostly a concern + # when prefix is '') + if not re.match(r'\d', r): + continue + if verbose: + print("picking %s" % r) + return {"version": r, + "full-revisionid": keywords["full"].strip(), + "dirty": False, "error": None, + "date": date} + # no suitable tags, so version is "0+unknown", but full hex is still there + if verbose: + print("no suitable tags, using unknown + full revision id") + return {"version": "0+unknown", + "full-revisionid": keywords["full"].strip(), + "dirty": False, "error": "no suitable tags", "date": None} + + +@register_vcs_handler("git", "pieces_from_vcs") +def git_pieces_from_vcs(tag_prefix, root, verbose, runner=run_command): + """Get version from 'git describe' in the root of the source tree. + + This only gets called if the git-archive 'subst' keywords were *not* + expanded, and _version.py hasn't already been rewritten with a short + version string, meaning we're inside a checked out source tree. + """ + GITS = ["git"] + if sys.platform == "win32": + GITS = ["git.cmd", "git.exe"] + + _, rc = runner(GITS, ["rev-parse", "--git-dir"], cwd=root, + hide_stderr=True) + if rc != 0: + if verbose: + print("Directory %s not under git control" % root) + raise NotThisMethod("'git rev-parse --git-dir' returned error") + + # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] + # if there isn't one, this yields HEX[-dirty] (no NUM) + describe_out, rc = runner(GITS, ["describe", "--tags", "--dirty", + "--always", "--long", + "--match", "%s*" % tag_prefix], + cwd=root) + # --long was added in git-1.5.5 + if describe_out is None: + raise NotThisMethod("'git describe' failed") + describe_out = describe_out.strip() + full_out, rc = runner(GITS, ["rev-parse", "HEAD"], cwd=root) + if full_out is None: + raise NotThisMethod("'git rev-parse' failed") + full_out = full_out.strip() + + pieces = {} + pieces["long"] = full_out + pieces["short"] = full_out[:7] # maybe improved later + pieces["error"] = None + + branch_name, rc = runner(GITS, ["rev-parse", "--abbrev-ref", "HEAD"], + cwd=root) + # --abbrev-ref was added in git-1.6.3 + if rc != 0 or branch_name is None: + raise NotThisMethod("'git rev-parse --abbrev-ref' returned error") + branch_name = branch_name.strip() + + if branch_name == "HEAD": + # If we aren't exactly on a branch, pick a branch which represents + # the current commit. If all else fails, we are on a branchless + # commit. + branches, rc = runner(GITS, ["branch", "--contains"], cwd=root) + # --contains was added in git-1.5.4 + if rc != 0 or branches is None: + raise NotThisMethod("'git branch --contains' returned error") + branches = branches.split("\n") + + # Remove the first line if we're running detached + if "(" in branches[0]: + branches.pop(0) + + # Strip off the leading "* " from the list of branches. + branches = [branch[2:] for branch in branches] + if "master" in branches: + branch_name = "master" + elif not branches: + branch_name = None + else: + # Pick the first branch that is returned. Good or bad. + branch_name = branches[0] + + pieces["branch"] = branch_name + + # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty] + # TAG might have hyphens. + git_describe = describe_out + + # look for -dirty suffix + dirty = git_describe.endswith("-dirty") + pieces["dirty"] = dirty + if dirty: + git_describe = git_describe[:git_describe.rindex("-dirty")] + + # now we have TAG-NUM-gHEX or HEX + + if "-" in git_describe: + # TAG-NUM-gHEX + mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe) + if not mo: + # unparseable. Maybe git-describe is misbehaving? + pieces["error"] = ("unable to parse git-describe output: '%s'" + % describe_out) + return pieces + + # tag + full_tag = mo.group(1) + if not full_tag.startswith(tag_prefix): + if verbose: + fmt = "tag '%s' doesn't start with prefix '%s'" + print(fmt % (full_tag, tag_prefix)) + pieces["error"] = ("tag '%s' doesn't start with prefix '%s'" + % (full_tag, tag_prefix)) + return pieces + pieces["closest-tag"] = full_tag[len(tag_prefix):] + + # distance: number of commits since tag + pieces["distance"] = int(mo.group(2)) + + # commit: short hex revision ID + pieces["short"] = mo.group(3) + + else: + # HEX: no tags + pieces["closest-tag"] = None + count_out, rc = runner(GITS, ["rev-list", "HEAD", "--count"], cwd=root) + pieces["distance"] = int(count_out) # total number of commits + + # commit date: see ISO-8601 comment in git_versions_from_keywords() + date = runner(GITS, ["show", "-s", "--format=%ci", "HEAD"], cwd=root)[0].strip() + # Use only the last line. Previous lines may contain GPG signature + # information. + date = date.splitlines()[-1] + pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1) + + return pieces + + +def do_vcs_install(manifest_in, versionfile_source, ipy): + """Git-specific installation logic for Versioneer. + + For Git, this means creating/changing .gitattributes to mark _version.py + for export-subst keyword substitution. + """ + GITS = ["git"] + if sys.platform == "win32": + GITS = ["git.cmd", "git.exe"] + files = [manifest_in, versionfile_source] + if ipy: + files.append(ipy) + try: + my_path = __file__ + if my_path.endswith(".pyc") or my_path.endswith(".pyo"): + my_path = os.path.splitext(my_path)[0] + ".py" + versioneer_file = os.path.relpath(my_path) + except NameError: + versioneer_file = "versioneer.py" + files.append(versioneer_file) + present = False + try: + with open(".gitattributes", "r") as fobj: + for line in fobj: + if line.strip().startswith(versionfile_source): + if "export-subst" in line.strip().split()[1:]: + present = True + break + except EnvironmentError: + pass + if not present: + with open(".gitattributes", "a+") as fobj: + fobj.write(f"{versionfile_source} export-subst\n") + files.append(".gitattributes") + run_command(GITS, ["add", "--"] + files) + + +def versions_from_parentdir(parentdir_prefix, root, verbose): + """Try to determine the version from the parent directory name. + + Source tarballs conventionally unpack into a directory that includes both + the project name and a version string. We will also support searching up + two directory levels for an appropriately named parent directory + """ + rootdirs = [] + + for _ in range(3): + dirname = os.path.basename(root) + if dirname.startswith(parentdir_prefix): + return {"version": dirname[len(parentdir_prefix):], + "full-revisionid": None, + "dirty": False, "error": None, "date": None} + rootdirs.append(root) + root = os.path.dirname(root) # up a level + + if verbose: + print("Tried directories %s but none started with prefix %s" % + (str(rootdirs), parentdir_prefix)) + raise NotThisMethod("rootdir doesn't start with parentdir_prefix") + + +SHORT_VERSION_PY = """ +# This file was generated by 'versioneer.py' (0.20) from +# revision-control system data, or from the parent directory name of an +# unpacked source archive. Distribution tarballs contain a pre-generated copy +# of this file. + +import json + +version_json = ''' +%s +''' # END VERSION_JSON + + +def get_versions(): + return json.loads(version_json) +""" + + +def versions_from_file(filename): + """Try to determine the version from _version.py if present.""" + try: + with open(filename) as f: + contents = f.read() + except EnvironmentError: + raise NotThisMethod("unable to read _version.py") + mo = re.search(r"version_json = '''\n(.*)''' # END VERSION_JSON", + contents, re.M | re.S) + if not mo: + mo = re.search(r"version_json = '''\r\n(.*)''' # END VERSION_JSON", + contents, re.M | re.S) + if not mo: + raise NotThisMethod("no version_json in _version.py") + return json.loads(mo.group(1)) + + +def write_to_version_file(filename, versions): + """Write the given version number to the given _version.py file.""" + os.unlink(filename) + contents = json.dumps(versions, sort_keys=True, + indent=1, separators=(",", ": ")) + with open(filename, "w") as f: + f.write(SHORT_VERSION_PY % contents) + + print("set %s to '%s'" % (filename, versions["version"])) + + +def plus_or_dot(pieces): + """Return a + if we don't already have one, else return a .""" + if "+" in pieces.get("closest-tag", ""): + return "." + return "+" + + +def render_pep440(pieces): + """Build up version string, with post-release "local version identifier". + + Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you + get a tagged build and then dirty it, you'll get TAG+0.gHEX.dirty + + Exceptions: + 1: no tags. git_describe was just HEX. 0+untagged.DISTANCE.gHEX[.dirty] + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + rendered += plus_or_dot(pieces) + rendered += "%d.g%s" % (pieces["distance"], pieces["short"]) + if pieces["dirty"]: + rendered += ".dirty" + else: + # exception #1 + rendered = "0+untagged.%d.g%s" % (pieces["distance"], + pieces["short"]) + if pieces["dirty"]: + rendered += ".dirty" + return rendered + + +def render_pep440_branch(pieces): + """TAG[[.dev0]+DISTANCE.gHEX[.dirty]] . + + The ".dev0" means not master branch. Note that .dev0 sorts backwards + (a feature branch will appear "older" than the master branch). + + Exceptions: + 1: no tags. 0[.dev0]+untagged.DISTANCE.gHEX[.dirty] + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += plus_or_dot(pieces) + rendered += "%d.g%s" % (pieces["distance"], pieces["short"]) + if pieces["dirty"]: + rendered += ".dirty" + else: + # exception #1 + rendered = "0" + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += "+untagged.%d.g%s" % (pieces["distance"], + pieces["short"]) + if pieces["dirty"]: + rendered += ".dirty" + return rendered + + +def render_pep440_pre(pieces): + """TAG[.post0.devDISTANCE] -- No -dirty. + + Exceptions: + 1: no tags. 0.post0.devDISTANCE + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"]: + rendered += ".post0.dev%d" % pieces["distance"] + else: + # exception #1 + rendered = "0.post0.dev%d" % pieces["distance"] + return rendered + + +def render_pep440_post(pieces): + """TAG[.postDISTANCE[.dev0]+gHEX] . + + The ".dev0" means dirty. Note that .dev0 sorts backwards + (a dirty tree will appear "older" than the corresponding clean one), + but you shouldn't be releasing software with -dirty anyways. + + Exceptions: + 1: no tags. 0.postDISTANCE[.dev0] + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + rendered += ".post%d" % pieces["distance"] + if pieces["dirty"]: + rendered += ".dev0" + rendered += plus_or_dot(pieces) + rendered += "g%s" % pieces["short"] + else: + # exception #1 + rendered = "0.post%d" % pieces["distance"] + if pieces["dirty"]: + rendered += ".dev0" + rendered += "+g%s" % pieces["short"] + return rendered + + +def render_pep440_post_branch(pieces): + """TAG[.postDISTANCE[.dev0]+gHEX[.dirty]] . + + The ".dev0" means not master branch. + + Exceptions: + 1: no tags. 0.postDISTANCE[.dev0]+gHEX[.dirty] + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + rendered += ".post%d" % pieces["distance"] + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += plus_or_dot(pieces) + rendered += "g%s" % pieces["short"] + if pieces["dirty"]: + rendered += ".dirty" + else: + # exception #1 + rendered = "0.post%d" % pieces["distance"] + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += "+g%s" % pieces["short"] + if pieces["dirty"]: + rendered += ".dirty" + return rendered + + +def render_pep440_old(pieces): + """TAG[.postDISTANCE[.dev0]] . + + The ".dev0" means dirty. + + Exceptions: + 1: no tags. 0.postDISTANCE[.dev0] + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + rendered += ".post%d" % pieces["distance"] + if pieces["dirty"]: + rendered += ".dev0" + else: + # exception #1 + rendered = "0.post%d" % pieces["distance"] + if pieces["dirty"]: + rendered += ".dev0" + return rendered + + +def render_git_describe(pieces): + """TAG[-DISTANCE-gHEX][-dirty]. + + Like 'git describe --tags --dirty --always'. + + Exceptions: + 1: no tags. HEX[-dirty] (note: no 'g' prefix) + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"]: + rendered += "-%d-g%s" % (pieces["distance"], pieces["short"]) + else: + # exception #1 + rendered = pieces["short"] + if pieces["dirty"]: + rendered += "-dirty" + return rendered + + +def render_git_describe_long(pieces): + """TAG-DISTANCE-gHEX[-dirty]. + + Like 'git describe --tags --dirty --always -long'. + The distance/hash is unconditional. + + Exceptions: + 1: no tags. HEX[-dirty] (note: no 'g' prefix) + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + rendered += "-%d-g%s" % (pieces["distance"], pieces["short"]) + else: + # exception #1 + rendered = pieces["short"] + if pieces["dirty"]: + rendered += "-dirty" + return rendered + + +def render(pieces, style): + """Render the given version pieces into the requested style.""" + if pieces["error"]: + return {"version": "unknown", + "full-revisionid": pieces.get("long"), + "dirty": None, + "error": pieces["error"], + "date": None} + + if not style or style == "default": + style = "pep440" # the default + + if style == "pep440": + rendered = render_pep440(pieces) + elif style == "pep440-branch": + rendered = render_pep440_branch(pieces) + elif style == "pep440-pre": + rendered = render_pep440_pre(pieces) + elif style == "pep440-post": + rendered = render_pep440_post(pieces) + elif style == "pep440-post-branch": + rendered = render_pep440_post_branch(pieces) + elif style == "pep440-old": + rendered = render_pep440_old(pieces) + elif style == "git-describe": + rendered = render_git_describe(pieces) + elif style == "git-describe-long": + rendered = render_git_describe_long(pieces) + else: + raise ValueError("unknown style '%s'" % style) + + return {"version": rendered, "full-revisionid": pieces["long"], + "dirty": pieces["dirty"], "error": None, + "date": pieces.get("date")} + + +class VersioneerBadRootError(Exception): + """The project root directory is unknown or missing key files.""" + + +def get_versions(verbose=False): + """Get the project version from whatever source is available. + + Returns dict with two keys: 'version' and 'full'. + """ + if "versioneer" in sys.modules: + # see the discussion in cmdclass.py:get_cmdclass() + del sys.modules["versioneer"] + + root = get_root() + cfg = get_config_from_root(root) + + assert cfg.VCS is not None, "please set [versioneer]VCS= in setup.cfg" + handlers = HANDLERS.get(cfg.VCS) + assert handlers, "unrecognized VCS '%s'" % cfg.VCS + verbose = verbose or cfg.verbose + assert cfg.versionfile_source is not None, \ + "please set versioneer.versionfile_source" + assert cfg.tag_prefix is not None, "please set versioneer.tag_prefix" + + versionfile_abs = os.path.join(root, cfg.versionfile_source) + + # extract version from first of: _version.py, VCS command (e.g. 'git + # describe'), parentdir. This is meant to work for developers using a + # source checkout, for users of a tarball created by 'setup.py sdist', + # and for users of a tarball/zipball created by 'git archive' or github's + # download-from-tag feature or the equivalent in other VCSes. + + get_keywords_f = handlers.get("get_keywords") + from_keywords_f = handlers.get("keywords") + if get_keywords_f and from_keywords_f: + try: + keywords = get_keywords_f(versionfile_abs) + ver = from_keywords_f(keywords, cfg.tag_prefix, verbose) + if verbose: + print("got version from expanded keyword %s" % ver) + return ver + except NotThisMethod: + pass + + try: + ver = versions_from_file(versionfile_abs) + if verbose: + print("got version from file %s %s" % (versionfile_abs, ver)) + return ver + except NotThisMethod: + pass + + from_vcs_f = handlers.get("pieces_from_vcs") + if from_vcs_f: + try: + pieces = from_vcs_f(cfg.tag_prefix, root, verbose) + ver = render(pieces, cfg.style) + if verbose: + print("got version from VCS %s" % ver) + return ver + except NotThisMethod: + pass + + try: + if cfg.parentdir_prefix: + ver = versions_from_parentdir(cfg.parentdir_prefix, root, verbose) + if verbose: + print("got version from parentdir %s" % ver) + return ver + except NotThisMethod: + pass + + if verbose: + print("unable to compute version") + + return {"version": "0+unknown", "full-revisionid": None, + "dirty": None, "error": "unable to compute version", + "date": None} + + +def get_version(): + """Get the short version string for this project.""" + return get_versions()["version"] + + +def get_cmdclass(cmdclass=None): + """Get the custom setuptools/distutils subclasses used by Versioneer. + + If the package uses a different cmdclass (e.g. one from numpy), it + should be provide as an argument. + """ + if "versioneer" in sys.modules: + del sys.modules["versioneer"] + # this fixes the "python setup.py develop" case (also 'install' and + # 'easy_install .'), in which subdependencies of the main project are + # built (using setup.py bdist_egg) in the same python process. Assume + # a main project A and a dependency B, which use different versions + # of Versioneer. A's setup.py imports A's Versioneer, leaving it in + # sys.modules by the time B's setup.py is executed, causing B to run + # with the wrong versioneer. Setuptools wraps the sub-dep builds in a + # sandbox that restores sys.modules to it's pre-build state, so the + # parent is protected against the child's "import versioneer". By + # removing ourselves from sys.modules here, before the child build + # happens, we protect the child from the parent's versioneer too. + # Also see https://github.com/python-versioneer/python-versioneer/issues/52 + + cmds = {} if cmdclass is None else cmdclass.copy() + + # we add "version" to both distutils and setuptools + from distutils.core import Command + + class cmd_version(Command): + description = "report generated version string" + user_options = [] + boolean_options = [] + + def initialize_options(self): + pass + + def finalize_options(self): + pass + + def run(self): + vers = get_versions(verbose=True) + print("Version: %s" % vers["version"]) + print(" full-revisionid: %s" % vers.get("full-revisionid")) + print(" dirty: %s" % vers.get("dirty")) + print(" date: %s" % vers.get("date")) + if vers["error"]: + print(" error: %s" % vers["error"]) + cmds["version"] = cmd_version + + # we override "build_py" in both distutils and setuptools + # + # most invocation pathways end up running build_py: + # distutils/build -> build_py + # distutils/install -> distutils/build ->.. + # setuptools/bdist_wheel -> distutils/install ->.. + # setuptools/bdist_egg -> distutils/install_lib -> build_py + # setuptools/install -> bdist_egg ->.. + # setuptools/develop -> ? + # pip install: + # copies source tree to a tempdir before running egg_info/etc + # if .git isn't copied too, 'git describe' will fail + # then does setup.py bdist_wheel, or sometimes setup.py install + # setup.py egg_info -> ? + + # we override different "build_py" commands for both environments + if 'build_py' in cmds: + _build_py = cmds['build_py'] + elif "setuptools" in sys.modules: + from setuptools.command.build_py import build_py as _build_py + else: + from distutils.command.build_py import build_py as _build_py + + class cmd_build_py(_build_py): + def run(self): + root = get_root() + cfg = get_config_from_root(root) + versions = get_versions() + _build_py.run(self) + # now locate _version.py in the new build/ directory and replace + # it with an updated value + if cfg.versionfile_build: + target_versionfile = os.path.join(self.build_lib, + cfg.versionfile_build) + print("UPDATING %s" % target_versionfile) + write_to_version_file(target_versionfile, versions) + cmds["build_py"] = cmd_build_py + + if 'build_ext' in cmds: + _build_ext = cmds['build_ext'] + elif "setuptools" in sys.modules: + from setuptools.command.build_ext import build_ext as _build_ext + else: + from distutils.command.build_ext import build_ext as _build_ext + + class cmd_build_ext(_build_ext): + def run(self): + root = get_root() + cfg = get_config_from_root(root) + versions = get_versions() + _build_ext.run(self) + if self.inplace: + # build_ext --inplace will only build extensions in + # build/lib<..> dir with no _version.py to write to. + # As in place builds will already have a _version.py + # in the module dir, we do not need to write one. + return + # now locate _version.py in the new build/ directory and replace + # it with an updated value + target_versionfile = os.path.join(self.build_lib, + cfg.versionfile_build) + print("UPDATING %s" % target_versionfile) + write_to_version_file(target_versionfile, versions) + cmds["build_ext"] = cmd_build_ext + + if "cx_Freeze" in sys.modules: # cx_freeze enabled? + from cx_Freeze.dist import build_exe as _build_exe + # nczeczulin reports that py2exe won't like the pep440-style string + # as FILEVERSION, but it can be used for PRODUCTVERSION, e.g. + # setup(console=[{ + # "version": versioneer.get_version().split("+", 1)[0], # FILEVERSION + # "product_version": versioneer.get_version(), + # ... + + class cmd_build_exe(_build_exe): + def run(self): + root = get_root() + cfg = get_config_from_root(root) + versions = get_versions() + target_versionfile = cfg.versionfile_source + print("UPDATING %s" % target_versionfile) + write_to_version_file(target_versionfile, versions) + + _build_exe.run(self) + os.unlink(target_versionfile) + with open(cfg.versionfile_source, "w") as f: + LONG = LONG_VERSION_PY[cfg.VCS] + f.write(LONG % + {"DOLLAR": "$", + "STYLE": cfg.style, + "TAG_PREFIX": cfg.tag_prefix, + "PARENTDIR_PREFIX": cfg.parentdir_prefix, + "VERSIONFILE_SOURCE": cfg.versionfile_source, + }) + cmds["build_exe"] = cmd_build_exe + del cmds["build_py"] + + if 'py2exe' in sys.modules: # py2exe enabled? + from py2exe.distutils_buildexe import py2exe as _py2exe + + class cmd_py2exe(_py2exe): + def run(self): + root = get_root() + cfg = get_config_from_root(root) + versions = get_versions() + target_versionfile = cfg.versionfile_source + print("UPDATING %s" % target_versionfile) + write_to_version_file(target_versionfile, versions) + + _py2exe.run(self) + os.unlink(target_versionfile) + with open(cfg.versionfile_source, "w") as f: + LONG = LONG_VERSION_PY[cfg.VCS] + f.write(LONG % + {"DOLLAR": "$", + "STYLE": cfg.style, + "TAG_PREFIX": cfg.tag_prefix, + "PARENTDIR_PREFIX": cfg.parentdir_prefix, + "VERSIONFILE_SOURCE": cfg.versionfile_source, + }) + cmds["py2exe"] = cmd_py2exe + + # we override different "sdist" commands for both environments + if 'sdist' in cmds: + _sdist = cmds['sdist'] + elif "setuptools" in sys.modules: + from setuptools.command.sdist import sdist as _sdist + else: + from distutils.command.sdist import sdist as _sdist + + class cmd_sdist(_sdist): + def run(self): + versions = get_versions() + # pylint:disable=attribute-defined-outside-init # noqa + self._versioneer_generated_versions = versions + # unless we update this, the command will keep using the old + # version + self.distribution.metadata.version = versions["version"] + return _sdist.run(self) + + def make_release_tree(self, base_dir, files): + root = get_root() + cfg = get_config_from_root(root) + _sdist.make_release_tree(self, base_dir, files) + # now locate _version.py in the new base_dir directory + # (remembering that it may be a hardlink) and replace it with an + # updated value + target_versionfile = os.path.join(base_dir, cfg.versionfile_source) + print("UPDATING %s" % target_versionfile) + write_to_version_file(target_versionfile, + self._versioneer_generated_versions) + cmds["sdist"] = cmd_sdist + + return cmds + + +CONFIG_ERROR = """ +setup.cfg is missing the necessary Versioneer configuration. You need +a section like: + + [versioneer] + VCS = git + style = pep440 + versionfile_source = src/myproject/_version.py + versionfile_build = myproject/_version.py + tag_prefix = + parentdir_prefix = myproject- + +You will also need to edit your setup.py to use the results: + + import versioneer + setup(version=versioneer.get_version(), + cmdclass=versioneer.get_cmdclass(), ...) + +Please read the docstring in ./versioneer.py for configuration instructions, +edit setup.cfg, and re-run the installer or 'python versioneer.py setup'. +""" + +SAMPLE_CONFIG = """ +# See the docstring in versioneer.py for instructions. Note that you must +# re-run 'versioneer.py setup' after changing this section, and commit the +# resulting files. + +[versioneer] +#VCS = git +#style = pep440 +#versionfile_source = +#versionfile_build = +#tag_prefix = +#parentdir_prefix = + +""" + +OLD_SNIPPET = """ +from ._version import get_versions +__version__ = get_versions()['version'] +del get_versions +""" + +INIT_PY_SNIPPET = """ +from . import {0} +__version__ = {0}.get_versions()['version'] +""" + + +def do_setup(): + """Do main VCS-independent setup function for installing Versioneer.""" + root = get_root() + try: + cfg = get_config_from_root(root) + except (EnvironmentError, configparser.NoSectionError, + configparser.NoOptionError) as e: + if isinstance(e, (EnvironmentError, configparser.NoSectionError)): + print("Adding sample versioneer config to setup.cfg", + file=sys.stderr) + with open(os.path.join(root, "setup.cfg"), "a") as f: + f.write(SAMPLE_CONFIG) + print(CONFIG_ERROR, file=sys.stderr) + return 1 + + print(" creating %s" % cfg.versionfile_source) + with open(cfg.versionfile_source, "w") as f: + LONG = LONG_VERSION_PY[cfg.VCS] + f.write(LONG % {"DOLLAR": "$", + "STYLE": cfg.style, + "TAG_PREFIX": cfg.tag_prefix, + "PARENTDIR_PREFIX": cfg.parentdir_prefix, + "VERSIONFILE_SOURCE": cfg.versionfile_source, + }) + + ipy = os.path.join(os.path.dirname(cfg.versionfile_source), + "__init__.py") + if os.path.exists(ipy): + try: + with open(ipy, "r") as f: + old = f.read() + except EnvironmentError: + old = "" + module = os.path.splitext(os.path.basename(cfg.versionfile_source))[0] + snippet = INIT_PY_SNIPPET.format(module) + if OLD_SNIPPET in old: + print(" replacing boilerplate in %s" % ipy) + with open(ipy, "w") as f: + f.write(old.replace(OLD_SNIPPET, snippet)) + elif snippet not in old: + print(" appending to %s" % ipy) + with open(ipy, "a") as f: + f.write(snippet) + else: + print(" %s unmodified" % ipy) + else: + print(" %s doesn't exist, ok" % ipy) + ipy = None + + # Make sure both the top-level "versioneer.py" and versionfile_source + # (PKG/_version.py, used by runtime code) are in MANIFEST.in, so + # they'll be copied into source distributions. Pip won't be able to + # install the package without this. + manifest_in = os.path.join(root, "MANIFEST.in") + simple_includes = set() + try: + with open(manifest_in, "r") as f: + for line in f: + if line.startswith("include "): + for include in line.split()[1:]: + simple_includes.add(include) + except EnvironmentError: + pass + # That doesn't cover everything MANIFEST.in can do + # (http://docs.python.org/2/distutils/sourcedist.html#commands), so + # it might give some false negatives. Appending redundant 'include' + # lines is safe, though. + if "versioneer.py" not in simple_includes: + print(" appending 'versioneer.py' to MANIFEST.in") + with open(manifest_in, "a") as f: + f.write("include versioneer.py\n") + else: + print(" 'versioneer.py' already in MANIFEST.in") + if cfg.versionfile_source not in simple_includes: + print(" appending versionfile_source ('%s') to MANIFEST.in" % + cfg.versionfile_source) + with open(manifest_in, "a") as f: + f.write("include %s\n" % cfg.versionfile_source) + else: + print(" versionfile_source already in MANIFEST.in") + + # Make VCS-specific changes. For git, this means creating/changing + # .gitattributes to mark _version.py for export-subst keyword + # substitution. + do_vcs_install(manifest_in, cfg.versionfile_source, ipy) + return 0 + + +def scan_setup_py(): + """Validate the contents of setup.py against Versioneer's expectations.""" + found = set() + setters = False + errors = 0 + with open("setup.py", "r") as f: + for line in f.readlines(): + if "import versioneer" in line: + found.add("import") + if "versioneer.get_cmdclass()" in line: + found.add("cmdclass") + if "versioneer.get_version()" in line: + found.add("get_version") + if "versioneer.VCS" in line: + setters = True + if "versioneer.versionfile_source" in line: + setters = True + if len(found) != 3: + print("") + print("Your setup.py appears to be missing some important items") + print("(but I might be wrong). Please make sure it has something") + print("roughly like the following:") + print("") + print(" import versioneer") + print(" setup( version=versioneer.get_version(),") + print(" cmdclass=versioneer.get_cmdclass(), ...)") + print("") + errors += 1 + if setters: + print("You should remove lines like 'versioneer.VCS = ' and") + print("'versioneer.versionfile_source = ' . This configuration") + print("now lives in setup.cfg, and should be removed from setup.py") + print("") + errors += 1 + return errors + + +if __name__ == "__main__": + cmd = sys.argv[1] + if cmd == "setup": + errors = do_setup() + errors += scan_setup_py() + if errors: + sys.exit(1)