Skip to content

Commit

Permalink
test+feat: minimalist cuda, pybind11 support
Browse files Browse the repository at this point in the history
  • Loading branch information
bionicles committed Jun 4, 2024
1 parent d292249 commit eedb838
Show file tree
Hide file tree
Showing 6 changed files with 78 additions and 7 deletions.
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,14 @@ test_parallel:
time (py.test --durations=0 -n $(N_WORKERS) --cov=tree_plus_src --cov-report=term-missing --cov-report=lcov:coverage/lcov.info -vv tests/test_*.py)

# sequential unit tests (for CI)
test_sequential:p
test_sequential:
pytest tests/test_more_language_units.py tests/test_units.py tests/test_engine.py -vv

# just to crank on language features, easy to debug on this
test_more_languages:
pytest tests/test_more_language_units.py -vv

test: test_parallel
test: test_sequential
# test: test_sequential test_tp_dotdot test_e2e test_cli test_programs test_deploy

# first we'll do our unit tests (most likely to need fast debug)
Expand Down
10 changes: 10 additions & 0 deletions tests/more_languages/group6/cpp_examples_impl.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#include "pybind11/pybind11.h"

#include "cpp_examples_impl.h"

PYBIND11_MODULE(cpp_examples, m)
{
m.doc() = "pybind11 cpp_examples plugin"; // module docstring

m.def("add", &add<int>, "An example function to add two numbers.");
}
10 changes: 10 additions & 0 deletions tests/more_languages/group6/cpp_examples_impl.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#include "cpp_examples_impl.h"

template <typename T>
T add(T a, T b) { return a + b; }

template <>
int add<int>(int a, int b)
{
return a + b;
}
7 changes: 7 additions & 0 deletions tests/more_languages/group6/cpp_examples_impl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
#pragma once

template <typename T>
T add(T a, T b);

template <>
int add<int>(int, int);
30 changes: 28 additions & 2 deletions tests/test_more_language_units.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@
"void globalGreet()",
"int main()",
"void printMessage(const std::string &message)",
"template<typename T>",
"void printVector(const std::vector<T>& vec)",
"""template<typename T>
void printVector(const std::vector<T>& vec)""",
"struct Point",
" Point(int x, int y) : x(x), y(y)",
"class Animal",
Expand Down Expand Up @@ -1853,6 +1853,31 @@ def test_more_languages_group_lisp(
"tests/more_languages/group6/python_complex_class.py",
["class Box(Space[NDArray[Any]])"],
),
(
"tests/more_languages/group6/cpp_examples_impl.cu",
[
"""template <typename T>
T add(T a, T b)""",
"""template <>
int add<int>(int a, int b)""",
],
),
(
"tests/more_languages/group6/cpp_examples_impl.h",
[
"""template <typename T>
T add(T a, T b)""",
"""template <>
int add<int>(int, int)""",
],
),
(
"tests/more_languages/group6/cpp_examples_impl.cc",
[
"PYBIND11_MODULE(cpp_examples, m)",
' m.def("add", &add<int>, "An example function to add two numbers.")',
],
),
],
)
def test_more_languages_group_6(
Expand Down Expand Up @@ -1976,6 +2001,7 @@ def test_more_languages_isabelle_symbol_replacement():
"void EnableXlaCompilation()",
"bool FailOnXlaCompilation()",
"#define TF_PY_DECLARE_FLAG(flag_name)",
"PYBIND11_MODULE(flags_pybind, m)",
]


Expand Down
24 changes: 21 additions & 3 deletions tree_plus_src/parse_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
TEXTCHARS = bytearray({7, 8, 9, 10, 12, 13, 27} | set(range(0x20, 0x100)) - {0x7F})
LISP_EXTENSIONS = {".lisp", ".clj", ".scm", ".el", ".rkt"}
JS_EXTENSIONS = {".js", ".jsx", ".ts", ".tsx"}
C_EXTENSIONS = {".c", ".cpp", ".cc", ".h"}
C_EXTENSIONS = {".c", ".cpp", ".cc", ".h", ".cu"}
COBOL_EXTENSIONS = {".cbl", ".cobol"}
FORTRAN_EXTENSIONS = {
".f",
Expand Down Expand Up @@ -352,7 +352,8 @@ def parse_c(contents: str) -> List[str]:
# Functions first (most common)
r"^(?P<function> *(?P<modifier>[\w:]+ )?(?P<function_return_type>[\w:*&]+(?P<generics>\s?<[^>]*>\s?)? )(?P<function_name>[\w*&[\]]+)\([^\)]*\)(?=\s{))|"
# templates
r"^(?P<template>template ?<.*?>[^\{^;^=\n]*(?=\s))|"
# r"^(?P<template>template ?<.*?>[^\{^;^=\n]*(?=\s))|"
r"^(?P<template>(?P<template_body>template ?<.*?>[\s\S]*?)(?P<tail>;|{|\)))|"
# hashtag macros
r"^(?P<macro>#(?:define)(?P<invocation>\s\w+( ?\w* ?\(.*\))?)?)|"
# Methods
Expand All @@ -368,6 +369,10 @@ def parse_c(contents: str) -> List[str]:
r"^(?P<enum>(?:enum) (?:class )?\w+(?: : \w+)?)|"
# public or private sections
r"^(?P<public_or_private> *(public|private):)|"
# pybind modules
r"^(?P<pybind11>PYBIND11_MODULE[\s\S]*?)(?={)|"
# pybind defs
r"^(?P<def> *\w+\.def\(\"[\s\S]*?(?=;))|"
# static definitions seem important
r"^(?P<other_static>static (struct )?(?P<static_kind>\w+) \w+(\[\])?(?= =))",
# functions
Expand Down Expand Up @@ -413,16 +418,29 @@ def parse_c(contents: str) -> List[str]:
component = groups["enum"]
public_or_private = None # right?
elif "template" in groups:
component = groups["template"].rstrip("\n")
component = (
groups["template_body"]
.rstrip("\n")
.rstrip(" ")
.rstrip("\n")
.rstrip("\n")
)
if groups["tail"].startswith(")"):
component += ")"
public_or_private = None # right?
elif "macro" in groups:
component = groups["macro"]
public_or_private = None # right?
elif "other_static" in groups:
component = groups["other_static"]
elif "pybind11" in groups:
component = groups["pybind11"]
elif "def" in groups:
component = groups["def"]
if component:
debug_print(f"{component=}")
component = component
component = component.rstrip("\n").rstrip(" ")
components.append(component)

return components
Expand Down

0 comments on commit eedb838

Please sign in to comment.