Skip to content
This repository has been archived by the owner on Sep 11, 2024. It is now read-only.

Astrocat bugs #102

Open
wants to merge 14 commits into
base: master
Choose a base branch
from
Open
152 changes: 110 additions & 42 deletions astrodbkit/astrocat.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,16 @@
import astropy.table as at
import astropy.coordinates as coord
import datetime
from sklearn.cluster import DBSCAN
from collections import Counter
from scipy.stats import norm
# from scipy.stats import norm
from astroquery.vizier import Vizier
from astroquery.xmatch import XMatch
from sklearn.externals import joblib
from astropy.coordinates import SkyCoord
import pandas as pd
from bokeh.plotting import ColumnDataSource, figure, output_file, show
from bokeh.io import output_notebook, show
from sklearn.cluster import DBSCAN
from sklearn.externals import joblib

Vizier.ROW_LIMIT = -1

Expand All @@ -33,7 +36,7 @@ def __init__(self, name='Test'):
The name of the database
"""
self.name = name
self.catalog = pd.DataFrame(columns=('id','ra','dec','flag','datasets'))
self.sources = pd.DataFrame(columns=('id', 'ra', 'dec', 'flag', 'datasets'))
self.n_sources = 0
self.history = "{}: Database created".format(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
self.catalogs = {}
Expand All @@ -49,7 +52,40 @@ def info(self):
"""
print(self.history)

def add_source(self, ra, dec, flag='', radius=10*q.arcsec):
def plot(self, cat_name, x, y, **kwargs):
"""
Plot the named columns of the given attribute if the value is a pandas.DataFrame

Parameters
----------
cat_name: str
The attribute name
x: str
The name of the column to plot on the x-axis
y: str
The name of the column to plot on the y-axis
"""
# Get the attribute
if isinstance(cat_name, str) and hasattr(self, cat_name):
attr = getattr(self, cat_name)
else:
print('No attribute named',cat_name)
return

# Make sure the attribute is a DataFrame
if isinstance(attr, pd.core.frame.DataFrame):
ds = ColumnDataSource(attr)
myPlot = figure()
myPlot.xaxis.axis_label = x
myPlot.yaxis.axis_label = y
myPlot.circle(x, y, source=ds)
plt = show(myPlot, notebook_handle=True)

else:
print(cat_name,'is not a Pandas DataFrame!')
return

def add_source(self, ra, dec, flag='', radius=10*q.arcsec, catalogs={}):
"""
Add a source to the catalog manually and find data in existing catalogs

Expand All @@ -63,9 +99,12 @@ def add_source(self, ra, dec, flag='', radius=10*q.arcsec):
A flag for the source
radius: float
The cross match radius for the list of catalogs
catalogs: dict
Additional catalogs to search, e.g.
catalogs={'TMASS':{'cat_loc':'II/246/out', 'id_col':'id', 'ra_col':'RAJ2000', 'dec_col':'DEJ2000'}}
"""
# Get the id
id = int(len(self.catalog)+1)
# Set the id
self.n_sources += 1

# Check the coordinates
ra = ra.to(q.deg)
Expand All @@ -74,10 +113,16 @@ def add_source(self, ra, dec, flag='', radius=10*q.arcsec):

# Search the catalogs for this source
for cat_name,params in self.catalogs.items():
self.Vizier_query(params['cat_loc'], cat_name, ra, dec, radius, ra_col=params['ra_col'], dec_col=params['dec_col'], append=True, group=False)
self.Vizier_query(params['cat_loc'], cat_name, ra, dec, radius, ra_col=params['ra_col'], dec_col=params['dec_col'], append=True, force_id=self.n_sources, group=False)

# Search additional catalogs
for cat_name,params in catalogs.items():
if cat_name not in self.catalogs:
self.Vizier_query(params['cat_loc'], cat_name, ra, dec, radius, ra_col=params['ra_col'], dec_col=params['dec_col'], force_id=self.n_sources, group=False)

