Coverage for tcvx21/grillix_post/work_file_writer_m.py: 90%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

207 statements  

1""" 

2Data extraction from GRILLIX is a two-step process. The first step is to make a "work file" of the 

3observable signal at each tau point and toroidal plane. The second is take statistics over this file, and 

4write a NetCDF which is compatible with the validation analysis 

5 

6This module performs the first step 

7""" 

8import matplotlib.pyplot as plt 

9from pathlib import Path 

10import xarray as xr 

11from netCDF4 import Dataset 

12import numpy as np 

13 

14import tcvx21 

15from tcvx21.units_m import Quantity, Dimensionless, convert_xarray_to_quantity 

16from tcvx21.file_io.json_io_m import read_from_json 

17 

18from tcvx21.grillix_post.components import Grid, Equi, Normalisation,\ 

19 read_snaps_from_file, get_snap_length_from_file, read_fortran_namelist, convert_params_filepaths, integrated_sources 

20 

21from tcvx21.grillix_post.lineouts import OutboardMidplaneMap, outboard_midplane_chord, penalisation_contour,\ 

22 thomson_scattering, rdpa, xpoint 

23 

24from tcvx21.grillix_post.observables import floating_potential, ion_saturation_current, sound_speed, \ 

25 compute_parallel_gradient, initialise_lineout_for_parallel_gradient, total_parallel_heat_flux, \ 

26 effective_parallel_exb_velocity 

27 

28xr.set_options(keep_attrs=True) 

29search_paths = ['.', 'trunk', '..'] 

30 

31 

32def filepath_resolver(directory: Path, search_file: str): 

33 """ 

34 Checks relative paths specified in search_paths for a specified file or folder 

35 """ 

36 

37 for search_path in search_paths: 

38 search_directory = directory / search_path 

39 

40 if search_directory.exists(): 

41 if search_file in [file.name for file in search_directory.iterdir()]: 

42 found_file = search_directory / search_file 

43 

44 assert found_file.exists() 

45 return found_file.absolute() 

46 

47 raise FileNotFoundError(f"{search_file} not found in {directory}") 

48 

49 

50class WorkFileWriter: 

51 """ 

52 Class to extract data from the raw GRILLIX snapshots. 

53 

54 Writes to a "work file" 

55 """ 

56 

57 def __init__(self, file_path: Path, 

58 work_file: Path, 

59 equi_file: str = 'TCV_ortho.nc', 

60 toroidal_field_direction: str = 'reverse', 

61 data_length: int = 1000, 

62 n_points: int = 500, 

63 make_work_file: bool = True 

64 ): 

65 """ 

66 Initialises an object to store general information about a run, as well as lineouts for the tcvx21 case, and 

67 methods to calculate experimental quantities 

68 """ 

69 assert toroidal_field_direction in ['forward', 'reverse'], \ 

70 f"toroidal_field_direction must be forward or reverse, but was {toroidal_field_direction}" 

71 

72 self.file_path = Path(file_path) 

73 self.work_file = Path(work_file) 

74 self.data_length = data_length 

75 

76 assert (self.file_path / 'description.txt').exists(), f"Need to have a description.txt file in file_path" 

77 

78 print("Setting up core analysis") 

79 self.grid = Grid(filepath_resolver(file_path, 'vgrid.nc')) 

80 self.norm = Normalisation.initialise_from_normalisation_file( 

81 filepath_resolver(file_path, 'physical_parameters.nml')) 

82 self.equi = Equi(equi_file=filepath_resolver(file_path, equi_file), 

83 penalisation_file=filepath_resolver(file_path, 'pen_metainfo.nc'), 

84 flip_z=False if toroidal_field_direction == 'reverse' else True) 

85 

86 parameter_filepath = filepath_resolver(file_path, 'params.in') 

87 self.params = convert_params_filepaths(parameter_filepath, read_fortran_namelist(parameter_filepath)) 

88 

89 self.diagnostic_to_lineout = {'LFS-LP': 'lfs', 'LFS-IR': 'lfs', 'HFS-LP': 'hfs', 'FHRP': 'omp', 

90 'TS': 'ts', 'RDPA': 'rdpa', 'Xpt': 'xpt'} 

91 

92 print("Making lineouts") 

