Source code for zfinder.plotter

""" Class to plot the results of the zfinder """

import csv
import warnings
from itertools import zip_longest

import numpy as np
import matplotlib.pyplot as plt
from astropy.wcs import WCS
from astropy.io import fits

from zfinder.template import calc_template_params, gaussf, find_lines
from zfinder.fft import calc_fft_params, double_damped_sinusoid
from zfinder.uncertainty import z_uncert
from zfinder.utils import wcs2pix, generate_square_pix_coords, longest_decimal

warnings.filterwarnings(action="ignore", message='Some errors were detected !')
warnings.filterwarnings("ignore", module='astropy.wcs.wcs')

[docs] class Plotter(): """ Class to plot the results of zfinder. Includes methods of plotting exported csv data Parameters ---------- showfig : bool, optional Whether to show the figure after plotting. Default is True. savefig : bool, optional Whether to save the figure after plotting. Default is True. Attributes ---------- unit_prefixes : dict Dictionary of unit prefixes for plotting Methods ------- calc_best_z() Calculate the best redshift plot_chi2() Plot the chi-squared vs redshift plot_template_flux() Plot the template flux plot_fft_flux() Plot the fft flux plot_heatmap() Plot a heatmap of the redshifts export_template_data() Export the Template fit data to a csv file export_fft_data() Export the FFT data to a csv file plot_chi2_fromcsv() Plot the chi-squared vs redshift from a csv file plot_template_flux_fromcsv() Plot the template flux from a csv file plot_fft_flux_fromcsv() Plot the fft flux from a csv file plot_heatmap_fromcsv() Plot a heatmap of the redshifts from a csv file Examples -------- >>> # After csv files exported with zfinder >>> source = Plotter() >>> source.plot_chi2_fromcsv('template.csv') >>> source.plot_template_flux_fromcsv() >>> >>> source.plot_chi2_fromcsv('fft.csv') >>> source.plot_fft_flux_fromcsv() >>> >>> source.plot_heatmap_fromcsv('template_per_pixel.csv') >>> source.plot_heatmap_fromcsv('fft_per_pixel.csv') """ unit_prefixes = { -24 : 'y', -21 : 'z', -18 : 'a', -15 : 'f', -12 : 'p', -9 : 'n', -6 : '\u03BC', -3 : 'm', 0 : '', 3 : 'k', 6 : 'M', 9 : 'G', 12 : 'T', 15 : 'P', 18 : 'E', 21 : 'Z', 24 : 'Y'} def __init__(self, showfig=True, savefig=False): self._showfig = showfig self._savefig = savefig self._template_best_z = None self._fft_best_z = None self._round_to = 2 self._freq_exp = None self._flux_exp = None
[docs] def calc_best_z(self, z, chi2, title=None): """ Calculate the best redshift Parameters ---------- z : list Array of redshifts chi2 : list Array of chi-squared values Returns ------- best_z : float The best redshift """ best_z = z[np.argmin(chi2)] if title == 'Template': self._template_best_z = best_z elif title.upper() == 'FFT': self._fft_best_z = best_z return best_z
[docs] def plot_chi2(self, z, dz, chi2, title): """ Plot the chi-sqaured vs redshift """ self._z = z self._chi2 = chi2 min_chi2 = min(chi2) best_z = self.calc_best_z(z, chi2, title) self._round_to = len(str(dz).split('.')[1]) plt.figure(figsize=(15,7)) plt.plot(z, chi2, color='black', label='$\chi^2_r$') plt.plot(best_z, min_chi2, 'bo', markersize=5, label='Best Fit') plt.title(f'{title} $\chi^2_r$ = {round(min_chi2, 2)} @ z={round(best_z, self._round_to)}', fontsize=15) plt.xlabel('Redshift', fontsize=15) plt.ylabel('$\chi^2_r$', x=0.01, fontsize=15) plt.xticks(fontsize=15) plt.yticks(fontsize=15) plt.yscale('log') plt.legend() if self._savefig: plt.savefig(f'{title.lower()}_chi2.png') if self._showfig: plt.show()
@staticmethod def _plot_sslf_lines(frequency, flux): """ Helper function to plot sslf lines in the found flux """ peaks, snrs, scales = find_lines(flux) text_offset_high = max(flux)/20 text_offset_low = 0.4*text_offset_high for i, line in enumerate(peaks): x = frequency[line] y = flux[line] plot, = plt.plot(x, y, 'bo') plt.text(x, y+text_offset_high, f'snr={snrs[i]}', color='blue') plt.text(x, y+text_offset_low, f'scale={scales[i]}', color='blue') if i == 0: plot.set_label('Lines')
[docs] def plot_template_flux(self, transition, frequency, freq_exp, flux, flux_exp): """ Plot the template flux Parameters ---------- transition : float The transition frequency frequency : list Array of frequencies freq_exp : int The exponent of the frequency flux : list Array of fluxes flux_exp : int The exponent of the flux """ self._frequency = frequency self._flux = flux self._freq_exp = freq_exp self._flux_exp = flux_exp plt.figure(figsize=(15,7)) plt.plot(frequency, np.zeros(len(frequency)), color='black', linestyle=(0, (5, 5))) plt.plot(frequency, flux, color='black', drawstyle='steps-mid') if self._template_best_z is None: raise ValueError("No best redshift found. Run Plotter.plot_chi2() first.") x0 = transition/(1+self._template_best_z) self._params, covars = calc_template_params(frequency, flux, x0) self._p_err = np.sqrt(np.diag(covars)) # calculate the error on the parameters plt.plot(frequency, gaussf(frequency, *self._params, x0), color='red', label='Template Fit') self._plot_sslf_lines(frequency, flux) plt.margins(x=0) plt.fill_between(frequency, flux, 0, where=(np.array(flux) > 0), color='gold', alpha=0.75, label='Aperture Flux') plt.title(f'Template Fit z={round(self._template_best_z, self._round_to)}', fontsize=15) plt.xlabel(f'Frequency $({self.unit_prefixes[freq_exp]}Hz)$', fontsize=15) plt.ylabel(f'Flux $({self.unit_prefixes[flux_exp]}Jy)$', fontsize=15) plt.xticks(fontsize=15) plt.yticks(fontsize=15) plt.legend() if self._savefig: plt.savefig('template_flux.png') if self._showfig: plt.show()
[docs] def plot_fft_flux(self, transition, frequency, ffreq, fflux): """ Plot the fft flux Parameters ---------- transition : float The transition frequency frequency : list Array of frequencies ffreq : list Array of fft frequencies fflux : list Array of fft fluxes """ self._ffreq = ffreq self._fflux = fflux plt.figure(figsize=(15,7)) plt.plot(ffreq, fflux, color='black', drawstyle='steps-mid', label='FFT Flux') plt.plot(ffreq, np.zeros(len(fflux)), color='black', linestyle=(0, (5, 5))) if self._fft_best_z is None: raise ValueError("No best redshift found. Run Plotter.plot_chi2() first.") self._params, covars = calc_fft_params(transition, ffreq, fflux, self._fft_best_z, frequency[0]) self._p_err = np.sqrt(np.diag(covars)) plt.plot(ffreq, double_damped_sinusoid(ffreq, *self._params, self._fft_best_z, frequency[0], transition), color='red', label='FFT Fit') plt.margins(x=0) plt.title(f'FFT Fit z={round(self._fft_best_z, self._round_to)}', fontsize=15) plt.xlabel('Scale', fontsize=15) plt.ylabel('Amplitude', fontsize=15) plt.xticks(fontsize=15) plt.yticks(fontsize=15) plt.legend() if self._savefig: plt.savefig('fft_flux.png') if self._showfig: plt.show()
@staticmethod def _export_heatmap_data(filename, mode, data): """ Export the redshifts, velocities, and fluxes to a csv file """ with open(filename, mode, newline='') as f: wr = csv.writer(f) wr.writerows(data) wr.writerow('')
[docs] def plot_heatmap(self, ra, dec, hdr, data, size, z, title, aperture_radius, flux_limit, export=False, subsize=None, contfile=None): """ Plot a heatmap of the redshifts Parameters ---------- ra : list[string] Target right ascension dec : list[string] Target declination hdr : astropy.io.fits.header.Header The header of the fits file data : astropy.io.fits.hdu.image.PrimaryHDU The data of the fits file size : int The size of the square in pixels z : list Array of redshifts title : string The title method of the plot. Either 'Template' or 'FFT' aperture_radius : float The radius of the aperture in pixels flux_limit : float The flux limit to mask the velocities export : bool, optional Whether to export the redshifts to a csv file. Default is True. subsize : int, optional Plot a smaller heatmap of the redshifts. Default is None. """ # Calculate the velocities target_z = np.take(z, z.size // 2) # redshift of the target ra and dec velocities = 3*10**5*((((1 + target_z)**2 - 1) / ((1 + target_z)**2 + 1)) - (((1 + z)**2 - 1) / ((1 + z)**2 + 1))) # km/s # Need to get x and y coordinates to plot the heatmap with bounds for correct ra and dec target_pix_ra_dec = wcs2pix(ra, dec, hdr) x, y = generate_square_pix_coords(size, *target_pix_ra_dec, aperture_radius) # Mask velocities lower than the flux limit if contfile is not None: data_summed = fits.getdata(contfile) else: data_summed = np.sum(np.maximum(data, 0), axis=0) uy = np.round(np.unique(y)).astype(int) ux = np.round(np.unique(x)).astype(int) fluxes = data_summed[uy][:, ux] mask = fluxes < flux_limit if subsize is not None: mask = self._extract_centered_subarray(mask, subsize) uy = self._extract_centered_subarray_1d(uy, subsize) ux = self._extract_centered_subarray_1d(ux, subsize) velocities[mask] = np.nan if export: self._export_heatmap_data(f'{title.lower()}_per_pixel.csv', 'a', velocities) # export redshifts to csv self._export_heatmap_data(f'{title.lower()}_per_pixel.csv', 'a', fluxes) # export redshifts to csv cmap = plt.cm.seismic cmap.set_bad('black') scale_velo = np.nanmax(np.abs(velocities)) # velocities = np.flipud(velocities) w = WCS(hdr, naxis=2) plt.figure(figsize=(7,5)) plt.subplot(projection=w) hm = plt.imshow(np.flipud(velocities), cmap=cmap, interpolation='nearest', vmin=-scale_velo, vmax=scale_velo, extent=[ux[0], ux[-1], uy[0], uy[-1]], origin='lower') cbar = plt.colorbar(hm) cbar.ax.set_ylabel('km/s', fontsize=15) cbar.ax.tick_params(labelsize=15) plt.title(f'{title} Per Pixel', fontsize=15) plt.xlabel('RA', fontsize=15) plt.ylabel('DEC', fontsize=15) plt.xticks(fontsize=15) plt.yticks(fontsize=15) if self._savefig: plt.savefig(f'{title.lower()}_per_pixel.png') if self._showfig: plt.show()
[docs] def plot_coords(self, x_centre, y_centre, x_coords, y_coords, radius, fitsfile=None): if fitsfile is not None: hdr = fits.getheader(fitsfile) w = WCS(hdr, naxis=2) fig, ax = plt.subplots(subplot_kw={'projection': w}) else: fig, ax = plt.subplots() circ = plt.Circle((x_centre, y_centre), radius, fill=False, color='blue', label='_nolegend_') fig.set_figwidth(7) fig.set_figheight(7) plt.subplots_adjust(left=0.2, right=0.8, top=0.8, bottom=0.2) ax.add_patch(circ) plt.scatter(x_coords, y_coords, color='blue', label='Random') plt.scatter(x_centre, y_centre, color='black', label='Target') plt.title(f'{len(x_coords)} random points') plt.xlim(-radius-1+x_centre, radius+1+x_centre) plt.ylim(-radius-1+y_centre, radius+1+y_centre) plt.xlabel('RA', fontsize=15) plt.ylabel('DEC', fontsize=15) plt.legend(loc='upper left') if self._savefig: plt.savefig('Point Distribution.png', dpi=200) if self._showfig: plt.show()
@staticmethod def _write_method_to_csv(filename, headings, data): with open(filename, 'w', newline='') as f: wr = csv.writer(f) wr.writerow(headings) wr.writerows(data) def _z_uncert(self, sigma, flux): """ Caclulate the uncertainty on the best fitting redshift """ peaks, _, _ = find_lines(flux) reduced_sigma = sigma**2 / (len(flux) - 2*len(peaks) - 1) neg, pos = z_uncert(self._z, self._chi2, reduced_sigma) return neg, pos def _calculate_results(self, sigma, best_z, flux): z_low_err, z_up_err = self._z_uncert(sigma, flux) results = [[z_low_err], [round(best_z, self._round_to)], [z_up_err], [self._params[0]], [self._p_err[0]], [self._params[1]], [self._p_err[1]]] if self._freq_exp is None: return results, None exponents = [[self._freq_exp], [self._flux_exp]] return results, exponents def _export_template_data(self, fitsfile, ra, dec, aperture_radius, transition, filename='template.csv', sigma=1, flux_uncert=1): """ Export the Template fit data to a csv file Ensure that Plotter.plot_template_flux() and Plotter.plot_chi2() have been run first. Parameters ---------- filename : string, optional The filename of the csv file. Default is 'template.csv'. sigma : float, optional The sigma value for the uncertainty on the redshift. Default is 1. flux_uncert : float, optional The uncertainty on the flux. Default is 1. """ results, exponents = self._calculate_results(sigma, self._template_best_z, self._flux) headings = ['z_low_err', 'z', 'z_up_err', 'amp', 'amp_err', 'std_dev', 'std_dev_err', 'fitsfile', 'ra', 'dec', 'aperture_radius', 'transition', 'dz', 'chi2_r', 'freq', 'flux', 'flux_uncert', 'freq_exp', 'flux_exp'] common = [[fitsfile], [ra], [dec], [aperture_radius], [transition]] data = [*zip_longest(*results, *common, self._z, self._chi2, self._frequency, self._flux, flux_uncert, *exponents, fillvalue='')] self._write_method_to_csv(filename, headings, data) def _export_fft_data(self, fitsfile, ra, dec, aperture_radius, transition, filename='fft.csv', sigma=1, frequency=None, flux=None, flux_uncert=1): """ Export the FFT data to a csv file Ensure that Plotter.plot_fft_flux() and Plotter.plot_chi2() have been run first. Parameters ---------- filename : string, optional The filename of the csv file. Default is 'fft.csv'. sigma : float, optional The sigma value for the uncertainty on the redshift. Default is 1. flux_uncert : float, optional """ results, _ = self._calculate_results(sigma, self._fft_best_z, flux) headings = ['z_low_err', 'z', 'z_up_err', 'amp', 'amp_err', 'std_dev', 'std_dev_err', 'fitsfile', 'ra', 'dec', 'aperture_radius', 'transition', 'dz', 'chi2_r', 'frequency', 'flux', 'ffreq', 'fflux', 'fflux_uncert'] if type(flux_uncert) == int: flux_uncert = [flux_uncert] common = [[fitsfile], [ra], [dec], [aperture_radius], [transition]] data = [*zip_longest(*results, *common, self._z, self._chi2, frequency, flux, self._ffreq, self._fflux, flux_uncert, fillvalue='')] self._write_method_to_csv(filename, headings, data)
[docs] def plot_chi2_fromcsv(self, filename): """ Plot the chi-squared vs redshift from a csv file """ z, chi2 = np.genfromtxt(filename, delimiter=',', skip_header=1, usecols=(12,13)).T z = z[~np.isnan(z)] chi2 = chi2[~np.isnan(chi2)] dz = longest_decimal(z) self.plot_chi2(z, dz, chi2, filename.split('.')[0].capitalize())
[docs] def plot_template_flux_fromcsv(self, filename='template.csv'): """ Plot the template flux from a csv file """ transition, frequency, flux, freq_exp, flux_exp = np.genfromtxt(filename, delimiter=',', skip_header=1, usecols=(11, 14, 15, 17, 18)).T transition = transition[0] frequency = frequency[~np.isnan(frequency)] flux = flux[~np.isnan(flux)] freq_exp = int(freq_exp[0]) flux_exp = int(flux_exp[0]) self.plot_template_flux(transition, frequency, freq_exp, flux, flux_exp)
[docs] def plot_fft_flux_fromcsv(self, filename='fft.csv'): """ Plot the fft flux from a csv file """ transition, frequency, ffreq, fflux = np.genfromtxt(filename, delimiter=',', skip_header=1, usecols=(11, 14, 16, 17)).T transition = transition[0] frequency = frequency[~np.isnan(frequency)] ffreq = ffreq[~np.isnan(ffreq)] fflux = fflux[~np.isnan(fflux)] self.plot_fft_flux(transition, frequency, ffreq, fflux)
[docs] def plot_heatmap_fromcsv(self, filename, subsize=None, flux_limit=None, contfile=None): """ Plot a heatmap of the redshifts from a csv file """ dtypes = [('fitsfile', 'U100'), ('ra', 'U100'), ('dec', 'U100'), ('aperture_radius', 'f8'), ('transition', 'f8'), ('size', 'i4'), ('flux_limit', 'f8')] fitsfile, ra, dec, aperture_radius, _, size, flux_lim = \ np.genfromtxt(filename, delimiter=',', skip_header=1, dtype=dtypes, max_rows=2, invalid_raise=False).tolist() z = np.genfromtxt(filename, delimiter=',', skip_header=3, skip_footer=size*2) if flux_limit is not None: flux_lim = flux_limit if subsize is not None: z = self._extract_centered_subarray(z, subsize) hdr = fits.getheader(fitsfile) data = fits.getdata(fitsfile)[0] self.plot_heatmap(ra, dec, hdr, data, size, z, filename.split('_')[0].capitalize(), aperture_radius, flux_lim, subsize=subsize, contfile=contfile)
[docs] def plot_coords_fromcsv(self, filename='fft_uncertainty.csv'): """ Plot the distribution of random points from a csv file """ x_centre, y_centre, x_coords, y_coords, radius = np.genfromtxt(filename, delimiter=',', skip_header=1, usecols=(1,2,3,4,8)).T fitsfile = np.genfromtxt(filename, delimiter=',', skip_header=1, dtype='U100', usecols=(10)).T x_centre = x_centre[0] y_centre = y_centre[0] x_coords = x_coords[~np.isnan(x_coords)] y_coords = y_coords[~np.isnan(y_coords)] radius = radius[0] fitsfile = fitsfile[0] self.plot_coords(x_centre, y_centre, x_coords, y_coords, radius, fitsfile)
@staticmethod def _extract_centered_subarray(original_array, new_size): """ Extract a centered subarray from a 2D array """ start_row = (original_array.shape[0] - new_size) // 2 end_row = start_row + new_size start_col = (original_array.shape[1] - new_size) // 2 end_col = start_col + new_size centered_subarray = original_array[start_row:end_row, start_col:end_col] return centered_subarray @staticmethod def _extract_centered_subarray_1d(original_array, new_size): """ Extract a centered subarray from a 1D array """ start_index = (original_array.shape[0] - new_size) // 2 end_index = start_index + new_size centered_subarray = original_array[start_index:end_index] return centered_subarray