Coverage for tcvx21/observable_c/observable_1d_m.py: 91%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

116 statements  

1""" 

2A class for 1D observables like the density profile along a lineout 

3""" 

4 

5from tcvx21 import Quantity 

6from tcvx21.units_m import pint 

7import numpy as np 

8import matplotlib.pyplot as plt 

9from scipy.interpolate import interp1d 

10from .observable_m import Observable, MissingDataError 

11 

12 

13class Observable1D(Observable): 

14 

15 def __init__(self, data, diagnostic, observable, label, color, linestyle): 

16 self.name = '' 

17 self.label = '' 

18 self.color = '' 

19 self.linestyle = '' 

20 self.diagnostic = '' 

21 self.observable = '' 

22 self.dimensionality = -1 

23 self.experimental_hierarchy = -1 

24 self.simulation_hierarchy = -1 

25 self._values = [] 

26 self._errors = [] 

27 self.mask = [] 

28 

29 super().__init__(data, diagnostic, observable, label, color, linestyle) 

30 self._positions_rsep = Quantity(data['Rsep_omp'][:], data['Rsep_omp'].units).to('cm') 

31 self.set_mask() 

32 self.xmin, self.xmax, self.ymin, self.ymax = None, None, None, None 

33 

34 def check_dimensionality(self): 

35 assert self.dimensionality == 1 

36 

37 @property 

38 def positions(self) -> Quantity: 

39 """ 

40 Returns the observable positions (using the flux-surface-label R^u - R^u_omp 

41 with a mask applied if applicable 

42 """ 

43 return self._positions_rsep[self.mask] 

44 

45 def _position_mask(self, position_min=Quantity(-np.inf, 'm'), position_max=Quantity(np.inf, 'm')): 

46 return np.logical_and( 

47 self._positions_rsep > position_min, 

48 self._positions_rsep < position_max, 

49 ) 

50 

51 def set_mask(self, position_min=Quantity(-np.inf, 'm'), position_max=Quantity(np.inf, 'm')): 

52 """Constructs an array mask for returning values in a diagnostic of interest, and removing NaN values""" 

53 

54 self.mask = np.logical_and.reduce((self._position_mask(position_min, position_max), self.nan_mask())) 

55 

56 @property 

57 def xlim(self): 

58 """Returns a tuple giving the bounds of the positions (after the mask is applied)""" 

59 return (Quantity(self.xmin, self.positions.units) if self.xmin is not None else np.min(self.positions), 

60 Quantity(self.xmax, self.positions.units) if self.xmax is not None else np.max(self.positions)) 

61 

62 def set_plot_limits(self, xmin=None, xmax=None, ymin=None, ymax=None): 

63 """Sets value limits (i.e. ylimits) for plotting""" 

64 if xmin is not None: 

65 self.xmin = xmin 

66 if xmax is not None: 

67 self.xmax = xmax 

68 if ymin is not None: 

69 self.ymin = ymin 

70 if ymax is not None: 

71 self.ymax = ymax 

72 

73 def apply_plot_limits(self, ax): 

74 """Sets the plot limits if they are non-None""" 

75 if self.xmin is not None: 

76 ax.set_xlim(left=self.xmin) 

77 

78 if self.xmax is not None: 

79 ax.set_xlim(right=self.xmax) 

80 

81 if self.ymin is not None: 

82 ax.set_ylim(bottom=self.ymin) 

83 

84 if self.ymax is not None: 

85 ax.set_ylim(top=self.ymax) 

86 

87 def trim_mask(self, trim_to_x: tuple, trim_expansion=0.1): 

88 trim_to_x = list(trim_to_x) 

89 assert len(trim_to_x) == 2 

90 

91 if not isinstance(trim_to_x[0], Quantity): 

92 trim_to_x[0] = Quantity(trim_to_x[0], 'cm') 

93 if not isinstance(trim_to_x[1], Quantity): 

94 trim_to_x[1] = Quantity(trim_to_x[1], 'cm') 

95 

96 trim_range = trim_to_x[0] - trim_expansion*(trim_to_x[1]-trim_to_x[0]), \ 

97 trim_to_x[1] + trim_expansion*(trim_to_x[1]-trim_to_x[0]) 

98 mask = np.logical_and(self.mask, 

99 self._position_mask(*trim_range)).astype(bool) 

100 return mask 

101 

102 def values_in_trim(self, trim_to_x: tuple, trim_expansion=0.1): 

103 mask = self.trim_mask(trim_to_x, trim_expansion) 

104 return self._values[mask] 

105 

106 def ylim_in_trim(self, trim_to_x: tuple, trim_expansion=0.1): 

107 values = self.values_in_trim(trim_to_x, trim_expansion) 

108 return values.min(), values.max() 

109 

110 def ylims_in_trim(self, others, trim_to_x: tuple, trim_expansion=0.1, range_expansion=0.1): 

