Coverage for tcvx21/grillix_post/work_file_writer_m.py: 90%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2Data extraction from GRILLIX is a two-step process. The first step is to make a "work file" of the
3observable signal at each tau point and toroidal plane. The second is take statistics over this file, and
4write a NetCDF which is compatible with the validation analysis
6This module performs the first step
7"""
8import matplotlib.pyplot as plt
9from pathlib import Path
10import xarray as xr
11from netCDF4 import Dataset
12import numpy as np
14import tcvx21
15from tcvx21.units_m import Quantity, Dimensionless, convert_xarray_to_quantity
16from tcvx21.file_io.json_io_m import read_from_json
18from tcvx21.grillix_post.components import Grid, Equi, Normalisation,\
19 read_snaps_from_file, get_snap_length_from_file, read_fortran_namelist, convert_params_filepaths, integrated_sources
21from tcvx21.grillix_post.lineouts import OutboardMidplaneMap, outboard_midplane_chord, penalisation_contour,\
22 thomson_scattering, rdpa, xpoint
24from tcvx21.grillix_post.observables import floating_potential, ion_saturation_current, sound_speed, \
25 compute_parallel_gradient, initialise_lineout_for_parallel_gradient, total_parallel_heat_flux, \
26 effective_parallel_exb_velocity
28xr.set_options(keep_attrs=True)
29search_paths = ['.', 'trunk', '..']
32def filepath_resolver(directory: Path, search_file: str):
33 """
34 Checks relative paths specified in search_paths for a specified file or folder
35 """
37 for search_path in search_paths:
38 search_directory = directory / search_path
40 if search_directory.exists():
41 if search_file in [file.name for file in search_directory.iterdir()]:
42 found_file = search_directory / search_file
44 assert found_file.exists()
45 return found_file.absolute()
47 raise FileNotFoundError(f"{search_file} not found in {directory}")
50class WorkFileWriter:
51 """
52 Class to extract data from the raw GRILLIX snapshots.
54 Writes to a "work file"
55 """
57 def __init__(self, file_path: Path,
58 work_file: Path,
59 equi_file: str = 'TCV_ortho.nc',
60 toroidal_field_direction: str = 'reverse',
61 data_length: int = 1000,
62 n_points: int = 500,
63 make_work_file: bool = True
64 ):
65 """
66 Initialises an object to store general information about a run, as well as lineouts for the tcvx21 case, and
67 methods to calculate experimental quantities
68 """
69 assert toroidal_field_direction in ['forward', 'reverse'], \
70 f"toroidal_field_direction must be forward or reverse, but was {toroidal_field_direction}"
72 self.file_path = Path(file_path)
73 self.work_file = Path(work_file)
74 self.data_length = data_length
76 assert (self.file_path / 'description.txt').exists(), f"Need to have a description.txt file in file_path"
78 print("Setting up core analysis")
79 self.grid = Grid(filepath_resolver(file_path, 'vgrid.nc'))
80 self.norm = Normalisation.initialise_from_normalisation_file(
81 filepath_resolver(file_path, 'physical_parameters.nml'))
82 self.equi = Equi(equi_file=filepath_resolver(file_path, equi_file),
83 penalisation_file=filepath_resolver(file_path, 'pen_metainfo.nc'),
84 flip_z=False if toroidal_field_direction == 'reverse' else True)
86 parameter_filepath = filepath_resolver(file_path, 'params.in')
87 self.params = convert_params_filepaths(parameter_filepath, read_fortran_namelist(parameter_filepath))
89 self.diagnostic_to_lineout = {'LFS-LP': 'lfs', 'LFS-IR': 'lfs', 'HFS-LP': 'hfs', 'FHRP': 'omp',
90 'TS': 'ts', 'RDPA': 'rdpa', 'Xpt': 'xpt'}
92 print("Making lineouts")
94 self.omp_map = OutboardMidplaneMap(self.grid, self.equi, self.norm)
95 omp = outboard_midplane_chord(self.grid, self.equi, n_points=n_points)
96 lfs = penalisation_contour(self.grid, self.equi, level=0.0, contour_index=0, n_points=n_points)
97 hfs = penalisation_contour(self.grid, self.equi, level=0.0, contour_index=1, n_points=n_points)
99 ts = thomson_scattering(self.grid, self.equi,
100 tcvx21.thomson_coords_json, n_points=n_points)
101 rdpa_ = rdpa(self.grid, self.equi, self.omp_map, self.norm,
102 tcvx21.rdpa_coords_json)
103 xpt = xpoint(self.grid, self.equi, self.norm)
105 print("Tracing for parallel gradient")
106 initialise_lineout_for_parallel_gradient(lfs, self.grid, self.equi, self.norm,
107 npol=self.params['params_grid']['npol'],
108 stored_trace=file_path/f'lfs_trace_{n_points}.nc')
110 self.lineouts = {
111 'omp': omp,
112 'lfs': lfs,
113 'hfs': hfs,
114 'ts': ts,
115 'rdpa': rdpa_,
116 'xpt': xpt
117 }
119 if not work_file.exists() and make_work_file:
120 print("Initialising work file")
121 self.initialise_work_file()
122 print("Work file initialised")
124 print('Done')
126 def initialise_work_file(self):
127 """
128 Initialises a NetCDF file which can be iteratively filled with observables
129 """
131 dataset = Dataset(self.work_file, 'w', format="NETCDF4")
132 snaps_length = get_snap_length_from_file(file_path=self.file_path)
134 dataset.first_snap = np.max((snaps_length - self.data_length, 0))
135 print(f"Reading data from snap {dataset.first_snap} to last snap"
136 f"(currently {snaps_length}), length {self.data_length}")
137 dataset.last_snap = dataset.first_snap
139 self.write_summary(dataset.last_snap)
141 dataset.createDimension(dimname='tau', size=None)
142 dataset.createDimension(dimname='phi', size=self.params['params_grid']['npol'])
144 standard_dict = read_from_json(tcvx21.template_file)
146 # Expand the dictionary template to also include a region around the X-point
147 standard_dict['Xpt'] = {'observables': {
148 'density': {'units': '1 / meter ** 3'},
149 'electron_temp': {'units': 'electron_volt'},
150 'potential': {'units': 'volt'}}}
152 dataset.createVariable(varname='time', datatype=np.float64, dimensions=('tau',))
153 dataset['time'].units = 'milliseconds'
154 dataset.createVariable(varname='snap_index', datatype=np.float64, dimensions=('tau',))
156 for diagnostic, diagnostic_dict in standard_dict.items():
157 diagnostic_group = dataset.createGroup(diagnostic)
159 lineout_key = self.diagnostic_to_lineout[diagnostic]
160 diagnostic_group.lineout_key = lineout_key
161 lineout = self.lineouts[diagnostic_group.lineout_key]
163 diagnostic_group.createDimension(dimname='points', size=lineout.r_points.size)
164 diagnostic_group.createVariable(varname='R', datatype=np.float64, dimensions=('points',))
165 diagnostic_group.createVariable(varname='Z', datatype=np.float64, dimensions=('points',))
166 diagnostic_group['R'][:] = lineout.r_points * self.norm.R0.to('m').magnitude
167 diagnostic_group['R'].units = 'meter'
168 diagnostic_group['Z'][:] = lineout.z_points * self.norm.R0.to('m').magnitude * \
169 (-1.0 if self.equi.flipped_z else 1.0)
170 diagnostic_group['Z'].units = 'meter'
172 diagnostic_group.createVariable(varname='rho', datatype=np.float64, dimensions=('points',))
173 diagnostic_group.createVariable(varname='Rsep', datatype=np.float64, dimensions=('points',))
174 diagnostic_group['rho'][:] = self.lineout_rho(lineout_key).values
175 diagnostic_group['Rsep'][:] = self.lineout_rsep(lineout_key).to('meter').magnitude
176 diagnostic_group['Rsep'].units = 'meter'
178 if hasattr(lineout, 'coords'):
179 diagnostic_group.createVariable(varname='Zx', datatype=np.float64, dimensions=('points',))
180 diagnostic_group['Zx'][:] = lineout.coords['Zx'].to('m').magnitude
181 diagnostic_group['Zx'].units = 'meter'
183 observables_group = diagnostic_group.createGroup('observables')
185 for observable, observable_dict in diagnostic_dict['observables'].items():
187 if observable.endswith('_std') or observable.endswith('_skew') or observable.endswith(
188 '_kurtosis'):
189 # Don't calculate statistics or fits at this point
190 continue
191 observable_var = observables_group.createVariable(
192 varname=observable, datatype=np.float64, dimensions=('tau', 'phi', 'points'))
194 observable_var.units = observable_dict['units']
196 dataset.close()
198 def write_summary(self, snap):
200 snaps = read_snaps_from_file(self.file_path, self.norm, time_slice=[snap], all_planes=True).persist()
201 sources = integrated_sources(self.grid, self.equi, self.norm, self.params, snaps)
203 with Dataset(self.work_file, 'a') as dataset:
204 dataset.description = (self.file_path / 'description.txt').read_text()
205 dataset.toroidal_planes = snaps.sizes['phi']
206 dataset.points_per_plane = snaps.sizes['points']
207 dataset.n_points_R = self.grid.r_s.size
208 dataset.n_points_Z = self.grid.z_s.size
210 timestep = (self.params['params_tstep']['dtau_max'] * self.norm.tau_0).to_compact()
211 dataset.timestep = timestep.magnitude
212 dataset.timestep_units = str(f"{timestep.units:P}")
214 dataset.particle_source = sources['density_source'].magnitude
215 dataset.particle_source_units = str(f"{sources['density_source'].units:P}")
216 dataset.power_source = sources['power_source'].magnitude
217 dataset.power_source_units = str(f"{sources['power_source'].units:P}")
219 snaps.close()
221 def fill_work_file(self):
222 """
223 Iterates over the snaps dataset, to fill the complete dataset. Closes after each time segment, to prevent
224 data loss. This isn't super efficient, but the data filling only has to be run once
226 Summary elements like the description is updated each tau the script is executed
227 """
229 with Dataset(self.work_file, 'a') as dataset:
230 snaps_length = get_snap_length_from_file(file_path=self.file_path)
231 first_snap = dataset.first_snap
232 current_snap = dataset.last_snap
233 last_snap = snaps_length
235 print(f"First snap: {first_snap}, current_snap: {current_snap}, last_snap: {last_snap}")
237 for snap in range(current_snap, last_snap):
238 t = snap - first_snap
240 snaps = read_snaps_from_file(self.file_path,
241 self.norm,
242 time_slice=[snap],
243 all_planes=True).persist()
245 with Dataset(self.work_file, 'a') as dataset:
246 # Open and close the dataset every write, to prevent data-loss
247 # from timing out
249 print(f"Processing snap {snap} of {last_snap} (t={t}/{last_snap - first_snap})")
251 dataset['time'][t] = self.time(snaps=snaps).magnitude
252 dataset['snap_index'][t] = dataset.first_snap + t
254 for diagnostic_group in dataset.groups.values():
255 lineout_key = diagnostic_group.lineout_key
256 for observable_key, observable_var in diagnostic_group['observables'].variables.items():
257 observable = getattr(self, observable_key)
258 values = observable(lineout_key, snaps=snaps).compute()
259 values = values.transpose("tau", "phi", "interp_points")
260 values = convert_xarray_to_quantity(values).to(observable_var.units).magnitude
262 observable_var[t, :, :] = values
264 dataset.last_snap = snap + 1
266 snaps.close()
268 print('Done')
270 def time(self, snaps: xr.Dataset):
271 """Returns an array of tau values (in ms)"""
272 return convert_xarray_to_quantity(snaps.tau).to('ms')
274 def lineout_rho(self, lineout: str) -> xr.DataArray:
275 """Returns the flux-surface labels of a lineout"""
276 lineout_ = self.lineouts[lineout]
277 return self.equi.normalised_flux_surface_label(lineout_.r_points, lineout_.z_points, grid=False)
279 def lineout_rsep(self, lineout: str) -> Quantity:
280 """Returns the OMP-mapped distance along a lineout"""
281 return convert_xarray_to_quantity(self.omp_map.convert_rho_to_distance(self.lineout_rho(lineout)))
283 def density(self, lineout: str, snaps: xr.Dataset) -> xr.DataArray:
284 """Returns the plasma density values along a lineout"""
285 return self.lineouts[lineout].interpolate(snaps.density)
287 def electron_temp(self, lineout: str, snaps: xr.Dataset) -> xr.DataArray:
288 """Returns the electron temperature values along a lineout"""
289 return self.lineouts[lineout].interpolate(snaps.electron_temp)
291 def ion_temp(self, lineout: str, snaps: xr.Dataset) -> xr.DataArray:
292 """Returns the ion temperature values along a lineout"""
293 return self.lineouts[lineout].interpolate(snaps.ion_temp)
295 def potential(self, lineout: str, snaps: xr.Dataset) -> xr.DataArray:
296 """Returns the electrostatic potential values along a lineout"""
297 return self.lineouts[lineout].interpolate(snaps.potential)
299 def velocity(self, lineout: str, snaps: xr.Dataset) -> xr.DataArray:
300 """Returns the ion velocity values along a lineout"""
301 return self.lineouts[lineout].interpolate(snaps.velocity)
303 def penalisation_direction_function(self, lineout: str) -> xr.DataArray:
304 """Returns the penalisation direction function, which gives which field direction is 'towards' the target"""
305 return self.lineouts[lineout].interpolate(self.equi.penalisation_direction)
307 def current(self, lineout: str, snaps: xr.Dataset) -> xr.DataArray:
308 """Returns the plasma current"""
309 return self.lineouts[lineout].interpolate(snaps.current) * self.penalisation_direction_function(lineout)
311 def sound_speed(self, lineout: str, snaps: xr.Dataset) -> xr.DataArray:
312 """Calculates and returns the sound speed"""
313 return sound_speed(self.electron_temp(lineout, snaps), self.ion_temp(lineout, snaps), self.norm)
315 def jsat(self, lineout: str, snaps: xr.Dataset) -> xr.DataArray:
316 """
317 Calculates and returns the ion saturation current, with a factor of 0.5 depending on whether the
318 lineout is immersed (for omp lineout, snaps) or a wall probe (for other hfs/lfs)
319 """
320 if lineout == 'omp':
321 return ion_saturation_current(self.density(lineout, snaps), self.sound_speed(lineout, snaps), self.norm,
322 wall_probe=False)
323 elif lineout in ['hfs', 'lfs', 'rdpa']:
324 return ion_saturation_current(self.density(lineout, snaps), self.sound_speed(lineout, snaps), self.norm,
325 wall_probe=True)
326 else:
327 raise NotImplementedError(f"Lineout {lineout} not recognised")
329 def vfloat(self, lineout: str, snaps: xr.Dataset) -> xr.DataArray:
330 """Calculates and returns the floating potential"""
331 return floating_potential(self.potential(lineout, snaps), self.electron_temp(lineout, snaps), self.norm)
333 def mach_number(self, lineout: str, snaps: xr.Dataset) -> xr.DataArray:
334 """Calculates and returns the mach_number"""
335 mach = self.velocity(lineout, snaps) / self.sound_speed(lineout, snaps)
336 return mach.assign_attrs({'norm': Dimensionless})
338 def electron_temp_parallel_gradient(self, lineout: str, snaps: xr.Dataset):
339 """
340 Parallel gradient of electron temperature (defined only for a subset of the lineouts, since tracing
341 is performed in advance)
342 """
343 return compute_parallel_gradient(self.lineouts[lineout], snaps.electron_temp)
345 def ion_temp_parallel_gradient(self, lineout: str, snaps: xr.Dataset):
346 """
347 Parallel gradient of ion temperature (defined only for a subset of the lineouts, since tracing
348 is performed in advance)
349 """
350 return compute_parallel_gradient(self.lineouts[lineout], snaps.ion_temp)
352 def effective_parallel_exb(self, lineout:str, snaps: xr.Dataset):
353 """
354 The effective parallel ExB velocity (i.e. the parallel velocity that would cause as much poloidal
355 transport as the poloidal ExB velocity)
356 """
357 effective_parallel_exb = effective_parallel_exb_velocity(self.grid, self.equi, self.norm, snaps.potential)
358 return self.lineouts[lineout].interpolate(effective_parallel_exb)
360 def q_parallel(self, lineout: str, snaps: xr.Dataset) -> xr.DataArray:
361 """Calculates and returns the parallel heat flux"""
362 density = self.density(lineout, snaps)
363 electron_temp = self.electron_temp(lineout, snaps)
364 electron_temp_parallel_gradient = self.electron_temp_parallel_gradient(lineout, snaps)
365 ion_temp = self.ion_temp(lineout, snaps)
366 ion_temp_parallel_gradient = self.ion_temp_parallel_gradient(lineout, snaps)
367 ion_velocity = self.velocity(lineout, snaps)
368 current = self.current(lineout, snaps)
369 effective_parallel_exb = self.effective_parallel_exb(lineout, snaps)
371 q_par = total_parallel_heat_flux(density, electron_temp, electron_temp_parallel_gradient,
372 ion_temp, ion_temp_parallel_gradient,
373 ion_velocity, current, effective_parallel_exb, self.norm)
375 return q_par
377 def plot_lineouts(self):
378 """Plots the magnetic geometry and the lineouts (as a sanity check)"""
380 divertor_ = read_from_json(tcvx21.divertor_coords_json)
382 _, ax = plt.subplots(figsize=(10, 10))
383 plt.contour(self.grid.r_s, self.grid.z_s, self.equi.normalised_flux_surface_label(self.grid.r_s, self.grid.z_s))
385 plt.scatter(self.lineouts['xpt'].r_points, self.lineouts['xpt'].z_points,
386 s=1, marker='.', color='r', label='xpt')
387 plt.scatter(self.lineouts['rdpa'].r_points, self.lineouts['rdpa'].z_points,
388 s=1, marker='+', color='k', label='rdpa')
390 for key, lineout in self.lineouts.items():
392 if key in ['rdpa', 'xpt']:
393 continue
395 plt.plot(lineout.r_points, lineout.z_points, label=key, linewidth=2.5)
397 if hasattr(lineout, 'forward_lineout'):
398 plt.plot(lineout.forward_lineout.r_points, lineout.forward_lineout.z_points,
399 label=f"{key}+", linewidth=2.5)
400 if hasattr(lineout, 'reverse_lineout'):
401 plt.plot(lineout.reverse_lineout.r_points, lineout.reverse_lineout.z_points,
402 label=f"{key}-", linewidth=2.5)
404 plt.legend()
405 if self.equi.flipped_z:
406 ax.invert_yaxis()
407 ax.set_aspect('equal')
409 plt.plot(divertor_['r_points'] / self.equi.axis_r.values,
410 divertor_['z_points'] / self.equi.axis_r.values * -1.0 if self.equi.flipped_z else 1.0,
411 color='k')