Coverage for tcvx21/record_c/record_writer_m.py: 89%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

121 statements  

1""" 

2A class to write a standard NetCDF file from any arbitrary data source. 

3""" 

4from pathlib import Path 

5from netCDF4 import Dataset, Group 

6import numpy as np 

7from tcvx21 import Quantity 

8 

9 

10def strip_units(array_in): 

11 try: 

12 return array_in.magnitude 

13 except AttributeError: 

14 assert isinstance(array_in, np.ndarray) 

15 return array_in 

16 

17class RecordWriter: 

18 

19 def __init__(self, 

20 file_path: Path, 

21 descriptor: str, 

22 description: str, 

23 allow_overwrite: bool = False): 

24 """ 

25 Initialises a new dataset object and prepares to write into it 

26 

27 file_path should be a pathlib.Path object pointing to a valid filepath where to write the NetCDF file to. 

28 if allow_overwrite, then existing files will be overwritten, otherwise an error will be raised if the filepath 

29 already exists. 

30 

31 descriptor should be a short descriptor. It is used to automatically produce a figure label 

32 

33 description is not used for the analysis. It should be a string with a long-form description of the 

34 simulation contained in the validation file 

35 For simulations, it should detail things such as; 

36 * the code used to generate the data 

37 * what git hash was used to run the cases 

38 * the name and contact email address of who generated the data 

39 * how long the simulation was run before starting to gather statistics 

40 * how long statistics were gathered for 

41 * any "tricks" used to stabilise the simulations 

42 """ 

43 

44 assert isinstance(file_path, Path), \ 

45 (f"Should pass file_path as a Path object but received {type(file_path)}") 

46 

47 assert isinstance(descriptor, str), \ 

48 (f"Should pass a string as descriptor, but received type {type(descriptor)}") 

49 

50 assert isinstance(description, str), \ 

51 (f"Should pass description as a string but received {type(file_path)}") 

52 

53 assert file_path.suffix == ".nc", f"Should use a .nc suffix, but suffix was {file_path.suffix}" 

54 

55 if file_path.exists(): 

56 if not allow_overwrite: 

57 raise FileExistsError(f"The requested file_path {file_path} already " 

58 f"exists and allow_overwrite is False") 

59 else: 

60 print(f"Overwriting {file_path}") 

61 file_path.unlink() 

62 

63 self.file_path = file_path 

64 self.description = description 

65 self.descriptor = descriptor 

66 

67 def write_data_dict(self, data_dict: dict, additional_attributes: dict = {}): 

68 """ 

69 Recursively writes entries from a standard-formatted dictionary (based on the observables.json template) into 

70 a NetCDF file 

71 """ 

72 dataset = Dataset(self.file_path, 'w') 

73 

74 dataset.description = self.description 

75 dataset.descriptor = self.descriptor 

76 dataset.build_url = 'gitlab.mpcdf.mpg.de/tcv-x21/tcv-x21' 

77 for attribute, value in additional_attributes.items(): 

78 setattr(dataset, attribute, value) 

79 

80 for diagnostic in data_dict.keys(): 

81 self.write_diagnostic(dataset, data_dict[diagnostic], diagnostic) 

82 

83 dataset.close() 

84 

85 def write_diagnostic(self, dataset: Dataset, diagnostic_dict: dict, diagnostic: str): 

86 """ 

87 Writes the contents of a diagnostic into the NetCDF 

88 """ 

89 diagnostic_group = dataset.createGroup(diagnostic) 

90 

91 for key, val in diagnostic_dict.items(): 

92 if key in ['observables', 'name']: 

93 continue 

94 setattr(diagnostic_group, key, val) 

95 

96 diagnostic_group.diagnostic_name = diagnostic_dict['name'] 

97 observables_group = diagnostic_group.createGroup('observables') 

98 

99 for observable in diagnostic_dict['observables'].keys(): 

100 try: 

101 self.write_observable(observables_group, diagnostic_dict['observables'][observable], observable) 

102 except Exception as e: 

103 print(f"Failed to write {diagnostic_dict['name']}:{observable}. Reraising error") 

104 raise e from None 

105 

106 @staticmethod 

107 def plain_text_units(units: str): 

108 """ 

109 Converts a unit string to long-form ASCII text 

110 """ 

111 return f"{Quantity(1, units).units}" 

112 

113 def write_observable(self, observables_group: Group, data: dict, observable: str): 

114 """ 

115 Write a observable into the NetCDF 

116 """ 

117 

118 observable_group = observables_group.createGroup(observable) 

119 observable_group.observable_name = data['name'] 

120 

121 dimensionality = data['dimensionality'] 

