Coverage for tcvx21/observable_c/observable_2d_m.py: 93%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1from tcvx21 import Quantity
2import numpy as np
3import matplotlib.pyplot as plt
4from scipy.interpolate import griddata
5from .observable_m import Observable, MissingDataError
6from matplotlib.colors import Normalize, LogNorm
9class Observable2D(Observable):
11 def __init__(self, data, diagnostic, observable, label, color, linestyle):
12 self.name = ''
13 self.label = ''
14 self.color = ''
15 self.linestyle = ''
16 self.diagnostic = ''
17 self.observable = ''
18 self.dimensionality = -1
19 self.experimental_hierarchy = -1
20 self.simulation_hierarchy = -1
21 self._values = []
22 self._errors = []
23 self.mask = []
25 try:
26 super().__init__(data, diagnostic, observable, label, color, linestyle)
27 self._positions_rsep = Quantity(data['Rsep_omp'][:], data['Rsep_omp'].units).to('cm')
28 self._positions_zx = Quantity(data['Zx'][:], data['Zx'].units).to('m')
30 try:
31 # If the (R, Z) coordinates for the reference equilibrium are available, also load them
32 self._positions_r = Quantity(data['R'][:], data['R'].units).to('m')
33 self._positions_z = Quantity(data['Z'][:], data['Z'].units).to('m')
34 except IndexError:
35 pass
37 self.set_mask()
39 except AttributeError:
40 raise MissingDataError
42 def check_dimensionality(self):
43 assert self.dimensionality == 2
45 @property
46 def positions_rsep(self) -> Quantity:
47 """
48 Returns the radial observable positions (using the flux-surface-label R^u - R^u_omp)
49 with a mask applied if applicable
50 """
51 return self._positions_rsep[self.mask]
53 @property
54 def positions_zx(self) -> Quantity:
55 """
56 Returns the vertical observable positions (using the x-point displacement Z - Z_x)
57 with a mask applied if applicable
58 """
59 return self._positions_zx[self.mask]
61 def set_mask(self, rsep_min=Quantity(-np.inf, 'm'), rsep_max=Quantity(np.inf, 'm'),
62 zx_min=Quantity(-np.inf, 'm'), zx_max=Quantity(np.inf, 'm')):
63 """Constructs an array mask for returning values in a diagnostic of interest, and removing NaN values"""
65 position_mask = np.logical_and.reduce((
66 self._positions_rsep > rsep_min,
67 self._positions_rsep < rsep_max,
68 self._positions_zx > zx_min,
69 self._positions_zx < zx_max,
70 ))
72 nan_mask = self.nan_mask()
74 self.mask = np.logical_and.reduce((position_mask, nan_mask))
76 def plot(self, ax: plt.Axes = None, plot_type: str = 'values', units: str = None, cbar_lim=None,
77 log_cbar: bool = False, **kwargs):
78 """
79 Plots the observable
80 """
81 if ax is None:
82 ax = plt.gca()
84 if plot_type == 'values':
85 image = self.plot_values(ax, units=units, cbar_lim=cbar_lim, log_cbar=log_cbar, **kwargs)
86 elif plot_type == 'errors':
87 image = self.plot_values(ax, show_errors=True, units=units, cbar_lim=cbar_lim, log_cbar=log_cbar, **kwargs)
88 elif plot_type == 'sample_points':
89 image = self.plot_sample_points(ax)
90 else:
91 raise NotImplementedError(f"No implementation for plot_type = {plot_type}")
93 ax.set_xlabel(r'$R^u - R^u_{sep}$' + f' [{ax.xaxis.units}]')
94 ax.set_ylabel(r'$Z - Z_X$' + f' [{ax.yaxis.units}]')
96 return image
98 def plot_sample_points(self, ax):
99 """
100 Plot the positions where the data is defined
101 """
102 return ax.scatter(self.positions_rsep, self.positions_zx, color=self.color, label=self.label)
104 def _get_gridded_values(self, rsep_samples: int = 100, zx_samples: int = 150,
105 errors: bool = False, units: str = None) -> [Quantity, Quantity, Quantity]:
106 """Interpolates the values onto a regular 2D mesh"""
107 rsep_basis = np.linspace(self.positions_rsep.min(), self.positions_rsep.max(), num=rsep_samples)
108 zx_basis = np.linspace(self.positions_zx.min(), self.positions_zx.max(), num=zx_samples)
110 rsep_mesh, zx_mesh = np.meshgrid(rsep_basis, zx_basis)
112 gridded_values = griddata((self.positions_rsep.magnitude, self.positions_zx.magnitude),
113 self.values.magnitude if not errors else self.errors.magnitude,
114 (rsep_mesh.magnitude, zx_mesh.magnitude),
115 method='linear', rescale=True
116 )
118 gridded_values = Quantity(gridded_values, self.units).to(self.compact_units if not units else units)
120 return rsep_basis, zx_basis, gridded_values
122 def plot_values(self, ax, show_errors: bool = False,
123 units: str = None, cbar_lim=None, log_cbar: bool = False,
124 n_contours=15, diverging: bool = False, robust: bool = True,
125 **kwargs):
126 """
127 Plot the values or errors as a 2D filled grid
128 """
129 from tcvx21.plotting.labels_m import make_colorbar
131 rsep_basis, zx_basis, gridded_values = self._get_gridded_values(errors=show_errors, units=units)
133 if robust:
134 # Use the 2% and 98% quantiles, rather than min and max, to remove outliers
135 vmin, vmax = np.nanquantile(gridded_values, q=[0.02, 0.98]).magnitude
136 else:
137 vmin, vmax = np.nanquantile(gridded_values, q=[0.0, 1.0]).magnitude
138 if cbar_lim is not None:
139 vmin, vmax = cbar_lim.to(gridded_values.units).magnitude
140 if diverging:
141 abs_vmax = max(np.abs(vmin), np.abs(vmax))
142 vmin, vmax = -abs_vmax, abs_vmax
144 if not log_cbar:
145 cnorm = Normalize(vmin, vmax)
146 levels = np.linspace(vmin, vmax, num=n_contours, endpoint=True)
147 elif vmin < 0.0 or diverging:
148 raise NotImplementedError('Symmetric logarithmic plots not supported')
149 else:
150 cnorm = LogNorm(vmin, vmax)
151 levels = np.logspace(np.log10(vmin), np.log10(vmax), num=n_contours, endpoint=True)
153 image = ax.contourf(rsep_basis, zx_basis, gridded_values.magnitude,
154 norm=cnorm, levels=levels, extend='both', **kwargs)
156 ax.xaxis.units, ax.yaxis.units = str(rsep_basis.units), str(zx_basis.units)
157 ax.set_title(self.label)
159 if cbar_lim is None:
160 cbar = make_colorbar(ax, mappable=image, units=gridded_values.units, as_title=True)
161 if log_cbar:
162 cbar.ax.set_yticklabels('')
164 return image
166 def calculate_cbar_limits(self, others: list = None, robust: bool = True) -> Quantity:
167 """
168 Calculates a constant colormap normalisation which captures the range of all values, including from other
169 observables passed as others
170 """
171 if others is None:
172 others = []
174 values = self.values
176 for other in others:
177 if other.is_empty:
178 continue
180 values = np.append(values, other.values)
182 if robust:
183 return np.quantile(values, q=[0.02, 0.98])
184 else:
185 return np.quantile(values, q=[0.0, 1.0])
187 def interpolate_onto_reference(self, reference, plot_comparison: bool = False):
188 """
189 Linearly interpolates a 2D array of simulation values onto the points where reference data is
190 given. Extrapolation of simulation data is not allowed. NaNs will be returned for points
191 outside the convex hull of points
193 It is assumed that the simulation data is at a superset of the experimental points.
195 If you would like to check the interpolation, you can set the plot_comparison flag to true
196 """
198 rs_ref, rs_test = reference.positions_rsep.magnitude, self.positions_rsep.magnitude
199 zx_ref, zx_test = reference.positions_zx.magnitude, self.positions_zx.magnitude
201 interpolated_value = Quantity(griddata((rs_test, zx_test), self.values.magnitude, (rs_ref, zx_ref)), self.units)
202 interpolated_error = Quantity(griddata((rs_test, zx_test), self.errors.magnitude, (rs_ref, zx_ref)), self.units)
204 if plot_comparison:
205 plt.plot(np.arange(reference.npts), reference.values, label='reference')
206 plt.fill_between(np.arange(reference.npts),
207 reference.values + reference.errors,
208 reference.values - reference.errors)
209 plt.plot(np.arange(reference.npts), interpolated_value, label='test')
210 plt.fill_between(np.arange(reference.npts),
211 interpolated_value + interpolated_error,
212 interpolated_value - interpolated_error)
213 plt.legend()
215 result = object.__new__(self.__class__)
216 self.fill_attributes(result)
217 result._positions_rsep = rs_ref
218 result._positions_zx = zx_ref
219 result._values = interpolated_value
220 result._errors = interpolated_error
221 result.mask = ~np.logical_or(np.isnan(interpolated_value), np.isnan(interpolated_error))
223 return result