Coverage for tcvx21/observable_c/observable_2d_m.py: 93%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

122 statements  

1from tcvx21 import Quantity 

2import numpy as np 

3import matplotlib.pyplot as plt 

4from scipy.interpolate import griddata 

5from .observable_m import Observable, MissingDataError 

6from matplotlib.colors import Normalize, LogNorm 

7 

8 

9class Observable2D(Observable): 

10 

11 def __init__(self, data, diagnostic, observable, label, color, linestyle): 

12 self.name = '' 

13 self.label = '' 

14 self.color = '' 

15 self.linestyle = '' 

16 self.diagnostic = '' 

17 self.observable = '' 

18 self.dimensionality = -1 

19 self.experimental_hierarchy = -1 

20 self.simulation_hierarchy = -1 

21 self._values = [] 

22 self._errors = [] 

23 self.mask = [] 

24 

25 try: 

26 super().__init__(data, diagnostic, observable, label, color, linestyle) 

27 self._positions_rsep = Quantity(data['Rsep_omp'][:], data['Rsep_omp'].units).to('cm') 

28 self._positions_zx = Quantity(data['Zx'][:], data['Zx'].units).to('m') 

29 

30 try: 

31 # If the (R, Z) coordinates for the reference equilibrium are available, also load them 

32 self._positions_r = Quantity(data['R'][:], data['R'].units).to('m') 

33 self._positions_z = Quantity(data['Z'][:], data['Z'].units).to('m') 

34 except IndexError: 

35 pass 

36 

37 self.set_mask() 

38 

39 except AttributeError: 

40 raise MissingDataError 

41 

42 def check_dimensionality(self): 

43 assert self.dimensionality == 2 

44 

45 @property 

46 def positions_rsep(self) -> Quantity: 

47 """ 

48 Returns the radial observable positions (using the flux-surface-label R^u - R^u_omp) 

49 with a mask applied if applicable 

50 """ 

51 return self._positions_rsep[self.mask] 

52 

53 @property 

54 def positions_zx(self) -> Quantity: 

55 """ 

56 Returns the vertical observable positions (using the x-point displacement Z - Z_x) 

57 with a mask applied if applicable 

58 """ 

59 return self._positions_zx[self.mask] 

60 

61 def set_mask(self, rsep_min=Quantity(-np.inf, 'm'), rsep_max=Quantity(np.inf, 'm'), 

62 zx_min=Quantity(-np.inf, 'm'), zx_max=Quantity(np.inf, 'm')): 

63 """Constructs an array mask for returning values in a diagnostic of interest, and removing NaN values""" 

64 

65 position_mask = np.logical_and.reduce(( 

66 self._positions_rsep > rsep_min, 

67 self._positions_rsep < rsep_max, 

68 self._positions_zx > zx_min, 

69 self._positions_zx < zx_max, 

70 )) 

71 

72 nan_mask = self.nan_mask() 

73 

74 self.mask = np.logical_and.reduce((position_mask, nan_mask)) 

75 

76 def plot(self, ax: plt.Axes = None, plot_type: str = 'values', units: str = None, cbar_lim=None, 

77 log_cbar: bool = False, **kwargs): 

78 """ 

79 Plots the observable 

80 """ 

81 if ax is None: 

82 ax = plt.gca() 

83 

84 if plot_type == 'values': 

85 image = self.plot_values(ax, units=units, cbar_lim=cbar_lim, log_cbar=log_cbar, **kwargs) 

86 elif plot_type == 'errors': 

87 image = self.plot_values(ax, show_errors=True, units=units, cbar_lim=cbar_lim, log_cbar=log_cbar, **kwargs) 

88 elif plot_type == 'sample_points': 

89 image = self.plot_sample_points(ax) 

90 else: 

91 raise NotImplementedError(f"No implementation for plot_type = {plot_type}") 

92 

93 ax.set_xlabel(r'$R^u - R^u_{sep}$' + f' [{ax.xaxis.units}]') 

94 ax.set_ylabel(r'$Z - Z_X$' + f' [{ax.yaxis.units}]') 

95 

96 return image 

97 

98 def plot_sample_points(self, ax): 

99 """ 

100 Plot the positions where the data is defined 

101 """ 

102 return ax.scatter(self.positions_rsep, self.positions_zx, color=self.color, label=self.label) 

103 

104 def _get_gridded_values(self, rsep_samples: int = 100, zx_samples: int = 150, 

105 errors: bool = False, units: str = None) -> [Quantity, Quantity, Quantity]: 

106 """Interpolates the values onto a regular 2D mesh""" 

107 rsep_basis = np.linspace(self.positions_rsep.min(), self.positions_rsep.max(), num=rsep_samples) 

108 zx_basis = np.linspace(self.positions_zx.min(), self.positions_zx.max(), num=zx_samples) 

109 

110 rsep_mesh, zx_mesh = np.meshgrid(rsep_basis, zx_basis) 

111 

