Source code for fastspecfit.templates

"""
fastspecfit.templates
=====================

Tools for handling templates.

"""
import os
import numpy as np
import numba

import fitsio
from astropy.table import Table

from fastspecfit.logger import log

VDISP_NOMINAL = 250. # [km/s]
VDISP_BOUNDS = (75., 500.) # [km/s]

[docs] class Templates(object): """Stellar population synthesis templates for continuum fitting. Reads SPS template spectra from a FITS file and pre-caches FFTs for velocity-dispersion convolution up to :attr:`MAX_PRE_VDISP` km/s. Parameters ---------- template_file : :class:`str` or None, optional Full path to the SPS template FITS file. Auto-detected from ``FTEMPLATES_DIR`` when ``None``. template_version : :class:`str` or None, optional Template version string used to build the filename when ``template_file`` is ``None``. Defaults to :attr:`DEFAULT_TEMPLATEVERSION`. imf : :class:`str` or None, optional Initial mass function used to build the filename when ``template_file`` is ``None``. Defaults to :attr:`DEFAULT_IMF`. mintemplatewave : :class:`float` or None, optional Minimum wavelength to load into memory (Angstroms). Uses the minimum available wavelength when ``None``. maxtemplatewave : :class:`float`, optional Maximum wavelength to load into memory (Angstroms). Default is 400 000 Å. vdisp_nominal : :class:`float`, optional Nominal velocity dispersion in km/s used to pre-broaden the templates. Default is :data:`VDISP_NOMINAL`. vdisp_bounds : tuple of float, optional ``(min, max)`` velocity dispersion bounds in km/s. Default is :data:`VDISP_BOUNDS`. fastphot : :class:`bool`, optional If ``True``, load in photometry-only mode. Default is ``False``. read_linefluxes : :class:`bool`, optional If ``True``, also read the model emission-line flux arrays. Default is ``False``. """ # SPS template constants (used by build-templates) # https://github.com/cconroy20/fsps/tree/master/SPECTRA/C3K#readme PIXKMS = 25. # [km/s] PIXKMS_BOUNDS = (2750., 9100.) AGN_PIXKMS = 75. # [km/s] AGN_PIXKMS_BOUNDS = (1075., 3090.) DEFAULT_TEMPLATEVERSION = '2.0.0' DEFAULT_IMF = 'chabrier' # highest vdisp for which we attempt to use cached FFTs MAX_PRE_VDISP = 500. def __init__(self, template_file=None, template_version=None, imf=None, mintemplatewave=None, maxtemplatewave=40e4, vdisp_nominal=VDISP_NOMINAL, vdisp_bounds=VDISP_BOUNDS, fastphot=False, read_linefluxes=False): self.init_ffts() if template_file is None: if template_version is None: template_version = Templates.DEFAULT_TEMPLATEVERSION if imf is None: imf = Templates.DEFAULT_IMF template_file = self.get_templates_filename(template_version=template_version, imf=imf) if not os.path.isfile(template_file): errmsg = f'Templates file {template_file} not found.' log.critical(errmsg) raise IOError(errmsg) self.file = template_file T = fitsio.FITS(template_file) templatewave = T['WAVE'].read() # [npix] wavehdr = T['WAVE'].read_header() # [npix] templateflux = T['FLUX'].read() # [npix x nsed] templatelineflux = T['LINEFLUX'].read() # [npix x nsed] templateinfo = T['METADATA'].read() templatehdr = T['METADATA'].read_header() templateflux = np.transpose(templateflux).copy() templatelineflux = np.transpose(templatelineflux).copy() self.version = T[0].read_header()['VERSION'] self.imf = templatehdr['IMF'] self.ntemplates = len(templateinfo) if mintemplatewave is None: mintemplatewave = np.min(templatewave) keeplo = np.searchsorted(templatewave, mintemplatewave, 'left') keephi = np.searchsorted(templatewave, maxtemplatewave, 'right') self.wave = templatewave[keeplo:keephi] self.flux = templateflux[:, keeplo:keephi] self.flux_nolines = self.flux - templatelineflux[:, keeplo:keephi] self.npix = len(self.wave) # dust attenuation curve self.dust_klambda = Templates.klambda(self.wave) self.vdisp_nominal = vdisp_nominal # [km/s] self.vdisp_bounds = vdisp_bounds # [km/s] pixkms_bounds = np.searchsorted(self.wave, Templates.PIXKMS_BOUNDS, 'left') self.pixkms_bounds = pixkms_bounds self.conv_pre = self.convolve_vdisp_pre(self.flux) self.flux_nomvdisp = self.convolve_vdisp(self.flux, vdisp_nominal) self.conv_pre_nolines = self.convolve_vdisp_pre(self.flux_nolines) self.flux_nolines_nomvdisp = self.convolve_vdisp(self.flux_nolines, vdisp_nominal) self.info = Table(templateinfo) if 'DUSTFLUX' in T and 'AGNFLUX' in T: from fastspecfit.util import trapz # make sure fluxes are normalized to unity dustflux = T['DUSTFLUX'].read() #dustflux /= trapz(dustflux, x=templatewave) # should already be 1.0 self.dustflux = dustflux[keeplo:keephi] #dusthdr = T['DUSTFLUX'].read_header() #self.qpah = dusthdr['QPAH'] #self.umin = dusthdr['UMIN'] #self.gamma = dusthdr['GAMMA'] # construct the AGN wavelength vector iragnflux = T['AGNFLUX'].read() iragnwave = T['AGNWAVE'].read() #iragnflux /= trapz(iragnflux, x=iragnwave) # should already be 1.0 trim = np.searchsorted(iragnwave, 1e4, 'left') # hack... iragnflux = iragnflux[trim:] iragnwave = iragnwave[trim:] feflux = T['FEFLUX'].read() fewave = T['FEWAVE'].read() febounds = np.searchsorted(templatewave, Templates.AGN_PIXKMS_BOUNDS, 'left') irbounds = np.searchsorted(templatewave, iragnwave[0], 'left') agnwave = np.hstack((templatewave[:febounds[0]], fewave, templatewave[febounds[1]:irbounds], iragnwave)) #self.agnwave = agnwave #agnhdr = T['AGNFLUX'].read_header() #self.agntau = agnhdr['AGNTAU'] else: errmsg = f'Templates file {template_file} missing mandatory extensions DUSTFLUX and AGNFLUX.' log.critical(errmsg) raise IOError(errmsg) # Read the model emission-line fluxes; only present for # template_version>=1.1.1 and generally only useful to a power-user. if read_linefluxes: self.linewaves = T['LINEWAVES'].read() self.linefluxes = T['LINEFLUXES'].read()
[docs] def init_ffts(self): """Configure the FFT backend and set the convolution function.""" import scipy.fft as sc_fft import scipy.signal as sc_sig from importlib.util import find_spec if find_spec("mkl_fft") is not None: import mkl_fft._scipy_fft_backend as be sc_fft.set_global_backend(be) self.convolve = sc_sig.convolve log.debug('Using mkl_fft library for FFTs') else: self.convolve = sc_sig.oaconvolve
[docs] @staticmethod def get_templates_filename(template_version, imf): """Build the SPS template filename from ``FTEMPLATES_DIR``, version, and IMF.""" template_dir = os.path.expandvars(os.environ.get('FTEMPLATES_DIR')) template_file = os.path.join(template_dir, template_version, f'ftemplates-{imf}-{template_version}.fits') return template_file
[docs] def convolve_vdisp_pre(self, templateflux): """Precompute FFT data to accelerate repeated velocity-dispersion convolutions. Parameters ---------- templateflux : :class:`numpy.ndarray`, shape (ntemplates, nwavelengths) Array of template flux spectra. Returns ------- conv_pre : :class:`tuple` ``(flux_lohi, ft_flux_mid, fft_len)`` where ``flux_lohi`` contains the raw fluxes outside :attr:`PIXKMS_BOUNDS`, ``ft_flux_mid`` is the FFT of the central wavelength range, and ``fft_len`` is the padded FFT length. """ import scipy.fft as sc_fft # determine largest kernel we will support # based on the maximum supported vdisp. pixsize_kms = Templates.PIXKMS sigma = Templates.MAX_PRE_VDISP / pixsize_kms # [pixels] radius = Templates._gaussian_radius(sigma) kernel_size = 2*radius + 1 lo, hi = self.pixkms_bounds # extract the un-convolved ranges of templateflux flux_lo = templateflux[:, :lo] flux_hi = templateflux[:, hi:] flux_mid = templateflux[:, lo:hi] mid_len = flux_mid.shape[1] fft_len = sc_fft.next_fast_len(mid_len + kernel_size - 1, real=True) ft_flux_mid = sc_fft.rfft(flux_mid, n=fft_len) return (np.hstack((flux_lo, flux_hi)), ft_flux_mid, fft_len)
[docs] @staticmethod def conv_pre_select(conv_pre, rows): """Select a subset of templates from a :meth:`convolve_vdisp_pre` result. Parameters ---------- conv_pre : :class:`tuple` or None Preprocessing structure from :meth:`convolve_vdisp_pre`, or ``None``. rows : array-like of int Indices of the templates to keep. Returns ------- conv_pre_subset : :class:`tuple` or None Preprocessing data restricted to the selected rows, or ``None`` if ``conv_pre`` is ``None``. """ if conv_pre is None: return None else: flux_lohi, ft_flux_mid, fft_len = conv_pre return (flux_lohi[rows, :], ft_flux_mid[rows, :], fft_len)
[docs] def convolve_vdisp_from_pre(self, flux_lohi, ft_flux_mid, flux_len, fft_len, vdisp): """Convolve a pre-decomposed spectrum with a velocity dispersion. Uses precomputed FFT data for the central wavelength range (from :meth:`convolve_vdisp_pre`) and raw fluxes for the peripheral ranges. Parameters ---------- flux_lohi : :class:`numpy.ndarray`, shape (2*lo,) Concatenated raw fluxes for wavelengths outside :attr:`PIXKMS_BOUNDS`. ft_flux_mid : :class:`numpy.ndarray` of complex FFT of the central wavelength range. flux_len : :class:`int` Total length of the output flux array. fft_len : :class:`int` Padded FFT length used for ``ft_flux_mid``. vdisp : :class:`float` Velocity dispersion in km/s; must be ≤ :attr:`MAX_PRE_VDISP`. Returns ------- output : :class:`numpy.ndarray`, shape (flux_len,) Convolved flux spectrum. """ import scipy.fft as sc_fft assert vdisp <= Templates.MAX_PRE_VDISP output = np.empty(flux_len) pixsize_kms = Templates.PIXKMS sigma = vdisp / pixsize_kms # [pixels] radius = Templates._gaussian_radius(sigma) kernel = Templates._gaussian_kernel1d(sigma, radius) # compute FFT of Gaussian kernel, then complete convolution ft_kernel = sc_fft.rfft(kernel, n=fft_len) conv = sc_fft.irfft(ft_flux_mid * ft_kernel, n=fft_len) lo, hi = self.pixkms_bounds # extract middle of convolution (eqv to mode='same') s = len(kernel)//2 e = s + hi - lo output[lo:hi] = conv[s:e] output[:lo] = flux_lohi[:lo] output[hi:] = flux_lohi[lo:] return output
[docs] def convolve_vdisp(self, templateflux, vdisp): """Convolve one or more template spectra with a velocity dispersion. Only the wavelength range defined by :attr:`PIXKMS_BOUNDS` is convolved; pixels outside that range are copied unchanged. Parameters ---------- templateflux : :class:`numpy.ndarray` Either a 1D spectrum or a 2D array of shape ``(ntemplates, nwavelengths)``. vdisp : :class:`float` Velocity dispersion in km/s. Returns ------- output : :class:`numpy.ndarray` Convolved spectrum or array of spectra, same shape as ``templateflux``. """ from scipy.signal import oaconvolve # Convolve by the velocity dispersion. if vdisp <= 0.: output = templateflux.copy() else: output = np.empty_like(templateflux) pixsize_kms = Templates.PIXKMS sigma = vdisp / pixsize_kms # [pixels] radius = Templates._gaussian_radius(sigma) kernel = Templates._gaussian_kernel1d(sigma, radius) lo, hi = self.pixkms_bounds if templateflux.ndim == 1: output[lo:hi] = self.convolve( templateflux[lo:hi], kernel, mode='same') output[:lo] = templateflux[:lo] output[hi:] = templateflux[hi:] else: ntemplates = templateflux.shape[0] for ii in range(ntemplates): output[ii, lo:hi] = self.convolve( templateflux[ii, lo:hi], kernel, mode='same') output[:, :lo] = templateflux[:, :lo] output[:, hi:] = templateflux[:, hi:] return output
[docs] @staticmethod def _gaussian_radius(sigma): """Return the truncated radius (in pixels) of a Gaussian kernel with stddev ``sigma``.""" truncate = 4. return int(truncate * sigma + 0.5)
[docs] @staticmethod def _gaussian_kernel1d(sigma, radius, order=0): """Compute a 1-D Gaussian convolution kernel of width ``2*radius + 1``.""" sigma2 = sigma * sigma x = np.arange(-radius, radius+1, dtype=np.float64) phi_x = np.exp(-0.5 / sigma2 * x ** 2) phi_x /= phi_x.sum() if order == 0: return phi_x else: # f(x) = q(x) * phi(x) = q(x) * exp(p(x)) # f'(x) = (q'(x) + q(x) * p'(x)) * phi(x) # p'(x) = -1 / sigma ** 2 # Implement q'(x) + q(x) * p'(x) as a matrix operator and apply to the # coefficients of q(x) exponent_range = np.arange(order + 1) q = np.zeros(order + 1) q[0] = 1 D = np.diag(exponent_range[1:], 1) # D @ q(x) = q'(x) P = np.diag(np.ones(order)/-sigma2, -1) # P @ q(x) = q(x) * p'(x) Q_deriv = D + P for _ in range(order): q = Q_deriv.dot(q) q = (x[:, None] ** exponent_range).dot(q) return q * phi_x
[docs] @staticmethod def klambda(wave): """Construct the total-to-selective dust attenuation curve k(lambda). Parameters ---------- wave : :class:`numpy.ndarray` Rest-frame wavelength array in Angstroms. Returns ------- klambda : :class:`numpy.ndarray` Total-to-selective attenuation curve, same shape as ``wave``. """ dust_power = -0.7 # power-law slope dust_normwave = 5500. # pivot wavelength return (wave / dust_normwave)**dust_power