Coverage for tcvx21/plotting/tile2d_single_observable.py: 90%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

80 statements  

1import matplotlib.pyplot as plt 

2import tcvx21 

3plt.style.use(tcvx21.style_sheet) 

4 

5from pathlib import Path 

6from tcvx21 import Quantity 

7from tcvx21.quant_validation.ricci_metric_m import calculate_normalised_ricci_distance_from_observables 

8from tcvx21.plotting.labels_m import add_twinx_label, add_x_zero_line, make_observable_string, label_subplots 

9import numpy as np 

10from .save_figure_m import savefig 

11 

12def tile2d_single_observable(experimental_data, simulation_data, diagnostic, observable, 

13 fig_width=7.5, fig_height_per_row=1.5, title_height=1.0, 

14 subplots_kwargs={}, offset=None, 

15 save: bool = True, show: bool = False, close: bool = True, 

16 output_path: Path = None, **kwargs): 

17 """ 

18 Plots both field directions of a single 2d observable 

19 """ 

20 nrows = 2 

21 key = (diagnostic, observable) 

22 

23 fig, axs = plt.subplots(ncols=len(simulation_data) + 1, 

24 nrows=nrows, 

25 figsize=(fig_width, title_height + nrows * fig_height_per_row), 

26 sharex='col', sharey='row') 

27 

28 plt.subplots_adjust(top=1 - title_height / fig.get_size_inches()[1], **subplots_kwargs) 

29 

30 plot_2D_observables(axs=axs, key=key, experimental_data=experimental_data, 

31 simulation_data=simulation_data, offset=offset, **kwargs) 

32 

33 reference = experimental_data['forward_field'].get_observable(*key) 

34 compact_units = reference.compact_units 

35 units_string = f"{Quantity(1, compact_units).units:~P}" if compact_units else "-" 

36 if offset is not None: 

37 units_string = f"$10^{{{np.log10(offset):n}}}${units_string.lstrip('1')}" 

38 plt.suptitle(f"{make_observable_string(reference)} [{units_string}]", 

39 fontsize='large', y=1 - title_height / 2 / fig.get_size_inches()[1]) 

40 

41 label_subplots(axs.flatten()) 

42 

43 if output_path is None and save: 

44 output_path = tcvx21.results_dir / 'summary_fig' / f"{diagnostic}+{observable}.png" 

45 

46 savefig(fig, output_path=output_path, show=show, close=close) 

47 

48 return fig, axs 

49 

50 

51def plot_2D_observables(axs, key, experimental_data, simulation_data, 

52 cbar_lim_=(None, None), experiment_sets_cbar=True, add_labels=True, 

53 cbar_pad=0.075, diverging: bool = False, log_cbar: bool = False, offset = None, 

54 ticks = None, 

55 **kwargs): 

56 """ 

57 Routine to plot and style a 2D observable 

58 

59 key = (diagnostic, observable) 

60 """ 

61 # Load the experimental data 

62 reference = { 

63 'forward_field': experimental_data['forward_field'].get_observable(*key), 

64 'reversed_field': experimental_data['reversed_field'].get_observable(*key), 

65 } 

66 

67 # Calculate the common colorbar 

68 if experiment_sets_cbar: 

69 cbar_lim = reference['forward_field'].calculate_cbar_limits((reference['reversed_field'],)) 

70 else: 

71 cbar_lim = reference['forward_field'].calculate_cbar_limits( 

72 [simulation['forward_field'].get_observable(*key) for simulation in simulation_data.values()] 

73 + [simulation['reversed_field'].get_observable(*key) for simulation in simulation_data.values()]) 

74 

75 cbar_lim[0] = cbar_lim[0] if cbar_lim_[0] is None else cbar_lim_[0] 

76 cbar_lim[1] = cbar_lim[1] if cbar_lim_[1] is None else cbar_lim_[1] 

77 if diverging: 

78 cbar_lim[0], cbar_lim[1] = -max(np.abs(cbar_lim)), max(np.abs(cbar_lim)) 

79 cbar_lim = cbar_lim.to(reference['forward_field'].compact_units) 

80 

81 # Plot the experimental data 

82 image = reference['forward_field'].plot(axs[0][0], cbar_lim=cbar_lim, log_cbar=log_cbar, **kwargs) 

83 image = reference['reversed_field'].plot(axs[1][0], cbar_lim=cbar_lim, log_cbar=log_cbar, **kwargs) 

84 

85 for row, field_direction in enumerate(['forward_field', 'reversed_field']): 

86 # Set labels for the experimental data 

87 axs[row][0].set_ylim(reference[field_direction].positions_zx.min(), 

88 reference[field_direction].positions_zx.max()) 

89 if add_labels: 

90 title_text = axs[row][0].title.get_text() 

91 axs[row][0].title.set_text('') 

92 axs[row][0].set_title(f"{title_text}", fontsize='small', pad=2.0) 

93 

94 for simulation, ax in zip([case[field_direction].get_observable(*key) for case in simulation_data.values()], 

95 axs[row, 1:]): 

96 # Plot the simulation data 

97 if simulation.is_empty: 

98 ax.set_axis_off() 

99 ax.set_visible(False) 

100 continue 

101 

102 simulation.plot(ax, units=reference[field_direction].compact_units, cbar_lim=cbar_lim, log_cbar=log_cbar, **kwargs) 

103 ax.set_ylabel('') 

104 add_x_zero_line(ax) 

105 

106 if add_labels: 

107 d = calculate_normalised_ricci_distance_from_observables(simulation=simulation, 

108 experiment=reference[field_direction]) 

109 

110 title_text = ax.title.get_text() 

111 ax.title.set_text('') 

112 ax.set_title(f"{title_text}: d={d:2.1f}", fontsize='small', pad=2.0) 

113 

114 if not add_labels: 

115 for ax in axs.flatten(): 

116 ax.set_title('') 

117 ax.set_xlabel('') 

118 ax.set_ylabel('') 

119 for field_direction, ax in zip(['Forward', 'Reversed'], axs[:, 0].flatten()): 

120 ax.set_ylabel(field_direction) 

121 else: 

122 for ax in axs[0, :].flatten(): 

123 ax.set_xlabel('') 

124 for field_direction, ax in zip(['Forward field', 'Reversed field'], axs[:, -1].flatten()): 

125 add_twinx_label(ax, field_direction, labelpad=15, visible=ax.get_visible()) 

126 

127 if ticks is None: 

128 if log_cbar: 

129 ticks = np.logspace(np.log10(cbar_lim[0].magnitude), np.log10(cbar_lim[1].magnitude), num=5) 

130 else: 

131 ticks = np.linspace(cbar_lim[0].magnitude, cbar_lim[1].magnitude, num=5) 

132 

133 if log_cbar: 

134 def format_func(value, tick_number): 

135 return f'$10^{{{np.log10(value):2.1f}}}$' 

136 else: 

137 def format_func(value, tick_number): 

138 offset_ = offset if offset is not None else 1.0 

139 return f'{value/ offset_:.3g}' 

140 

141 cbar = plt.colorbar(image, ax=axs.flatten(), pad=cbar_pad, ticks=ticks, format=plt.FuncFormatter(format_func)) 

142 

143 return axs, cbar