Coverage for polars_analysis / pulse.py: 73%

150 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-03-16 14:47 -0400

1import concurrent.futures 

2import logging 

3import multiprocessing as mp 

4import os 

5import subprocess as sp 

6import traceback 

7from concurrent.futures import ProcessPoolExecutor 

8from pathlib import Path 

9from typing import List, Literal, Optional 

10 

11import polars as pl 

12 

13import polars_analysis.analysis.pulse_analysis as analysis 

14import polars_analysis.plotting.pulse_plotting as plotting 

15from polars_analysis.data_sources import DataSource 

16from polars_analysis.db_interface import prod_db_data_uploader 

17from polars_analysis.db_interface.production_test_db import ProductionTestDB 

18from polars_analysis.plotting.helper import Metadata 

19from polars_analysis.utils import add_run_info, get_columns_or_exit 

20 

21# Instantiate logger 

22log = logging.getLogger(__name__) 

23 

24""" 

25High level commands to run pulse loading, calculations, and plotting 

26""" 

27 

28 

29def calc_derived(df: pl.DataFrame, all_phases: bool = False, OFC_quantile: float = 1.0) -> pl.DataFrame: 

30 df = ( 

31 df.filter( 

32 pl.col("meas_type") == "pulse", 

33 pl.col("is_pulsed"), 

34 ) 

35 .select( 

36 [ 

37 "run_number", 

38 "measurement", 

39 "channel", 

40 "gain", 

41 "samples", 

42 "att_val", 

43 "awg_amp", 

44 "pas_mode", 

45 "board_id", 

46 "board_variant", 

47 "board_version", 

48 ] 

49 ) 

50 .pipe(analysis.pipe_samples_interleaved) 

51 .with_columns( 

52 max_pulse_amp=analysis.expr_max_pulse_amp(), 

53 amp=analysis.expr_awg_amp_to_amp(), 

54 max_phase_indices=analysis.expr_max_phase_indices(), 

55 ) 

56 .pipe(analysis.pipe_OFCs, quantile=OFC_quantile) 

57 .pipe(analysis.pipe_apply_OFCs, all_phases=all_phases) 

58 .pipe(analysis.pipe_rise_time) 

59 .pipe(analysis.pipe_zero_crossing) 

60 .pipe(analysis.pipe_ref_pulse_correlation) 

61 .pipe(analysis.pipe_ref_pulse_rmse) 

62 .pipe(analysis.pipe_inl, skip_last_n_hi=1, skip_last_n_lo=2) 

63 .pipe(analysis.pipe_gain_ratio) 

64 .pipe(analysis.pipe_energy_sigma) 

65 .pipe(analysis.pipe_energy_sigma, all_phases=all_phases) 

66 .drop("samples", "samples_baseline", strict=False) 

67 ) 

68 return df 

69 

70 

71def calc_all(raw_data: pl.DataFrame, all_phases: bool = False, OFC_quantile: float = 1.0) -> pl.DataFrame: 

72 if len(raw_data.filter(pl.col("meas_type") == "pulse")) == 0: 

73 log.critical("No rows in the dataframe correspond to a pulse run. Aborting.") 

74 raise Exception("Empty dataframe") 

75 

76 derived_data = calc_derived(raw_data, all_phases, OFC_quantile=OFC_quantile) 

77 

78 return derived_data 

79 

80 

81def plot_all( 

82 raw_data: pl.DataFrame, 

83 derived_data: pl.DataFrame, 

84 plot_dir: Path, 

85 all_phases: bool = False, 

86 qc_plotting: bool = False, 

87): 

88 plot_dir_filled = len([p for p in plot_dir.glob("*png")]) > 0 

89 

90 raw_data = raw_data.select( 

91 ["run_number", "measurement", "channel", "gain", "samples", "att_val", "awg_amp", "board_id", "pas_mode"] 

92 ) 

93 

94 plotting.plot_pulse_means_rms( 

95 derived_data, 

96 plot_dir, 

97 raw_data["channel"].unique().to_list(), 

98 raw_data[0]["pas_mode"][0], 

99 raw_data[0]["board_id"][0], 

100 raw_data[0]["att_val"][0], 

101 ) 

102 

103 derived_data = derived_data.join( 

104 raw_data.select("run_number", "measurement", "channel", "gain", "board_id"), 

105 on=["run_number", "measurement", "channel", "gain"], 

106 ) 

107 

108 githash = sp.check_output(["git", "rev-parse", "HEAD"]).decode("ascii").strip() 

109 for channel in raw_data["channel"].unique(): 

