Coverage for tcvx21/grillix_post/validation_writer_m.py: 97%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2Data extraction from GRILLIX is a two-step process. The first step is to make a "work file" of the
3observable signal at each time point and toroidal plane. The second is take statistics over this file, and
4write a NetCDF which is compatible with the validation analysis
6This module performs the second step
7"""
8from pathlib import Path
9import xarray as xr
10from netCDF4 import Dataset
12import tcvx21
13from tcvx21.record_c.record_writer_m import RecordWriter
14from tcvx21.units_m import Quantity
15from tcvx21.file_io.json_io_m import read_from_json
17from tcvx21.analysis.statistics_m import compute_statistical_moment_with_bootstrap, compute_statistical_moment, \
18 strip_moment
21def convert_work_file_to_validation_netcdf(work_file: Path, output_file: Path, simulation_hierarchy: dict,
22 statistics_interval: int = 500):
23 """
24 Converts a work file from the WorkFileWriter into a StandardNetCDF which can be used
25 with the validation analysis
27 The default statistics_interval is 500 time points, which corresponds to about 1ms
28 """
29 assert work_file.exists() and work_file.suffix == '.nc'
30 assert isinstance(simulation_hierarchy, dict), f"simulation_hierarchy should be dict, but was {type(simulation_hierarchy)}"
32 standard_dictionary = fill_standard_dict(work_file, statistics_interval, simulation_hierarchy)
34 dataset = Dataset(work_file)
36 time = dataset['time'][slice(-statistics_interval, None)]
37 statistics_time = Quantity(time, dataset['time'].units).max() - Quantity(time, dataset['time'].units).min()
39 additional_attributes = {}
41 for attribute in ['toroidal_planes', 'points_per_plane', 'n_points_R', 'n_points_Z', 'timestep', 'timestep_units',
42 'particle_source', 'particle_source_units', 'power_source', 'power_source_units']:
44 additional_attributes[attribute] = getattr(dataset, attribute)
46 additional_attributes['statistics_interval'] = statistics_time.magnitude
47 additional_attributes['statistics_interval_units'] = str(f"{statistics_time.units:P}")
49 writer = RecordWriter(file_path=output_file,
50 descriptor='GRX',
51 description=dataset.description,
52 allow_overwrite=True)
54 writer.write_data_dict(standard_dictionary, additional_attributes)
57def Q(netcdf_array):
58 """Converts netcdf arrays to Quantities"""
59 return Quantity(netcdf_array.values, netcdf_array.units)
62def fill_standard_dict(work_file: Path, statistics_interval: int, simulation_hierarchy: dict) -> dict:
63 """
64 Iterates over the elements in the standard dictionary, and fills each observable
65 """
66 print("Filling standard dict")
68 standard_dict = read_from_json(tcvx21.template_file)
70 # Expand the dictionary template to also include a region around the X-point
71 standard_dict['Xpt'] = {'name': 'X-point', 'observables': {
72 'density': {'name': 'Plasma density', 'units': '1 / meter ** 3'},
73 'electron_temp': {'name': 'Electron temperature', 'units': 'electron_volt'},
74 'potential': {'name': 'Plasma potential', 'units': 'volt'}}}
76 for observable in standard_dict['Xpt']['observables'].values():
77 observable['dimensionality'] = 2
78 observable['experimental_hierarchy'] = -1
79 observable['simulation_hierarchy'] = -1
81 for diagnostic, diagnostic_dict in standard_dict.items():
82 for observable, observable_dict in diagnostic_dict['observables'].items():
83 write_observable(work_file, statistics_interval, simulation_hierarchy,
84 diagnostic, observable, observable_dict)
86 print("Done")
87 return standard_dict
90def write_observable(work_file: Path, statistics_interval: int, simulation_hierarchy: dict,
91 diagnostic_key: str, observable_key: str, output_dict: dict):
92 """
93 Calculates statistical moments and fills values into the standard_dictionary, for writing
94 to a standard NetCDF
95 """
97 print(f"\tProcessing {diagnostic_key}:{observable_key}")
99 observable_key, moment = strip_moment(observable_key)
100 try:
101 diagnostic = xr.open_dataset(work_file, group=diagnostic_key)
102 except OSError as e: # pragma: no cover
103 # Catch an old name for the diagnostic
104 if diagnostic_key == 'FHRP':
105 diagnostic_key = 'RPTCV'
106 diagnostic = xr.open_dataset(work_file, group=diagnostic_key)
107 else:
108 raise e from None
109 try:
110 observable = xr.open_dataset(work_file, group=f"{diagnostic_key}/observables")[observable_key]
111 except OSError:
112 # Catch an old term for the observables
113 observable = xr.open_dataset(work_file, group=f"{diagnostic_key}/measurements")[observable_key]
115 observable = observable.isel(tau=slice(-statistics_interval, None)).persist()
117 output_dict['simulation_hierarchy'] = simulation_hierarchy[observable_key if moment != 'lambda_q' else moment]
119 if diagnostic_key == 'Xpt':
120 # Don't need an error estimate for the X-point profile, and the memory requirements for bootstrapping this
121 # can be very large
122 value = compute_statistical_moment(observable, moment=moment)
123 error = 0.0 * value
124 else:
125 value, error = compute_statistical_moment_with_bootstrap(observable, moment=moment)
127 output_dict['values'] = Q(value).to(output_dict['units']).magnitude
128 output_dict['errors'] = Q(error).to(output_dict['units']).magnitude
130 for variable_key in diagnostic.variables.keys():
131 variable = diagnostic[variable_key]
132 output_key = variable_key.replace('Rsep', 'Ru')
134 output_dict[output_key] = variable.values
135 output_dict[f"{output_key}_units"] = getattr(variable, 'units', '')