Coverage for tcvx21/observable_c/observable_1d_m.py: 91%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2A class for 1D observables like the density profile along a lineout
3"""
5from tcvx21 import Quantity
6from tcvx21.units_m import pint
7import numpy as np
8import matplotlib.pyplot as plt
9from scipy.interpolate import interp1d
10from .observable_m import Observable, MissingDataError
13class Observable1D(Observable):
15 def __init__(self, data, diagnostic, observable, label, color, linestyle):
16 self.name = ''
17 self.label = ''
18 self.color = ''
19 self.linestyle = ''
20 self.diagnostic = ''
21 self.observable = ''
22 self.dimensionality = -1
23 self.experimental_hierarchy = -1
24 self.simulation_hierarchy = -1
25 self._values = []
26 self._errors = []
27 self.mask = []
29 super().__init__(data, diagnostic, observable, label, color, linestyle)
30 self._positions_rsep = Quantity(data['Rsep_omp'][:], data['Rsep_omp'].units).to('cm')
31 self.set_mask()
32 self.xmin, self.xmax, self.ymin, self.ymax = None, None, None, None
34 def check_dimensionality(self):
35 assert self.dimensionality == 1
37 @property
38 def positions(self) -> Quantity:
39 """
40 Returns the observable positions (using the flux-surface-label R^u - R^u_omp
41 with a mask applied if applicable
42 """
43 return self._positions_rsep[self.mask]
45 def _position_mask(self, position_min=Quantity(-np.inf, 'm'), position_max=Quantity(np.inf, 'm')):
46 return np.logical_and(
47 self._positions_rsep > position_min,
48 self._positions_rsep < position_max,
49 )
51 def set_mask(self, position_min=Quantity(-np.inf, 'm'), position_max=Quantity(np.inf, 'm')):
52 """Constructs an array mask for returning values in a diagnostic of interest, and removing NaN values"""
54 self.mask = np.logical_and.reduce((self._position_mask(position_min, position_max), self.nan_mask()))
56 @property
57 def xlim(self):
58 """Returns a tuple giving the bounds of the positions (after the mask is applied)"""
59 return (Quantity(self.xmin, self.positions.units) if self.xmin is not None else np.min(self.positions),
60 Quantity(self.xmax, self.positions.units) if self.xmax is not None else np.max(self.positions))
62 def set_plot_limits(self, xmin=None, xmax=None, ymin=None, ymax=None):
63 """Sets value limits (i.e. ylimits) for plotting"""
64 if xmin is not None:
65 self.xmin = xmin
66 if xmax is not None:
67 self.xmax = xmax
68 if ymin is not None:
69 self.ymin = ymin
70 if ymax is not None:
71 self.ymax = ymax
73 def apply_plot_limits(self, ax):
74 """Sets the plot limits if they are non-None"""
75 if self.xmin is not None:
76 ax.set_xlim(left=self.xmin)
78 if self.xmax is not None:
79 ax.set_xlim(right=self.xmax)
81 if self.ymin is not None:
82 ax.set_ylim(bottom=self.ymin)
84 if self.ymax is not None:
85 ax.set_ylim(top=self.ymax)
87 def trim_mask(self, trim_to_x: tuple, trim_expansion=0.1):
88 trim_to_x = list(trim_to_x)
89 assert len(trim_to_x) == 2
91 if not isinstance(trim_to_x[0], Quantity):
92 trim_to_x[0] = Quantity(trim_to_x[0], 'cm')
93 if not isinstance(trim_to_x[1], Quantity):
94 trim_to_x[1] = Quantity(trim_to_x[1], 'cm')
96 trim_range = trim_to_x[0] - trim_expansion*(trim_to_x[1]-trim_to_x[0]), \
97 trim_to_x[1] + trim_expansion*(trim_to_x[1]-trim_to_x[0])
98 mask = np.logical_and(self.mask,
99 self._position_mask(*trim_range)).astype(bool)
100 return mask
102 def values_in_trim(self, trim_to_x: tuple, trim_expansion=0.1):
103 mask = self.trim_mask(trim_to_x, trim_expansion)
104 return self._values[mask]
106 def ylim_in_trim(self, trim_to_x: tuple, trim_expansion=0.1):
107 values = self.values_in_trim(trim_to_x, trim_expansion)
108 return values.min(), values.max()
110 def ylims_in_trim(self, others, trim_to_x: tuple, trim_expansion=0.1, range_expansion=0.1):
111 ymin, ymax = Quantity(np.zeros(len(others)+1), self.units), Quantity(np.zeros(len(others)+1), self.units)
113 ymin[0], ymax[0] = self.ylim_in_trim(trim_to_x, trim_expansion)
115 for i, case in enumerate(others):
116 ymin[i+1], ymax[i+1] = case.ylim_in_trim(trim_to_x, trim_expansion)
118 ymin, ymax = ymin.min(), ymax.max()
120 return ymin - range_expansion * (ymax-ymin), ymax + range_expansion * (ymax-ymin)
122 def plot(self, ax: plt.Axes = None, plot_type: str = 'region', trim_to_x: tuple = tuple(),
123 trim_expansion=0.1, **kwargs):
124 """
125 Plots the observable
127 Can set "plot_type" to give either an errorbar plot or a shaded region around a mean line
129 If trim_to_x is passed, it should be a tuple of length 2, giving the min and max positions for the
130 values to plot.
131 """
132 if ax is None:
133 ax = plt.gca()
135 original_mask = np.copy(self.mask)
137 if trim_to_x:
138 self.mask = self.trim_mask(trim_to_x, trim_expansion)
140 if plot_type == 'errorbar':
141 line = self.plot_errorbar(ax, **kwargs)
142 elif plot_type == 'region':
143 line = self.plot_region(ax, **kwargs)
144 else:
145 raise NotImplementedError(f"No implementation for plot_type = {plot_type}")
147 self.apply_plot_limits(ax)
149 ax.set_xlabel(r'$R^u - R^u_{sep}$' + f' [{ax.xaxis.units}]')
151 self.mask = original_mask
153 return line
155 def plot_errorbar(self, ax: plt.Axes, color=None, label=None, linestyle=None,
156 errorevery=5, **kwargs):
157 """
158 Makes an errorbar plot of the observable (not recommended)
159 """
160 return ax.errorbar(self.positions,
161 self.values.to(self.compact_units),
162 self.errors.to(self.compact_units),
163 color=self.color if color is None else color,
164 label=self.label if label is None else label,
165 linestyle=self.linestyle if linestyle is None else linestyle,
166 errorevery=errorevery,
167 **kwargs)[0]
169 def plot_region(self, ax: plt.Axes, color=None, label=None, linestyle=None, **kwargs):
170 """
171 Makes a plot of the observable, with the error represented by a shaded region around the line
172 """
173 line, = ax.plot(self.positions,
174 self.values.to(self.compact_units),
175 color=self.color if color is None else color,
176 label=self.label if label is None else label,
177 linestyle=self.linestyle if linestyle is None else linestyle,
178 **kwargs)
180 ax.fill_between(self.positions,
181 self.values + self.errors,
182 self.values - self.errors,
183 alpha=0.25, color=self.color if color is None else color)
184 return line
186 @staticmethod
187 def _interpolate_points(x_source, y_source, x_query):
188 """Calls interp1d for quantities"""
189 interpolator = \
190 interp1d(x=x_source.to(x_query.units).magnitude,
191 y=y_source.magnitude,
192 kind='cubic', bounds_error=True)
194 interpolated = interpolator(np.array(x_query.magnitude))
195 return Quantity(interpolated, y_source.units)
197 def points_overlap(self, reference_positions):
198 """
199 Returns a boolean array which can be used to mask reference values so that interpolate_onto_positions
200 only interpolates (no extrapolation)
201 """
202 return np.logical_and(
203 reference_positions > self.positions.min(),
204 reference_positions < self.positions.max())
206 def interpolate_onto_positions(self, reference_positions):
207 """
208 Interpolates a 1D array of simulation values onto the points where reference data is
209 given. Extrapolation of simulation data is not allowed: instead, the reference data
210 is cropped to the range of the simulation data.
211 """
213 interpolated_value = self._interpolate_points(self.positions,
214 self.values,
215 reference_positions)
217 interpolated_error = self._interpolate_points(self.positions,
218 self.errors,
219 reference_positions)
221 assert np.allclose(reference_positions.size, interpolated_value.size, interpolated_error.size)
222 result = object.__new__(self.__class__)
223 self.fill_attributes(result)
225 result._positions_rsep = reference_positions
226 result._values = interpolated_value
227 result._errors = interpolated_error
228 result.mask = np.ones_like(reference_positions).astype(bool)
230 return result