Coverage for tcvx21/grillix_post/validation_writer_m.py: 97%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

59 statements  

1""" 

2Data extraction from GRILLIX is a two-step process. The first step is to make a "work file" of the 

3observable signal at each time point and toroidal plane. The second is take statistics over this file, and 

4write a NetCDF which is compatible with the validation analysis 

5 

6This module performs the second step 

7""" 

8from pathlib import Path 

9import xarray as xr 

10from netCDF4 import Dataset 

11 

12import tcvx21 

13from tcvx21.record_c.record_writer_m import RecordWriter 

14from tcvx21.units_m import Quantity 

15from tcvx21.file_io.json_io_m import read_from_json 

16 

17from tcvx21.analysis.statistics_m import compute_statistical_moment_with_bootstrap, compute_statistical_moment, \ 

18 strip_moment 

19 

20 

21def convert_work_file_to_validation_netcdf(work_file: Path, output_file: Path, simulation_hierarchy: dict, 

22 statistics_interval: int = 500): 

23 """ 

24 Converts a work file from the WorkFileWriter into a StandardNetCDF which can be used 

25 with the validation analysis 

26 

27 The default statistics_interval is 500 time points, which corresponds to about 1ms 

28 """ 

29 assert work_file.exists() and work_file.suffix == '.nc' 

30 assert isinstance(simulation_hierarchy, dict), f"simulation_hierarchy should be dict, but was {type(simulation_hierarchy)}" 

31 

32 standard_dictionary = fill_standard_dict(work_file, statistics_interval, simulation_hierarchy) 

33 

34 dataset = Dataset(work_file) 

35 

36 time = dataset['time'][slice(-statistics_interval, None)] 

37 statistics_time = Quantity(time, dataset['time'].units).max() - Quantity(time, dataset['time'].units).min() 

38 

39 additional_attributes = {} 

40 

41 for attribute in ['toroidal_planes', 'points_per_plane', 'n_points_R', 'n_points_Z', 'timestep', 'timestep_units', 

42 'particle_source', 'particle_source_units', 'power_source', 'power_source_units']: 

43 

44 additional_attributes[attribute] = getattr(dataset, attribute) 

45 

46 additional_attributes['statistics_interval'] = statistics_time.magnitude 

47 additional_attributes['statistics_interval_units'] = str(f"{statistics_time.units:P}") 

48 

49 writer = RecordWriter(file_path=output_file, 

50 descriptor='GRX', 

51 description=dataset.description, 

52 allow_overwrite=True) 

53 

54 writer.write_data_dict(standard_dictionary, additional_attributes) 

55 

56 

57def Q(netcdf_array): 

58 """Converts netcdf arrays to Quantities""" 

59 return Quantity(netcdf_array.values, netcdf_array.units) 

60 

61 

62def fill_standard_dict(work_file: Path, statistics_interval: int, simulation_hierarchy: dict) -> dict: 

63 """ 

64 Iterates over the elements in the standard dictionary, and fills each observable 

65 """ 

66 print("Filling standard dict") 

67 

68 standard_dict = read_from_json(tcvx21.template_file) 

69 

70 # Expand the dictionary template to also include a region around the X-point 

71 standard_dict['Xpt'] = {'name': 'X-point', 'observables': { 

72 'density': {'name': 'Plasma density', 'units': '1 / meter ** 3'}, 

73 'electron_temp': {'name': 'Electron temperature', 'units': 'electron_volt'}, 

74 'potential': {'name': 'Plasma potential', 'units': 'volt'}}} 

75 

76 for observable in standard_dict['Xpt']['observables'].values(): 

77 observable['dimensionality'] = 2 

78 observable['experimental_hierarchy'] = -1 

79 observable['simulation_hierarchy'] = -1 

80 

81 for diagnostic, diagnostic_dict in standard_dict.items(): 

82 for observable, observable_dict in diagnostic_dict['observables'].items(): 

83 write_observable(work_file, statistics_interval, simulation_hierarchy, 

84 diagnostic, observable, observable_dict) 

85 

86 print("Done") 

87 return standard_dict 

88 

89 

90def write_observable(work_file: Path, statistics_interval: int, simulation_hierarchy: dict, 

91 diagnostic_key: str, observable_key: str, output_dict: dict): 

92 """ 

93 Calculates statistical moments and fills values into the standard_dictionary, for writing 

94 to a standard NetCDF 

95 """ 

96 

97 print(f"\tProcessing {diagnostic_key}:{observable_key}") 

98 

99 observable_key, moment = strip_moment(observable_key) 

100 try: 

101 diagnostic = xr.open_dataset(work_file, group=diagnostic_key) 

102 except OSError as e: # pragma: no cover 

103 # Catch an old name for the diagnostic 

104 if diagnostic_key == 'FHRP': 

105 diagnostic_key = 'RPTCV' 

106 diagnostic = xr.open_dataset(work_file, group=diagnostic_key) 

107 else: 

108 raise e from None 

109 try: 

110 observable = xr.open_dataset(work_file, group=f"{diagnostic_key}/observables")[observable_key] 

111 except OSError: 

112 # Catch an old term for the observables 

113 observable = xr.open_dataset(work_file, group=f"{diagnostic_key}/measurements")[observable_key] 

114 

115 observable = observable.isel(tau=slice(-statistics_interval, None)).persist() 

116 

117 output_dict['simulation_hierarchy'] = simulation_hierarchy[observable_key if moment != 'lambda_q' else moment] 

118 

119 if diagnostic_key == 'Xpt': 

120 # Don't need an error estimate for the X-point profile, and the memory requirements for bootstrapping this 

121 # can be very large 

122 value = compute_statistical_moment(observable, moment=moment) 

123 error = 0.0 * value 

124 else: 

125 value, error = compute_statistical_moment_with_bootstrap(observable, moment=moment) 

126 

127 output_dict['values'] = Q(value).to(output_dict['units']).magnitude 

128 output_dict['errors'] = Q(error).to(output_dict['units']).magnitude 

129 

130 for variable_key in diagnostic.variables.keys(): 

131 variable = diagnostic[variable_key] 

132 output_key = variable_key.replace('Rsep', 'Ru') 

133 

134 output_dict[output_key] = variable.values 

135 output_dict[f"{output_key}_units"] = getattr(variable, 'units', '')