Coverage for tcvx21/plotting/tile2d_single_observable.py: 90%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import matplotlib.pyplot as plt
2import tcvx21
3plt.style.use(tcvx21.style_sheet)
5from pathlib import Path
6from tcvx21 import Quantity
7from tcvx21.quant_validation.ricci_metric_m import calculate_normalised_ricci_distance_from_observables
8from tcvx21.plotting.labels_m import add_twinx_label, add_x_zero_line, make_observable_string, label_subplots
9import numpy as np
10from .save_figure_m import savefig
12def tile2d_single_observable(experimental_data, simulation_data, diagnostic, observable,
13 fig_width=7.5, fig_height_per_row=1.5, title_height=1.0,
14 subplots_kwargs={}, offset=None,
15 save: bool = True, show: bool = False, close: bool = True,
16 output_path: Path = None, **kwargs):
17 """
18 Plots both field directions of a single 2d observable
19 """
20 nrows = 2
21 key = (diagnostic, observable)
23 fig, axs = plt.subplots(ncols=len(simulation_data) + 1,
24 nrows=nrows,
25 figsize=(fig_width, title_height + nrows * fig_height_per_row),
26 sharex='col', sharey='row')
28 plt.subplots_adjust(top=1 - title_height / fig.get_size_inches()[1], **subplots_kwargs)
30 plot_2D_observables(axs=axs, key=key, experimental_data=experimental_data,
31 simulation_data=simulation_data, offset=offset, **kwargs)
33 reference = experimental_data['forward_field'].get_observable(*key)
34 compact_units = reference.compact_units
35 units_string = f"{Quantity(1, compact_units).units:~P}" if compact_units else "-"
36 if offset is not None:
37 units_string = f"$10^{{{np.log10(offset):n}}}${units_string.lstrip('1')}"
38 plt.suptitle(f"{make_observable_string(reference)} [{units_string}]",
39 fontsize='large', y=1 - title_height / 2 / fig.get_size_inches()[1])
41 label_subplots(axs.flatten())
43 if output_path is None and save:
44 output_path = tcvx21.results_dir / 'summary_fig' / f"{diagnostic}+{observable}.png"
46 savefig(fig, output_path=output_path, show=show, close=close)
48 return fig, axs
51def plot_2D_observables(axs, key, experimental_data, simulation_data,
52 cbar_lim_=(None, None), experiment_sets_cbar=True, add_labels=True,
53 cbar_pad=0.075, diverging: bool = False, log_cbar: bool = False, offset = None,
54 ticks = None,
55 **kwargs):
56 """
57 Routine to plot and style a 2D observable
59 key = (diagnostic, observable)
60 """
61 # Load the experimental data
62 reference = {
63 'forward_field': experimental_data['forward_field'].get_observable(*key),
64 'reversed_field': experimental_data['reversed_field'].get_observable(*key),
65 }
67 # Calculate the common colorbar
68 if experiment_sets_cbar:
69 cbar_lim = reference['forward_field'].calculate_cbar_limits((reference['reversed_field'],))
70 else:
71 cbar_lim = reference['forward_field'].calculate_cbar_limits(
72 [simulation['forward_field'].get_observable(*key) for simulation in simulation_data.values()]
73 + [simulation['reversed_field'].get_observable(*key) for simulation in simulation_data.values()])
75 cbar_lim[0] = cbar_lim[0] if cbar_lim_[0] is None else cbar_lim_[0]
76 cbar_lim[1] = cbar_lim[1] if cbar_lim_[1] is None else cbar_lim_[1]
77 if diverging:
78 cbar_lim[0], cbar_lim[1] = -max(np.abs(cbar_lim)), max(np.abs(cbar_lim))
79 cbar_lim = cbar_lim.to(reference['forward_field'].compact_units)
81 # Plot the experimental data
82 image = reference['forward_field'].plot(axs[0][0], cbar_lim=cbar_lim, log_cbar=log_cbar, **kwargs)
83 image = reference['reversed_field'].plot(axs[1][0], cbar_lim=cbar_lim, log_cbar=log_cbar, **kwargs)
85 for row, field_direction in enumerate(['forward_field', 'reversed_field']):
86 # Set labels for the experimental data
87 axs[row][0].set_ylim(reference[field_direction].positions_zx.min(),
88 reference[field_direction].positions_zx.max())
89 if add_labels:
90 title_text = axs[row][0].title.get_text()
91 axs[row][0].title.set_text('')
92 axs[row][0].set_title(f"{title_text}", fontsize='small', pad=2.0)
94 for simulation, ax in zip([case[field_direction].get_observable(*key) for case in simulation_data.values()],
95 axs[row, 1:]):
96 # Plot the simulation data
97 if simulation.is_empty:
98 ax.set_axis_off()
99 ax.set_visible(False)
100 continue
102 simulation.plot(ax, units=reference[field_direction].compact_units, cbar_lim=cbar_lim, log_cbar=log_cbar, **kwargs)
103 ax.set_ylabel('')
104 add_x_zero_line(ax)
106 if add_labels:
107 d = calculate_normalised_ricci_distance_from_observables(simulation=simulation,
108 experiment=reference[field_direction])
110 title_text = ax.title.get_text()
111 ax.title.set_text('')
112 ax.set_title(f"{title_text}: d={d:2.1f}", fontsize='small', pad=2.0)
114 if not add_labels:
115 for ax in axs.flatten():
116 ax.set_title('')
117 ax.set_xlabel('')
118 ax.set_ylabel('')
119 for field_direction, ax in zip(['Forward', 'Reversed'], axs[:, 0].flatten()):
120 ax.set_ylabel(field_direction)
121 else:
122 for ax in axs[0, :].flatten():
123 ax.set_xlabel('')
124 for field_direction, ax in zip(['Forward field', 'Reversed field'], axs[:, -1].flatten()):
125 add_twinx_label(ax, field_direction, labelpad=15, visible=ax.get_visible())
127 if ticks is None:
128 if log_cbar:
129 ticks = np.logspace(np.log10(cbar_lim[0].magnitude), np.log10(cbar_lim[1].magnitude), num=5)
130 else:
131 ticks = np.linspace(cbar_lim[0].magnitude, cbar_lim[1].magnitude, num=5)
133 if log_cbar:
134 def format_func(value, tick_number):
135 return f'$10^{{{np.log10(value):2.1f}}}$'
136 else:
137 def format_func(value, tick_number):
138 offset_ = offset if offset is not None else 1.0
139 return f'{value/ offset_:.3g}'
141 cbar = plt.colorbar(image, ax=axs.flatten(), pad=cbar_pad, ticks=ticks, format=plt.FuncFormatter(format_func))
143 return axs, cbar