Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#49 ability to plot requirements curves #112

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
7/13/18: Add ability to overplot requirements curves (issue #49)
7/11/18: Update code to python3
3/17/16: Update documentation to Sphinx standard and add documentation build files (issue #17)
2/17/16: Changes to correlation function plots & documentation (issue #77)
Expand Down
111 changes: 108 additions & 3 deletions stile/sys_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,7 +566,10 @@ def compensateDefault(self, data, data2, random, random2, both=False):
return 'compensated'

def plot(self, data, colors=['r', 'b'], log_yscale=False,
plot_bmode=True, plot_data_only=True, plot_random_only=True):
plot_bmode=True, plot_data_only=True, plot_random_only=True,
requirement_x=None, requirement_emode=None, requirement_bmode=None,
requirement_emode_range=None, requirement_bmode_range=None,
requirement_color=['gray', 'green'], requirement_linestyle='dashed'):
"""
Plot the data returned from a :class:`BaseCorrelationFunctionSysTest` object. This chooses
some sensible defaults, but much of its behavior can be changed.
Expand All @@ -580,12 +583,67 @@ def plot(self, data, colors=['r', 'b'], log_yscale=False,
[default: True]
:param plot_random_only: Whether to plot the random-only correlation functions, if present
[default: True]
:param requirement_x: The x-axis points of a requirement curve [default: None, meaning
do not plot a requirement curve]
:param requirement_emode: The e-mode points of a requirement curve [default: None, meaning
do not plot an e-mode requirement curve]
:param requirement_bmode: The b-mode points of a requirement curve [default: None, meaning
do not plot a b-mode requirement curve]
:param requirement_emode_range: A 2-item iterable defining the top and bottom (y-axis) edges
of an e-mode requirement range [default: None, meaning do
not plot]
:param requirement_bmode_range: A 2-item iterable defining the top and bottom (y-axis) edges
of a b-mode requirement range [default: None, meaning do
not plot]
:param requirement_color: The color(s) of the requirement plots. If iterable, the first
item will be used for the e-mode requirements and the second
for b-mode, if both are plotted; otherwise the first item
will be used for whichever mode is plotted. [default:
'gray', 'green']
:param requirement_linestyle: The linestyle of the requirements curves. If iterable, this
will follow the ordering rules of `requirement_color`.
[default: 'dashed']
:returns: A matplotlib ``Figure`` which may be written to a file with
:func:`.savefig()`, if matplotlib can be imported; else None.
"""

if not has_matplotlib:
return None
if requirement_emode_range and len(requirement_emode_range)!=2:
raise ValueError("requirement_emode_range must be a 2-item tuple")
if requirement_bmode_range and len(requirement_bmode_range)!=2:
raise ValueError("requirement_bmode_range must be a 2-item tuple")
if requirement_x is None and (requirement_emode is not None or requirement_bmode is not None
or requirement_emode_range is not None or requirement_bmode_range is not None):
print("Cannot plot requirement curves without requirement_x--skipping")
if (requirement_emode is not None or requirement_emode_range is not None):
if hasattr(requirement_color, '__iter__'):
reqecolor = requirement_color[0]
else:
reqecolor = requirement_color
if hasattr(requirement_linestyle, '__iter__') and not isinstance(requirement_linestyle, str):
reqels = requirement_linestyle[0]
else:
reqels = requirement_linestyle
else:
reqecolor = None
reqels = None
if (requirement_bmode is not None or requirement_bmode_range is not None):
if hasattr(requirement_color, '__iter__'):
if reqecolor:
reqbcolor = requirement_color[1]
else:
reqbcolor = requirement_color[0]
else:
reqbcolor = requirement_color
if hasattr(requirement_color, '__iter__') and not isinstance(requirement_linestyle, str):
if reqels:
reqbls = requirement_linestyle[1]
else:
reqbls = requirement_linestyle[0]
else:
reqbls = requirement_linestyle

fields = data.dtype.names
# Pick which radius measurement to use
# TreeCorr changed the name of the output columns
Expand Down Expand Up @@ -635,11 +693,22 @@ def plot(self, data, colors=['r', 'b'], log_yscale=False,
curr_plot = 0
ax = fig.add_subplot(nrows, 1, 1)
ax.axhline(0, alpha=0.7, color='gray')
if requirement_emode_range is not None:
ax.fill_between(requirement_x, requirement_emode_range[0], requirement_emode_range[1],
color=reqecolor, alpha=0.5)
if requirement_emode is not None:
ax.plot(requirement_x, requirement_emode, color=reqecolor, ls=reqels)
ax.errorbar(data[r], data[pd.t_field], yerr=data[pd.sigma_field], color=colors[0],
label=pd.t_title)
if pd.x_title and plot_bmode:
if requirement_bmode_range is not None:
ax.fill_between(requirement_x, requirement_bmode_range[0], requirement_bmode_range[1],
color=reqbcolor, alpha=0.5)
if requirement_bmode is not None:
ax.plot(requirement_x, requirement_bmode, color=reqbcolor, ls=reqbls)
ax.errorbar(data[r], data[pd.x_field], yerr=data[pd.sigma_field], color=colors[1],
label=pd.x_title)

elif pd.t_im_title: # Plot y and y_im if not plotting yb (else it goes on a separate plot)
ax.errorbar(data[r], data[pd.t_im_field], yerr=data[pd.sigma_field], color=colors[1],
label=pd.t_im_title)
Expand Down Expand Up @@ -2104,10 +2173,10 @@ def scatterPlot(self, x, y, yerr=None, z=None, xlabel=None, ylabel=None, zlabel=
if linear_regression:
if yerr is None:
m, c = self.linearRegression(x, y)
ax.plot(xtmp, m*xtmp+c, "--%s" % used_color)
ax.plot(xtmp, m*xtmp+c, linestyle="--", color=used_color)
else:
m, c, cov_m, cov_c, cov_mc = self.linearRegression(x, y, err=yerr)
ax.plot(xtmp, m*xtmp+c, "--%s" % used_color)
ax.plot(xtmp, m*xtmp+c, linestyle="--", color=used_color)
y = m*xtmp+c
# calculate yerr using the covariance
yerr = numpy.sqrt(xtmp**2*cov_m + 2.*xtmp*cov_mc + cov_c)
Expand Down Expand Up @@ -2229,6 +2298,42 @@ def getStatisticsPerCCD(self, ccds, x, y, yerr=None, z=None, stat="median"):
else:
raise ValueError('stat should be mean or median.')

def plot(self, results, requirement_x=None, requirement_y=None, requirement_y_range=None,
requirement_color='gray', requirement_linestyle='dashed'):
"""
This function can be used to add requirements to a plot generated by a previous call to
:class:`BaseScatterPlotSysTest`.

:param requirement_x: The x-axis points of a requirement curve [default: None, meaning do
not plot a requirement curve]
:param requirement_y: yhe x-axis points of a requirement curve [default: None, meaning do
not plot a requirement curve]
:param requirement_y_range: A 2-item iterable defining the top and bottom (y-axis) edges
the requirement range [default: None, meaning do not plot]
:param requirement_color: The color of the requirement curves. [default: 'gray']
:param requirement_linestyle: The linestyle of the requirements curves. [default: 'dashed']
:returns: A matplotlib ``Figure`` which may be written to a file with
:func:`.savefig()`, if matplotlib can be imported; else None.
"""
if hasattr(results, 'savefig'):
# "is not None" because numpy yells about needing any or all for arrays
if requirement_x is not None and (requirement_y is not None or requirement_y_range is not None):
ax = results.gca()
if requirement_y is not None:
ax.plot(requirement_x, requirement_y,
color=requirement_color, ls=requirement_linestyle)
if requirement_y_range is not None:
if not (hasattr(requirement_y_range, '__iter__')
and len(requirement_y_range)==2):
raise ValueError("requirement_y_range must be a 2-item iterable")
ax.fill_between(requirement_x, requirement_y_range[0], requirement_y_range[1],
color=requirement_color, alpha=0.5)
return results
else:
return PlotNone()



class ScatterPlotStarVsPSFG1SysTest(BaseScatterPlotSysTest):
"""
A class to make ScatterPlots of star vs PSF g1 values
Expand Down
34 changes: 34 additions & 0 deletions tests/test_correlation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,40 @@ def test_plot(self):
pl = obj.plot(results)
self.assertIsInstance(pl, matplotlib.figure.Figure)
pl.savefig('examine.png')

def test_requirements_curves(self):
""" Test that requirements curves can be plotted. """
stile_args = {'ra_units': 'degrees', 'dec_units': 'degrees', 'min_sep': 0.05, 'max_sep': 1,
'sep_units': 'degrees', 'nbins': 20}
lens_data = stile.ReadASCIITable('../examples/example_lens_catalog.dat',
fields={'id': 0, 'ra': 1, 'dec': 2, 'z': 3, 'g1': 4, 'g2': 5})
source_data = stile.ReadASCIITable('../examples/example_source_catalog.dat',
fields={'id': 0, 'ra': 1, 'dec': 2, 'z': 3, 'g1': 4, 'g2': 5})
cf = stile.CorrelationFunctionSysTest("GalaxyShear")
results = cf(lens_data, source_data, config=stile_args)
x = results['R_nom [deg]']
y = 0.001/x
yrange = [0.9*y, 1.1*y]
plot = cf.plot(results)
plot = cf.plot(results, requirement_x=x, requirement_emode=y)
plot = cf.plot(results, requirement_x=x, requirement_bmode=y)
plot = cf.plot(results, requirement_x=x, requirement_emode=y, requirement_bmode=y)
plot = cf.plot(results, requirement_x=x, requirement_emode=y,
requirement_emode_range=yrange)
plot = cf.plot(results, requirement_x=x, requirement_bmode=y,
requirement_bmode_range=yrange)
plot = cf.plot(results, requirement_x=x, requirement_emode=y,
requirement_emode_range=yrange, requirement_bmode=y,
requirement_bmode_range=yrange)
plot = cf.plot(results, requirement_x=x, requirement_emode=y,
requirement_bmode_range=yrange)
plot = cf.plot(results, requirement_x=x, requirement_emode=y, requirement_color='blue',
requirement_linestyle='dotted')
plot = cf.plot(results, requirement_x=x, requirement_emode=y, requirement_bmode=y,
requirement_color='blue', requirement_linestyle='dotted')
plot = cf.plot(results, requirement_x=x, requirement_emode=y, requirement_bmode=y,
requirement_color=['blue', 'orange'], requirement_linestyle=['dotted', 'solid'])

def test_generator(self):
"""Make sure the CorrelationFunctionSysTest() generator returns the right objects"""
object_list = ['GalaxyShear', 'BrightStarShear', 'StarXGalaxyDensity', 'StarXGalaxyShear',
Expand Down
36 changes: 36 additions & 0 deletions tests/test_sys_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import numpy
import unittest

try:
import stile
except ImportError:
sys.path.append('..')
import stile


class TestSysTests(unittest.TestCase):
def test_ScatterPlotRequirements(self):
data = stile.ReadASCIITable('../examples/example_source_catalog.dat',
fields={'id': 0, 'ra': 1, 'dec': 2, 'z': 3, 'g1': 4, 'g2': 5})
test_data = numpy.rec.fromarrays([data['ra'], data['dec'],
data['g1'], data['g2'], data['g1']+0.05, data['g2']+0.01, 0.1*data['g1'], 0.1*data['g2']],
names=['ra', 'dec', 'g1', 'g2', 'psf_g1', 'psf_g2', 'g1_err', 'g2_err'])
req_x = numpy.array([numpy.min(data['g1']), numpy.max(data['g1'])])
req_y = req_x
req_y_range = [req_y*0.9, req_y*1.1]

scatterplot = stile.ScatterPlotSysTest('StarVsPSFG1')
results = scatterplot(test_data)
results.savefig('p1.png')
res = scatterplot.plot(results, requirement_x=req_x, requirement_y=req_y)
res.savefig('p2.png')
res = scatterplot.plot(results, requirement_x=req_x, requirement_y=req_y, requirement_color='orange', requirement_linestyle='dotted')
res.savefig('p3.png')
res = scatterplot.plot(results, requirement_x=req_x, requirement_y=req_y, requirement_y_range=req_y_range)
res.savefig('p4.png')
res = scatterplot.plot(results, requirement_x=req_x, requirement_y=req_y, requirement_y_range=req_y_range, requirement_color='orange', requirement_linestyle='dotted')
res.savefig('p5.png')

if __name__=='__main__':
unittest.main()