From 8ed24c57e28db42b51c674f71d74c00f29799090 Mon Sep 17 00:00:00 2001 From: Andrew Palumbo Date: Wed, 19 Apr 2017 22:22:29 -0700 Subject: [PATCH 1/2] CUDA Feature Branch From e0738849f50de12ffa6eb01897adb48d4df18232 Mon Sep 17 00:00:00 2001 From: Nikolay Sakharnykh Date: Mon, 8 May 2017 09:37:30 -0700 Subject: [PATCH 2/2] MAHOUT-1974 initial CUDA support closes apache/mahout#310 --- cuda/pom.xml | 251 +++++++++++ .../apache/mahout/cuda/CompressedMatrix.scala | 83 ++++ .../org/apache/mahout/cuda/Context.scala | 36 ++ .../org/apache/mahout/cuda/GPUMMul.scala | 416 ++++++++++++++++++ .../org/apache/mahout/cuda/package.scala | 211 +++++++++ .../apache/mahout/cuda/CUDATestSuite.scala | 77 ++++ .../mahout/cuda/UserSetCUDATestSuite.scala | 98 +++++ .../math/backend/RootSolverFactory.scala | 9 + pom.xml | 13 + viennacl-omp/pom.xml | 3 - 10 files changed, 1194 insertions(+), 3 deletions(-) create mode 100644 cuda/pom.xml create mode 100644 cuda/src/main/scala/org/apache/mahout/cuda/CompressedMatrix.scala create mode 100644 cuda/src/main/scala/org/apache/mahout/cuda/Context.scala create mode 100644 cuda/src/main/scala/org/apache/mahout/cuda/GPUMMul.scala create mode 100644 cuda/src/main/scala/org/apache/mahout/cuda/package.scala create mode 100644 cuda/src/test/scala/org/apache/mahout/cuda/CUDATestSuite.scala create mode 100644 cuda/src/test/scala/org/apache/mahout/cuda/UserSetCUDATestSuite.scala diff --git a/cuda/pom.xml b/cuda/pom.xml new file mode 100644 index 0000000000..bb6bb9235d --- /dev/null +++ b/cuda/pom.xml @@ -0,0 +1,251 @@ + + + + + + 4.0.0 + + + org.apache.mahout + mahout + 0.13.1-SNAPSHOT + ../pom.xml + + + + 0.8.0 + + + mahout-native-cuda_${scala.compat.version} + + Mahout Native CUDA Bindings + Native Structures and interfaces to be used from Mahout math-scala. + + + jar + + + + + + org.apache.maven.plugins + maven-jar-plugin + + + + test-jar + + package + + + + + + maven-javadoc-plugin + + + + maven-source-plugin + + + + net.alchim31.maven + scala-maven-plugin + + + add-scala-sources + initialize + + add-source + + + + scala-compile + process-resources + + compile + + + + scala-test-compile + process-test-resources + + testCompile + + + + + + + + + + + + + + + + + + org.scalatest + scalatest-maven-plugin + + + test + + test + + + + + -Xmx4g + + + + + org.apache.maven.plugins + maven-dependency-plugin + 2.3 + + + + properties + + + + + + + + maven-antrun-plugin + 1.4 + + + copy + package + + + + + + + run + + + + + + + + + + + + + ${project.groupId} + mahout-math-scala_${scala.compat.version} + + + + + log4j + log4j + + + + + org.scalatest + scalatest_${scala.compat.version} + + + + org.bytedeco + javacpp + 1.2.4 + + + + org.jcuda + jcuda + ${jcuda.jcudaVersion} + + + + org.jcuda + jcusparse + ${jcuda.jcudaVersion} + + + + + + + + mahout-release + + + + net.alchim31.maven + scala-maven-plugin + + + generate-scaladoc + + doc + + + + attach-scaladoc-jar + + doc-jar + + + + + + + + + travis + + + + org.apache.maven.plugins + maven-surefire-plugin + + + -Xmx4g + + + + + org.apache.maven.plugins + maven-failsafe-plugin + + + -Xmx4g + + + + + + + + diff --git a/cuda/src/main/scala/org/apache/mahout/cuda/CompressedMatrix.scala b/cuda/src/main/scala/org/apache/mahout/cuda/CompressedMatrix.scala new file mode 100644 index 0000000000..2aa27f1422 --- /dev/null +++ b/cuda/src/main/scala/org/apache/mahout/cuda/CompressedMatrix.scala @@ -0,0 +1,83 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved. + +package org.apache.mahout.cuda + +import java.nio._ + +import jcuda.jcusparse.cusparseIndexBase.CUSPARSE_INDEX_BASE_ZERO +import jcuda.jcusparse.cusparseMatrixType.CUSPARSE_MATRIX_TYPE_GENERAL +import jcuda.jcusparse.cusparseOperation.CUSPARSE_OPERATION_NON_TRANSPOSE +import jcuda.jcusparse.JCusparse._ +import jcuda.jcusparse._ +import jcuda.runtime.JCuda._ +import jcuda.runtime.cudaMemcpyKind._ + +final class CompressedMatrix { + + var row_ptr = new jcuda.Pointer() + var col_ind = new jcuda.Pointer() + var vals = new jcuda.Pointer() + + var trans = CUSPARSE_OPERATION_NON_TRANSPOSE + var descr = new cusparseMatDescr() + + var nrows = 0 + var ncols = 0 + var nonz = 0 + + def this(ctx: Context, nrow: Int, ncol: Int, nonzeros: Int = 0) { + this() + + nrows = nrow + ncols = ncol + cudaMalloc(row_ptr, (nrow+1)*jcuda.Sizeof.INT) + + nonz = nonzeros + if (nonzeros > 0) { + cudaMalloc(col_ind, nonzeros*jcuda.Sizeof.INT) + cudaMalloc(vals, nonzeros*jcuda.Sizeof.DOUBLE) + } + + // create and setup matrix descriptor + cusparseCreateMatDescr(descr) + cusparseSetMatType(descr, CUSPARSE_MATRIX_TYPE_GENERAL) + cusparseSetMatIndexBase(descr, CUSPARSE_INDEX_BASE_ZERO) + } + + def set(rowJumper: Array[Int], + colIndices: Array[Int], + elements: Array[Double], + nrow: Int, + ncol: Int, + nonzeros: Int) { + cudaMemcpy(row_ptr, jcuda.Pointer.to(rowJumper), (nrow+1)*jcuda.Sizeof.INT, cudaMemcpyHostToDevice) + cudaMemcpy(col_ind, jcuda.Pointer.to(colIndices), (nonzeros)*jcuda.Sizeof.INT, cudaMemcpyHostToDevice) + cudaMemcpy(vals, jcuda.Pointer.to(elements), (nonzeros)*jcuda.Sizeof.DOUBLE, cudaMemcpyHostToDevice) + } + + def close() { + cudaFree(row_ptr) + if (nonz > 0) { + cudaFree(col_ind) + cudaFree(vals) + } + } +} + diff --git a/cuda/src/main/scala/org/apache/mahout/cuda/Context.scala b/cuda/src/main/scala/org/apache/mahout/cuda/Context.scala new file mode 100644 index 0000000000..92b28ec330 --- /dev/null +++ b/cuda/src/main/scala/org/apache/mahout/cuda/Context.scala @@ -0,0 +1,36 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved. + +package org.apache.mahout.cuda + +import jcuda.jcusparse.JCusparse._ +import jcuda.jcusparse._ +import jcuda.runtime.JCuda + +final class Context { + + // Enable exceptions for all CUDA libraries + JCuda.setExceptionsEnabled(true) + JCusparse.setExceptionsEnabled(true) + + // Initialize JCusparse library + var handle: jcuda.jcusparse.cusparseHandle = new cusparseHandle() + cusparseCreate(handle) +} + diff --git a/cuda/src/main/scala/org/apache/mahout/cuda/GPUMMul.scala b/cuda/src/main/scala/org/apache/mahout/cuda/GPUMMul.scala new file mode 100644 index 0000000000..02d02ea8cd --- /dev/null +++ b/cuda/src/main/scala/org/apache/mahout/cuda/GPUMMul.scala @@ -0,0 +1,416 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved. + +package org.apache.mahout.cuda + +import org.apache.mahout.logging._ +import org.apache.mahout.math +import org.apache.mahout.math._ +import org.apache.mahout.math.backend.incore.MMulSolver +import org.apache.mahout.math.flavor.{BackEnum, TraversingStructureEnum} +import org.apache.mahout.math.function.Functions +import org.apache.mahout.math.scalabindings.RLikeOps._ +import org.apache.mahout.math.scalabindings._ + +import scala.collection.JavaConversions._ + +import jcuda._ +import jcuda.jcusparse._ +import jcuda.runtime.JCuda + +object GPUMMul extends MMBinaryFunc { + + private implicit val log = getLog(GPUMMul.getClass) + + override def apply(a: Matrix, b: Matrix, r: Option[Matrix]): Matrix = { + + require(a.ncol == b.nrow, "Incompatible matrix sizes in matrix multiplication.") + + val (af, bf) = (a.getFlavor, b.getFlavor) + val backs = (af.getBacking, bf.getBacking) + val sd = (af.getStructure, math.scalabindings.densityAnalysis(a), bf.getStructure, densityAnalysis(b)) + + + try { + + val alg: MMulAlg = backs match { + + // Both operands are jvm memory backs. + case (BackEnum.JVMMEM, BackEnum.JVMMEM) ⇒ + + sd match { + + // Multiplication cases by a diagonal matrix. + case (TraversingStructureEnum.VECTORBACKED, _, TraversingStructureEnum.COLWISE, _) + if a.isInstanceOf[DiagonalMatrix] ⇒ jvmDiagCW + case (TraversingStructureEnum.VECTORBACKED, _, TraversingStructureEnum.SPARSECOLWISE, _) + if a.isInstanceOf[DiagonalMatrix] ⇒ jvmDiagCW + case (TraversingStructureEnum.VECTORBACKED, _, TraversingStructureEnum.ROWWISE, _) + if a.isInstanceOf[DiagonalMatrix] ⇒ jvmDiagRW + case (TraversingStructureEnum.VECTORBACKED, _, TraversingStructureEnum.SPARSEROWWISE, _) + if a.isInstanceOf[DiagonalMatrix] ⇒ jvmDiagRW + + case (TraversingStructureEnum.COLWISE, _, TraversingStructureEnum.VECTORBACKED, _) + if b.isInstanceOf[DiagonalMatrix] ⇒ jvmCWDiag + case (TraversingStructureEnum.SPARSECOLWISE, _, TraversingStructureEnum.VECTORBACKED, _) + if b.isInstanceOf[DiagonalMatrix] ⇒ jvmCWDiag + case (TraversingStructureEnum.ROWWISE, _, TraversingStructureEnum.VECTORBACKED, _) + if b.isInstanceOf[DiagonalMatrix] ⇒ jvmRWDiag + case (TraversingStructureEnum.SPARSEROWWISE, _, TraversingStructureEnum.VECTORBACKED, _) + if b.isInstanceOf[DiagonalMatrix] ⇒ jvmRWDiag + + // Dense-dense cases + case (TraversingStructureEnum.ROWWISE, true, TraversingStructureEnum.COLWISE, true) if a eq b.t ⇒ gpuDRWAAt + case (TraversingStructureEnum.ROWWISE, true, TraversingStructureEnum.COLWISE, true) if a.t eq b ⇒ gpuDRWAAt + case (TraversingStructureEnum.ROWWISE, true, TraversingStructureEnum.COLWISE, true) ⇒ gpuRWCW + case (TraversingStructureEnum.ROWWISE, true, TraversingStructureEnum.ROWWISE, true) ⇒ jvmRWRW + case (TraversingStructureEnum.COLWISE, true, TraversingStructureEnum.COLWISE, true) ⇒ jvmCWCW + case (TraversingStructureEnum.COLWISE, true, TraversingStructureEnum.ROWWISE, true) if a eq b.t ⇒ jvmDCWAAt + case (TraversingStructureEnum.COLWISE, true, TraversingStructureEnum.ROWWISE, true) if a.t eq b ⇒ jvmDCWAAt + case (TraversingStructureEnum.COLWISE, true, TraversingStructureEnum.ROWWISE, true) ⇒ jvmCWRW + + // Sparse row matrix x sparse row matrix (array of vectors) + case (TraversingStructureEnum.ROWWISE, false, TraversingStructureEnum.ROWWISE, false) ⇒ gpuSparseRWRW + case (TraversingStructureEnum.ROWWISE, false, TraversingStructureEnum.COLWISE, false) ⇒ jvmSparseRWCW + case (TraversingStructureEnum.COLWISE, false, TraversingStructureEnum.ROWWISE, false) ⇒ jvmSparseCWRW + case (TraversingStructureEnum.COLWISE, false, TraversingStructureEnum.COLWISE, false) ⇒ jvmSparseCWCW + + // Sparse matrix x sparse matrix (hashtable of vectors) + case (TraversingStructureEnum.SPARSEROWWISE, false, TraversingStructureEnum.SPARSEROWWISE, false) ⇒ + gpuSparseRowRWRW + case (TraversingStructureEnum.SPARSEROWWISE, false, TraversingStructureEnum.SPARSECOLWISE, false) ⇒ + jvmSparseRowRWCW + case (TraversingStructureEnum.SPARSECOLWISE, false, TraversingStructureEnum.SPARSEROWWISE, false) ⇒ + jvmSparseRowCWRW + case (TraversingStructureEnum.SPARSECOLWISE, false, TraversingStructureEnum.SPARSECOLWISE, false) ⇒ + jvmSparseRowCWCW + + // Sparse matrix x non-like + case (TraversingStructureEnum.SPARSEROWWISE, false, TraversingStructureEnum.ROWWISE, _) ⇒ gpuSparseRowRWRW + case (TraversingStructureEnum.SPARSEROWWISE, false, TraversingStructureEnum.COLWISE, _) ⇒ jvmSparseRowRWCW + case (TraversingStructureEnum.SPARSECOLWISE, false, TraversingStructureEnum.ROWWISE, _) ⇒ jvmSparseRowCWRW + case (TraversingStructureEnum.SPARSECOLWISE, false, TraversingStructureEnum.COLWISE, _) ⇒ jvmSparseCWCW + case (TraversingStructureEnum.ROWWISE, _, TraversingStructureEnum.SPARSEROWWISE, false) ⇒ gpuSparseRWRW + case (TraversingStructureEnum.ROWWISE, _, TraversingStructureEnum.SPARSECOLWISE, false) ⇒ jvmSparseRWCW + case (TraversingStructureEnum.COLWISE, _, TraversingStructureEnum.SPARSEROWWISE, false) ⇒ jvmSparseCWRW + case (TraversingStructureEnum.COLWISE, _, TraversingStructureEnum.SPARSECOLWISE, false) ⇒ jvmSparseRowCWCW + + // Everything else including at least one sparse LHS or RHS argument + case (TraversingStructureEnum.ROWWISE, false, TraversingStructureEnum.ROWWISE, _) ⇒ gpuSparseRWRW + case (TraversingStructureEnum.ROWWISE, false, TraversingStructureEnum.COLWISE, _) ⇒ jvmSparseRWCW + case (TraversingStructureEnum.COLWISE, false, TraversingStructureEnum.ROWWISE, _) ⇒ jvmSparseCWRW + case (TraversingStructureEnum.COLWISE, false, TraversingStructureEnum.COLWISE, _) ⇒ jvmSparseCWCW2flips + + // Sparse methods are only effective if the first argument is sparse, so we need to do a swap. + case (_, _, _, false) ⇒ (a, b, r) ⇒ apply(b.t, a.t, r.map { + _.t + }).t + + // Default jvm-jvm case. + // for some reason a SrarseRowMatrix DRM %*% SrarseRowMatrix DRM was dumping off to here + case _ ⇒ gpuRWCW + } + } + + alg(a, b, r) + } catch { + // TODO FASTHACK: Revert to JVM if there is an exception. + // E.g., java.lang.nullPointerException if more openCL contexts + // have been created than number of GPU cards. + // Better option wuold be to fall back to OpenCL first. + case ex: Exception => + log.info(ex.getMessage + "falling back to JVM MMUL") + ex.printStackTrace + var res = MMul(a, b, r) + println("zsum = " + res.zSum().toString()) + return res + } + } + + type MMulAlg = MMBinaryFunc + + @inline + private def gpuRWCW(a: Matrix, b: Matrix, r: Option[Matrix] = None): Matrix = { + log.info("Using gpuRWCW method") + + val hasElementsA = a.zSum() > 0.0 + val hasElementsB = b.zSum() > 0.0 + + // A has a sparse matrix structure of unknown size. We do not want to + // simply convert it to a Dense Matrix which may result in an OOM error. + + // If it is empty use JVM MMul, since we can not convert it to a VCL CSR Matrix. + if (!hasElementsA) { + log.warn("Matrix a has zero elements can not convert to CSR") + return MMul(a, b, r) + } + + // CSR matrices are efficient up to 50% non-zero + if(b.getFlavor.isDense) { + log.warn("Dense matrices are not supported in CUDA backend, using JVM instead") + return MMul(a, b, r) + } else { + // Fall back to JVM based MMul if either matrix is sparse and empty + if (!hasElementsA || !hasElementsB) { + log.warn("Matrix a or b has zero elements can not convert to CSR") + return MMul(a, b, r) + } + + var ms = System.currentTimeMillis() + + val ctx = new Context() + val cudaA = toCudaCmpMatrix(a, ctx) + val cudaB = toCudaCmpMatrix(b, ctx) + val cudaC = prod(cudaA, cudaB, ctx) + val mxC = fromCudaCmpMatrix(cudaC) + + ms = System.currentTimeMillis() - ms + log.debug(s"CUDA multiplication time: $ms ms") + val gpuTrace = mxC.zSum() + log.debug(s"CUDA trace: $gpuTrace") + + // uncomment code below to verify results against JVM +/* + ms = System.currentTimeMillis() + val jvmTrace = MMul(a, b, r).zSum() + ms = System.currentTimeMillis() - ms + log.debug(s"JVM multiplication time: $ms ms") + log.debug(s"JVM trace: $jvmTrace") +*/ + cudaA.close() + cudaB.close() + cudaC.close() + + mxC + } + } + + + @inline + private def jvmRWRW(a: Matrix, b: Matrix, r: Option[Matrix] = None): Matrix = { + log.info("Using jvmRWRW method") + // A bit hackish: currently, this relies a bit on the fact that like produces RW(?) + val bclone = b.like(b.ncol, b.nrow).t + for (brow ← b) bclone(brow.index(), ::) := brow + + require(bclone.getFlavor.getStructure == TraversingStructureEnum.COLWISE || bclone.getFlavor.getStructure == + TraversingStructureEnum.SPARSECOLWISE, "COL wise conversion assumption of RHS is wrong, do over this code.") + + gpuRWCW(a, bclone, r) + } + + private def jvmCWCW(a: Matrix, b: Matrix, r: Option[Matrix] = None): Matrix = { + log.info("Using jvmCWCW method") + jvmRWRW(b.t, a.t, r.map(_.t)).t + } + + private def jvmCWRW(a: Matrix, b: Matrix, r: Option[Matrix] = None): Matrix = { + log.info("Using jvmCWRW method") + // This is a primary contender with Outer Prod sum algo. + // Here, we force-reorient both matrices and run RWCW. + // A bit hackish: currently, this relies a bit on the fact that clone always produces RW(?) + val aclone = a.cloned + + require(aclone.getFlavor.getStructure == TraversingStructureEnum.ROWWISE || aclone.getFlavor.getStructure == + TraversingStructureEnum.SPARSEROWWISE, "Row wise conversion assumption of RHS is wrong, do over this code.") + + jvmRWRW(aclone, b, r) + } + + // left is Sparse right is any + private def gpuSparseRWRW(a: Matrix, b: Matrix, r: Option[Matrix] = None): Matrix = { + log.info("Using gpuSparseRWRW method") + val mxR = r.getOrElse(b.like(a.nrow, b.ncol)) + + + /* This is very close to the algorithm from SparseMatrix.times + for (arow ← a; ael ← arow.nonZeroes) + mxR(arow.index(), ::).assign(b(ael.index, ::), Functions.plusMult(ael)) + mxR + + Make sure that the matrix is not empty. VCL {{compressed_matrix}}s must + have nnz > 0 + N.B. This method is horribly inefficent. However there is a difference between + getNumNonDefaultElements() and getNumNonZeroElements() which we do not always + have access to. We created MAHOUT-1882 for this. + */ + + val hasElementsA = a.zSum() > 0.0 + val hasElementsB = b.zSum() > 0.0 + + // A has a sparse matrix structure of unknown size. We do not want to + // simply convert it to a Dense Matrix which may result in an OOM error. + // If it is empty use JVM MMul, since we can not convert it to a VCL CSR Matrix. + if (!hasElementsA) { + log.warn("Matrix a has zero elements can not convert to CSR") + return MMul(a, b, r) + } + + // CSR matrices are efficient up to 50% non-zero + if(b.getFlavor.isDense) { + log.warn("Dense matrices are not supported in CUDA backend, using JVM instead") + return MMul(a, b, r) + } else { + // Fall back to JVM based MMul if either matrix is sparse and empty + if (!hasElementsA || !hasElementsB) { + log.warn("Matrix a or b has zero elements can not convert to CSR") + return MMul(a, b, r) + } + + var ms = System.currentTimeMillis() + + val ctx = new Context() + val cudaA = toCudaCmpMatrix(a, ctx) + val cudaB = toCudaCmpMatrix(b, ctx) + val cudaC = prod(cudaA, cudaB, ctx) + val mxC = fromCudaCmpMatrix(cudaC) + + ms = System.currentTimeMillis() - ms + log.debug(s"CUDA multiplication time: $ms ms") + val gpuTrace = mxC.zSum() + log.debug(s"CUDA trace: $gpuTrace") + + // uncomment code below to verify results against JVM +/* + ms = System.currentTimeMillis() + val jvmTrace = MMul(a, b, r).zSum() + ms = System.currentTimeMillis() - ms + log.debug(s"JVM multiplication time: $ms ms") + log.debug(s"JVM trace: $jvmTrace") +*/ + cudaA.close() + cudaB.close() + cudaC.close() + + mxC + } + + } + + // Sparse %*% dense + private def gpuSparseRowRWRW(a: Matrix, b: Matrix, r: Option[Matrix] = None): Matrix = { + log.info("Using gpuSparseRowRWRW method") + val hasElementsA = a.zSum() > 0 + + // A has a sparse matrix structure of unknown size. We do not want to + // simply convert it to a Dense Matrix which may result in an OOM error. + // If it is empty fall back to JVM MMul, since we can not convert it + // to a VCL CSR Matrix. + if (!hasElementsA) { + log.warn("Matrix a has zero elements can not convert to CSR") + return MMul(a, b, r) + } + + log.warn("Dense matrices are not supported in CUDA backend") + return MMul(a, b, r) + } + + private def jvmSparseRowCWCW(a: Matrix, b: Matrix, r: Option[Matrix] = None) = + gpuSparseRowRWRW(b.t, a.t, r.map(_.t)).t + + private def jvmSparseRowCWCW2flips(a: Matrix, b: Matrix, r: Option[Matrix] = None) = + gpuSparseRowRWRW(a cloned, b cloned, r) + + private def jvmSparseRowRWCW(a: Matrix, b: Matrix, r: Option[Matrix]) = + gpuSparseRowRWRW(a, b cloned, r) + + + private def jvmSparseRowCWRW(a: Matrix, b: Matrix, r: Option[Matrix]) = + gpuSparseRowRWRW(a cloned, b, r) + + private def jvmSparseRWCW(a: Matrix, b: Matrix, r: Option[Matrix] = None) = + gpuSparseRWRW(a, b.cloned, r) + + private def jvmSparseCWRW(a: Matrix, b: Matrix, r: Option[Matrix] = None) = + gpuSparseRWRW(a cloned, b, r) + + private def jvmSparseCWCW(a: Matrix, b: Matrix, r: Option[Matrix] = None) = + gpuSparseRWRW(b.t, a.t, r.map(_.t)).t + + private def jvmSparseCWCW2flips(a: Matrix, b: Matrix, r: Option[Matrix] = None) = + gpuSparseRWRW(a cloned, b cloned, r) + + private def jvmDiagRW(diagm:Matrix, b:Matrix, r:Option[Matrix] = None):Matrix = { + log.info("Using jvmDiagRW method") + val mxR = r.getOrElse(b.like(diagm.nrow, b.ncol)) + + for (del ← diagm.diagv.nonZeroes()) + mxR(del.index, ::).assign(b(del.index, ::), Functions.plusMult(del)) + + mxR + } + + private def jvmDiagCW(diagm: Matrix, b: Matrix, r: Option[Matrix] = None): Matrix = { + log.info("Using jvmDiagCW method") + val mxR = r.getOrElse(b.like(diagm.nrow, b.ncol)) + for (bcol ← b.t) mxR(::, bcol.index()) := bcol * diagm.diagv + mxR + } + + private def jvmCWDiag(a: Matrix, diagm: Matrix, r: Option[Matrix] = None) = + jvmDiagRW(diagm, a.t, r.map {_.t}).t + + private def jvmRWDiag(a: Matrix, diagm: Matrix, r: Option[Matrix] = None) = + jvmDiagCW(diagm, a.t, r.map {_.t}).t + + + /** Dense column-wise AA' */ + private def jvmDCWAAt(a:Matrix, b:Matrix, r:Option[Matrix] = None) = { + // a.t must be equiv. to b. Cloning must rewrite to row-wise. + gpuDRWAAt(a.cloned,null,r) + } + + /** Dense Row-wise AA' */ + // we probably will not want to use this for the actual release unless A is cached already + // but adding for testing purposes. + private def gpuDRWAAt(a:Matrix, b:Matrix, r:Option[Matrix] = None) = { + log.warn("Dense matrices are not supported in CUDA backend, using JVM instead") + MMul(a, b, r) + } + + private def jvmOuterProdSum(a: Matrix, b: Matrix, r: Option[Matrix] = None): Matrix = { + log.info("Using jvmOuterProdSum method") + // Need to check whether this is already laid out for outer product computation, which may be faster than + // reorienting both matrices. + val (m, n) = (a.nrow, b.ncol) + + // Prefer col-wise result iff a is dense and b is sparse. In all other cases default to row-wise. + val preferColWiseR = a.getFlavor.isDense && !b.getFlavor.isDense + + val mxR = r.getOrElse { + (a.getFlavor.isDense, preferColWiseR) match { + case (false, false) ⇒ b.like(m, n) + case (false, true) ⇒ b.like(n, m).t + case (true, false) ⇒ a.like(m, n) + case (true, true) ⇒ a.like(n, m).t + } + } + + // Loop outer products + if (preferColWiseR) { + // B is sparse and A is not, so we need to iterate over b values and update R columns with += + // one at a time. + for ((acol, brow) ← a.t.zip(b); bel ← brow.nonZeroes) mxR(::, bel.index()) += bel * acol + } else { + for ((acol, brow) ← a.t.zip(b); ael ← acol.nonZeroes()) mxR(ael.index(), ::) += ael * brow + } + + mxR + } +} diff --git a/cuda/src/main/scala/org/apache/mahout/cuda/package.scala b/cuda/src/main/scala/org/apache/mahout/cuda/package.scala new file mode 100644 index 0000000000..93a3b25af2 --- /dev/null +++ b/cuda/src/main/scala/org/apache/mahout/cuda/package.scala @@ -0,0 +1,211 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved. + +package org.apache.mahout + +import java.nio._ + +import org.apache.mahout.logging._ +import org.apache.mahout.math._ +import scalabindings._ +import RLikeOps._ +import org.apache.mahout.math.backend.incore._ +import scala.collection.JavaConversions._ + +import jcuda.runtime.JCuda._ +import jcuda.runtime.cudaMemcpyKind._ +import jcuda.jcusparse.JCusparse._ + +package object cuda { + + private implicit val log = getLog(GPUMMul.getClass) + + /** + * + * @param mxSrc + * @param ctx + * @return + */ + def toCudaCmpMatrix(mxSrc: Matrix, ctx: Context): CompressedMatrix = { + val (jumpers, colIdcs, els) = repackCSR(mxSrc) + val compMx = new CompressedMatrix(ctx, mxSrc.nrow, mxSrc.ncol, els.length) + compMx.set(jumpers, colIdcs, els, mxSrc.nrow, mxSrc.ncol, els.length) + compMx + } + + private def repackCSR(mx: Matrix): (Array[Int], Array[Int], Array[Double]) = { + val nzCnt = mx.map(_.getNumNonZeroElements).sum + val jumpers = new Array[Int](mx.nrow + 1) + val colIdcs = new Array[Int](nzCnt + 0) + val els = new Array[Double](nzCnt) + var posIdx = 0 + + var sortCols = false + + // Row-wise loop. Rows may not necessarily come in order. But we have to have them in-order. + for (irow ← 0 until mx.nrow) { + + val row = mx(irow, ::) + jumpers(irow) = posIdx + + // Remember row start index in case we need to restart conversion of this row if out-of-order + // column index is detected + val posIdxStart = posIdx + + // Retry loop: normally we are done in one pass thru it unless we need to re-run it because + // out-of-order column was detected. + var done = false + while (!done) { + + // Is the sorting mode on? + if (sortCols) { + + // Sorting of column indices is on. So do it. + row.nonZeroes() + // Need to convert to a strict collection out of iterator + .map(el ⇒ el.index → el.get) + // Sorting requires Sequence api + .toSeq + // Sort by column index + .sortBy(_._1) + // Flush to the CSR buffers. + .foreach { case (index, v) ⇒ + colIdcs(posIdx) = index + els(posIdx) = v + posIdx += 1 + } + + // Never need to retry if we are already in the sorting mode. + done = true + + } else { + + // Try to run unsorted conversion here, switch lazily to sorted if out-of-order column is + // detected. + var lastCol = 0 + val nzIter = row.nonZeroes().iterator() + var abortNonSorted = false + + while (nzIter.hasNext && !abortNonSorted) { + + val el = nzIter.next() + val index = el.index + + if (index < lastCol) { + + // Out of order detected: abort inner loop, reset posIdx and retry with sorting on. + abortNonSorted = true + sortCols = true + posIdx = posIdxStart + + } else { + + // Still in-order: save element and column, continue. + els(posIdx) = el + colIdcs(posIdx) = index + posIdx += 1 + + // Remember last column seen. + lastCol = index + } + } // inner non-sorted + + // Do we need to re-run this row with sorting? + done = !abortNonSorted + + } // if (sortCols) + + } // while (!done) retry loop + + } // row-wise loop + + // Make sure Mahout matrix did not cheat on non-zero estimate. + assert(posIdx == nzCnt) + + jumpers(mx.nrow) = nzCnt + + (jumpers, colIdcs, els) + } + + def prod(a: CompressedMatrix, b: CompressedMatrix, ctx: Context): CompressedMatrix = { + var m = a.nrows + var n = b.ncols + var k = b.nrows + + var c: CompressedMatrix = new CompressedMatrix(ctx, m, n) + + // step 1: compute nnz count + var nnzC = new Array[Int](1) + nnzC(0) = 0 + cusparseXcsrgemmNnz(ctx.handle, a.trans, b.trans, m, n, k, + a.descr, a.nonz, a.row_ptr, a.col_ind, + b.descr, b.nonz, b.row_ptr, b.col_ind, + c.descr, c.row_ptr, jcuda.Pointer.to(nnzC)) + c.nonz = nnzC(0) + if (c.nonz == 0) { + var baseC = new Array[Int](1) + cudaMemcpy(jcuda.Pointer.to(nnzC), c.row_ptr.withByteOffset(m * jcuda.Sizeof.INT), jcuda.Sizeof.INT, cudaMemcpyDeviceToHost) + cudaMemcpy(jcuda.Pointer.to(baseC), c.row_ptr, jcuda.Sizeof.INT, cudaMemcpyDeviceToHost) + c.nonz = nnzC(0) - baseC(0) + } + + // step 2: allocate and compute matrix product + cudaMalloc(c.col_ind, jcuda.Sizeof.INT * c.nonz); + cudaMalloc(c.vals, jcuda.Sizeof.DOUBLE * c.nonz); + cusparseDcsrgemm(ctx.handle, a.trans, b.trans, m, n, k, + a.descr, a.nonz, + a.vals, a.row_ptr, a.col_ind, + b.descr, b.nonz, + b.vals, b.row_ptr, b.col_ind, + c.descr, + c.vals, c.row_ptr, c.col_ind); + c + } + + def fromCudaCmpMatrix(src: CompressedMatrix): Matrix = { + val m = src.nrows + val n = src.ncols + val NNz = src.nonz + + log.debug("m=" + m.toString() + ", n=" + n.toString() + ", nnz=" + NNz.toString()) + + val row_ptr = new Array[Int](m + 1) + val col_idx = new Array[Int](NNz) + val values = new Array[Double](NNz) + + cudaMemcpy(jcuda.Pointer.to(row_ptr), src.row_ptr, (m+1)*jcuda.Sizeof.INT, cudaMemcpyDeviceToHost) + cudaMemcpy(jcuda.Pointer.to(col_idx), src.col_ind, (NNz)*jcuda.Sizeof.INT, cudaMemcpyDeviceToHost) + cudaMemcpy(jcuda.Pointer.to(values), src.vals, (NNz)*jcuda.Sizeof.DOUBLE, cudaMemcpyDeviceToHost) + + val srMx = new SparseRowMatrix(m, n) + + // read the values back into the matrix + var j = 0 + // row wise, copy any non-zero elements from row(i-1,::) + for (i <- 1 to m) { + // for each nonzero element, set column col(idx(j) value to vals(j) + while (j < row_ptr(i)) { + srMx(i - 1, col_idx(j)) = values(j) + j += 1 + } + } + srMx + } + +} diff --git a/cuda/src/test/scala/org/apache/mahout/cuda/CUDATestSuite.scala b/cuda/src/test/scala/org/apache/mahout/cuda/CUDATestSuite.scala new file mode 100644 index 0000000000..5222cc1106 --- /dev/null +++ b/cuda/src/test/scala/org/apache/mahout/cuda/CUDATestSuite.scala @@ -0,0 +1,77 @@ +package org.apache.mahout.cuda + +import org.scalatest.{FunSuite, Matchers} +import org.apache.mahout.math._ +import scalabindings.RLikeOps._ + +import scala.util.Random + +/** + * Created by andy on 3/29/17. + */ +class CUDATestSuite extends FunSuite with Matchers { + + + test("sparse mmul at geometry of 1000 x 1000 %*% 1000 x 1000 density = .2. 5 runs") { + CUDATestSuite.getAverageTime(1000, 1000, 1000, .20, 1234L, 3) + } + test("sparse mmul at geometry of 1000 x 1000 %*% 1000 x 1000 density = .02. 5 runs") { + CUDATestSuite.getAverageTime(1000, 1000, 1000, .02, 1234L, 3) + } + test("sparse mmul at geometry of 1000 x 1000 %*% 1000 x 1000 density = .002. 5 runs") { + CUDATestSuite.getAverageTime(1000, 1000, 1000, .002, 1234L, 3) + } +} + + +object CUDATestSuite { + def getAverageTime(m: Int = 1000, + s: Int = 1000, + n: Int = 1000, + density: Double = .2, + seed: Long = 1234L, + nruns: Int = 5): Long = { + + val r = new Random(seed) + val cudaCtx = new Context() + + // sparse row-wise + val mxA = new SparseRowMatrix(m, s, false) + val mxB = new SparseRowMatrix(s, n, true) + + // add some sparse data with the given threshold + mxA := { (_, _, v) => if (r.nextDouble() < density) r.nextDouble() else v } + mxB := { (_, _, v) => if (r.nextDouble() < density) r.nextDouble() else v } + + // run Mahout JVM - only math once + var mxC = mxA %*% mxB + + // run Mahout JVM - only math another {{nruns}} times and take average + var ms = System.currentTimeMillis() + for (i: Int <- 1 to nruns) { + mxC = mxA %*% mxB + } + ms = (System.currentTimeMillis() - ms) / nruns + print(s"Mahout JVM Sparse multiplication time: $ms ms.\n") + + + // run Mahout JCuda math bindings once + val cudaA = toCudaCmpMatrix(mxA, cudaCtx) + val cudaB = toCudaCmpMatrix(mxB, cudaCtx) + var mxCuda = prod(cudaA, cudaB, cudaCtx) + + // run Mahout JCuda another {{nruns}} times and take average + ms = System.currentTimeMillis() + for (i: Int <- 1 to nruns) { + mxCuda = prod(cudaA, cudaB, cudaCtx) + } + + ms = (System.currentTimeMillis() - ms) / nruns + print(s"Mahout JCuda Sparse multiplication time: $ms ms.\n") + + // TODO: Ensure that we've been working with the same matrices. + // (mxC - mxCuda).norm / mxC.nrow / mxC.ncol should be < 1e-16 + ms + } + +} diff --git a/cuda/src/test/scala/org/apache/mahout/cuda/UserSetCUDATestSuite.scala b/cuda/src/test/scala/org/apache/mahout/cuda/UserSetCUDATestSuite.scala new file mode 100644 index 0000000000..c718986178 --- /dev/null +++ b/cuda/src/test/scala/org/apache/mahout/cuda/UserSetCUDATestSuite.scala @@ -0,0 +1,98 @@ +package org.apache.mahout.cuda + +import org.scalatest.{FunSuite, Matchers} +import org.apache.mahout.math._ +import scalabindings.RLikeOps._ +import CUDATestSuite._ +import scala.util.Properties.envOrElse + +import scala.util.Random + + +import scala.util.Random +/** + * Created by andy on 3/29/17. + */ + +// some quickfixes as well +class UserSetCUDATestSuite extends FunSuite with Matchers { + + // defaults + var m: Int = 1000 + var s: Int = 1000 + var n: Int = 1000 + var density: Double = .2 + var seed: Long = 1234L + var num_runs: Int = 5 + + // grab the environment variables if set. + m = envOrElse("SIZE_M","1000").toInt + s = envOrElse("SIZE_S","1000").toInt + n = envOrElse("SIZE_N","1000").toInt + density = envOrElse("DENSITY",".02").toDouble + seed = envOrElse("SEED","1234").toLong + num_runs = envOrElse("NUM_RUNS","3").toInt + + test("User Defined sparse mmul at geometry of " + + m + " x " + s + " %*% " + s + " x " + n + " density = " + density + " " + num_runs + " runs \n") { + + val ms = getAverageTime(m, n, s, density, seed, num_runs) + + println("User Defined sparse mmul at geometry of " + + m + " x " + s + " %*% " + s + " x " + n + " density = " + density + " " + num_runs + " runs : "+ms +" ms") + } +} + + +object UserSetCUDATestSuite { + def getAverageTime(m: Int = 1000, + s: Int = 1000, + n: Int = 1000, + density: Double = .2, + seed: Long = 1234L, + nruns: Int = 5): Long = { + + val r = new Random(seed) + val cudaCtx = new Context() + + // sparse row-wise + val mxA = new SparseRowMatrix(m, s, false) + val mxB = new SparseRowMatrix(s, n, true) + + // add some sparse data with the given threshold + mxA := { (_, _, v) => if (r.nextDouble() < density) r.nextDouble() else v } + mxB := { (_, _, v) => if (r.nextDouble() < density) r.nextDouble() else v } + + // run Mahout JVM - only math once + var mxC = mxA %*% mxB + + // run Mahout JVM - only math another {{nruns}} times and take average + var ms = System.currentTimeMillis() + for (i: Int <- 1 to nruns) { + mxC = mxA %*% mxB + } + ms = (System.currentTimeMillis() - ms) / nruns + print(s"Mahout JVM Sparse multiplication time: $ms ms.") + + + // run Mahout JCuda math bindings once + val cudaA = toCudaCmpMatrix(mxA, cudaCtx) + val cudaB = toCudaCmpMatrix(mxB, cudaCtx) + var mxCuda = prod(cudaA, cudaB, cudaCtx) + + // run Mahout JCuda another {{nruns}} times and take average + ms = System.currentTimeMillis() + for (i: Int <- 1 to nruns) { + mxCuda = prod(cudaA, cudaB, cudaCtx) + } + + ms = (System.currentTimeMillis() - ms) / nruns + print(s"Mahout JCuda Sparse multiplication time: $ms ms.") + + // TODO: Ensure that we've been working with the same matrices. + // (mxC - mxCuda).norm / mxC.nrow / mxC.ncol should be < 1e-16 + ms + } + +} + diff --git a/math-scala/src/main/scala/org/apache/mahout/math/backend/RootSolverFactory.scala b/math-scala/src/main/scala/org/apache/mahout/math/backend/RootSolverFactory.scala index 0904ea5c11..7865bfacf2 100644 --- a/math-scala/src/main/scala/org/apache/mahout/math/backend/RootSolverFactory.scala +++ b/math-scala/src/main/scala/org/apache/mahout/math/backend/RootSolverFactory.scala @@ -56,6 +56,14 @@ final object RootSolverFactory extends SolverFactory { def getOperator[C: ClassTag]: MMBinaryFunc = { + try { + logger.info("Creating scala.cuda.GPUMMul solver") + clazz = Class.forName("scala.cuda.GPUMMul$").getField("MODULE$").get(null).asInstanceOf[MMBinaryFunc] + logger.info("Successfully created scala.cuda.GPUMMul solver") + + } catch { + case cudax: Exception => + logger.info("Unable to create class GPUMMul with CUDA: attempting OpenCL version") try { logger.info("Creating org.apache.mahout.viennacl.opencl.GPUMMul solver") clazz = Class.forName("org.apache.mahout.viennacl.opencl.GPUMMul$").getField("MODULE$").get(null).asInstanceOf[MMBinaryFunc] @@ -79,6 +87,7 @@ final object RootSolverFactory extends SolverFactory { clazz = MMul } } + } clazz } } diff --git a/pom.xml b/pom.xml index 8485c7cd6c..31fa578f04 100644 --- a/pom.xml +++ b/pom.xml @@ -244,6 +244,12 @@ ${project.version} + + mahout-native-cuda_${scala.compat.version} + ${project.groupId} + ${project.version} + + mahout-native-viennacl_${scala.compat.version} ${project.groupId} @@ -849,6 +855,13 @@ + + cuda + + cuda + + + viennacl diff --git a/viennacl-omp/pom.xml b/viennacl-omp/pom.xml index 296c5c30eb..0052e0ad65 100644 --- a/viennacl-omp/pom.xml +++ b/viennacl-omp/pom.xml @@ -202,11 +202,8 @@ - - -