# Add the source to the catalog
self.catalog = self.catalog.append([id, ra.value, dec.value, flag, datasets], ignore_index=True)
new_cat = pd.DataFrame([[self.n_sources, ra.value, dec.value, flag, datasets]], columns=self.sources.columns)
self.sources = self.sources.append(new_cat, ignore_index=True)

def delete_source(self, id):
"""
Expand All @@ -89,18 +134,18 @@ def delete_source(self, id):
The id of the source in the catalog
"""
# Set the index
self.catalog.set_index('id')
self.sources.set_index('id')

# Exclude the unwanted source
self.catalog = self.catalog[self.catalog.id!=id]
self.sources = self.sources[self.sources.id!=id]

# Remove the records from the catalogs
for cat_name in self.catalogs:
new_cat = getattr(self, cat_name)[getattr(self, cat_name).source_id!=id]
print('{} records removed from {} catalog'.format(int(len(getattr(self, cat_name))-len(new_cat)), cat_name))
setattr(self, cat_name, new_cat)

def ingest_data(self, data, cat_name, id_col, ra_col='_RAJ2000', dec_col='_DEJ2000', cat_loc='', append=False, count=-1):
def ingest_data(self, data, cat_name, id_col, ra_col='_RAJ2000', dec_col='_DEJ2000', cat_loc='', append=False, delimiter='\t', force_id='', count=-1):
"""
Ingest a data file and regroup sources

Expand All @@ -120,6 +165,8 @@ def ingest_data(self, data, cat_name, id_col, ra_col='_RAJ2000', dec_col='_DEJ20
The location of the original catalog data
append: bool
Append the catalog rather than replace
force_id: int
Assigns a specific id in the catalog
count: int
The number of table rows to add
(This is mainly for testing purposes)
Expand All @@ -133,7 +180,7 @@ def ingest_data(self, data, cat_name, id_col, ra_col='_RAJ2000', dec_col='_DEJ20

if isinstance(data, str):
cat_loc = cat_loc or data
data = pd.read_csv(data, sep='\t', comment='#', engine='python')[:count]
data = pd.read_csv(data, sep=delimiter, comment='#', engine='python')[:count]

elif isinstance(data, pd.core.frame.DataFrame):
cat_loc = cat_loc or type(data)
Expand Down Expand Up @@ -167,7 +214,7 @@ def ingest_data(self, data, cat_name, id_col, ra_col='_RAJ2000', dec_col='_DEJ20
data.insert(0,'catID', ['{}_{}'.format(cat_name,n+1) for n in range(last,last+len(data))])
data.insert(0,'dec_corr', data['dec'])
data.insert(0,'ra_corr', data['ra'])
data.insert(0,'source_id', np.nan)
data.insert(0,'source_id', force_id or np.nan)

print('Ingesting {} rows from {} catalog...'.format(len(data),cat_name))

Expand All @@ -185,7 +232,7 @@ def ingest_data(self, data, cat_name, id_col, ra_col='_RAJ2000', dec_col='_DEJ20
except AttributeError:
print("No catalog named '{}'. Set 'append=False' to create it.".format(cat_name))

def inventory(self, source_id):
def inventory(self, source_id, return_inventory=False):
"""
Look at the inventory for a given source

Expand All @@ -203,15 +250,30 @@ def inventory(self, source_id):
print('Please enter an integer between 1 and',self.n_sources)

else:

print('Source:')
print(at.Table.from_pandas(self.catalog[self.catalog['id']==source_id]).pprint())

# Empty inventory
inv = {}

# Add the record from the source table
inv['source'] = at.Table.from_pandas(self.sources[self.sources['id']==source_id])

for cat_name in self.catalogs:
cat = getattr(self, cat_name)
rows = cat[cat['source_id']==source_id]
if not rows.empty:
print('\n{}:'.format(cat_name))
at.Table.from_pandas(rows).pprint()
inv[cat_name] = at.Table.from_pandas(rows)

if return_inventory:

# Return the data
return inv

else:

# Print out the data in each catalog
for cat_name, data in inv.items():
print('\n',cat_name,':')
data.pprint()

def _catalog_check(self, cat_name, append=False):
"""
Expand Down Expand Up @@ -262,7 +324,7 @@ def SDSS_spectra_query(self, cat_name, ra, dec, radius, group=True, **kwargs):
if self._catalog_check(cat_name):

# Prep the current catalog as an astropy.QTable
tab = at.Table.from_pandas(self.catalog)
tab = at.Table.from_pandas(self.sources)

# Cone search Vizier
print("Searching SDSS for sources within {} of ({}, {}). Please be patient...".format(viz_cat, radius, ra, dec))
Expand All @@ -280,7 +342,7 @@ def SDSS_spectra_query(self, cat_name, ra, dec, radius, group=True, **kwargs):
if len(self.catalogs)>1 and group:
self.group_sources(self.xmatch_radius)

def Vizier_query(self, viz_cat, cat_name, ra, dec, radius, ra_col='RAJ2000', dec_col='DEJ2000', columns=["**"], append=False, group=True, **kwargs):
def Vizier_query(self, viz_cat, cat_name, ra, dec, radius, ra_col='RAJ2000', dec_col='DEJ2000', columns=["**", "+_r"], append=False, force_id='', group=True, nrows=-1, **kwargs):
"""
Use astroquery to search a catalog for sources within a search cone

