Coverage for polars_analysis / pulse.py: 73%
150 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-16 14:47 -0400
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-16 14:47 -0400
1import concurrent.futures
2import logging
3import multiprocessing as mp
4import os
5import subprocess as sp
6import traceback
7from concurrent.futures import ProcessPoolExecutor
8from pathlib import Path
9from typing import List, Literal, Optional
11import polars as pl
13import polars_analysis.analysis.pulse_analysis as analysis
14import polars_analysis.plotting.pulse_plotting as plotting
15from polars_analysis.data_sources import DataSource
16from polars_analysis.db_interface import prod_db_data_uploader
17from polars_analysis.db_interface.production_test_db import ProductionTestDB
18from polars_analysis.plotting.helper import Metadata
19from polars_analysis.utils import add_run_info, get_columns_or_exit
21# Instantiate logger
22log = logging.getLogger(__name__)
24"""
25High level commands to run pulse loading, calculations, and plotting
26"""
29def calc_derived(df: pl.DataFrame, all_phases: bool = False, OFC_quantile: float = 1.0) -> pl.DataFrame:
30 df = (
31 df.filter(
32 pl.col("meas_type") == "pulse",
33 pl.col("is_pulsed"),
34 )
35 .select(
36 [
37 "run_number",
38 "measurement",
39 "channel",
40 "gain",
41 "samples",
42 "att_val",
43 "awg_amp",
44 "pas_mode",
45 "board_id",
46 "board_variant",
47 "board_version",
48 ]
49 )
50 .pipe(analysis.pipe_samples_interleaved)
51 .with_columns(
52 max_pulse_amp=analysis.expr_max_pulse_amp(),
53 amp=analysis.expr_awg_amp_to_amp(),
54 max_phase_indices=analysis.expr_max_phase_indices(),
55 )
56 .pipe(analysis.pipe_OFCs, quantile=OFC_quantile)
57 .pipe(analysis.pipe_apply_OFCs, all_phases=all_phases)
58 .pipe(analysis.pipe_rise_time)
59 .pipe(analysis.pipe_zero_crossing)
60 .pipe(analysis.pipe_ref_pulse_correlation)
61 .pipe(analysis.pipe_ref_pulse_rmse)
62 .pipe(analysis.pipe_inl, skip_last_n_hi=1, skip_last_n_lo=2)
63 .pipe(analysis.pipe_gain_ratio)
64 .pipe(analysis.pipe_energy_sigma)
65 .pipe(analysis.pipe_energy_sigma, all_phases=all_phases)
66 .drop("samples", "samples_baseline", strict=False)
67 )
68 return df
71def calc_all(raw_data: pl.DataFrame, all_phases: bool = False, OFC_quantile: float = 1.0) -> pl.DataFrame:
72 if len(raw_data.filter(pl.col("meas_type") == "pulse")) == 0:
73 log.critical("No rows in the dataframe correspond to a pulse run. Aborting.")
74 raise Exception("Empty dataframe")
76 derived_data = calc_derived(raw_data, all_phases, OFC_quantile=OFC_quantile)
78 return derived_data
81def plot_all(
82 raw_data: pl.DataFrame,
83 derived_data: pl.DataFrame,
84 plot_dir: Path,
85 all_phases: bool = False,
86 qc_plotting: bool = False,
87):
88 plot_dir_filled = len([p for p in plot_dir.glob("*png")]) > 0
90 raw_data = raw_data.select(
91 ["run_number", "measurement", "channel", "gain", "samples", "att_val", "awg_amp", "board_id", "pas_mode"]
92 )
94 plotting.plot_pulse_means_rms(
95 derived_data,
96 plot_dir,
97 raw_data["channel"].unique().to_list(),
98 raw_data[0]["pas_mode"][0],
99 raw_data[0]["board_id"][0],
100 raw_data[0]["att_val"][0],
101 )
103 derived_data = derived_data.join(
104 raw_data.select("run_number", "measurement", "channel", "gain", "board_id"),
105 on=["run_number", "measurement", "channel", "gain"],
106 )
108 githash = sp.check_output(["git", "rev-parse", "HEAD"]).decode("ascii").strip()
109 for channel in raw_data["channel"].unique():
110 log.info(f"Plotting {channel=}")
112 filtered_df: pl.DataFrame = derived_data.filter(pl.col("channel") == channel)
113 info = Metadata.fill_from_dataframe(filtered_df)
114 info.githash = githash
116 plotting.plot_pulse_overlay_all(filtered_df, channel, plot_dir)
118 plotting.plot_gain_ratios(
119 filtered_df,
120 channel,
121 plot_dir,
122 )
124 for amp in filtered_df["amp"]:
125 plotting.plot_pulse_gain_overlay(
126 filtered_df.filter(pl.col("amp") == amp),
127 channel,
128 plot_dir,
129 )
131 gains: List[Literal["lo", "hi"]] = ["lo", "hi"]
132 for gain in gains:
133 info.gain = gain
134 gain_filtered_df: pl.DataFrame = filtered_df.filter(pl.col("gain") == gain)
135 skip_last_n = None
136 if qc_plotting and gain == "hi":
137 skip_last_n = 1
138 elif qc_plotting and gain == "lo":
139 skip_last_n = 2
141 plotting.plot_energy_resolution(gain_filtered_df, channel, plot_dir)
142 plotting.plot_energy_resolution(
143 gain_filtered_df,
144 channel,
145 plot_dir,
146 plot_log_scale=True,
147 )
149 plotting.plot_sigma_e(gain_filtered_df, channel, plot_dir)
151 plotting.plot_sigma_T(
152 gain_filtered_df,
153 channel,
154 plot_dir,
155 )
156 plotting.plot_sigma_T(
157 gain_filtered_df,
158 channel,
159 plot_dir,
160 plot_log_scale=True,
161 )
163 plotting.plot_timing_mean(
164 gain_filtered_df,
165 channel,
166 plot_dir,
167 )
169 plotting.plot_risetime(
170 gain_filtered_df,
171 channel,
172 plot_dir,
173 )
174 plotting.plot_zero_crossing(gain_filtered_df, channel, plot_dir)
175 plotting.plot_autocorrelation(gain_filtered_df, channel, plot_dir)
177 plotting.plot_INL(gain_filtered_df, channel, plot_dir, skip_last_n=skip_last_n)
179 for amp in gain_filtered_df["amp"]:
180 plotting.plot_ofc_samples(
181 gain_filtered_df.filter(pl.col("amp") == amp),
182 channel,
183 plot_dir,
184 )
185 plotting.plot_timing_hist(
186 gain_filtered_df.filter(pl.col("amp") == amp),
187 channel,
188 plot_dir,
189 )
190 plotting.plot_energy_hist(
191 gain_filtered_df.filter(pl.col("amp") == amp),
192 channel,
193 plot_dir,
194 )
195 if all_phases:
196 plotting.plot_all_phases_energy(
197 gain_filtered_df.filter(pl.col("amp") == amp),
198 channel,
199 plot_dir,
200 )
201 plotting.plot_all_phases_timing(
202 gain_filtered_df.filter(pl.col("amp") == amp),
203 channel,
204 plot_dir,
205 )
207 if not plot_dir_filled:
208 for f in plot_dir.glob("*png"):
209 os.chmod(f, 0o664)
210 for f in plot_dir.glob("*json"):
211 os.chmod(f, 0o664)
214def parallel_plot_all(
215 raw_data: pl.DataFrame,
216 derived_data: pl.DataFrame,
217 plot_dir: Path,
218 all_phases: bool = False,
219 qc_plotting: bool = False,
220):
221 plot_dir_filled = len([p for p in plot_dir.glob("*png")]) > 0
222 githash = sp.check_output(["git", "rev-parse", "HEAD"]).decode("ascii").strip()
224 with ProcessPoolExecutor(mp_context=mp.get_context("spawn")) as executor:
225 job_handles = dict()
226 columns_to_get = [
227 "run_number",
228 "measurement",
229 "channel",
230 "gain",
231 "samples",
232 "att_val",
233 "awg_amp",
234 "board_id",
235 "pas_mode",
236 ]
237 raw_data = get_columns_or_exit(raw_data, columns_to_get)
239 job_handles[
240 executor.submit(
241 plotting.plot_pulse_means_rms,
242 derived_data.clone(),
243 plot_dir,
244 raw_data["channel"].unique().to_list(),
245 raw_data[0]["pas_mode"][0],
246 raw_data[0]["board_id"][0],
247 raw_data[0]["att_val"][0],
248 )
249 ] = "plot_pulse_means_rms"
251 derived_data = derived_data.join(
252 raw_data.select("run_number", "measurement", "channel", "gain", "board_id"),
253 on=["run_number", "measurement", "channel", "gain"],
254 )
255 for channel in raw_data["channel"].unique():
256 log.info(f"Plotting {channel=}")
258 filtered_df: pl.DataFrame = derived_data.filter(pl.col("channel") == channel)
260 job_handles[executor.submit(plotting.plot_pulse_overlay_all, filtered_df, channel, plot_dir)] = (
261 "plot_pulse_overlay_all"
262 )
264 job_handles[
265 executor.submit(
266 plotting.plot_gain_ratios,
267 filtered_df,
268 channel,
269 plot_dir,
270 )
271 ] = "plot_gain_ratios"
273 for amp in filtered_df["amp"]:
274 job_handles[
275 executor.submit(
276 plotting.plot_pulse_gain_overlay,
277 filtered_df.filter(pl.col("amp") == amp),
278 channel,
279 plot_dir,
280 )
281 ] = "plot_pulse_gain_overlay"
283 for gain in ["lo", "hi"]:
284 gain_filtered_df: pl.DataFrame = filtered_df.filter(pl.col("gain") == gain)
285 info_g = Metadata.fill_from_dataframe(gain_filtered_df)
286 info_g.githash = githash
288 job_handles[executor.submit(plotting.plot_energy_resolution, gain_filtered_df, channel, plot_dir)] = (
289 "plot_energy_resolution"
290 )
291 job_handles[
292 executor.submit(
293 plotting.plot_energy_resolution,
294 gain_filtered_df,
295 channel,
296 plot_dir,
297 plot_log_scale=True,
298 )
299 ] = "plot_energy_resolution"
301 job_handles[executor.submit(plotting.plot_sigma_e, gain_filtered_df, channel, plot_dir)] = (
302 "plot_sigma_e"
303 )
305 job_handles[
306 executor.submit(
307 plotting.plot_sigma_T,
308 gain_filtered_df,
309 channel,
310 plot_dir,
311 )
312 ] = "plot_sigma_T"
313 job_handles[
314 executor.submit(
315 plotting.plot_sigma_T,
316 gain_filtered_df,
317 channel,
318 plot_dir,
319 plot_log_scale=True,
320 )
321 ] = "plot_sigma_T"
323 job_handles[
324 executor.submit(
325 plotting.plot_timing_mean,
326 gain_filtered_df,
327 channel,
328 plot_dir,
329 )
330 ] = "plot_timing_mean"
332 job_handles[
333 executor.submit(
334 plotting.plot_risetime,
335 gain_filtered_df,
336 channel,
337 plot_dir,
338 )
339 ] = "plot_risetime"
341 job_handles[
342 executor.submit(
343 plotting.plot_zero_crossing,
344 gain_filtered_df,
345 channel,
346 plot_dir,
347 )
348 ] = "plot_zero_crossing"
350 job_handles[executor.submit(plotting.plot_autocorrelation, gain_filtered_df, channel, plot_dir)] = (
351 "plot_autocorrelation"
352 )
354 job_handles[
355 executor.submit(
356 plotting.plot_INL,
357 gain_filtered_df,
358 channel,
359 plot_dir,
360 skip_last_n=((1 if gain == "hi" else 2) if qc_plotting else None),
361 )
362 ] = "plot_INL"
364 for amp in gain_filtered_df["amp"]:
365 job_handles[
366 executor.submit(
367 plotting.plot_ofc_samples,
368 gain_filtered_df.filter(pl.col("amp") == amp),
369 channel,
370 plot_dir,
371 )
372 ] = "plot_ofc_samples"
373 job_handles[
374 executor.submit(
375 plotting.plot_timing_hist,
376 gain_filtered_df.filter(pl.col("amp") == amp),
377 channel,
378 plot_dir,
379 )
380 ] = "plot_timing_hist"
381 job_handles[
382 executor.submit(
383 plotting.plot_energy_hist,
384 gain_filtered_df.filter(pl.col("amp") == amp),
385 channel,
386 plot_dir,
387 )
388 ] = "plot_energy_hist"
389 if all_phases:
390 job_handles[
391 executor.submit(
392 plotting.plot_all_phases_energy,
393 gain_filtered_df.filter(pl.col("amp") == amp),
394 channel,
395 plot_dir,
396 )
397 ] = "plot_all_phases_energy"
398 job_handles[
399 executor.submit(
400 plotting.plot_all_phases_timing,
401 gain_filtered_df.filter(pl.col("amp") == amp),
402 channel,
403 plot_dir,
404 )
405 ] = "plot_all_phases_timing"
407 # Check for exceptions
408 for future in concurrent.futures.as_completed(job_handles):
409 job = job_handles[future]
410 try:
411 future.result()
412 except Exception as exc:
413 log.error(f"{job} generated an exception: {exc}")
414 print(traceback.format_exc())
416 if not plot_dir_filled:
417 for f in plot_dir.glob("*png"):
418 os.chmod(f, 0o664)
419 for f in plot_dir.glob("*json"):
420 os.chmod(f, 0o664)
423def upload_derived_data(derived_data: pl.DataFrame, uri: str):
424 prod_db = ProductionTestDB(uri)
425 run_number = derived_data["run_number"].unique().to_list()
426 success = prod_db_data_uploader.upload_derived_data(
427 derived_data.with_columns(pl.col("rise_time").mean().over("gain").alias("risetime_mean")),
428 prod_db,
429 {
430 "energy_mean": "energy_mean",
431 "energy_std": "energy_std",
432 "time_mean": "time_mean",
433 "time_std": "time_std",
434 "zero_crossing_time": "zero_crossing_time",
435 "ref_pulse_corr": "ref_corr",
436 "ref_pulse_rmse": "ref_rmse",
437 "gain_ratio": "gain_ratio",
438 "INL": "simple_INL",
439 "risetime_mean": "risetime_mean",
440 },
441 "pulse",
442 )
443 if success:
444 log.info(f"Uploaded run {run_number} production data to db: {prod_db}")
445 else:
446 log.error(f"Failed to upload run {run_number} to production data to db: {prod_db}")
449def calc_plot_all(
450 loader: DataSource,
451 run_number: int,
452 plot_dir: Path,
453 uri: Optional[str] = None,
454 all_phases: bool = False,
455 OFC_quantile: float = 1.0,
456):
457 if not plot_dir.exists():
458 plot_dir.mkdir(parents=True, exist_ok=True)
459 os.chmod(plot_dir, 0o775)
460 raw_data = loader.load_raw_data(run_number, require_unsaturated=True)
461 raw_data = raw_data.sort(by="awg_amp", descending=True)
462 derived_data = calc_all(raw_data, all_phases, OFC_quantile)
463 loader.save_derived_data(derived_data, run_number=run_number, meas_type="pulse")
465 skipped_times = derived_data.filter(pl.col("times").list.eval(pl.element().explode().null_count()).explode() != 0)[
466 "gain", "channel", "awg_amp"
467 ]
469 board_ids = loader.get_boards_list(run_number)["board_id"].to_list()
471 for gain, channel in skipped_times["gain", "channel"].unique().iter_rows():
472 times = skipped_times.filter(gain=gain, channel=channel)["awg_amp"].unique().sort().to_list()
473 for board_id in board_ids:
474 add_run_info(
475 f"Channel {channel} {gain.upper()} gain AWG amps with discard times", times, board_id, plot_dir, True
476 )
478 if uri is not None:
479 upload_derived_data(derived_data, uri)
481 log.info("Making pulse plots")
482 if log.getEffectiveLevel() == 10: # debug
483 plot_all(raw_data, derived_data, plot_dir, all_phases)
484 else:
485 parallel_plot_all(raw_data, derived_data, plot_dir, all_phases)