110 log.info(f"Plotting {channel=}") 

111 

112 filtered_df: pl.DataFrame = derived_data.filter(pl.col("channel") == channel) 

113 info = Metadata.fill_from_dataframe(filtered_df) 

114 info.githash = githash 

115 

116 plotting.plot_pulse_overlay_all(filtered_df, channel, plot_dir) 

117 

118 plotting.plot_gain_ratios( 

119 filtered_df, 

120 channel, 

121 plot_dir, 

122 ) 

123 

124 for amp in filtered_df["amp"]: 

125 plotting.plot_pulse_gain_overlay( 

126 filtered_df.filter(pl.col("amp") == amp), 

127 channel, 

128 plot_dir, 

129 ) 

130 

131 gains: List[Literal["lo", "hi"]] = ["lo", "hi"] 

132 for gain in gains: 

133 info.gain = gain 

134 gain_filtered_df: pl.DataFrame = filtered_df.filter(pl.col("gain") == gain) 

135 skip_last_n = None 

136 if qc_plotting and gain == "hi": 

137 skip_last_n = 1 

138 elif qc_plotting and gain == "lo": 

139 skip_last_n = 2 

140 

141 plotting.plot_energy_resolution(gain_filtered_df, channel, plot_dir) 

142 plotting.plot_energy_resolution( 

143 gain_filtered_df, 

144 channel, 

145 plot_dir, 

146 plot_log_scale=True, 

147 ) 

148 

149 plotting.plot_sigma_e(gain_filtered_df, channel, plot_dir) 

150 

151 plotting.plot_sigma_T( 

152 gain_filtered_df, 

153 channel, 

154 plot_dir, 

155 ) 

156 plotting.plot_sigma_T( 

157 gain_filtered_df, 

158 channel, 

159 plot_dir, 

160 plot_log_scale=True, 

161 ) 

162 

163 plotting.plot_timing_mean( 

164 gain_filtered_df, 

165 channel, 

166 plot_dir, 

167 ) 

168 

169 plotting.plot_risetime( 

170 gain_filtered_df, 

171 channel, 

172 plot_dir, 

173 ) 

174 plotting.plot_zero_crossing(gain_filtered_df, channel, plot_dir) 

175 plotting.plot_autocorrelation(gain_filtered_df, channel, plot_dir) 

176 

177 plotting.plot_INL(gain_filtered_df, channel, plot_dir, skip_last_n=skip_last_n) 

178 

179 for amp in gain_filtered_df["amp"]: 

180 plotting.plot_ofc_samples( 

181 gain_filtered_df.filter(pl.col("amp") == amp), 

182 channel, 

183 plot_dir, 

184 ) 

185 plotting.plot_timing_hist( 

186 gain_filtered_df.filter(pl.col("amp") == amp), 

187 channel, 

188 plot_dir, 

189 ) 

190 plotting.plot_energy_hist( 

191 gain_filtered_df.filter(pl.col("amp") == amp), 

192 channel, 

193 plot_dir, 

194 ) 

195 if all_phases: 

196 plotting.plot_all_phases_energy( 

197 gain_filtered_df.filter(pl.col("amp") == amp), 

198 channel, 

199 plot_dir, 

200 ) 

201 plotting.plot_all_phases_timing( 

202 gain_filtered_df.filter(pl.col("amp") == amp), 

203 channel, 

204 plot_dir, 

205 ) 

206 

207 if not plot_dir_filled: 

208 for f in plot_dir.glob("*png"): 

209 os.chmod(f, 0o664) 

210 for f in plot_dir.glob("*json"): 

211 os.chmod(f, 0o664) 

212 

213 

214def parallel_plot_all( 

215 raw_data: pl.DataFrame, 

216 derived_data: pl.DataFrame, 

217 plot_dir: Path, 

218 all_phases: bool = False, 

219 qc_plotting: bool = False, 

220): 

221 plot_dir_filled = len([p for p in plot_dir.glob("*png")]) > 0 

222 githash = sp.check_output(["git", "rev-parse", "HEAD"]).decode("ascii").strip() 

223 

224 with ProcessPoolExecutor(mp_context=mp.get_context("spawn")) as executor: 

225 job_handles = dict() 

226 columns_to_get = [ 

227 "run_number", 

228 "measurement", 

229 "channel", 

230 "gain", 

231 "samples", 

232 "att_val", 

233 "awg_amp", 

234 "board_id", 

235 "pas_mode", 

236 ] 

237 raw_data = get_columns_or_exit(raw_data, columns_to_get) 

238 

