Coverage for polars_analysis / cross_talk.py: 47%

73 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-03-16 14:47 -0400

1import logging 

2import os 

3import sys 

4from pathlib import Path 

5from typing import Optional 

6 

7import polars as pl 

8 

9import polars_analysis.analysis.cross_talk_analysis as analysis 

10import polars_analysis.plotting.cross_talk_plotting as plotting 

11import polars_analysis.plotting.pulse_plotting as pulse_plotting 

12from polars_analysis import frame_utils, utils 

13from polars_analysis.data_sources import DataSource 

14from polars_analysis.db_interface import prod_db_data_uploader 

15from polars_analysis.db_interface.production_test_db import ProductionTestDB 

16 

17# Instantiate logger 

18log = logging.getLogger(__name__) 

19 

20""" 

21High level commands to run cross-talk loading, calculations, and plotting 

22""" 

23 

24 

25def load_derived_data(derived_dir: Path, run_number: int) -> pl.DataFrame: 

26 derived_df = ( 

27 pl.scan_parquet(derived_dir / "cross_talk_derived_values_*.parquet") 

28 .filter(pl.col("run_number") == run_number) 

29 .collect(streaming=True) 

30 ) 

31 

32 if derived_df.is_empty(): 

33 log.error(f"No derived pedestal data for run {run_number} found in {derived_dir}, exiting...") 

34 log.error("Have you run [yellow]calc-save-runs[/yellow]?") 

35 log.error(f"Does run {run_number} contain [yellow]crosstalk[/yellow] measurements?") 

36 sys.exit(1) 

37 

38 return derived_df 

39 

40 

41def calc_derived(df: pl.DataFrame, all_phases: bool = False) -> pl.DataFrame: 

42 df = ( 

43 df.filter( 

44 pl.col("meas_type") == "crosstalk", 

45 # pl.col("is_pulsed") == True, # noqa: E712 

46 ) 

47 .select( 

48 [ 

49 "run_number", 

50 "board_id", 

51 "board_variant", 

52 "board_version", 

53 "measurement", 

54 "channel", 

55 "gain", 

56 "samples", 

57 "att_val", 

58 "awg_amp", 

59 "meas_chan", 

60 "pas_mode", 

61 ] 

62 ) 

63 # FIXME: we should use "is_pulsed" here when ready 

64 .with_columns(is_reference_pulse=analysis.expr_is_reference_pulse()) 

65 # Analysis objects are samples interleaved... 

66 .pipe(analysis.pipe_samples_interleaved) 

67 .with_columns( 

68 max_pulse_amp=analysis.expr_max_pulse_amp(), 

69 amp=analysis.expr_awg_amp_to_amp(), 

70 max_phase_indices=analysis.expr_max_phase_indices(), 

71 ) 

72 .pipe(analysis.pipe_OFCs) 

73 .pipe(analysis.pipe_apply_OFCs, all_phases=all_phases) 

74 .pipe(analysis.pipe_rise_time) 

75 .drop("samples", "samples_baseline", strict=False) 

76 ) 

77 

78 return df 

79 

80 

81def calc_all(raw_data: pl.DataFrame, all_phases: bool = False) -> pl.DataFrame: 

82 if len(raw_data.filter(pl.col("meas_type") == "crosstalk")) == 0: 

83 log.critical("No rows in the dataframe correspond to a crosstalk run. Wrong run number? Aborting.") 

84 raise Exception("Empty dataframe") 

85 

86 return calc_derived(raw_data, all_phases) 

87 

88 

89def plot_all(derived_data: pl.DataFrame, plot_dir: Path): 

90 plot_dir_filled = len([p for p in plot_dir.glob("*png")]) > 0 

91 

92 plotting.plot_ct_overlay(derived_data, plot_dir) 

93 plotting.plot_ct_table(derived_data, plot_dir) 

94 

95 for gain in ["lo", "hi"]: 

96 for channel in derived_data["channel"].unique(): 

97 filtered_df = derived_data.filter(pl.col("channel") == channel, pl.col("gain") == gain) 

98 if len(filtered_df) > 0: 

99 pulse_plotting.plot_energy_hist(filtered_df, channel, plot_dir) 

100 else: 

101 log.warning(f"Channel {channel} {gain} gain missing and not plotted. Maybe saturated.") 

102 

103 if not plot_dir_filled: 

104 for f in plot_dir.glob("*png"): 

105 os.chmod(f, 0o664) 

106 for f in plot_dir.glob("*json"): 

107 os.chmod(f, 0o664) 

108 

109 

110def calc_plot_all( 

111 loader: DataSource, 

112 run_number: int, 

113 plot_dir: Path, 

114 check_frames: bool = False, 

115 swap_frame18: bool = False, 

116 uri: Optional[str] = None, 

117 all_phases: bool = False, 

118): 

119 if not plot_dir.exists(): 

120 plot_dir.mkdir(parents=True, exist_ok=True) 

121 os.chmod(plot_dir, 0o775) 

122 

123 alignment_info = "Alignment not checked" 

124 

125 board_ids = loader.get_boards_list(run_number)["board_id"].to_list() 

126 

127 if check_frames: 

128 raw_data, alignment = frame_utils.check_and_align_frames_wrapper( 

129 loader, run_number, swap_frame18, plot_dir=plot_dir 

130 ) 

131 

132 if len(alignment[0]) != 0: 

133 channels = [int(i) for i in alignment[0]] 

134 alignment_info = "Alignment performed on ch " + str(channels) 

135 for board_id in board_ids: 

136 utils.add_run_info("channels", channels, board_id, plot_dir, print_to_website=False) 

137 utils.add_run_info( 

138 "first_sample", [int(i) for i in alignment[1]], board_id, plot_dir, print_to_website=False 

139 ) 

140 utils.add_run_info("offset", [int(i) for i in alignment[2]], board_id, plot_dir, print_to_website=False) 

141 else: 

142 alignment_info = "Frames already aligned" 

143 else: 

144 raw_data = loader.load_raw_data(run_number) 

145 

146 for board_id in board_ids: 

147 utils.add_run_info("alignment_info", alignment_info, board_id, plot_dir) 

148 

149 derived_data = calc_all(raw_data, all_phases) 

150 loader.save_derived_data(derived_data, run_number=run_number, meas_type="crosstalk") 

151 

152 if uri is not None: 

153 prod_db = ProductionTestDB(uri) 

154 if prod_db_data_uploader.upload_derived_data(derived_data, prod_db, {}, "crosstalk"): 

155 log.info(f"Uploaded run {run_number} production data to db: {prod_db}") 

156 else: 

157 log.error(f"Failed to upload run {run_number} to production data to db: {prod_db}") 

158 

159 plot_all(derived_data, plot_dir)