Coverage for polars_analysis / plotting / helper.py: 73%

288 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-03-16 15:00 -0400

1from __future__ import annotations 

2 

3import errno 

4import logging 

5import math 

6import os 

7from dataclasses import dataclass 

8from itertools import groupby 

9from time import time 

10from typing import Any, Callable, List, Literal, Optional, Tuple, Union, cast, overload 

11 

12import matplotlib as mpl 

13import numpy as np 

14import numpy.typing as npt 

15import polars as pl 

16from matplotlib import pyplot as plt 

17from scipy.optimize import curve_fit # type: ignore 

18 

19# Instantiate logger 

20log = logging.getLogger(__name__) 

21 

22 

23"""GLOBAL CONSTANTS""" # :( 

24START = 0 

25N_PHASES = 30 

26pulses_per_train = 30 

27start_meas = 0 

28# colors 

29if hasattr(mpl, "colormaps"): 

30 jet = mpl.colormaps["jet"] # type: ignore 

31else: 

32 from matplotlib import cm 

33 

34 get = cm.get_cmap("jet", 30) 

35 

36FIVE_PERCENT_RISETIME = 30 # ns 

37t_align = 490 

38 

39 

40@overload 

41def lin(x: float, a: float, b: float) -> float: ... 

42@overload 

43def lin(x: np.ndarray, a: np.ndarray, b: np.ndarray) -> np.ndarray: ... 

44def lin( 

45 x: Union[float, np.ndarray], a: Union[float, np.ndarray], b: Union[float, np.ndarray] 

46) -> Union[float, np.ndarray]: 

47 return a * x + b 

48 

49 

50def timer_func(func) -> Callable[..., Any]: 

51 # This function shows the execution time of 

52 # the function object passed 

53 def wrap_func(*args, **kwargs) -> Any: 

54 t1 = time() 

55 result = func(*args, **kwargs) 

56 t2 = time() 

57 print(f"Function {func.__name__!r} executed in {(t2 - t1):.4f}s") 

58 return result 

59 

60 return wrap_func 

61 

62 

63def make_xlabel(plt, label: str) -> None: 

64 plt.xlabel(r"" + label + "", x=1.0, ha="right", size=12, labelpad=0) 

65 

66 

67def make_ylabel(plt, label: str) -> None: 

68 plt.ylabel(r"" + label + "", y=1.0, ha="right", size=12, labelpad=0) 

69 

70 

71# These binned stats methods are adapted from https://stackoverflow.com/a/53022701 

72def hist_mean(centers: np.ndarray, counts: np.ndarray): 

73 return np.average(centers, weights=counts) 

74 

75 

76def hist_var(centers: np.ndarray, counts: np.ndarray): 

77 dev = counts * (centers - hist_mean(centers, counts)) ** 2 

78 return dev.sum() / counts.sum() 

79 

80 

81def hist_moment(centers: np.ndarray, counts: np.ndarray, n: int): 

82 nom = (counts * (centers - hist_mean(centers, counts)) ** n).sum() / counts.sum() 

83 d = hist_var(centers, counts) ** (n / 2) 

84 return nom / d 

85 

86 

87def exp_decay(x, A, tau, C): 

88 return A * np.exp(-x / tau) + C 

89 

90 

91def est_exp_fit_pars(x, y): 

92 return (np.max(y), np.max(x) / 3, np.min(y)) 

93 

94 

95def fit_exp_decay(x, y): 

96 fit_pars_est = est_exp_fit_pars(x, y) 

97 if len(x) < 3 or math.isnan(x[0]): 

98 log.info("Len(x) < 3; using initial values for exponential fit") 

99 A, tau, C = fit_pars_est 

100 return A, 0, tau, 0, C, 0 

101 try: 

102 pars, cov = curve_fit(exp_decay, x, y, p0=fit_pars_est) 

103 A, dA = pars[0], np.sqrt(cov[0, 0]) 

104 tau, dTau = pars[1], np.sqrt(cov[1, 1]) 

105 C, dC = pars[2], np.sqrt(cov[2, 2]) 

106 except RuntimeError: 

107 log.warning("Exponential fit did not converge, using initial values") 

108 A, tau, C = fit_pars_est 

109 dA, dTau, dC = 0, 0, 0 

110 

