Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable wheels for MacOS #35

Merged
merged 35 commits into from
Dec 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
a6acc01
test setup python
mavenlin Dec 4, 2023
6909e20
introducing multi-platform
mavenlin Dec 4, 2023
4f2aa9b
test build for macos
mavenlin Dec 4, 2023
a855700
trigger build
mavenlin Dec 4, 2023
cad212b
test build on macos
mavenlin Dec 4, 2023
eced1f8
fix
mavenlin Dec 4, 2023
2372769
test macos-11
mavenlin Dec 4, 2023
348003f
remove workflow since it is verified
mavenlin Dec 4, 2023
e9a2726
test setup python on mac
mavenlin Dec 4, 2023
16306dd
build on mac and linux
mavenlin Dec 4, 2023
bf5ec28
rename workflow
mavenlin Dec 4, 2023
07dc4ec
test windows pipeline
mavenlin Dec 4, 2023
3363b9f
validate python version
mavenlin Dec 4, 2023
7e45b35
try out this
mavenlin Dec 4, 2023
075ad92
remove validation
mavenlin Dec 4, 2023
b85d202
debug
mavenlin Dec 4, 2023
1315323
debug
mavenlin Dec 4, 2023
c70648f
debug
mavenlin Dec 4, 2023
09c9371
fix abi error for windows
mavenlin Dec 4, 2023
c94b46d
restore
mavenlin Dec 4, 2023
d802912
f**k windows
mavenlin Dec 4, 2023
ce49956
Update .bazelrc
mavenlin Dec 4, 2023
5e020f1
Use clang
mavenlin Dec 4, 2023
6e5db42
split into many rules for better caching
mavenlin Dec 4, 2023
6185968
pass different flag for windows
mavenlin Dec 4, 2023
0b4de9b
debug 3.9 only
mavenlin Dec 4, 2023
613ecc3
cat link params
mavenlin Dec 4, 2023
85a20d8
verbose
mavenlin Dec 4, 2023
ee33187
export all symbols
mavenlin Dec 4, 2023
4810c8f
directly add features
mavenlin Dec 4, 2023
2e04b0b
try cat again
mavenlin Dec 4, 2023
5ae2d8a
don't run windows build
mavenlin Dec 4, 2023
4624871
nothing changed, but just build for more platforms
mavenlin Dec 4, 2023
d031b81
fix indent
mavenlin Dec 4, 2023
c723727
fix order
mavenlin Dec 4, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions .bazelrc
Original file line number Diff line number Diff line change
@@ -1,3 +1,14 @@
build --copt=-g0 --copt=-O3 --copt=-DNDEBUG
build --action_env=BAZEL_LINKLIBS=-l%:libstdc++.a:-lm
build --action_env=BAZEL_LINKOPTS=-static-libgcc
build:linux --copt=-g0 --copt=-O3 --copt=-DNDEBUG
build:macos_x86_64 --copt=-g0 --copt=-O3 --copt=-DNDEBUG
build:macos_arm64 --copt=-g0 --copt=-O3 --copt=-DNDEBUG
build:windows_x86_64 -c opt --compiler=clang-cl

build:linux --action_env=BAZEL_LINKLIBS=-l%:libstdc++.a:-lm
build:linux --action_env=BAZEL_LINKOPTS=-static-libgcc
build:linux --define os=linux --define cpu=x86_64

build:macos_x86_64 --define os=macos --define cpu=x86_64
build:macos_arm64 --define os=macos --define cpu=arm64

build:windows_x86_64 --define os=windows --define cpu=x86_64
build:windows_arm64 --define os=windows --define cpu=arm64
19 changes: 10 additions & 9 deletions .github/workflows/release.yml → .github/workflows/linux.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Action name
name: Release Wheel
name: Release for Linux

