Coverage for polars_analysis / plotting / helper.py: 73%
288 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-16 15:00 -0400
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-16 15:00 -0400
1from __future__ import annotations
3import errno
4import logging
5import math
6import os
7from dataclasses import dataclass
8from itertools import groupby
9from time import time
10from typing import Any, Callable, List, Literal, Optional, Tuple, Union, cast, overload
12import matplotlib as mpl
13import numpy as np
14import numpy.typing as npt
15import polars as pl
16from matplotlib import pyplot as plt
17from scipy.optimize import curve_fit # type: ignore
19# Instantiate logger
20log = logging.getLogger(__name__)
23"""GLOBAL CONSTANTS""" # :(
24START = 0
25N_PHASES = 30
26pulses_per_train = 30
27start_meas = 0
28# colors
29if hasattr(mpl, "colormaps"):
30 jet = mpl.colormaps["jet"] # type: ignore
31else:
32 from matplotlib import cm
34 get = cm.get_cmap("jet", 30)
36FIVE_PERCENT_RISETIME = 30 # ns
37t_align = 490
40@overload
41def lin(x: float, a: float, b: float) -> float: ...
42@overload
43def lin(x: np.ndarray, a: np.ndarray, b: np.ndarray) -> np.ndarray: ...
44def lin(
45 x: Union[float, np.ndarray], a: Union[float, np.ndarray], b: Union[float, np.ndarray]
46) -> Union[float, np.ndarray]:
47 return a * x + b
50def timer_func(func) -> Callable[..., Any]:
51 # This function shows the execution time of
52 # the function object passed
53 def wrap_func(*args, **kwargs) -> Any:
54 t1 = time()
55 result = func(*args, **kwargs)
56 t2 = time()
57 print(f"Function {func.__name__!r} executed in {(t2 - t1):.4f}s")
58 return result
60 return wrap_func
63def make_xlabel(plt, label: str) -> None:
64 plt.xlabel(r"" + label + "", x=1.0, ha="right", size=12, labelpad=0)
67def make_ylabel(plt, label: str) -> None:
68 plt.ylabel(r"" + label + "", y=1.0, ha="right", size=12, labelpad=0)
71# These binned stats methods are adapted from https://stackoverflow.com/a/53022701
72def hist_mean(centers: np.ndarray, counts: np.ndarray):
73 return np.average(centers, weights=counts)
76def hist_var(centers: np.ndarray, counts: np.ndarray):
77 dev = counts * (centers - hist_mean(centers, counts)) ** 2
78 return dev.sum() / counts.sum()
81def hist_moment(centers: np.ndarray, counts: np.ndarray, n: int):
82 nom = (counts * (centers - hist_mean(centers, counts)) ** n).sum() / counts.sum()
83 d = hist_var(centers, counts) ** (n / 2)
84 return nom / d
87def exp_decay(x, A, tau, C):
88 return A * np.exp(-x / tau) + C
91def est_exp_fit_pars(x, y):
92 return (np.max(y), np.max(x) / 3, np.min(y))
95def fit_exp_decay(x, y):
96 fit_pars_est = est_exp_fit_pars(x, y)
97 if len(x) < 3 or math.isnan(x[0]):
98 log.info("Len(x) < 3; using initial values for exponential fit")
99 A, tau, C = fit_pars_est
100 return A, 0, tau, 0, C, 0
101 try:
102 pars, cov = curve_fit(exp_decay, x, y, p0=fit_pars_est)
103 A, dA = pars[0], np.sqrt(cov[0, 0])
104 tau, dTau = pars[1], np.sqrt(cov[1, 1])
105 C, dC = pars[2], np.sqrt(cov[2, 2])
106 except RuntimeError:
107 log.warning("Exponential fit did not converge, using initial values")
108 A, tau, C = fit_pars_est
109 dA, dTau, dC = 0, 0, 0
111 return A, dA, tau, dTau, C, dC
114def gauss(x, mu, sigma, N):
115 return N * np.exp(-((x - mu) ** 2.0) / (2 * sigma**2))
118def calc_gaussian(data: npt.ArrayLike, bins: np.ndarray) -> Tuple[float, float, float, float, float, float]:
119 n, _ = np.histogram(data, bins=bins)
120 return calc_gaussian_from_bins(n, bins)
123def calc_gaussian_from_bins(n: np.ndarray, bins: np.ndarray) -> Tuple[float, float, float, float, float, float]:
124 centers = 0.5 * (bins[1:] + bins[:-1])
126 mean_est = hist_mean(centers, n)
127 std_est = np.sqrt(hist_var(centers, n))
128 # Fit errors don't work when mean is near 0 but not 0....
129 if mean_est - std_est < 0 < mean_est + std_est:
130 mean_est = np.float64(0.0)
132 guess: List[np.floating] = [mean_est, std_est, np.max(n)]
134 if len(centers) < 3 or math.isnan(centers[0]):
135 log.info("Len(centers) < 3; using default values for gaussian fit")
136 mu, dmu, sigma, dsigma, N, dN = 0, 0, 1, 0, 1, 0
137 else:
138 try:
139 pars, cov = curve_fit(gauss, centers, n, p0=guess)
140 mu, dmu = pars[0], np.sqrt(cov[0, 0])
141 sigma, dsigma = pars[1], np.sqrt(cov[1, 1])
142 N, dN = pars[2], np.sqrt(cov[2, 2])
144 except RuntimeError:
145 # In case the fit doesn't converge
146 log.warning("Gaussian fit did not converge, using default values")
147 mu, dmu, sigma, dsigma, N, dN = 0, 0, 1, 0, 1, 0
149 return mu, dmu, sigma, dsigma, N, dN
152def drawLabel(ax, adc, channel, runType, pos, gain, filter_name="", filter_info="", maxamp="") -> None:
153 # make label
154 label = r"ATLAS Upgrade"
155 label += "\n" + r"Slice v1.1 " + gain + "\n" + r"ADC " + str(adc)
156 if "coherent" not in channel:
157 label += ", Channel " + str(channel)
158 if filter_name != "" and filter_info != "":
159 label += "\n" + r"" + str(filter_name) + ": " + str(filter_info)
160 if maxamp != "":
161 label += "\n" + r"RunType: " + maxamp
162 elif "coherent" in channel:
163 label += "\n" + "Coherent noise"
164 else:
165 label += "\n" + runType
167 # position label
168 x = 0.05
169 y = 0.05
170 va = "bottom"
171 ha = "left"
172 if "b" not in pos: # Top
173 y = 0.95
174 va = "top"
175 if "l" not in pos: # Right
176 x = 0.95
177 ha = "right"
179 plt.text(
180 x,
181 y,
182 label,
183 size=14,
184 transform=ax.transAxes,
185 multialignment=ha,
186 verticalalignment=va,
187 horizontalalignment=ha,
188 )
191def plot_aesthetics(runType, adc, channel, xtitle, ytitle, title, gain="GAIN") -> None:
192 ax = plt.subplot(111)
193 plt.minorticks_on()
194 plt.grid(linestyle="--")
195 ax.set_axisbelow(True)
196 make_xlabel(plt, xtitle)
197 make_ylabel(plt, ytitle)
198 plt.title(title)
199 drawLabel(ax, adc + 1, channel, runType.capitalize(), "tl", gain) # TODO gainA
202def checkDir(directory):
203 if not os.path.exists(directory):
204 try:
205 os.mkdir(directory)
206 os.chmod(directory, 0o775)
207 except OSError as exc: # Guard against race condition
208 if exc.errno != errno.EEXIST:
209 raise
210 return directory
213def remove_outliers(data: np.ndarray) -> np.ndarray:
214 if np.size(data) < 1:
215 return data
216 distance_from_mean = abs(data - np.mean(data))
217 max_deviations = 5 # remove events more than XX sigma away
218 not_outlier = distance_from_mean < max_deviations * np.std(data)
219 if len(data[not_outlier]) < 1:
220 return data
221 return data[not_outlier]
224def list_to_text_ranges(runs: Union[int, List[int], np.ndarray, None]) -> str:
225 if isinstance(runs, int):
226 return f"{runs}"
227 elif runs is None:
228 return ""
229 elif len(runs) == 1:
230 return f"{runs[0]}"
231 else:
232 # from stack overflow, reduce lists to minimum list of ranges
233 runs = sorted(runs)
234 out = []
235 for _, g in groupby(enumerate(runs), lambda k: k[0] - k[1]):
236 start = next(g)[1]
237 end = list(v for _, v in g) or [start]
238 out.append((start, end[-1]))
240 if len(out) > 0:
241 out_str = f"{out[0][0]}-{out[0][1]}"
242 i = 1
243 while i < len(out):
244 out_str += f", {out[i][0]}-{out[i][1]}"
245 i += 1
246 else:
247 out_str = "Missing Run #"
249 return out_str
252def datetime_to_hr(timestamp_data: pl.Series):
253 return ((timestamp_data - timestamp_data.min()).dt.total_seconds() / 3600).alias("hours")
256@dataclass
257class Metadata:
258 run_numbers: Optional[str] = None
259 gain: Optional[Literal["hi", "lo", "hi-lo"]] = None
260 channels: Optional[str] = None
261 board_id: Optional[str] = None
262 pas_mode: Optional[str] = None
263 att_val: Optional[str] = None
264 load: Optional[str] = None
265 amp: Optional[Literal["loamp", "hiamp"]] = None
266 githash: Optional[str] = None
267 board_type: Optional[Literal["EM", "HEC"]] = None
268 board_version: Optional[str] = None
270 @classmethod
271 def fill_from_dataframe(cls, df: pl.DataFrame) -> Metadata:
272 metadata = cls()
273 columns_to_get = [
274 "run_number",
275 "gain",
276 "channel",
277 "board_id",
278 "pas_mode",
279 "att_val",
280 "board_variant",
281 "board_version",
282 ]
284 columns_exist = [c for c in columns_to_get if c in df.columns]
286 df = df.select(columns_exist)
288 if "run_number" in df.columns:
289 metadata.run_numbers = list_to_text_ranges(df["run_number"].unique().sort().to_list())
290 if "gain" in df.columns:
291 metadata.gain = cast(Literal["hi", "lo", "hi-lo"], "-".join(df["gain"].unique().sort().to_list()))
292 if "channel" in df.columns:
293 metadata.channels = list_to_text_ranges(df["channel"].unique().sort().to_list())
294 if "board_id" in df.columns:
295 metadata.board_id = list_to_text_ranges(df["board_id"].unique().sort().to_list())
296 if "pas_mode" in df.columns:
297 metadata.pas_mode = list_to_text_ranges(df["pas_mode"].unique().sort().to_list())
298 if "att_val" in df.columns:
299 metadata.att_val = list_to_text_ranges(df["att_val"].unique().sort().to_list())
300 if "board_variant" in df.columns:
301 board_variant = df["board_variant"].unique().sort().to_list()
302 if len(board_variant) != 1:
303 log.error(f"Found more than one board_variant in dataframe ({board_variant}), not including...")
304 else:
305 metadata.board_type = board_variant[0]
306 if "board_version" in df.columns:
307 metadata.board_version = list_to_text_ranges(df["board_version"].unique().sort().to_list())
309 return metadata
312def plot_summary_string(
313 ATLAS: bool = False,
314 name: Optional[str] = None,
315 info: Optional[Metadata] = None,
316 run_numbers: Optional[Union[str, int]] = None,
317 gain: Optional[Literal["hi", "lo"]] = None,
318 channels: Optional[Union[str, int]] = None,
319 board_id: Optional[Union[str, int]] = None,
320 attenuation: Optional[Union[str, float, List[float]]] = None,
321 pas_mode: Optional[Union[str, int]] = None,
322 load: Optional[Union[str, int]] = None,
323 amp: Optional[Literal["loamp", "hiamp"]] = None,
324 board_type: Optional[Literal["EM", "HEC"]] = None,
325 board_version: Optional[str] = None,
326) -> str:
327 if info is None:
328 info = Metadata()
330 _run_numbers: Optional[Union[str, int]] = None
331 _gain: Optional[Literal["hi", "lo", "hi-lo"]] = None
332 _channels: Optional[Union[str, int]] = None
333 _board_id: Optional[Union[str, int]] = None
334 _attenuation: Optional[Union[str, float, List[float]]] = None
335 _pas_mode: Optional[Union[str, int]] = None
336 _load: Optional[Union[str, int]] = None
337 _amp: Optional[Literal["loamp", "hiamp"]] = None
338 _type: Optional[Literal["EM", "HEC"]] = None
339 _version: Optional[str] = None
341 # Use info dict by default, otherwise argument
342 if info.run_numbers is not None:
343 _run_numbers = info.run_numbers
344 else:
345 _run_numbers = run_numbers
347 if info.gain is not None:
348 _gain = info.gain
349 else:
350 _gain = gain
352 if info.channels is not None:
353 _channels = info.channels
354 else:
355 _channels = channels
357 if info.board_id is not None:
358 _board_id = info.board_id
359 else:
360 _board_id = board_id
362 if info.att_val is not None:
363 _attenuation = info.att_val
364 else:
365 _attenuation = attenuation
367 if info.pas_mode is not None:
368 _pas_mode = info.pas_mode
369 else:
370 _pas_mode = pas_mode
372 if info.load is not None:
373 _load = info.load
374 else:
375 _load = load
377 if info.amp is not None:
378 _amp = info.amp
379 else:
380 _amp = amp
382 if info.board_type is not None:
383 _type = info.board_type
384 else:
385 _type = board_type
387 if info.board_version is not None:
388 _version = info.board_version
389 else:
390 _version = board_version
392 # Construct string
393 string = ""
394 if ATLAS:
395 string = "ATLAS Upgrade: "
396 if name is not None:
397 string += f"{name}: \n"
399 if _run_numbers is not None:
400 if isinstance(_run_numbers, str) and "-" in _run_numbers:
401 string += f"Runs {_run_numbers}"
402 else:
403 string += f"Run {_run_numbers}"
405 if _board_id is not None:
406 string += f", Board ID {_board_id}"
408 if _channels is not None:
409 if isinstance(_channels, int):
410 string += f", CH {_channels:03}"
411 else:
412 string += f", CH {_channels}"
413 if _gain is not None:
414 string += f", {_gain} gain"
415 if _attenuation is not None:
416 if isinstance(_attenuation, list):
417 att_string = ",".join(f"{a:.1f}" if isinstance(a, (int, float)) else str(a) for a in _attenuation)
418 if len(_attenuation) > 1:
419 att_string = f"[{att_string}]"
420 elif isinstance(_attenuation, (int, float)):
421 att_string = f"{_attenuation:.1f}"
422 else:
423 att_string = str(_attenuation)
424 string += f", {att_string}dB"
425 if _pas_mode is not None:
426 if _pas_mode in [25, 50, "25", "50"]:
427 string += f", {_pas_mode}Ω"
428 else:
429 string += f", PS gain {_pas_mode}"
430 if _load is not None:
431 string += f", {_load}pF"
432 if _load is not None:
433 string += f", {_amp}"
434 # Don't use for now, maybe add later
435 if _type is not None:
436 pass
437 if _version is not None:
438 pass
440 return string