93 

94 self.omp_map = OutboardMidplaneMap(self.grid, self.equi, self.norm) 

95 omp = outboard_midplane_chord(self.grid, self.equi, n_points=n_points) 

96 lfs = penalisation_contour(self.grid, self.equi, level=0.0, contour_index=0, n_points=n_points) 

97 hfs = penalisation_contour(self.grid, self.equi, level=0.0, contour_index=1, n_points=n_points) 

98 

99 ts = thomson_scattering(self.grid, self.equi, 

100 tcvx21.thomson_coords_json, n_points=n_points) 

101 rdpa_ = rdpa(self.grid, self.equi, self.omp_map, self.norm, 

102 tcvx21.rdpa_coords_json) 

103 xpt = xpoint(self.grid, self.equi, self.norm) 

104 

105 print("Tracing for parallel gradient") 

106 initialise_lineout_for_parallel_gradient(lfs, self.grid, self.equi, self.norm, 

107 npol=self.params['params_grid']['npol'], 

108 stored_trace=file_path/f'lfs_trace_{n_points}.nc') 

109 

110 self.lineouts = { 

111 'omp': omp, 

112 'lfs': lfs, 

113 'hfs': hfs, 

114 'ts': ts, 

115 'rdpa': rdpa_, 

116 'xpt': xpt 

117 } 

118 

119 if not work_file.exists() and make_work_file: 

120 print("Initialising work file") 

121 self.initialise_work_file() 

122 print("Work file initialised") 

123 

124 print('Done') 

125 

126 def initialise_work_file(self): 

127 """ 

128 Initialises a NetCDF file which can be iteratively filled with observables 

129 """ 

130 

131 dataset = Dataset(self.work_file, 'w', format="NETCDF4") 

132 snaps_length = get_snap_length_from_file(file_path=self.file_path) 

133 

134 dataset.first_snap = np.max((snaps_length - self.data_length, 0)) 

135 print(f"Reading data from snap {dataset.first_snap} to last snap" 

136 f"(currently {snaps_length}), length {self.data_length}") 

137 dataset.last_snap = dataset.first_snap 

138 

139 self.write_summary(dataset.last_snap) 

140 

141 dataset.createDimension(dimname='tau', size=None) 

142 dataset.createDimension(dimname='phi', size=self.params['params_grid']['npol']) 

143 

144 standard_dict = read_from_json(tcvx21.template_file) 

145 

146 # Expand the dictionary template to also include a region around the X-point 

147 standard_dict['Xpt'] = {'observables': { 

148 'density': {'units': '1 / meter ** 3'}, 

149 'electron_temp': {'units': 'electron_volt'}, 

150 'potential': {'units': 'volt'}}} 

151 

152 dataset.createVariable(varname='time', datatype=np.float64, dimensions=('tau',)) 

153 dataset['time'].units = 'milliseconds' 

154 dataset.createVariable(varname='snap_index', datatype=np.float64, dimensions=('tau',)) 

155 

156 for diagnostic, diagnostic_dict in standard_dict.items(): 

157 diagnostic_group = dataset.createGroup(diagnostic) 

158 

159 lineout_key = self.diagnostic_to_lineout[diagnostic] 

160 diagnostic_group.lineout_key = lineout_key 

161 lineout = self.lineouts[diagnostic_group.lineout_key] 

162 

163 diagnostic_group.createDimension(dimname='points', size=lineout.r_points.size) 

164 diagnostic_group.createVariable(varname='R', datatype=np.float64, dimensions=('points',)) 

165 diagnostic_group.createVariable(varname='Z', datatype=np.float64, dimensions=('points',)) 

166 diagnostic_group['R'][:] = lineout.r_points * self.norm.R0.to('m').magnitude 

167 diagnostic_group['R'].units = 'meter' 

168 diagnostic_group['Z'][:] = lineout.z_points * self.norm.R0.to('m').magnitude * \ 

169 (-1.0 if self.equi.flipped_z else 1.0) 

170 diagnostic_group['Z'].units = 'meter' 

171 

172 diagnostic_group.createVariable(varname='rho', datatype=np.float64, dimensions=('points',)) 

173 diagnostic_group.createVariable(varname='Rsep', datatype=np.float64, dimensions=('points',)) 

