Coverage for polars_analysis / noise_stability.py: 94%

126 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-03-16 15:00 -0400

1import concurrent.futures 

2import datetime 

3import logging 

4import multiprocessing as mp 

5import os 

6import subprocess as sp 

7import traceback 

8from concurrent.futures import ProcessPoolExecutor 

9from pathlib import Path 

10from typing import Optional, cast 

11from zoneinfo import ZoneInfo 

12 

13import polars as pl 

14 

15import polars_analysis.plotting.noise_stability_plotting as plotting 

16import polars_analysis.plotting.pedestal_plotting as pedestal_plotting 

17from polars_analysis import utils 

18from polars_analysis.analysis import constants 

19from polars_analysis.data_sources import DataSource 

20from polars_analysis.plotting.helper import Metadata 

21from polars_analysis.utils import get_columns_or_exit 

22 

23# Instantiate logger 

24log = logging.getLogger(__name__) 

25 

26 

27def calc_all(raw_data: pl.DataFrame) -> pl.DataFrame: 

28 if len(raw_data.filter(pl.col("meas_type") == "noise_stability")) == 0: 

29 log.critical("No rows in the dataframe correspond to a noise_stability run. Aborting.") 

30 raise Exception("Empty dataframe") 

31 

32 raw_data = calc_derived(raw_data) 

33 return raw_data 

34 

35 

36def calc_derived(df: pl.DataFrame) -> pl.DataFrame: 

37 df = ( 

38 df.filter(pl.col("meas_type") == "noise_stability") 

39 .select( 

40 "run_number", "samples", "measurement", "board_id", "timestamp", "gain", "channel", "pas_mode", "att_val" 

41 ) 

42 .with_columns( 

43 pl.col("samples").list.mean().alias("mean"), 

44 pl.col("samples").list.std().alias("std"), 

45 ) 

46 ) 

47 return df 

48 

49 

50def plot_all( 

51 raw_data: pl.DataFrame, 

52 monitoring_df: pl.DataFrame, 

53 derived_data: pl.DataFrame, 

54 lab_env_data: pl.DataFrame, 

55 plot_dir: Path, 

56 # uri: Optional[str], 

57 plot_all_temp_sources: Optional[bool] = False, 

58): 

59 ### Raw Samples Plots ### 

60 columns_to_get = [ 

61 "run_number", 

62 "measurement", 

63 "channel", 

64 "gain", 

65 "samples", 

66 "board_id", 

67 "pas_mode", 

68 "trigger_rate", 

69 ] 

70 raw_data = get_columns_or_exit(raw_data, columns_to_get) 

71 

72 pas_mode = raw_data["pas_mode"].unique().to_list()[0] 

73 

74 if pas_mode != pas_mode: 

75 pas_mode = -1 

76 raw_data.drop_in_place("pas_mode") 

77 

78 aggregated_df = ( 

79 raw_data.sort(["channel", "gain", "measurement"]) # Sort by all relevant columns 

80 .group_by("channel", "gain", maintain_order=True) 

81 .agg( 

82 pl.col("run_number").first(), 

83 pl.col("samples").explode(), 

84 pl.col("board_id").first(), 

85 ) 

86 ) 

87 

88 temp_sources = constants.ALL_TEMPERATURE_SOURCES if plot_all_temp_sources else constants.TEMPERATURE_SOURCES 

89 for temp_source in temp_sources: 

90 log.debug(f"Processing temperature source: {temp_source}") 

91 try: 

92 plotting.plot_temp_correlation( 

93 derived_data, monitoring_df, lab_env_data, plot_dir, temp_source=temp_source, settling_time=None 

94 ) 

95 except ValueError: 

96 log.warning(f"Not enough data to calculate correlation for {temp_source}") 

97 for channel_df in aggregated_df.iter_rows(named=True): 

98 all_samples = channel_df["samples"] 

99 channel_info = Metadata.fill_from_dataframe(pl.DataFrame(channel_df)) 

100 pedestal_plotting.plot_raw(channel_info, all_samples, plot_dir) 

101 

102 info = Metadata.fill_from_dataframe(pl.DataFrame(aggregated_df)) 

103 plotting.plot_monitoring(monitoring_df, lab_env_data, plot_dir) 

104 plotting.avg_rms_mean_vs_channel(derived_data, plot_dir) 

