Coverage for polars_analysis / analysis / cut_thresholds.py: 54%
98 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-13 13:37 -0400
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-13 13:37 -0400
1import logging
2from copy import deepcopy
3from pathlib import Path
4from typing import Any, Literal, Optional
6import numpy as np
7import polars as pl
8import yaml
9from matplotlib.axes import Axes
10from matplotlib.patches import Rectangle
12from polars_analysis.analysis import constants
14# Instantiate logger
15log = logging.getLogger(__name__)
18def str_dash_to_list(s: str):
19 """
20 Return numpy array of X...Y from 'X-Y'
21 """
22 return np.arange(int(s.split("-")[0]), int(s.split("-")[1]) + 1)
25class CutThresholds:
26 def __init__(self, constants_file: Path = constants.ROOTDIR / "polars_analysis/analysis/qc_thresholds.csv"):
27 # "polars_analysis/analysis/pass_fail_constants.csv"
28 if "csv" in str(constants_file).lower():
29 self.bounds = (
30 pl.read_csv(constants_file, comment_prefix="#")
31 # Strip whitespace from column headers
32 .rename(lambda column_name: column_name.strip())
33 # Strip whitespace from elements
34 .with_columns(pl.col(pl.String).str.strip_chars())
35 # Replace empty strings with null
36 .with_columns(
37 pl.when(pl.col(pl.String).str.len_chars() == 0).then(None).otherwise(pl.col(pl.String)).name.keep()
38 )
39 # Convert requirements and awg_amp/att_val to float
40 .with_columns(
41 pl.col("requirement").cast(pl.Float64, strict=False),
42 pl.col("min_requirement").cast(pl.Float64, strict=False),
43 pl.col("max_requirement").cast(pl.Float64, strict=False),
44 pl.col("awg_amp").cast(pl.Float64, strict=False),
45 pl.col("att_val").cast(pl.Float64, strict=False),
46 )
47 )
49 elif "yaml" in str(constants_file).lower() or "yml" in str(constants_file).lower():
50 with open(constants_file) as file:
51 self.bounds = yaml.safe_load(file)
52 else:
53 log.error("Uknown constants file type {constants_file}")
55 def walk_dictionary(self, input_dict: dict, decending_key: Any = None, known_values: dict = {}):
56 """
57 Flatten nested dictionary into a dictionary with key-value pairs
59 Args:
60 input_dict: nested dictionary (no lists or tuples)
61 decending_key: flips back and forth between None and
62 a key to handle key-val pairs through sub-dictionaries.
63 known_values: the generated flat dictionary, recursively constructed
64 """
66 # No further dictionaries
67 last_layer = all([type(val) is not dict for val in input_dict.values()])
69 if last_layer:
70 # One final dicionary per key
71 for key, val in input_dict.items():
72 # We need a deep copy since update is in place and we want a unique dictionary per key-val pair
73 temp = deepcopy(known_values)
74 temp.update({key: val})
75 yield temp
77 else:
78 # Grab all keys not pointing to a dictionary at this level and append
79 for key, val in input_dict.items():
80 if type(val) is not dict:
81 known_values.update({key: val})
83 # Decend
84 for key, val in input_dict.items():
85 if decending_key is not None:
86 # We assume that the keys in this dictionary are the values for the key in the parent dictionary.
87 known_values.update({decending_key: key})
89 if type(val) is dict:
90 yield from self.walk_dictionary(val, (key if decending_key is None else None), known_values)
92 def explode_cuts(self) -> pl.DataFrame:
93 """
94 Explode into single 'row' per cut / ch / gain / greater-than,less-than,equals
95 """
96 if type(self.bounds) is pl.dataframe.frame.DataFrame:
97 columns_to_expand = ["channel", "link", "adc"]
99 df: pl.DataFrame = (
100 self.bounds.with_columns(
101 [
102 # Grab number dash number and turn into array of numbers
103 pl.when(pl.col(c).str.contains(r"\d+-\d+"))
104 .then(
105 pl.col(c).map_elements(
106 lambda x: np.arange(int(x.split("-")[0]), int(x.split("-")[1]) + 1),
107 return_dtype=pl.List(pl.Int64),
108 )
109 )
110 .otherwise( # I think we never get to this anyway due to polars executing both when and then
111 pl.col(c).map_elements(lambda x: [int(x.split("-")[0])], return_dtype=pl.List(pl.Int64))
112 )
113 .alias(c)
114 for c in columns_to_expand
115 ]
116 )
117 .explode("channel")
118 .explode("link")
119 .explode("adc")
120 )
122 return df
123 elif type(self.bounds) is type(dict()):
124 df = pl.concat(
125 [pl.from_dict(d) for var in self.bounds["variables"] for d in self.walk_dictionary(var)], how="diagonal"
126 )
128 df = (
129 df.with_columns(
130 pl.col("channels").map_elements(str_dash_to_list, return_dtype=pl.List(pl.Int64)).alias("channel")
131 )
132 .explode("channel")
133 .drop("channels")
134 )
136 return df
137 else:
138 raise Exception("explode_cuts doesn't know what to do")
140 def get_bounds(
141 self, name: str, gain: Literal["hi", "lo"], amp: Optional[float] = None, attn: Optional[float] = None
142 ) -> tuple:
143 """
144 Return tuple of (lower, upper) bounds based on the name, gain, and optionally amp
145 """
146 try:
147 if type(self.bounds) is pl.dataframe.frame.DataFrame:
148 return self._get_bounds_csv(name, gain, amp, attn=attn)
149 else:
150 log.error("Uknown constants file type, returning (None,None)")
151 return (None, None)
152 except Exception as e:
153 log.error(f"Failed to get bounds for {name=}, {gain=}, {amp=}, {attn=}")
154 log.error(e)
155 return (None, None)
157 def _get_bounds_csv(
158 self,
159 name: str,
160 gain: Literal["hi", "lo"],
161 amp: Optional[float] = None,
162 board_type: str = "EM",
163 attn: Optional[float] = None,
164 channel: Optional[int] = None,
165 ) -> tuple:
166 """
167 Return tuple of (lower, upper) bounds based on the name, gain, and optionally amp
168 """
170 filters = [pl.col("name") == name, pl.col("gain") == gain]
171 if amp is not None:
172 filters.append(pl.col("awg_amp") == amp)
173 if attn is not None:
174 filters += [pl.col("att_val") == attn]
175 if board_type is not None:
176 filters += [pl.col("board_type") == board_type]
177 if channel is not None:
178 filters += [pl.col("channel") == str(channel)]
180 bounds = self.bounds.filter(*filters)
182 if len(bounds["min_requirement"]) == 0:
183 log.error(f"No cut thresholds for {name=}, {gain=}, {amp=}")
184 return (None, None)
185 elif len(bounds["min_requirement"]) > 1:
186 log.error("Multiple thresholds selected, perhaps you are missing the amp value for a pulse variable.")
187 return (None, None)
189 return (bounds["min_requirement"].item(), bounds["max_requirement"].item())
191 def draw_on(
192 self,
193 ax: Axes,
194 name: str,
195 gain: Literal["hi", "lo"],
196 amp: Optional[float] = None,
197 axis: Optional[str] = "y",
198 suppress_warnings: Optional[bool] = False,
199 ) -> None:
200 """
201 Draw the bounds on a figure.
202 axis==y(x) means the bounds are in the y(x) variable
203 """
204 try:
205 yLowPlot, yHighPlot = 0.0, 1.0
206 xLowPlot, xHighPlot = 0.0, 1.0
207 if axis == "y":
208 yLowPlot, yHighPlot = ax.get_ylim()
209 yLow, yHigh = self.get_bounds(name, gain, amp)
210 yLow = max(yLow, yLowPlot)
211 yHigh = min(yHigh, yHighPlot)
213 xLow, xHigh = ax.get_xlim()
214 else:
215 yLow, yHigh = ax.get_ylim()
216 xLowPlot, xHighPlot = ax.get_xlim()
217 xLow, xHigh = self.get_bounds(name, gain, amp)
218 xLow = max(xLow, xLowPlot)
219 xHigh = min(xHigh, xHighPlot)
221 ax.add_patch(
222 Rectangle(
223 (xLow, yLow),
224 xHigh - xLow,
225 yHigh - yLow,
226 alpha=0.2,
227 color=("blue" if gain == "lo" else "red"),
228 fill=True,
229 label=f"{gain} gain accept",
230 zorder=1,
231 )
232 )
234 # Reset to previous limits
235 if axis == "y":
236 ax.set_ylim(yLowPlot, yHighPlot)
237 else:
238 ax.set_xlim(xLowPlot, xHighPlot)
240 ax.legend()
242 except Exception as e:
243 if not suppress_warnings:
244 log.warning(f"Error looking up thresholds for {name}, {gain}, {amp}.")
245 print(e)