added random sampling to test pipeline in various situations

This commit is contained in:
Sakimori 2025-07-29 14:33:09 -04:00
parent 2182108909
commit c4d88037f3
No known key found for this signature in database

View file

@ -4,6 +4,7 @@ import logging
from os import path, chdir, getcwd, makedirs from os import path, chdir, getcwd, makedirs
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
from typing import Union from typing import Union
from datetime import datetime
import numpy as np import numpy as np
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
@ -22,6 +23,13 @@ logger = logging.getLogger()
logFmt = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" logFmt = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
logging.basicConfig(level=logging.INFO, format=logFmt) logging.basicConfig(level=logging.INFO, format=logFmt)
#Set values for random sampling
MIN_DM = 300
MAX_DM = 10000
MIN_WIDTH = 0.001 #1 ms
MAX_WIDTH = 1 #1 s
rng = np.random.default_rng()
#This function from https://josephwkania.github.io/will/examples/inject_pulse.html#Show-how-we-can-inject-a-pulse-into-a-GREENBURST-filterbank. #This function from https://josephwkania.github.io/will/examples/inject_pulse.html#Show-how-we-can-inject-a-pulse-into-a-GREENBURST-filterbank.
def show_dynamic( def show_dynamic(
dynamic_spectra: np.ndarray, dynamic_spectra: np.ndarray,
@ -79,6 +87,13 @@ def processList(values):
dataName = path.dirname(dir[:-1]) #strip the trailing slash for dirname dataName = path.dirname(dir[:-1]) #strip the trailing slash for dirname
values.file = path.join(dir,dataname+".fil") values.file = path.join(dir,dataname+".fil")
addBurst(values) addBurst(values)
def randomDMandWidth():
"""
Helper function that returns (DM, pulseWidth) tuple with bounds set at start of script.
"""
return (rng.uniform(low=MIN_DM, high=MAX_DM), rng.uniform(low=MIN_WIDTH, high=MAX_WIDTH))
def addBurst(values): def addBurst(values):
""" """
@ -97,22 +112,29 @@ def addBurst(values):
filterbankObj = Your(path.join("./", f"{basename}_trunc.fil")) filterbankObj = Your(path.join("./", f"{basename}_trunc.fil"))
spectra = filterbankObj.get_data(0, samples) spectra = filterbankObj.get_data(0, samples)
#save pre-injection spectra plot
if not values.skipplot:
show_dynamic(spectra, f"{basename} Pre-injection Dynamic Spectra", save=True)
#get bandpass and store in bpWeights #get bandpass and store in bpWeights
bpWeights = create.filter_weights(spectra) bpWeights = create.filter_weights(spectra)
logger.info(f"{basename} loaded. Sampling pulse {values.nsamp} times.") logger.info(f"{basename} loaded. Sampling pulse {values.nsamp} times.")
#create pulse #create pulse
#check if this is part of a rng run
if values.rsamp:
#generate the values AND save the values to file for comparison later
dm, pWidth = randomDMandWidth()
with open(values.thisRunName, "a") as file:
file.write(f"{str(round(dm)).ljust(6)} pc/cc {str(round(pWidth,3))} s {basename}\n")
else:
dm = values.dm
pWidth = 0.001
#first version is very simple, plan on adding more complex injections in future #first version is very simple, plan on adding more complex injections in future
pulseObj = create.SimpleGaussPulse( pulseObj = create.SimpleGaussPulse(
sigma_time=0.001, sigma_time=pWidth,
sigma_freq=350, sigma_freq=350,
center_freq = filterbankObj.your_header.center_freq, center_freq = filterbankObj.your_header.center_freq,
dm = values.dm, dm = dm,
tau = 20, tau = 20,
phi=np.pi / 3, #does nothing if nscint = 0 phi=np.pi / 3, #does nothing if nscint = 0
spectral_index_alpha=0, spectral_index_alpha=0,
@ -131,10 +153,6 @@ def addBurst(values):
start = samples // 3 start = samples // 3
) )
#and save the new plot
if not values.skipplot:
logger.info("Saving plot...")
show_dynamic(injectedSpectra, f"{basename} Dynamic Spectra and {values.dm} DM Pulse", save=True)
#now generate new filterbank file #now generate new filterbank file
newName = f"{basename}_injected.fil" newName = f"{basename}_injected.fil"
sigprocObj = make_sigproc_object( sigprocObj = make_sigproc_object(
@ -177,6 +195,9 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"-f", "--file", dest="file", type=str, help="Single filterbank file." "-f", "--file", dest="file", type=str, help="Single filterbank file."
) )
parser.add_argument(
"-r", "--random", dest="rsamp", action="store_true", help="Use random DM and pulse widths with bounds set by file. Ignores -d."
)
parser.add_argument( parser.add_argument(
"-d", "--dm", dest="dm", type=float, help="DM of injected pulse." "-d", "--dm", dest="dm", type=float, help="DM of injected pulse."
) )
@ -186,15 +207,12 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"-p", "--plot", action="store_true", help="Just plot file and quit." "-p", "--plot", action="store_true", help="Just plot file and quit."
) )
parser.add_argument(
"-s", "--skipplot", action="store_true", help="Skip plotting for large filterbanks."
)
parser.set_defaults(dm=250.0) parser.set_defaults(dm=250.0)
parser.set_defaults(nsamp=int(3e5)) parser.set_defaults(nsamp=int(3e5))
parser.set_defaults(listfile=None) parser.set_defaults(listfile=None)
parser.set_defaults(file=None) parser.set_defaults(file=None)
parser.set_defaults(plot=False) parser.set_defaults(plot=False)
parser.set_defaults(skipplot=False) parser.set_defaults(rsamp=False)
values = parser.parse_args() values = parser.parse_args()
#set working directory to ignored directory #set working directory to ignored directory
@ -203,6 +221,8 @@ if __name__ == "__main__":
makedirs(outdir) makedirs(outdir)
chdir(outdir) chdir(outdir)
values.thisRunName = datetime.now().isoformat(timespec='seconds').replace(":", "-") + ".txt"
if values.file is not None: #single file takes priority if values.file is not None: #single file takes priority
logging.info(f"Running with file {values.file}") logging.info(f"Running with file {values.file}")
if values.plot: if values.plot: