diff --git a/.editorconfig b/.editorconfig
new file mode 100644
index 0000000..1e56054
--- /dev/null
+++ b/.editorconfig
@@ -0,0 +1,24 @@
+# http://editorconfig.org
+
+root = true
+
+[*]
+indent_style = space
+indent_size = 4
+trim_trailing_whitespace = true
+insert_final_newline = true
+charset = utf-8
+end_of_line = lf
+
+[*.bat]
+indent_style = tab
+end_of_line = crlf
+
+[LICENCE]
+insert_final_newline = false
+
+[Makefile]
+indent_style = tab
+
+[*.{diff,patch}]
+trim_trailing_whitespace = false
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..bed1d91
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,450 @@
+### Large data files and model weights ###
+*.h5
+*.hdf5
+*.pt
+*.pth
+*.pb
+*.tflite
+*.onnx
+*.tfrecord
+*.tfrecord-*
+
+#pycharm+all,vim,macos,windows,linux,python,zsh,visualstudiocode,flask
+
+### Flask ###
+instance/*
+!instance/.gitignore
+.webassets-cache
+.env
+
+### Flask.Python Stack ###
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+*.py,cover
+.hypothesis/
+.pytest_cache/
+cover/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+.pybuilder/
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+# For a library or package, you might want to ignore these files since the code is
+# intended to run in multiple environments; otherwise, check them in:
+# .python-version
+
+# pipenv
+# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+# However, in case of collaboration, if having platform-specific dependencies or dependencies
+# having no cross-platform support, pipenv may install dependencies that don't work, or not
+# install all needed dependencies.
+#Pipfile.lock
+
+# PEP 582; used by e.g. github.com/David-OConnor/pyflow
+__pypackages__/
+
+# Celery stuff
+celerybeat-schedule
+celerybeat.pid
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
+
+# pytype static type analyzer
+.pytype/
+
+# Cython debug symbols
+cython_debug/
+
+### Linux ###
+*~
+
+# temporary files which can be created if a process still has a handle open of a deleted file
+.fuse_hidden*
+
+# KDE directory preferences
+.directory
+
+# Linux trash folder which might appear on any partition or disk
+.Trash-*
+
+# .nfs files are created when an open file is removed but is still being accessed
+.nfs*
+
+### macOS ###
+# General
+.DS_Store
+.AppleDouble
+.LSOverride
+
+# Icon must end with two \r
+Icon
+
+
+# Thumbnails
+._*
+
+# Files that might appear in the root of a volume
+.DocumentRevisions-V100
+.fseventsd
+.Spotlight-V100
+.TemporaryItems
+.Trashes
+.VolumeIcon.icns
+.com.apple.timemachine.donotpresent
+
+# Directories potentially created on remote AFP share
+.AppleDB
+.AppleDesktop
+Network Trash Folder
+Temporary Items
+.apdisk
+
+### PyCharm+all ###
+# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider
+# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
+
+# User-specific stuff
+.idea/**/workspace.xml
+.idea/**/tasks.xml
+.idea/**/usage.statistics.xml
+.idea/**/dictionaries
+.idea/**/shelf
+
+# AWS User-specific
+.idea/**/aws.xml
+
+# Generated files
+.idea/**/contentModel.xml
+
+# Sensitive or high-churn files
+.idea/**/dataSources/
+.idea/**/dataSources.ids
+.idea/**/dataSources.local.xml
+.idea/**/sqlDataSources.xml
+.idea/**/dynamic.xml
+.idea/**/uiDesigner.xml
+.idea/**/dbnavigator.xml
+
+# Gradle
+.idea/**/gradle.xml
+.idea/**/libraries
+
+# Gradle and Maven with auto-import
+# When using Gradle or Maven with auto-import, you should exclude module files,
+# since they will be recreated, and may cause churn. Uncomment if using
+# auto-import.
+# .idea/artifacts
+# .idea/compiler.xml
+# .idea/jarRepositories.xml
+# .idea/modules.xml
+# .idea/*.iml
+# .idea/modules
+# *.iml
+# *.ipr
+
+# CMake
+cmake-build-*/
+
+# Mongo Explorer plugin
+.idea/**/mongoSettings.xml
+
+# File-based project format
+*.iws
+
+# IntelliJ
+out/
+
+# mpeltonen/sbt-idea plugin
+.idea_modules/
+
+# JIRA plugin
+atlassian-ide-plugin.xml
+
+# Cursive Clojure plugin
+.idea/replstate.xml
+
+# Crashlytics plugin (for Android Studio and IntelliJ)
+com_crashlytics_export_strings.xml
+crashlytics.properties
+crashlytics-build.properties
+fabric.properties
+
+# Editor-based Rest Client
+.idea/httpRequests
+
+# Android studio 3.1+ serialized cache file
+.idea/caches/build_file_checksums.ser
+
+### PyCharm+all Patch ###
+# Ignores the whole .idea folder and all .iml files
+# See https://github.com/joeblau/gitignore.io/issues/186 and https://github.com/joeblau/gitignore.io/issues/360
+
+.idea/
+
+# Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-249601023
+
+*.iml
+modules.xml
+.idea/misc.xml
+*.ipr
+
+# Sonarlint plugin
+.idea/sonarlint
+
+### Python ###
+# Byte-compiled / optimized / DLL files
+
+# C extensions
+
+# Distribution / packaging
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+
+# Installer logs
+
+# Unit test / coverage reports
+
+# Translations
+
+# Django stuff:
+
+# Flask stuff:
+
+# Scrapy stuff:
+
+# Sphinx documentation
+
+# PyBuilder
+
+# Jupyter Notebook
+
+# IPython
+
+# pyenv
+# For a library or package, you might want to ignore these files since the code is
+# intended to run in multiple environments; otherwise, check them in:
+# .python-version
+
+# pipenv
+# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+# However, in case of collaboration, if having platform-specific dependencies or dependencies
+# having no cross-platform support, pipenv may install dependencies that don't work, or not
+# install all needed dependencies.
+
+# PEP 582; used by e.g. github.com/David-OConnor/pyflow
+
+# Celery stuff
+
+# SageMath parsed files
+
+# Environments
+
+# Spyder project settings
+
+# Rope project settings
+
+# mkdocs documentation
+
+# mypy
+
+# Pyre type checker
+
+# pytype static type analyzer
+
+# Cython debug symbols
+
+### Vim ###
+# Swap
+[._]*.s[a-v][a-z]
+!*.svg # comment out if you don't need vector files
+[._]*.sw[a-p]
+[._]s[a-rt-v][a-z]
+[._]ss[a-gi-z]
+[._]sw[a-p]
+
+# Session
+Session.vim
+Sessionx.vim
+
+# Temporary
+.netrwhist
+# Auto-generated tag files
+tags
+# Persistent undo
+[._]*.un~
+
+### VisualStudioCode ###
+.vscode/*
+!.vscode/settings.json
+!.vscode/tasks.json
+!.vscode/launch.json
+!.vscode/extensions.json
+*.code-workspace
+
+# Local History for Visual Studio Code
+.history/
+
+### VisualStudioCode Patch ###
+# Ignore all local history of files
+.history
+.ionide
+
+### Windows ###
+# Windows thumbnail cache files
+Thumbs.db
+Thumbs.db:encryptable
+ehthumbs.db
+ehthumbs_vista.db
+
+# Dump file
+*.stackdump
+
+# Folder config file
+[Dd]esktop.ini
+
+# Recycle Bin used on file shares
+$RECYCLE.BIN/
+
+# Windows Installer files
+*.cab
+*.msi
+*.msix
+*.msm
+*.msp
+
+# Windows shortcuts
+*.lnk
+
+### Zsh ###
+# Zsh compiled script + zrecompile backup
+*.zwc
+*.zwc.old
+
+# Zsh completion-optimization dumpfile
+*zcompdump*
+
+# Zsh zcalc history
+.zcalc_history
+
+# A popular plugin manager's files
+._zinit
+.zinit_lstupd
+
+# zdharma/zshelldoc tool's files
+zsdoc/data
+
+# robbyrussell/oh-my-zsh/plugins/per-directory-history plugin's files
+# (when set-up to store the history in the local directory)
+.directory_history
+
+# MichaelAquilina/zsh-autoswitch-virtualenv plugin's files
+# (for Zsh plugins using Python)
+
+# Zunit tests' output
+/tests/_output/*
+!/tests/_output/.gitkeep
+
+# End of https://www.toptal.com/developers/gitignore/api/pycharm+all,vim,macos,windows,linux,python,zsh,visualstudiocode,flask
+
diff --git a/CODEOWNERS b/CODEOWNERS
new file mode 100644
index 0000000..d6374ff
--- /dev/null
+++ b/CODEOWNERS
@@ -0,0 +1,11 @@
+# Lines starting with '#' are comments.
+# Each line is a file pattern followed by one or more owners.
+
+# More infos at https://github.blog/2017-07-06-introducing-code-owners/
+# and
+# https://docs.github.com/en/github/creating-cloning-and-archiving-repositories/about-code-owners
+
+# This is a wildcard pattern that means we have to review every filetypes
+# When the dev team will grow, we'll remove that and replace it with
+# specific filetypes/paths relevant to different people
+* @sam1902
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
new file mode 100644
index 0000000..d880a8f
--- /dev/null
+++ b/CONTRIBUTING.md
@@ -0,0 +1,149 @@
+# Commit guidelines
+
+## Don't commit to `master`
+ - The `master` branch of this repo is write-protected. All changes need to be made through *Pull Requests*.
+
+ - Pull requests must follow this following guidelines:
+ - Have a [**linear history**](https://www.bitsnbites.eu/a-tidy-linear-git-history/).
+ - Only feature [signed commits](https://docs.github.com/en/github/authenticating-to-github/signing-commits).
+ - Be approved by at least one *Code Owner*.
+
+## Git LFS is evil
+
+Whatever you do, don't commit binary files. The `.gitignore` is already set up such that you can't commit most common binary files such as .aar, .apk, .so etc. but you should be aware that if it's binary, it doesn't belong in git version tracking.
+
+Git LFS is a system to allow tracking of binary files, but great powers come with great responsibilities and it does more harm than good in most cases. To see if the repo contains any Git LFS files that may harm your build process, simply run the following from within the repo and look for matches:
+
+```
+grep -rnw . -e 'git-lfs'
+```
+
+If anything comes up, remove the corresponding files, rollback to a previous version without those, burn the repo down, but don't push that or we'll all be forced to use git lfs for the rest of eternity.
+
+## Everyone is responsible for their own development environment
+
+If you've got issues with your dev env, try asking other devs that work in the same environment. Asking someone with a different environment will only slow everyone down as they'll have to learn the specifics of your setup when they don't need to know.
+
+I personally use a combination of [JetBrain's PyCharm](https://www.jetbrains.com/pycharm/) and Vim on macOs.
+
+When you have to make changes for your environment specifically, **don't commit those changes**. You can add those files to a file called `.git/info/exclude` which acts like a local version of a `.gitignore`.
+
+## Whitespace errors
+
+First, your submissions should not contain any whitespace errors. Git provides an easy way to check for this — before you commit, run `git diff --check`, which identifies possible whitespace errors and lists them for you.
+
+![output from git diff --check example](https://git-scm.com/book/en/v2/images/git-diff-check.png)
+
+If you run that command before committing, you can tell if you’re about to commit whitespace issues that may annoy other developers.
+
+## Separate commits logically
+
+Try to make each commit a logically separate changeset. If you can, try to make your changes digestible — don’t code for a whole weekend on five different issues and then submit them all as one massive commit on Monday. Even if you don’t commit during the weekend, use the staging area on Monday to split your work into at least one commit per issue, with a useful message per commit.
+
+If some of the changes modify the same file, try to use git add --patch to **partially stage files** (covered in detail in [Interactive Staging](https://git-scm.com/book/en/v2/ch00/_interactive_staging)).
+
+If you want to remove a file from stating, run `git reset HEAD {file}`. This won't change the file content, don't worry.
+
+The project snapshot at the tip of the branch is identical whether you do one commit or five, as long as all the changes are added at some point, so try to make things easier on your fellow developers when they have to review your changes.
+
+[Rewriting History](https://git-scm.com/book/en/v2/ch00/_rewriting_history) describes a number of useful Git tricks for rewriting history and interactively staging files — use these tools to help craft a **clean and understandable** history before sending the work to someone else.
+
+## Commit message
+As a general rule, your messages should start with a single line that’s **no more than about 50 characters** and that describes the changeset concisely, followed by a blank line, followed by a more detailed explanation.
+
+The Git project requires that the more detailed explanation include your motivation for the change and contrast its implementation with previous behavior — this is a good guideline to follow.
+
+Write your commit message in the imperative: "Fix bug" and not "Fixed bug" or "Fixes bug.". You messages should always start with an actionnable verb: Make, Fix, Add, Improve, Update, etc. Here is a template you can follow, which we’ve lightly adapted from one originally [written by Tim Pope](https://tbaggery.com/2008/04/19/a-note-about-git-commit-messages.html):
+> Capitalized, short (50 chars or less) summary
+>
+> More detailed explanatory text, if necessary. Wrap it to about 72
+> characters or so. In some contexts, the first line is treated as the
+> subject of an email and the rest of the text as the body. The blank
+> line separating the summary from the body is critical (unless you omit
+> the body entirely); tools like rebase will confuse you if you run the
+> two together.
+>
+> Write your commit message in the imperative: "Fix bug" and not "Fixed bug"
+> or "Fixes bug." This convention matches up with commit messages generated
+> by commands like git merge and git revert.
+>
+> Further paragraphs come after blank lines.
+>
+> - Bullet points are okay, too
+>
+> - Typically a hyphen or asterisk is used for the bullet, followed by a
+> single space, with blank lines in between, but conventions vary here
+>
+> - Use a hanging indent
+
+Try running `git log --no-merges` there to see what a nicely-formatted project-commit history looks like.
+
+# Steps for creating good pull requests
+
+## Name
+
+Change the pull request's name to something meaningful. By default it'll just be generated from the branch's name, but rename it yourself before posting it.
+
+## Sections and links
+
+Use markdown titles to explain the changes you've made, and why you made them. This should include details about any contingency you've encountered while developing this feature, and **links to resouces** that helped you solve them, such as Stack Overflow links from any code snippet, page explaining the technology, documentation hinting at problematic limitations.
+
+If that web page is huge (like one page documentation for the whole lib), then **try to make those point to a specific point** in the webpage you're linking. This can be done easily by clicking on little HTML anchors, typically next to the section titles. They should add a `#` at the end of the URL followed by the section title, like `https://mydoc.com/doc/superlibrary.html#relevant-section-title`.
+
+## User interface ?
+
+If you PR is related to anything visual, **add a screenshot** of what the feature looks like. If it's dynamic, you can add a GIF instead, but it's not mandatory.
+
+## Breaking changes
+
+If you PR fixes a bug, describe precisely how the changes you've made affect the code behaviour.
+
+Questions you should answers typically look like:
+
+ - Does this method now returns a `null` when the URL value is empty ?
+ - Does that default parameter value changed somehow ?
+
+If you've made any change affecting the public interface of a class or function (think Java's `public` methods), then **document it**.
+
+## Update the tests
+
+### Before your changes
+
+Tests are tests, and tests can break. Before you push any commit to remote, make sure your branch is clean with `git status`, then checkout the base of your branch using `git checkout {hash}` and replace `{hash}` with you branch's base commit, then run **all** the unit tests there and see if they pass.
+
+If they don't pass, then fix your dev environment (env vars and such), it's the only possible cause.
+
+### After your changes
+
+After you've made sure your environment is correctly configured, checkout your branch's latest commit and rerun all the unit tests you've just ran. If they don't pass, it means you've broke something in between and you should fix that before pushing.
+
+### Coverage
+
+Then, run a [code coverage tool](https://www.wikiwand.com/en/Code_coverage) to make sure all the new code you've written is *at the very least* checked out by one test.
+
+**100% coverage is not enough !**
+
+You can have 100% coverage with poorly designed tests, unit tests should test many different scenario, not just one. But even with just one, you'll get that coverage, so it can be a misleading metric. A metric cease to be good when it becomes a target.
+
+## Involuntary changes
+
+When trying to understand the codebase, you might be inclined to put `print` commands here and there. It's fine to do so, just use any tool you're confortable with, but please for god's sake **don't let them in when commiting**.
+
+Before commiting, run a `git status` and see what file have changed. Only relevant files should have changed, if any odd file changed, run `git diff {file}` to see what you changed and if it's relevant. If it's not, then run `git checkout -- {file}` to reset it to the latest commited state.
+
+This does not only apply to `print`s, but also newlines left after removing `print`s manually, automatic formatting tools in IDE that change the whole file to your own style settings etc. **You should only perform the minimal changes to implement your feature**.
+
+The reason for that is that when reviewing your PR, reviewers might have a thousand files to "view" even though you just added a newline to them. Also, you'll appear as though you've made changes to a thousand file whereas you only meant to change two of them.
+
+# Steps for creating good issues
+
+Document how to reproduce the issue, starting from `git clone` the affected pushed commit. Make sure to dump you `$PATH` variable as this tends to be affecting environment specific behaviour. When using Python, add your interpreter's `pip3 freeze` to list packets and their version. If the list is too long, put it inside a spoiler in markdown.
+
+Tag your issue.
+
+If you have any gut intuition as to what's causing it, write that down in the issue, along with any reference that might have helped you arrive at this conclusion.
+
+# Links to external documentation, mailing lists, or a code of conduct.
+
+ - [Contibuting to a Project on GIT](https://git-scm.com/book/en/v2/Distributed-Git-Contributing-to-a-Project)
+ - Please follow [our code of conduct](https://thoughtbot.com/open-source-code-of-conduct).
diff --git a/LICENCE.txt b/LICENCE.txt
new file mode 100644
index 0000000..5686a75
--- /dev/null
+++ b/LICENCE.txt
@@ -0,0 +1,32 @@
+
+
+BSD License
+
+Copyright (c) 2021-08-13, DesignStripe
+All rights reserved.
+
+Redistribution and use in source and binary forms, with or without modification,
+are permitted provided that the following conditions are met:
+
+* Redistributions of source code must retain the above copyright notice, this
+ list of conditions and the following disclaimer.
+
+* Redistributions in binary form must reproduce the above copyright notice, this
+ list of conditions and the following disclaimer in the documentation and/or
+ other materials provided with the distribution.
+
+* Neither the name of the copyright holder nor the names of its
+ contributors may be used to endorse or promote products derived from this
+ software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
+ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
+IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT,
+INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
+OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE
+OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED
+OF THE POSSIBILITY OF SUCH DAMAGE.
+
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..c179c1c
--- /dev/null
+++ b/README.md
@@ -0,0 +1,92 @@
+# Torch Pconv
+
+[![PyPI version](https://badge.fury.io/py/torch_pconv.svg)](https://badge.fury.io/py/torch_pconv)
+
+Faster and more memory efficient implementation of the Partial Convolution 2D layer in PyTorch equivalent to the
+standard Nvidia implementation.
+
+This implementation has numerous advantages:
+
+1. It is **strictly equivalent** in computation
+ to [the reference implementation by Nvidia](https://github.com/NVIDIA/partialconv/blob/610d373f35257887d45adae84c86d0ce7ad808ec/models/partialconv2d.py)
+ . I made unit tests to assess that all throughout development.
+2. It's commented and more readable
+3. It's faster and more memory efficient, which means you can use more layers on smaller GPUs. It's a good thing
+ considering today's GPU prices.
+4. It's a PyPI-published library. You can `pip` install it instead of copy/pasting source code, and get the benefit of (
+ free) bugfixes when someone notice a bug in the implementation.
+
+![Total memory cost (in bytes)](doc/2021-08-13_perfs.png?raw=true)
+
+## Getting started
+
+```sh
+pip3 install torch_pconv
+```
+
+## Usage
+
+```python3
+import torch
+from torch_pconv import PConv2d
+
+images = torch.rand(32, 3, 256, 256)
+masks = (torch.rand(32, 256, 256) > 0.5).to(torch.float32)
+
+pconv = PConv2d(
+ in_channels=3,
+ out_channels=64,
+ kernel_size=7,
+ stride=1,
+ padding=2,
+ dilation=2,
+ bias=True
+)
+
+output, shrunk_masks = pconv(images, masks)
+```
+
+## Performance improvement
+
+### Test
+
+You can
+find [the reference implementation by Nvidia here](https://github.com/NVIDIA/partialconv/blob/610d373f35257887d45adae84c86d0ce7ad808ec/models/partialconv2d.py)
+.
+
+I tested their implementation vs mine one the following configuration:
+
+Parameter | Value |
+in_channels | 64 |
+out_channels | 128 |
+kernel_size | 9 |
+stride | 1 |
+padding | 3 |
+bias | True |
+input height/width | 256 |
+
+
+The goal here was to produce the most computationally expensive partial convolution operator so that the performance
+difference is displayed better.
+
+I compute both the forward and the backward pass, in case one consumes more memory than the other.
+
+### Results
+
+![Total memory cost (in bytes)](doc/2021-08-13_perfs.png?raw=true)
+
+ | torch_pconv | Nvidia® (Guilin) |
+Forward only | 813 466 624 | 4 228 120 576 |
+Backward only | 1 588 201 480 | 1 588 201 480 |
+Forward + Backward | 2 405 797 640 | 6 084 757 512 |
+
+
+## Development
+
+To install the latest version from Github, run:
+
+```
+git clone git@github.com:DesignStripe/torch_pconv.git torch_pconv
+cd torch_pconv
+pip3 install -U .
+```
diff --git a/doc/2021-08-13_perfs.png b/doc/2021-08-13_perfs.png
new file mode 100644
index 0000000..a749a6a
Binary files /dev/null and b/doc/2021-08-13_perfs.png differ
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000..3cb17b7
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,4 @@
+torch
+tensor_type
+pshape
+pytest
diff --git a/setup.cfg b/setup.cfg
new file mode 100644
index 0000000..dc72275
--- /dev/null
+++ b/setup.cfg
@@ -0,0 +1,5 @@
+[aliases]
+test=pytest
+
+[flake8]
+max-line-length=120
diff --git a/setup.py b/setup.py
new file mode 100644
index 0000000..80542d6
--- /dev/null
+++ b/setup.py
@@ -0,0 +1,30 @@
+#!/usr/bin/env python3
+import pathlib
+
+from setuptools import setup
+
+HERE = pathlib.Path(__file__).parent
+README = (HERE / "README.md").read_text()
+
+setup(
+ name="torch_pconv",
+ version="0.1.0",
+ packages=["torch_pconv"],
+ description="Faster and more memory efficient implementation of the Partial Convolution 2D"
+ " layer in PyTorch equivalent to the standard NVidia implem.",
+ long_description=README,
+ long_description_content_type="text/markdown",
+ url="https://github.com/DesignStripe/torch_pconv",
+ author="Samuel Prevost",
+ author_email="samuel.prevost@pm.me",
+ licence="BSD",
+ classifiers=[
+ "License :: OSI Approved :: BSD License",
+ "Programming Language :: Python :: 3",
+ "Programming Language :: Python :: 3.8",
+ ],
+ include_package_data=True,
+ install_requires=["torch", "tensor_type", "pshape"],
+ setup_requires=['pytest-runner'],
+ tests_require=['pytest'],
+)
diff --git a/tests/conv_config.py b/tests/conv_config.py
new file mode 100644
index 0000000..5461497
--- /dev/null
+++ b/tests/conv_config.py
@@ -0,0 +1,20 @@
+from dataclasses import dataclass, asdict
+
+
+@dataclass
+class ConvConfig:
+ in_channels: int
+ out_channels: int
+ kernel_size: int
+ stride: int = 1
+ padding: int = 0
+ dilation: int = 1
+ bias: bool = False
+
+ def copy(self):
+ # noinspection PyArgumentList
+ return self.__class__(**asdict(self))
+
+ @property
+ def dict(self):
+ return asdict(self)
diff --git a/tests/pconv_guilin.py b/tests/pconv_guilin.py
new file mode 100644
index 0000000..7fac425
--- /dev/null
+++ b/tests/pconv_guilin.py
@@ -0,0 +1,110 @@
+###############################################################################
+# BSD 3-Clause License
+#
+# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
+#
+# Author & Contact: Guilin Liu (guilinl@nvidia.com)
+###############################################################################
+"""
+Code by Guilin Liu at
+https://github.com/NVIDIA/partialconv/blob/610d373f35257887d45adae84c86d0ce7ad808ec/models/partialconv2d.py
+
+I tried to modify the least code: just enough to make is compatible with 3D masks (instead of 4D)
+"""
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+
+class PConvGuilin(nn.Conv2d):
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+ kwargs['multi_channel'] = True
+
+ # whether the mask is multi-channel or not
+ if 'multi_channel' in kwargs:
+ self.multi_channel = kwargs['multi_channel']
+ kwargs.pop('multi_channel')
+ else:
+ self.multi_channel = False
+
+ self.return_mask = True
+
+ if self.multi_channel:
+ self.weight_maskUpdater = torch.ones(self.out_channels, self.in_channels, self.kernel_size[0],
+ self.kernel_size[1])
+ else:
+ self.weight_maskUpdater = torch.ones(1, 1, self.kernel_size[0], self.kernel_size[1])
+
+ self.slide_winsize = self.weight_maskUpdater.shape[1] * self.weight_maskUpdater.shape[2] * \
+ self.weight_maskUpdater.shape[3]
+
+ self.last_size = (None, None, None, None)
+ self.update_mask = None
+ self.mask_ratio = None
+
+ def forward(self, inputs, mask_in=None):
+ if len(inputs.shape) != 4 or len(mask_in.shape) != 3:
+ raise TypeError()
+
+ if inputs.dtype != torch.float32 or mask_in.dtype != torch.float32:
+ raise TypeError()
+
+ mask_in = mask_in[:, None].expand(-1, inputs.shape[1], -1, -1)
+
+ if mask_in is not None or self.last_size != tuple(inputs.shape):
+ self.last_size = tuple(inputs.shape)
+
+ with torch.no_grad():
+ if self.weight_maskUpdater.type() != inputs.type():
+ self.weight_maskUpdater = self.weight_maskUpdater.to(inputs)
+
+ if mask_in is None:
+ # if mask is not provided, create a mask
+ if self.multi_channel:
+ mask = torch.ones(inputs.data.shape[0], inputs.data.shape[1], inputs.data.shape[2],
+ inputs.data.shape[3]).to(inputs)
+ else:
+ mask = torch.ones(1, 1, inputs.data.shape[2], inputs.data.shape[3]).to(inputs)
+ else:
+ mask = mask_in
+
+ self.update_mask = F.conv2d(mask, self.weight_maskUpdater, bias=None, stride=self.stride,
+ padding=self.padding, dilation=self.dilation, groups=1)
+
+ # for mixed precision training, change 1e-8 to 1e-6
+ self.mask_ratio = self.slide_winsize / (self.update_mask + 1e-8)
+ # self.mask_ratio = torch.max(self.update_mask)/(self.update_mask + 1e-8)
+ self.update_mask = torch.clamp(self.update_mask, 0, 1)
+ self.mask_ratio = torch.mul(self.mask_ratio, self.update_mask)
+
+ raw_out = nn.Conv2d.forward(self, torch.mul(inputs, mask) if mask_in is not None else inputs)
+
+ if self.bias is not None:
+ bias_view = self.bias.view(1, self.out_channels, 1, 1)
+ output = torch.mul(raw_out - bias_view, self.mask_ratio) + bias_view
+ output = torch.mul(output, self.update_mask)
+ else:
+ output = torch.mul(raw_out, self.mask_ratio)
+
+ if self.return_mask:
+ return output, self.update_mask[:, 0]
+ else:
+ return output
+
+ def set_weight(self, w):
+ with torch.no_grad():
+ self.weight.copy_(w)
+ return self
+
+ def set_bias(self, b):
+ with torch.no_grad():
+ self.bias.copy_(b)
+ return self
+
+ def get_weight(self):
+ return self.weight
+
+ def get_bias(self):
+ return self.bias
diff --git a/tests/pconv_rfr.py b/tests/pconv_rfr.py
new file mode 100644
index 0000000..1b7b8c1
--- /dev/null
+++ b/tests/pconv_rfr.py
@@ -0,0 +1,77 @@
+###############################################################################
+# BSD 3-Clause License
+#
+# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
+#
+# Author & Contact: Guilin Liu (guilinl@nvidia.com)
+###############################################################################
+"""
+Code by Guilin by but probably with some modifications by jingyuanli001, code at
+https://github.com/jingyuanli001/RFR-Inpainting/blob/faed6f154e01fc3accce5dff82a5b28e6f426fbe/modules/partialconv2d.py
+
+I tried to modify the least code: just enough to make is compatible with 3D masks (instead of 4D)
+"""
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+
+class PConvRFR(nn.Conv2d):
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+
+ self.ones = torch.ones(self.out_channels, self.in_channels, self.kernel_size[0],
+ self.kernel_size[1])
+
+ # max value of one convolution's window
+ self.slide_winsize = self.ones.shape[1] * self.ones.shape[2] * self.ones.shape[3]
+
+ self.update_mask = None
+ self.mask_ratio = None
+
+ def forward(self, inputs, mask=None):
+ if len(inputs.shape) != 4 or len(mask.shape) != 3:
+ raise TypeError()
+
+ if inputs.dtype != torch.float32 or mask.dtype != torch.float32:
+ raise TypeError()
+
+ mask = mask[:, None].expand(-1, inputs.shape[1], -1, -1)
+
+ with torch.no_grad():
+ self.update_mask = F.conv2d(mask, self.ones.to(mask), bias=None, stride=self.stride,
+ padding=self.padding,
+ dilation=self.dilation,
+ groups=1)
+
+ self.mask_ratio = self.slide_winsize / (self.update_mask + 1e-8)
+ self.update_mask = torch.clamp(self.update_mask, 0, 1)
+ self.mask_ratio *= self.update_mask
+
+ raw_out = nn.Conv2d.forward(self, inputs * mask)
+
+ if self.bias is not None:
+ bias_view = self.bias.view(1, self.out_channels, 1, 1)
+ output = torch.mul(raw_out - bias_view, self.mask_ratio) + bias_view
+ output = torch.mul(output, self.update_mask)
+ else:
+ output = raw_out * self.mask_ratio
+
+ return output, self.update_mask[:, 0]
+
+ def set_weight(self, w):
+ with torch.no_grad():
+ self.weight.copy_(w)
+ return self
+
+ def set_bias(self, b):
+ with torch.no_grad():
+ self.bias.copy_(b)
+ return self
+
+ def get_weight(self):
+ return self.weight
+
+ def get_bias(self):
+ return self.bias
diff --git a/tests/test_pconv.py b/tests/test_pconv.py
new file mode 100644
index 0000000..493e76e
--- /dev/null
+++ b/tests/test_pconv.py
@@ -0,0 +1,378 @@
+import itertools
+import unittest
+from functools import partial
+from typing import List, Type, Dict, Tuple, Callable, Union, Iterable
+
+import torch
+from pshape import pshape
+from torch import Tensor
+from torch.profiler import profile, ProfilerActivity
+
+from torch_pconv import PConv2d
+from pconv_guilin import PConvGuilin
+from pconv_rfr import PConvRFR
+from conv_config import ConvConfig
+
+PConvLike = torch.nn.Module
+
+
+class TestPConv(unittest.TestCase):
+ pconv_classes = [
+ PConvGuilin,
+ PConvRFR,
+ # This forces numerical error to be the same as other implementations, but makes the computation a bit slower
+ partial(PConv2d, legacy_behaviour=True),
+ ]
+
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
+
+ def test_output_shapes(self):
+ b, c, h = 16, 3, 256
+ image, mask = self.mkinput(b=b, c=c, h=h)
+ configs = [
+ ConvConfig(3, 64, 5, padding=2, stride=2),
+ ConvConfig(64, 64, 5, padding=1),
+ ConvConfig(64, 64, 3, padding=4),
+ ConvConfig(64, 64, 7, padding=5),
+ ConvConfig(64, 32, 3, padding=2),
+ ]
+ expected_heights = (128, 126, 132, 136, 138,)
+
+ self.assertEqual(len(configs), len(expected_heights))
+
+ outputs_imgs, outputs_masks = image, mask
+ for expected_height, config in zip(expected_heights, configs):
+ outputs_imgs, outputs_masks = self.run_pconvs(self.pconv_classes, config=config)(outputs_imgs,
+ outputs_masks)
+ for clazz in self.pconv_classes:
+ img, mask = outputs_imgs[clazz], outputs_masks[clazz]
+ self.assertTupleEqual(tuple(img.shape), (b, config.out_channels, expected_height, expected_height))
+ self.assertTupleEqual(tuple(mask.shape), (b, expected_height, expected_height))
+
+ def test_output_dtype(self):
+ b, c, h = 16, 3, 256
+ image, mask = self.mkinput(b=b, c=c, h=h)
+ configs = [
+ ConvConfig(3, 64, 5, padding=2, stride=2),
+ ConvConfig(64, 64, 5, padding=1),
+ ConvConfig(64, 64, 3, padding=4),
+ ConvConfig(64, 64, 7, padding=5),
+ ConvConfig(64, 32, 3, padding=2),
+ ]
+ expected_heights = (128, 126, 132, 136, 138,)
+
+ self.assertEqual(len(configs), len(expected_heights))
+
+ outputs_imgs, outputs_masks = image, mask
+ for expected_height, config in zip(expected_heights, configs):
+ outputs_imgs, outputs_masks = self.run_pconvs(self.pconv_classes, config=config)(outputs_imgs,
+ outputs_masks)
+ for clazz in self.pconv_classes:
+ img, mask = outputs_imgs[clazz], outputs_masks[clazz]
+ assert img.dtype == torch.float32
+ assert mask.dtype == torch.float32
+
+ def test_input_shape(self):
+ config = next(iter(self.realistic_config()))
+ # We have to call each class distinctively
+ pconv_calls = [clazz(**config.dict).to(self.device) for clazz in self.pconv_classes]
+
+ # Good dtypes
+ image = torch.rand(10, 3, 256, 256, dtype=torch.float32).to(self.device)
+ mask = (torch.rand(10, 256, 256) > 0.5).to(torch.float32).to(self.device)
+ try:
+ for pconv_call in pconv_calls:
+ pconv_call(image, mask)
+ except TypeError as e:
+ self.fail(str(e))
+
+ image = (torch.rand(10, 256, 256) * 255).to(torch.float32).to(self.device) # Bad shape, channels missing
+ mask = (torch.rand(10, 256, 256) > 0.5).to(torch.float32).to(self.device)
+ for pconv_call in pconv_calls:
+ self.assertRaises(TypeError, pconv_call, image, mask)
+
+ image = torch.rand(10, 3, 256, 256).to(torch.float32).to(self.device)
+ mask = (torch.rand(10, 3, 256, 256) > 0.5).to(torch.float32).to(self.device) # Bad shape, channels present
+ for pconv_call in pconv_calls:
+ self.assertRaises(TypeError, pconv_call, image, mask)
+
+ def test_input_dtype(self):
+ config = next(iter(self.realistic_config()))
+ # We have to call each class distinctively
+ pconv_calls = [clazz(**config.dict).to(self.device) for clazz in self.pconv_classes]
+
+ # Good dtypes
+ image = torch.rand(10, 3, 256, 256, dtype=torch.float32).to(self.device)
+ mask = (torch.rand(10, 256, 256) > 0.5).to(torch.float32).to(self.device)
+ try:
+ for pconv_call in pconv_calls:
+ pconv_call(image, mask)
+ except TypeError as e:
+ self.fail(str(e))
+
+ image = (torch.rand(10, 3, 256, 256) * 255).to(torch.uint8).to(self.device) # Bad dtype
+ mask = (torch.rand(10, 256, 256) > 0.5).to(torch.float32).to(self.device)
+ for pconv_call in pconv_calls:
+ self.assertRaises(TypeError, pconv_call, image, mask)
+
+ image = (torch.rand(10, 3, 256, 256) * 255).to(torch.float32).to(self.device)
+ mask = (torch.rand(10, 256, 256) > 0.5).to(self.device) # Bad Dtype
+ for pconv_call in pconv_calls:
+ self.assertRaises(TypeError, pconv_call, image, mask)
+
+ def test_mask_values_binary(self):
+ """The mask is a float tensor because the convolution doesn't operate on boolean tensors, however,
+ its values are still 0.0 (False) OR 1.0 (True). The masks should NEVER have 0.34 or anything in
+ between those two values.
+
+ Technical explanation for why:
+ masks are passed to the convolution with ones kernel, at that point, their values can be any integer
+ since the convolution will sum ones together, so no float value can be created here.
+ Then, we run torch.clip(mask, 0, 1). At this point, any integer value >= 1 becomes 1, leaving only 0 and 1s.
+ Rince and repeat at next iteration."""
+ image, mask = self.realistic_input()
+ configs = self.realistic_config()
+ outputs_imgs, outputs_masks = image, mask
+ for config in configs:
+ outputs_imgs, outputs_masks = self.run_pconvs(self.pconv_classes, config=config)(outputs_imgs,
+ outputs_masks)
+ for mask in outputs_masks.values():
+ assert ((mask == 1.0) | (
+ mask == 0.0)).all(), "All mask values should remain either 1.0 or 0.0, nothing in between."
+
+ def test_dilation(self):
+ image, mask = self.realistic_input()
+ configs = self.realistic_config()
+ # Enable bias on every PConv
+ for i, c in enumerate(configs):
+ c.dilation = max(1, i % 4)
+
+ outputs_imgs, outputs_masks = image, mask
+ for config in configs:
+ outputs_imgs, outputs_masks = self.run_pconvs(self.pconv_classes, config=config)(outputs_imgs,
+ outputs_masks)
+ self.compare(outputs_imgs, self.allclose)
+ self.compare(outputs_masks, self.allclose)
+
+ def test_bias(self):
+ """This test is very sensitive to numerical errors.
+ On my setup, this test passes when ran on GPU, but fails when ran on CPU. The most likely reason is that
+ the CUDA backend's way to add the bias in the convolution differs from the Intel MKL way to add the bias,
+ resulting in different numerical errors.
+
+ Just inspect the min/mean/max values and see if they differ significantly, and if they don't then ignore this
+ test failing, or send me a PR to fix it."""
+
+ image, mask = self.realistic_input()
+ configs = self.realistic_config()
+ # Enable bias on every PConv
+ for c in configs:
+ c.bias = True
+
+ outputs_imgs, outputs_masks = image, mask
+ for config in configs:
+ outputs_imgs, outputs_masks = self.run_pconvs(self.pconv_classes, config=config)(outputs_imgs,
+ outputs_masks)
+ self.compare(outputs_imgs, self.allclose)
+ self.compare(outputs_masks, self.allclose)
+
+ def test_backpropagation(self):
+ """Does a 3 step forward pass, and then attempts to backpropagate the resulting image
+ to see if the gradient can be computed and wasn't lost along the way."""
+ image, mask = self.realistic_input()
+ configs = self.realistic_config()
+ outputs_imgs, outputs_masks = image, mask
+ for config in configs:
+ outputs_imgs, outputs_masks = self.run_pconvs(self.pconv_classes, config=config)(outputs_imgs,
+ outputs_masks)
+
+ for clazz in self.pconv_classes:
+ try:
+ outputs_imgs[clazz].sum().backward()
+ except RuntimeError:
+ self.fail(f"Could not compute the gradient for {clazz.__name__}")
+
+ def test_memory_complexity(self):
+ device = torch.device('cpu')
+ image, mask = self.realistic_input(c=64, d=device)
+ config = ConvConfig(64, 128, 9, stride=1, padding=3, bias=True)
+ pconv_calls = [clazz(**config.dict).to(device) for clazz in self.pconv_classes]
+
+ tolerance = 0.1 # 10 %
+ max_mem_use = {
+ PConvGuilin: 6_084_757_512, # 5.67 GiB
+ PConvRFR: 6_084_758_024, # 5.67 GiB
+ PConv2d: 2_405_797_640, # 2.24 GiB
+ }
+
+ for pconv_call in pconv_calls:
+ with profile(activities=[ProfilerActivity.CPU],
+ profile_memory=True, record_shapes=True, with_stack=True) as prof:
+ # Don't forget to run grad computation as well, since that eats a lot of memory too
+ out_im, _ = pconv_call(image, mask)
+ out_im.sum().backward()
+
+ # Stealing the total memory stat from the profiler
+ total_mem = abs(
+ list(filter(lambda fe: fe.key == "[memory]", list(prof.key_averages())))[0].cpu_memory_usage)
+
+ # Printing how much mem used in total
+ # print(f"{pconv_call.__class__.__name__} used {self.format_bytes(total_mem)} ({total_mem})")
+
+ max_mem = (max_mem_use[pconv_call.__class__] * (1 + tolerance))
+ assert total_mem < max_mem, f"{pconv_call.__class__.__name__} used {self.format_bytes(total_mem)}" \
+ f" which is more than {self.format_bytes(max_mem)}"
+
+ def test_iterated_equality(self):
+ """
+ Tests that even when iterating:
+ 1- The output images have the same values (do not diverge due to error accumulation for example)
+ 2- The output masks have the same values
+ 3- The outputted masks are just repeated along the channel dimension
+ """
+ image, mask = self.realistic_input()
+ configs = self.realistic_config()
+
+ outputs_imgs, outputs_masks = image, mask
+ for config in configs:
+ outputs_imgs, outputs_masks = self.run_pconvs(self.pconv_classes, config=config)(outputs_imgs,
+ outputs_masks)
+
+ self.compare(outputs_imgs, self.allclose)
+ self.compare(outputs_masks, self.allclose)
+
+ def test_equality(self):
+ config = ConvConfig(in_channels=3, out_channels=64, kernel_size=5)
+ image, mask = self.mkinput(b=16, h=256, c=config.in_channels)
+
+ outputs_imgs, outputs_masks = self.run_pconvs(self.pconv_classes, config)(
+ image, mask)
+
+ self.compare(outputs_imgs, self.allclose)
+
+ self.compare(outputs_masks, self.allclose)
+
+ @classmethod
+ def realistic_input(cls, b=16, c=3, h=256, d=None) -> Tuple[Tensor, Tensor]:
+ # 16 images, each of 3 channels and of height/width 256 pixels
+ return cls.mkinput(b=b, c=c, h=h, d=cls.device if d is None else d)
+
+ @classmethod
+ def realistic_config(cls) -> Iterable[ConvConfig]:
+ # These are the partial convs used in https://github.com/jingyuanli001/RFR-Inpainting
+ # All have bias=False because in practice they're always followed by a BatchNorm2d anyway
+ return (
+ ConvConfig(3, 64, 7, stride=2, padding=3, bias=False),
+ ConvConfig(64, 64, 7, stride=1, padding=3, bias=False),
+ ConvConfig(64, 64, 7, stride=1, padding=3, bias=False),
+ ConvConfig(64, 64, 7, stride=1, padding=3, bias=False),
+ ConvConfig(64, 32, 3, stride=1, padding=1, bias=False),
+ )
+
+ @classmethod
+ def mkinput(cls, b, c, h, d=None) -> Tuple[Tensor, Tensor]:
+ if d is None:
+ d = cls.device
+ image = torch.rand(b, c, h, h).float().to(d)
+ mask = (torch.rand(b, h, h) > 0.5).float().to(d)
+ return image, mask
+
+ @staticmethod
+ def compare(values: Dict[Type[PConvLike], Tensor],
+ comparator: Callable[[Tensor, Tensor], bool]):
+ for (clazz1, out1), (clazz2, out2) in itertools.combinations(values.items(), 2):
+ eq = comparator(out1, out2)
+ if not eq:
+ pshape(out1, out2, heading=True)
+ assert eq, f"{clazz1.__name__ if hasattr(clazz1, '__name__') else 'class1'}'s doesn't match {clazz2.__name__ if hasattr(clazz2, '__name__') else 'class2'}'s output"
+
+ @classmethod
+ def run_pconvs(cls, pconvs: List[Type[PConvLike]], config: ConvConfig) -> Callable[
+ [Union[Dict[Type[PConvLike], Tensor], Tensor],
+ Union[Dict[Type[PConvLike], Tensor], Tensor]], Tuple[
+ Dict[Type[PConvLike], Tensor], Dict[Type[PConvLike], Tensor]]]:
+ """Returns a closure that :
+ Initialise each PConvLike class with the provided config,
+ set their weights and biases to be equal, and run each of them onto the
+ input(s) images/masks. Then saves the output in a dict that match the class to
+ the output. Returns that dict.
+ The closure can be called with either a specific input per class, or one input
+ which will be shared among every class.
+
+ This method's signature is admittedly a bit unwieldy...
+
+ :param pconvs: the list of PConvLike classes to run
+ :param config: the ConvConfig to use for those classes
+ :return: The returned closure takes either two tensors, or two dict of tensors
+ where keys are the corresponding PConv classes which to call it on
+ """
+
+ def inner(imgs: Union[Dict[Type[PConvLike], Tensor], Tensor],
+ masks: Union[Dict[Type[PConvLike], Tensor], Tensor]) -> \
+ Tuple[
+ Dict[Type[PConvLike], Tensor], Dict[Type[PConvLike], Tensor]]:
+ if not isinstance(imgs, dict):
+ imgs = {clazz: imgs for clazz in pconvs}
+ if not isinstance(masks, dict):
+ masks = {clazz: masks for clazz in pconvs}
+ outputs_imgs = dict()
+ outputs_masks = dict()
+ w = None
+ b = None
+ for clazz in pconvs:
+ # noinspection PyArgumentList
+ pconv = clazz(**config.dict).to(cls.device)
+ if config.bias:
+ if b is None:
+ b = pconv.get_bias()
+ else:
+ pconv.set_bias(b.clone())
+
+ if w is None:
+ w = pconv.get_weight()
+ else:
+ pconv.set_weight(w.clone())
+
+ out_img, out_mask = pconv(imgs[clazz].clone(), masks[clazz].clone())
+ outputs_imgs[clazz] = out_img
+ outputs_masks[clazz] = out_mask
+ return outputs_imgs, outputs_masks
+
+ return inner
+
+ @classmethod
+ def channelwise_allclose(cls, x):
+ close = True
+ for channel1, channel2 in itertools.combinations(x.transpose(0, 1), 2):
+ close &= cls.allclose(channel1, channel2)
+ return close
+
+ @classmethod
+ def channelwise_almost_eq(cls, x):
+ close = True
+ for channel1, channel2 in itertools.combinations(x.transpose(0, 1), 2):
+ close &= cls.almost_eq(channel1, channel2)
+ return close
+
+ @staticmethod
+ def almost_eq(x, y):
+ return torch.allclose(x, y, rtol=0, atol=2e-3)
+
+ @staticmethod
+ def allclose(x, y):
+ return torch.allclose(x, y, rtol=1e-5, atol=1e-8)
+
+ @staticmethod
+ def format_bytes(size):
+ # 2**10 = 1024
+ power = 2 ** 10
+ n = 0
+ power_labels = {0: '', 1: 'K', 2: 'M', 3: 'G', 4: 'T'}
+ while abs(size) > power:
+ size /= power
+ n += 1
+ suffix = power_labels[n] + 'iB'
+ return f"{size:.2f} {suffix}"
+
+ if __name__ == "__main__":
+ unittest.main()
diff --git a/torch_pconv/__init__.py b/torch_pconv/__init__.py
new file mode 100644
index 0000000..2c4b9ba
--- /dev/null
+++ b/torch_pconv/__init__.py
@@ -0,0 +1,5 @@
+from torch_pconv.pconv import PConv2d
+
+__all__ = [
+ "PConv2d",
+]
diff --git a/torch_pconv/pconv.py b/torch_pconv/pconv.py
new file mode 100644
index 0000000..2e35f68
--- /dev/null
+++ b/torch_pconv/pconv.py
@@ -0,0 +1,219 @@
+###############################################################################
+# BSD 3-Clause License
+#
+# Copyright (c) 2021, DesignStripe. All rights reserved.
+#
+# Author & Contact: Samuel Prevost (samuel@designstripe.com)
+###############################################################################
+
+from tensor_type import Tensor4d, Tensor3d, Tensor
+import math
+from typing import Tuple, Any
+import torch
+from torch import nn
+
+
+class PConv2d(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: int = 1,
+ stride: int = 1,
+ padding: int = 0,
+ dilation: int = 1,
+ bias: bool = False,
+ legacy_behaviour: bool = False,
+ ):
+ """Partial Convolution on 2D input.
+
+ :param in_channels: see torch.nn.Conv2d
+ :param out_channels: see torch.nn.Conv2d
+ :param kernel_size: see torch.nn.Conv2d
+ :param stride: see torch.nn.Conv2d
+ :param padding: see torch.nn.Conv2d
+ :param dilation: see torch.nn.Conv2d
+ :param bias: see torch.nn.Conv2d
+ :param legacy_behaviour: Tries to replicate Guilin's implementation's numerical error when handling the bias,
+ but in doing so, it does extraneous operations that could be avoided and still result in *almost* the same
+ result, at a tolerance of 0.00000458 % on the cuDNN 11.4 backend. Can safely be False for real life
+ applications.
+ """
+ super().__init__()
+
+ # Set this to True, and the output is guaranteed to be exactly the same as PConvGuilin and PConvRFR
+ # Set this to False, and the output will be very very close, but with some numerical errors removed/added,
+ # even though formally the maths are equivalent.
+ self.legacy_behaviour = legacy_behaviour
+
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.kernel_size = self._to_tuple(kernel_size)
+ self.stride = self._to_tuple(stride)
+ self.padding = self._to_tuple(padding)
+ self.dilation = self._to_tuple(dilation)
+ self.use_bias = bias
+
+ conv_kwargs = dict(
+ kernel_size=self.kernel_size,
+ stride=self.stride,
+ padding=self.padding,
+ dilation=self.dilation,
+ groups=1,
+ bias=False,
+ )
+
+ # Don't use a bias here, we handle the bias manually to speed up computation
+ self.regular_conv = nn.Conv2d(in_channels=self.in_channels, out_channels=self.out_channels, **conv_kwargs)
+
+ # I found a way to avoid doing a in_channels --> out_channels conv and instead just do a
+ # 1 channel in --> 1 channel out conv and then just scale the output of the conv by the number
+ # of input channels, and repeat the resulting tensor to have "out channels"
+ # This saves 1) a lot of memory because no need to pad before the conv
+ # 2) a lot of computation because the convolution is way smaller (in_c * out_c times less operations)
+ # It's also possible to avoid repeating the tensor to have "out channels", and instead use broadcasting
+ # when doing operations. This further reduces the number of operations to do and is equivalent,
+ # and especially the amount of memory used.
+ self.mask_conv = nn.Conv2d(in_channels=1, out_channels=1, **conv_kwargs)
+
+ # Inits
+ self.regular_conv.apply(
+ lambda m: nn.init.kaiming_normal_(m.weight, a=0, mode="fan_in")
+ )
+
+ # the mask convolution should be a constant operation
+ torch.nn.init.constant_(self.mask_conv.weight, 1.0)
+ for param in self.mask_conv.parameters():
+ param.requires_grad = False
+
+ if self.use_bias:
+ self.bias = nn.Parameter(torch.empty(1, self.out_channels, 1, 1))
+ else:
+ self.register_parameter("bias", None)
+
+ with torch.no_grad():
+ # This is how nn._ConvNd initialises its weights
+ nn.init.kaiming_uniform_(self.regular_conv.weight, a=math.sqrt(5))
+
+ if self.bias is not None:
+ fan_in, _ = nn.init._calculate_fan_in_and_fan_out(
+ self.regular_conv.weight
+ )
+ bound = 1 / math.sqrt(fan_in)
+ nn.init.uniform_(self.bias.view(self.out_channels), -bound, bound)
+
+ def forward(self, x: Tensor4d, mask: Tensor3d) -> Tuple[Tensor4d, Tensor3d]:
+ """Performs the 2D partial convolution.
+
+ About the mask:
+ - its dtype should be torch.float32
+ - its values should be EITHER 0.0 OR 1.0, not in between
+ - it should not have a channel dimensions. Just (batch, height, width).
+ The returned mask is guaranteed to also match these criteria.
+
+ This returns a tuple containing:
+ - the result of the partial convolution on the input x.
+ - the "updated mask", which is slightly "closed off". It is a "binary" mask of dtype float,
+ containing values of either 0.0 or 1.0 (nothing in between).
+
+ :param x: The input image batch, a 4d tensor of traditional batch, channel, height, width.
+ :param mask: This takes as input a 3d binary (0.0 OR 1.0) mask of dtype=float
+
+ :return: a tuple (output, updated_mask)
+ """
+ Tensor4d.check(x)
+ batch, channels, h, w = x.shape
+ Tensor[batch, h, w].check(mask)
+
+ if mask.dtype != torch.float32:
+ raise TypeError(
+ "mask should have dtype=torch.float32 with values being either 0.0 or 1.0"
+ )
+
+ if x.dtype != torch.float32:
+ raise TypeError("x should have dtype=torch.float32")
+
+ # Create singleton channel dimension for broadcasting
+ mask = mask.unsqueeze(1)
+
+ output = self.regular_conv(x * mask)
+ _, _, conv_h, conv_w = output.shape
+
+ update_mask: Tensor[batch, 1, conv_h, conv_w]
+ mask_ratio: Tensor[batch, 1, conv_h, conv_w]
+ with torch.no_grad():
+ mask_ratio, update_mask = self.compute_masks(mask)
+
+ if self.use_bias:
+ if self.legacy_behaviour:
+ # Doing this is entirely pointless. However, the legacy Guilin's implementation does it and
+ # if I don't do it, I get a relative numerical error of about 0.00000458 %
+ output += self.bias
+ output -= self.bias
+
+ output *= mask_ratio # Multiply by the sum(1)/sum(mask) ratios
+ output += self.bias # Add the bias *after* mask_ratio, not before !
+ output *= update_mask # Nullify pixels outside the valid mask
+ else:
+ output *= mask_ratio
+
+ return output, update_mask[:, 0]
+
+ def compute_masks(self, mask: Tensor3d) -> Tuple[Tensor4d, Tensor4d]:
+ """
+ This computes two masks:
+ - the update_mask is a binary mask that has 1 if the pixel was used in the convolution, and 0 otherwise
+ - the mask_ratio which has value sum(1)/sum(mask) if the pixel was used in the convolution, and 0 otherwise
+
+ * sum(1) means the sum of a kernel full of ones of equivalent size as the self.regular_conv's kernel.
+ It is usually calculated as self.in_channels * self.kernel_size ** 2, assuming a square kernel.
+ * sum(mask) means the sum of ones and zeros of the mask in a particular region.
+ If the region is entirely valid, then sum(mask) = sum(1) but if the region is only partially within the mask,
+ then 0 < sum(mask) < sum(1).
+ sum(mask) is calculated specifically in the vicinity of the pixel, and is pixel dependant.
+
+ * mask_ratio is Tensor4d with the channel dimension as a singleton, and is NOT binary.
+ It has values between 0 and sum(1) (included).
+ * update_mask is a Tensor4d with the channel dimension as a singleton, and is "binary" (either 0.0 or 1.0).
+
+ :param mask: the input "binary" mask. It has to be a dtype=float32, but containing only values 0.0 or 1.0.
+ :return: mask_ratio, update_mask
+ """
+ update_mask = self.mask_conv(mask) * self.in_channels
+ # Make values where update_mask==0 be super high
+ # and otherwise computes the sum(ones)/sum(mask) value for other entries
+ # noinspection PyTypeChecker
+ mask_ratio = self.in_channels * self.kernel_size[0] * self.kernel_size[1] / (update_mask + 1e-8)
+ # Once we've normalised the values in update_mask and saved them elsewhere, we can now ignore their value
+ # and return update_mask to a binary mask
+ update_mask = torch.clamp(update_mask, 0, 1)
+ # Then multiplies those super high values by zero so we cancel them out
+ mask_ratio *= update_mask
+ # We can discard the extra channel dimension what was just there to help with broadcasting
+
+ return mask_ratio, update_mask
+
+ @staticmethod
+ def _to_tuple(v: Any) -> Tuple[Any, Any]:
+ if not isinstance(v, tuple):
+ return v, v
+ else:
+ return v
+
+ def set_weight(self, w):
+ with torch.no_grad():
+ self.regular_conv.weight.copy_(w)
+
+ return self
+
+ def set_bias(self, b):
+ with torch.no_grad():
+ self.bias.copy_(b.view(1, self.out_channels, 1, 1))
+
+ return self
+
+ def get_weight(self):
+ return self.regular_conv.weight
+
+ def get_bias(self):
+ return self.bias