-
Notifications
You must be signed in to change notification settings - Fork 0
/
use_guvectorize.py
71 lines (49 loc) · 2.13 KB
/
use_guvectorize.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import numpy as np
from numba import guvectorize, float64, njit
import common
version = "use_guvectorize"
realsFile = "reals_{}.csv".format(version)
imagesFile = "images_{}.csv".format(version)
resultsFile = "results_{}.csv".format(version)
# if we do not use njit, there will be warning.
@njit
def mandelbrotFunc(inReal: float, inImage: float, cReal: float,
cImage: float) -> tuple[float, float]:
outReal = inReal * inReal - inImage * inImage + cReal
outImage = cImage + 2 * inReal * inImage
return outReal, outImage
# if we do not use njit, there will be warning.
@njit
def M(cReal, cImage):
zReal = 0.0
zImage = 0.0
iotaC = common.iterationLimit
# After this loop, if there is not a z larger than common.threshold, iotaC will be common.iterationLimit
for i in range(1, common.iterationLimit + 1):
zReal, zImage = mandelbrotFunc(zReal, zImage, cReal, cImage)
squareOfModulus = zReal * zReal + zImage * zImage
if squareOfModulus > common.threshold * common.threshold:
iotaC = i
break
return float(iotaC) / float(common.iterationLimit)
# use guvectorize to let Python do the same operation to all elements in the lists in parallel.
@guvectorize([(float64[:, :], float64[:, :], float64[:, :])],
'(m, n),(m, n)->(m, n)')
def doMForAll(reals2d, images2d, results):
for i in range(len(reals2d)):
for j in range(len(reals2d[0])):
results[i][j] = M(reals2d[i][j], images2d[i][j])
def calcResults():
reals, images, reals2d, images2d = common.genRealsImages()
results = np.zeros((len(images), len(reals)), dtype=np.float64)
doMForAll(reals2d, images2d, results)
return reals, images, reals2d, images2d, results
def main():
reals, images, _, _, results = calcResults()
common.saveDataToFile(reals, images, results, realsFile, imagesFile,
resultsFile)
reals, images, results = common.loadDataFromFile(realsFile, imagesFile,
resultsFile)
common.plotHot(reals, images, results)
if __name__ == "__main__":
main()