239 job_handles[ 

240 executor.submit( 

241 plotting.plot_pulse_means_rms, 

242 derived_data.clone(), 

243 plot_dir, 

244 raw_data["channel"].unique().to_list(), 

245 raw_data[0]["pas_mode"][0], 

246 raw_data[0]["board_id"][0], 

247 raw_data[0]["att_val"][0], 

248 ) 

249 ] = "plot_pulse_means_rms" 

250 

251 derived_data = derived_data.join( 

252 raw_data.select("run_number", "measurement", "channel", "gain", "board_id"), 

253 on=["run_number", "measurement", "channel", "gain"], 

254 ) 

255 for channel in raw_data["channel"].unique(): 

256 log.info(f"Plotting {channel=}") 

257 

258 filtered_df: pl.DataFrame = derived_data.filter(pl.col("channel") == channel) 

259 

260 job_handles[executor.submit(plotting.plot_pulse_overlay_all, filtered_df, channel, plot_dir)] = ( 

261 "plot_pulse_overlay_all" 

262 ) 

263 

264 job_handles[ 

265 executor.submit( 

266 plotting.plot_gain_ratios, 

267 filtered_df, 

268 channel, 

269 plot_dir, 

270 ) 

271 ] = "plot_gain_ratios" 

272 

273 for amp in filtered_df["amp"]: 

274 job_handles[ 

275 executor.submit( 

276 plotting.plot_pulse_gain_overlay, 

277 filtered_df.filter(pl.col("amp") == amp), 

278 channel, 

279 plot_dir, 

280 ) 

281 ] = "plot_pulse_gain_overlay" 

282 

283 for gain in ["lo", "hi"]: 

284 gain_filtered_df: pl.DataFrame = filtered_df.filter(pl.col("gain") == gain) 

285 info_g = Metadata.fill_from_dataframe(gain_filtered_df) 

286 info_g.githash = githash 

287 

288 job_handles[executor.submit(plotting.plot_energy_resolution, gain_filtered_df, channel, plot_dir)] = ( 

289 "plot_energy_resolution" 

290 ) 

291 job_handles[ 

292 executor.submit( 

293 plotting.plot_energy_resolution, 

294 gain_filtered_df, 

295 channel, 

296 plot_dir, 

297 plot_log_scale=True, 

298 ) 

299 ] = "plot_energy_resolution" 

300 

301 job_handles[executor.submit(plotting.plot_sigma_e, gain_filtered_df, channel, plot_dir)] = ( 

302 "plot_sigma_e" 

303 ) 

304 

305 job_handles[ 

306 executor.submit( 

307 plotting.plot_sigma_T, 

308 gain_filtered_df, 

309 channel, 

310 plot_dir, 

311 ) 

312 ] = "plot_sigma_T" 

313 job_handles[ 

314 executor.submit( 

315 plotting.plot_sigma_T, 

316 gain_filtered_df, 

317 channel, 

318 plot_dir, 

319 plot_log_scale=True, 

320 ) 

321 ] = "plot_sigma_T" 

322 

323 job_handles[ 

324 executor.submit( 

325 plotting.plot_timing_mean, 

326 gain_filtered_df, 

327 channel, 

328 plot_dir, 

329 ) 

330 ] = "plot_timing_mean" 

331 

332 job_handles[ 

333 executor.submit( 

334 plotting.plot_risetime, 

335 gain_filtered_df, 

336 channel, 

337 plot_dir, 

338 ) 

339 ] = "plot_risetime" 

340 

341 job_handles[ 

342 executor.submit( 

343 plotting.plot_zero_crossing, 

344 gain_filtered_df, 

345 channel, 

346 plot_dir, 

347 ) 

348 ] = "plot_zero_crossing" 

349 

350 job_handles[executor.submit(plotting.plot_autocorrelation, gain_filtered_df, channel, plot_dir)] = ( 

351 "plot_autocorrelation" 

352 ) 

353 

354 job_handles[ 

355 executor.submit( 

356 plotting.plot_INL, 

357 gain_filtered_df, 

358 channel, 

359 plot_dir, 

360 skip_last_n=((1 if gain == "hi" else 2) if qc_plotting else None), 

361 ) 

362 ] = "plot_INL" 

363 

364 for amp in gain_filtered_df["amp"]: 

365 job_handles[ 

366 executor.submit( 

367 plotting.plot_ofc_samples, 

368 gain_filtered_df.filter(pl.col("amp") == amp), 

369 channel, 

370 plot_dir, 

371 ) 

372 ] = "plot_ofc_samples" 