174 diagnostic_group['rho'][:] = self.lineout_rho(lineout_key).values 

175 diagnostic_group['Rsep'][:] = self.lineout_rsep(lineout_key).to('meter').magnitude 

176 diagnostic_group['Rsep'].units = 'meter' 

177 

178 if hasattr(lineout, 'coords'): 

179 diagnostic_group.createVariable(varname='Zx', datatype=np.float64, dimensions=('points',)) 

180 diagnostic_group['Zx'][:] = lineout.coords['Zx'].to('m').magnitude 

181 diagnostic_group['Zx'].units = 'meter' 

182 

183 observables_group = diagnostic_group.createGroup('observables') 

184 

185 for observable, observable_dict in diagnostic_dict['observables'].items(): 

186 

187 if observable.endswith('_std') or observable.endswith('_skew') or observable.endswith( 

188 '_kurtosis'): 

189 # Don't calculate statistics or fits at this point 

190 continue 

191 observable_var = observables_group.createVariable( 

192 varname=observable, datatype=np.float64, dimensions=('tau', 'phi', 'points')) 

193 

194 observable_var.units = observable_dict['units'] 

195 

196 dataset.close() 

197 

198 def write_summary(self, snap): 

199 

200 snaps = read_snaps_from_file(self.file_path, self.norm, time_slice=[snap], all_planes=True).persist() 

201 sources = integrated_sources(self.grid, self.equi, self.norm, self.params, snaps) 

202 

203 with Dataset(self.work_file, 'a') as dataset: 

204 dataset.description = (self.file_path / 'description.txt').read_text() 

205 dataset.toroidal_planes = snaps.sizes['phi'] 

206 dataset.points_per_plane = snaps.sizes['points'] 

207 dataset.n_points_R = self.grid.r_s.size 

208 dataset.n_points_Z = self.grid.z_s.size 

209 

210 timestep = (self.params['params_tstep']['dtau_max'] * self.norm.tau_0).to_compact() 

211 dataset.timestep = timestep.magnitude 

212 dataset.timestep_units = str(f"{timestep.units:P}") 

213 

214 dataset.particle_source = sources['density_source'].magnitude 

215 dataset.particle_source_units = str(f"{sources['density_source'].units:P}") 

216 dataset.power_source = sources['power_source'].magnitude 

217 dataset.power_source_units = str(f"{sources['power_source'].units:P}") 

218 

219 snaps.close() 

220 

221 def fill_work_file(self): 

222 """ 

223 Iterates over the snaps dataset, to fill the complete dataset. Closes after each time segment, to prevent 

224 data loss. This isn't super efficient, but the data filling only has to be run once 

225 

226 Summary elements like the description is updated each tau the script is executed 

227 """ 

228 

229 with Dataset(self.work_file, 'a') as dataset: 

230 snaps_length = get_snap_length_from_file(file_path=self.file_path) 

231 first_snap = dataset.first_snap 

232 current_snap = dataset.last_snap 

233 last_snap = snaps_length 

234 

235 print(f"First snap: {first_snap}, current_snap: {current_snap}, last_snap: {last_snap}") 

236 

237 for snap in range(current_snap, last_snap): 

238 t = snap - first_snap 

239 

240 snaps = read_snaps_from_file(self.file_path, 

241 self.norm, 

242 time_slice=[snap], 

243 all_planes=True).persist() 

244 

245 with Dataset(self.work_file, 'a') as dataset: 

246 # Open and close the dataset every write, to prevent data-loss 

247 # from timing out 

248 

249 print(f"Processing snap {snap} of {last_snap} (t={t}/{last_snap - first_snap})") 

250 

251 dataset['time'][t] = self.time(snaps=snaps).magnitude 

252 dataset['snap_index'][t] = dataset.first_snap + t 

253 

254 for diagnostic_group in dataset.groups.values(): 

255 lineout_key = diagnostic_group.lineout_key 

256 for observable_key, observable_var in diagnostic_group['observables'].variables.items(): 

257 observable = getattr(self, observable_key) 

258 values = observable(lineout_key, snaps=snaps).compute() 

259 values = values.transpose("tau", "phi", "interp_points") 

