-
Notifications
You must be signed in to change notification settings - Fork 0
/
core_fit2_mcmc.py
77 lines (63 loc) · 2.04 KB
/
core_fit2_mcmc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import numpy as np
from core_fit2_util import *
import emcee
import pickle
def lnprior(theta):
dis_len, merg_len, infall_m = theta
if(-3.0 < dis_len < 1.0 and 0.0 < merg_len < 1.0 and 10.0 < infall_m < 15):
return 0.0
else:
return -np.inf
my_zmr_sdss= None
my_zmr_valid = None
my_clstrs = None
def set_my(a,b,c):
global my_zmr_sdss,my_zmr_valid,my_clstrs
my_zmr_sdss= a
my_zmr_valid = b
my_clstrs = c
def lnlike(theta,zmr_sdss,zmr_valid,clstrs):
dis_len, merg_len, infall_m = theta
zmr_core = zmr_from_clusters(dis_len,merg_len,clstrs,zmr_valid,infall_m)
cost = calc_gal_density_cost2(zmr_core,zmr_sdss)
return -cost
def lnprob(theta,zmr_sdss,zmr_valid,clstrs):
lnp = lnprior(theta)
if not np.isfinite(lnp):
return -np.inf
return lnp + lnlike(theta,zmr_sdss,zmr_valid,clstrs)
def lnlike2(theta):
dis_len, merg_len, infall_m = theta
zmr_core = zmr_from_clusters(dis_len,merg_len,my_clstrs,my_zmr_valid,infall_m)
cost = calc_gal_density_cost2(zmr_core,my_zmr_sdss)
return -cost
def lnprob2(theta):
lnp = lnprior(theta)
if not np.isfinite(lnp):
return -np.inf
return lnp + lnlike2(theta)
def run_mcmc(theta_0,nwalkers,niter,zmr_sdss,zmr_valid,clstrs):
dis_len, merg_len, infall_m = theta_0
ndim = len(theta_0)
pos = [theta_0 + 1e-2*np.random.randn(ndim) for i in range(nwalkers)]
set_my(zmr_sdss,zmr_valid,clstrs)
sampler = emcee.EnsembleSampler(nwalkers,ndim,lnprob2,threads=24)
#args=(zmr_sdss,zmr_valid,clstrs)
#s1 = pickle.dumps(zmr_sdss)
#s2 = pickle.dumps(zmr_valid)
#s3 = pickle.dumps(clstrs)
sampler.run_mcmc(pos,niter)
x = range(niter)
plt.figure()
plt.ylabel('dis_len')
for i in range(nwalkers):
plt.plot(x,sampler.chain[i,:,0],'b')
plt.figure()
plt.ylabel('meg_len')
for i in range(nwalkers):
plt.plot(x,sampler.chain[i,:,1],'b')
plt.figure()
plt.ylabel('infall_m')
for i in range(nwalkers):
plt.plot(x,sampler.chain[i,:,2],'b')
return sampler