112 gridded_values = griddata((self.positions_rsep.magnitude, self.positions_zx.magnitude), 

113 self.values.magnitude if not errors else self.errors.magnitude, 

114 (rsep_mesh.magnitude, zx_mesh.magnitude), 

115 method='linear', rescale=True 

116 ) 

117 

118 gridded_values = Quantity(gridded_values, self.units).to(self.compact_units if not units else units) 

119 

120 return rsep_basis, zx_basis, gridded_values 

121 

122 def plot_values(self, ax, show_errors: bool = False, 

123 units: str = None, cbar_lim=None, log_cbar: bool = False, 

124 n_contours=15, diverging: bool = False, robust: bool = True, 

125 **kwargs): 

126 """ 

127 Plot the values or errors as a 2D filled grid 

128 """ 

129 from tcvx21.plotting.labels_m import make_colorbar 

130 

131 rsep_basis, zx_basis, gridded_values = self._get_gridded_values(errors=show_errors, units=units) 

132 

133 if robust: 

134 # Use the 2% and 98% quantiles, rather than min and max, to remove outliers 

135 vmin, vmax = np.nanquantile(gridded_values, q=[0.02, 0.98]).magnitude 

136 else: 

137 vmin, vmax = np.nanquantile(gridded_values, q=[0.0, 1.0]).magnitude 

138 if cbar_lim is not None: 

139 vmin, vmax = cbar_lim.to(gridded_values.units).magnitude 

140 if diverging: 

141 abs_vmax = max(np.abs(vmin), np.abs(vmax)) 

142 vmin, vmax = -abs_vmax, abs_vmax 

143 

144 if not log_cbar: 

145 cnorm = Normalize(vmin, vmax) 

146 levels = np.linspace(vmin, vmax, num=n_contours, endpoint=True) 

147 elif vmin < 0.0 or diverging: 

148 raise NotImplementedError('Symmetric logarithmic plots not supported') 

149 else: 

150 cnorm = LogNorm(vmin, vmax) 

151 levels = np.logspace(np.log10(vmin), np.log10(vmax), num=n_contours, endpoint=True) 

152 

153 image = ax.contourf(rsep_basis, zx_basis, gridded_values.magnitude, 

154 norm=cnorm, levels=levels, extend='both', **kwargs) 

155 

156 ax.xaxis.units, ax.yaxis.units = str(rsep_basis.units), str(zx_basis.units) 

157 ax.set_title(self.label) 

158 

159 if cbar_lim is None: 

160 cbar = make_colorbar(ax, mappable=image, units=gridded_values.units, as_title=True) 

161 if log_cbar: 

162 cbar.ax.set_yticklabels('') 

163 

164 return image 

165 

166 def calculate_cbar_limits(self, others: list = None, robust: bool = True) -> Quantity: 

167 """ 

168 Calculates a constant colormap normalisation which captures the range of all values, including from other 

169 observables passed as others 

170 """ 

171 if others is None: 

172 others = [] 

173 

174 values = self.values 

175 

176 for other in others: 

177 if other.is_empty: 

178 continue 

179 

180 values = np.append(values, other.values) 

181 

182 if robust: 

183 return np.quantile(values, q=[0.02, 0.98]) 

184 else: 

185 return np.quantile(values, q=[0.0, 1.0]) 

186 

187 def interpolate_onto_reference(self, reference, plot_comparison: bool = False): 

188 """ 

189 Linearly interpolates a 2D array of simulation values onto the points where reference data is 

190 given. Extrapolation of simulation data is not allowed. NaNs will be returned for points 

191 outside the convex hull of points 

192 

193 It is assumed that the simulation data is at a superset of the experimental points. 

194 

195 If you would like to check the interpolation, you can set the plot_comparison flag to true 

196 """ 

197 

198 rs_ref, rs_test = reference.positions_rsep.magnitude, self.positions_rsep.magnitude 

199 zx_ref, zx_test = reference.positions_zx.magnitude, self.positions_zx.magnitude 

200 

201 interpolated_value = Quantity(griddata((rs_test, zx_test), self.values.magnitude, (rs_ref, zx_ref)), self.units) 

202 interpolated_error = Quantity(griddata((rs_test, zx_test), self.errors.magnitude, (rs_ref, zx_ref)), self.units) 

203 

204 if plot_comparison: 

205 plt.plot(np.arange(reference.npts), reference.values, label='reference') 

206 plt.fill_between(np.arange(reference.npts), 

207 reference.values + reference.errors, 

208 reference.values - reference.errors) 

209 plt.plot(np.arange(reference.npts), interpolated_value, label='test') 

210 plt.fill_between(np.arange(reference.npts), 

211 interpolated_value + interpolated_error, 

212 interpolated_value - interpolated_error) 

213 plt.legend() 

214 

215 result = object.__new__(self.__class__) 

216 self.fill_attributes(result) 

217 result._positions_rsep = rs_ref 

218 result._positions_zx = zx_ref 

219 result._values = interpolated_value 

220 result._errors = interpolated_error 

221 result.mask = ~np.logical_or(np.isnan(interpolated_value), np.isnan(interpolated_error)) 

222 

223 return result