greenburstAux/testanalysis.py
2025-08-05 11:21:15 -04:00

236 lines
8.9 KiB
Python

import os
from collections import Counter
from time import sleep
import logging
import matplotlib.pyplot as plt
from matplotlib import rcParams
import numpy as np
import pandas as pd
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO) #change for debug prints
def linesplit(line):
"""
Splits line of file into useful components.
Returns (dm, pulse width, originating filename).
"""
dm, line = line.split("pc/cc")
dm = int(dm)
pw, fname = line.split(" s ")
pw = float(pw)
fname = fname.strip().removesuffix("_injected")
return (dm, pw, fname)
def updateDic(lines, dic):
"""
Processes a list of lines and adds them to given dictionary in-place.
"""
for line in lines:
dm, pw, fname = linesplit(line)
dic["file"].append(fname)
dic["dm"].append(dm)
dic["pulseWidth"].append(pw)
injectedDic = {
"file" : [],
"dm" : [],
"pulseWidth" : []
}
detectedDic = {
"file" : [],
"dm" : [],
"pulseWidth" : []
}
#load injected data into dataframe
with open(os.path.join(".","out","2025-07-31T17-25-54.txt"), "r") as file:
lines = file.readlines()
updateDic(lines, injectedDic)
injections = pd.DataFrame(data=injectedDic) #this is our main object for injection data
#load detection data into dataframe
with open(os.path.join(".","out","plotOut.txt"), "r") as file:
lines = file.readlines()
updateDic(lines, detectedDic)
detections = pd.DataFrame(data=detectedDic) #this is our main object for detection data
#define summary printing for multiple steps
def summary(stage):
logger.info(f"Summary Stage {stage}")
logger.info(f"Number of files injected: {len(Counter(injections['file']))}")
logger.info(f"Number of files with detections: {len(Counter(detections['file']))}")
logger.info("=========================")
logger.info(f"Number of pulses injected: {len(injections['dm'])}")
logger.info(f"Number of pulses detected: {len(detections['dm'])}")
logger.info("=========================")
logger.info(f"File ratio: {round(len(Counter(detections['file']))/len(Counter(injections['file'])), 3)}")
logger.info("=========================")
#print initial summary
summary(1)
#let's track how many detections get removed by filtering
preFilterCount = len(detections['dm'])
#many files contained pulsar bursts, so we filter those out via DM
minDM = (10**2.5) * 0.95 #as per signal generation, plus a bit of wiggle room
logger.info(f"Filtering out pulsars (DM below {int(minDM)}...)")
detections = detections[detections['dm'] > minDM]
detections = detections.reset_index(drop=True)
summary(2)
#five files have SO MANY FALSE POSITIVES so get rid of them here?
remFiles = ["data_2025-04-30_07-53-07", "data_2025-05-01_07-47-34", "data_2025-04-24_07-36-04", "data_2025-04-29_07-50-16", "data_2025-04-30_08-18-17"]
logger.info("Filtering out the five problem files...")
remMask = [True] * len(detections['dm'])
for detection in detections.itertuples():
if detection.file in remFiles:
remMask[detection.Index] = False
detections = detections[remMask]
detections = detections.reset_index(drop=True)
summary(3)
postFilterCount = len(detections['dm'])
logger.info(f"Removed {preFilterCount-postFilterCount} detections by filtering the following:")
for file in remFiles:
logger.info(file+"_injected.fil")
#Let's do detection matching! Yaaaay!
#What detections line up to which injections? This will determine which ones got missed entirely.
#Define some kind of epsilon for DM and pulse width; if detection is within epsilon in DM we can match it.
dmEps = 5
#and define an auxiliary array of 0s for injections. List of detection counts!
matchCount = np.zeros(len(injections['dm']), dtype=int)
#also keep track of false positives:
falsePositiveMask = [False] * len(detections['dm'])
#Use queries to find matches
for detection in detections.itertuples():
qstring = (
f"(file == '{detection.file}') & "
f"((dm - @dmEps) < {detection.dm}) & "
f"((dm + @dmEps) > {detection.dm})"
)
matches = injections.query(qstring)
if len(matches) > 0:
logger.debug(f"Detection: DM {detection.dm} and PW {detection.pulseWidth}")
logger.debug(matches)
if len(matches) == 1:
i = matches.index[0]
matchCount[i] += 1
logger.debug("======")
elif len(matches) > 1:
raise ValueError("MULTIPLE MATCHES OHNO")
else: #no matching injection...
falsePositiveMask[detection.Index] = True
logger.debug(f"NO MATCH FOR: DM {detection.dm} and PW {detection.pulseWidth}")
logger.debug("Injections in file:")
logger.debug(injections.query(f"(file == '{detection.file}')"))
matchMaskInj = matchCount > 0
matchMaskDet = np.logical_not(falsePositiveMask)
missedMask = matchCount == 0
#So where are we?
#We have multiple datasets.
#1. Dataframe of all injected pulses. [injections]
#2. Dataframe of detections with pulsars filtered out. [detections]
#3. List of number of times each injection was detected [matchCount]
#4. A mask for only detected injections [matchMaskInj]
#5. A mask for only true positives [matchMaskDet]
#6. A mask for only missed injections [missedMask]
#7. A mask for false positives [falsePositiveMask]
logger.info(f"Successful detection ratio: {Counter(matchMaskInj)[True]/(len(injections['dm']))}")
#Let's try to figure out if certain files are responsible for the weird amount of false
#positives at around 10^2.7 pc/cc
#IT WAS 5 FILES. Filtered out above.
fpFileCounts = Counter(detections[falsePositiveMask]['file'])
sortedFPCounts = [(k, v) for k,v in sorted(fpFileCounts.items(), key=lambda value: -value[1])]
logger.info("False positive counts:")
for f, c in sortedFPCounts:
logger.info(f"{f}: {c}")
#Let's set a matplotlib default to make figures a bit bigger cos i like them
rcParams["figure.figsize"] = [7,5]
plt.rcParams['figure.constrained_layout.use'] = True
# #Injected pulses
# plt.figure(figsize=(7,10))
# allAx = plt.subplot(311)
# _, bins, _ = allAx.hist(np.log10(injections['dm']), bins=15)
# allAx.set_title("All injected pulses")
# detAx = plt.subplot(312, sharex=allAx, sharey=allAx)
# detAx.hist(np.log10(injections['dm'][matchMaskInj]), bins=bins)
# detAx.set_title("Detected pulses")
# misAx = plt.subplot(313, sharex=allAx, sharey=allAx)
# misAx.hist(np.log10(injections['dm'][missedMask]), bins=bins)
# misAx.set_title("Missed pulses")
# plt.ylabel("Count")
# plt.xlabel(r"DM (log pc cm$^3$)")
# plt.draw()
# #Detected pulses unstacked
# plt.figure(figsize=(7,10))
# allAx = plt.subplot(311)
# _, bins, _ = allAx.hist(np.log10(detections['dm']), bins=15)
# allAx.set_title("All detections")
# detAx = plt.subplot(312, sharex=allAx, sharey=allAx)
# detAx.hist(np.log10(detections['dm'][matchMaskDet]), bins=bins)
# detAx.set_title("Detected pulses")
# misAx = plt.subplot(313, sharex=allAx, sharey=allAx)
# misAx.hist(np.log10(detections['dm'][falsePositiveMask]), bins=bins)
# misAx.set_title("False positives")
# plt.ylabel("Count")
# plt.xlabel(r"DM (log pc cm$^3$)")
# plt.draw()
#Stacked histogram of injections
plt.figure(figsize=(7,10))
ax = plt.subplot(17,1,(1,8))
ax.hist([np.log10(injections['dm'][matchMaskInj]), np.log10(injections['dm'][missedMask])],
stacked=True, label=['Detected', 'Missed'], bins=15)
ax.yaxis.get_major_locator().set_params(integer=True)
ax.legend(loc='upper center')
ax.set_xlabel(r"DM (log pc cm$^{-3}$)")
ax.tick_params(labelbottom=True)
ax2 = plt.subplot(17,1,(9,16))
ax2.hist([np.log10(injections['pulseWidth'][matchMaskInj]), np.log10(injections['pulseWidth'][missedMask])],
stacked=True, label=['Detected', 'Missed'], bins=15)
ax2.yaxis.get_major_locator().set_params(integer=True)
ax2.set_xlabel(r"Pulse width (log s)")
plt.ylabel("Count")
plt.title("Injected pulses")
wordAx = plt.subplot(17,1,17)
wordAx.text(.3,.5,f"Overall detection rate: {round(Counter(matchMaskInj)[True]/(len(injections['dm']))*100,1)}%", size=14)
wordAx.set_axis_off()
plt.savefig(os.path.join("out","injections.png"))
#Stacked histogram of detections
plt.figure(figsize=(7,10))
ax = plt.subplot(17,1,(1,8))
ax.hist([np.log10(detections['dm'][matchMaskDet]), np.log10(detections['dm'][falsePositiveMask])],
stacked=True, label=['True detection', 'False positive'], bins=15)
ax.yaxis.get_major_locator().set_params(integer=True)
ax.set_xlabel(r"DM (log pc cm$^{-3}$)")
ax.legend(loc='upper center')
ax2 = plt.subplot(17,1,(9,16))
ax2.hist([np.log10(detections['pulseWidth'][matchMaskDet]), np.log10(detections['pulseWidth'][falsePositiveMask])],
stacked=True, label=['True detection', 'False positive'], bins=15)
ax2.yaxis.get_major_locator().set_params(integer=True)
ax2.set_xlabel("Pulse width (log s)")
plt.ylabel("Count")
plt.title("Detections")
wordAx = plt.subplot(17,1,17)
wordAx.text(.3,.5,f"False positive rate: {round(Counter(falsePositiveMask)[True]/(len(detections['dm']))*100,1)}%", size=14)
wordAx.set_axis_off()
plt.savefig(os.path.join("out","detections.png"))