Coverage for tcvx21/record_c/record_m.py: 85%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

97 statements  

1""" 

2Interface to a standard NetCDF file 

3""" 

4import netCDF4 

5from pathlib import Path 

6import tcvx21 

7from tcvx21 import Quantity 

8from tcvx21.observable_c.observable_m import MissingDataError 

9from tcvx21.observable_c.observable_empty_m import EmptyObservable 

10from tcvx21.observable_c.observable_1d_m import Observable1D 

11from tcvx21.observable_c.observable_2d_m import Observable2D 

12 

13 

14class Record: 

15 

16 def __init__(self, file_path: Path, label: str = None, color: str = 'C0', linestyle: str = '-'): 

17 """ 

18 Adds a link to a standard NetCDF file 

19 

20 label should be a short label for use in a legend 

21  

22 color should be a colour code that will be used to represent observables from this dataset 

23 

24 linestyle should correspond to a linestyle in code (see link below). Helpful for colour-blindness or  

25 black & white printing 

26 https://matplotlib.org/stable/gallery/lines_bars_and_markers/linestyles.html 

27 """ 

28 

29 file_path = Path(file_path) 

30 assert file_path.exists() and file_path.suffix == '.nc' 

31 

32 self.file_path = file_path 

33 self.dataset = netCDF4.Dataset(file_path) 

34 

35 if label is None: 

36 self.label = self.dataset.descriptor 

37 else: 

38 self.label = label 

39 

40 self.color = color 

41 self.linestyle = linestyle 

42 

43 self.data_template = tcvx21.read_from_json(tcvx21.template_file) 

44 

45 for diagnostic in list(self.data_template.keys()): 

46 for observable in list(self.data_template[diagnostic]['observables'].keys()): 

47 self.make_observable(diagnostic, observable) 

48 

49 @property 

50 def is_empty(self): 

51 return False 

52 

53 def set_error_to_zero(self): 

54 """For all observables, overwrite the error with zeros""" 

55 for diagnostic, observable in self.keys(): 

56 try: 

57 m = self.get_observable(diagnostic, observable) 

58 if not m.is_empty: 

59 m._errors = 0.0 * m._errors 

60 except AssertionError: 

61 # Ignore non-standard data (i.e. the X-point lineout) 

62 pass 

63 

64 def make_observable(self, diagnostic: str, observable: str): 

65 """ 

66 Makes a 'observable' object to represent the data stored in the NetCDF 

67 """ 

68 data = self.dataset[diagnostic]['observables'][observable] 

69 data_template = self.data_template[diagnostic]['observables'][observable] 

70 

71 try: 

72 dimensionality = data.dimensionality 

73 except AttributeError: 

74 dimensionality = data_template['dimensionality'] 

75 

76 try: 

77 # if dimensionality == 0: 

78 # observable_ = Observable0D(data, diagnostic, observable, 

79 # label=self.label, color=self.color, linestyle=self.linestyle) 

80 

81 if dimensionality == 1: 

82 observable_ = Observable1D(data, diagnostic, observable, 

83 label=self.label, color=self.color, linestyle=self.linestyle) 

84 

85 elif dimensionality == 2: 

86 observable_ = Observable2D(data, diagnostic, observable, 

87 label=self.label, color=self.color, linestyle=self.linestyle) 

88 

89 else: 

90 raise NotImplementedError(f"No implementation for dimensionality {dimensionality}") 

91 

92 except MissingDataError: 

93 observable_ = EmptyObservable() 

94 

95 # Overwrite the entry in the observables dictionary with the observable object 

96 self.data_template[diagnostic]['observables'][observable] = observable_ 

97 

98 def keys(self): 

99 """ 

100 Makes a (diagnostic, observable) iterator for iterating over all of the observables 

101 """ 

102 for diagnostic in self.dataset.groups: 

103 for observable in self.dataset[diagnostic]['observables'].groups: 

104 yield diagnostic, observable 

105 

106 def __getitem__(self, item) -> Quantity: 

107 """Index on the underlying dataset""" 

108 return self.dataset[item] 

109 

110 def _check_legal_access(self, diagnostic: str = None, observable: str = None): 

111 """ 

112 Makes sure that the requested data is a valid entry, and if not provide a helpful error 

113 """ 

114 if diagnostic is None: 

115 return 

116 allowed_diagnostics = list(self.data_template.keys()) 

117 assert diagnostic in allowed_diagnostics, \ 

118 f"diagnostic should be in {allowed_diagnostics} but was {diagnostic}" 

119 

120 if observable is None: 

121 return 

122 allowed_observables = list(self.data_template[diagnostic]['observables'].keys()) 

123 assert observable in allowed_observables, \ 

124 f"{diagnostic}:observable should be in {allowed_observables} but was {observable}" 

125 

126 def get_observable(self, diagnostic: str, observable: str, with_check: bool=False): 

127 """ 

128 Safe access to observables 

129 

130 Makes sure that you are accessing a valid observable, and then converts the observable into Quantity 

131 

132 Note that you can also get attributes from the NetCDF via square-bracket indexing on the named keys 

133 """ 

134 if observable.endswith('_fluct'): 

135 base_observable = observable.rstrip('fluct')[:-1] 

136 std = self.get_observable(diagnostic, f"{base_observable}_std", with_check) 

137 mean = self.get_observable(diagnostic, base_observable, with_check) 

138 if not (std.is_empty or mean.is_empty): 

139 fluct = std/mean 

140 fluct.name = f"Fluctuation of {mean.name}" 

141 if fluct.dimensionality == 1: 

142 fluct.set_plot_limits(ymin=0.0, ymax=1.0) 

143 return fluct 

144 else: 

145 return EmptyObservable() 

146 

147 self._check_legal_access(diagnostic, observable) 

148 

149 observable_ = self.data_template[diagnostic]['observables'][observable] 

150 

151 if with_check: 

152 assert isinstance(observable_, observable),\ 

153 f"Requested observable but returned type {type(observable_)}" 

154 

155 return observable_ 

156 

157 def get_nc_group(self, diagnostic: str = None, observable: str = None): 

158 """ 

159 Returns the raw NetCDF group (useful when there is missing data) 

160 """ 

161 self._check_legal_access(diagnostic, observable) 

162 

163 if diagnostic is None: 

164 return self.dataset 

165 elif observable is None: 

166 return self.dataset[diagnostic] 

167 else: 

168 return self.dataset[diagnostic]['observables'][observable] 

169 

170class EmptyRecord(Record): 

171 

172 def __init__(self): 

173 pass 

174 

175 def get_observable(self, *args, **kwargs): 

176 return EmptyObservable() 

177 

178 @property 

179 def is_empty(self): 

180 return True