111 return A, dA, tau, dTau, C, dC 

112 

113 

114def gauss(x, mu, sigma, N): 

115 return N * np.exp(-((x - mu) ** 2.0) / (2 * sigma**2)) 

116 

117 

118def calc_gaussian(data: npt.ArrayLike, bins: np.ndarray) -> Tuple[float, float, float, float, float, float]: 

119 n, _ = np.histogram(data, bins=bins) 

120 return calc_gaussian_from_bins(n, bins) 

121 

122 

123def calc_gaussian_from_bins(n: np.ndarray, bins: np.ndarray) -> Tuple[float, float, float, float, float, float]: 

124 centers = 0.5 * (bins[1:] + bins[:-1]) 

125 

126 mean_est = hist_mean(centers, n) 

127 std_est = np.sqrt(hist_var(centers, n)) 

128 # Fit errors don't work when mean is near 0 but not 0.... 

129 if mean_est - std_est < 0 < mean_est + std_est: 

130 mean_est = np.float64(0.0) 

131 

132 guess: List[np.floating] = [mean_est, std_est, np.max(n)] 

133 

134 if len(centers) < 3 or math.isnan(centers[0]): 

135 log.info("Len(centers) < 3; using default values for gaussian fit") 

136 mu, dmu, sigma, dsigma, N, dN = 0, 0, 1, 0, 1, 0 

137 else: 

138 try: 

139 pars, cov = curve_fit(gauss, centers, n, p0=guess) 

140 mu, dmu = pars[0], np.sqrt(cov[0, 0]) 

141 sigma, dsigma = pars[1], np.sqrt(cov[1, 1]) 

142 N, dN = pars[2], np.sqrt(cov[2, 2]) 

143 

144 except RuntimeError: 

145 # In case the fit doesn't converge 

146 log.warning("Gaussian fit did not converge, using default values") 

147 mu, dmu, sigma, dsigma, N, dN = 0, 0, 1, 0, 1, 0 

148 

149 return mu, dmu, sigma, dsigma, N, dN 

150 

151 

152def drawLabel(ax, adc, channel, runType, pos, gain, filter_name="", filter_info="", maxamp="") -> None: 

153 # make label 

154 label = r"ATLAS Upgrade" 

155 label += "\n" + r"Slice v1.1 " + gain + "\n" + r"ADC " + str(adc) 

156 if "coherent" not in channel: 

157 label += ", Channel " + str(channel) 

158 if filter_name != "" and filter_info != "": 

159 label += "\n" + r"" + str(filter_name) + ": " + str(filter_info) 

160 if maxamp != "": 

161 label += "\n" + r"RunType: " + maxamp 

162 elif "coherent" in channel: 

163 label += "\n" + "Coherent noise" 

164 else: 

165 label += "\n" + runType 

166 

167 # position label 

168 x = 0.05 

169 y = 0.05 

170 va = "bottom" 

171 ha = "left" 

172 if "b" not in pos: # Top 

173 y = 0.95 

174 va = "top" 

175 if "l" not in pos: # Right 

176 x = 0.95 

177 ha = "right" 

178 

179 plt.text( 

180 x, 

181 y, 

182 label, 

183 size=14, 

184 transform=ax.transAxes, 

185 multialignment=ha, 

186 verticalalignment=va, 

187 horizontalalignment=ha, 

188 ) 

189 

190 

191def plot_aesthetics(runType, adc, channel, xtitle, ytitle, title, gain="GAIN") -> None: 

192 ax = plt.subplot(111) 

193 plt.minorticks_on() 

194 plt.grid(linestyle="--") 

195 ax.set_axisbelow(True) 

196 make_xlabel(plt, xtitle) 

197 make_ylabel(plt, ytitle) 

198 plt.title(title) 

199 drawLabel(ax, adc + 1, channel, runType.capitalize(), "tl", gain) # TODO gainA 

200 

201 

202def checkDir(directory): 

203 if not os.path.exists(directory): 

204 try: 

205 os.mkdir(directory) 

206 os.chmod(directory, 0o775) 

207 except OSError as exc: # Guard against race condition 

208 if exc.errno != errno.EEXIST: 

209 raise 

210 return directory 

211 

212 

213def remove_outliers(data: np.ndarray) -> np.ndarray: 

