good plots

This commit is contained in:
Sakimori 2025-08-04 17:22:20 -04:00
parent 488ea2d6e4
commit 54c128b45f
No known key found for this signature in database

View file

@ -1,10 +1,16 @@
import os
from collections import Counter
from time import sleep
import logging
import matplotlib.pyplot as plt
from matplotlib import rcParams
import numpy as np
import pandas as pd
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO) #change for debug prints
def linesplit(line):
"""
Splits line of file into useful components.
@ -58,28 +64,43 @@ detections = pd.DataFrame(data=detectedDic) #this is our main object for detecti
#define summary printing for multiple steps
def summary(stage):
print(f"Summary Stage {stage}")
print(injections.head())
print(f"Number of files injected: {len(Counter(injections['file']))}")
print(f"Number of files with detections: {len(Counter(detections['file']))}")
print("=========================")
print(f"Number of pulses injected: {len(injections['dm'])}")
print(f"Number of pulses detected: {len(detections['dm'])}")
print("=========================")
print(f"File ratio: {round(len(Counter(detections['file']))/len(Counter(injections['file'])), 3)}")
print(f"Detection ratio: {round(len(detections['dm'])/len(injections['dm']), 3)}")
print("=========================")
logger.info(f"Summary Stage {stage}")
logger.info(f"Number of files injected: {len(Counter(injections['file']))}")
logger.info(f"Number of files with detections: {len(Counter(detections['file']))}")
logger.info("=========================")
logger.info(f"Number of pulses injected: {len(injections['dm'])}")
logger.info(f"Number of pulses detected: {len(detections['dm'])}")
logger.info("=========================")
logger.info(f"File ratio: {round(len(Counter(detections['file']))/len(Counter(injections['file'])), 3)}")
logger.info("=========================")
#print initial summary
summary(1)
#let's track how many detections get removed by filtering
preFilterCount = len(detections['dm'])
#many files contained pulsar bursts, so we filter those out via DM
minDM = (10**2.5) * 0.95 #as per signal generation, plus a bit of wiggle room
print(f"Filtering out pulsars (DM below {int(minDM)}...)")
logger.info(f"Filtering out pulsars (DM below {int(minDM)}...)")
detections = detections[detections['dm'] > minDM]
detections = detections.reset_index(drop=True)
summary(2)
#five files have SO MANY FALSE POSITIVES so get rid of them here?
remFiles = ["data_2025-04-30_07-53-07", "data_2025-05-01_07-47-34", "data_2025-04-24_07-36-04", "data_2025-04-29_07-50-16", "data_2025-04-30_08-18-17"]
logger.info("Filtering out the five problem files...")
remMask = [True] * len(detections['dm'])
for detection in detections.itertuples():
if detection.file in remFiles:
remMask[detection.Index] = False
detections = detections[remMask]
detections = detections.reset_index(drop=True)
summary(3)
postFilterCount = len(detections['dm'])
logger.info(f"Removed {preFilterCount-postFilterCount} detections by filtering.")
#Let's do detection matching! Yaaaay!
#What detections line up to which injections? This will determine which ones got missed entirely.
#Define some kind of epsilon for DM and pulse width; if detection is within epsilon in DM we can match it.
@ -97,29 +118,119 @@ for detection in detections.itertuples():
)
matches = injections.query(qstring)
if len(matches) > 0:
print(f"Detection: DM {detection.dm} and PW {detection.pulseWidth}")
print(matches)
logger.debug(f"Detection: DM {detection.dm} and PW {detection.pulseWidth}")
logger.debug(matches)
if len(matches) == 1:
i = matches.index[0]
matchCount[i] += 1
print("======")
logger.debug("======")
elif len(matches) > 1:
raise ValueError("MULTIPLE MATCHES OHNO")
else: #no matching injection...
falsePositiveMask[detection.Index] = True
print(f"NO MATCH FOR: DM {detection.dm} and PW {detection.pulseWidth}")
print("Injections in file:")
print(injections.query(f"(file == '{detection.file}')"))
matchMaskInj = [matchCount > 0]
logger.debug(f"NO MATCH FOR: DM {detection.dm} and PW {detection.pulseWidth}")
logger.debug("Injections in file:")
logger.debug(injections.query(f"(file == '{detection.file}')"))
matchMaskInj = matchCount > 0
matchMaskDet = np.logical_not(falsePositiveMask)
missedMask = [matchCount == 0]
missedMask = matchCount == 0
#So where are we?
#We have multiple datasets.
#1. List of all injected pulses. [injections]
#2. List of detections with pulsars filtered out. [detections]
#3. Number of times each injection was detected [matchCount]
#1. Dataframe of all injected pulses. [injections]
#2. Dataframe of detections with pulsars filtered out. [detections]
#3. List of number of times each injection was detected [matchCount]
#4. A mask for only detected injections [matchMaskInj]
#5. A mask for only true positives [matchMaskDet]
#6. A mask for only missed injections [missedMask]
#7. A mask for false positives [falsePositiveMask]
#7. A mask for false positives [falsePositiveMask]
logger.info(f"Successful detection ratio: {Counter(matchMaskInj)[True]/(len(injections['dm']))}")
#Let's try to figure out if certain files are responsible for the weird amount of false
#positives at around 10^2.7 pc/cc
fpFileCounts = Counter(detections[falsePositiveMask]['file'])
sortedFPCounts = [(k, v) for k,v in sorted(fpFileCounts.items(), key=lambda value: -value[1])]
logger.info("False positive counts:")
for f, c in sortedFPCounts:
logger.info(f"{f}: {c}")
#Let's set a matplotlib default to make figures a bit bigger cos i like them
rcParams["figure.figsize"] = [7,5]
plt.rcParams['figure.constrained_layout.use'] = True
# #Injected pulses
# plt.figure(figsize=(7,10))
# allAx = plt.subplot(311)
# _, bins, _ = allAx.hist(np.log10(injections['dm']), bins=15)
# allAx.set_title("All injected pulses")
# detAx = plt.subplot(312, sharex=allAx, sharey=allAx)
# detAx.hist(np.log10(injections['dm'][matchMaskInj]), bins=bins)
# detAx.set_title("Detected pulses")
# misAx = plt.subplot(313, sharex=allAx, sharey=allAx)
# misAx.hist(np.log10(injections['dm'][missedMask]), bins=bins)
# misAx.set_title("Missed pulses")
# plt.ylabel("Count")
# plt.xlabel(r"DM (log pc cm$^3$)")
# plt.draw()
# #Detected pulses unstacked
# plt.figure(figsize=(7,10))
# allAx = plt.subplot(311)
# _, bins, _ = allAx.hist(np.log10(detections['dm']), bins=15)
# allAx.set_title("All detections")
# detAx = plt.subplot(312, sharex=allAx, sharey=allAx)
# detAx.hist(np.log10(detections['dm'][matchMaskDet]), bins=bins)
# detAx.set_title("Detected pulses")
# misAx = plt.subplot(313, sharex=allAx, sharey=allAx)
# misAx.hist(np.log10(detections['dm'][falsePositiveMask]), bins=bins)
# misAx.set_title("False positives")
# plt.ylabel("Count")
# plt.xlabel(r"DM (log pc cm$^3$)")
# plt.draw()
#Stacked histogram of injections
plt.figure(figsize=(7,10))
ax = plt.subplot(17,1,(1,8))
ax.hist([np.log10(injections['dm'][matchMaskInj]), np.log10(injections['dm'][missedMask])],
stacked=True, label=['Detected', 'Missed'], bins=15)
ax.yaxis.get_major_locator().set_params(integer=True)
ax.legend(loc='upper center')
ax.set_xlabel(r"DM (log pc cm$^3$)")
ax.tick_params(labelbottom=True)
ax2 = plt.subplot(17,1,(9,16))
ax2.hist([np.log10(injections['pulseWidth'][matchMaskInj]), np.log10(injections['pulseWidth'][missedMask])],
stacked=True, label=['Detected', 'Missed'], bins=15)
ax2.yaxis.get_major_locator().set_params(integer=True)
ax2.set_xlabel(r"Pulse width (log s)")
plt.ylabel("Count")
plt.title("Injected pulses")
wordAx = plt.subplot(17,1,17)
wordAx.text(.3,.5,f"Overall detection rate: {round(Counter(matchMaskInj)[True]/(len(injections['dm']))*100,1)}%", size=14)
wordAx.set_axis_off()
plt.draw()
#Stacked histogram of detections
plt.figure(figsize=(7,10))
ax = plt.subplot(17,1,(1,8))
ax.hist([np.log10(detections['dm'][matchMaskDet]), np.log10(detections['dm'][falsePositiveMask])],
stacked=True, label=['True detection', 'False positive'], bins=15)
ax.yaxis.get_major_locator().set_params(integer=True)
ax.set_xlabel(r"DM (log pc cm$^3$)")
ax.legend(loc='upper center')
ax2 = plt.subplot(17,1,(9,16))
ax2.hist([np.log10(detections['pulseWidth'][matchMaskDet]), np.log10(detections['pulseWidth'][falsePositiveMask])],
stacked=True, label=['True detection', 'False positive'], bins=15)
ax2.yaxis.get_major_locator().set_params(integer=True)
ax2.set_xlabel("Pulse width (log s)")
plt.ylabel("Count")
plt.title("Detections")
wordAx = plt.subplot(17,1,17)
wordAx.text(.3,.5,f"False positive rate: {round(Counter(falsePositiveMask)[True]/(len(detections['dm']))*100,1)}%", size=14)
wordAx.set_axis_off()
plt.draw()
#block end of script
plt.show()