Coverage for tcvx21/record_c/record_writer_m.py: 89%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2A class to write a standard NetCDF file from any arbitrary data source.
3"""
4from pathlib import Path
5from netCDF4 import Dataset, Group
6import numpy as np
7from tcvx21 import Quantity
10def strip_units(array_in):
11 try:
12 return array_in.magnitude
13 except AttributeError:
14 assert isinstance(array_in, np.ndarray)
15 return array_in
17class RecordWriter:
19 def __init__(self,
20 file_path: Path,
21 descriptor: str,
22 description: str,
23 allow_overwrite: bool = False):
24 """
25 Initialises a new dataset object and prepares to write into it
27 file_path should be a pathlib.Path object pointing to a valid filepath where to write the NetCDF file to.
28 if allow_overwrite, then existing files will be overwritten, otherwise an error will be raised if the filepath
29 already exists.
31 descriptor should be a short descriptor. It is used to automatically produce a figure label
33 description is not used for the analysis. It should be a string with a long-form description of the
34 simulation contained in the validation file
35 For simulations, it should detail things such as;
36 * the code used to generate the data
37 * what git hash was used to run the cases
38 * the name and contact email address of who generated the data
39 * how long the simulation was run before starting to gather statistics
40 * how long statistics were gathered for
41 * any "tricks" used to stabilise the simulations
42 """
44 assert isinstance(file_path, Path), \
45 (f"Should pass file_path as a Path object but received {type(file_path)}")
47 assert isinstance(descriptor, str), \
48 (f"Should pass a string as descriptor, but received type {type(descriptor)}")
50 assert isinstance(description, str), \
51 (f"Should pass description as a string but received {type(file_path)}")
53 assert file_path.suffix == ".nc", f"Should use a .nc suffix, but suffix was {file_path.suffix}"
55 if file_path.exists():
56 if not allow_overwrite:
57 raise FileExistsError(f"The requested file_path {file_path} already "
58 f"exists and allow_overwrite is False")
59 else:
60 print(f"Overwriting {file_path}")
61 file_path.unlink()
63 self.file_path = file_path
64 self.description = description
65 self.descriptor = descriptor
67 def write_data_dict(self, data_dict: dict, additional_attributes: dict = {}):
68 """
69 Recursively writes entries from a standard-formatted dictionary (based on the observables.json template) into
70 a NetCDF file
71 """
72 dataset = Dataset(self.file_path, 'w')
74 dataset.description = self.description
75 dataset.descriptor = self.descriptor
76 dataset.build_url = 'gitlab.mpcdf.mpg.de/tcv-x21/tcv-x21'
77 for attribute, value in additional_attributes.items():
78 setattr(dataset, attribute, value)
80 for diagnostic in data_dict.keys():
81 self.write_diagnostic(dataset, data_dict[diagnostic], diagnostic)
83 dataset.close()
85 def write_diagnostic(self, dataset: Dataset, diagnostic_dict: dict, diagnostic: str):
86 """
87 Writes the contents of a diagnostic into the NetCDF
88 """
89 diagnostic_group = dataset.createGroup(diagnostic)
91 for key, val in diagnostic_dict.items():
92 if key in ['observables', 'name']:
93 continue
94 setattr(diagnostic_group, key, val)
96 diagnostic_group.diagnostic_name = diagnostic_dict['name']
97 observables_group = diagnostic_group.createGroup('observables')
99 for observable in diagnostic_dict['observables'].keys():
100 try:
101 self.write_observable(observables_group, diagnostic_dict['observables'][observable], observable)
102 except Exception as e:
103 print(f"Failed to write {diagnostic_dict['name']}:{observable}. Reraising error")
104 raise e from None
106 @staticmethod
107 def plain_text_units(units: str):
108 """
109 Converts a unit string to long-form ASCII text
110 """
111 return f"{Quantity(1, units).units}"
113 def write_observable(self, observables_group: Group, data: dict, observable: str):
114 """
115 Write a observable into the NetCDF
116 """
118 observable_group = observables_group.createGroup(observable)
119 observable_group.observable_name = data['name']
121 dimensionality = data['dimensionality']
122 n_points = np.size(data['values'])
123 has_error = np.size(data['errors']) != 0
125 observable_group.dimensionality = dimensionality
126 observable_group.experimental_hierarchy = data['experimental_hierarchy']
128 if data['simulation_hierarchy'] > 0:
129 observable_group.simulation_hierarchy = data['simulation_hierarchy']
131 if n_points == 0:
132 print(f'Missing data for {observable_group.path}')
133 return
135 plain_text_units = self.plain_text_units(data['units'])
137 if dimensionality == 0:
138 observable_group.createDimension(dimname='point', size=1)
139 value = observable_group.createVariable(varname='value', datatype=np.float64, dimensions=('point',))
140 value[:] = strip_units(data['values'])
141 value.units = plain_text_units
143 if has_error:
144 error = observable_group.createVariable(varname='error', datatype=np.float64, dimensions=('point',))
145 error[:] = strip_units(data['errors'])
146 error.units = plain_text_units
148 elif dimensionality == 1:
149 observable_group.createDimension(dimname='points', size=n_points)
150 value = observable_group.createVariable(varname='value', datatype=np.float64, dimensions=('points',))
152 assert np.size(data['Ru']) == n_points
154 value[:] = strip_units(data['values'])
155 value.units = plain_text_units
157 r_upstream = observable_group.createVariable(varname='Rsep_omp',
158 datatype=np.float64, dimensions=('points',))
159 r_upstream[:] = data['Ru']
160 r_upstream.units = self.plain_text_units(data['Ru_units'])
162 # Radial position -- not required, but nice for reference
163 if 'R' in data.keys():
164 r = observable_group.createVariable(varname='R',
165 datatype=np.float64, dimensions=('points',))
166 r[:] = data['R']
167 r.units = self.plain_text_units(data['R_units'])
169 if 'Z' in data.keys():
170 # Vertical position
171 z = observable_group.createVariable(varname='Z',
172 datatype=np.float64, dimensions=('points',))
173 z[:] = data['Z']
174 z.units = self.plain_text_units(data['Z_units'])
176 if has_error:
177 error = observable_group.createVariable(varname='error', datatype=np.float64, dimensions=('points',))
178 error[:] = strip_units(data['errors'])
179 error.units = plain_text_units
181 elif dimensionality == 2:
182 # Write flattened RDPA data
183 observable_group.createDimension(dimname='points', size=n_points)
184 value = observable_group.createVariable(varname='value', datatype=np.float64, dimensions=('points',))
185 value[:] = strip_units(data['values'])
186 value.units = plain_text_units
188 assert np.size(data['Ru']) == n_points, f"{n_points}, {np.size(data['Ru'])}, {data}"
189 assert np.size(data['Zx']) == n_points, f"{n_points}, {np.size(data['Zx'])}, {data}"
191 # Upstream-mapped radial position (flux-surface label)
192 r_upstream = observable_group.createVariable(varname='Rsep_omp',
193 datatype=np.float64, dimensions=('points',))
195 r_upstream[:] = data['Ru']
196 r_upstream.units = self.plain_text_units(data['Ru_units'])
198 # Radial position -- not required, but nice for reference
199 if 'R' in data.keys():
200 r = observable_group.createVariable(varname='R',
201 datatype=np.float64, dimensions=('points',))
202 r[:] = data['R']
203 r.units = self.plain_text_units(data['R_units'])
205 # Vertical position relative to the X-point
206 zx = observable_group.createVariable(varname='Zx',
207 datatype=np.float64, dimensions=('points',))
209 zx[:] = data['Zx']
210 zx.units = self.plain_text_units(data['Zx_units'])
212 if 'Z' in data.keys():
213 # Vertical position
214 z = observable_group.createVariable(varname='Z',
215 datatype=np.float64, dimensions=('points',))
216 z[:] = data['Z']
217 z.units = self.plain_text_units(data['Z_units'])
219 if has_error:
220 error = observable_group.createVariable(varname='error', datatype=np.float64, dimensions=('points',))
221 error[:] = strip_units(data['errors'])
222 error.units = plain_text_units
224 else:
225 raise NotImplementedError(f"Haven't implemented a write method for dimensionality {dimensionality} (yet)")