Skip to content

Commit

Permalink
feat(#48): Implement Matrix Multiplication Using Virtual Threads #48
Browse files Browse the repository at this point in the history
  • Loading branch information
Samuel Abramov committed Oct 11, 2023
1 parent d5f45ec commit 6918ed4
Show file tree
Hide file tree
Showing 9 changed files with 253 additions and 12 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/codeql-analysis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,10 @@ jobs:
- name: Checkout repository
uses: actions/checkout@v4

- name: Set up JDK 17
- name: Set up JDK 21
uses: actions/setup-java@v3
with:
java-version: '17'
java-version: '21'
distribution: 'adopt'

# Initializes the CodeQL tools for scanning.
Expand Down
2 changes: 1 addition & 1 deletion example/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ version = '1.0.5'

java {
toolchain {
languageVersion = JavaLanguageVersion.of(17)
languageVersion = JavaLanguageVersion.of(21)
}
}

Expand Down
11 changes: 4 additions & 7 deletions lib/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ version = '1.0.5'

java {
toolchain {
languageVersion = JavaLanguageVersion.of(17)
languageVersion = JavaLanguageVersion.of(21)
}
}

Expand All @@ -18,6 +18,9 @@ repositories {
}

dependencies {
implementation 'org.ejml:ejml-core:0.43.1'
implementation 'org.ejml:ejml-ddense:0.43.1'
implementation 'org.ejml:ejml-simple:0.43.1'
implementation 'com.opencsv:opencsv:5.8'
implementation 'org.apache.logging.log4j:log4j-api:2.20.0'
implementation 'org.apache.logging.log4j:log4j-core:2.20.0'
Expand All @@ -38,12 +41,6 @@ testing {
}
}

java {
toolchain {
languageVersion = JavaLanguageVersion.of(17)
}
}

task sourceJar(type: Jar) {
from sourceSets.main.allSource
archiveClassifier.set('sources')
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package de.edux.util.math;

public interface ConcurrentMatrixMultiplication {

/**
* Multiplies two matrices and returns the resulting matrix.
*
* @param a The first matrix.
* @param b The second matrix.
* @return The product of the two matrices.
* @throws IllegalArgumentException If the matrices cannot be multiplied due to incompatible dimensions.
*/
double[][] multiplyMatrices(double[][] a, double[][] b) throws IllegalArgumentException, IncompatibleDimensionsException;



}
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package de.edux.util.math;

public class IncompatibleDimensionsException extends Exception{
public IncompatibleDimensionsException(String message) {
super(message);
}
}
58 changes: 58 additions & 0 deletions lib/src/main/java/de/edux/util/math/MathMatrix.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package de.edux.util.math;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;

public class MathMatrix implements ConcurrentMatrixMultiplication {
private static final Logger LOG = LoggerFactory.getLogger(MathMatrix.class);

@Override
public double[][] multiplyMatrices(double[][] a, double[][] b) throws IncompatibleDimensionsException {
LOG.info("Multiplying matrices of size {}x{} and {}x{}", a.length, a[0].length, b.length, b[0].length);
int aRows = a.length;
int aCols = a[0].length;
int bCols = b[0].length;

if (aCols != b.length) {
throw new IncompatibleDimensionsException("Cannot multiply matrices with incompatible dimensions");
}

double[][] result = new double[aRows][bCols];

try(var executor = Executors.newVirtualThreadPerTaskExecutor()) {
List<Future<Void>> futures = new ArrayList<>(aRows);

for (int i = 0; i < aRows; i++) {
final int rowIndex = i;
futures.add(executor.submit(() -> {
for (int colIndex = 0; colIndex < bCols; colIndex++) {
result[rowIndex][colIndex] = multiplyMatrixRowByColumn(a, b, rowIndex, colIndex);
}
return null;
}));
}
for (var future : futures) {
future.get();
}
} catch (ExecutionException | InterruptedException e) {
LOG.error("Error while multiplying matrices", e);
}

LOG.info("Finished multiplying matrices");
return result;
}

private double multiplyMatrixRowByColumn(double[][] a, double[][] b, int row, int col) {
double sum = 0;
for (int i = 0; i < a[0].length; i++) {
sum += a[row][i] * b[i][col];
}
return sum;
}
}
49 changes: 49 additions & 0 deletions lib/src/main/java/de/edux/util/math/MatrixOperations.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package de.edux.util.math;

public interface MatrixOperations {
/**
* Adds two matrices and returns the resulting matrix.
*
* @param a The first matrix.
* @param b The second matrix.
* @return The sum of the two matrices.
* @throws IllegalArgumentException If the matrices are not of the same dimension.
*/
double[][] addMatrices(double[][] a, double[][] b) throws IllegalArgumentException;

/**
* Subtracts matrix b from matrix a and returns the resulting matrix.
*
* @param a The first matrix.
* @param b The second matrix.
* @return The result of a - b.
* @throws IllegalArgumentException If the matrices are not of the same dimension.
*/
double[][] subtractMatrices(double[][] a, double[][] b) throws IllegalArgumentException;

/**
* Transposes the given matrix and returns the resulting matrix.
*
* @param a The matrix to transpose.
* @return The transposed matrix.
*/
double[][] transposeMatrix(double[][] a);

/**
* Inverts the given matrix and returns the resulting matrix.
*
* @param a The matrix to invert.
* @return The inverted matrix.
* @throws IllegalArgumentException If the matrix is not invertible.
*/
double[][] invertMatrix(double[][] a) throws IllegalArgumentException;

/**
* Calculates and returns the determinant of the given matrix.
*
* @param a The matrix.
* @return The determinant of the matrix.
* @throws IllegalArgumentException If the matrix is not square.
*/
double determinant(double[][] a) throws IllegalArgumentException;
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package de.edux.ml.nn.network;

import de.edux.data.provider.Penguin;
import de.edux.data.provider.SeabornDataProcessor;
import de.edux.data.provider.SeabornProvider;
import de.edux.functions.activation.ActivationFunction;
Expand All @@ -9,7 +8,6 @@
import de.edux.ml.nn.config.NetworkConfiguration;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.RepeatedTest;
import org.junit.jupiter.api.Test;

import java.io.File;
import java.net.URL;
Expand Down
115 changes: 115 additions & 0 deletions lib/src/test/java/de/edux/util/math/MathMatrixTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
package de.edux.util.math;

import org.junit.jupiter.api.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;

import static org.junit.jupiter.api.Assertions.assertEquals;

class MathMatrixTest {
private static final long someMaximumValue = 1_000_000_000; // Example value
private static final Logger LOG = LoggerFactory.getLogger(MathMatrixTest.class);

@Test
void multiplyMatrices() throws IncompatibleDimensionsException {
long startTime = System.currentTimeMillis();
int size = 500;

double[][] matrixA = generateMatrix(size);
double[][] matrixB = generateMatrix(size);

ConcurrentMatrixMultiplication matrixMultiplier = new MathMatrix();
double[][] resultMatrix = matrixMultiplier.multiplyMatrices(matrixA, matrixB);

assertEquals(size, resultMatrix.length);
assertEquals(size, resultMatrix[0].length);

long endTime = System.currentTimeMillis();
long timeElapsed = endTime - startTime;
LOG.info("Time elapsed: " + timeElapsed / 1000 + " seconds");
}

@Test
void multiplyMatricesSmall() throws IncompatibleDimensionsException {
double[][] matrixA = {
{1, 2},
{3, 4}
};

double[][] matrixB = {
{2, 0},
{1, 3}
};

ConcurrentMatrixMultiplication matrixMultiplier = new MathMatrix();
double[][] resultMatrix = matrixMultiplier.multiplyMatrices(matrixA, matrixB);

double[][] expectedMatrix = {
{4, 6},
{10, 12}
};

assertArrayEquals(expectedMatrix, resultMatrix);
}

static void assertArrayEquals(double[][] expected, double[][] actual) {
assertEquals(expected.length, actual.length);

for (int i = 0; i < expected.length; i++) {
assertArrayEquals(expected[i], actual[i]);
}
}

static void assertArrayEquals(double[] expected, double[] actual) {
assertEquals(expected.length, actual.length);

for (int i = 0; i < expected.length; i++) {
assertEquals(expected[i], actual[i]);
}
}

double[][] generateMatrix(int size) {
double[][] matrix = new double[size][size];
final int MAX_THREADS = 32;

ExecutorService executor = Executors.newVirtualThreadPerTaskExecutor();
List<Future<Void>> futures = new ArrayList<>();

try {
int rowsPerThread = Math.max(size / MAX_THREADS, 1);

for (int i = 0; i < MAX_THREADS && i * rowsPerThread < size; i++) {
final int startRow = i * rowsPerThread;
final int endRow = Math.min((i + 1) * rowsPerThread, size);

futures.add(executor.submit(() -> {
for (int row = startRow; row < endRow; row++) {
for (int col = 0; col < size; col++) {
matrix[row][col] = Math.random() * 10; // Random values between 0 and 10
}
}
return null;
}));
}

for (Future<Void> future : futures) {
future.get();
}
} catch (InterruptedException | ExecutionException e) {
e.printStackTrace();
} finally {
executor.shutdown();
}

LOG.info("Generated matrix with size: " + size);
return matrix;
}

}

0 comments on commit 6918ed4

Please sign in to comment.