Skip to content

Commit

Permalink
[WIP] Pandas.jl port
Browse files Browse the repository at this point in the history
  • Loading branch information
tylerjthomas9 committed Nov 25, 2023
1 parent 9ee5bf5 commit db03b13
Show file tree
Hide file tree
Showing 19 changed files with 1,243 additions and 174 deletions.
2 changes: 1 addition & 1 deletion CondaPkg.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ cucim = "=23.10"
cuspatial = "=23.10"
cugraph = "=23.10"
cuml = "=23.10"
python = ">=3.9,<=3.10"
python = ">=3.9,<3.11"

[deps.cuda-version]
channel = "conda-forge"
Expand Down
18 changes: 16 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,29 @@ version = "0.5.0"

[deps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
CondaPkg = "992eb4ea-22a4-4c89-a5bb-47a3300528ab"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
DataValues = "e7dc6d0d-1eca-5fa6-8ad6-5aecde8b7ea5"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
IteratorInterfaceExtensions = "82899510-4779-5014-852e-03e436cf321d"
Lazy = "50d2b5c4-7a5e-59d5-8109-a42b560f39c0"
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
MLJTestInterface = "72560011-54dd-4dc2-94f3-c5de45b75ecd"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d"
Revise = "295af30f-e4ad-537b-8983-00126c2a3abe"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
TableTraits = "3783bdb8-4a98-5b6b-af9a-565f29a5fe9c"
TableTraitsUtils = "382cd787-c1b6-5bf2-a167-d5b971a19bda"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"

[compat]
Aqua = "0.7"
Aqua = "0.8"
CUDA = "3, 4, 5"
CondaPkg = "0.2"
DataFrames = "1.6"
MLJBase = "1"
MLJModelInterface = "1"
PythonCall = "0.9"
Expand All @@ -23,9 +36,10 @@ julia = "1.8"

[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
MLJTestInterface = "72560011-54dd-4dc2-94f3-c5de45b75ecd"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "Aqua", "MLJBase", "MLJTestInterface"]
test = ["Test", "Aqua", "DataFrames", "MLJBase", "MLJTestInterface"]
83 changes: 9 additions & 74 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,6 @@
[![Lifecycle:Maturing](https://img.shields.io/badge/Lifecycle-Maturing-007EC6)](https://github.com/bcgov/repomountie/blob/master/doc/lifecycle-badges.md)
[![Code Style: YASGuide](https://img.shields.io/badge/code%20style-yas-violet.svg)](https://github.com/jrevels/YASGuide)


:warning: RAPIDS.jl is only supported on Julia 1.8.5+. For previous Julia versions, you have to manually upgrade to libraries from GCC 12.

# RAPIDS.jl
Unofficial Julia wrapper for the [RAPIDS.ai](https://rapids.ai/index.html) ecosystem.

Expand Down Expand Up @@ -42,9 +39,14 @@ julia> ]add https://github.com/tylerjthomas9/RAPIDS.jl
julia> using Pkg; Pkg.add(url="https://github.com/tylerjthomas9/RAPIDS.jl")
```

## Julia Interfaces

- `CuDF`
- `CuML`

## Python API

You can access the following python libraries with their standard syntax:
You can access the following python libraries with their standard Python syntax:
- `cupy`
- `cudf`
- `cuml`
Expand All @@ -55,76 +57,9 @@ You can access the following python libraries with their standard syntax:
- `dask_cuda`
- `dask_cudf`
- `numpy`
- `pandas` (cudf pandas)
- `pickle`

Here is an example of using `LogisticRegression`, `make_classification` via the Python API.

```julia
using RAPIDS
const make_classification = cuml.datasets.classification.make_classification

X_py, y_py = make_classification(n_samples=200, n_features=4,
n_informative=2, n_classes=2)
lr = cuml.LogisticRegression(max_iter=100)
lr.fit(X_py, y_py)
preds = lr.predict(X_py)

print(lr.coef_)
```

## MLJ Interface

A MLJ interface is also available for supported models. The model hyperparameters are the same as described in the [cuML docs](https://docs.rapids.ai/api/cuml/stable/api.html). The only difference is that the models will always input/output numpy arrays, which will be converted back to Julia arrays (`output_type="input"`).

```julia
using MLJBase
using RAPIDS.CuML
const make_classification = cuml.datasets.classification.make_classification

X_py, y_py = make_classification(n_samples=200, n_features=4,
n_informative=2, n_classes=2)
X = RAPIDS.pyconvert(Matrix{Float32}, X_py.get())
y = RAPIDS.pyconvert(Vector{Float32}, y_py.get().flatten())

lr = LogisticRegression(max_iter=100)
mach = machine(lr, X, y)
fit!(mach)
preds = predict(mach, X)

print(mach.fitresult.coef_)
```

MLJ Support:
- Clustering
- `KMeans`
- `DBSCAN`
- `AgglomerativeClustering`
- `HDBSCAN`
- Classification
- `LogisticRegression`
- `MBSGDClassifier`
- `RandomForestClassifier`
- `SVC`
- `LinearSVC`
- `KNeighborsClassifier`
- Regression
- `LinearRegression`
- `Ridge`
- `Lasso`
- `ElasticNet`
- `MBSGDRegressor`
- `RandomForestRegressor`
- `CD`
- `SVR`
- `LinearSVR`
- `KNeighborsRegressor`
- Dimensionality Reduction
- `PCA`
- `IncrementalPCA`
- `TruncatedSVD`
- `UMAP`
- `TSNE`
- `GaussianRandomProjection`
- Time Series
- `ExponentialSmoothing`
- `ARIMA`
## Known Issues
- RAPIDS.jl is only supported on Julia 1.8.5+. For previous Julia versions, you have to manually upgrade to libraries from GCC 12.
104 changes: 46 additions & 58 deletions format/Manifest.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# This file is machine-generated - editing it directly is not advised

julia_version = "1.8.5"
julia_version = "1.9.4"
manifest_format = "2.0"
project_hash = "30b405be1c677184b7703a9bfb3d2100029ccad0"

Expand All @@ -21,21 +21,23 @@ uuid = "00ebfdb7-1f24-5e51-bd34-a7502290713f"
version = "3.3.6"

[[deps.CommonMark]]
deps = ["Crayons", "JSON", "SnoopPrecompile", "URIs"]
git-tree-sha1 = "e2f4627b0d3f2c1876360e0b242a7c23923b469d"
deps = ["Crayons", "JSON", "PrecompileTools", "URIs"]
git-tree-sha1 = "532c4185d3c9037c0237546d817858b23cf9e071"
uuid = "a80b9123-70ca-4bc0-993e-6e3bcb318db6"
version = "0.8.10"
version = "0.8.12"

[[deps.Compat]]
deps = ["Dates", "LinearAlgebra", "UUIDs"]
git-tree-sha1 = "7a60c856b9fa189eb34f5f8a6f6b5529b7942957"
deps = ["UUIDs"]
git-tree-sha1 = "8a62af3e248a8c4bad6b32cbbe663ae02275e32c"
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
version = "4.6.1"
version = "4.10.0"

[[deps.CompilerSupportLibraries_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae"
version = "1.0.1+0"
[deps.Compat.extensions]
CompatLinearAlgebraExt = "LinearAlgebra"

[deps.Compat.weakdeps]
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

[[deps.Crayons]]
git-tree-sha1 = "249fe38abf76d48563e2f4556bebd215aa317e15"
Expand All @@ -44,9 +46,9 @@ version = "4.1.1"

[[deps.DataStructures]]
deps = ["Compat", "InteractiveUtils", "OrderedCollections"]
git-tree-sha1 = "d1fff3a548102f48987a52a2e0d114fa97d730f0"
git-tree-sha1 = "3dbd312d370723b6bb43ba9d02fc36abade4518d"
uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
version = "0.18.13"
version = "0.18.15"

[[deps.Dates]]
deps = ["Printf"]
Expand Down Expand Up @@ -76,20 +78,20 @@ uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
version = "0.21.4"

[[deps.JuliaFormatter]]
deps = ["CSTParser", "CommonMark", "DataStructures", "Glob", "Pkg", "SnoopPrecompile", "Tokenize"]
git-tree-sha1 = "0f6545dd63fec03d0cfe0c1d28f851e2d804e942"
deps = ["CSTParser", "CommonMark", "DataStructures", "Glob", "Pkg", "PrecompileTools", "Tokenize"]
git-tree-sha1 = "3d5b5b539e4606dcca0e6a467b98a64c8da4850b"
uuid = "98e50ef6-434e-11e9-1051-2b60c6c9e899"
version = "1.0.25"
version = "1.0.42"

[[deps.LibCURL]]
deps = ["LibCURL_jll", "MozillaCACerts_jll"]
uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21"
version = "0.6.3"
version = "0.6.4"

[[deps.LibCURL_jll]]
deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"]
uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0"
version = "7.84.0+0"
version = "8.4.0+0"

[[deps.LibGit2]]
deps = ["Base64", "NetworkOptions", "Printf", "SHA"]
Expand All @@ -98,15 +100,11 @@ uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"
[[deps.LibSSH2_jll]]
deps = ["Artifacts", "Libdl", "MbedTLS_jll"]
uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8"
version = "1.10.2+0"
version = "1.11.0+1"

[[deps.Libdl]]
uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb"

[[deps.LinearAlgebra]]
deps = ["Libdl", "libblastrampoline_jll"]
uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

[[deps.Logging]]
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"

Expand All @@ -117,45 +115,46 @@ uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"
[[deps.MbedTLS_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1"
version = "2.28.0+0"
version = "2.28.2+0"

[[deps.Mmap]]
uuid = "a63ad114-7e13-5084-954f-fe012c677804"

[[deps.MozillaCACerts_jll]]
uuid = "14a3606d-f60d-562e-9121-12d972cd8159"
version = "2022.2.1"
version = "2022.10.11"

[[deps.NetworkOptions]]
uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908"
version = "1.2.0"

[[deps.OpenBLAS_jll]]
deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"]
uuid = "4536629a-c528-5b80-bd46-f80d51c5b363"
version = "0.3.20+0"

[[deps.OrderedCollections]]
git-tree-sha1 = "d321bf2de576bf25ec4d3e4360faca399afca282"
git-tree-sha1 = "2e73fe17cac3c62ad1aebe70d44c963c3cfdc3e3"
uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
version = "1.6.0"
version = "1.6.2"

[[deps.Parsers]]
deps = ["Dates", "SnoopPrecompile"]
git-tree-sha1 = "478ac6c952fddd4399e71d4779797c538d0ff2bf"
deps = ["Dates", "PrecompileTools", "UUIDs"]
git-tree-sha1 = "a935806434c9d4c506ba941871b327b96d41f2bf"
uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0"
version = "2.5.8"
version = "2.8.0"

[[deps.Pkg]]
deps = ["Artifacts", "Dates", "Downloads", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"]
deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"]
uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
version = "1.8.0"
version = "1.9.2"

[[deps.PrecompileTools]]
deps = ["Preferences"]
git-tree-sha1 = "03b4c25b43cb84cee5c90aa9b5ea0a78fd848d2f"
uuid = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
version = "1.2.0"

[[deps.Preferences]]
deps = ["TOML"]
git-tree-sha1 = "47e5f437cc0e7ef2ce8406ce1e7e24d44915f88d"
git-tree-sha1 = "00805cd429dcb4870060ff49ef443486c262e38e"
uuid = "21216c6a-2e73-6563-6e65-726566657250"
version = "1.3.0"
version = "1.4.1"

[[deps.Printf]]
deps = ["Unicode"]
Expand All @@ -176,34 +175,28 @@ version = "0.7.0"
[[deps.Serialization]]
uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"

[[deps.SnoopPrecompile]]
deps = ["Preferences"]
git-tree-sha1 = "e760a70afdcd461cf01a575947738d359234665c"
uuid = "66db9d55-30c0-4569-8b51-7e840670fc0c"
version = "1.0.3"

[[deps.Sockets]]
uuid = "6462fe0b-24de-5631-8697-dd941f90decc"

[[deps.TOML]]
deps = ["Dates"]
uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76"
version = "1.0.0"
version = "1.0.3"

[[deps.Tar]]
deps = ["ArgTools", "SHA"]
uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e"
version = "1.10.1"
version = "1.10.0"

[[deps.Tokenize]]
git-tree-sha1 = "90538bf898832b6ebd900fa40f223e695970e3a5"
git-tree-sha1 = "0454d9a9bad2400c7ccad19ca832a2ef5a8bc3a1"
uuid = "0796e94c-ce3b-5d07-9a54-7f471281c624"
version = "0.5.25"
version = "0.5.26"

[[deps.URIs]]
git-tree-sha1 = "074f993b0ca030848b897beff716d93aca60f06a"
git-tree-sha1 = "67db6cc7b3821e19ebe75791a9dd19c9b1188f2b"
uuid = "5c2747f8-b7ea-4ff2-ba2e-563bfd36b1d4"
version = "1.4.2"
version = "1.5.1"

[[deps.UUIDs]]
deps = ["Random", "SHA"]
Expand All @@ -215,17 +208,12 @@ uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"
[[deps.Zlib_jll]]
deps = ["Libdl"]
uuid = "83775a58-1f1d-513f-b197-d71354ab007a"
version = "1.2.12+3"

[[deps.libblastrampoline_jll]]
deps = ["Artifacts", "Libdl", "OpenBLAS_jll"]
uuid = "8e850b90-86db-534c-a0d3-1478176c7d93"
version = "5.1.1+0"
version = "1.2.13+0"

[[deps.nghttp2_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d"
version = "1.48.0+0"
version = "1.52.0+1"

[[deps.p7zip_jll]]
deps = ["Artifacts", "Libdl"]
Expand Down
Loading

0 comments on commit db03b13

Please sign in to comment.