Coverage for tcvx21/record_c/record_m.py: 85%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2Interface to a standard NetCDF file
3"""
4import netCDF4
5from pathlib import Path
6import tcvx21
7from tcvx21 import Quantity
8from tcvx21.observable_c.observable_m import MissingDataError
9from tcvx21.observable_c.observable_empty_m import EmptyObservable
10from tcvx21.observable_c.observable_1d_m import Observable1D
11from tcvx21.observable_c.observable_2d_m import Observable2D
14class Record:
16 def __init__(self, file_path: Path, label: str = None, color: str = 'C0', linestyle: str = '-'):
17 """
18 Adds a link to a standard NetCDF file
20 label should be a short label for use in a legend
22 color should be a colour code that will be used to represent observables from this dataset
24 linestyle should correspond to a linestyle in code (see link below). Helpful for colour-blindness or
25 black & white printing
26 https://matplotlib.org/stable/gallery/lines_bars_and_markers/linestyles.html
27 """
29 file_path = Path(file_path)
30 assert file_path.exists() and file_path.suffix == '.nc'
32 self.file_path = file_path
33 self.dataset = netCDF4.Dataset(file_path)
35 if label is None:
36 self.label = self.dataset.descriptor
37 else:
38 self.label = label
40 self.color = color
41 self.linestyle = linestyle
43 self.data_template = tcvx21.read_from_json(tcvx21.template_file)
45 for diagnostic in list(self.data_template.keys()):
46 for observable in list(self.data_template[diagnostic]['observables'].keys()):
47 self.make_observable(diagnostic, observable)
49 @property
50 def is_empty(self):
51 return False
53 def set_error_to_zero(self):
54 """For all observables, overwrite the error with zeros"""
55 for diagnostic, observable in self.keys():
56 try:
57 m = self.get_observable(diagnostic, observable)
58 if not m.is_empty:
59 m._errors = 0.0 * m._errors
60 except AssertionError:
61 # Ignore non-standard data (i.e. the X-point lineout)
62 pass
64 def make_observable(self, diagnostic: str, observable: str):
65 """
66 Makes a 'observable' object to represent the data stored in the NetCDF
67 """
68 data = self.dataset[diagnostic]['observables'][observable]
69 data_template = self.data_template[diagnostic]['observables'][observable]
71 try:
72 dimensionality = data.dimensionality
73 except AttributeError:
74 dimensionality = data_template['dimensionality']
76 try:
77 # if dimensionality == 0:
78 # observable_ = Observable0D(data, diagnostic, observable,
79 # label=self.label, color=self.color, linestyle=self.linestyle)
81 if dimensionality == 1:
82 observable_ = Observable1D(data, diagnostic, observable,
83 label=self.label, color=self.color, linestyle=self.linestyle)
85 elif dimensionality == 2:
86 observable_ = Observable2D(data, diagnostic, observable,
87 label=self.label, color=self.color, linestyle=self.linestyle)
89 else:
90 raise NotImplementedError(f"No implementation for dimensionality {dimensionality}")
92 except MissingDataError:
93 observable_ = EmptyObservable()
95 # Overwrite the entry in the observables dictionary with the observable object
96 self.data_template[diagnostic]['observables'][observable] = observable_
98 def keys(self):
99 """
100 Makes a (diagnostic, observable) iterator for iterating over all of the observables
101 """
102 for diagnostic in self.dataset.groups:
103 for observable in self.dataset[diagnostic]['observables'].groups:
104 yield diagnostic, observable
106 def __getitem__(self, item) -> Quantity:
107 """Index on the underlying dataset"""
108 return self.dataset[item]
110 def _check_legal_access(self, diagnostic: str = None, observable: str = None):
111 """
112 Makes sure that the requested data is a valid entry, and if not provide a helpful error
113 """
114 if diagnostic is None:
115 return
116 allowed_diagnostics = list(self.data_template.keys())
117 assert diagnostic in allowed_diagnostics, \
118 f"diagnostic should be in {allowed_diagnostics} but was {diagnostic}"
120 if observable is None:
121 return
122 allowed_observables = list(self.data_template[diagnostic]['observables'].keys())
123 assert observable in allowed_observables, \
124 f"{diagnostic}:observable should be in {allowed_observables} but was {observable}"
126 def get_observable(self, diagnostic: str, observable: str, with_check: bool=False):
127 """
128 Safe access to observables
130 Makes sure that you are accessing a valid observable, and then converts the observable into Quantity
132 Note that you can also get attributes from the NetCDF via square-bracket indexing on the named keys
133 """
134 if observable.endswith('_fluct'):
135 base_observable = observable.rstrip('fluct')[:-1]
136 std = self.get_observable(diagnostic, f"{base_observable}_std", with_check)
137 mean = self.get_observable(diagnostic, base_observable, with_check)
138 if not (std.is_empty or mean.is_empty):
139 fluct = std/mean
140 fluct.name = f"Fluctuation of {mean.name}"
141 if fluct.dimensionality == 1:
142 fluct.set_plot_limits(ymin=0.0, ymax=1.0)
143 return fluct
144 else:
145 return EmptyObservable()
147 self._check_legal_access(diagnostic, observable)
149 observable_ = self.data_template[diagnostic]['observables'][observable]
151 if with_check:
152 assert isinstance(observable_, observable),\
153 f"Requested observable but returned type {type(observable_)}"
155 return observable_
157 def get_nc_group(self, diagnostic: str = None, observable: str = None):
158 """
159 Returns the raw NetCDF group (useful when there is missing data)
160 """
161 self._check_legal_access(diagnostic, observable)
163 if diagnostic is None:
164 return self.dataset
165 elif observable is None:
166 return self.dataset[diagnostic]
167 else:
168 return self.dataset[diagnostic]['observables'][observable]
170class EmptyRecord(Record):
172 def __init__(self):
173 pass
175 def get_observable(self, *args, **kwargs):
176 return EmptyObservable()
178 @property
179 def is_empty(self):
180 return True