Coverage for tcvx21/plotting/tile1d_m.py: 89%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2Routine for making a tiled plot
4This is helpful for comparing results with each other
6Can make tile plots either for a single diagnostic and many observables, or for a single observable at many
7diagnostic positions
8"""
9import matplotlib.pyplot as plt
10import tcvx21
11plt.style.use(tcvx21.style_sheet)
13import numpy as np
14from pathlib import Path
16from tcvx21.plotting.labels_m import add_twinx_label, add_x_zero_line, add_y_zero_line, \
17 make_observable_string, make_diagnostic_string, make_field_direction_string, label_subplots, format_yaxis
18from tcvx21.quant_validation.ricci_metric_m import calculate_normalised_ricci_distance_from_observables
19from .save_figure_m import savefig
21from typing import Union
22from tcvx21.quant_validation.latex_table_writer_m import observable_latex
25def tile1d(experimental_data, simulation_data,
26 diagnostics: Union[np.ndarray, str, tuple],
27 observables: Union[np.ndarray, str, tuple],
28 overplot: tuple = None,
29 fig_width=7.5, fig_height_per_row=2.0, title_height=1.0, manual_title: str = '',
30 legend_loc=('best',), make_title: bool = True,
31 show: bool = False, save: bool = True, close: bool = True,
32 output_path: Path = None):
33 """
34 Make a tiled plot showing either several observables for a single diagnostics, or several diagnostics (positions)
35 for the same observable
36 """
37 observables = np.atleast_1d(observables)
38 diagnostics = np.atleast_1d(diagnostics)
40 assert observables.ndim == 1 and diagnostics.ndim == 1
42 if observables.size == 1:
43 mode = 'observable'
44 observables = np.broadcast_to(observables, diagnostics.size)
45 elif diagnostics.size == 1:
46 mode = 'diagnostic'
47 diagnostics = np.broadcast_to(diagnostics, observables.size)
48 else:
49 raise NotImplementedError(f"Could not make tile plot with observables {observables.shape} "
50 f"and diagnostics {diagnostics.shape}")
52 ncols = 2
53 nrows = observables.size
55 fig, axs = plt.subplots(ncols=ncols, nrows=nrows,
56 figsize=(fig_width, title_height + nrows * fig_height_per_row),
57 sharex='col', sharey='row',
58 squeeze=False)
60 for column, field_direction in enumerate(['forward_field', 'reversed_field']):
61 # Always plot forward field in the left column, and reversed field in the
62 # right column
63 for row, (observable, diagnostic) in enumerate(zip(observables, diagnostics)):
64 cases = []
65 for fd in ['forward_field', 'reversed_field']:
66 c = experimental_data[fd].get_observable(diagnostic, observable)
67 if not c.is_empty:
68 cases.append(c)
70 ax = axs[row, column]
72 reference = experimental_data[field_direction].get_observable(diagnostic, observable)
73 if reference.is_empty:
74 continue
76 if mode == 'diagnostic':
77 label = diagnostic
78 else:
79 label = f"TCV {observable_latex[observable]}"
81 reference.plot(ax, label=label)
83 for case in simulation_data.values():
84 for fd in ['forward_field', 'reversed_field']:
85 c = case[fd].get_observable(diagnostic, observable)
86 if not c.is_empty:
87 cases.append(c)
89 simulation = case[field_direction].get_observable(diagnostic, observable)
90 if simulation.is_empty:
91 continue
93 line = simulation.plot(ax, trim_to_x=reference.xlim)
95 d = calculate_normalised_ricci_distance_from_observables(simulation=simulation,
96 experiment=reference)
98 line.set_label(f"{line.get_label()}: d={d:3.2f}")
100 if overplot is not None:
101 if mode == 'observable':
102 raise NotImplementedError('Overplotting of observables is not supported (units must match)')
103 for (overplot_observable, overplot_diagnostic, overplot_color) in overplot:
104 if mode == 'diagnostic' and observable == overplot_observable:
105 overplot_ = experimental_data[field_direction].get_observable(overplot_diagnostic, observable)
106 if overplot_.is_empty:
107 continue
108 overplot_.plot(ax, color=overplot_color, label=overplot_diagnostic)
110 # Set the ylim to show all data
111 ax.set_ylim(*reference.ylims_in_trim(cases, trim_to_x=reference.xlim))
113 # Apply custom limits if they are defined
114 reference.apply_plot_limits(ax)
116 ax.legend(loc=legend_loc[0] if len(legend_loc) == 1 else legend_loc[2 * row + column])
117 add_x_zero_line(ax)
118 add_y_zero_line(ax)
119 if 'kurtosis' in observable:
120 add_y_zero_line(ax, level=3.0)
122 if row != nrows - 1:
123 ax.set_xlabel('')
124 if column != 0:
125 ax.set_ylabel('')
126 if mode == 'observable':
127 add_twinx_label(ax, make_diagnostic_string(diagnostic))
128 else:
129 add_twinx_label(ax, make_observable_string(reference))
131 if row == 0:
132 ax.set_title(make_field_direction_string(field_direction))
134 plt.draw()
135 for row in range(nrows):
136 format_yaxis(axs[row, 0])
138 if manual_title:
139 title_string = manual_title
140 elif mode == 'observable':
141 reference = experimental_data['forward_field'].get_observable(diagnostics[0], observables[0])
142 title_string = make_observable_string(reference)
143 else:
144 title_string = make_diagnostic_string(diagnostics[0])
146 if make_title:
147 _, fig_height = fig.get_size_inches()
148 plt.subplots_adjust(top=1 - title_height / fig_height)
150 plt.suptitle(title_string, fontsize='large', y=1 - title_height / 2 / fig_height)
152 label_subplots(axs.flatten())
154 if output_path is None and save:
155 if mode == 'observable':
156 filename = f"{title_string}+{','.join(list(diagnostics))}".replace(' ', '_')
157 else:
158 filename = f"{title_string}+{','.join(list(observables))}".replace(' ', '_')
160 output_path = tcvx21.results_dir / 'summary_fig' / f"{filename}.png"
162 savefig(fig, output_path=output_path, show=show, close=close)
164 return fig, axs