373 job_handles[ 

374 executor.submit( 

375 plotting.plot_timing_hist, 

376 gain_filtered_df.filter(pl.col("amp") == amp), 

377 channel, 

378 plot_dir, 

379 ) 

380 ] = "plot_timing_hist" 

381 job_handles[ 

382 executor.submit( 

383 plotting.plot_energy_hist, 

384 gain_filtered_df.filter(pl.col("amp") == amp), 

385 channel, 

386 plot_dir, 

387 ) 

388 ] = "plot_energy_hist" 

389 if all_phases: 

390 job_handles[ 

391 executor.submit( 

392 plotting.plot_all_phases_energy, 

393 gain_filtered_df.filter(pl.col("amp") == amp), 

394 channel, 

395 plot_dir, 

396 ) 

397 ] = "plot_all_phases_energy" 

398 job_handles[ 

399 executor.submit( 

400 plotting.plot_all_phases_timing, 

401 gain_filtered_df.filter(pl.col("amp") == amp), 

402 channel, 

403 plot_dir, 

404 ) 

405 ] = "plot_all_phases_timing" 

406 

407 # Check for exceptions 

408 for future in concurrent.futures.as_completed(job_handles): 

409 job = job_handles[future] 

410 try: 

411 future.result() 

412 except Exception as exc: 

413 log.error(f"{job} generated an exception: {exc}") 

414 print(traceback.format_exc()) 

415 

416 if not plot_dir_filled: 

417 for f in plot_dir.glob("*png"): 

418 os.chmod(f, 0o664) 

419 for f in plot_dir.glob("*json"): 

420 os.chmod(f, 0o664) 

421 

422 

423def upload_derived_data(derived_data: pl.DataFrame, uri: str): 

424 prod_db = ProductionTestDB(uri) 

425 run_number = derived_data["run_number"].unique().to_list() 

426 success = prod_db_data_uploader.upload_derived_data( 

427 derived_data.with_columns(pl.col("rise_time").mean().over("gain").alias("risetime_mean")), 

428 prod_db, 

429 { 

430 "energy_mean": "energy_mean", 

431 "energy_std": "energy_std", 

432 "time_mean": "time_mean", 

433 "time_std": "time_std", 

434 "zero_crossing_time": "zero_crossing_time", 

435 "ref_pulse_corr": "ref_corr", 

436 "ref_pulse_rmse": "ref_rmse", 

437 "gain_ratio": "gain_ratio", 

438 "INL": "simple_INL", 

439 "risetime_mean": "risetime_mean", 

440 }, 

441 "pulse", 

442 ) 

443 if success: 

444 log.info(f"Uploaded run {run_number} production data to db: {prod_db}") 

445 else: 

446 log.error(f"Failed to upload run {run_number} to production data to db: {prod_db}") 

447 

448 

449def calc_plot_all( 

450 loader: DataSource, 

451 run_number: int, 

452 plot_dir: Path, 

453 uri: Optional[str] = None, 

454 all_phases: bool = False, 

455 OFC_quantile: float = 1.0, 

456): 

457 if not plot_dir.exists(): 

458 plot_dir.mkdir(parents=True, exist_ok=True) 

459 os.chmod(plot_dir, 0o775) 

460 raw_data = loader.load_raw_data(run_number, require_unsaturated=True) 

461 raw_data = raw_data.sort(by="awg_amp", descending=True) 

462 derived_data = calc_all(raw_data, all_phases, OFC_quantile) 

463 loader.save_derived_data(derived_data, run_number=run_number, meas_type="pulse") 

464 

465 skipped_times = derived_data.filter(pl.col("times").list.eval(pl.element().explode().null_count()).explode() != 0)[ 

466 "gain", "channel", "awg_amp" 

467 ] 

468 

469 board_ids = loader.get_boards_list(run_number)["board_id"].to_list() 

470 

471 for gain, channel in skipped_times["gain", "channel"].unique().iter_rows(): 

472 times = skipped_times.filter(gain=gain, channel=channel)["awg_amp"].unique().sort().to_list() 

473 for board_id in board_ids: 

474 add_run_info( 

475 f"Channel {channel} {gain.upper()} gain AWG amps with discard times", times, board_id, plot_dir, True 

476 ) 

477 

478 if uri is not None: 

479 upload_derived_data(derived_data, uri) 

480 

481 log.info("Making pulse plots") 

482 if log.getEffectiveLevel() == 10: # debug 

483 plot_all(raw_data, derived_data, plot_dir, all_phases) 

484 else: 

485 parallel_plot_all(raw_data, derived_data, plot_dir, all_phases)