-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
9 changed files
with
206 additions
and
100 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,8 +3,10 @@ language: julia | |
os: | ||
- linux | ||
- osx | ||
- windows | ||
julia: | ||
- 1.0 | ||
- 1 | ||
- nightly | ||
notifications: | ||
email: false | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,6 +3,9 @@ uuid = "62fd8b95-f654-4bbd-a8a5-9c27f68ccd50" | |
authors = ["Tim Holy <[email protected]>"] | ||
version = "0.1.0" | ||
|
||
[deps] | ||
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" | ||
|
||
[compat] | ||
julia = "1" | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,8 @@ | ||
# TensorCore | ||
|
||
[![Stable](https://img.shields.io/badge/docs-stable-blue.svg)](https://JuliaMath.github.io/TensorCore.jl/stable) | ||
[![Dev](https://img.shields.io/badge/docs-dev-blue.svg)](https://JuliaMath.github.io/TensorCore.jl/dev) | ||
[![Build Status](https://travis-ci.com/JuliaMath/TensorCore.jl.svg?branch=master)](https://travis-ci.com/JuliaMath/TensorCore.jl) | ||
[![Codecov](https://codecov.io/gh/JuliaMath/TensorCore.jl/branch/master/graph/badge.svg)](https://codecov.io/gh/JuliaMath/TensorCore.jl) | ||
|
||
This package is intended as a lightweight foundation for tensor operations across the Julia ecosystem. | ||
Currently it exports two core operations, `hadamard` and `tensor`, and corresponding unicode operators `⊙` and `⊗`, respectively. |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,5 @@ | ||
[deps] | ||
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" | ||
|
||
[compat] | ||
Documenter = "0.24" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,7 +9,6 @@ makedocs(; | |
repo="https://github.com/JuliaMath/TensorCore.jl/blob/{commit}{path}#L{line}", | ||
sitename="TensorCore.jl", | ||
authors="Tim Holy <[email protected]>", | ||
assets=String[], | ||
) | ||
|
||
deploydocs(; | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,16 @@ | ||
# TensorCore.jl | ||
|
||
This package is intended as a lightweight foundation for tensor operations across the Julia ecosystem. | ||
Currently it exports two core operations, `hadamard` and `tensor`, and corresponding unicode operators `⊙` and `⊗`, respectively. | ||
|
||
## API | ||
|
||
```@index | ||
``` | ||
|
||
```@autodocs | ||
Modules = [TensorCore] | ||
```@docs | ||
hadamard | ||
hadamard! | ||
tensor | ||
tensor! | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,126 @@ | ||
module TensorCore | ||
|
||
greet() = print("Hello World!") | ||
using LinearAlgebra | ||
|
||
end # module | ||
export ⊙, hadamard, hadamard! | ||
export ⊗, tensor, tensor! | ||
|
||
""" | ||
hadamard(a, b) | ||
a ⊙ b | ||
For arrays `a` and `b`, perform elementwise multiplication. | ||
`a` and `b` must have identical `axes`. | ||
`⊙` can be passed as an operator to higher-order functions. | ||
```jldoctest; setup=:(using TensorCore) | ||
julia> a = [2, 3]; b = [5, 7]; | ||
julia> a ⊙ b | ||
2-element Array{$Int,1}: | ||
10 | ||
21 | ||
julia> a ⊙ [5] | ||
ERROR: DimensionMismatch("Axes of `A` and `B` must match, got (Base.OneTo(2),) and (Base.OneTo(1),)") | ||
[...] | ||
``` | ||
""" | ||
function hadamard(A::AbstractArray, B::AbstractArray) | ||
@noinline throw_dmm(axA, axB) = throw(DimensionMismatch("Axes of `A` and `B` must match, got $axA and $axB")) | ||
|
||
axA, axB = axes(A), axes(B) | ||
axA == axB || throw_dmm(axA, axB) | ||
return map(*, A, B) | ||
end | ||
const ⊙ = hadamard | ||
|
||
""" | ||
hadamard!(dest, A, B) | ||
Similar to `hadamard(A, B)` (which can also be written `A ⊙ B`), but stores its results in | ||
the pre-allocated array `dest`. | ||
""" | ||
function hadamard!(dest::AbstractArray, A::AbstractArray, B::AbstractArray) | ||
@noinline function throw_dmm(axA, axB, axdest) | ||
throw(DimensionMismatch("`axes(dest) = $axdest` must be equal to `axes(A) = $axA` and `axes(B) = $axB`")) | ||
end | ||
|
||
axA, axB, axdest = axes(A), axes(B), axes(dest) | ||
((axdest == axA) & (axdest == axB)) || throw_dmm(axA, axB, axdest) | ||
@simd for I in eachindex(dest, A, B) | ||
@inbounds dest[I] = A[I] * B[I] | ||
end | ||
return dest | ||
end | ||
|
||
""" | ||
tensor(A, B) | ||
A ⊗ B | ||
Compute the tensor product of `A` and `B`. | ||
If `C = A ⊗ B`, then `C[i1, ..., im, j1, ..., jn] = A[i1, ... im] * B[j1, ..., jn]`. | ||
```jldoctest; setup=:(using TensorCore) | ||
julia> a = [2, 3]; b = [5, 7, 11]; | ||
julia> a ⊗ b | ||
2×3 Array{$Int,2}: | ||
10 14 22 | ||
15 21 33 | ||
``` | ||
For vectors `v` and `w`, the Kronecker product is related to the tensor product by | ||
`kron(v,w) == vec(w ⊗ v)` or `w ⊗ v == reshape(kron(v,w), (length(w), length(v)))`. | ||
""" | ||
tensor(A::AbstractArray, B::AbstractArray) = [a*b for a in A, b in B] | ||
const ⊗ = tensor | ||
|
||
const CovectorLike{T} = Union{Adjoint{T,<:AbstractVector},Transpose{T,<:AbstractVector}} | ||
function tensor(u::AbstractArray, v::CovectorLike) | ||
# If `v` is thought of as a covector, you might want this to be two-dimensional, | ||
# but thought of as a matrix it should be three-dimensional. | ||
# The safest is to avoid supporting it at all. See discussion in #35150. | ||
error("`tensor` is not defined for co-vectors, perhaps you meant `*`?") | ||
end | ||
function tensor(u::CovectorLike, v::AbstractArray) | ||
error("`tensor` is not defined for co-vectors, perhaps you meant `*`?") | ||
end | ||
function tensor(u::CovectorLike, v::CovectorLike) | ||
error("`tensor` is not defined for co-vectors, perhaps you meant `*`?") | ||
end | ||
|
||
""" | ||
tensor!(dest, A, B) | ||
Similar to `tensor(A, B)` (which can also be written `A ⊗ B`), but stores its results in | ||
the pre-allocated array `dest`. | ||
""" | ||
function tensor!(dest::AbstractArray, A::AbstractArray, B::AbstractArray) | ||
@noinline function throw_dmm(axA, axB, axdest) | ||
throw(DimensionMismatch("`axes(dest) = $axdest` must concatenate `axes(A) = $axA` and `axes(B) = $axB`")) | ||
end | ||
|
||
axA, axB, axdest = axes(A), axes(B), axes(dest) | ||
axes(dest) == (axA..., axB...) || throw_dmm(axA, axB, axdest) | ||
if IndexStyle(dest) === IndexCartesian() | ||
for IB in CartesianIndices(axB) | ||
@inbounds b = B[IB] | ||
@simd for IA in CartesianIndices(axA) | ||
@inbounds dest[IA,IB] = A[IA]*b | ||
end | ||
end | ||
else | ||
i = firstindex(dest) | ||
@inbounds for b in B | ||
@simd for a in A | ||
dest[i] = a*b | ||
i += 1 | ||
end | ||
end | ||
end | ||
return dest | ||
end | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,67 @@ | ||
using TensorCore | ||
using LinearAlgebra | ||
using Test | ||
|
||
@testset "Ambiguities" begin | ||
@test isempty(detect_ambiguities(TensorCore, Base, Core, LinearAlgebra)) | ||
end | ||
|
||
@testset "TensorCore.jl" begin | ||
# Write your own tests here. | ||
for T in (Int, Float32, Float64, BigFloat) | ||
a = [T[1, 2], T[-3, 7]] | ||
b = [T[5, 11], T[-13, 17]] | ||
@test map(⋅, a, b) == map(dot, a, b) == [27, 158] | ||
@test map(⊙, a, b) == map(hadamard, a, b) == [a[1].*b[1], a[2].*b[2]] | ||
@test map(⊗, a, b) == map(tensor, a, b) == [a[1]*transpose(b[1]), a[2]*transpose(b[2])] | ||
@test hadamard!(fill(typemax(Int), 2), T[1, 2], T[-3, 7]) == [-3, 14] | ||
@test tensor!(fill(typemax(Int), 2, 2), T[1, 2], T[-3, 7]) == [-3 7; -6 14] | ||
end | ||
|
||
@test_throws DimensionMismatch [1,2] ⊙ [3] | ||
@test_throws DimensionMismatch hadamard!([0, 0, 0], [1,2], [-3,7]) | ||
@test_throws DimensionMismatch hadamard!([0, 0], [1,2], [-3]) | ||
@test_throws DimensionMismatch hadamard!([0, 0], [1], [-3,7]) | ||
@test_throws DimensionMismatch tensor!(Matrix{Int}(undef, 2, 2), [1], [-3,7]) | ||
@test_throws DimensionMismatch tensor!(Matrix{Int}(undef, 2, 2), [1,2], [-3]) | ||
|
||
u, v = [2+2im, 3+5im], [1-3im, 7+3im] | ||
@test u ⋅ v == conj(u[1])*v[1] + conj(u[2])*v[2] | ||
@test u ⊙ v == [u[1]*v[1], u[2]*v[2]] | ||
@test u ⊗ v == [u[1]*v[1] u[1]*v[2]; u[2]*v[1] u[2]*v[2]] | ||
@test hadamard(u, v) == u ⊙ v | ||
@test tensor(u, v) == u ⊗ v | ||
dest = similar(u) | ||
@test hadamard!(dest, u, v) == u ⊙ v | ||
dest = Matrix{Complex{Int}}(undef, 2, 2) | ||
@test tensor!(dest, u, v) == u ⊗ v | ||
|
||
for (A, B, b) in (([1 2; 3 4], [5 6; 7 8], [5,6]), | ||
([1+0.8im 2+0.7im; 3+0.6im 4+0.5im], | ||
[5+0.4im 6+0.3im; 7+0.2im 8+0.1im], | ||
[5+0.6im,6+0.3im])) | ||
@test A ⊗ b == cat(A*b[1], A*b[2]; dims=3) | ||
@test A ⊗ B == cat(cat(A*B[1,1], A*B[2,1]; dims=3), | ||
cat(A*B[1,2], A*B[2,2]; dims=3); dims=4) | ||
end | ||
|
||
A, B = reshape(1:27, 3, 3, 3), reshape(1:4, 2, 2) | ||
@test A ⊗ B == [a*b for a in A, b in B] | ||
|
||
# Adjoint/transpose is a dual vector, not an AbstractMatrix | ||
v = [1,2] | ||
@test_throws ErrorException v ⊗ v' | ||
@test_throws ErrorException v ⊗ transpose(v) | ||
@test_throws ErrorException v' ⊗ v | ||
@test_throws ErrorException transpose(v) ⊗ v | ||
@test_throws ErrorException v' ⊗ v' | ||
@test_throws ErrorException transpose(v) ⊗ transpose(v) | ||
@test_throws ErrorException v' ⊗ transpose(v) | ||
@test_throws ErrorException transpose(v) ⊗ v' | ||
@test_throws ErrorException A ⊗ v' | ||
@test_throws ErrorException A ⊗ transpose(v) | ||
|
||
# Docs comparison to `kron` | ||
v, w = [1,2,3], [5,7] | ||
@test kron(v,w) == vec(w ⊗ v) | ||
@test w ⊗ v == reshape(kron(v,w), (length(w), length(v))) | ||
end |
7ab18c8
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@JuliaRegistrator register
7ab18c8
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Registration pull request created: JuliaRegistries/General/14192
After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.
This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via: