diff --git a/DynSpecMS/ClassDynSpecMS.py b/DynSpecMS/ClassDynSpecMS.py index 8a3b1e7..51b7c73 100644 --- a/DynSpecMS/ClassDynSpecMS.py +++ b/DynSpecMS/ClassDynSpecMS.py @@ -11,15 +11,12 @@ from DDFacet.Array import shared_dict from DDFacet.Other import AsyncProcessPool -from DDFacet.Other import Multiprocessing from DDFacet.Other import ModColor from DDFacet.Other.progressbar import ProgressBar import numpy as np from astropy.time import Time from DDFacet.Other import ClassTimeIt from astropy import constants as const -import os -from killMS.Other import reformat from DDFacet.Other import AsyncProcessPool from .dynspecms_version import version import glob @@ -37,6 +34,10 @@ import DDFacet.Other.ClassJonesDomains import psutil from . import ClassGiveCatalog +from DynSpecMS.kernels.phase_and_sum_vis import phase_and_sum_direction +from DynSpecMS.kernels.t_idx_jones import compute_jones_diag_for_time, extract_row_jones_jax +from jax import jit, vmap +import jax.numpy as jnp def print_memory_info(): mem_info = psutil.virtual_memory() @@ -255,8 +256,7 @@ def InitFromCatalog(self): print("Selected %i target [out of the %i in the original list]"%(self.NDirSelected,CGC.NOrig), file=log) if self.NDirSelected==0: - print(ModColor.Str(" Have found no sources - returning"), file=log) - self.killWorkers() + print(ModColor.Str(f" Have found no sources within the specified {self.Radius}-degree radius - returning without executing"), file=log) return NOff=self.NOff @@ -326,15 +326,13 @@ def give_iFacet_iTessel(l,m): print_memory_info() try: - shape = (self.NDir, self.NChan, self.NTimesGrid, 4) + shape = (self.NDir, self.NChan, self.NTimesGrid, self.npol_grid) log.print(f"Allocating GridLinPol with shape {shape}, memory usage: {compute_memory_usage_gb(shape):.2f} GB") self.DicoGrids["GridLinPol"] = np.zeros(shape, np.complex128) - shape = (self.NDir, self.NChan, self.NTimesGrid, 4) log.print(f"Allocating GridWeight with shape {shape}, memory usage: {compute_memory_usage_gb(shape):.2f} GB") self.DicoGrids["GridWeight"] = np.zeros(shape, np.complex128) - shape = (self.NDir, self.NChan, self.NTimesGrid, 4) log.print(f"Allocating GridWeight2 with shape {shape}, memory usage: {compute_memory_usage_gb(shape):.2f} GB") self.DicoGrids["GridWeight2"] = np.zeros(shape, np.complex128) except Exception as e: @@ -352,16 +350,10 @@ def give_iFacet_iTessel(l,m): if self.BeamModel is not None or self.DDFParset!="": self.DoJonesCorr_Beam=True - AsyncProcessPool.APP=None - # AsyncProcessPool.init(ncpu=self.NCPU, - # num_io_processes=1, - # affinity="disable") - AsyncProcessPool._init_default() - AsyncProcessPool.init((self.NCPU or psutil.cpu_count(logical=False)-2), + self.APP=AsyncProcessPool.init((self.NCPU or psutil.cpu_count(logical=False)-2), affinity=0, num_io_processes=1, verbose=0) - self.APP=AsyncProcessPool.APP self.APP.registerJobHandlers(self) self.APP.startWorkers() @@ -513,21 +505,17 @@ def give_in_points(P,NRand=5): CatOff=CatOff[CatOff.ra!=0] else: - while NDone ["I", "Q"]) + raw_stokes = list(self.options.stokes.upper()) if hasattr(self.options, 'stokes') else ["I", "Q", "U", "V"] + self.stokes_list = list(dict.fromkeys([s for s in raw_stokes if s in "IQUV"])) + + if not self.stokes_list: + raise ValueError("No valid Stokes parameters requested. Use combinations of I, Q, U, V.") + + if CorrType == [9, 10, 11, 12]: + if set(self.stokes_list).issubset({'I', 'Q'}): + self.ms_pol_indices = (0, 3) + else: + self.ms_pol_indices = (0, 1, 2, 3) + elif CorrType == [9, 12]: + self.ms_pol_indices = (0, 1) + if 'U' in self.stokes_list or 'V' in self.stokes_list: + raise ValueError(f"Requested {self.stokes_list} but MS only has XX and YY.") else: raise ValueError("Pols should be XX, XY, YX, YY or XX, YY") + + # The new grid size is based on exactly what we keep + self.npol_grid = len(self.ms_pol_indices) + + # For the JAX kernel, it will now process a densely packed array of either 2 or 4 pols + self.kernel_pol_indices = tuple(range(self.npol_grid)) + tp.close() chFreq=tf.getcol("CHAN_FREQ").ravel() @@ -661,15 +672,13 @@ def ReadMSInfos(self): raise ValueError("should have the same chan width") pBAR.render(iMS+1, self.nMS) + if tmin is None or tmax is None: + raise RuntimeError("None of the provided MS files could be read. Please check the logs above for specific errors.") + self.NTimesGrid=int(np.ceil((tmax-tmin)/dtBin)) self.timesGrid=tmin+np.arange(self.NTimesGrid)*dtBin self.tmin=tmin self.tmax=tmax - - for iMS in range(self.nMS): - if not DicoMSInfos[iMS]["Readable"]: - print(ModColor.Str("Problem reading %s"%MSName), file=log) - print(ModColor.Str(" %s"%DicoMSInfos[iMS]["Exception"]), file=log) t.close() @@ -751,13 +760,15 @@ def LoadMS(self,iJob): data = np.zeros((NROW,nch,npol),np.complex64) t.getcolnp(self.ColName,data,ROW0,NROW) if RevertChans: data=data[:,::-1,:] + + data = data[:, :, self.ms_pol_indices] # keep only the pols we need if self.ModelName: print(" Substracting %s from %s"%(self.ModelName,self.ColName), file=log) model=np.zeros((NROW,nch,npol),np.complex64) t.getcolnp(self.ModelName,model,ROW0,NROW) if RevertChans: model=model[:,::-1,:] - + model = model[:, :, self.ms_pol_indices] # Slice model data-=model del(model) @@ -783,6 +794,7 @@ def LoadMS(self,iJob): flag=np.zeros((NROW,nch,npol),bool) t.getcolnp("FLAG",flag,ROW0,NROW) if RevertChans: flag=flag[:,::-1] + flag = flag[:, :, self.ms_pol_indices] # Slice flag # data[:,:,:]=0 @@ -845,7 +857,8 @@ def setJones(self,DicoDATA): "At":"tessel", "DtBeamMin":5., "NBand":self.BeamNBand, - "CenterNorm":1} + "CenterNorm":1, + "ForceScalar": False} SolsName=self.SolsName if SolsName is not None and "[" in SolsName: @@ -858,7 +871,8 @@ def setJones(self,DicoDATA): "SolsDir":self.SolsDir, "GlobalNorm":None, "JonesNormList":"AP"}, - "Cache":{"Dir":self.CacheDir} + "Cache":{"Dir":self.CacheDir}, + "Parallel":{"NCPU": self.NCPU} } print("Reading Jones matrices solution file:", file=log) @@ -1065,18 +1079,31 @@ def killWorkers(self): def Finalise(self): - G=self.DicoGrids["GridLinPol"] W=self.DicoGrids["GridWeight"].copy() W[W == 0] = 1 Gn = G/W self.Gn=Gn - GOut=np.zeros_like(G) - GOut[..., 0] = 0.5*(Gn[..., 0] + Gn[..., 3]) # I = 0.5(XX + YY) - GOut[..., 1] = 0.5*(Gn[..., 0] - Gn[..., 3]) # Q = 0.5(XX - YY) - GOut[..., 2] = 0.5*(Gn[..., 1] + Gn[..., 2]) # U = 0.5(XY + YX) - GOut[..., 3] = -0.5j*(Gn[..., 1] - Gn[..., 2]) # V = -0.5i(XY - YX) + # Allocate GOut specifically to the number of requested Stokes parameters + GOut = np.zeros(G.shape[:-1] + (len(self.stokes_list),), dtype=np.complex128) + + for i, s in enumerate(self.stokes_list): + if s == 'I': + if self.npol_grid == 4: + GOut[..., i] = 0.5 * (Gn[..., 0] + Gn[..., 3]) # I = 0.5(XX + YY) + else: + GOut[..., i] = 0.5 * (Gn[..., 0] + Gn[..., 1]) # Tight-packed XX, YY + elif s == 'Q': + if self.npol_grid == 4: + GOut[..., i] = 0.5 * (Gn[..., 0] - Gn[..., 3]) # Q = 0.5(XX - YY) + else: + GOut[..., i] = 0.5 * (Gn[..., 0] - Gn[..., 1]) + elif s == 'U': + GOut[..., i] = 0.5 * (Gn[..., 1] + Gn[..., 2]) # U = 0.5(XY + YX) + elif s == 'V': + GOut[..., i] = -0.5j * (Gn[..., 1] - Gn[..., 2]) # V = -0.5i(XY - YX) + self.GOut = GOut # def Stack_SingleTime(self,DicoDATA,iTime): @@ -1097,21 +1124,19 @@ def Stack_SingleTimeAllDir(self,iJob,iTime): if indRow.size==0: return ThisTime=self.DicoMSInfos[iMS]["times"][iTime] - nrow,nch,npol=DicoDATA["data"].shape - indCh=np.int64(np.arange(nch)).reshape((1,nch,1)) - indPol=np.int64(np.arange(npol)).reshape((1,1,npol)) - indR=indRow.reshape((indRow.size,1,1)) - nRowOut=indRow.size - indArr=nch*npol*np.int64(indR)+npol*np.int64(indCh)+np.int64(indPol) + nrow, nch, npol_grid = DicoDATA["data"].shape + indCh = np.int64(np.arange(nch)).reshape((1,nch,1)) + indPol = np.int64(np.arange(npol_grid)).reshape((1,1,npol_grid)) + indR = indRow.reshape((indRow.size,1,1)) + nRowOut = indRow.size + indArr = nch*npol_grid*np.int64(indR) + npol_grid*np.int64(indCh) + np.int64(indPol) - #indRow = np.where(DicoDATA["times"]>0)[0] - #f = DicoDATA["flag"][indRow, :, :] - #d = DicoDATA["data"][indRow, :, :] - - T=ClassTimeIt.ClassTimeIt("SingleTimeAllDir") + T = ClassTimeIt.ClassTimeIt("SingleTimeAllDir") T.disable() - d = np.array((DicoDATA["data"].flat[indArr.flat[:]]).reshape((nRowOut,nch,npol))).copy() - f = np.array((DicoDATA["flag"].flat[indArr.flat[:]]).reshape((nRowOut,nch,npol))).copy() + + # FIX HERE: Replace 'npol' with 'npol_grid' in the reshape tuples + d = np.array((DicoDATA["data"].flat[indArr.flat[:]]).reshape((nRowOut, nch, npol_grid))).copy() + f = np.array((DicoDATA["flag"].flat[indArr.flat[:]]).reshape((nRowOut, nch, npol_grid))).copy() T.timeit("first") # for i in range(10): @@ -1144,168 +1169,64 @@ def Stack_SingleTimeAllDir(self,iJob,iTime): dcorr=d.copy() f0, _ = self.Freq_minmax ich0 = int( (ChanFreqs - f0)/self.ChanWidth ) - OneMinusF=(1-f).copy() - - W=np.zeros((nRowOut,nch,npol),np.float32) - for ipol in range(npol): - W[:,:,ipol]=weights[:,:,0] - W[f]=0 - Wc=W.copy() - # weights=weights*np.ones((1,1,npol)) - # W=weights - - kk=np.zeros_like(d) - T.timeit("third") - for iDir in range(self.NDir): - ra=self.PosArray.ra[iDir] - dec=self.PosArray.dec[iDir] - ra0,dec0=self.DicoMSInfos[iMS]["ra0dec0"] - l, m = self.radec2lm(ra, dec,ra0,dec0) - n = np.sqrt(1. - l**2. - m**2.) - - - T.timeit("lmn") - kkk = np.exp(-2.*np.pi*1j* chfreq/const.c.value *(u0*l + v0*m + w0*(n-1)) ) # Phasing term - T.timeit("kkk") + chfreq_1d = self.DicoMSInfos[iMS]["ChanFreq"].ravel() - for ipol in range(npol): - kk[:,:,ipol]=kkk[:,:,0] - T.timeit("kkk copy") - - # #ind=np.where((A0s==0)&(A1s==10))[0] - # ind=np.where((A0s!=1000))[0] - # import pylab - # pylab.ion() - # pylab.clf() - # pylab.plot(np.angle(d[ind,2,0])) - # pylab.plot(np.angle(kk[ind,2,0].conj())) - # pylab.draw() - # pylab.show(False) - # pylab.pause(0.1) - - - - #DicoMSInfos = self.DicoMSInfos - - #_,nch,_=DicoDATA["data"].shape - - dcorr[:]=d[:] - W=Wc.copy() - #W2=Wc.copy() - dcorr*=W - wdcorr=np.ones(dcorr.shape,np.float64) - #kk=kk*np.ones((1,1,npol)) + # evaluate domain-wide Jones diagnostics EXACTLY ONCE prior to direction iteration + J_diag, J_ch_domain_idx = None, None + if self.DoJonesCorr_kMS or self.DoJonesCorr_Beam: + DicoJones = shared_dict.attach("DicoJones_%i"%iJob) + DicoJones.reload() - T.timeit("corr") + # kernel computes the full diagonal cube for the timestep + G_arr = DicoJones["G"] + tm_arr = DicoJones["tm"] + fd_arr = DicoJones["FreqDomains"] - if self.DoJonesCorr_kMS or self.DoJonesCorr_Beam: - T1=ClassTimeIt.ClassTimeIt(" DoJonesCorr") - T1.disable() - DicoJones=shared_dict.attach("DicoJones_%i"%iJob) - DicoJones.reload() - T1.timeit("Load") - tm = DicoJones['tm'] - # Time slot for the solution - iTJones=np.argmin(np.abs(tm-ThisTime))#self.timesGrid[iTime])) - - #iDJones=np.argmin(AngDist(ra,DicoJones['ra'],dec,DicoJones['dec'])) - lJones, mJones = self.CoordMachine.radec2lm(DicoJones['ra'], DicoJones['dec']) - iDJones=np.argmin(np.sqrt((l-lJones)**2+(m-mJones)**2)) + # pass native primitives & memory arrays - not dicts + J_diag, iTJones, J_ch_domain_idx = compute_jones_diag_for_time( + G_arr, tm_arr, fd_arr, chfreq_1d, ThisTime + ) - - _,nchJones,_,_,_,_=DicoJones['G'].shape - T1.timeit("argmin") - - + # Simplify weights: phase_and_sum_direction casts to pols locally + w_scalar = weights[:, :, 0] - for iFJones in range(nchJones): - - nu0,nu1=DicoJones['FreqDomains'][iFJones] - fData=self.DicoMSInfos[iMS]["ChanFreq"].ravel() - indCh=np.where((fData>=nu0) & (fData=nu0) & (fData= (Radius*np.pi/180))[0] + for i_out in ind_out: + name_str = self.PosArray.Name[i_out] if hasattr(self.PosArray, 'Name') else "Target" + if isinstance(name_str, bytes): + name_str = name_str.decode('utf-8', errors='ignore') + dist_deg = Dist[i_out] * 180 / np.pi + log.print(f"Skipping {name_str}: distance from phase center ({dist_deg:.3f} deg) > radius ({Radius} deg)") + ind=np.where(Dist<(Radius*np.pi/180))[0] self.PosArray=self.PosArray[ind] diff --git a/DynSpecMS/ClassSaveResults.py b/DynSpecMS/ClassSaveResults.py index be791fd..b2ab68e 100644 --- a/DynSpecMS/ClassSaveResults.py +++ b/DynSpecMS/ClassSaveResults.py @@ -3,15 +3,11 @@ from __future__ import absolute_import from builtins import range from builtins import object -from distutils.spawn import find_executable from astropy.time import Time -from astropy import units as uni from astropy.io import fits from astropy.wcs import WCS -from astropy import coordinates as coord -from astropy import constants as const import numpy as np -import glob, os +import os #import pylab from DDFacet.Other import logger log=logger.getLogger("ClassSaveResults") @@ -30,17 +26,14 @@ def GiveMAD(X): return np.median(np.abs(X-np.median(X))) class ClassSaveResults(object): - def __init__(self, DynSpecMS,DIRNAME=None): + def __init__(self, DynSpecMS, DIRNAME=None): self.DynSpecMS=DynSpecMS - self.DIRNAME=DIRNAME - if self.DIRNAME is None or self.DIRNAME=="MSName": + + # Respect the DIRNAME handed down from the ms2dynspec cleanly + if DIRNAME is None or DIRNAME=="MSName": self.DIRNAME="DynSpecs_%s"%self.DynSpecMS.OutName else: - self.DIRNAME=os.path.join(self.DIRNAME,"_DynSpecs_%s"%(self.DynSpecMS.OutName)) - - - #image = self.DynSpecMS.Image - #self.ImageData=np.squeeze(fits.getdata(image, ext=0)) + self.DIRNAME=DIRNAME self.ImageI=self.DynSpecMS.ImageI if self.ImageI and os.path.isfile(self.DynSpecMS.ImageI): @@ -178,13 +171,18 @@ def WriteFitsThisDir(self,iDir,Weight=False): prihdr.set('CRVAL2', self.DynSpecMS.fMin*1e-6, 'Frequency at the reference pixel (MHz)') prihdr.set('CDELT2', self.DynSpecMS.ChanWidth*1e-6, 'Delta freq (MHz)') prihdr.set('CUNIT2', 'MHz', 'unit') - prihdr.set('CTYPE3', 'Stokes parameter', '1=I, 2=Q, 3=U, 4=V') + stokes_labels = [] + if hasattr(self.DynSpecMS, 'stokes_list'): + stokes_labels = [f"{i+1}={s}" for i, s in enumerate(self.DynSpecMS.stokes_list)] + else: + stokes_labels = ['1=I', '2=Q', '3=U', '4=V'] + prihdr.set('CTYPE3', 'Stokes parameter', ', '.join(stokes_labels)) prihdr.set('CRPIX3', 1., 'Reference') prihdr.set('CRVAL3', 1., 'frequence at the reference pixel') prihdr.set('CDELT3', 1., 'Delta stokes') prihdr.set('CUNIT3', '', 'unit') prihdr.set('DATE-CRE', Time.now().iso.split()[0], 'Date of file generation') - prihdr.set('OBSID', self.DynSpecMS.OutName, 'LOFAR Observation ID') + prihdr.set('OBSID', self.DynSpecMS.OutName, 'Observation ID') prihdr.set('CHAN-WID', self.DynSpecMS.ChanWidth, 'Frequency channel width') prihdr.set('FRQ-MIN', self.DynSpecMS.fMin, 'Minimal frequency') prihdr.set('FRQ-MAX', self.DynSpecMS.fMax, 'Maximal frequency') @@ -193,6 +191,8 @@ def WriteFitsThisDir(self,iDir,Weight=False): prihdr.set('RA_RAD', ra, 'Pixel right ascension') prihdr.set('DEC_RAD', dec, 'Pixel declination') prihdr.set('TEL_NAME', self.DynSpecMS.TELESCOPE_NAME, 'Telescope Name') + prihdr.set('PROJECT', self.DynSpecMS.PROJECT, 'Project ID') + prihdr.set('OBSERVER', self.DynSpecMS.OBSERVER, 'Observer') name=self.DynSpecMS.PosArray.Name[iDir] if not isinstance(name,str): @@ -222,7 +222,7 @@ def WriteFitsThisDir(self,iDir,Weight=False): # Gn=np.sqrt(Gn0)/Gn1 # Gn[Gn0==0]=0 else: - Gn = self.DynSpecMS.GOut[iDir,:, :, :].real + Gn = self.DynSpecMS.GOut[iDir, :, :, :].real hdu = fits.PrimaryHDU(np.rollaxis(Gn, 2), header=prihdr) #print(f"Fits being written: {fitsname}") @@ -289,11 +289,6 @@ def PlotSpecSingleDir(self, iDir=0, BoxArcSec=300.): pylab.clf() - # if find_executable("latex") is not None: - # pylab.rc('text', usetex=True) - # font = {'family':'serif', 'serif': ['Times']} - # pylab.rc('font', **font) - # Figure properties bigfont = 8 diff --git a/DynSpecMS/MakeDBImagesDynSpec.py b/DynSpecMS/MakeDBImagesDynSpec.py index 495de8b..27f65e5 100755 --- a/DynSpecMS/MakeDBImagesDynSpec.py +++ b/DynSpecMS/MakeDBImagesDynSpec.py @@ -34,8 +34,6 @@ log=MyLogger.getLogger("ClassInterpol") IdSharedMem=str(int(os.getpid()))+"." from pyrap.tables import table -from killMS.Other.ClassTimeIt import ClassTimeIt -from killMS.Other.least_squares import least_squares import copy import astropy.io.fits as pyfits import glob diff --git a/DynSpecMS/dynspecms_version.py b/DynSpecMS/dynspecms_version.py index 286bea7..486de34 100644 --- a/DynSpecMS/dynspecms_version.py +++ b/DynSpecMS/dynspecms_version.py @@ -7,11 +7,17 @@ def version(): path = os.path.dirname(os.path.abspath(__file__)) os.chdir(path) try: - result=subprocess.check_output('git describe --tags', shell=True,universal_newlines=True).rstrip() - except: - result='unknown' - os.chdir(prevdir) + result = subprocess.check_output( + 'git describe --tags', + shell=True, + universal_newlines=True, + stderr=subprocess.DEVNULL + ).rstrip() + except Exception: + result = 'unknown' + finally: + os.chdir(prevdir) return result if __name__=='__main__': - print(version()) + print(version()) \ No newline at end of file diff --git a/DynSpecMS/kernels/__init__.py b/DynSpecMS/kernels/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/DynSpecMS/kernels/phase_and_sum_vis.py b/DynSpecMS/kernels/phase_and_sum_vis.py new file mode 100644 index 0000000..ebab40a --- /dev/null +++ b/DynSpecMS/kernels/phase_and_sum_vis.py @@ -0,0 +1,152 @@ +""" +JAX-enabled kernel to phase visibilities to directions and sum them. + +Behavior: +- Phases visibilities to a target (ra,dec) given a phase centre (ra0,dec0). +- Optionally applies Jones (gain) corrections (precomputed or via a callable). +- Sums visibilities across rows -> returns per-channel, per-polarisation sums + and weight sums (and weight-squared sums) compatible with existing code. + +Usage sketch: + ds, ws, w2s = phase_and_sum_direction(vis, flag, weights, + u, v, w, A0s, A1s, + chan_freqs, ra, dec, ra0, dec0, + slicePol=slice(None), + Jones=None) +""" +from __future__ import division +import jax.numpy as jnp +from jax import jit +JAX_AVAILABLE = True +from typing import Optional +from functools import partial + +import numpy as np + +# speed of light (m/s) +C = 299792458.0 + + +def radec2lm(ra : jnp.ndarray, dec: jnp.ndarray, ra0: jnp.ndarray, dec0: jnp.ndarray) -> tuple[jnp.ndarray, jnp.ndarray]: + """ + Convert RA/DEC to direction cosines l,m for a given phase centre. + + Parameters + ---------- + ra, dec, ra0, dec0 : scalar or array-like (radians) + Input coordinates. + + Returns + ------- + l, m : same array type as inputs (jnp or np) + """ + # using the same formula as ClassDynSpecMS.radec2lm + l = jnp.cos(dec) * jnp.sin(ra - ra0) + m = jnp.sin(dec) * jnp.cos(dec0) - jnp.cos(dec) * jnp.sin(dec0) * jnp.cos(ra - ra0) + return l, m + + +@jit +def _compute_phase(chfreq : jnp.ndarray, u: jnp.ndarray, v: jnp.ndarray, w: jnp.ndarray, l: jnp.ndarray, m: jnp.ndarray, n: jnp.ndarray): + """Compute phasor exp(-2pi i nu/c * (u*l + v*m + w*(n-1))) + + Parameters + ---------- + chfreq : array_like, shape (nch,), float + Channel centre frequencies in Hz. + u, v, w : array_like, shape (nrow, 1, 1) + Baseline coordinates in metres. + l, m, n : scalar or array-like + Direction cosines. + Returns + ------- + phase : array_like, shape (nrow, nch, 1), complex + Phasor values. + """ + chf = chfreq.reshape((1, -1, 1)) + kterm = -2.0 * jnp.pi * 1j * chf / C + uvw_dot = u * l + v * m + w * (n - 1.0) + return jnp.exp(kterm * uvw_dot) + +@jit +def phase_and_sum_direction(vis : jnp.ndarray, flag : jnp.ndarray, weights : jnp.ndarray, u : jnp.ndarray, v : jnp.ndarray, w : jnp.ndarray, A0s : jnp.ndarray, A1s : jnp.ndarray, + chan_freqs : jnp.ndarray, ra : float, dec : float, ra0 : float, dec0 : float, + slicePol=(0,1,2,3), Jones=Optional[dict])-> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: + """ + Phase visibilities to a direction, apply optional Jones corrections, and sum. + + Parameters + ---------- + vis : array_like, shape (nrow, nch, npol), complex + Visibilities for selected time rows. + flag : array_like, shape (nrow, nch, npol), bool + Flag array (True => flagged). + weights : array_like, shape (nrow, nch) or (nrow, nch, 1) + Weights (scalar per visibility, broadcasted to pols). + u, v, w : array_like, shape (nrow,) or (nrow,1,1) + Baseline coordinates in metres. + A0s, A1s : array_like, shape (nrow,) + Antenna indices per visibility row (used for Jones lookups if provided). + chan_freqs : array_like, shape (nch,), float + Channel centre frequencies in Hz. + ra, dec : float + Target direction (radians). + ra0, dec0 : float + Phase centre (radians). + slicePol : slice or sequence + Polarisation mapping (like self.slicePol). + Jones : tuple of arrays with J0 and J1 + tuple (J0, J1, ch_mask) : J0,J1 are arrays broadcastable to + (nrow, n_ch_mask, 1) and ch_mask selects channels they apply to. + + Returns + ------- + ds : ndarray (nch, noutpol) complex (numpy) + ws : ndarray (nch, noutpol) float (numpy) + w2s: ndarray (nch, noutpol) float (numpy) + """ + nrow, nch, npol = vis.shape + + # build per-pol weights array + W = jnp.zeros((nrow, nch, npol), dtype=jnp.float32) + # broadcast the (nrow,nch) weights to each pol + for i_p in range(npol): + W = W.at[:, :, i_p].set(weights) + + # zero weights where flagged + W = jnp.where(flag, 0.0, W) + + dcorr = jnp.asarray(vis) + + # compute l,m,n + l, m = radec2lm(ra, dec, ra0, dec0) + n = jnp.sqrt(jnp.clip(1.0 - l * l - m * m)) + + # compute phasing + phase = _compute_phase(chan_freqs, u, v, w, l, m, n) + + if Jones is not None: + J0, J1, ch_mask = Jones + + dcorr = dcorr.at[:, ch_mask, :].set(J0.conj() * dcorr[:, ch_mask, :] * J1) + W = W.at[:, ch_mask, :].set(W[:, ch_mask, :] * (jnp.abs(J0) * jnp.abs(J1)) ** 2) + + # Apply phase and sum across rows + dcorr = dcorr * phase + ds = jnp.sum(dcorr, axis=0) # (nch, npol) + ws = jnp.sum(W, axis=0) + w2s = jnp.sum(W ** 2, axis=0) + + # Apply requested polarisation mapping + if isinstance(slicePol, slice): + ds_out = ds[:, slicePol] + ws_out = ws[:, slicePol] + w2s_out = w2s[:, slicePol] + else: + # assume slicePol is sequence of indices + idx = list(slicePol) + ds_out = ds[:, idx] + ws_out = ws[:, idx] + w2s_out = w2s[:, idx] + + return ds_out, ws_out, w2s_out \ No newline at end of file diff --git a/DynSpecMS/kernels/t_idx_jones.py b/DynSpecMS/kernels/t_idx_jones.py new file mode 100644 index 0000000..9a90195 --- /dev/null +++ b/DynSpecMS/kernels/t_idx_jones.py @@ -0,0 +1,177 @@ +""" +- compute_jones_diag_for_time(DicoJones, chan_freqs, time_value, default_val=1+0j) + -> (J_diag, iTJones, ch_domain_idx) + Compute per-antenna, per-channel diagonal Jones (J_00) for the nearest + Jones solution time to `time_value`. J_diag has shape + (nAnt, nChan, nDirJones) and is returned as a JAX array (jnp.ndarray). + +- extract_row_jones_jax(J_diag, A0s, A1s, dir_idx) + -> (J0, J1, ch_indices) + Given J_diag, produce per-row, per-channel J0/J1 arrays of shape + (nRow, nChan, 1) usable directly inside a JAX phasing kernel. dir_idx + can be a scalar (same Jones direction for all rows) or an array (one per row). + +Assumptions about DicoJones (match what's produced by ClassDynSpecMS.setJones): +- DicoJones['G'] shape == (nTimeJones, nFreqJones, nAnt, nDirJones, 2, 2) + (so indexing like G[iTJones, iFJones, Aindex, iDJones, 0, 0] is valid) +- DicoJones['tm'] shape == (nTimeJones,) +- DicoJones['FreqDomains'] is an iterable/list/array of (nu0, nu1) in Hz, + length nFreqJones + +Usage sketch: + J_diag, iT = compute_jones_diag_for_time(DicoJones, chan_freqs, ThisTime) + J0, J1, ch_mask = extract_row_jones_jax(J_diag, A0s, A1s, dir_idx) + # Pass (J0, J1, ch_mask) into your phase_and_sum kernel (which can be jitted). +""" +from __future__ import division + +import jax +import jax.numpy as jnp +from jax import jit, vmap +JAX = True + +@jit +def _build_channel_domain_mapping(chan_freqs : jnp.ndarray, freq_domains : jnp.ndarray) -> tuple[jnp.ndarray, jnp.ndarray]: + """ + Vectorized mapping of each channel to a Jones freq-domain index. + + - chan_freqs: (nChan,) + - freq_domains: array-like (nFreqJones, 2) of (nu0, nu1) + + Returns: + - ch_domain_idx: (nChan,) int32, index of domain for each channel or -1 + - ch_any_mask: (nChan,) bool indicates channels matched to any domain + """ + chan_freqs = jnp.asarray(chan_freqs) + fd = jnp.asarray(freq_domains) + if fd.size == 0: + nChan = chan_freqs.shape[0] + return -jnp.ones((nChan,), dtype=jnp.int32), jnp.zeros((nChan,), dtype=bool) + + nu0 = fd[:, 0].reshape((-1, 1)) + nu1 = fd[:, 1].reshape((-1, 1)) + + mask_all = (chan_freqs.reshape((1, -1)) >= nu0) & (chan_freqs.reshape((1, -1)) < nu1) + + ch_any = jnp.any(mask_all, axis=0) + + ch_idx = jnp.argmax(mask_all, axis=0).astype(jnp.int32) + + ch_idx = jnp.where(ch_any, ch_idx, -1) + return ch_idx, ch_any + + +@jit +def compute_jones_diag_for_time(G, tm, freq_domains, chan_freqs, time_value, default_val=1+0j): + """ + Compute per-antenna, per-channel diagonal Jones J_00 for nearest Jones time. + + Parameters + ---------- + DicoJones : dict-like (supports indexing with keys) + 'G' : (nTimeJones, nFreqJones, nAnt, nDirJones, 2, 2) + 'tm': (nTimeJones,) + 'FreqDomains' : iterable/lists of (nu0, nu1) length nFreqJones + chan_freqs : array-like (nChan,) frequencies in Hz + time_value : scalar (same units as DicoJones['tm']) + default_val : complex scalar used for channels not covered by any domain + + Returns + ------- + J_diag : jnp.ndarray shape (nAnt, nChan, nDirJones) complex + The diagonal element J[0,0] for each antenna, channel, and Jones direction. + iTJones : int + Index of the selected Jones time slice (nearest to time_value). + ch_domain_idx : jnp.ndarray shape (nChan,) int32 + Index of the freq-domain each channel was assigned to, or -1. + """ + + G = jnp.asarray(G) + tm = jnp.asarray(tm) + freq_domains = jnp.asarray(freq_domains) + chan_freqs = jnp.asarray(chan_freqs) + + # 1) choose nearest Jones time index + # Using argmin on abs difference + time_diffs = jnp.abs(tm - time_value) + iTJones = jnp.argmin(time_diffs).astype(jnp.int32) + + # 2) slice the time => G_time shape (nFreqJones, nAnt, nDirJones, 2, 2) + G_time = G[iTJones] # (nFreqJones, nAnt, nDirJones, 2, 2) + + # 3) extract diagonal (0,0) per freq-domain -> (nFreqJones, nAnt, nDirJones) + J00_domains = G_time[:, :, :, 0, 0] + + # 4) determine channel -> freq-domain mapping + # ch_domain_idx shape (nChan,), -1 if not in any domain + ch_domain_idx, ch_any_mask = _build_channel_domain_mapping(chan_freqs, freq_domains) + + # 5) safe gather: for channels with -1 put index 0 temporarily then overwrite with default + safe_idx = jnp.where(ch_domain_idx >= 0, ch_domain_idx, 0) # (nChan,) + + # gather J00 for each channel using safe_idx -> shape (nChan, nAnt, nDirJones) + J_chan_ant_dir = jnp.take(J00_domains, safe_idx, axis=0) + J_chan_ant_dir = jnp.transpose(J_chan_ant_dir, (1, 0, 2)) # (nAnt, nChan, nDirJones) + + # Branchless JAX: Unconditionally apply the mask using jnp.where + nAnt, nChan, nDirJones = J_chan_ant_dir.shape + default_block = jnp.full((nAnt, nChan, nDirJones), default_val, dtype=J_chan_ant_dir.dtype) + + # Broadcast the invalid mask so it shapes exactly to (1, nChan, 1) to match (nAnt, nChan, nDirJones) + mask_broadcast = (~ch_any_mask).reshape((1, nChan, 1)) + + # Fill in the default values wherever the mask failed + J_diag = jnp.where(mask_broadcast, default_block, J_chan_ant_dir) + + return J_diag, iTJones, ch_domain_idx + + +def extract_row_jones_jax(J_diag : jnp.ndarray, A0s : jnp.ndarray, A1s : jnp.ndarray, dir_idx : int) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: + """ + Build per-row per-channel J0 and J1 arrays suitable for immediate use. + + Parameters + ---------- + J_diag : jnp.ndarray (nAnt, nChan, nDirJones) + A0s, A1s : array-like (nRow,) antenna indices per row + dir_idx : scalar int or array-like (nRow,) Jones-direction index per row + + Returns + ------- + J0 : jnp.ndarray (nRow, nChan, 1) complex + J1 : jnp.ndarray (nRow, nChan, 1) complex + ch_indices : jnp.ndarray (nChan,) int (0..nChan-1) + """ + J_diag = jnp.asarray(J_diag) + A0s = jnp.asarray(A0s, dtype=jnp.int32) + A1s = jnp.asarray(A1s, dtype=jnp.int32) + + nRow = A0s.shape[0] + nChan = J_diag.shape[1] + + # scalar dir_idx for all rows -> simple indexing + if jnp.ndim(dir_idx) == 0: + J_dir = J_diag[:, :, dir_idx] # (nAnt, nChan) + # Use advanced indexing on antenna axis to select per-row antennas + # J_dir[A0s, :] -> (nRow, nChan) + J0 = J_dir[A0s, :] + J1 = J_dir[A1s, :] + + else: + # dir_idx is per-row array; we need J_diag[A_ant, :, d_idx] for each row + dir_idx = jnp.asarray(dir_idx, dtype=jnp.int32) # (nRow,) + + # vmapped over rows to gather row-specific (nChan,) arrays + def _row_fetch(a_idx, d_idx): + # returns shape (nChan,) + return J_diag[a_idx, :, d_idx] + + vmap_row_fetch = vmap(_row_fetch, in_axes=(0, 0), out_axes=0) + J0 = vmap_row_fetch(A0s, dir_idx) # (nRow, nChan) + J1 = vmap_row_fetch(A1s, dir_idx) # (nRow, nChan) + + # Add trailing singleton axis to match kernel broadcasting expectations + J0 = J0[..., jnp.newaxis] # (nRow, nChan, 1) + J1 = J1[..., jnp.newaxis] # (nRow, nChan, 1) + ch_indices = jnp.arange(nChan, dtype=jnp.int32) + return J0, J1, ch_indices \ No newline at end of file diff --git a/DynSpecMS/logo.py b/DynSpecMS/logo.py index 88f1da3..bb05712 100644 --- a/DynSpecMS/logo.py +++ b/DynSpecMS/logo.py @@ -1,12 +1,10 @@ from __future__ import print_function def PrintLogo(version): - print(""" ______ _____ ___ ___ _____ """) - print(""" | _ \ / ___| | \/ |/ ___| """) - print(""" | | | |_ _ _ __ \ `--. _ __ ___ ___| . . |\ `--. """) - print(""" | | | | | | | '_ \ `--. \ '_ \ / _ \/ __| |\/| | `--. \ """) - print(""" | |/ /| |_| | | | /\__/ / |_) | __/ (__| | | |/\__/ / """) - print(""" |___/ \__, |_| |_\____/| .__/ \___|\___\_| |_/\____/ """) - print(""" __/ | | | """) - print(""" |___/ |_| """) + print(""" _____ _____ __ __ _____ """) + print(""" | __ \|_ _| \/ |/ ____|""") + print(""" | |__) | | | | \ / | (___ """) + print(""" | _ / | | | |\/| |\___ \ """) + print(""" | | \ \ _| |_| | | |____) |""") + print(""" |_| \_\_____|_| |_|_____/ """) print(" version",version) diff --git a/DynSpecMS/schema/__init__.py b/DynSpecMS/schema/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/DynSpecMS/schema/kronicle_rims_schema.py b/DynSpecMS/schema/kronicle_rims_schema.py new file mode 100644 index 0000000..b2c099e --- /dev/null +++ b/DynSpecMS/schema/kronicle_rims_schema.py @@ -0,0 +1,269 @@ +import re +from datetime import datetime, timezone +from typing import Any, Literal, Optional + +from kronicle_sdk.models.data.kronicable_sample import KronicableSample +from pydantic import ( + BaseModel, + ConfigDict, + EmailStr, + Field, + field_validator, + model_validator, +) + + +class DataDimensions(BaseModel): + time_start_utc: datetime # will automatically convert a unix timestamp or string to a datetime object + time_end_utc: datetime + time_resolution_s: float = Field( + ..., ge=0, description="Time resolution (t delta) in seconds" + ) + frequency_min_mhz: float = Field(..., ge=0, description="Minimum frequency in MHz") + frequency_max_mhz: float = Field(..., ge=0, description="Maximum frequency in MHz") + frequency_resolution_khz: float = Field( + ..., ge=0, description="Frequency resolution (channel width) in kHz" + ) + + # Enforce exact stokes coverage + stokes: list[Literal["I", "Q", "U", "V"]] + + @field_validator("time_start_utc", "time_end_utc") + @classmethod + def check_not_in_future(cls, v: datetime) -> datetime: + # Check timezone to make sure it does not "look" like it is in the future + now = datetime.now(timezone.utc) + if v > now: + raise ValueError("Observation time cannot be in the future.") + return v + + @model_validator(mode="after") + def validate_time_order(self) -> "DataDimensions": + if self.time_end_utc <= self.time_start_utc: + raise ValueError("time_end_utc must be after time_start_utc") + return self + + +class AccessPolicy(BaseModel): + visibility: str = Field( + ..., description="e.g., 'public' or 'LOFAR KSP'" + ) # can access data whilst still under embargo + + embargo_months: int = Field( + default=0, + ge=0, + le=24, + description="Embargo period in full months (max 24 months)", + ) + + +class IdentifiedPerson(BaseModel): + email: Optional[EmailStr] = None + orcid: Optional[str] = None + name: Optional[str] = None + + @field_validator("orcid") + @classmethod + def validate_orcid(cls, v: Optional[str]) -> Optional[str]: + if v is None or v.strip(' "\'`') is None: + return None + v = v.strip(' "\'`') + + # Match optional https://, optional http://, then orcid.org/, + # then the 4-digit groups + match = re.fullmatch(r"(?:https?://)?(?:orcid\.org/)?((\d{4}-){3}\d{4})", v) + if not match: + raise ValueError("not a valid ORCID") + # Return only the 4 quadruplets + return match[1] + + @field_validator("email") + @classmethod + def normalize_email(cls, v: str) -> str: + v = v.strip(' "\'`').lower() + return v + + @field_validator("name") + @classmethod + def normalize_name(cls, v: str) -> str: + v = v.strip(' "\'`') + return v + + def model_dump(self, **params): + d = super().model_dump(**params) + return {k: v for k, v in d.items() if v is not None} + +class RimsSource(KronicableSample): + """ + Details about the dataset used as a source of the computation + """ + + model_config = ConfigDict(populate_by_name=True, alias_generator=None) + + instrument_name: str = Field( + ..., + description="Name of the instrument used for the observation i.e. MeerKAT, LOFAR, etc.", + ) # required + dataset_id: str = Field( + ..., + description="Unique identifier for the dataset - may just be the measurement set name if unknown", + ) # required + observer: Optional[IdentifiedPerson] = Field( + default=IdentifiedPerson(email="community@kronicle.org"), + description="Identifier/email/name of the person or system adding the data to Kronicle", + ) + # data_format: str = Field(default="FITS") + + +class RimsProduct(KronicableSample): + """ + One of the files produced by the computation + """ + + model_config = ConfigDict(populate_by_name=True, alias_generator=None) + + name: str + uri: str + source_type: Optional[str] = Field( + None, alias="type", description="e.g., star, pulsar, or bright source" + ) + file_extension: Optional[str] = Field( + default="application/fits", alias="mime", description="MIME type of the file" + ) + # Coordinates & Motion + ra_deg: float = Field( + ..., ge=0.0, lt=360.0, description="Right ascension in degrees [0, 360)" + ) + dec_deg: float = Field( + ..., ge=-90.0, le=90.0, description="Declination in degrees [-90, 90]" + ) + # pmra: Optional[float] = Field( + # None, description="Proper motion in RA (mas/yr). Can be positive or negative." + # ) + # pmdec: Optional[float] = Field( + # None, description="Proper motion in Dec (mas/yr). Can be positive or negative." + # ) + + +class AppService(KronicableSample): + """ + Details about the app used for the computation + """ + + model_config = ConfigDict(populate_by_name=True, alias_generator=None) + + hash: str = Field( + ..., + alias="RIMS client version", + description="Version of the RIMS client used to generate this payload, ideally a commit hash for reproducibility", + ) + maintainer: Optional[IdentifiedPerson] + + computing_infrastructure: Optional[str] = Field( + None, + description="Name of the computing infrastructure used for data processing, e.g., 'SURF', 'AWS', 'Google Cloud', etc.", + ) + + +class RimsBatch(KronicableSample): + """ + Details of the computations + """ + + model_config = ConfigDict(populate_by_name=True, alias_generator=None) + + tags: list[str] = Field(default_factory=list) + publisher: IdentifiedPerson + + # ----- Data filters + # Target Information and origin + catalog_key: Optional[str] = Field( + default=None, description="Target catalog ID, if target is from a known catalog" + ) + catalog_name: Optional[str] = Field( + default=None, description="Catalog Name, if target is from a known catalog" + ) + # Data Characteristics + data_dimensions: DataDimensions = Field( + ..., + alias="data dimensions", + description="This will be the time, frequency and polarization coverage", + ) # and antenna/baseline? + + # ----- Data filters + # Publication and Versioning + publication_details: Optional[str] = Field( + None, + alias="publication details", + description="Free-form string for BibTeX entry or ORCID ID. Recommended if data result is used in a publication", + ) + + # Access Policy + batch_access_policy: AccessPolicy = Field(..., alias="batch access policy") + + +class RimsObservationPayload(KronicableSample): + """ + Gathers all the different information about the computation + """ + + source: RimsSource + app: AppService + batch: RimsBatch + product: RimsProduct + + def get_fields(self) -> list[KronicableSample]: + return [self.source, self.app, self.batch, self.product] + + @classmethod + def get_field_classes(cls) -> list[type[KronicableSample]]: + return [RimsSource, AppService, RimsBatch, RimsProduct] + + @classmethod + def get_all_fields(cls): + aggregated: dict[str, str] = {} + for component_cls in cls.get_field_classes(): + if component_cls is not None: + aggregated.update(component_cls.get_all_fields()) + return aggregated + + @classmethod + def _get_channel_schema(cls) -> dict[str, str]: + """ + Aggregate the channel_schema from all subcomponents into a single dict. + Later keys overwrite earlier keys in case of conflicts. + """ + aggregated: dict[str, str] = {} + for component_cls in cls.get_field_classes(): + if component_cls is not None: + aggregated.update(component_cls._get_channel_schema()) + return aggregated + + @classmethod + def get_field_descriptions(cls) -> dict[str, str]: + aggregated: dict[str, str] = {} + for component in cls.get_field_classes(): + if component is not None: + aggregated.update(component.get_field_descriptions()) + return aggregated + + def to_row(self) -> dict[str, Any]: + aggregated: dict[str, Any] = {} + for component in self.get_fields(): + if component is not None: + aggregated.update(component.to_row()) + return aggregated + + +def get_field_descriptions(obj: KronicableSample) -> dict[str, str]: + """ + Return a dict mapping field names to their description, if a description was provided. + Works safely for both ModelField and FieldInfo. + """ + descriptions = {} + for name, field in obj.__class__.model_fields.items(): + # If it's a ModelField, grab its field_info; else assume it's already FieldInfo + info = getattr(field, "field_info", field) + if info.description is not None: + descriptions[name] = info.description + return descriptions diff --git a/DynSpecMS/scripts/cli.py b/DynSpecMS/scripts/cli.py new file mode 100644 index 0000000..74043d3 --- /dev/null +++ b/DynSpecMS/scripts/cli.py @@ -0,0 +1,40 @@ +# DynSpecMS/scripts/cli.py +import argparse +import sys + +from DynSpecMS.scripts import ms2dynspec +from DynSpecMS.scripts import dynspec_upload + +def main(): + parser = argparse.ArgumentParser( + prog="rims", + description="The Dynamic Spectra Extraction and Publishing Tool for Measurement Sets", + ) + + subparsers = parser.add_subparsers( + title="subcommands", + dest="command", + help="Choose a command to run" + ) + + subparsers.add_parser("run", help="Run the ms2dynspec extraction process", add_help=False) + subparsers.add_parser("publish", help="Publish dynamic spectra to RIMS Online", add_help=False) + # Future commands can easily be added here: + + if len(sys.argv) == 1: + parser.print_help() + sys.exit(1) + + args, remaining_argv = parser.parse_known_args() + + sys.argv = [f"rims {args.command}"] + remaining_argv + + if args.command == "run": + ms2dynspec.main() + elif args.command == "publish": + dynspec_upload.main() + else: + parser.print_help() + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/DynSpecMS/scripts/dynspec_upload.py b/DynSpecMS/scripts/dynspec_upload.py new file mode 100644 index 0000000..f3cd174 --- /dev/null +++ b/DynSpecMS/scripts/dynspec_upload.py @@ -0,0 +1,347 @@ +import json +import os +import glob +import requests +import numpy as np +from astropy.io import fits +from astropy.time import Time +from datetime import timezone + +from ..schema.kronicle_rims_schema import ( + RimsObservationPayload, + DataDimensions, + AccessPolicy, + RimsProduct, + RimsSource, + AppService, + RimsBatch, + IdentifiedPerson +) + +from kronicle_sdk.connectors.channel.channel_writer import KronicleWriter +from kronicle_sdk.utils.conf_utils import read_ini_conf + +def file_upload(filename: str, host: str, token: str)-> str: + """ Upload file filename to host host using authorization token token. + If successful, returns the URL for the publicly visible uploaded data, else + raises RuntimeError. + + Args: + filename: Path to the file to upload + host: URL of the upload server + token: Authorization token for the upload server + Returns: + URL of the uploaded file if successful, False otherwise. + """ + try: + with open(filename, 'rb') as infile: + r = requests.post(host, data={'auth':token}, files={'file': infile}) + status = r.json().get('status') + if status != 'success': + raise RuntimeError(f'Upload failed with status {status}') + else: + return r.json().get('url') + except Exception as e: + print(f"Error uploading {filename} to {host}: {e}") + return False + +def parse_dynspec_for_metadata( + filename: str, + run_metadata: dict, + visibility: str, + embargo_months: int, + publishing_info: str = None, + publisher_name: str = None, + publisher_email: str = None, + publisher_orcid: str = None, + maintainer_name: str = None, + maintainer_email: str = None, + maintainer_orcid: str = None, + file_url: str = "" +) -> RimsObservationPayload: + """ + Given a fits filename, parsed run_metadata, and an uploaded file URL, + parse the header information and return an RimsObservationPayload object + composed of Source, App, Batch, and Product schemas. + + Args: + filename: Path to the FITS file to parse. + run_metadata: Dictionary containing metadata about the run, read from run_metadata.json. + visibility: "public" or "private" string for access policy. + embargo_months: Integer number of months for embargo in access policy. + publishing_info: Optional string with additional publishing information to include in batch metadata. + publisher_name: Optional name of the publisher to include in batch metadata. + publisher_email: Optional email of the publisher to include in batch metadata. + publisher_orcid: Optional ORCID of the publisher to include in batch metadata. + maintainer_name: Optional name of the maintainer to include in app metadata. + maintainer_email: Optional email of the maintainer to include in app metadata. + maintainer_orcid: Optional ORCID of the maintainer to include in app metadata. + file_url: URL of the uploaded FITS file to include in product metadata. + Returns: + RimsObservationPayload object containing the parsed metadata. + """ + header = fits.getheader(filename) + + # Time parsing + obs_start = Time(header.get("OBS-STAR"), format="isot", scale="utc").datetime.replace(tzinfo=timezone.utc) + obs_stop = Time(header.get("OBS-STOP"), format="isot", scale="utc").datetime.replace(tzinfo=timezone.utc) + + # Frequency parsing (from FRQ-MIN and FRQ-MAX in Hz) + freq_min_mhz = header.get("FRQ-MIN", 0.0) / 1e6 + freq_max_mhz = header.get("FRQ-MAX", 0.0) / 1e6 + + # Frequency resolution (from CHAN-WID in Hz) + freq_resolution_khz = header.get("CHAN-WID", 0.0) / 1e3 + + # Time resolution (from CDELT1 in seconds) + time_resolution_s = abs(header.get("CDELT1", 0.0)) + + # Determine stokes parameters present based on NAXIS3 + n_stokes = header.get("NAXIS3", 4) + stokes_map = ["I", "Q", "U", "V"] + stokes = stokes_map[:n_stokes] + + data_dimensions = DataDimensions( + time_start_utc=obs_start, + time_end_utc=obs_stop, + time_resolution_s=time_resolution_s, + frequency_min_mhz=freq_min_mhz, + frequency_max_mhz=freq_max_mhz, + frequency_resolution_khz=freq_resolution_khz, + stokes=stokes + ) + + access_policy = AccessPolicy( + visibility=visibility, + embargo_months=embargo_months + ) + + ra_rad = header.get("RA_RAD", 0.0) + dec_rad = header.get("DEC_RAD", 0.0) + + ra_deg = float(np.rad2deg(ra_rad)) % 360.0 # modulus ensures [0, 360) range + dec_deg = float(np.rad2deg(dec_rad)) + dec_deg = max(-90.0, min(90.0, dec_deg)) + + sw_meta = run_metadata.get("software_metadata", {}) + client_version = sw_meta.get("version", "1.0.0") + if client_version == "unknown" and sw_meta.get("git_hash") != "Unknown": + client_version = sw_meta.get("git_hash") + + maintainer_person = IdentifiedPerson( + email=maintainer_email or "maintainer@kronicle.org", + name=maintainer_name, + orcid=maintainer_orcid + ) + + app_service = AppService( + **{"RIMS client version": client_version}, + maintainer=maintainer_person, + computing_infrastructure=sw_meta.get("os_platform", None) + ) + + publisher_person = IdentifiedPerson( + email=publisher_email or "community_user@kronicle.org", + name=publisher_name, + orcid=publisher_orcid + ) + + source = RimsSource( + dataset_id=header.get("OBSID", os.path.basename(filename).replace(".fits", "")).strip(), + instrument_name=header.get("TEL_NAME", "Unknown").strip(), + observer=IdentifiedPerson(name=header.get("OBSERVER", "Unknown").strip()) + ) + + batch = RimsBatch( + name=os.path.basename(run_metadata.get("arguments", {}).get("OutDirName", "unknown_batch")), + tags=[''], + publisher=publisher_person, + data_dimensions=data_dimensions, + batch_access_policy=access_policy, + **({"publication details": publishing_info} if publishing_info else {}) + ) + + product = RimsProduct( + name=header.get("NAME", "Unknown Target").strip(), + uri=file_url, + type=header.get("SRC-TYPE", "Unknown").strip(), + ra_deg=ra_deg, + dec_deg=dec_deg, + access_policy=access_policy, + file_extension=filename.split(".")[-1].lower() + ) + + payload = RimsObservationPayload( + source=source, + batch=batch, + app=app_service, + product=product + ) + + return payload + +def publish_to_kronicle(payload: RimsObservationPayload, kronicle_user: str, kronicle_pass: str, kronicle_host: str) -> bool: + """ + Function to publish the parsed RimsObservationPayload to Kronicle. + + Args: + payload: RimsObservationPayload object containing the metadata to publish. + kronicle_user: Username for Kronicle authentication. + kronicle_pass: Password for Kronicle authentication. + kronicle_host: URL of the Kronicle instance to publish to. + Returns: + True if publish is successful, False otherwise. + """ + kronicle_writer = KronicleWriter(kronicle_host, kronicle_user, kronicle_pass) + kronicle_payload = { + "channel_id" : 'bf88c5a1-6c6a-4766-b7c6-7c05b44702ec', + "channel_name" : "RIMS network", + "channel_schema" : payload.channel_schema, + "metadata": {"description": payload.get_field_descriptions()}, + "rows" : [payload.to_row()] + } + + # #write payload to file: + # with open("kronicle_payload.json", "w") as f: + # json.dump(kronicle_payload, f, indent=2, default=str) + try: + kronicle_writer.insert_rows_and_upsert_channel(kronicle_payload) + return True + except Exception as e: + return False + +def process_dynspec_directory( + root_dir: str, + upload_host: str, + upload_token: str, + kronicle_user: str, + kronicle_pass: str, + kronicle_host: str, + visibility: str, + embargo_months: int, + publishing_info: str = None, + publisher_name: str = None, + publisher_email: str = None, + publisher_orcid: str = None, + maintainer_name: str = None, + maintainer_email: str = None, + maintainer_orcid: str = None +): + """ + Iterates through TARGET, TARGET_W, OFF, OFF_W directories under root_dir, processing FITS files. + """ + metadata_file = os.path.join(root_dir, "run_metadata.json") + run_metadata = {} + if os.path.isfile(metadata_file): + try: + with open(metadata_file, "r") as f: + run_metadata = json.load(f) + except Exception as e: + print(f"Warning: Failed to read {metadata_file}: {e}") + else: + print(f"Warning: No run_metadata.json found in {root_dir}") + + subdirs_to_check = ["TARGET", "TARGET_W", "OFF", "OFF_W"] + + for subdir in subdirs_to_check: + target_path = os.path.join(root_dir, subdir) + if not os.path.isdir(target_path): + print(f"Skipping {subdir}: Directory not found.") + continue + + fits_files = glob.glob(os.path.join(target_path, "*.fits")) + + for fits_file in fits_files: + print(f"Processing: {fits_file}") + + # 1. Upload to host (file server) + file_url = file_upload(fits_file, upload_host, upload_token) + if file_url is False: + print(f"Failed to upload {fits_file}. Skipping.") + continue + print(f"Uploaded successfully. URL: {file_url}") + + # 2. Parse Metadata generation + payload = parse_dynspec_for_metadata( + fits_file, + run_metadata, + visibility, + embargo_months, + publishing_info=publishing_info, + publisher_name=publisher_name, + publisher_email=publisher_email, + publisher_orcid=publisher_orcid, + maintainer_name=maintainer_name, + maintainer_email=maintainer_email, + maintainer_orcid=maintainer_orcid, + file_url=file_url + ) + + # 3. Publish to Kronicle + result = publish_to_kronicle(payload, kronicle_user, kronicle_pass, kronicle_host) + if result: + print(f"Successfully processed {fits_file}.\n") + else: + print(f"Failed to publish metadata for {fits_file}.") + +import argparse + +def main(): + parser = argparse.ArgumentParser(description="Upload dynamic spectra FITS files and publish metadata to Kronicle.") + parser.add_argument("root_dir", help="Root directory containing TARGET, TARGET_W, OFF, OFF_W subdirectories.") + parser.add_argument("--server-conf", default="upload_server.ini", help="Path to server configuration INI file.") + parser.add_argument("--publisher-conf", default="upload_details.ini", help="Path to publisher details INI file.") + + args = parser.parse_args() + + server_conf = read_ini_conf(args.server_conf) + upload_host = server_conf.get("upload", "host") + upload_token = server_conf.get("upload", "token") + kronicle_user = server_conf.get("kronicle", "username") + kronicle_pass = server_conf.get("kronicle", "password") + kronicle_host = server_conf.get("kronicle", "host") + + print(f"Upload host: {upload_host}" + f"\nKronicle host: {kronicle_host}" + f"\nKronicle user: {kronicle_user}" + f"\nKronicle pass: {kronicle_pass}" + ) + + publisher_conf = read_ini_conf(args.publisher_conf) + visibility = publisher_conf.get("publishing_details", "visibility", fallback="private") + embargo_months = publisher_conf.getint("publishing_details", "embargo_months", fallback=0) + publshing_info = publisher_conf.get("publishing_details", "publishing_info", fallback=None) + publisher_name = publisher_conf.get("publishing_details", "publisher_name", fallback=None) + publisher_email = publisher_conf.get("publishing_details", "publisher_email", fallback=None) + publisher_orcid = publisher_conf.get("publishing_details", "publisher_orcid", fallback=None) + maintainer_name = publisher_conf.get("publishing_details", "maintainer_name", fallback=None) + maintainer_email = publisher_conf.get("publishing_details", "maintainer_email", fallback=None) + maintainer_orcid = publisher_conf.get("publishing_details", "maintainer_orcid", fallback=None) + + pub_details = None + if publshing_info and os.path.isfile(publshing_info): + with open(publshing_info, 'r') as f: + pub_details = f.read() + elif publshing_info: + print(f"Warning: publishing info file {publshing_info} not found.") + + process_dynspec_directory( + root_dir=args.root_dir, + upload_host=upload_host, + upload_token=upload_token, + kronicle_user=kronicle_user, + kronicle_pass=kronicle_pass, + kronicle_host=kronicle_host, + visibility=visibility, + embargo_months=embargo_months, + publishing_info=pub_details, + publisher_name=publisher_name, + publisher_email=publisher_email, + publisher_orcid=publisher_orcid, + maintainer_name=maintainer_name, + maintainer_email=maintainer_email, + maintainer_orcid=maintainer_orcid + ) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/DynSpecMS/scripts/ms2dynspec.py b/DynSpecMS/scripts/ms2dynspec.py index bea188d..c6df20e 100755 --- a/DynSpecMS/scripts/ms2dynspec.py +++ b/DynSpecMS/scripts/ms2dynspec.py @@ -34,37 +34,21 @@ ========================================================================= """ -import sys import os import argparse -from distutils.spawn import find_executable -from matplotlib import rc -#import matplotlib -#matplotlib.rcParams['font.family'] = 'sans-serif' -#fontsize=12 -#rc('font',**{'family':'serif','serif':['Times'],'size':fontsize}) -#if find_executable("latex") is not None: -# rc('text', usetex=True) from DDFacet.Other import Multiprocessing try: - import dask.array as da - from daskms import xds_from_table as table + from daskms import xds_from_table as dasktable HAS_DASK=True except: HAS_DASK=False from pyrap.tables import table -from astropy.time import Time -from astropy import units as uni -from astropy.io import fits -from astropy import coordinates as coord -from astropy import constants as const import numpy as np -import glob, os -import pylab +import os from DDFacet.Other import MyPickle from DynSpecMS import logo logo.PrintLogo(__version__) @@ -142,6 +126,86 @@ def angSep(ra1, dec1, ra2, dec2): temp = 1.0 * cmp(temp, 0) return np.degrees(np.arccos(temp)) +import sys +import platform +import subprocess +import DynSpecMS +import json + +def get_run_metadata(): + """Attempt to get hardware, Python, and Git information.""" + try: + package_dir = os.path.dirname(DynSpecMS.__file__) + except Exception: + package_dir = os.path.dirname(os.path.abspath(__file__)) + + git_hash = "Unknown" + git_remote = "Unknown" + + try: + git_hash = subprocess.check_output( + ['git', 'rev-parse', '--short', 'HEAD'], + cwd=package_dir, + stderr=subprocess.DEVNULL + ).decode('utf-8').strip() + + git_remote = subprocess.check_output( + ['git', 'config', '--get', 'remote.origin.url'], + cwd=package_dir, + stderr=subprocess.DEVNULL + ).decode('utf-8').strip() + except (subprocess.CalledProcessError, OSError, FileNotFoundError): + pass + + return { + "version": __version__, + "git_hash": git_hash, + "git_remote": git_remote, + "python_version": sys.version.split()[0], # e.g., '3.9.10' + "os_platform": f"{platform.system()} {platform.machine()}" + } + +def save_run_metadata(args, out_dir): + """Saves the software metadata and argument list to a JSON file, masking absolute paths.""" + meta = get_run_metadata() + + args_dict = dict(vars(args)) + path_keys = [ + 'ms', 'srclist', 'FitsCatalog', 'DicoFacet', 'imageI', + 'imageV', 'BaseDirSpecs', 'SolsDir', 'CacheDir', + 'DDFParset', 'OutDirName' + ] + + for k in path_keys: + val = args_dict.get(k) + if isinstance(val, str) and val.strip(): + safe_paths = [] + for p in val.split(','): + p = p.strip() + if not p: continue + + basename = os.path.basename(p.rstrip('/\\')) + if basename: + safe_paths.append(f"<{k.lower()}_path>/{basename}") + else: + safe_paths.append(f"<{k.lower()}_path>") + + if safe_paths: + args_dict[k] = ",".join(safe_paths) + + run_info = { + "software_metadata": meta, + "arguments": args_dict + } + + os.makedirs(out_dir, exist_ok=True) + out_file = os.path.join(out_dir, "run_metadata.json") + try: + with open(out_file, "w") as f: + json.dump(run_info, f, indent=4) + log.print(f"Saved run metadata to {out_file}") + except Exception as e: + log.print(f"Failed to save run metadata: {e}") def ms2dynspec(args=None, messages=[]): @@ -159,7 +223,8 @@ def ms2dynspec(args=None, messages=[]): DT={} for MSName in MSList: if HAS_DASK: - t = table(MSName) + t = dasktable(MSName) + print(type(t[0]["TIME"].values)) Times=np.unique((t[0]["TIME"].values)) else: t = table(MSName,ack=False) @@ -183,7 +248,7 @@ def ms2dynspec(args=None, messages=[]): field_decs=[] for MSName in MSList: if HAS_DASK: - tField = table(f"{MSName}::FIELD") + tField = dasktable(f"{MSName}::FIELD") ra0, dec0 = np.ravel(tField[0]["PHASE_DIR"].values) else: tField = table(f"{MSName}::FIELD",ack=False) @@ -192,12 +257,6 @@ def ms2dynspec(args=None, messages=[]): if ra0<0.: ra0+=2.*np.pi field_ras.append(ra0) field_decs.append(dec0) - tField.close() - - # L_radec=list(set(L_radec)) - # if len(L_radec)>1: stop - # ra0,dec0=L_radec[0] - field_ras=np.array(field_ras) field_decs=np.array(field_decs) ra_different=np.any(np.abs(field_ras-np.mean(field_ras))>args.tolerance*np.pi/(180*3600)) @@ -229,14 +288,15 @@ def ms2dynspec(args=None, messages=[]): SubSet=(iChunk,NChunk) for ik,k in enumerate(sorted(list(DT.keys()))): MSList=DT[k] - OutDirName=args.OutDirName - if OutDirName=="MSName": - OutDirName=args.ms + if args.OutDirName=="MSName": + # Old behaviour: exact folder name, no _T splits + DIRNAME=os.path.abspath("DynSpecs_%s"%args.ms) else: - OutDirName=os.path.basename(os.path.abspath(OutDirName)) - DIRNAME=os.path.abspath("DynSpecs_%s"%OutDirName) - if len(DT)>1: - DIRNAME = "%s_T%i"%(DIRNAME,ik) + DIRNAME=os.path.abspath(args.OutDirName) + # Elegantly split custom folders ONLY if multiple MSs exist + if len(DT)>1: + DIRNAME = "%s_T%i"%(DIRNAME,ik) + if NChunk>1: DIRNAME = "%s_RandChunk%i"%(DIRNAME,iChunk) D = ClassDynSpecMS(ListMSName=MSList, @@ -269,6 +329,9 @@ def ms2dynspec(args=None, messages=[]): if D.Mode=="Spec": D.StackAll() SaveMachine=ClassSaveResults.ClassSaveResults(D,DIRNAME=DIRNAME) + + save_run_metadata(args, SaveMachine.DIRNAME) + if D.Mode=="Spec": SaveMachine.WriteFits() SaveMachine.SaveCatalog() @@ -323,6 +386,7 @@ def main(): parser.add_argument("--SourceCatOff", type=str, default="", help="Read the code", required=False) parser.add_argument("--SourceCatOff_FluxMean", type=float, default=0, help="Read the code", required=False) parser.add_argument("--SourceCatOff_dFluxMean", type=float, default=0, help="Read the code", required=False) + parser.add_argument("--stokes", type=str, default="IQUV", help="Stokes params to compute, e.g., I, IV, IQUV") parser.add_argument("--NMaxTargets", type=int, default=0, help="Read the code", required=False) args = parser.parse_args() diff --git a/README.md b/README.md index 21e0c75..e3d9d67 100644 --- a/README.md +++ b/README.md @@ -3,8 +3,8 @@ A Radio Astronomy Software tool for creating the _Dynamic Spectra_ for a specifi ### Installation #### Requirements: -- Ubuntu 20.04 -- Python3.9 virtual environment +- Ubuntu 20.04 or 22.04 +- <= Python3.12 virtual environment - DDFacet installed to virtual environment - Casa5 installation @@ -14,14 +14,17 @@ git clone https://github.com/saopicc/RIMS.git pip install ./RIMS ``` +### RIMS run +For a given Measurement set and list of targets you're interested in, generate dynamic spectra for every individual target. + ### Usage View help: ``` -ms2dynspec --help +rims run --help ``` Run job: ``` -ms2dynspec --ms --data --model --srclist --rad --noff --DDFParset --CacheDir --OutDirName +rims run --ms --data --model --srclist --rad --noff --DDFParset --CacheDir --OutDirName ``` Notes: @@ -42,8 +45,9 @@ or in fits table format ### Output: The output folder should countain the following files/folders: ``` -├── 1541903773.reg +├── .reg ├── Catalog.npy +├── run_metadata.json ├── OFF ├── OFF_W ├── TARGET @@ -51,6 +55,47 @@ The output folder should countain the following files/folders: ``` Where the OFF and OFF_W directories contain the off target and off target weights fits files respectively, and the TARGET and TARGET_W directories contain the target and target weights fits files respectively. +### RIMS publish +Should you wish to contribute to the global "RIMS online" project, we encourage you to upload your generated dynamic spectra. To do so, + +### Usage +View help: +``` +rims publish --help +``` +Run job: +``` +rims publish ---server-conf --publisher-conf +``` + +Where the contents of the server conf file looks like: +``` +[upload] +host= +token= + +[kronicle] +host=https://kronicle.aqmo.org +username= +password= +``` + +And the contents of your details conf file looks like: +``` +[publishing_details] +visibility= +embargo_months= +publishing_info= +publisher_name= +publisher_email= +publisher_orcid= +maintainer_name= +maintainer_email= +maintainer_orcid= +``` + +For details on locally creating an upload server to host your dynamic spectra products, visit [rims-upload-server](https://github.com/mhardcastle/rims-upload-server). Otherwise, contact authors for a server to which you may upload your dynamic spectra. + ### Licensing: MIT License diff --git a/setup.py b/setup.py index 1c33b89..56271f1 100644 --- a/setup.py +++ b/setup.py @@ -6,23 +6,34 @@ packages=find_packages(include=['DynSpecMS', 'DynSpecMS.*']), include_package_data=True, install_requires=[ - 'dask[array]<=2023.5.0', - 'dask-ms==0.2.21', - 'xarray==2024.7.0', - 'psutil<=5.9.3' + 'dask[array]', + 'dask-ms', + 'xarray', + 'psutil', + 'numpy', + 'matplotlib', + 'astropy', + 'future', + 'scipy', + 'jax', + 'jaxlib', + 'pydantic>=2.0.0', + 'pydantic[email]', + 'requests', + 'kronicle_sdk' # other dependencies ], - entry_points={ + entry_points={ 'console_scripts': [ - 'ms2dynspec=DynSpecMS.scripts.ms2dynspec:main', + 'rims=DynSpecMS.scripts.cli:main', ], }, - author='Cyril Tasse', + author='Cyril Tasse and the RIMS team', author_email='cyril.tasse@obspm.fr', description='Extract Dynamic Spectra from Measurement Sets', long_description=open('README.md').read(), long_description_content_type='text/markdown', - url='https://github.com/cyriltasse/DynSpecMS', + url='https://github.com/saopicc/RIMS', classifiers=[ 'Programming Language :: Python :: 3', ],