greenburstAux/signalgen.py
2025-08-12 11:17:37 -04:00

311 lines
10 KiB
Python
Executable file

#! /minish/keh00032/.conda/envs/keh00032/bin/python
import logging
from os import path, chdir, getcwd, makedirs, remove, listdir
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
from typing import Union
from datetime import datetime
from time import sleep
import numpy as np
import matplotlib.pyplot as plt
from jess.dispersion import dedisperse
from jess.fitters import median_fitter
from scipy.stats import median_abs_deviation
from your import Your, Writer
from your.formats.filwriter import make_sigproc_object
from will import create, inject
from scipy import signal
logger = logging.getLogger()
logFmt = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
logging.basicConfig(level=logging.INFO, format=logFmt)
#Set values for random sampling
LOG_MIN_DM = 2.5
LOG_MAX_DM = 3.5
LOG_MIN_WIDTH = 0
LOG_MAX_WIDTH = 3 #in ms
MAX_BURSTS = 4
LOG_FIDUCIAL_WIDTH = -2.6
rng = np.random.default_rng()
#This function from https://josephwkania.github.io/will/examples/inject_pulse.html#Show-how-we-can-inject-a-pulse-into-a-GREENBURST-filterbank.
def show_dynamic(
dynamic_spectra: np.ndarray,
title: Union[str, None] = None,
save: Union[bool, None] = False
) -> None:
"""
Show a dynamic spectra by first flattening it
in frequency. Do this by getting the medians of
each channel and then run a median filter along the
bandpass.
Then set the limits of the imshow so we get good detail
for the majority of the data.
Args:
dynmaic_spectra - the dynamic spectra to plot
title - Title of plot
save - Save the plot as `title` + `.png`
"""
#downsample the spectra first
downFac = np.shape(dynamic_spectra)[0] // 8920
if downFac > 1:
logging.info(f"Downsampling with factor {downFac}...")
decSpectra = signal.decimate(dynamic_spectra, downFac, axis=0)
else:
decSpectra = dynamic_spectra
spectra_mads = median_fitter(np.median(decSpectra, axis=0))
flat = decSpectra - spectra_mads
std = median_abs_deviation(flat, axis=None)
med = np.median(flat)
plt.figure(figsize=(20, 10))
plt.imshow(flat.T, vmin=med - 3 * std, vmax=med + 6 * std, aspect="auto")
plt.xlabel("Time Sample #", size=20)
plt.ylabel("Channel #", size=20)
plt.colorbar()
plt.tight_layout()
if title is not None:
plt.title(title, size=28)
if save:
plt.savefig(title.replace(" ", "_") + ".png", dpi=75, bbox_inches="tight")
def processList(values):
"""
Entry point if values.file is None and values.listfile is set.
Processes the file for directories and iterates through by setting values.file and calling addBurst.
"""
with open(values.listfile, "r") as listfile:
files = listfile.readlines()
for file in files:
values.file = file.strip()
if values.plot:
filterbankObj = Your(values.file)
spectra = filterbankObj.get_data(0, 524288)
show_dynamic(spectra, f"{values.file} Dynamic Spectra", save=True)
else:
addBurst(values)
def processDir(values, fileList):
"""
Entry point if values.file and values.listfile are None and values.fileDir is set.
Processes all .fil files in a directory.
"""
for filename in fileList:
values.file = path.join(values.fileDir, filename)
try:
addBurst(values)
except Exception as error:
logger.warning(f"Encountered {error} while processing {filename}.")
def randomDMandWidth():
"""
Helper function that returns (DM, pulseWidth) tuple with bounds set at start of script.
"""
log_dm = rng.uniform(low=LOG_MIN_DM, high=LOG_MAX_DM)
log_pw = rng.uniform(low=LOG_MIN_WIDTH, high=LOG_MAX_WIDTH)
dm = 10**log_dm
pw = 10**(log_pw - 3) #-3 to convert ms to s
return (dm, pw)
def addBurst(values):
"""
Entry point if values.file is set.
--listfile will enter into this function multiple times for each line.
"""
filterbankObj = Your(values.file)
values.file = values.file.strip(".fil") #removing extension for text manipulation later
basename = path.basename(values.file)
samples = 524288
#the full filterbanks use 64GB in RAM when injecting burst, so we write out a truncated version and load that instead.
filWriter = Writer(filterbankObj, outdir="./", outname = f"{basename}_trunc", nstart = 0, nsamp = samples)
filWriter.to_fil()
#replace filterbankObj object and reload spectra (spectra should be the same but just in case)
filterbankObj = Your(path.join("./", f"{basename}_trunc.fil"))
spectra = filterbankObj.get_data(0, samples)
#get bandpass and store in bpWeights
bpWeights = create.filter_weights(spectra)
if values.rsamp:
pulseNum = rng.integers(1, high=MAX_BURSTS, endpoint=True)
else:
pulseNum = 1
logger.info(f"{basename} loaded. Injecting {pulseNum} pulses.")
starts = np.linspace(samples//4, 3*(samples//4), num=pulseNum, dtype=int)
for run in range(pulseNum):
start = starts[run]
#create pulse
#check if this is part of a rng run
if values.rsamp:
#generate the values AND save the values to file for comparison later
dm, pWidth = randomDMandWidth()
with open(values.thisRunName, "a") as file:
file.write(f"{str(round(dm)).ljust(6)} pc/cc {str(round(pWidth,3))} s {basename}\n")
else:
dm = values.dm
pWidth = values.pWidth
#first version is very simple, plan on adding more complex injections in future
pulseObj = create.SimpleGaussPulse(
sigma_time=pWidth,
sigma_freq=350,
center_freq = filterbankObj.your_header.center_freq,
dm = dm,
tau = 20,
phi=np.pi / 3, #does nothing if nscint = 0
spectral_index_alpha=0,
chan_freqs = filterbankObj.chan_freqs,
tsamp = filterbankObj.your_header.tsamp,
nscint=0,
bandpass = bpWeights
)
#We need to scale the number of samples with pulse width to (hopefully) maintain constant peak flux
scaleFac = pWidth/(10**(LOG_FIDUCIAL_WIDTH))
scaledSamps = int(values.nsamp * scaleFac)
logger.info(f"Sampling pulse # {run+1} with width {round(pWidth,3)} s {scaledSamps} times.")
pulse = pulseObj.sample_pulse(nsamp=scaledSamps) #30000 by default
logger.info("Injecting pulse and saving file.")
#inject pulse
if run != pulseNum-1:
savename = f"{basename}_injected_{run}.fil"
else:
savename = f"{basename}_injected.fil"
inject.inject_constant_into_file(
yr_input = filterbankObj,
pulse = pulse,
start = start,
out_fil = path.join("./", savename)
)
#reload into object
if run != pulseNum-1:
filterbankObj = Your(path.join("./", f"{basename}_injected_{run}.fil"))
if run > 0:
remove(path.join("./", f"{basename}_injected_{run-1}.fil"))
remove(path.join("./", f"{basename}_trunc.fil")) #delete truncated file to save 2GB of disk space
logger.info(f"Truncated file removed.")
""" #now generate new filterbank file
newName = f"{basename}_injected.fil"
sigprocObj = make_sigproc_object(
rawdatafile=newName, #d
source_name="TEMP", #d
nchans=filterbankObj.nchans,
foff=filterbankObj.your_header.foff, # MHz
fch1=filterbankObj.your_header.fch1, # MHz
tsamp=filterbankObj.your_header.tsamp, # seconds
tstart=filterbankObj.your_header.tstart, # MJD
src_raj=filterbankObj.your_header.ra_deg,
src_dej=filterbankObj.your_header.dec_deg,
machine_id=0,
nbeams=1,
ibeam=1,
nbits=filterbankObj.your_header.nbits,
nifs=1,
barycentric=0,
pulsarcentric=0,
data_type=1,
az_start=-1,
za_start=-1,
)
sigprocObj.write_header(newName)
logger.info("Saving filterbank...")
sigprocObj.append_spectra(injectedSpectra, newName)
logger.info(f"{newName} successfully written.")
os.remove(path.join("./", f"{basename}_trunc.fil")) #delete truncated file to save 2GB of disk space
logger.info(f"Truncated file removed.") """
if __name__ == "__main__":
parser = ArgumentParser(
description="Insert simulated bursts into GREENBURST filterbank files.",
formatter_class=ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"-l", "--listfile", dest="listfile", type=str, help="File containing list of greenburst filterbank directories."
)
parser.add_argument(
"-f", "--file", dest="file", type=str, help="Single filterbank file."
)
parser.add_argument(
"-r", "--random", dest="rsamp", action="store_true", help="Use random DM and pulse widths with bounds set by file. Ignores -d."
)
parser.add_argument(
"-d", "--dm", dest="dm", type=float, help="DM of injected pulse."
)
parser.add_argument(
"-w", "--width", dest="pWidth", type=float, help="Width of desired pulse in seconds."
)
parser.add_argument(
"-n", "--nsamp", type=int, help="Number of samples to take of the generated pulse."
)
parser.add_argument(
"-p", "--plot", action="store_true", help="Just plot file and quit."
)
parser.add_argument(
"-D", "--directory", dest="fileDir", type=str, help="Directory containing filterbank files."
)
parser.add_argument(
"-o", "--output", dest="output", type=str, help="Set output directory."
)
parser.add_argument(
"-s", "--start", type=int, help="Zoomed plot start time bin."
)
parser.set_defaults(dm=250.0)
parser.set_defaults(pWidth=0.001)
parser.set_defaults(nsamp=int(3e5))
parser.set_defaults(listfile=None)
parser.set_defaults(file=None)
parser.set_defaults(fileDir=None)
parser.set_defaults(plot=False)
parser.set_defaults(rsamp=False)
parser.set_defaults(output=None)
parser.set_defaults(start=-1)
values = parser.parse_args()
#set working directory to ignored directory or set output
if values.output is None:
outdir = path.join(getcwd(),"out","")
else:
outdir = values.output
if not path.isdir(outdir):
makedirs(outdir)
chdir(outdir)
values.thisRunName = datetime.now().isoformat(timespec='seconds').replace(":", "-") + ".txt"
if values.file is not None: #single file takes priority
logging.info(f"Running with file {values.file}")
if values.plot:
filterbankObj = Your(values.file)
if values.start == -1:
spectra = filterbankObj.get_data(0, 524288)
else:
spectra = filterbankObj.get_data(values.start, 8900)
show_dynamic(spectra, f"{values.file} Dynamic Spectra", save=True)
else:
addBurst(values)
elif values.listfile is not None: #list of files
logging.info(f"Looking for filenames in file {values.listfile}")
processList(values)
elif values.fileDir is not None: #directory
filePaths = [f for f in listdir(values.fileDir) if f.endswith(".fil")]
logging.info(f"Found {len(filePaths)} files.")
processDir(values, filePaths)