Source code for nwaylib.fastskymatch

"""
Functions for finding pairs within a certain search distance "err".

Very fast method based on hashing.
"""
from __future__ import division, print_function

import itertools
import os
from collections import defaultdict

import astropy.io.fits as pyfits
import astropy.units as u
import healpy
import joblib
import numpy
import tqdm
from astropy.coordinates import SkyCoord, SkyOffsetFrame
from numpy import arccos, arcsin, arctan2, cos, exp, hypot, log, pi, sin

cachedir = 'cache'
os.makedirs(cachedir, exist_ok=True)
mem = joblib.Memory(cachedir, verbose=False)


[docs] def dist(apos, bpos): """ Angular separation (in degrees) between two points on a sphere. http://en.wikipedia.org/wiki/Great-circle_distance """ (a_ra, a_dec), (b_ra, b_dec) = apos, bpos lon1 = a_ra / 180 * pi lat1 = a_dec / 180 * pi lon2 = b_ra / 180 * pi lat2 = b_dec / 180 * pi sdlon = sin(lon2 - lon1) cdlon = cos(lon2 - lon1) slat1 = sin(lat1) slat2 = sin(lat2) clat1 = cos(lat1) clat2 = cos(lat2) num1 = clat2 * sdlon num2 = clat1 * slat2 - slat1 * clat2 * cdlon denominator = slat1 * slat2 + clat1 * clat2 * cdlon return arctan2(hypot(num1, num2), denominator) * 180 / pi
[docs] def dist3d(apos, bpos): """ Angular separation in ra & dec between two points on a sphere. """ (a_ra, a_dec), (b_ra, b_dec) = apos, bpos a_ra = numpy.where(a_ra == -99, numpy.nan, a_ra) a_dec = numpy.where(a_dec == -99, numpy.nan, a_dec) b_ra = numpy.where(b_ra == -99, numpy.nan, b_ra) b_dec = numpy.where(b_dec == -99, numpy.nan, b_dec) with numpy.errstate(divide='ignore', invalid='ignore'): a = SkyCoord(a_ra, a_dec, frame="icrs", unit="deg") b = SkyCoord(b_ra, b_dec, frame="icrs", unit="deg") localframe = SkyOffsetFrame(origin=a) na = a.transform_to(localframe) nb = b.transform_to(localframe) dra = na.lon - nb.lon ddec = na.lat - nb.lat separation = a.separation(b).to(u.degree).value dra = dra.to(u.degree).value ddec = ddec.to(u.degree).value return separation, dra, ddec
[docs] def get_tablekeys(table, name, tablename=''): keys = sorted(table.dtype.names, key=lambda k: 0 if k.upper() == name else 1 if k.upper().startswith(name) else 2) assert len(keys) > 0 and name in keys[0].upper(), 'ERROR: No "%s" column found in input catalogue "%s". Only have: %s' % (name, tablename, ', '.join(table.dtype.names)) return keys[0]
[docs] def get_healpix_resolution_degrees(nside): resol = healpy.pixelfunc.nside2resol(nside) / pi * 180 # according to monte carlo simulations, distances up to this factor # are completely contained within the pixel and its neighbors resfactor = 0.7 return resfactor * resol
@mem.cache(ignore=['logger']) def crossproduct(radectables, err, logger, pairwise_errs=[]): # check if away from the poles and RA=0 use_flat_bins = True for ra, dec in radectables: if not (err < 1 and (ra > 10 * err).all() and (ra < 360 - 10 * err).all() and (numpy.abs(dec) < 45).all()): use_flat_bins = False break if use_flat_bins: logger.log('matching: using fast flat-sky approximation for this match') else: # choose appropriate nside for err (in deg) nside = 1 for nside_next in range(30): # largest distance still contained within pixels dist_neighbors_complete = get_healpix_resolution_degrees(2**nside_next) # we are looking for a pixel size which ensures bigger distances than the error radius # but we want the smallest pixels possible, to reduce the cartesian product if dist_neighbors_complete < err: # too small, do not accept # sources within err will be outside the neighbor pixels break nside = 2**nside_next resol = get_healpix_resolution_degrees(nside) * 60 * 60 logger.log('matching: healpix hashing on pixel resolution ~ %f arcsec (nside=%d)' % (resol, nside)) buckets = defaultdict(lambda : [[] for _ in range(len(radectables))]) primary_cat_keys = None pbar = tqdm.tqdm(total=sum([len(t[0]) for t in radectables])) for ti, (ra_table, dec_table) in enumerate(radectables): if use_flat_bins: for ei, (ra, dec) in enumerate(zip(ra_table, dec_table)): i, j = int(ra / err), int(dec / err) # put in bucket, and neighbors for jj, ii in (j,i), (j,i+1), (j+1,i), (j+1, i+1): k = (ii, jj) # only primary catalogue is allowed to define new buckets if ti == 0 or k in buckets: buckets[k][ti].append(ei) pbar.update() else: # get healpixels ra, dec = ra_table, dec_table phi = ra / 180 * pi theta = dec / 180 * pi + pi/2. i = healpy.pixelfunc.ang2pix(nside, phi=phi, theta=theta, nest=True) j = healpy.pixelfunc.get_all_neighbours(nside, phi=phi, theta=theta, nest=True) # only consider four neighbours in one direction (N) # does not work, sometimes A is south of B, but B is east of A # so need to consider all neighbors, and deduplicate later neighbors = numpy.hstack((i.reshape((-1,1)), j.transpose())) # put in bucket, and neighbors if ti == 0: # only primary catalogue is allowed to define new buckets for ei, keys in enumerate(neighbors): for k in keys: buckets[k][ti].append(ei) pbar.update() else: for ei, keys in enumerate(neighbors): for k in keys: if k in primary_cat_keys: buckets[k][ti].append(ei) pbar.update() if ti == 0: primary_cat_keys = set(buckets.keys()) pbar.close() # add no-counterpart options results = set() # now combine within buckets logger.log('matching: collecting from %d buckets, creating cartesian products ...' % len(buckets)) #print('matching: %6d matches expected after hashing' % numpy.sum([ # len(lists[0]) * numpy.product([len(li) + 1 for li in lists[1:]]) # for lists in buckets.values()])) #pbar = logger.progress(ndigits=5, maxval=len(buckets)).start() pbar = tqdm.tqdm(total=len(buckets)) while buckets: k, lists = buckets.popitem() pbar.update() # add for secondary catalogues the option of missing source for li in lists[1:]: li.insert(0, -1) # create the cartesian product local_results = itertools.product(*[sorted(li) for li in lists]) # if pairwise filtering is requested, use it to trim down solutions if pairwise_errs: local_results = numpy.array(list(local_results)) #nstart = len(local_results) for tablei, tablej, errij in pairwise_errs: indicesi = local_results[:,tablei] indicesj = local_results[:,tablej] # first find entries that actually have both entries mask_both = numpy.logical_and(indicesi >= 0, indicesj >= 0) #if not mask_both.any(): # continue # get the RA/Dec rai, deci = radectables[tablei] raj, decj = radectables[tablej] rai, deci = rai[indicesi[mask_both]], deci[indicesi[mask_both]] raj, decj = raj[indicesj[mask_both]], decj[indicesj[mask_both]] # compute distances mask_good = dist((rai, deci), (raj, decj)) < errij * 60 * 60 # select the ones where one is missing, or those within errij mask_good2 = ~mask_both mask_good2[mask_both][mask_good] = True #print(mask_good2.sum(), mask_both.shape, mask_both.sum(), mask_good.sum()) local_results = local_results[mask_good2,:] # print("compression:%d/%d" % (len(local_results), nstart)) results.update([tuple(li) for li in local_results]) else: results.update(local_results) del local_results pbar.close() n = len(results) logger.log('matching: %6d unique matches from cartesian product. sorting ...' % n) # now make results unique by sorting results = numpy.array(sorted(results)) return results # use preferred newer astropy command if available if hasattr(pyfits.BinTableHDU, 'from_columns'): fits_from_columns = pyfits.BinTableHDU.from_columns else: fits_from_columns = pyfits.new_table
[docs] def match_multiple(tables, table_names, err, fits_formats, logger, circular=True, pairwise_errs=[]): """ computes the cartesian product of all possible matches, limited to a maximum distance of err (in degrees). tables: input FITS table table_names: names of the tables fits_formats: FITS data type of the columns of each table returns results: cartesian product of all possible matches (smaller than err) cat_columns: table with separation distances in arcsec header: which columns were used in each table for RA/DEC """ logger.log('') logger.log('matching with %f arcsec radius' % (err * 60 * 60)) logger.log('matching: %6d naive possibilities' % numpy.prod([len(t) for t in tables])) logger.log('matching: hashing') ra_keys = [get_tablekeys(table, 'RA', tablename=tablename) for table, tablename in zip(tables, table_names)] logger.log(' using RA columns: %s' % ', '.join(ra_keys)) dec_keys = [get_tablekeys(table, 'DEC', tablename=tablename) for table, tablename in zip(tables, table_names)] logger.log(' using DEC columns: %s' % ', '.join(dec_keys)) ratables = [(t[ra_key], t[dec_key]) for t, ra_key, dec_key in zip(tables, ra_keys, dec_keys)] resultstable = crossproduct(ratables, err, logger=logger, pairwise_errs=pairwise_errs) results = resultstable.view(dtype=[(table_name, resultstable.dtype) for table_name in table_names]).reshape((-1,)) keys = [] for table_name, table in zip(table_names, tables): keys += ["%s_%s" % (table_name, n) for n in table.dtype.names] logger.log('merging in %d columns from input catalogues ...' % sum([1 + len(table.dtype.names) for table in tables])) cat_columns = [] pbar = tqdm.tqdm(total=sum([1 + len(table.dtype.names) for table in tables])) for table, table_name, fits_format in zip(tables, table_names, fits_formats): tbl = table[results[table_name]] # set missing to nan mask_missing = results[table_name] == -1 pbar.update() for n, format in zip(table.dtype.names, fits_format): k = "%s_%s" % (table_name, n) keys.append(k) col = tbl[n] #print(' setting "%s" to -99 (%d affected; column format "%s")' % (k, mask_missing.sum(), format)) try: col[mask_missing] = -99 except Exception as e: logger.log(' setting "%s" to -99 failed (%d affected; column format "%s"): %s' % (k, mask_missing.sum(), format, e)) fitscol = pyfits.Column(name=k, format=format, array=col) cat_columns.append(fitscol) pbar.update() pbar.close() tbhdu = fits_from_columns(pyfits.ColDefs(cat_columns)) header = dict( COLS_RA = ' '.join(["%s_%s" % (ti, ra_key) for ti, ra_key in zip(table_names, ra_keys)]), COLS_DEC = ' '.join(["%s_%s" % (ti, dec_key) for ti, dec_key in zip(table_names, dec_keys)]) ) logger.log(' adding angular separation columns') max_separation = numpy.zeros(len(results)) for i in range(len(tables)): a_ra = tbhdu.data["%s_%s" % (table_names[i], ra_keys[i])] a_dec = tbhdu.data["%s_%s" % (table_names[i], dec_keys[i])] for j in range(i): k = "Separation_%s_%s" % (table_names[i], table_names[j]) k1 = k + "_ra" k2 = k + "_dec" if circular: keys += [k, k1, k2] else: keys += [k] b_ra = tbhdu.data["%s_%s" % (table_names[j], ra_keys[j])] b_dec = tbhdu.data["%s_%s" % (table_names[j], dec_keys[j])] if circular: col = dist((a_ra, a_dec), (b_ra, b_dec)) else: col, col_ra, col_dec = dist3d((a_ra, a_dec), (b_ra, b_dec)) valid_input = numpy.logical_and(a_ra != -99, b_ra != -99) assert not numpy.isnan(col[valid_input]).any(), ['%d distances are nan' % numpy.isnan(col[valid_input]).sum(), a_ra[numpy.isnan(col)], a_dec[numpy.isnan(col)], b_ra[numpy.isnan(col)], b_dec[numpy.isnan(col)]] col[a_ra == -99] = numpy.nan col[b_ra == -99] = numpy.nan if not circular: col_ra[a_ra == -99] = numpy.nan col_ra[b_ra == -99] = numpy.nan col_dec[a_ra == -99] = numpy.nan col_dec[b_ra == -99] = numpy.nan max_separation = numpy.nanmax([col * 60 * 60, max_separation], axis=0) # store distance in arcsec cat_columns.append(pyfits.Column(name=k, format='E', array=col * 60 * 60)) if not circular: cat_columns.append(pyfits.Column(name=k1, format='E', array=col_ra * 60 * 60)) cat_columns.append(pyfits.Column(name=k2, format='E', array=col_dec * 60 * 60)) cat_columns.append(pyfits.Column(name="Separation_max", format='E', array=max_separation)) cat_columns.append(pyfits.Column(name="ncat", format='I', array=(resultstable > -1).sum(axis=1))) keys.append("Separation_max") mask = max_separation < err * 60 * 60 for c in cat_columns: c.array = c.array[mask] logger.log('matching: %6d matches after filtering by search radius' % mask.sum()) logger.log('') return results[mask], cat_columns, header
[docs] def wraptable2fits(cat_columns, extname): tbhdu = fits_from_columns(pyfits.ColDefs(cat_columns)) hdu = pyfits.PrimaryHDU() import datetime import time now = datetime.datetime.fromtimestamp(time.time()) nowstr = now.isoformat() nowstr = nowstr[:nowstr.rfind('.')] hdu.header['DATE'] = nowstr hdu.header['ANALYSIS'] = 'NWAY matching' tbhdu.header['EXTNAME'] = extname hdulist = pyfits.HDUList([hdu, tbhdu]) return hdulist
[docs] def array2fits(table, extname): cat_columns = pyfits.ColDefs([pyfits.Column(name=n, format='E',array=table[n]) for n in table.dtype.names]) return wraptable2fits(cat_columns, extname)