105 plotting.plot_outliers(derived_data, plot_dir, info) 

106 plotting.plot_avg_sample_range(derived_data, plot_dir) 

107 plotting.plot_monitor_channel_correlation(derived_data, monitoring_df, lab_env_data, plot_dir) 

108 plotting.plot_monitor_monitor_correlation(monitoring_df, lab_env_data, plot_dir) 

109 for gain in ["lo", "hi"]: 

110 gain_filtered_df: pl.DataFrame = derived_data.filter(pl.col("gain") == gain) 

111 plotting.plot_mean_rms_vs_time(gain_filtered_df, plot_dir) 

112 plotting.plot_sample_range_vs_time(gain_filtered_df, plot_dir) 

113 

114 

115def parallel_plot_all( 

116 raw_data: pl.DataFrame, 

117 monitoring_df: pl.DataFrame, 

118 derived_data: pl.DataFrame, 

119 lab_env_data: pl.DataFrame, 

120 plot_dir: Path, 

121 plot_all_temp_sources: bool = False, 

122): 

123 plot_dir_filled = len([p for p in plot_dir.glob("*png")]) > 0 

124 githash = sp.check_output(["git", "rev-parse", "HEAD"]).decode("ascii").strip() 

125 columns_to_get = [ 

126 "run_number", 

127 "measurement", 

128 "channel", 

129 "gain", 

130 "samples", 

131 "board_id", 

132 "pas_mode", 

133 ] 

134 raw_data = get_columns_or_exit(raw_data, columns_to_get) 

135 derived_data = derived_data.join( 

136 raw_data.select("run_number", "measurement", "channel", "gain", "board_id"), 

137 on=["run_number", "measurement", "channel", "gain"], 

138 ) 

139 

140 info_derived = Metadata.fill_from_dataframe(derived_data) 

141 

142 with ProcessPoolExecutor(mp_context=mp.get_context("spawn")) as executor: 

143 job_handles = dict() 

144 

145 aggregated_df = ( 

146 raw_data.sort(["channel", "gain", "measurement"]) # Sort by all relevant columns 

147 .group_by("channel", "gain", maintain_order=True) 

148 .agg( 

149 pl.col("run_number").first(), 

150 pl.col("samples").explode(), 

151 pl.col("board_id").first(), 

152 pl.col("pas_mode").first(), 

153 ) 

154 ) 

155 for channel_df in aggregated_df.iter_rows(named=True): 

156 all_samples = channel_df["samples"] 

157 info = Metadata.fill_from_dataframe(pl.DataFrame(channel_df)) 

158 job_handles[ 

159 executor.submit( 

160 pedestal_plotting.plot_raw, 

161 info, 

162 all_samples, 

163 plot_dir, 

164 ) 

165 ] = "plot_raw" 

166 

167 temp_sources = constants.ALL_TEMPERATURE_SOURCES if plot_all_temp_sources else constants.TEMPERATURE_SOURCES 

168 for temp_source in temp_sources: 

169 job_handles[ 

170 executor.submit( 

171 plotting.plot_temp_correlation, 

172 derived_data.clone(), 

173 monitoring_df, 

174 lab_env_data, 

175 plot_dir, 

176 temp_source=temp_source, 

177 # settling_time=None, # settling time in minutes 

178 ) 

179 ] = "plot_temp_correlation" 

180 job_handles[ 

181 executor.submit( 

182 plotting.plot_monitor_channel_correlation, 

183 derived_data.clone(), 

184 monitoring_df, 

185 lab_env_data, 

186 plot_dir, 

187 ) 

188 ] = "plot_monitor_channel_correlation" 

189 

190 job_handles[ 

191 executor.submit( 

192 plotting.plot_monitor_monitor_correlation, 

193 monitoring_df, 

194 lab_env_data, 

195 plot_dir, 

196 ) 

197 ] = "plot_monitor_monitor_correlation" 

198 job_handles[ 

199 executor.submit( 

200 plotting.plot_outliers, 

201 derived_data.clone(), 

202 plot_dir, 

203 info_derived, 

204 ) 

205 ] = "plot_outliers" 

206 job_handles[ 

207 executor.submit( 

208 plotting.plot_avg_sample_range, 

209 derived_data.clone(), 

210 plot_dir, 

211 ) 

212 ] = "plot_avg_sample_range" 

213 

