Skip to content

Commit

Permalink
MPSMatrix from SubArray
Browse files Browse the repository at this point in the history
# Conflicts:
#	lib/mps/linalg.jl
  • Loading branch information
tgymnich committed Sep 5, 2023
1 parent 69aa51e commit fa4421c
Showing 1 changed file with 26 additions and 0 deletions.
26 changes: 26 additions & 0 deletions lib/mps/matrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,32 @@ function MPSMatrix(arr::MtlArray{T,3}) where T
return obj
end


"""
MPSMatrix(arr::MtlMatrix)
Metal matrix representation used in Performance Shaders.
Note that this results in a transposed view of the input,
as Metal stores matrices row-major instead of column-major.
"""
function MPSMatrix(arr::SubArray{T,2,MtlArray{T,3}}) where T
n_cols, n_rows = size(arr)
row_bytes = sizeof(T)*n_cols
index = parentindices(arr)[3]
offset = row_bytes * n_cols * (index-1)
desc = MPSMatrixDescriptor(n_rows, n_cols, row_bytes, T)
mat = @objc [MPSMatrix alloc]::id{MPSMatrix}
obj = MPSMatrix(mat)
finalizer(release, obj)
@objc [obj::id{MPSMatrix} initWithBuffer:parent(arr).buffer::id{MTLBuffer}
offset:offset::NSUInteger
descriptor:desc::id{MPSMatrixDescriptor}]::id{MPSMatrix}
return obj
end

### parentindices(A)

#
# matrix multiplication
#
Expand Down

0 comments on commit fa4421c

Please sign in to comment.