Coverage for tcvx21/plotting/plot_comparison_2d_m.py: 90%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

84 statements  

1""" 

2Plotting routines for 2D observables 

3""" 

4import matplotlib.pyplot as plt 

5import tcvx21 

6plt.style.use(tcvx21.style_sheet) 

7 

8from tcvx21.quant_validation.ricci_metric_m import calculate_normalised_ricci_distance_from_observables 

9from .labels_m import add_twinx_label, make_labels, add_x_zero_line, make_colorbar 

10import numpy as np 

11 

12def plot_2D_comparison(field_direction, diagnostic, observable, 

13 experimental_data, simulation_data, 

14 fig_width=7.5, fig_height_per_row=1.5, title_height=1.0, 

15 common_colorbar: bool = True, experiment_sets_colorbar: bool = False, 

16 diverging: bool = False, 

17 **kwargs): 

18 """ 

19 Plots a 2D (area data) observable, with different values represented via a constant color 

20 

21 common_colorbar will use the same colorbar for all figures, which is good for a quantitative side-by-side 

22 comparison, but which is sensitive to extreme values 

23 """ 

24 fig, axs = plt.subplots(ncols=len(simulation_data) + 1, 

25 figsize=(fig_width, title_height + fig_height_per_row), 

26 sharex=True, sharey=True) 

27 

28 for ax in axs.flatten(): 

29 add_x_zero_line(ax) 

30 

31 key = {'diagnostic': diagnostic, 'observable': observable} 

32 

33 reference = experimental_data[field_direction].get_observable(**key) 

34 

35 simulations = [case[field_direction].get_observable(**key) for case in simulation_data.values()] 

36 if common_colorbar: 

37 if experiment_sets_colorbar: 

38 cbar_lim = reference.calculate_cbar_limits().to(reference.compact_units) 

39 else: 

40 cbar_lim = reference.calculate_cbar_limits(simulations).to(reference.compact_units) 

41 

42 if diverging: 

43 cbar_lim[0], cbar_lim[1] = -max(np.abs(cbar_lim)), max(np.abs(cbar_lim)) 

44 else: 

45 cbar_lim = None 

46 

47 image = reference.plot(axs[0], cbar_lim=cbar_lim, diverging=diverging, **kwargs) 

48 title_text = axs[0].title.get_text() 

49 axs[0].title.set_text('') 

50 axs[0].set_title(f"{title_text}", loc='left', fontsize='small') 

51 axs[0].set_ylim(reference.positions_zx.min(), reference.positions_zx.max()) 

52 

53 for simulation, ax in zip(simulations, axs[1:]): 

54 

55 if simulation.is_empty: 

56 continue 

57 image = simulation.plot(ax, units = reference.compact_units, cbar_lim=cbar_lim, diverging=diverging, **kwargs) 

58 ax.set_ylabel('') 

59 

60 d = calculate_normalised_ricci_distance_from_observables(simulation=simulation, experiment=reference) 

61 

62 title_text = ax.title.get_text() 

63 ax.title.set_text('') 

64 ax.set_title(f"{title_text}: d={d:2.1f}", loc='left', fontsize='small') 

65 

66 _, fig_height = fig.get_size_inches() 

67 plt.subplots_adjust(top=1 - title_height / fig_height, wspace=0.25) 

68 

69 field_direction_string, _, observable_string = make_labels(field_direction, diagnostic, reference) 

70 

71 if common_colorbar: 

72 plt.suptitle(f"{observable_string}: {field_direction_string}", 

73 fontsize='large', y=1 - title_height / 2 / fig_height) 

74 cbar = make_colorbar(ax=axs.flatten(), mappable=image, units=reference.compact_units, as_title=True) 

75 ticks = np.round(np.linspace(cbar_lim[0].magnitude, cbar_lim[1].magnitude, num=5), decimals=1) 

76 tick_labels = [f'{x:4.3}' for x in ticks] 

77 cbar.set_ticks(ticks) 

78 cbar.set_ticklabels(tick_labels) 

79 else: 

80 plt.suptitle(observable_string, fontsize='large', y=1 - title_height / 2 / fig_height) 

81 add_twinx_label(axs[-1], field_direction_string) 

82 

83 return fig, axs 

84 

85def plot_2D_comparison_simulation_only(field_direction, diagnostic, observable, 

86 simulation_data, 

87 fig_width=7.5, fig_height_per_row=1.5, title_height=1.0, 

88 common_colorbar: bool = True, experiment_sets_colorbar: bool = False, 

89 diverging: bool = False, 

90 **kwargs): 

91 """ 

92 Plots a 2D (area data) observable, with different values represented via a constant color 

93 

94 common_colorbar will use the same colorbar for all figures, which is good for a quantitative side-by-side 

95 comparison, but which is sensitive to extreme values 

96 """ 

97 fig, axs = plt.subplots(ncols=len(simulation_data), 

98 figsize=(fig_width, title_height + fig_height_per_row), 

99 sharex=True, sharey=True) 

100 axs = np.atleast_1d(axs) 

101 

102 for ax in axs.flatten(): 

103 add_x_zero_line(ax) 

104 

105 key = {'diagnostic': diagnostic, 'observable': observable} 

106 

107 reference = list(simulation_data.values())[0][field_direction].get_observable(**key) 

108 

109 simulations = [case[field_direction].get_observable(**key) for case in simulation_data.values()] 

110 if common_colorbar: 

111 if experiment_sets_colorbar: 

112 print("Warning: experiment_sets_colorbar ignored in plot_2D_comparison_simulation_only") 

113 else: 

114 cbar_lim = reference.calculate_cbar_limits(simulations).to(reference.compact_units) 

115 

116 if diverging: 

117 cbar_lim[0], cbar_lim[1] = -max(np.abs(cbar_lim)), max(np.abs(cbar_lim)) 

118 else: 

119 cbar_lim = None 

120 

121 axs[0].set_ylim(reference.positions_zx.min(), reference.positions_zx.max()) 

122 

123 for simulation, ax in zip(simulations, axs): 

124 

125 if simulation.is_empty: 

126 continue 

127 image = simulation.plot(ax, units = reference.compact_units, cbar_lim=cbar_lim, diverging=diverging, **kwargs) 

128 ax.set_ylabel('') 

129 

130 title_text = ax.title.get_text() 

131 ax.title.set_text('') 

132 ax.set_title(f"{title_text}", loc='left', fontsize='small') 

133 

134 _, fig_height = fig.get_size_inches() 

135 plt.subplots_adjust(top=1 - title_height / fig_height, wspace=0.25) 

136 

137 field_direction_string, _, observable_string = make_labels(field_direction, diagnostic, reference) 

138 

139 if common_colorbar: 

140 plt.suptitle(f"{observable_string}: {field_direction_string}", 

141 fontsize='large', y=1 - title_height / 2 / fig_height) 

142 cbar = make_colorbar(ax=axs.flatten(), mappable=image, units=reference.compact_units, as_title=True) 

143 ticks = np.round(np.linspace(cbar_lim[0].magnitude, cbar_lim[1].magnitude, num=5), decimals=1) 

144 tick_labels = [f'{x:4.3}' for x in ticks] 

145 cbar.set_ticks(ticks) 

146 cbar.set_ticklabels(tick_labels) 

147 else: 

148 plt.suptitle(observable_string, fontsize='large', y=1 - title_height / 2 / fig_height) 

149 add_twinx_label(axs[-1], field_direction_string) 

150 

151 return fig, axs