on:
push:
Expand All @@ -11,24 +11,25 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ['3.9', '3.10', '3.11']
container:
image: ghcr.io/sail-sg/jax-xc-image:latest
python-version: ['3.9', '3.10', '3.11', '3.12']
steps:
- name: Cancel previous run
uses: styfle/[email protected]
with:
access_token: ${{ github.token }}
- uses: actions/checkout@v2
- name: Setup Python-${{ matrix.python-version }} and Build
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Build
run: |
eval "$(pyenv init -)" && pyenv global ${{ matrix.python-version }}-dev
pip install -r requirements.txt
bazel build --remote_cache=http://${{ secrets.BAZEL_CACHE }}:8080 @maple2jax//:jax_xc_wheel
bazel build --config=linux --remote_cache=http://${{ secrets.BAZEL_CACHE }}:8080 @maple2jax//:jax_xc_wheel
- name: Upload artifact
uses: actions/upload-artifact@main
with:
name: wheel
name: linux_wheel
path: bazel-bin/external/maple2jax/*.whl
publish:
runs-on: ubuntu-latest
Expand All @@ -37,7 +38,7 @@ jobs:
- name: Download artifact
uses: actions/download-artifact@main
with:
name: wheel
name: linux_wheel
path: dist
- name: Publish package to PyPI
uses: pypa/gh-action-pypi-publish@release/v1
Expand Down
45 changes: 45 additions & 0 deletions .github/workflows/macos.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
name: Release for Macos

on:
push:
tags:
- 'v[0-9]+\.[0-9]+\.[0-9]+'

jobs:
release:
runs-on: macos-11
strategy:
matrix:
python-version: ['3.9', '3.10', '3.11', '3.12']
steps:
- name: Cancel previous run
uses: styfle/[email protected]
with:
access_token: ${{ github.token }}
- uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Build
run: |
pip install -r requirements.txt
bazel build --config=macos_x86_64 --remote_cache=http://${{ secrets.BAZEL_CACHE }}:8080 @maple2jax//:jax_xc_wheel
- name: Upload artifact
uses: actions/upload-artifact@main
with:
name: macos_wheel
path: bazel-bin/external/maple2jax/*.whl
publish:
runs-on: ubuntu-latest
needs: release
steps:
- name: Download artifact
uses: actions/download-artifact@main
with:
name: macos_wheel
path: dist
- name: Publish package to PyPI
uses: pypa/gh-action-pypi-publish@release/v1
with:
password: ${{ secrets.PYPI_API_TOKEN }}
11 changes: 6 additions & 5 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,17 @@ on: [push, pull_request]
jobs:
test:
runs-on: ubuntu-latest
container:
image: ghcr.io/sail-sg/jax-xc-image:latest
steps:
- name: Cancel previous run
uses: styfle/[email protected]
with:
access_token: ${{ github.token }}
- uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: 3.12
- name: Test
run: |
eval "$(pyenv init -)" && pyenv global 3.11-dev
pip install --upgrade -r requirements.txt
bazel test --test_output=all --remote_cache=http://${{ secrets.BAZEL_CACHE }}:8080 //tests/...
pip install -r requirements.txt
bazel test --config=linux --test_output=all --remote_cache=http://${{ secrets.BAZEL_CACHE }}:8080 //tests/...
31 changes: 31 additions & 0 deletions .github/workflows/windows.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
name: Release for Windows

# Don't run yet, it fails at linking, very difficult to debug with only a remote runner.
# on: [push]

jobs:
release:
runs-on: windows-2019
strategy:
matrix:
python-version: ['3.9']
steps:
- name: Cancel previous run
uses: styfle/[email protected]
with:
access_token: ${{ github.token }}
- uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Build
run: |
pip install -r requirements.txt
bazel build --config=windows_x86_64 --remote_cache=http://${{ secrets.BAZEL_CACHE }}:8080 @maple2jax//:jax_xc_wheel
shell: pwsh
- name: Upload artifact
uses: actions/upload-artifact@main
with:
name: macos_wheel
path: bazel-bin/external/maple2jax/*.whl
13 changes: 13 additions & 0 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,19 @@ workspace(name = "jax_xc")

load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")

http_archive(
name = "bazel_skylib",
sha256 = "cd55a062e763b9349921f0f5db8c3933288dc8ba4f76dd9416aac68acee3cb94",
urls = [
"https://mirror.bazel.build/github.com/bazelbuild/bazel-skylib/releases/download/1.5.0/bazel-skylib-1.5.0.tar.gz",
"https://github.com/bazelbuild/bazel-skylib/releases/download/1.5.0/bazel-skylib-1.5.0.tar.gz",
],
)

load("@bazel_skylib//:workspace.bzl", "bazel_skylib_workspace")

bazel_skylib_workspace()

http_archive(
name = "rules_python",
sha256 = "8c8fe44ef0a9afc256d1e75ad5f448bb59b81aba149b8958f02f7b3a98f5d9b4",
Expand Down
2 changes: 1 addition & 1 deletion maple2jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@
from .functionals import * # noqa
from . import experimental # noqa

__version__ = "0.0.9"
__version__ = "0.0.10"
50 changes: 35 additions & 15 deletions maple2jax/libxc/build.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,18 @@ cc_library(
visibility = ["//visibility:public"],
)

cc_library(
name = "register",
hdrs = [
"src_cc/register.h",
],
deps = [
":xc_inc",
"@visit_struct",
"@pybind11",
],
)

{% for src in c_file_basenames %}
genrule(
name = "gen_{{ src }}c",
Expand All @@ -44,16 +56,14 @@ genrule(
tools = [":wrap", ":wrap.cc.jinja"],
)

{% endfor %}

cc_binary(
name = "libxc.so",
copts = ["-std=c++14", "-fexceptions"],
features = [
"-use_header_modules", # Required for pybind11.
"-parse_headers",
cc_library(
name = "{{ src }}c.obj",
srcs = ["src_cc/{{ src }}c"],
features = ["windows_export_all_symbols"],
deps = [
":xc_inc",
":register",
],
linkshared = 1,
includes = [
".",
"src",
Expand All @@ -64,19 +74,29 @@ cc_binary(
"XC_DONT_COMPILE_KXC",
"XC_DONT_COMPILE_LXC",
],
alwayslink = True,
)

{% endfor %}

pybind_extension(
name = "libxc",
copts = select({
"@platforms//os:windows": [],
"//conditions:default": ["-std=c++14"],
}),
features = ["windows_export_all_symbols"],
deps = [
":xc_inc",
"@visit_struct",
"@pybind11",
"@local_config_python//:python_headers",
":register",
{% for basename in c_file_basenames %}
":{{ basename }}c.obj",
{% endfor %}
],
srcs = [
"src_cc/register.h",
"src_cc/register.cc",
"src_cc/libxc.cc",
{% for basename in c_file_basenames %}
"src_cc/{{ basename }}c",
{% endfor %}
],
visibility = ["//visibility:public"],
)
Expand Down
4 changes: 2 additions & 2 deletions maple2jax/libxc/gen_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ def main(_):
else:
c_file_basenames.append(basename)

with open(FLAGS.template, "r") as f:
with open(FLAGS.template, "r", encoding="utf8") as f:
template = Template(f.read(), trim_blocks=True, lstrip_blocks=True)
build = template.render(c_file_basenames=c_file_basenames)
with open(FLAGS.build, "w") as out:
with open(FLAGS.build, "w", encoding="utf8") as out:
out.write(build)


Expand Down
6 changes: 3 additions & 3 deletions maple2jax/libxc/wrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@


def wrap_file(filename, out):
with open(filename, "r") as f:
with open(filename, "r", encoding="utf8") as f:
content = f.read()
# find all init function and the corresponding param struct name
results = re.findall(
Expand Down Expand Up @@ -74,14 +74,14 @@ def wrap_file(filename, out):
fields.extend(members)
register_struct.append((s, fields, struct_to_init[s]))

with open(FLAGS.template, "r") as f:
with open(FLAGS.template, "r", encoding="utf8") as f:
t = Template(f.read(), trim_blocks=True, lstrip_blocks=True)
content = t.render(
filename=os.path.basename(filename),
register_struct=register_struct,
register_maple=register_maple,
)
with open(FLAGS.out, "wt") as fout:
with open(FLAGS.out, "wt", encoding="utf8") as fout:
fout.write(content)


Expand Down
6 changes: 5 additions & 1 deletion maple2jax/python.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,11 @@ def _get_abi_tag(rctx, python_bin):
"version = platform.python_version_tuple();" +
"print(f'cp{version[0]}{version[1]}{sys.abiflags}')",
])
return result.stdout.splitlines()[0]
lines = result.stdout.splitlines()
if len(lines) == 0:
return ""
else:
return lines[0]

def _declare_python_abi_impl(rctx):
python_bin = _get_python_bin(rctx)
Expand Down
43 changes: 41 additions & 2 deletions maple2jax/wheel.BUILD
Original file line number Diff line number Diff line change
@@ -1,6 +1,39 @@
load("@bazel_skylib//lib:selects.bzl", "selects")
load("@python_abi//:abi.bzl", "abi_tag", "python_tag")
load("@rules_python//python:packaging.bzl", "py_wheel")

selects.config_setting_group(
name = "macos_arm64",
match_all = [
"@platforms//os:macos",
"@platforms//cpu:arm64",
],
)

selects.config_setting_group(
name = "macos_x86_64",
match_all = [
"@platforms//os:macos",
"@platforms//cpu:x86_64",
],
)

selects.config_setting_group(
name = "linux_x86_64",
match_all = [
"@platforms//os:linux",
"@platforms//cpu:x86_64",
],
)

selects.config_setting_group(
name = "windows_x86_64",
match_all = [
"@platforms//os:windows",
"@platforms//cpu:x86_64",
],
)

py_wheel(
name = "jax_xc_wheel",
abi = abi_tag(),
Expand All @@ -12,7 +45,13 @@ py_wheel(
],
description_file = "@jax_xc//:README.rst",
distribution = "jax_xc",
platform = "manylinux_2_17_x86_64",
platform = select({
":macos_arm64": "macosx_11_0_arm64",
":macos_x86_64": "macosx_11_0_x86_64",
":linux_x86_64": "manylinux_2_17_x86_64",
":windows_x86_64": "win_amd64",
"//conditions:default": "manylinux_2_17_x86_64",
}),
python_requires = ">=3.9",
python_tag = python_tag(),
requires = [
Expand All @@ -21,7 +60,7 @@ py_wheel(
"tensorflow-probability",
"autofd",
],
version = "0.0.9",
version = "0.0.10",
deps = [
"@maple2jax//jax_xc",
"@maple2jax//jax_xc:experimental",
Expand Down
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ tensorflow-probability
jinja2
absl-py
numpy
pyscf
regex
jaxtyping
autofd
Loading