214 job_handles[ 

215 executor.submit( 

216 plotting.plot_monitoring, 

217 monitoring_df, 

218 lab_env_data, 

219 plot_dir, 

220 ) 

221 ] = "plot_monitoring" 

222 job_handles[ 

223 executor.submit( 

224 plotting.avg_rms_mean_vs_channel, 

225 derived_data.clone(), 

226 plot_dir, 

227 ) 

228 ] = "avg_rms_mean_vs_channel" 

229 for gain in ["lo", "hi"]: 

230 gain_df = derived_data.clone().filter(pl.col("gain") == gain).clone() 

231 info_g = Metadata.fill_from_dataframe(gain_df) 

232 info_g.githash = githash 

233 

234 job_handles[ 

235 executor.submit( 

236 plotting.plot_mean_rms_vs_time, 

237 gain_df, 

238 plot_dir, 

239 ) 

240 ] = "plot_mean_rms_vs_time" 

241 

242 job_handles[ 

243 executor.submit( 

244 plotting.plot_sample_range_vs_time, 

245 gain_df, 

246 plot_dir, 

247 ) 

248 ] = "plot_sample_range_vs_time" 

249 # Check for exceptions 

250 for future in concurrent.futures.as_completed(job_handles): 

251 job = job_handles[future] 

252 try: 

253 future.result() 

254 except Exception as exc: 

255 log.error(f"{job} generated an exception: {exc}") 

256 print(traceback.format_exc()) 

257 

258 if not plot_dir_filled: 

259 for f in plot_dir.glob("*png"): 

260 os.chmod(f, 0o664) 

261 for f in plot_dir.glob("*json"): 

262 os.chmod(f, 0o664) 

263 

264 

265def calc_plot_all( 

266 loader: DataSource, 

267 run_number: int, 

268 plot_dir: Path, 

269 plot_all_temp_sources: bool = False, 

270): 

271 if not plot_dir.exists(): 

272 plot_dir.mkdir(parents=True, exist_ok=True) 

273 os.chmod(plot_dir, 0o775) 

274 

275 raw_data = loader.load_raw_data(run_number) 

276 monitoring_df = loader.load_monitoring_data(run_number) 

277 

278 derived_data = calc_all(raw_data) 

279 loader.save_derived_data(derived_data, run_number=run_number, meas_type="noise_stability") 

280 

281 n_readouts = cast(int, raw_data["measurement"].max()) 

282 start_time = cast(datetime.datetime, raw_data["timestamp"].min()) 

283 end_time = cast(datetime.datetime, raw_data["timestamp"].max()) 

284 start_time_NY = start_time.astimezone(ZoneInfo("America/New_York")) 

285 end_time_NY = end_time.astimezone(ZoneInfo("America/New_York")) 

286 duration_per_readout = (end_time - start_time) / n_readouts 

287 

288 board_ids = loader.get_boards_list(run_number)["board_id"].to_list() 

289 for board_id in board_ids: 

290 utils.add_run_info("n_readouts", n_readouts + 1, board_id, plot_dir) 

291 utils.add_run_info( 

292 "duration_between_readouts", f"{round(duration_per_readout.total_seconds())} s", board_id, plot_dir 

293 ) 

294 utils.add_run_info("start_time", start_time_NY.strftime("%Y-%m-%d %H:%M:%S %Z"), board_id, plot_dir) 

295 utils.add_run_info("end_time", end_time_NY.strftime("%Y-%m-%d %H:%M:%S %Z"), board_id, plot_dir) 

296 

297 lab_env_data_all = loader.load_lab_environment_data() 

298 lab_env_data = ( 

299 lab_env_data_all.with_columns(pl.col("timestamp")) 

300 .filter(pl.col("timestamp") >= start_time) 

301 .filter(pl.col("timestamp") <= end_time) 

302 .filter(pl.col("lab_name") == "crate_lab") 

303 ) 

304 

305 log.info("Making noise stability plots") 

306 if log.getEffectiveLevel() == 10: # debug 

307 plot_all( 

308 raw_data, monitoring_df, derived_data, lab_env_data, plot_dir, plot_all_temp_sources=plot_all_temp_sources 

309 ) 

310 else: 

311 parallel_plot_all( 

312 raw_data, monitoring_df, derived_data, lab_env_data, plot_dir, plot_all_temp_sources=plot_all_temp_sources 

313 )