Skip to content

Commit

Permalink
add an option to stack the wavelet transforms for different ells
Browse files Browse the repository at this point in the history
  • Loading branch information
tgastine committed Dec 1, 2023
1 parent bec3045 commit d1dd004
Showing 1 changed file with 60 additions and 20 deletions.
80 changes: 60 additions & 20 deletions python/magic/coeff.py
Original file line number Diff line number Diff line change
Expand Up @@ -1090,9 +1090,14 @@ def movieRad(self, cut=0.5, levels=12, cm='RdYlBu_r', png=False, step=1,
else:
fig.savefig(filename, dpi=dpi)

def fft(self):
def fft(self, pcolor=False, cm='turbo'):
"""
Fourier transform of the poloidal potential
:param pcolor: this is a switch to use pcolormesh instead of contourf
:type pcolor: bool
:param cm: the name of the colormap (default is 'turbo')
:type cm: char
"""

dt = np.diff(self.time)
Expand Down Expand Up @@ -1133,8 +1138,12 @@ def fft(self):
levs = np.linspace(vmin, vmax, 129)
fig = plt.figure()
ax = fig.add_subplot(111)
im = ax.contourf(ls, self.omega, dat, levs, cmap=plt.get_cmap('turbo'),
extend='both')
if pcolor:
im = ax.pcolormesh(ls, self.omega, dat, cmap=plt.get_cmap(cm),
vmin=vmin, vmax=vmax)
else:
im = ax.contourf(ls, self.omega, dat, levs, cmap=plt.get_cmap(cm),
extend='both')

ax.set_yscale('log')
cbar = fig.colorbar(im)
Expand All @@ -1144,7 +1153,8 @@ def fft(self):

fig.tight_layout()

def cwt(self, ell, w0=20, nfreq=256, fmin_fac=8):
def cwt(self, ell, w0=20, nfreq=256, fmin_fac=8, fmax_fac=0.5,
cm='turbo', logscale=False):
"""
Build a time-frequency spectrum at a given degree :math:`\ell` using
a continuous wavelet transform with morlet wavelets.
Expand All @@ -1156,11 +1166,21 @@ def cwt(self, ell, w0=20, nfreq=256, fmin_fac=8):
given by fmin=1/(time[-1]-time[0]), such that
the minimum frequency retained is fmin_fac*fmin
:type fmin_fac: float
:param fmax_fac: a factor to adjust the maximum frequency retained
in the time-frequency domain. Maximum frequency is
given by fmax=fmax_fac*fcut, where fcut is 1/dt.
:type fmax_fac: float
:param ell: spherical harmonic degree at which ones want to build
the time frequency diagram
the time frequency diagram. If one gives a negative
number, then the sum over all mode is computed.
:type ell: int
:param nfreq: number of frequency bins
:type nfreq: int
:param cm: the name of the colormap (default is 'turbo')
:type cm: char
:param logscale: when turned to True, this displays the amplitude in
logarithmic scale (linear by default)
:type logscale: bool
"""
assert version > '1.4.0'

Expand All @@ -1179,33 +1199,53 @@ def cwt(self, ell, w0=20, nfreq=256, fmin_fac=8):
fmin = 1./(time[-1]-time[0]) # Minimum frequency

#self.omega = 2.*np.pi*np.linspace(fmin*fmin_fac, fcut/2, 100)
self.omega = np.logspace(np.log10(fmin*fmin_fac), np.log10(fcut/2), nfreq)
self.omega = np.logspace(np.log10(fmin*fmin_fac), np.log10(fmax_fac*fcut), nfreq)
self.omega *= 2.*np.pi
# Define the widths of the wavelets (related to their frequency)
widths = w0*fcut/self.omega

#widths = np.arange(1, len(self.time)//8)
self.ek_time_omega = np.zeros((len(widths), self.nstep), np.float64)
for m in range(0, ell+1, self.minc):
print(m)
lm = self.idx[ell, m]
out = signal.cwt(wlm[:, lm], signal.morlet2, widths, w=w0)
if ell < 0: # Then the sum is computed over all ell's
for ll in range(1, self.l_max_r+1):
print(ll)
for m in range(0, ll+1, self.minc):
lm = self.idx[ll, m]
out = signal.cwt(wlm[:, lm], signal.morlet2, widths, w=w0)

if m == 0:
tmp = 0.5*abs(out)**2
else:
tmp = abs(out)**2
if m == 0:
tmp = 0.5*abs(out)**2
else:
tmp = abs(out)**2

self.ek_time_omega += tmp
else:
for m in range(0, ell+1, self.minc):
print(m)
lm = self.idx[ell, m]
out = signal.cwt(wlm[:, lm], signal.morlet2, widths, w=w0)

self.ek_time_omega += tmp
if m == 0:
tmp = 0.5*abs(out)**2
else:
tmp = abs(out)**2

dat = np.log10(self.ek_time_omega)
vmax = dat.max()#-0.5
vmin = max(vmax-10, dat.min()+2)
levs = np.linspace(vmin, vmax, 129)
self.ek_time_omega += tmp

if logscale:
dat = np.log10(self.ek_time_omega)
vmax = dat.max()#-0.5
vmin = max(vmax-10, dat.min()+2)
levs = np.linspace(vmin, vmax, 64)
else:
dat = self.ek_time_omega
levs = 64
fig = plt.figure()
ax = fig.add_subplot(111)
im = ax.contourf(time, self.omega, dat, levs, cmap=plt.get_cmap('turbo'),
im = ax.contourf(time, self.omega, dat, levs, cmap=plt.get_cmap(cm),
extend='both')
#im = ax.pcolormesh(time, self.omega, dat, cmap=plt.get_cmap('turbo'),
# shading='gouraud')

ax.set_yscale('log')
cbar = fig.colorbar(im)
Expand Down

0 comments on commit d1dd004

Please sign in to comment.