260 values = convert_xarray_to_quantity(values).to(observable_var.units).magnitude 

261 

262 observable_var[t, :, :] = values 

263 

264 dataset.last_snap = snap + 1 

265 

266 snaps.close() 

267 

268 print('Done') 

269 

270 def time(self, snaps: xr.Dataset): 

271 """Returns an array of tau values (in ms)""" 

272 return convert_xarray_to_quantity(snaps.tau).to('ms') 

273 

274 def lineout_rho(self, lineout: str) -> xr.DataArray: 

275 """Returns the flux-surface labels of a lineout""" 

276 lineout_ = self.lineouts[lineout] 

277 return self.equi.normalised_flux_surface_label(lineout_.r_points, lineout_.z_points, grid=False) 

278 

279 def lineout_rsep(self, lineout: str) -> Quantity: 

280 """Returns the OMP-mapped distance along a lineout""" 

281 return convert_xarray_to_quantity(self.omp_map.convert_rho_to_distance(self.lineout_rho(lineout))) 

282 

283 def density(self, lineout: str, snaps: xr.Dataset) -> xr.DataArray: 

284 """Returns the plasma density values along a lineout""" 

285 return self.lineouts[lineout].interpolate(snaps.density) 

286 

287 def electron_temp(self, lineout: str, snaps: xr.Dataset) -> xr.DataArray: 

288 """Returns the electron temperature values along a lineout""" 

289 return self.lineouts[lineout].interpolate(snaps.electron_temp) 

290 

291 def ion_temp(self, lineout: str, snaps: xr.Dataset) -> xr.DataArray: 

292 """Returns the ion temperature values along a lineout""" 

293 return self.lineouts[lineout].interpolate(snaps.ion_temp) 

294 

295 def potential(self, lineout: str, snaps: xr.Dataset) -> xr.DataArray: 

296 """Returns the electrostatic potential values along a lineout""" 

297 return self.lineouts[lineout].interpolate(snaps.potential) 

298 

299 def velocity(self, lineout: str, snaps: xr.Dataset) -> xr.DataArray: 

300 """Returns the ion velocity values along a lineout""" 

301 return self.lineouts[lineout].interpolate(snaps.velocity) 

302 

303 def penalisation_direction_function(self, lineout: str) -> xr.DataArray: 

304 """Returns the penalisation direction function, which gives which field direction is 'towards' the target""" 

305 return self.lineouts[lineout].interpolate(self.equi.penalisation_direction) 

306 

307 def current(self, lineout: str, snaps: xr.Dataset) -> xr.DataArray: 

308 """Returns the plasma current""" 

309 return self.lineouts[lineout].interpolate(snaps.current) * self.penalisation_direction_function(lineout) 

310 

311 def sound_speed(self, lineout: str, snaps: xr.Dataset) -> xr.DataArray: 

312 """Calculates and returns the sound speed""" 

313 return sound_speed(self.electron_temp(lineout, snaps), self.ion_temp(lineout, snaps), self.norm) 

314 

315 def jsat(self, lineout: str, snaps: xr.Dataset) -> xr.DataArray: 

316 """ 

317 Calculates and returns the ion saturation current, with a factor of 0.5 depending on whether the 

318 lineout is immersed (for omp lineout, snaps) or a wall probe (for other hfs/lfs) 

319 """ 

320 if lineout == 'omp': 

321 return ion_saturation_current(self.density(lineout, snaps), self.sound_speed(lineout, snaps), self.norm, 

322 wall_probe=False) 

323 elif lineout in ['hfs', 'lfs', 'rdpa']: 

324 return ion_saturation_current(self.density(lineout, snaps), self.sound_speed(lineout, snaps), self.norm, 

325 wall_probe=True) 

326 else: 

327 raise NotImplementedError(f"Lineout {lineout} not recognised") 

328 

329 def vfloat(self, lineout: str, snaps: xr.Dataset) -> xr.DataArray: 

330 """Calculates and returns the floating potential""" 

331 return floating_potential(self.potential(lineout, snaps), self.electron_temp(lineout, snaps), self.norm) 

332 

333 def mach_number(self, lineout: str, snaps: xr.Dataset) -> xr.DataArray: 

334 """Calculates and returns the mach_number""" 

