Coverage for tcvx21/quant_validation/latex_table_writer_m.py: 92%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2Apologies for hard-coding. This makes a latex table which has conditional formatting applied and
3expands the diagnostic, observable into nice compact symbolic forms
5Not suitable for picking up directly
6"""
7import tcvx21
8from tcvx21.quant_validation.ricci_metric_m import level_of_agreement_function
9from matplotlib import colors
10import numpy as np
11import matplotlib.pyplot as plt
12from pathlib import Path
14diagnostics_latex = {
15 'FHRP': 'Fast\\\\horizontally-\\\\reciprocating\\\\probe (FHRP)\\\\for outboard\\\\midplane',
16 'TS': 'Thomson scattering\\\\(TS) for divertor\\\\entrance',
17 'RDPA': 'Reciprocating\\\\divertor probe\\\\array (RDPA)\\\\for divertor\\\\volume',
18 'LFS-IR': 'Infrared camera (IR)\\\\for low-field-side target',
19 'LFS-LP': 'Wall Langmuir\\\\probes for\\\\low-field-side\\\\target',
20 'HFS-LP': 'Wall Langmuir\\\\probes for\\\\high-field-side\\\\target',
21}
23observable_latex = {
24 'density': '$n$',
25 'electron_temp': '$T_e$',
26 'potential': '$V_{pl}$',
27 'jsat': '$J_{sat}$',
28 'jsat_std': '$\\sigma\\left(J_{sat}\\right)$',
29 'jsat_skew': '$\\mathrm{skew}\\left(J_{sat}\\right)$',
30 'jsat_kurtosis': '$\\mathrm{kurt}\\left(J_{sat}\\right)$',
31 'current': '$J_\\parallel$',
32 'current_std': '$\\sigma\\left(J_\\parallel\\right)$',
33 'vfloat': '$V_{fl}$',
34 'vfloat_std': '$\\sigma\\left(V_{fl}\\right)$',
35 'q_parallel': '$q_\\parallel$',
36 'lambda_q': '$\\lambda_q$',
37 'mach_number': '$M_\\parallel$',
38}
40# Chi as a latex symbol
41chi_latex = '$\\chi$'
44def tuple_to_string(tuple_in: tuple) -> str:
45 """Converts an RGB tuple to a latex-xcolor compatible string"""
46 return ', '.join([f'{val:6.3f}' for val in tuple_in])
49def write_number_w_color(value, file, cell_clip=None, cmap='RdYlGn_r', text_clip_fraction=False, expand=1.4,
50 debug=False):
51 """
52 Writes a 'value' to a file, applying conditional cell colouring and text colouring
54 It is assumed that all values are >= 0.0, or np.nan
55 cell_clip gives the 100% intensity value, above which all values are set to 100%. This also sets the
56 normalisation of the range < cell_clip: a linear ramp from 0 to cell_clip is used
58 cmap is the name of a matplotlib colormap, default set to the red-blue diverging colormap
59 text_clip_fraction switches from black to white text for values above cell_clip * text_clip_fraction
60 expand sets the 0% and 100% intensity of the colormap to outside the normalised value space. This is useful
61 to make the cell color closer to white, which makes the cells more readable
62 """
64 if not np.isnan(value):
65 if cell_clip is None:
66 cellcolor, fontcolor = (1.0, 1.0, 1.0), (0.0, 0.0, 0.0)
67 else:
68 assert value >= 0.0
70 norm = colors.Normalize(vmin=cell_clip * (1 - expand), vmax=cell_clip * expand)
71 normalised = norm(np.min((value, cell_clip)))
72 cellcolor = plt.cm.get_cmap(cmap)(normalised)[:-1]
74 if text_clip_fraction:
75 # Set value < text_clip to black, and > text_clip to white
76 fontcolor = (0.0, 0.0, 0.0) if value < text_clip_fraction * cell_clip else (1.0, 1.0, 1.0)
77 else:
78 fontcolor = (0.0, 0.0, 0.0)
80 text = f"{value:.3}"
81 if 'e' in text:
82 text = f"{int(float(text))}"
84 else:
85 cellcolor, fontcolor = (0.4, 0.4, 0.4), (1.0, 1.0, 1.0)
86 text = f"{'NaN':6}"
88 if not debug:
89 file.write(f"\\cellcolor[rgb]{{{tuple_to_string(cellcolor)}}}"
90 f"\\textcolor[rgb]{{{tuple_to_string(fontcolor)}}}"
91 f"{{{text:6}}}")
92 else:
93 file.write(f"{text:6}")
96def process_case_key(case_key):
97 """Convert the toroidal field sign to math-type, to get it to print nicely"""
98 return case_key.replace('+', '($+$)').replace('-', '($-$)')
101def make_colorbar(cell_clip, label, output_path = None, cmap='RdYlGn_r', expand=1.4, eps = 1E-2):
103 plt.style.use(tcvx21.style_sheet)
104 cnorm = colors.Normalize(vmin=cell_clip * (1 - expand), vmax=cell_clip * expand)
106 fig, ax = plt.subplots(figsize=(5, 1))
107 dummy_array = np.array([[0, cell_clip]])
109 im = ax.imshow(dummy_array, cmap=cmap, norm=cnorm)
110 ax.set_visible(False)
112 # left, bottom, width, height in normalised axis units
113 cax = plt.axes([0.1, 0.2, 0.8, 0.1])
115 plt.colorbar(im, orientation="horizontal", cax=cax, extend='max',
116 ticks=np.arange(cell_clip+eps), boundaries=np.linspace(0,cell_clip+eps, num=1000))
117 cax.set_ylabel(label, rotation=0, labelpad=10, y=-0.5, size=plt.rcParams['axes.titlesize'])
119 if output_path is not None:
120 tcvx21.plotting.savefig(fig, output_path)
123def write_cases_to_latex(cases: dict, output_file: Path):
124 """
125 Write a formatted conditionally colored table to a file
126 """
128 for case in cases.values():
129 case.calculate_metric_terms()
131 case_keys = list(cases.keys())
132 n_cases = len(cases)
134 make_colorbar(cell_clip=5.0, label='$d_j$', output_path=output_file.parent/'colorbar.png')
136 with open(output_file, 'w') as f:
138 # Header
139 f.write(f"\\begin{{tabular}}{{ll{n_cases * 'cc'}}}\n")
140 f.write(f"\\toprule\n")
142 case_header = ' & '.join([f"\\multicolumn{{2}}{{c}}{{{process_case_key(case_key)}}}" for case_key in case_keys])
143 f.write(f" & & {case_header} \\\\ \n")
145 case_header = ' & '.join(n_cases * ["$d_j$ & $S$"])
146 f.write(f"Diagnostic & observable & {case_header} \\\\ \n")
147 f.write(f"\\midrule\n")
149 # Write the results in
150 for diagnostic, diagnostic_styled in diagnostics_latex.items():
152 # Assume that all cases have the same hierarchies, so take
153 # the first value
154 hierarchies = cases[case_keys[0]].hierarchy[diagnostic]
155 n_observables = len(hierarchies)
157 f.write(f"\\multirow{{{n_observables + 1}}}{{*}}{{\\makecell{{{diagnostic_styled}}}}}\n")
159 d = np.nan * np.ones((n_observables, n_cases))
160 S = np.nan * np.ones((n_observables, n_cases))
161 H = np.nan * np.ones(n_observables)
162 i = -1
164 for observable, observable_styled in observable_latex.items():
165 if observable in hierarchies.keys():
166 i += 1
167 H[i] = 1.0 / hierarchies[observable]
168 f.write(f"& {observable_styled:40} & ")
170 for j, case in enumerate(cases.values()):
171 d[i, j] = case.distance[diagnostic][observable]
172 write_number_w_color(d[i, j], file=f, cell_clip=5.0)
173 f.write(' & ')
175 S[i, j] = case.sensitivity[diagnostic][observable]
176 write_number_w_color(S[i, j], file=f)
178 if not j == n_cases - 1:
179 # Don't write an alignment character for the last entry on a line
180 f.write(' & ')
182 f.write('\\\\\n')
184 f.write(f"\\cline{{2-{2 * n_cases + 2}}}\n")
186 Q = np.nansum(H[:, np.newaxis] * S, axis=0)
187 chi = np.nansum(H[:, np.newaxis] * S * level_of_agreement_function(d), axis=0) / Q
189 # Write diagnostic chi and Q
191 diagnostic_summary = "$\\left(\\chi; Q\\right)$\\textsubscript" + f"{{{diagnostic}}}"
192 f.write(f"& {diagnostic_summary:40} & ")
195 for j in range(n_cases):
196 f.write(
197 "\\multicolumn{{2}}{{c}}{{$ \\textbf{{({chi:.2}; \\ {Q:.3})}} $ }}".format(chi=chi[j], Q=Q[j])
198 )
200 if not j == n_cases - 1:
201 # Don't write an alignment character for the last entry on a line
202 f.write(' & ')
204 f.write('\\\\\n')
205 f.write('\\midrule\n')
207 # Write the overall agreement
208 f.write(f"Overall\n& {chi_latex + '; $Q$':40} & ")
210 for j, case in enumerate(cases.values()):
211 chi, Q = case.compute_chi()
213 f.write(
214 "\\multicolumn{{2}}{{c}}{{$ \\textbf{{({chi:.2}; \\ {Q:.3})}} $ }}".format(chi=chi, Q=Q)
215 )
217 if not j == n_cases - 1:
218 # Don't write an alignment character for the last entry on a line
219 f.write(' & ')
220 else:
221 f.write('\\\\\n')
223 # Footer
224 f.write("\\bottomrule\n")
225 f.write("\\end{tabular}")