111 ymin, ymax = Quantity(np.zeros(len(others)+1), self.units), Quantity(np.zeros(len(others)+1), self.units) 

112 

113 ymin[0], ymax[0] = self.ylim_in_trim(trim_to_x, trim_expansion) 

114 

115 for i, case in enumerate(others): 

116 ymin[i+1], ymax[i+1] = case.ylim_in_trim(trim_to_x, trim_expansion) 

117 

118 ymin, ymax = ymin.min(), ymax.max() 

119 

120 return ymin - range_expansion * (ymax-ymin), ymax + range_expansion * (ymax-ymin) 

121 

122 def plot(self, ax: plt.Axes = None, plot_type: str = 'region', trim_to_x: tuple = tuple(), 

123 trim_expansion=0.1, **kwargs): 

124 """ 

125 Plots the observable 

126 

127 Can set "plot_type" to give either an errorbar plot or a shaded region around a mean line 

128 

129 If trim_to_x is passed, it should be a tuple of length 2, giving the min and max positions for the 

130 values to plot. 

131 """ 

132 if ax is None: 

133 ax = plt.gca() 

134 

135 original_mask = np.copy(self.mask) 

136 

137 if trim_to_x: 

138 self.mask = self.trim_mask(trim_to_x, trim_expansion) 

139 

140 if plot_type == 'errorbar': 

141 line = self.plot_errorbar(ax, **kwargs) 

142 elif plot_type == 'region': 

143 line = self.plot_region(ax, **kwargs) 

144 else: 

145 raise NotImplementedError(f"No implementation for plot_type = {plot_type}") 

146 

147 self.apply_plot_limits(ax) 

148 

149 ax.set_xlabel(r'$R^u - R^u_{sep}$' + f' [{ax.xaxis.units}]') 

150 

151 self.mask = original_mask 

152 

153 return line 

154 

155 def plot_errorbar(self, ax: plt.Axes, color=None, label=None, linestyle=None, 

156 errorevery=5, **kwargs): 

157 """ 

158 Makes an errorbar plot of the observable (not recommended) 

159 """ 

160 return ax.errorbar(self.positions, 

161 self.values.to(self.compact_units), 

162 self.errors.to(self.compact_units), 

163 color=self.color if color is None else color, 

164 label=self.label if label is None else label, 

165 linestyle=self.linestyle if linestyle is None else linestyle, 

166 errorevery=errorevery, 

167 **kwargs)[0] 

168 

169 def plot_region(self, ax: plt.Axes, color=None, label=None, linestyle=None, **kwargs): 

170 """ 

171 Makes a plot of the observable, with the error represented by a shaded region around the line 

172 """ 

173 line, = ax.plot(self.positions, 

174 self.values.to(self.compact_units), 

175 color=self.color if color is None else color, 

176 label=self.label if label is None else label, 

177 linestyle=self.linestyle if linestyle is None else linestyle, 

178 **kwargs) 

179 

180 ax.fill_between(self.positions, 

181 self.values + self.errors, 

182 self.values - self.errors, 

183 alpha=0.25, color=self.color if color is None else color) 

184 return line 

185 

186 @staticmethod 

187 def _interpolate_points(x_source, y_source, x_query): 

188 """Calls interp1d for quantities""" 

189 interpolator = \ 

190 interp1d(x=x_source.to(x_query.units).magnitude, 

191 y=y_source.magnitude, 

192 kind='cubic', bounds_error=True) 

193 

194 interpolated = interpolator(np.array(x_query.magnitude)) 

195 return Quantity(interpolated, y_source.units) 

196 

197 def points_overlap(self, reference_positions): 

198 """ 

199 Returns a boolean array which can be used to mask reference values so that interpolate_onto_positions 

200 only interpolates (no extrapolation) 

201 """ 

202 return np.logical_and( 

203 reference_positions > self.positions.min(), 

204 reference_positions < self.positions.max()) 

205 

206 def interpolate_onto_positions(self, reference_positions): 

207 """ 

208 Interpolates a 1D array of simulation values onto the points where reference data is 

209 given. Extrapolation of simulation data is not allowed: instead, the reference data 

210 is cropped to the range of the simulation data. 

211 """ 

212 

213 interpolated_value = self._interpolate_points(self.positions, 

214 self.values, 

215 reference_positions) 

216 

217 interpolated_error = self._interpolate_points(self.positions, 

218 self.errors, 

219 reference_positions) 

220 

221 assert np.allclose(reference_positions.size, interpolated_value.size, interpolated_error.size) 

222 result = object.__new__(self.__class__) 

223 self.fill_attributes(result) 

224 

225 result._positions_rsep = reference_positions 

226 result._values = interpolated_value 

227 result._errors = interpolated_error 

228 result.mask = np.ones_like(reference_positions).astype(bool) 

229 

230 return result