214 if np.size(data) < 1: 

215 return data 

216 distance_from_mean = abs(data - np.mean(data)) 

217 max_deviations = 5 # remove events more than XX sigma away 

218 not_outlier = distance_from_mean < max_deviations * np.std(data) 

219 if len(data[not_outlier]) < 1: 

220 return data 

221 return data[not_outlier] 

222 

223 

224def list_to_text_ranges(runs: Union[int, List[int], np.ndarray, None]) -> str: 

225 if isinstance(runs, int): 

226 return f"{runs}" 

227 elif runs is None: 

228 return "" 

229 elif len(runs) == 1: 

230 return f"{runs[0]}" 

231 else: 

232 # from stack overflow, reduce lists to minimum list of ranges 

233 runs = sorted(runs) 

234 out = [] 

235 for _, g in groupby(enumerate(runs), lambda k: k[0] - k[1]): 

236 start = next(g)[1] 

237 end = list(v for _, v in g) or [start] 

238 out.append((start, end[-1])) 

239 

240 if len(out) > 0: 

241 out_str = f"{out[0][0]}-{out[0][1]}" 

242 i = 1 

243 while i < len(out): 

244 out_str += f", {out[i][0]}-{out[i][1]}" 

245 i += 1 

246 else: 

247 out_str = "Missing Run #" 

248 

249 return out_str 

250 

251 

252def datetime_to_hr(timestamp_data: pl.Series): 

253 return ((timestamp_data - timestamp_data.min()).dt.total_seconds() / 3600).alias("hours") 

254 

255 

256@dataclass 

257class Metadata: 

258 run_numbers: Optional[str] = None 

259 gain: Optional[Literal["hi", "lo", "hi-lo"]] = None 

260 channels: Optional[str] = None 

261 board_id: Optional[str] = None 

262 pas_mode: Optional[str] = None 

263 att_val: Optional[str] = None 

264 load: Optional[str] = None 

265 amp: Optional[Literal["loamp", "hiamp"]] = None 

266 githash: Optional[str] = None 

267 board_type: Optional[Literal["EM", "HEC"]] = None 

268 board_version: Optional[str] = None 

269 

270 @classmethod 

271 def fill_from_dataframe(cls, df: pl.DataFrame) -> Metadata: 

272 metadata = cls() 

273 columns_to_get = [ 

274 "run_number", 

275 "gain", 

276 "channel", 

277 "board_id", 

278 "pas_mode", 

279 "att_val", 

280 "board_variant", 

281 "board_version", 

282 ] 

283 

284 columns_exist = [c for c in columns_to_get if c in df.columns] 

285 

286 df = df.select(columns_exist) 

287 

288 if "run_number" in df.columns: 

289 metadata.run_numbers = list_to_text_ranges(df["run_number"].unique().sort().to_list()) 

290 if "gain" in df.columns: 

291 metadata.gain = cast(Literal["hi", "lo", "hi-lo"], "-".join(df["gain"].unique().sort().to_list())) 

292 if "channel" in df.columns: 

293 metadata.channels = list_to_text_ranges(df["channel"].unique().sort().to_list()) 

294 if "board_id" in df.columns: 

295 metadata.board_id = list_to_text_ranges(df["board_id"].unique().sort().to_list()) 

296 if "pas_mode" in df.columns: 

297 metadata.pas_mode = list_to_text_ranges(df["pas_mode"].unique().sort().to_list()) 

298 if "att_val" in df.columns: 

299 metadata.att_val = list_to_text_ranges(df["att_val"].unique().sort().to_list()) 

300 if "board_variant" in df.columns: 

301 board_variant = df["board_variant"].unique().sort().to_list() 

302 if len(board_variant) != 1: 

303 log.error(f"Found more than one board_variant in dataframe ({board_variant}), not including...") 

304 else: 

305 metadata.board_type = board_variant[0] 

306 if "board_version" in df.columns: 

307 metadata.board_version = list_to_text_ranges(df["board_version"].unique().sort().to_list()) 

308 

309 return metadata 

310 

311 

