Coverage for polars_analysis / analysis / cut_thresholds.py: 54%

98 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-03-13 13:37 -0400

1import logging 

2from copy import deepcopy 

3from pathlib import Path 

4from typing import Any, Literal, Optional 

5 

6import numpy as np 

7import polars as pl 

8import yaml 

9from matplotlib.axes import Axes 

10from matplotlib.patches import Rectangle 

11 

12from polars_analysis.analysis import constants 

13 

14# Instantiate logger 

15log = logging.getLogger(__name__) 

16 

17 

18def str_dash_to_list(s: str): 

19 """ 

20 Return numpy array of X...Y from 'X-Y' 

21 """ 

22 return np.arange(int(s.split("-")[0]), int(s.split("-")[1]) + 1) 

23 

24 

25class CutThresholds: 

26 def __init__(self, constants_file: Path = constants.ROOTDIR / "polars_analysis/analysis/qc_thresholds.csv"): 

27 # "polars_analysis/analysis/pass_fail_constants.csv" 

28 if "csv" in str(constants_file).lower(): 

29 self.bounds = ( 

30 pl.read_csv(constants_file, comment_prefix="#") 

31 # Strip whitespace from column headers 

32 .rename(lambda column_name: column_name.strip()) 

33 # Strip whitespace from elements 

34 .with_columns(pl.col(pl.String).str.strip_chars()) 

35 # Replace empty strings with null 

36 .with_columns( 

37 pl.when(pl.col(pl.String).str.len_chars() == 0).then(None).otherwise(pl.col(pl.String)).name.keep() 

38 ) 

39 # Convert requirements and awg_amp/att_val to float 

40 .with_columns( 

41 pl.col("requirement").cast(pl.Float64, strict=False), 

42 pl.col("min_requirement").cast(pl.Float64, strict=False), 

43 pl.col("max_requirement").cast(pl.Float64, strict=False), 

44 pl.col("awg_amp").cast(pl.Float64, strict=False), 

45 pl.col("att_val").cast(pl.Float64, strict=False), 

46 ) 

47 ) 

48 

49 elif "yaml" in str(constants_file).lower() or "yml" in str(constants_file).lower(): 

50 with open(constants_file) as file: 

51 self.bounds = yaml.safe_load(file) 

52 else: 

53 log.error("Uknown constants file type {constants_file}") 

54 

55 def walk_dictionary(self, input_dict: dict, decending_key: Any = None, known_values: dict = {}): 

56 """ 

57 Flatten nested dictionary into a dictionary with key-value pairs 

58 

59 Args: 

60 input_dict: nested dictionary (no lists or tuples) 

61 decending_key: flips back and forth between None and 

62 a key to handle key-val pairs through sub-dictionaries. 

63 known_values: the generated flat dictionary, recursively constructed 

64 """ 

65 

66 # No further dictionaries 

67 last_layer = all([type(val) is not dict for val in input_dict.values()]) 

68 

69 if last_layer: 

70 # One final dicionary per key 

71 for key, val in input_dict.items(): 

72 # We need a deep copy since update is in place and we want a unique dictionary per key-val pair 

73 temp = deepcopy(known_values) 

74 temp.update({key: val}) 

75 yield temp 

76 

77 else: 

78 # Grab all keys not pointing to a dictionary at this level and append 

79 for key, val in input_dict.items(): 

80 if type(val) is not dict: 

81 known_values.update({key: val}) 

82 

83 # Decend 

84 for key, val in input_dict.items(): 

85 if decending_key is not None: 

86 # We assume that the keys in this dictionary are the values for the key in the parent dictionary. 

87 known_values.update({decending_key: key}) 

88 

89 if type(val) is dict: 

90 yield from self.walk_dictionary(val, (key if decending_key is None else None), known_values) 

91 

92 def explode_cuts(self) -> pl.DataFrame: 

93 """ 

94 Explode into single 'row' per cut / ch / gain / greater-than,less-than,equals 

95 """ 

96 if type(self.bounds) is pl.dataframe.frame.DataFrame: 

97 columns_to_expand = ["channel", "link", "adc"] 

98 

99 df: pl.DataFrame = ( 

100 self.bounds.with_columns( 

101 [ 

102 # Grab number dash number and turn into array of numbers 

103 pl.when(pl.col(c).str.contains(r"\d+-\d+")) 

104 .then( 

105 pl.col(c).map_elements( 

106 lambda x: np.arange(int(x.split("-")[0]), int(x.split("-")[1]) + 1), 

107 return_dtype=pl.List(pl.Int64), 

108 ) 

109 ) 

110 .otherwise( # I think we never get to this anyway due to polars executing both when and then 

111 pl.col(c).map_elements(lambda x: [int(x.split("-")[0])], return_dtype=pl.List(pl.Int64)) 

112 ) 

113 .alias(c) 

114 for c in columns_to_expand 

115 ] 

116 ) 

117 .explode("channel") 

118 .explode("link") 

119 .explode("adc") 

120 ) 