122 n_points = np.size(data['values']) 

123 has_error = np.size(data['errors']) != 0 

124 

125 observable_group.dimensionality = dimensionality 

126 observable_group.experimental_hierarchy = data['experimental_hierarchy'] 

127 

128 if data['simulation_hierarchy'] > 0: 

129 observable_group.simulation_hierarchy = data['simulation_hierarchy'] 

130 

131 if n_points == 0: 

132 print(f'Missing data for {observable_group.path}') 

133 return 

134 

135 plain_text_units = self.plain_text_units(data['units']) 

136 

137 if dimensionality == 0: 

138 observable_group.createDimension(dimname='point', size=1) 

139 value = observable_group.createVariable(varname='value', datatype=np.float64, dimensions=('point',)) 

140 value[:] = strip_units(data['values']) 

141 value.units = plain_text_units 

142 

143 if has_error: 

144 error = observable_group.createVariable(varname='error', datatype=np.float64, dimensions=('point',)) 

145 error[:] = strip_units(data['errors']) 

146 error.units = plain_text_units 

147 

148 elif dimensionality == 1: 

149 observable_group.createDimension(dimname='points', size=n_points) 

150 value = observable_group.createVariable(varname='value', datatype=np.float64, dimensions=('points',)) 

151 

152 assert np.size(data['Ru']) == n_points 

153 

154 value[:] = strip_units(data['values']) 

155 value.units = plain_text_units 

156 

157 r_upstream = observable_group.createVariable(varname='Rsep_omp', 

158 datatype=np.float64, dimensions=('points',)) 

159 r_upstream[:] = data['Ru'] 

160 r_upstream.units = self.plain_text_units(data['Ru_units']) 

161 

162 # Radial position -- not required, but nice for reference 

163 if 'R' in data.keys(): 

164 r = observable_group.createVariable(varname='R', 

165 datatype=np.float64, dimensions=('points',)) 

166 r[:] = data['R'] 

167 r.units = self.plain_text_units(data['R_units']) 

168 

169 if 'Z' in data.keys(): 

170 # Vertical position 

171 z = observable_group.createVariable(varname='Z', 

172 datatype=np.float64, dimensions=('points',)) 

173 z[:] = data['Z'] 

174 z.units = self.plain_text_units(data['Z_units']) 

175 

176 if has_error: 

177 error = observable_group.createVariable(varname='error', datatype=np.float64, dimensions=('points',)) 

178 error[:] = strip_units(data['errors']) 

179 error.units = plain_text_units 

180 

181 elif dimensionality == 2: 

182 # Write flattened RDPA data 

183 observable_group.createDimension(dimname='points', size=n_points) 

184 value = observable_group.createVariable(varname='value', datatype=np.float64, dimensions=('points',)) 

185 value[:] = strip_units(data['values']) 

186 value.units = plain_text_units 

187 

188 assert np.size(data['Ru']) == n_points, f"{n_points}, {np.size(data['Ru'])}, {data}" 

189 assert np.size(data['Zx']) == n_points, f"{n_points}, {np.size(data['Zx'])}, {data}" 

190 

191 # Upstream-mapped radial position (flux-surface label) 

192 r_upstream = observable_group.createVariable(varname='Rsep_omp', 

193 datatype=np.float64, dimensions=('points',)) 

194 

195 r_upstream[:] = data['Ru'] 

196 r_upstream.units = self.plain_text_units(data['Ru_units']) 

197 

198 # Radial position -- not required, but nice for reference 

199 if 'R' in data.keys(): 

200 r = observable_group.createVariable(varname='R', 

201 datatype=np.float64, dimensions=('points',)) 

202 r[:] = data['R'] 

203 r.units = self.plain_text_units(data['R_units']) 

204 

205 # Vertical position relative to the X-point 

206 zx = observable_group.createVariable(varname='Zx', 

207 datatype=np.float64, dimensions=('points',)) 

208 

209 zx[:] = data['Zx'] 

210 zx.units = self.plain_text_units(data['Zx_units']) 

211 

212 if 'Z' in data.keys(): 

213 # Vertical position 

214 z = observable_group.createVariable(varname='Z', 

215 datatype=np.float64, dimensions=('points',)) 

216 z[:] = data['Z'] 

217 z.units = self.plain_text_units(data['Z_units']) 

218 

219 if has_error: 

220 error = observable_group.createVariable(varname='error', datatype=np.float64, dimensions=('points',)) 

221 error[:] = strip_units(data['errors']) 

222 error.units = plain_text_units 

223 

224 else: 

225 raise NotImplementedError(f"Haven't implemented a write method for dimensionality {dimensionality} (yet)")