diff --git a/CHANGELOG.md b/CHANGELOG.md index d07c664..f08b29f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,4 @@ +7/13/18: Fix usage of randoms in correlation functions (issue #85) 7/11/18: Update code to python3 3/17/16: Update documentation to Sphinx standard and add documentation build files (issue #17) 2/17/16: Changes to correlation function plots & documentation (issue #77) diff --git a/stile/sys_tests.py b/stile/sys_tests.py index 276dbb9..bffe1ff 100644 --- a/stile/sys_tests.py +++ b/stile/sys_tests.py @@ -364,6 +364,16 @@ def makeCatalog(self, data, config=None, use_as_k=None, use_chip_coords=False): catalog_kwargs['config'] = config return treecorr.Catalog(**catalog_kwargs) + def hasObjects(self, cat): + if hasattr(cat, '__len__'): + try: + return len(cat)>0 + except TypeError: + return len(numpy.atleast_1d(cat))>0 + elif hasattr(cat, 'nobj'): + return cat.nobj>0 + return False + def getCF(self, correlation_function_type, data, data2=None, random=None, random2=None, use_as_k=None, use_chip_coords=False, config=None, **kwargs): @@ -429,11 +439,12 @@ def getCF(self, correlation_function_type, data, data2=None, # First, pull out the TreeCorr-relevant parameters from the stile_args dict, and add # anything passed as a kwarg to that dict. - if (random and len(random)) or (random2 and len(random2)): - treecorr_kwargs[correlation_function_type+'_statistic'] = \ - treecorr_kwargs.get(correlation_function_type+'_statistic', 'compensated') treecorr_kwargs = stile.treecorr_utils.PickTreeCorrKeys(config) treecorr_kwargs.update(stile.treecorr_utils.PickTreeCorrKeys(kwargs)) + if ((self.hasObjects(random) or self.hasObjects(random2)) + and correlation_function_type in ['nn', 'ng', 'nk']): + treecorr_kwargs[correlation_function_type+'_statistic'] = \ + treecorr_kwargs.get(correlation_function_type+'_statistic', 'compensated') treecorr.config.check_config(treecorr_kwargs, corr2_valid_params) if data is None: @@ -500,11 +511,11 @@ def getCF(self, correlation_function_type, data, data2=None, func_dr = None elif correlation_function_type == 'nn': func_random = treecorr_func_dict[correlation_function_type](treecorr_kwargs) - if len(random2): + if self.hasObjects(random2): func_random.process(random, random2) else: func_random.process(random) - if not len(data2): + if not self.hasObjects(data2): func_rr = treecorr_func_dict['nn'](treecorr_kwargs) func_rr.process(data, random) if treecorr_kwargs.get(['nn_statistic'], @@ -519,7 +530,7 @@ def getCF(self, correlation_function_type, data, data2=None, else: func_rr = treecorr_func_dict['nn'](treecorr_kwargs) func_rr.process(random, random2) - if treecorr_kwargs.get(['nn_statistic'], + if treecorr_kwargs.get('nn_statistic', self.compensateDefault(data, data2, random, random2, both=True) ) == 'compensated': func_dr = treecorr_func_dict['nn'](treecorr_kwargs) @@ -555,10 +566,10 @@ def compensateDefault(self, data, data2, random, random2, both=False): indicates that both data sets if present must have randoms; the default, False, means only the first data set must have an associated random. """ - if not random or (random and not len(random)): # No random + if not self.hasObjects(random): # No random return 'simple' - elif both and data2 and len(data2): # Second data set exists and must have a random - if random2 and len(random2): + elif both and self.hasObjects(data2): # Second data set exists and must have a random + if self.hasObjects(random2): return 'compensated' else: return 'simple' diff --git a/tests/test_correlation_functions.py b/tests/test_correlation_functions.py index 90d191f..7206ff6 100755 --- a/tests/test_correlation_functions.py +++ b/tests/test_correlation_functions.py @@ -124,6 +124,21 @@ def test_generator(self): self.assertEqual(type(stile.sys_tests.BaseCorrelationFunctionSysTest()), type(stile.CorrelationFunctionSysTest())) + def test_randoms(self): + """ Run some correlation functions with random catalogs to make sure that also works. """ + stile_args = {'ra_units': 'degrees', 'dec_units': 'degrees', 'min_sep': 0.05, 'max_sep': 1, + 'sep_units': 'degrees', 'nbins': 20} + lens_data = stile.ReadASCIITable('../examples/example_lens_catalog.dat', + fields={'id': 0, 'ra': 1, 'dec': 2, 'z': 3, 'g1': 4, 'g2': 5}) + source_data = stile.ReadASCIITable('../examples/example_source_catalog.dat', + fields={'id': 0, 'ra': 1, 'dec': 2, 'z': 3, 'g1': 4, 'g2': 5}) + + object_list = ['GalaxyShear', 'BrightStarShear', 'StarXGalaxyDensity', 'StarXGalaxyShear', + 'StarXStarShear', 'GalaxyDensityCorrelation', 'StarDensityCorrelation'] + for object_type in object_list: + object_1 = stile.CorrelationFunctionSysTest(object_type) + results = object_1(lens_data, source_data, lens_data, source_data, config=stile_args) + if __name__=='__main__': unittest.main()