-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrandMatrixMultiply.py
112 lines (86 loc) · 4.04 KB
/
randMatrixMultiply.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import numpy as np
# from scipy.sparse import random as sparse_random
# from scipy.sparse import csr_matrix
def randMatrixMultiply(A: np.matrix, B: np.matrix, c: int, P_k: np.array) -> np.matrix:
""" This function returns an estimator for the product AB using the random matrix multiplication algorithm
:param A: The first m x n matrix to be multiplied
:param B: The second n x p matrix to be multiplied
:param c: The number of column-row pairs to choose
:param P_k: The probability distribution to choose the probability matrix
:return: The estimator for the product AB
"""
n = A.shape[1]
C = np.zeros((A.shape[0], c)) # m x c
R = np.zeros((c, B.shape[1])) # c x p
for t in range(c): # For t = 0 to c-1
i_t = np.random.choice(range(n), p=P_k)
coefficient = 1 / (np.sqrt(c * P_k[i_t]))
C[:, t] = coefficient * A[:, i_t]
R[t, :] = coefficient * B[i_t, :]
return C @ R
def calculate_accuracy(A: np.matrix, B: np.matrix, c: int, P_k: np.array) -> float:
""" This function calculates the accuracy of randMatrixMultiply compared to normal matrix multiplication
:param A: The first m x n matrix to be multiplied
:param B: The second n x p matrix to be multiplied
:param c: The number of column-row pairs to choose
:param P_k: The probability distribution to choose the probability matrix
:return: The accuracy of the randMatrixMultiply function
"""
# Compute the product using normal matrix multiplication
AB_exact = A @ B
# Compute the product using randMatrixMultiply
AB_approx = randMatrixMultiply(A, B, c, P_k)
# Calculate the Frobenius norm of the difference
difference = AB_exact - AB_approx
frobenius_norm = np.linalg.norm(difference, 'fro')
# Calculate the Frobenius norm of the exact product
frobenius_norm_exact = np.linalg.norm(AB_exact, 'fro')
# Calculate the accuracy as 1 - (norm of difference / norm of exact product)
accuracy = 1 - (frobenius_norm / frobenius_norm_exact)
return accuracy
def calculate_loss(A: np.matrix, B: np.matrix, c: int, P_k: np.array) -> float:
""" This function calculates the loss of randMatrixMultiply compared to normal matrix multiplication
:param A: The first m x n matrix to be multiplied
:param B: The second n x p matrix to be multiplied
:param c: The number of column-row pairs to choose
:param P_k: The probability distribution to choose the probability matrix
:return: The loss of the randMatrixMultiply function
"""
AB_exact = A @ B
AB_approx = randMatrixMultiply(A, B, c, P_k)
# Calculate the Frobenius norm of the difference
difference = AB_exact - AB_approx
frobenius_norm = np.linalg.norm(difference, 'fro')
return frobenius_norm
def main():
# Set the seed for reproducibility
seed = 42
np.random.seed(seed)
# Set the dimensions of the matrices
m = 1000
n = 100
p = 50
# # Create the matrices A and B
A = np.random.rand(m, n)
B = np.random.rand(n, p)
# Tall and sparse matrices
# Create the sparse matrices A and B
# density = 0.01 # 1% non-zero entries
# A = sparse_random(m, n, density=density, format='csr', random_state=seed)
# B = sparse_random(n, p, density=density, format='csr', random_state=seed)
# Choose the number of column-row pairs to choose
c = 2
# Define the probability distribution P_k
# P_k = np.full(n, 1/n) # Uniform distribution for simplicity
# Optimal Probability Distribution:
A_col_norms = np.linalg.norm(A, axis=0)
B_row_norms = np.linalg.norm(B, axis=1)
P_k = (A_col_norms * B_row_norms) / np.sum(A_col_norms * B_row_norms)
print(P_k)
# Calculate the accuracy of the randMatrixMultiply function
accuracy = calculate_accuracy(A, B, c, P_k)
loss = calculate_loss(A, B, c, P_k)
print(f"The accuracy of the randMatrixMultiply function is: {accuracy}")
print(f"The loss of the randMatrixMultiply function is: {loss}")
if __name__ == "__main__":
main()