312def plot_summary_string( 

313 ATLAS: bool = False, 

314 name: Optional[str] = None, 

315 info: Optional[Metadata] = None, 

316 run_numbers: Optional[Union[str, int]] = None, 

317 gain: Optional[Literal["hi", "lo"]] = None, 

318 channels: Optional[Union[str, int]] = None, 

319 board_id: Optional[Union[str, int]] = None, 

320 attenuation: Optional[Union[str, float, List[float]]] = None, 

321 pas_mode: Optional[Union[str, int]] = None, 

322 load: Optional[Union[str, int]] = None, 

323 amp: Optional[Literal["loamp", "hiamp"]] = None, 

324 board_type: Optional[Literal["EM", "HEC"]] = None, 

325 board_version: Optional[str] = None, 

326) -> str: 

327 if info is None: 

328 info = Metadata() 

329 

330 _run_numbers: Optional[Union[str, int]] = None 

331 _gain: Optional[Literal["hi", "lo", "hi-lo"]] = None 

332 _channels: Optional[Union[str, int]] = None 

333 _board_id: Optional[Union[str, int]] = None 

334 _attenuation: Optional[Union[str, float, List[float]]] = None 

335 _pas_mode: Optional[Union[str, int]] = None 

336 _load: Optional[Union[str, int]] = None 

337 _amp: Optional[Literal["loamp", "hiamp"]] = None 

338 _type: Optional[Literal["EM", "HEC"]] = None 

339 _version: Optional[str] = None 

340 

341 # Use info dict by default, otherwise argument 

342 if info.run_numbers is not None: 

343 _run_numbers = info.run_numbers 

344 else: 

345 _run_numbers = run_numbers 

346 

347 if info.gain is not None: 

348 _gain = info.gain 

349 else: 

350 _gain = gain 

351 

352 if info.channels is not None: 

353 _channels = info.channels 

354 else: 

355 _channels = channels 

356 

357 if info.board_id is not None: 

358 _board_id = info.board_id 

359 else: 

360 _board_id = board_id 

361 

362 if info.att_val is not None: 

363 _attenuation = info.att_val 

364 else: 

365 _attenuation = attenuation 

366 

367 if info.pas_mode is not None: 

368 _pas_mode = info.pas_mode 

369 else: 

370 _pas_mode = pas_mode 

371 

372 if info.load is not None: 

373 _load = info.load 

374 else: 

375 _load = load 

376 

377 if info.amp is not None: 

378 _amp = info.amp 

379 else: 

380 _amp = amp 

381 

382 if info.board_type is not None: 

383 _type = info.board_type 

384 else: 

385 _type = board_type 

386 

387 if info.board_version is not None: 

388 _version = info.board_version 

389 else: 

390 _version = board_version 

391 

392 # Construct string 

393 string = "" 

394 if ATLAS: 

395 string = "ATLAS Upgrade: " 

396 if name is not None: 

397 string += f"{name}: \n" 

398 

399 if _run_numbers is not None: 

400 if isinstance(_run_numbers, str) and "-" in _run_numbers: 

401 string += f"Runs {_run_numbers}" 

402 else: 

403 string += f"Run {_run_numbers}" 

404 

405 if _board_id is not None: 

406 string += f", Board ID {_board_id}" 

407 

408 if _channels is not None: 

409 if isinstance(_channels, int): 

410 string += f", CH {_channels:03}" 

411 else: 

412 string += f", CH {_channels}" 

413 if _gain is not None: 

414 string += f", {_gain} gain" 

415 if _attenuation is not None: 

416 if isinstance(_attenuation, list): 

417 att_string = ",".join(f"{a:.1f}" if isinstance(a, (int, float)) else str(a) for a in _attenuation) 

418 if len(_attenuation) > 1: 

419 att_string = f"[{att_string}]" 

420 elif isinstance(_attenuation, (int, float)): 

421 att_string = f"{_attenuation:.1f}" 

422 else: 

423 att_string = str(_attenuation) 

424 string += f", {att_string}dB" 

425 if _pas_mode is not None: 

426 if _pas_mode in [25, 50, "25", "50"]: 

427 string += f", {_pas_mode}Ω" 

428 else: 

429 string += f", PS gain {_pas_mode}" 

430 if _load is not None: 

431 string += f", {_load}pF" 

432 if _load is not None: 

433 string += f", {_amp}" 

434 # Don't use for now, maybe add later 

435 if _type is not None: 

436 pass 

437 if _version is not None: 

438 pass 

439 

440 return string