Skip to content

Commit

Permalink
Merge branch 'CUDA' into MAHOUT-1974
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewpalumbo authored May 9, 2017
2 parents aa8fdcf + e073884 commit dce10d7
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 2 deletions.
1 change: 0 additions & 1 deletion cuda/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,6 @@
<artifactId>jcusparse</artifactId>
<version>${jcuda.jcudaVersion}</version>
</dependency>

</dependencies>


Expand Down
2 changes: 1 addition & 1 deletion cuda/src/main/scala/org/apache/mahout/cuda/Context.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import jcublas._
import JCublas._

final class Context {

// Enable exceptions for all CUDA libraries
JCuda.setExceptionsEnabled(true)
JCusparse.setExceptionsEnabled(true)
Expand All @@ -42,6 +43,5 @@ final class Context {
//TODO: is this needed somehow- via the cusparse library?
// cusparseCreate(denseHandle)


}

9 changes: 9 additions & 0 deletions cuda/src/main/scala/org/apache/mahout/cuda/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,16 @@ import scala.collection.JavaConversions._
import jcuda.runtime.JCuda._
import jcuda.runtime.cudaMemcpyKind._


import jcuda._
import jcuda.jcublas._




import jcuda.jcusparse.JCusparse._


package object cuda {

private implicit val log = getLog(GPUMMul.getClass)
Expand Down Expand Up @@ -124,6 +128,7 @@ package object cuda {
// }
// }


/**
*
* @param mxSrc
Expand Down Expand Up @@ -231,6 +236,7 @@ package object cuda {
(jumpers, colIdcs, els)
}


/**
* Dense %*% Dense
* @param a
Expand Down Expand Up @@ -285,10 +291,12 @@ package object cuda {
// step 1: compute nnz count
var nnzC = new Array[Int](1)
nnzC(0) = 0

cusparseXcsrgemmNnz(ctx.sparseHandle, a.trans, b.trans, m, n, k,
a.descr, a.nonz, a.row_ptr, a.col_ind,
b.descr, b.nonz, b.row_ptr, b.col_ind,
c.descr, c.row_ptr, jcuda.Pointer.to(nnzC))

c.nonz = nnzC(0)
if (c.nonz == 0) {
var baseC = new Array[Int](1)
Expand All @@ -300,6 +308,7 @@ package object cuda {
// step 2: allocate and compute matrix product
cudaMalloc(c.col_ind, jcuda.Sizeof.INT * c.nonz);
cudaMalloc(c.vals, jcuda.Sizeof.DOUBLE * c.nonz);

cusparseDcsrgemm(ctx.sparseHandle, a.trans, b.trans, m, n, k,
a.descr, a.nonz,
a.vals, a.row_ptr, a.col_ind,
Expand Down

0 comments on commit dce10d7

Please sign in to comment.