Source code for space_phot.util

from __future__ import annotations

from typing import Optional, Tuple

import numpy as np

try:
    import crds
except Exception as e:
    crds = None

import warnings
import os
import glob
import shutil
import urllib.request

import numpy as np
import scipy
import stpsf
import sncosmo

import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable

import astropy
from astropy import units as u
from astropy import wcs
from astropy.io import fits
from astropy.table import Table, vstack
from astropy.coordinates import SkyCoord
from astropy.stats import sigma_clipped_stats
from astropy.time import Time
from astropy.nddata import extract_array
from astropy.wcs.utils import skycoord_to_pixel, pixel_to_skycoord

import gwcs
from gwcs.utils import make_fitswcs_transform
from gwcs import coordinate_frames as cf
from astropy import coordinates as coord
from asdf import AsdfFile

import photutils
from photutils.aperture import (
    CircularAperture,
    CircularAnnulus,
    aperture_photometry,
)
from photutils.psf import EPSFModel

import jwst
#from jwst import datamodels
from jwst.pipeline import Detector1Pipeline, Image2Pipeline, Image3Pipeline
from jwst.associations import asn_from_list
from jwst.associations.lib.rules_level2_base import DMSLevel2bBase
from jwst.associations.lib.rules_level3_base import DMS_Level3_Base
from jwst import datamodels

#from jwst.pipeline import Detector1Pipeline, Image2Pipeline, Image3Pipeline
#from jwst.associations import asn_from_list
#from jwst.associations.lib.rules_level3_base import DMS_Level3_Base

from .wfc3_photometry.psf_tools.PSFUtils import make_models
from .wfc3_photometry.psf_tools.PSFPhot import get_standard_psf

__all__ = ['get_jwst_psf','get_hst_psf','get_jwst3_psf','get_hst3_psf','get_jwst_psf_grid',
            'get_jwst_psf_from_grid']


def mjd_dict_from_list(filelist, tolerance=0):
    """
    Group a list of FITS files by their observation date.

    Files are grouped by rounding their MJD keyword (MJD-AVG or EXPSTART)
    to a given number of decimal places.

    Parameters
    ----------
    filelist : list of str
        List of FITS filenames.
    tolerance : int, optional
        Number of decimal places to keep when rounding the MJD. Files whose
        rounded MJD match are grouped together.

    Returns
    -------
    mjd_dict : dict
        Dictionary mapping rounded MJD values to lists of filenames.
    """
    mjd_dict = {}
    for fname in filelist:
        with fits.open(fname) as dat:
            try:
                mjd = dat[0].header["MJD-AVG"]
            except Exception:
                mjd = dat[0].header["EXPSTART"]

        mjd_key = np.round(mjd, tolerance)
        mjd_dict.setdefault(mjd_key, []).append(fname)

    return mjd_dict


def filter_dict_from_list(
    filelist,
    sky_location=None,
    ext=1,
    buffer=0.0,
):
    """
    Group FITS files by FILTER keyword, with an optional check that a
    given sky position lies on the detector and is at least `buffer`
    pixels away from its edges.

    Parameters
    ----------
    filelist : list of str
        FITS filenames.
    sky_location : astropy.coordinates.SkyCoord or None, optional
        If provided, only include files in which this sky position can
        be transformed to valid detector pixel coordinates.
    ext : int, optional
        FITS extension to read the WCS / FILTER keyword from.
    buffer : float, optional
        Required minimum distance (in pixels) between the derived pixel
        coordinate and all image edges. Default: 0 (no edge exclusion).

        A file is only included if:
            buffer <= x <= nx - 1 - buffer
            buffer <= y <= ny - 1 - buffer

    Returns
    -------
    filt_dict : dict
        Dictionary mapping filter name → list of filenames.
    """
    filt_dict = {}

    for fname in filelist:
        try:
            with fits.open(fname) as hdul:
                header = hdul[ext].header
                filt = header.get("FILTER")
                if filt is None and ext==1:
                    filt = hdul[0].header.get('FILTER')    
                    if filt is None:
                        continue

                # If sky check requested…
                if sky_location is not None:
                    try:
                        w = wcs.WCS(header)
                        x, y = w.world_to_pixel(sky_location)
                    except Exception:
                        # WCS transform failed → exclude file
                        continue
                    # Reject NaNs or infs
                    if not np.isfinite(x) or not np.isfinite(y):
                        continue

                    # Image shape for bounds check
                    ny, nx = hdul[ext].data.shape

                    # Pixel must lie inside detector and outside edge buffer
                    if (
                        x < buffer
                        or y < buffer
                        or x > nx - 1 - buffer
                        or y > ny - 1 - buffer
                    ):
                        continue

                # Passed all checks → add file
                filt_dict.setdefault(filt, []).append(fname)

        except Exception:
            # unreadable FITS, bad header, etc — ignore
            continue

    return filt_dict



def stpsf_setup_sim_to_match_file(filename_or_HDUList, verbose=True, plot=False,dateobs=None):
    """ Setup a stpsf Instrument instance matched to a given
    """
    if isinstance(filename_or_HDUList,str):
        if verbose:
            print(f"Setting up sim to match {filename_or_HDUList}")
        header = fits.getheader(filename_or_HDUList)
    else:
        header = filename_or_HDUList[0].header
        if verbose:
            print(f"Setting up sim to match provided FITS HDUList object")

    inst = stpsf.instrument(header['INSTRUME'])

    if inst.name=='MIRI' and header['FILTER']=='P750L':
        # stpsf doesn't model the MIRI LRS prism spectral response
        print("Please note, stpsf does not currently model the LRS spectral response. Setting filter to F770W instead.")
        inst.filter='F770W'
    else:
        inst.filter=header['filter']
    inst.set_position_from_aperture_name(header['APERNAME'])

    if dateobs is None:
        dateobs = astropy.time.Time(header['DATE-OBS']+"T"+header['TIME-OBS'])
    inst.load_wss_opd_by_date(dateobs, verbose=verbose, plot=plot)


    # per-instrument specializations
    if inst.name == 'NIRCam':
        if header['PUPIL'].startswith('MASK'):
            inst.pupil_mask = header['PUPIL']
            inst.image_mask = header['CORONMSK'].replace('MASKA', 'MASK')  # note, have to modify the value slightly for
                                                                           # consistency with the labels used in stpsf
    elif inst.name == 'MIRI':
        if inst.filter in ['F1065C', 'F1140C', 'F1550C']:
            inst.image_mask = 'FQPM'+inst.filter[1:5]
        elif inst.filter == 'F2300C':
            inst.image_mask = 'LYOT2300'
        elif header['FILTER'] == 'P750L':
            inst.pupil_mask = 'P750L'
            if header['APERNAME'] == 'MIRIM_SLIT':
                inst.image_mask = 'LRS slit'

    # TODO add other per-instrument keyword checks

    if verbose:
        print(f"""
Configured simulation instrument for:
    Instrument: {inst.name}
    Filter: {inst.filter}
    Detector: {inst.detector}
    Apername: {inst.aperturename}
    Det. Pos.: {inst.detector_position} {'in subarray' if "FULL" not in inst.aperturename else ""}
    Image plane mask: {inst.image_mask}
    Pupil plane mask: {inst.pupil_mask}
    """)

    return inst

[docs] def get_jwst_psf_grid(st_obs,num_psfs=16,fname=None,dateobs=None): if fname is None: inst = stpsf_setup_sim_to_match_file(st_obs.exposure_fnames[0],dateobs=dateobs,verbose=False) else: inst = stpsf_setup_sim_to_match_file(fname,dateobs=dateobs,verbose=False) grid = inst.psf_grid(num_psfs=num_psfs, all_detectors=False, verbose=False) return grid
[docs] def get_jwst_psf_from_grid(st_obs,sky_location,grid,psf_width=101): grid.oversampling = (1,1) psf_list = [] for i in range(st_obs.n_exposures): imwcs = st_obs.wcs_list[i] x,y = astropy.wcs.utils.skycoord_to_pixel(sky_location,imwcs) grid.x_0 = x grid.y_0 = y xf, yf = np.meshgrid(np.arange(-4*psf_width/2,psf_width/2*4+1,1).astype(int)+int(x+.5), np.arange(-4*psf_width/2,psf_width/2*4+1,1).astype(int)+int(y+.5)) psf = np.array(grid(xf,yf)).astype(float) psf/=np.sum(psf) psf*=16 epsf_model = photutils.psf.ImagePSF(psf,oversampling=4) psf_list.append(epsf_model) return psf_list
[docs] def get_jwst_psf(st_obs,sky_location,psf_width=61,pipeline_level=2,fname=None,dateobs=None): #inst = stpsf.instrument(st_obs.instrument) #inst.filter = st_obs.filter #inst.detector=st_obs.detector if fname is None: inst = stpsf_setup_sim_to_match_file(st_obs.exposure_fnames[0],dateobs=dateobs,verbose=False) else: inst = stpsf_setup_sim_to_match_file(fname,dateobs=dateobs,verbose=False) if pipeline_level == 3: oversampling = 1 else: oversampling = 4 psf_list = [] #kernel = astropy.convolution.Box2DKernel(width=4) for i in range(st_obs.n_exposures): #inst.pixelscale = st_obs.pixel_scale[i] imwcs = st_obs.wcs_list[i] y,x = astropy.wcs.utils.skycoord_to_pixel(sky_location,imwcs) #inst.detector_position = (x,y) c = stpsf.gridded_library.CreatePSFLibrary(inst,inst.filter, num_psfs = 1, psf_location = (x,y), fov_pixels = psf_width, detectors=st_obs.detector,save=False,verbose=False, add_distortion=True, use_detsampled_psf=True if oversampling==1 else False) #psf = inst.calc_psf(oversample=4,normalize='last') grid = c.create_grid() #psf[0].data = astropy.convolution.convolve(psf[0].data, kernel) #stpsf.detectors.apply_detector_ipc(psf, extname=0) #epsf_model = photutils.psf.ImagePSF(psf[0].data*16,normalize=False,oversampling=oversampling) #epsf_model = photutils.psf.ImagePSF(grid.data[0,:,:]/np.sum(grid.data[0,:,:])*16,normalize=False,oversampling=oversampling) epsf_model = photutils.psf.ImagePSF(grid.data[0,:,:],oversampling=oversampling) psf_list.append(epsf_model) return psf_list
def get_jwst3_psf_spike(st_obs,st_obs3,sky_location,temp_outdir='.',verbose=True,psf_width=31): try: import spike except: raise RuntimeError('Must have spike-psf for level 3 psfs.') psf_drz = spike.psf.jwst(os.path.basename(st_obs.exposure_fnames[0]), '%f %f'%(sky_location.ra.value,sky_location.dec.value), st_obs3.instrument, img_type = os.path.splitext(st_obs.exposure_fnames[0])[1], camera = None, method = 'WebbPSF', usermethod = None, overwrite=True, savedir = temp_outdir, drizzleimgs = False, objonly = True, pretweaked = True, usecrds = True, keeporig = False, plot = False, verbose = verbose, parallel = False, out = 'fits', returnpsf = 'crop', cutout_fov = psf_width, savecutout = False, finalonly = True, removedir = temp_outdir, tweakparams = {}, drizzleparams = {'pixel_scale':st_obs3.pixel_scale, 'output_wcs': st_obs3.wcs}) return(psf_drz)
[docs] def get_jwst3_psf(st_obs,st_obs3,sky_location,num_psfs=4,psf_width=31,temp_outdir='.'): with open('./stpipe-log.cfg','w') as f: s = '[*]\nhandler = file:/dev/null\nlevel = INFO\n' f.write(s) #sys.exit() psfs = get_jwst_psf(st_obs,sky_location,psf_width=psf_width,pipeline_level=3) #grid = get_jwst_psf_grid(st_obs,num_psfs=num_psfs) #grid.oversampling = 1 # kernel = astropy.convolution.Box2DKernel(width=4) # psfs = [] # for i in range(st_obs.n_exposures): # imwcs = st_obs.wcs_list[i] # x,y = astropy.wcs.utils.skycoord_to_pixel(sky_location,imwcs) # psf = inst.calc_psf(oversample=4,normalize='last') # psf[0].data = astropy.convolution.convolve(psf[0].data, kernel) # # grid.x_0 = x # # grid.y_0 = y # # # # xf, yf = np.meshgrid(np.arange(-4*psf_width/2,psf_width/2*4+1,1).astype(int)+int(x+.5), # # np.arange(-4*psf_width/2,psf_width/2*4+1,1).astype(int)+int(y+.5)) # # psf = np.array(grid(xf,yf)).astype(float) # epsf_model = photutils.psf.ImagePSF(psf,normalize=True,oversampling=1) # psfs.append(epsf_model) outdir = os.path.join(temp_outdir,'temp_psf_dir')#%np.random.randint(0,1000)) if not os.path.exists(outdir): os.mkdir(outdir) #print(outdir) level2_sums = [] try: out_fnames = [] for i,f in enumerate(st_obs.exposure_fnames): #print(f) dat = fits.open(f) imwcs = wcs.WCS(dat['SCI',1]) #print(imwcs) y,x = skycoord_to_pixel(sky_location,imwcs) #xf, yf = np.mgrid[0:dat['SCI',1].data.shape[0]+int(psf_width*8),0:dat['SCI',1].data.shape[1]+int(psf_width*8)].astype(int) xf, yf = np.mgrid[0:dat['SCI',1].data.shape[0],0:dat['SCI',1].data.shape[1]].astype(int) #psfs[i].x_0 = x+psf_width*4 #psfs[i].y_0 = y+psf_width*4 psfs[i].x_0 = x#int(x)+.5 psfs[i].y_0 = y#int(y)+.5 #import pdb #pdb.set_trace() #dat['SCI',1].data = psfs[i].data# dat['SCI',1].data = psfs[i](xf,yf) level2_sums.append(np.sum(dat['SCI',1].data)) dat.writeto(os.path.join(outdir,os.path.basename(f)),overwrite=True) #out_fnames.append(os.path.join(outdir,os.path.basename(f))) out_fnames.append(os.path.basename(f)) #sys.exit() asn = asn_from_list.asn_from_list(out_fnames, rule=DMS_Level3_Base, product_name='temp_psf_cals') with open(os.path.join(outdir,'cal_data_asn.json'),"w") as outfile: name, serialized = asn.dump(format='json') outfile.write(serialized) ref_image = fits.open(st_obs3.fname)['SCI',1] ref_dm = datamodels.open(st_obs3.fname) gw = ref_dm.meta.wcs ny, nx = ref_dm.data.shape gw.bounding_box = ((0, nx), (0, ny)) # or (0, nx-1), depending on your convention AsdfFile({"wcs": gw}).write_to(os.path.join(temp_outdir, "ref_wcs.asdf")) # ref_wcs = wcs.WCS(ref_image) # transform = make_fitswcs_transform(ref_image.header) # # 3. Define frames # detector_frame = cf.Frame2D( # name="detector", # axes_names=("x", "y"), # unit=(u.pix, u.pix), # ) # sky_frame = cf.CelestialFrame( # name="icrs", # reference_frame=coord.ICRS(), # unit=(u.deg, u.deg), # ) # # 4. Build the GWCS pipeline # pipeline = [(detector_frame, transform), (sky_frame, None)] # gw = gwcs.WCS(pipeline) # ny,nx = ref_image.data.shape # # 5. Set bounding box # # NOTE: GWCS uses "F" order: ((xmin, xmax), (ymin, ymax)) for (x, y) axes # gw.bounding_box = ((0, nx), (0, ny)) # #ref_wcs = #gwcs.wcs.WCS(ref_image) # tree = {"wcs": gw} #wcs_file = AsdfFile(tree) #wcs_file.write_to(os.path.join(temp_outdir,'ref_wcs.asdf')) params = {'assign_mtwcs': {'skip': True}, 'tweakreg': {'skip': True}, 'skymatch': {'skip': True}, 'outlier_detection': {'skip': True}, 'resample': {'pixfrac' : 1., 'kernel' : 'square', #'pixel_scale' : st_obs3.pixel_scale, #'rotation' : ref_pa,#0, #-66.8983245393371, #'output_shape' : list(ref_image.shape), #'crpix' : [0,0], #'crval' : ref_crval, 'fillval' :'indef', 'weight_type' :'ivm', 'output_wcs': os.path.join(temp_outdir,'ref_wcs.asdf'), #'single' : True, #'blendheaders' : False, 'in_memory' : False, 'save_results' : True}, 'source_catalog': {'skip': True}} Image3Pipeline.call(os.path.join(outdir,'cal_data_asn.json'),steps=params, output_dir=outdir,save_results=True) #imwcs = None #level3 = None with fits.open(os.path.join(outdir,'temp_psf_cals_i2d.fits')) as dat: imwcs = wcs.WCS(dat['SCI',1]) level3 = dat[1].data level3[np.isnan(level3)] = 0 level3[level3<0] = 0 #print(np.max(level3)) #sys.exit() #kernel = astropy.convolution.Box2DKernel(width=4) #level3 = astropy.convolution.convolve(level3, kernel) y,x = astropy.wcs.utils.skycoord_to_pixel(sky_location,imwcs) # mx,my = np.meshgrid(np.arange(-4*psf_width/2,psf_width/2*4+1,1).astype(int)+int(x+.5+psf_width*4), # np.arange(-4*psf_width/2,psf_width/2*4+1,1).astype(int)+int(y+.5+psf_width*4)) mx,my = np.meshgrid(np.arange(-4*psf_width/2,psf_width/2*4+1,1).astype(int)+int(x+.5), np.arange(-4*psf_width/2,psf_width/2*4+1,1).astype(int)+int(y+.5)) level3_psf = photutils.psf.ImagePSF(level3[mx,my], oversampling=1) #import pdb #pdb.set_trace() temp_fnames = glob.glob(os.path.join(outdir,'*')) for f in temp_fnames: os.remove(f) shutil.rmtree(outdir, ignore_errors=True) #os.rmdir(outdir) os.remove('stpipe-log.cfg') except RuntimeError:#Exception as e: print('Failed to create PSF model') print(e) temp_fnames = glob.glob(os.path.join(outdir,'*')) for f in temp_fnames: os.remove(f) shutil.rmtree(outdir, ignore_errors=True) os.remove('stpipe-log.cfg') return level3_psf
def get_hst_psf_grid(st_obs): grid = make_models(get_standard_psf(os.path.join(os.path.abspath(os.path.dirname(__file__)), 'wfc3_photometry/psfs'),st_obs.filter,st_obs.detector))[0] return grid
[docs] def get_hst_psf(st_obs,sky_location,psf_width=25,pipeline_level=2): grid = make_models(get_standard_psf(os.path.join(os.path.abspath(os.path.dirname(__file__)), 'wfc3_photometry/psfs'),st_obs.filter,st_obs.detector))[0] psf_list = [] _, oversamp = np.array(grid.oversampling, dtype=float) for i in range(st_obs.n_exposures): imwcs = st_obs.wcs_list[i] y,x = astropy.wcs.utils.skycoord_to_pixel(sky_location,imwcs) size_os = psf_width * oversamp # Make size_os odd so we have a clean center if size_os % 2 == 0: size_os += 1 half_os = size_os // 2 # Grid in oversampled pixel indices around the center # (j, i) ~ (y, x) jj, ii = np.mgrid[-half_os:half_os+1, -half_os:half_os+1] # Convert oversampled pixel offsets to detector coordinates # 1 oversampled pixel = 1 / oversamp detector pixels x_coords = x + ii / oversamp y_coords = y + jj / oversamp # GriddedPSFModel expects x, y arrays (broadcastable) #stamp = psf_model(x_coords, y_coords) # Photutils wants positions as (y, x) stacked into an array # Evaluate PSF vals = grid.evaluate(x_coords,y_coords,1,float(x),float(y)) _psf_interp = vals.reshape((int(size_os), int(size_os))) #grid_idx, _ = grid._find_bounding_points(x,y) #print(grid_idx) #psfinterp = grid._calc_interpolator(int(x), int(y)) #psfinterp = grid._calc_interpolator(grid_idx) #_psf_interp = psfinterp(grid._xidx, grid._yidx) _psf_interp/=simple_aperture_sum(_psf_interp,[[_psf_interp.shape[0]/2,_psf_interp.shape[0]/2]],5.6*4) _psf_interp*=16 _psf_interp*=(hst_apcorr(5.6*st_obs.px_scale,st_obs.filter,st_obs.instrument)) if pipeline_level==2: psfmodel = photutils.psf.ImagePSF(_psf_interp, oversampling=grid.oversampling) else: psfmodel = photutils.psf.ImagePSF(_psf_interp, oversampling=1) psfmodel.x_0 = x#int(x) psfmodel.y_0 = y#int(y) psf_list.append(psfmodel) #yg, xg = np.mgrid[-1*(psf_width-1)/2:(psf_width+1)/2,-1*(psf_width-1)/2:(psf_width+1)/2].astype(int) #yf, xf = yg+int(y+.5), xg+int(x+.5) #yf, xf = yg+int(np.round(y)), xg+int(np.round(x)) #psf = np.array(psfmodel(xf,yf)).astype(float) #plt.imshow(psf) #plt.show() #continue #print(x,y) #epsf_model = EPSFModel(psf) #psf_list.append(epsf_model) return psf_list
[docs] def get_hst3_psf(st_obs,st_obs3,sky_location,psf_width=25): from drizzlepac import astrodrizzle psfs = get_hst_psf(st_obs,sky_location,psf_width=psf_width,pipeline_level=3) outdir = os.path.join(os.path.abspath(os.path.dirname(__file__)),'temp_%i'%np.random.randint(0,1000)) os.mkdir(outdir) level2_sums = [] try: out_fnames = [] for i,f in enumerate(st_obs.exposure_fnames): dat = fits.open(f) if False: newx = dat[1].header['NAXIS1']*4 newy = dat[1].header['NAXIS2']*4 old_wcs = wcs.WCS(dat[1],dat) new_wcs = old_wcs[::.25,::.25].to_header() for k in ['PC1_1', 'PC1_2','PC2_1','PC2_2']: new_wcs[k]/=4 for key in new_wcs.keys(): if len(key)>0: #dm_fits[i].header[key+'A'] = dm_fits[i].header[key] #if not (self.do_driz or ('CRPIX' in key or 'CTYPE' in key)): if 'CTYPE' not in key: if key.startswith('PC') and key not in dat[1].header.keys(): dat[1].header.set(key.replace('PC','CD'),value=new_wcs[key]) elif key in dat[1].header: dat[1].header.set(key,value=new_wcs[key]) #else: # dm_fits[i].header.set(key,value='TWEAK') dat[1].header['IDCSCALE'] = dat[1].header['IDCSCALE']/4 else: newx = dat['SCI',st_obs.sci_ext].data.shape[0]#header['NAXIS1'] newy = dat['SCI',st_obs.sci_ext].data.shape[1]#header['NAXIS2'] if True: #dat['SCI',1].data = np.zeros((newy,newx)) imwcs = wcs.WCS(dat['SCI',st_obs.sci_ext],dat) y,x = skycoord_to_pixel(sky_location,imwcs) #xf, yf = np.mgrid[0:dat['SCI',1].data.shape[0]+int(psf_width*8),0:dat['SCI',1].data.shape[1]+int(psf_width*8)].astype(int) xf, yf = np.mgrid[0:dat['SCI',st_obs.sci_ext].data.shape[0],0:dat['SCI',st_obs.sci_ext].data.shape[1]].astype(int) #psfs[i].x_0 = x+psf_width*4 #psfs[i].y_0 = y+psf_width*4 psfs[i].x_0 = int(x)+.5 psfs[i].y_0 = int(y)+.5 dat['SCI',st_obs.sci_ext].data = psfs[i](xf,yf) #x,y = astropy.wcs.utils.skycoord_to_pixel(sky_location,wcs.WCS(dat[1],dat)) #psf2 = photutils.psf.ImagePSF(psfs[i].data,normalize=False, # oversampling=1) #psf2.x_0 = x #psf2.y_0 = y #x = int(x+.5) #y = int(y+.5) #gx, gy = np.mgrid[0:newx,0:newy].astype(int) #dat[1].data = psf2.evaluate(gx,gy,psf2.flux.value,psf2.x_0.value,psf2.y_0.value, # use_oversampling=False) dat['SCI',st_obs.sci_ext].data[dat['SCI',st_obs.sci_ext].data<0] = 0 #dat[1].data/=scipy.ndimage.zoom(st_obs.pams[0].T,4) #dat[1].data/=st_obs.pams[0] #if st_obs.detector in ["ACS","UVIS"]: # dat['D2IMARR',st_obs.sci_ext].data = scipy.ndimage.zoom(dat['D2IMARR',st_obs.sci_ext].data,4) dat['DQ',st_obs.sci_ext].data = np.zeros((newx,newy)).astype(int) dat['ERR',st_obs.sci_ext].data = np.ones((newx,newy)) #dat = dat[:4] level2_sums.append(simple_aperture_sum(dat['SCI',st_obs.sci_ext].data,[[y,x]],5.6*4)) dat.writeto(os.path.join(outdir,os.path.basename(f)),overwrite=True) out_fnames.append(os.path.join(outdir,os.path.basename(f))) astrodrizzle.AstroDrizzle(','.join(out_fnames),output=os.path.join(outdir,'temp_psf'), build=True,median=False,skysub=False,sky_bits=None, driz_cr_corr=False,final_wht_type='ERR',driz_separate=False, driz_cr=False,blot=False,clean=True,group='sci,'+str(st_obs.sci_ext), final_scale=st_obs3.pixel_scale ) try: dat = fits.open(glob.glob(os.path.join(outdir,'temp_psf_drz.fits'))[0]) except: dat = fits.open(glob.glob(os.path.join(outdir,'temp_psf_drc.fits'))[0]) #sys.exit() imwcs = wcs.WCS(dat[1],dat) y,x = skycoord_to_pixel(sky_location,imwcs) level3 = dat[1].data level3[np.isnan(level3)] = 0 level3[level3<0] = 0 y,x = astropy.wcs.utils.skycoord_to_pixel(sky_location,imwcs) mx,my = np.meshgrid(np.arange(-4*psf_width/2,psf_width/2*4+1,1).astype(int)+int(x+.5), np.arange(-4*psf_width/2,psf_width/2*4+1,1).astype(int)+int(y+.5)) mx2,my2 = np.meshgrid(np.arange(-1*psf_width/2,psf_width/2*1+1,1).astype(int)+int(x+.5), np.arange(-1*psf_width/2,psf_width/2*1+1,1).astype(int)+int(y+.5)) level3_sum = simple_aperture_sum(level3,[y,x],5.6*4) level3[mx,my]/=level3_sum #level3[mx,my]*=np.median(level2_sums) level3[mx,my]*=16 level3[mx,my]*=(hst_apcorr(5.6*st_obs3.px_scale,st_obs3.filter,st_obs3.instrument)) #level3[mx,my]*=(np.median(level2_sums)/simple_aperture_sum(level3[mx,my],[[level3[mx,my].shape[0]/2, # level3[mx,my].shape[1]/2]],5.6*4)) #level3[mx,my]*=16 level3_psf = photutils.psf.ImagePSF(level3[mx,my], oversampling=4) #kernel = astropy.convolution.Box2DKernel(width=4) #level3_psf = photutils.psf.ImagePSF(astropy.convolution.convolve(level3[mx,my], kernel),normalize=False, # oversampling=4) #sys.exit() shutil.rmtree(outdir) except RuntimeError: print('Failed to create PSF model') shutil.rmtree(outdir) return level3_psf
def jwst_apcorr_interp(fname, radius, alternate_ref: Optional[str] = None): """ Interpolate JWST imaging aperture correction as a function of aperture radius (pixels), bypassing JWST pipeline step machinery. Parameters ---------- fname : str JWST image (e.g. *_cal.fits, *_i2d.fits) used to select the correct APCORR ref. radius : float Aperture radius in PIXELS. alternate_ref : str, optional If provided, use this file to select the APCORR reference, or pass an APCORR file directly (convenience) if its name/path contains 'apcorr'. Returns ------- tuple (ee_percent, apcorr, skyin_pix, skyout_pix) - ee_percent: interpolated EE in percent (0–100) - apcorr: interpolated aperture correction factor - skyin_pix, skyout_pix: recommended background annulus radii in pixels """ radius = float(radius) if radius <= 0: raise ValueError(f"radius must be positive (pixels), got {radius}") with datamodels.open(alternate_ref or fname) as model: # Determine APCORR reference file if alternate_ref and alternate_ref.lower().endswith((".asdf", ".fits")) and "apcorr" in alternate_ref.lower(): apcorr_path = alternate_ref else: apcorr_path = _get_best_apcorr_reffile(model) with datamodels.open(apcorr_path) as apm: tab = apm.apcorr_table # often a numpy recarray names = tab.dtype.names # Instrument selection keys filt = (model.meta.instrument.filter or "").upper() pup = (model.meta.instrument.pupil or "").upper() # Build mask for this filter/pupil filt_col = np.array([str(x).upper() for x in tab["filter"]]) m = (filt_col == filt) if "pupil" in names and pup: pup_col = np.array([str(x).upper() for x in tab["pupil"]]) m &= (pup_col == pup) if not np.any(m): # relax pupil if needed m2 = (filt_col == filt) if np.any(m2): m = m2 else: raise ValueError( f"No APCORR rows match FILTER={filt!r}, PUPIL={pup!r}. Ref: {apcorr_path}" ) sub = tab[m] # Collect arrays r = np.array(sub["radius"], dtype=float) # pixels c = np.array(sub["apcorr"], dtype=float) ee_frac = np.array(sub["eefraction"], dtype=float) # Get background annulus values. # Many ref tables keep skyin/skyout constant for the whole filter/pupil, # but we safely pick the first. skyin = float(np.array(sub["skyin"], dtype=float)[0]) skyout = float(np.array(sub["skyout"], dtype=float)[0]) # Sort by radius (interp1d requires monotonic x) idx = np.argsort(r) r = r[idx] c = c[idx] ee_frac = ee_frac[idx] # Radius bounds rmin, rmax = float(np.nanmin(r)), float(np.nanmax(r)) if radius < rmin or radius > rmax: raise ValueError( f"radius={radius} px is outside APCORR bounds [{rmin}, {rmax}] px " f"for FILTER={filt}, PUPIL={pup}. Ref: {apcorr_path}" ) # Interpolate (linear; matches your prior behavior) apcorr = float(scipy.interpolate.interp1d(r, c)(radius)) ee_percent = float(scipy.interpolate.interp1d(r, ee_frac)(radius) * 100.0) return ee_percent, apcorr, skyin, skyout def _get_best_apcorr_reffile(model) -> str: hdr = model.to_flat_dict() ctx = crds.get_context_name("jwst") return crds.getreferences(hdr, reftypes=["apcorr"], context=ctx)["apcorr"] def jwst_apcorr( fname: str, ee: float = 70, alternate_ref: Optional[str] = None, ): """ Lookup JWST imaging aperture correction directly from the APCORR reference file, bypassing the JWST pipeline machinery. Parameters ---------- fname : str JWST image (e.g. *_cal.fits, *_i2d.fits). ee : float Encircled energy percentage (e.g. 70 for 70% EE). alternate_ref : str, optional If provided, use this file to select the APCORR reference. Returns ------- list [radius_pix, apcorr, skyin_pix, skyout_pix] """ ee_fraction = float(ee) / 100.0 with datamodels.open(alternate_ref or fname) as model: if alternate_ref and alternate_ref.lower().endswith((".asdf", ".fits")) and "apcorr" in alternate_ref.lower(): apcorr_path = alternate_ref else: apcorr_path = _get_best_apcorr_reffile(model) with datamodels.open(apcorr_path) as apm: tab = apm.apcorr_table names = tab.dtype.names filt = (model.meta.instrument.filter or "").upper() pup = (model.meta.instrument.pupil or "").upper() filt_col = np.array([str(x).upper() for x in tab["filter"]]) ee_col = np.array(tab["eefraction"], dtype=float) m = (filt_col == filt) & np.isclose(ee_col, ee_fraction) if "pupil" in names and pup: pup_col = np.array([str(x).upper() for x in tab["pupil"]]) m &= (pup_col == pup) if not np.any(m): raise ValueError( f"No APCORR match for FILTER={filt}, PUPIL={pup}, EE={ee}% " f"in ref file {apcorr_path}" ) row = tab[m][0] return [ float(row["radius"]), float(row["apcorr"]), float(row["skyin"]), float(row["skyout"]), ] def estimate_bkg(data,position,inner, outer,model_psf=None,corr=None): assert model_psf is not None or corr is not None, 'Must supply model_psf or corr' assert inner<outer annulus_aperture = CircularAnnulus(np.flip(position), r_in=inner, r_out=outer) annulus_mask = annulus_aperture.to_mask(method='center') annulus_data = annulus_mask.multiply(data) import matplotlib.pyplot as plt model_psf.x_0 = position[1] model_psf.y_0 = position[0] yf, xf = np.mgrid[0:data.shape[0],0:data.shape[1]].astype(int) psf = np.array(model_psf(xf,yf)).astype(float) annulus_psf = annulus_mask.multiply(psf) print(np.sum(annulus_psf)/np.sum(psf)) plt.imshow(annulus_data) plt.show() plt.imshow(annulus_psf) plt.show() sys.exit() def generic_aperture_phot(data, positions, radius, sky, epadu=1, error=None): """ Perform circular aperture photometry with a local sky annulus. Parameters ---------- data : ndarray 2D image array. positions : array_like Position(s) in pixel coordinates (x, y). radius : float Aperture radius in pixels. sky : dict Dictionary with keys 'sky_in' and 'sky_out' giving inner and outer radii of the sky annulus in pixels. epadu : float, optional Electrons per ADU for Poisson error estimation. error : ndarray or None, optional Per-pixel 1-sigma uncertainties. If provided, photutils will use this for error propagation; otherwise we estimate errors from Poisson + sky scatter. Returns ------- phot : astropy.table.Table Photometry table with at least the following columns: 'aperture_sum', 'annulus_median', 'aper_bkg', 'aper_sum_bkgsub', and 'aperture_sum_err' if error is None. """ aperture = CircularAperture(positions, r=radius) annulus_aperture = CircularAnnulus( positions, r_in=sky["sky_in"], r_out=sky["sky_out"] ) annulus_mask = annulus_aperture.to_mask(method="center") bkg_median = [] bkg_stdev = [] for mask in annulus_mask: try: annulus_data = mask.multiply(data) annulus_data_1d = annulus_data[mask.data > 0] _, median_sigclip, stdev_sigclip = sigma_clipped_stats(annulus_data_1d) except Exception: median_sigclip = np.nan stdev_sigclip = np.nan bkg_median.append(median_sigclip) bkg_stdev.append(stdev_sigclip) bkg_median = np.array(bkg_median) bkg_stdev = np.array(bkg_stdev) phot = aperture_photometry(data, aperture, method="exact", error=error) phot["annulus_median"] = bkg_median phot["aper_bkg"] = bkg_median * aperture.area phot["aper_sum_bkgsub"] = phot["aperture_sum"] - phot["aper_bkg"] if error is None: # Poisson error on background-subtracted counts error_poisson = np.sqrt(phot["aper_sum_bkgsub"]) # Scatter inside the sky annulus error_scatter_sky = aperture.area * bkg_stdev**2 # Error on the mean sky level error_mean_sky = ( bkg_stdev**2 * aperture.area**2 / annulus_aperture.area ) fluxerr = np.sqrt( (error_poisson**2) / epadu + error_scatter_sky + error_mean_sky ) phot["aperture_sum_err"] = fluxerr return phot def jwst_aperture_phot(fname,ra,dec, filt,ee='r70'): try: force_ra = float(ra) force_dec = float(dec) unit = u.deg except: unit = (u.hourangle, u.deg) if isinstance(ee,str): radius,apcorr,skyan_in,skyan_out = get_apcorr_params(fname,int(ee[1:])) else: radius,apcorr,skyan_in,skyan_out = ee,1,ee+1,ee+3 #radius =1.8335238 #apcorr = aper_func(radius) #radius,apcorr = 1.83,1 image = fits.open(fname) data = image['SCI',1].data#*image['AREA',1].data err = image['ERR',1].data imwcs = wcs.WCS(image[1]) #positions = np.atleast_2d(np.flip([582.80256776,819.78997553]))# positions = np.atleast_2d(astropy.wcs.utils.skycoord_to_pixel(SkyCoord(ra, dec,unit=unit),imwcs)) imh = image['SCI',1].header area = image[1].header['PIXAR_SR'] aa = np.argwhere(data < 0) for i in np.arange(0, len(aa), 1): data[aa[i][0], aa[i][1]] = 0 sky = {'sky_in':skyan_in,'sky_out':skyan_out} #with datamodels.open(fname) as model: # dat = model.data # err = model.err #phot = generic_aperture_phot(data,positions,radius,sky,error=image['ERR',1].data) phot = generic_aperture_phot(data,positions,radius,sky,error=err) phot['aper_sum_corrected'] = phot['aper_sum_bkgsub'] * apcorr phot['aperture_sum_err']*=apcorr phot['magerr'] = 2.5 * np.log10(1.0 + (phot['aperture_sum_err']/phot['aper_sum_bkgsub'])) pixel_scale = wcs.utils.proj_plane_pixel_scales(imwcs)[0] * imwcs.wcs.cunit[0].to('arcsec') flux_units = u.MJy / u.sr * (pixel_scale * u.arcsec)**2 flux = phot['aper_sum_corrected']*flux_units phot['mag'] = flux.to(u.ABmag).value return phot def hst_apcorr(ap,filt,inst): if inst=='ir': if not os.path.exists('ir_ee_corrections.csv'): urllib.request.urlretrieve('https://www.stsci.edu/files/live/sites/www/files/home/hst/'+\ 'instrumentation/wfc3/data-analysis/photometric-calibration/'+\ 'ir-encircled-energy/_documents/ir_ee_corrections.csv', 'ir_ee_corrections.csv') ee = Table.read('ir_ee_corrections.csv',format='ascii') ee.remove_column('FILTER') waves = ee['PIVOT'] ee.remove_column('PIVOT') else: if not os.path.exists('wfc3uvis2_aper_007_syn.csv'): urllib.request.urlretrieve('https://www.stsci.edu/files/live/sites/www/files/home/hst/'+\ 'instrumentation/wfc3/data-analysis/photometric-calibration/'+\ 'uvis-encircled-energy/_documents/wfc3uvis2_aper_007_syn.csv','wfc3uvis2_aper_007_syn.csv') ee = Table.read('wfc3uvis2_aper_007_syn.csv',format='ascii') ee.remove_column('FILTER') waves = ee['WAVELENGTH'] ee.remove_column('WAVELENGTH') ee_arr = np.array([ee[col] for col in ee.colnames]) apps = [float(x.split('#')[1]) for x in ee.colnames] interp = scipy.interpolate.RectBivariateSpline(waves,apps,ee_arr.T) try: filt_wave = sncosmo.get_bandpass(filt).wave_eff except: filt_wave = sncosmo.get_bandpass('uv'+filt).wave_eff return(interp(filt_wave,ap)) def hst_get_zp(filt,zpsys='ab'): if zpsys.lower()=='ab': return {'F098M':25.666,'F105W':26.264,'F110W':26.819,'F125W':26.232,'F140W':26.450,'F160W':25.936}[filt] elif zpsys.lower()=='vega': return {'F098M':25.090,'F105W':25.603,'F110W':26.042,'F125W':25.312,'F140W':25.353,'F160W':24.662}[filt] else: print('unknown zpsys') return def hst_aperture_phot(fname,force_ra,force_dec,filt,radius=3, skyan_in=4,skyan_out=8): data_file = fits.open(fname) drc_dat = data_file['SCI',1] if data_file[1].header['BUNIT']=='ELECTRON': epadu = 1 else: epadu = data_file[0].header['EXPTIME'] try: force_ra = float(force_ra) force_dec = float(force_dec) unit = u.deg except: unit = (u.hourangle, u.deg) sky_location = SkyCoord(force_ra,force_dec,unit=unit) imwcs = wcs.WCS(drc_dat.header,data_file) x,y = astropy.wcs.utils.skycoord_to_pixel(sky_location,imwcs) px_scale = wcs.utils.proj_plane_pixel_scales(imwcs)[0] * imwcs.wcs.cunit[0].to('arcsec') try: zp = hst_get_zp(filt,'ab') inst = 'ir' except: inst = 'uvis' phot = generic_aperture_phot(drc_dat.data,np.atleast_2d([x,y]), radius,{'sky_in':skyan_in,'sky_out':skyan_out},epadu=epadu) phot['magerr'] = 1.086 * phot['aperture_sum_err']/phot['aper_sum_bkgsub'] apcorr = hst_get_ee_corr(radius*px_scale,filt,inst) if inst=='ir': ee_corr = 2.5*np.log10(apcorr) zp = hst_get_zp(filt,'ab') phot['aper_sum_corrected'] = phot['aper_sum_bkgsub']/apcorr phot['mag'] = -2.5*np.log10(phot['aper_sum_corrected'])+zp else: try: hdr = drc_dat.header photflam = hdr['PHOTFLAM'] except: hdr = fits.open(data_file)[0].header photflam = hdr['PHOTFLAM'] photplam = drc_dat.header['PHOTPLAM'] ee_corr = 2.5*np.log10(apcorr) zp = -2.5*np.log10(photflam)-5*np.log10(photplam)-2.408 phot['aper_sum_corrected'] = phot['aper_sum_bkgsub'] / apcorr phot['aperture_sum_err']/=apcorr phot['mag'] = -2.5*np.log10(phot['aper_sum_corrected']) + zp return(phot) def simple_aperture_sum(data, positions, radius): """ Compute a simple circular-aperture sum at one or more positions. Parameters ---------- data : ndarray 2D image array. positions : array_like Position or list of positions in pixel coordinates. Follows the photutils convention (x, y). radius : float Aperture radius in pixels. Returns ------- aperture_sum : astropy.table.Column The 'aperture_sum' column from photutils.aperture_photometry. """ aperture = CircularAperture(positions, r=radius) phot = aperture_photometry(data, aperture, method="exact") return phot["aperture_sum"]