335 mach = self.velocity(lineout, snaps) / self.sound_speed(lineout, snaps) 

336 return mach.assign_attrs({'norm': Dimensionless}) 

337 

338 def electron_temp_parallel_gradient(self, lineout: str, snaps: xr.Dataset): 

339 """ 

340 Parallel gradient of electron temperature (defined only for a subset of the lineouts, since tracing 

341 is performed in advance) 

342 """ 

343 return compute_parallel_gradient(self.lineouts[lineout], snaps.electron_temp) 

344 

345 def ion_temp_parallel_gradient(self, lineout: str, snaps: xr.Dataset): 

346 """ 

347 Parallel gradient of ion temperature (defined only for a subset of the lineouts, since tracing 

348 is performed in advance) 

349 """ 

350 return compute_parallel_gradient(self.lineouts[lineout], snaps.ion_temp) 

351 

352 def effective_parallel_exb(self, lineout:str, snaps: xr.Dataset): 

353 """ 

354 The effective parallel ExB velocity (i.e. the parallel velocity that would cause as much poloidal 

355 transport as the poloidal ExB velocity) 

356 """ 

357 effective_parallel_exb = effective_parallel_exb_velocity(self.grid, self.equi, self.norm, snaps.potential) 

358 return self.lineouts[lineout].interpolate(effective_parallel_exb) 

359 

360 def q_parallel(self, lineout: str, snaps: xr.Dataset) -> xr.DataArray: 

361 """Calculates and returns the parallel heat flux""" 

362 density = self.density(lineout, snaps) 

363 electron_temp = self.electron_temp(lineout, snaps) 

364 electron_temp_parallel_gradient = self.electron_temp_parallel_gradient(lineout, snaps) 

365 ion_temp = self.ion_temp(lineout, snaps) 

366 ion_temp_parallel_gradient = self.ion_temp_parallel_gradient(lineout, snaps) 

367 ion_velocity = self.velocity(lineout, snaps) 

368 current = self.current(lineout, snaps) 

369 effective_parallel_exb = self.effective_parallel_exb(lineout, snaps) 

370 

371 q_par = total_parallel_heat_flux(density, electron_temp, electron_temp_parallel_gradient, 

372 ion_temp, ion_temp_parallel_gradient, 

373 ion_velocity, current, effective_parallel_exb, self.norm) 

374 

375 return q_par 

376 

377 def plot_lineouts(self): 

378 """Plots the magnetic geometry and the lineouts (as a sanity check)""" 

379 

380 divertor_ = read_from_json(tcvx21.divertor_coords_json) 

381 

382 _, ax = plt.subplots(figsize=(10, 10)) 

383 plt.contour(self.grid.r_s, self.grid.z_s, self.equi.normalised_flux_surface_label(self.grid.r_s, self.grid.z_s)) 

384 

385 plt.scatter(self.lineouts['xpt'].r_points, self.lineouts['xpt'].z_points, 

386 s=1, marker='.', color='r', label='xpt') 

387 plt.scatter(self.lineouts['rdpa'].r_points, self.lineouts['rdpa'].z_points, 

388 s=1, marker='+', color='k', label='rdpa') 

389 

390 for key, lineout in self.lineouts.items(): 

391 

392 if key in ['rdpa', 'xpt']: 

393 continue 

394 

395 plt.plot(lineout.r_points, lineout.z_points, label=key, linewidth=2.5) 

396 

397 if hasattr(lineout, 'forward_lineout'): 

398 plt.plot(lineout.forward_lineout.r_points, lineout.forward_lineout.z_points, 

399 label=f"{key}+", linewidth=2.5) 

400 if hasattr(lineout, 'reverse_lineout'): 

401 plt.plot(lineout.reverse_lineout.r_points, lineout.reverse_lineout.z_points, 

402 label=f"{key}-", linewidth=2.5) 

403 

404 plt.legend() 

405 if self.equi.flipped_z: 

406 ax.invert_yaxis() 

407 ax.set_aspect('equal') 

408 

409 plt.plot(divertor_['r_points'] / self.equi.axis_r.values, 

410 divertor_['z_points'] / self.equi.axis_r.values * -1.0 if self.equi.flipped_z else 1.0, 

411 color='k')