Expand All @@ -304,6 +366,8 @@ def Vizier_query(self, viz_cat, cat_name, ra, dec, radius, ra_col='RAJ2000', dec
The list of columns to pass to astroquery
append: bool
Append the catalog rather than replace
force_id: int
Assigns a specific id in the catalog
"""
# Verify the cat_name
if self._catalog_check(cat_name, append=append):
Expand All @@ -312,24 +376,28 @@ def Vizier_query(self, viz_cat, cat_name, ra, dec, radius, ra_col='RAJ2000', dec
print("Searching {} for sources within {} of ({}, {}). Please be patient...".format(viz_cat, radius, ra, dec))
crds = coord.SkyCoord(ra=ra, dec=dec, frame='icrs')
V = Vizier(columns=columns, **kwargs)
V.ROW_LIMIT = -1
V.ROW_LIMIT = nrows

try:
data = V.query_region(crds, radius=radius, catalog=viz_cat)[0]

# Add the link to original record
data['record'] = ['http://vizier.u-strasbg.fr/viz-bin/VizieR-5?-ref=VIZ5b17f9660734&-out.add=.&-source={}&recno={}'.format(viz_cat,n+1) for n in range(len(data))]

except:
print("No data found in {} within {} of ({}, {}).".format(viz_cat, radius, ra, dec))
return

# Ingest the data
self.ingest_data(data, cat_name, 'id', ra_col=ra_col, dec_col=dec_col, cat_loc=viz_cat, append=append)
self.ingest_data(data, cat_name, 'id', ra_col=ra_col, dec_col=dec_col, cat_loc=viz_cat, append=append, force_id=force_id)

# Regroup
if len(self.catalogs)>1 and group:
self.group_sources(self.xmatch_radius)

def Vizier_xmatch(self, viz_cat, cat_name, ra_col='_RAJ2000', dec_col='_DEJ2000', radius='', group=True):
"""
Use astroquery to pull in and cross match a catalog with sources in self.catalog
Use astroquery to pull in and cross match a catalog with sources in self.sources

Parameters
----------
Expand All @@ -341,7 +409,7 @@ def Vizier_xmatch(self, viz_cat, cat_name, ra_col='_RAJ2000', dec_col='_DEJ2000'
The matching radius
"""
# Make sure sources have been grouped
if self.catalog.empty:
if self.sources.empty:
print('Please run group_sources() before cross matching.')
return

Expand All @@ -351,7 +419,7 @@ def Vizier_xmatch(self, viz_cat, cat_name, ra_col='_RAJ2000', dec_col='_DEJ2000'
viz_cat = "vizier:{}".format(viz_cat)

# Prep the current catalog as an astropy.QTable
tab = at.Table.from_pandas(self.catalog)
tab = at.Table.from_pandas(self.sources)

# Crossmatch with Vizier
print("Cross matching {} sources with {} catalog. Please be patient...".format(len(tab), viz_cat))
Expand Down Expand Up @@ -413,12 +481,12 @@ def group_sources(self, radius='', plot=False):
unique_coords = np.asarray([np.mean(coords[source_ids==id], axis=0) for id in list(set(source_ids))])

# Generate a source catalog
self.catalog = pd.DataFrame(columns=('id','ra','dec','flag','datasets'))
self.catalog['id'] = unique_source_ids
self.catalog[['ra','dec']] = unique_coords
self.catalog['flag'] = [None]*len(unique_source_ids)
# self.catalog['flag'] = ['d{}'.format(i) if i>1 else '' for i in Counter(source_ids).values()]
self.catalog['datasets'] = Counter(source_ids).values()
self.sources = pd.DataFrame(columns=('id','ra','dec','flag','datasets'))
self.sources['id'] = unique_source_ids
self.sources[['ra','dec']] = unique_coords
self.sources['flag'] = [None]*len(unique_source_ids)
# self.sources['flag'] = ['d{}'.format(i) if i>1 else '' for i in Counter(source_ids).values()]
self.sources['datasets'] = Counter(source_ids).values()

# Update history
self.history += "\n{}: Catalog grouped with radius {} arcsec.".format(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"), self.xmatch_radius)
Expand Down Expand Up @@ -495,7 +563,7 @@ def load(self, path):
DB = joblib.load(path)

# Load the attributes
self.catalog = DB.catalog
self.sources = DB.catalog
self.n_sources = DB.n_sources
self.name = DB.name
self.history = DB.history
Expand Down Expand Up @@ -535,11 +603,11 @@ def correct_offsets(self, cat_name, truth='ACS'):
else:

# First, remove any previous catalog correction
self.catalog.loc[self.catalog['cat_name']==cat_name, 'ra_corr'] = self.catalog.loc[self.catalog['cat_name']==cat_name, '_RAJ2000']
self.catalog.loc[self.catalog['cat_name']==cat_name, 'dec_corr'] = self.catalog.loc[self.catalog['cat_name']==cat_name, '_DEJ2000']
self.sources.loc[self.sources['cat_name']==cat_name, 'ra_corr'] = self.sources.loc[self.sources['cat_name']==cat_name, '_RAJ2000']
self.sources.loc[self.sources['cat_name']==cat_name, 'dec_corr'] = self.sources.loc[self.sources['cat_name']==cat_name, '_DEJ2000']

# Copy the catalog
onc_gr = self.catalog.copy()
onc_gr = self.sources.copy()

# restrict to one-to-one matches, sort by oncID so that matches are paired
o2o_new = onc_gr.loc[(onc_gr['oncflag'].str.contains('o')) & (onc_gr['cat_name'] == cat_name) ,:].sort_values('oncID')
Expand Down Expand Up @@ -582,8 +650,8 @@ def correct_offsets(self, cat_name, truth='ACS'):

# Update the coordinates of the appropriate sources
print('Shifting {} sources by {}" in RA and {}" in Dec...'.format(cat_name,mu_ra,mu_dec))
self.catalog.loc[self.catalog['cat_name']==cat_name, 'ra_corr'] += mu_ra
self.catalog.loc[self.catalog['cat_name']==cat_name, 'dec_corr'] += mu_dec
self.sources.loc[self.sources['cat_name']==cat_name, 'ra_corr'] += mu_ra
self.sources.loc[self.sources['cat_name']==cat_name, 'dec_corr'] += mu_dec

# Update history
now = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
Expand Down Expand Up @@ -625,7 +693,7 @@ def default_rename_columns(cat_name):
defaults = {'2MASS':{'JD':'epoch', 'Qflg':'flags', 'Jmag':'2MASS.J', 'Hmag':'2MASS.H', 'Kmag':'2MASS.Ks', 'e_Jmag':'2MASS.J_unc', 'e_Hmag':'2MASS.H_unc', 'e_Kmag':'2MASS.Ks_unc'},
'WISE':{'qph':'flags', 'W1mag':'WISE.W1', 'W2mag':'WISE.W2', 'W3mag':'WISE.W3', 'W4mag':'WISE.W4', 'e_W1mag':'WISE.W1_unc', 'e_W2mag':'WISE.W2_unc', 'e_W3mag':'WISE.W3_unc', 'e_W4mag':'WISE.W4_unc'},
'SDSS':{'ObsDate':'epoch', 'flags':'oflags', 'Q':'flags', 'umag':'SDSS.u', 'gmag':'SDSS.g', 'rmag':'SDSS.r', 'imag':'SDSS.i', 'zmag':'SDSS.z', 'e_umag':'SDSS.u_unc', 'e_gmag':'SDSS.g_unc', 'e_rmag':'SDSS.r_unc', 'e_imag':'SDSS.i_unc', 'e_zmag':'SDSS.z_unc'},
'TGAS':{'Epoch':'epoch', 'Plx':'parallax', 'e_Plx':'parallax_unc'}}
'GAIA':{'Epoch':'epoch', 'Plx':'parallax', 'e_Plx':'parallax_unc', 'Gmag':'Gaia.G', 'e_Gmag':'Gaia.G_unc', 'BPmag':'Gaia.BP', 'e_BPmag':'Gaia.BP_unc', 'RPmag':'Gaia.RP', 'e_RPmag':'Gaia.RP_unc'}}

return defaults[cat_name]

Expand All @@ -646,7 +714,7 @@ def default_column_fill(cat_name):
defaults = {'2MASS':{'publication_shortname':'Cutr03', 'telescope_id':2, 'instrument_id':5, 'system_id':2},
'WISE':{'publication_shortname':'Cutr13', 'telescope_id':3, 'instrument_id':6, 'system_id':2},
'SDSS':{'publication_shortname':'Alam15', 'telescope_id':6, 'instrument_id':9, 'system_id':2},
'TGAS':{'publication_shortname':'Gaia16', 'telescope_id':4, 'instrument_id':7, 'system_id':1}}
'GAIA':{'publication_shortname':'Gaia18', 'telescope_id':4, 'instrument_id':7, 'system_id':1}}

return defaults[cat_name]

12 changes: 10 additions & 2 deletions astrodbkit/astrodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def create_database(dbpath, schema='', overwrite=True):

# Load the schema if given
if schema:
os.system("cat {} | sqlite3 {}".format(schema,dbpath))
os.system("cat {} | sqlite3 {}".format(schema, dbpath))

# Otherwise just make an empty SOURCES table
else:
Expand Down Expand Up @@ -326,8 +326,16 @@ def add_data(self, data, table, delimiter='|', bands='', clean_up=True, rename_c
# Rename columns
if isinstance(rename_columns,str):
rename_columns = astrocat.default_rename_columns(rename_columns)

try_again = []
for input_name,new_name in rename_columns.items():
data.rename_column(input_name,new_name)
try:
data.rename_column(input_name,new_name)
except KeyError:
try_again.append(input_name)

for input_name in try_again:
data.rename_column(input_name,rename_columns[input_name])

# Add column fills
if isinstance(column_fill,str):
Expand Down
Loading