Coverage for tcvx21/plotting/tile1d_m.py: 89%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

96 statements  

1""" 

2Routine for making a tiled plot 

3 

4This is helpful for comparing results with each other 

5 

6Can make tile plots either for a single diagnostic and many observables, or for a single observable at many 

7diagnostic positions 

8""" 

9import matplotlib.pyplot as plt 

10import tcvx21 

11plt.style.use(tcvx21.style_sheet) 

12 

13import numpy as np 

14from pathlib import Path 

15 

16from tcvx21.plotting.labels_m import add_twinx_label, add_x_zero_line, add_y_zero_line, \ 

17 make_observable_string, make_diagnostic_string, make_field_direction_string, label_subplots, format_yaxis 

18from tcvx21.quant_validation.ricci_metric_m import calculate_normalised_ricci_distance_from_observables 

19from .save_figure_m import savefig 

20 

21from typing import Union 

22from tcvx21.quant_validation.latex_table_writer_m import observable_latex 

23 

24 

25def tile1d(experimental_data, simulation_data, 

26 diagnostics: Union[np.ndarray, str, tuple], 

27 observables: Union[np.ndarray, str, tuple], 

28 overplot: tuple = None, 

29 fig_width=7.5, fig_height_per_row=2.0, title_height=1.0, manual_title: str = '', 

30 legend_loc=('best',), make_title: bool = True, 

31 show: bool = False, save: bool = True, close: bool = True, 

32 output_path: Path = None): 

33 """ 

34 Make a tiled plot showing either several observables for a single diagnostics, or several diagnostics (positions) 

35 for the same observable 

36 """ 

37 observables = np.atleast_1d(observables) 

38 diagnostics = np.atleast_1d(diagnostics) 

39 

40 assert observables.ndim == 1 and diagnostics.ndim == 1 

41 

42 if observables.size == 1: 

43 mode = 'observable' 

44 observables = np.broadcast_to(observables, diagnostics.size) 

45 elif diagnostics.size == 1: 

46 mode = 'diagnostic' 

47 diagnostics = np.broadcast_to(diagnostics, observables.size) 

48 else: 

49 raise NotImplementedError(f"Could not make tile plot with observables {observables.shape} " 

50 f"and diagnostics {diagnostics.shape}") 

51 

52 ncols = 2 

53 nrows = observables.size 

54 

55 fig, axs = plt.subplots(ncols=ncols, nrows=nrows, 

56 figsize=(fig_width, title_height + nrows * fig_height_per_row), 

57 sharex='col', sharey='row', 

58 squeeze=False) 

59 

60 for column, field_direction in enumerate(['forward_field', 'reversed_field']): 

61 # Always plot forward field in the left column, and reversed field in the 

62 # right column 

63 for row, (observable, diagnostic) in enumerate(zip(observables, diagnostics)): 

64 cases = [] 

65 for fd in ['forward_field', 'reversed_field']: 

66 c = experimental_data[fd].get_observable(diagnostic, observable) 

67 if not c.is_empty: 

68 cases.append(c) 

69 

70 ax = axs[row, column] 

71 

72 reference = experimental_data[field_direction].get_observable(diagnostic, observable) 

73 if reference.is_empty: 

74 continue 

75 

76 if mode == 'diagnostic': 

77 label = diagnostic 

78 else: 

79 label = f"TCV {observable_latex[observable]}" 

80 

81 reference.plot(ax, label=label) 

82 

83 for case in simulation_data.values(): 

84 for fd in ['forward_field', 'reversed_field']: 

85 c = case[fd].get_observable(diagnostic, observable) 

86 if not c.is_empty: 

87 cases.append(c) 

88 

89 simulation = case[field_direction].get_observable(diagnostic, observable) 

90 if simulation.is_empty: 

91 continue 

92 

93 line = simulation.plot(ax, trim_to_x=reference.xlim) 

94 

95 d = calculate_normalised_ricci_distance_from_observables(simulation=simulation, 

96 experiment=reference) 

97 

98 line.set_label(f"{line.get_label()}: d={d:3.2f}") 

99 

100 if overplot is not None: 

101 if mode == 'observable': 

102 raise NotImplementedError('Overplotting of observables is not supported (units must match)') 

103 for (overplot_observable, overplot_diagnostic, overplot_color) in overplot: 

104 if mode == 'diagnostic' and observable == overplot_observable: 

105 overplot_ = experimental_data[field_direction].get_observable(overplot_diagnostic, observable) 

106 if overplot_.is_empty: 

107 continue 

108 overplot_.plot(ax, color=overplot_color, label=overplot_diagnostic) 

109 

110 # Set the ylim to show all data 

111 ax.set_ylim(*reference.ylims_in_trim(cases, trim_to_x=reference.xlim)) 

112 

113 # Apply custom limits if they are defined 

114 reference.apply_plot_limits(ax) 

115 

116 ax.legend(loc=legend_loc[0] if len(legend_loc) == 1 else legend_loc[2 * row + column]) 

117 add_x_zero_line(ax) 

118 add_y_zero_line(ax) 

119 if 'kurtosis' in observable: 

120 add_y_zero_line(ax, level=3.0) 

121 

122 if row != nrows - 1: 

123 ax.set_xlabel('') 

124 if column != 0: 

125 ax.set_ylabel('') 

126 if mode == 'observable': 

127 add_twinx_label(ax, make_diagnostic_string(diagnostic)) 

128 else: 

129 add_twinx_label(ax, make_observable_string(reference)) 

130 

131 if row == 0: 

132 ax.set_title(make_field_direction_string(field_direction)) 

133 

134 plt.draw() 

135 for row in range(nrows): 

136 format_yaxis(axs[row, 0]) 

137 

138 if manual_title: 

139 title_string = manual_title 

140 elif mode == 'observable': 

141 reference = experimental_data['forward_field'].get_observable(diagnostics[0], observables[0]) 

142 title_string = make_observable_string(reference) 

143 else: 

144 title_string = make_diagnostic_string(diagnostics[0]) 

145 

146 if make_title: 

147 _, fig_height = fig.get_size_inches() 

148 plt.subplots_adjust(top=1 - title_height / fig_height) 

149 

150 plt.suptitle(title_string, fontsize='large', y=1 - title_height / 2 / fig_height) 

151 

152 label_subplots(axs.flatten()) 

153 

154 if output_path is None and save: 

155 if mode == 'observable': 

156 filename = f"{title_string}+{','.join(list(diagnostics))}".replace(' ', '_') 

157 else: 

158 filename = f"{title_string}+{','.join(list(observables))}".replace(' ', '_') 

159 

160 output_path = tcvx21.results_dir / 'summary_fig' / f"{filename}.png" 

161 

162 savefig(fig, output_path=output_path, show=show, close=close) 

163 

164 return fig, axs