121 

122 return df 

123 elif type(self.bounds) is type(dict()): 

124 df = pl.concat( 

125 [pl.from_dict(d) for var in self.bounds["variables"] for d in self.walk_dictionary(var)], how="diagonal" 

126 ) 

127 

128 df = ( 

129 df.with_columns( 

130 pl.col("channels").map_elements(str_dash_to_list, return_dtype=pl.List(pl.Int64)).alias("channel") 

131 ) 

132 .explode("channel") 

133 .drop("channels") 

134 ) 

135 

136 return df 

137 else: 

138 raise Exception("explode_cuts doesn't know what to do") 

139 

140 def get_bounds( 

141 self, name: str, gain: Literal["hi", "lo"], amp: Optional[float] = None, attn: Optional[float] = None 

142 ) -> tuple: 

143 """ 

144 Return tuple of (lower, upper) bounds based on the name, gain, and optionally amp 

145 """ 

146 try: 

147 if type(self.bounds) is pl.dataframe.frame.DataFrame: 

148 return self._get_bounds_csv(name, gain, amp, attn=attn) 

149 else: 

150 log.error("Uknown constants file type, returning (None,None)") 

151 return (None, None) 

152 except Exception as e: 

153 log.error(f"Failed to get bounds for {name=}, {gain=}, {amp=}, {attn=}") 

154 log.error(e) 

155 return (None, None) 

156 

157 def _get_bounds_csv( 

158 self, 

159 name: str, 

160 gain: Literal["hi", "lo"], 

161 amp: Optional[float] = None, 

162 board_type: str = "EM", 

163 attn: Optional[float] = None, 

164 channel: Optional[int] = None, 

165 ) -> tuple: 

166 """ 

167 Return tuple of (lower, upper) bounds based on the name, gain, and optionally amp 

168 """ 

169 

170 filters = [pl.col("name") == name, pl.col("gain") == gain] 

171 if amp is not None: 

172 filters.append(pl.col("awg_amp") == amp) 

173 if attn is not None: 

174 filters += [pl.col("att_val") == attn] 

175 if board_type is not None: 

176 filters += [pl.col("board_type") == board_type] 

177 if channel is not None: 

178 filters += [pl.col("channel") == str(channel)] 

179 

180 bounds = self.bounds.filter(*filters) 

181 

182 if len(bounds["min_requirement"]) == 0: 

183 log.error(f"No cut thresholds for {name=}, {gain=}, {amp=}") 

184 return (None, None) 

185 elif len(bounds["min_requirement"]) > 1: 

186 log.error("Multiple thresholds selected, perhaps you are missing the amp value for a pulse variable.") 

187 return (None, None) 

188 

189 return (bounds["min_requirement"].item(), bounds["max_requirement"].item()) 

190 

191 def draw_on( 

192 self, 

193 ax: Axes, 

194 name: str, 

195 gain: Literal["hi", "lo"], 

196 amp: Optional[float] = None, 

197 axis: Optional[str] = "y", 

198 suppress_warnings: Optional[bool] = False, 

199 ) -> None: 

200 """ 

201 Draw the bounds on a figure. 

202 axis==y(x) means the bounds are in the y(x) variable 

203 """ 

204 try: 

205 yLowPlot, yHighPlot = 0.0, 1.0 

206 xLowPlot, xHighPlot = 0.0, 1.0 

207 if axis == "y": 

208 yLowPlot, yHighPlot = ax.get_ylim() 

209 yLow, yHigh = self.get_bounds(name, gain, amp) 

210 yLow = max(yLow, yLowPlot) 

211 yHigh = min(yHigh, yHighPlot) 

212 

213 xLow, xHigh = ax.get_xlim() 

214 else: 

215 yLow, yHigh = ax.get_ylim() 

216 xLowPlot, xHighPlot = ax.get_xlim() 

217 xLow, xHigh = self.get_bounds(name, gain, amp) 

218 xLow = max(xLow, xLowPlot) 

219 xHigh = min(xHigh, xHighPlot) 

220 

221 ax.add_patch( 

222 Rectangle( 

223 (xLow, yLow), 

224 xHigh - xLow, 

225 yHigh - yLow, 

226 alpha=0.2, 

227 color=("blue" if gain == "lo" else "red"), 

228 fill=True, 

229 label=f"{gain} gain accept", 

230 zorder=1, 

231 ) 

232 ) 

233 

234 # Reset to previous limits 

235 if axis == "y": 

236 ax.set_ylim(yLowPlot, yHighPlot) 

237 else: 

238 ax.set_xlim(xLowPlot, xHighPlot) 

239 

240 ax.legend() 

241 

242 except Exception as e: 

243 if not suppress_warnings: 

244 log.warning(f"Error looking up thresholds for {name}, {gain}, {amp}.") 

245 print(e)