Coverage for tcvx21/observable_c/observable_m.py: 87%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

121 statements  

1""" 

2Simple data container for a observable 

3""" 

4from tcvx21 import Quantity 

5import numpy as np 

6 

7 

8class MissingDataError(Exception): 

9 """An error to indicate that the observable is missing data""" 

10 pass 

11 

12 

13class Observable: 

14 

15 def __init__(self, data, diagnostic, observable, label, color, linestyle): 

16 """Simple container for individual observables""" 

17 

18 try: 

19 self.name = data.observable_name 

20 self.label = label 

21 self.color = color 

22 self.linestyle = linestyle 

23 

24 self.diagnostic, self.observable = diagnostic, observable 

25 self.dimensionality = data.dimensionality 

26 self.check_dimensionality() 

27 self.experimental_hierarchy = data.experimental_hierarchy 

28 

29 self.simulation_hierarchy = getattr(data, 'simulation_hierarchy', None) 

30 

31 self._values = Quantity(data['value'][:], data['value'].units) 

32 try: 

33 self._errors = Quantity(data['error'][:], data['error'].units).to(self._values.units) 

34 except IndexError: 

35 self._errors = Quantity(np.zeros_like(self._values), data['value'].units).to(self._values.units) 

36 

37 self.mask = np.ones_like(self._values).astype(bool) 

38 

39 except (AttributeError, IndexError): 

40 raise MissingDataError(f"Missing data for {diagnostic}:{observable}. Data available is {data}") 

41 

42 def check_dimensionality(self): 

43 raise NotImplementedError() 

44 

45 @property 

46 def values(self) -> Quantity: 

47 """Returns the observable values, with a mask applied if applicable""" 

48 return self._values[self.mask] 

49 

50 @property 

51 def errors(self) -> Quantity: 

52 """Returns the observable errors, with a mask applied if applicable""" 

53 return self._errors[self.mask] 

54 

55 @property 

56 def units(self) -> str: 

57 """Returns the units of the values and errors, as a string""" 

58 return str(self._values.units) 

59 

60 @property 

61 def is_empty(self): 

62 return False 

63 

64 @property 

65 def has_errors(self): 

66 return bool(np.count_nonzero(self.errors)) 

67 

68 @property 

69 def compact_units(self) -> str: 

70 """Units with compact suffix""" 

71 if self.values.check('[length]^-3'): 

72 # Don't convert 10^19 m^-3 to ~10 1/µm^3 

73 return str(self.values.units) 

74 else: 

75 return str(np.max(np.abs(self.values)).to_compact().units) 

76 

77 @property 

78 def npts(self): 

79 """Returns the number of unmasked observable points""" 

80 return self.values.size 

81 

82 def nan_mask(self): 

83 """Returns a mask which will remove NaN values""" 

84 return np.logical_and(~np.isnan(self._values), ~np.isnan(self._errors)) 

85 

86 def check_attributes(self, other): 

87 self.mask = np.logical_and(self.mask, other.mask) 

88 assert self.color == other.color 

89 assert self.label == other.label 

90 assert self.dimensionality == other.dimensionality 

91 assert self.linestyle == other.linestyle 

92 if hasattr(self, '_positions_rsep'): 

93 assert np.allclose(self._positions_rsep, other._positions_rsep, equal_nan=True) 

94 if hasattr(self, '_positions_zx'): 

95 assert np.allclose(self._positions_zx, other._positions_zx, equal_nan=True) 

96 

97 def fill_attributes(self, result): 

98 """Fills the attributes when copying to make a new object""" 

99 result.mask = self.mask 

100 result.color = self.color 

101 result.label = self.label 

102 result.dimensionality = self.dimensionality 

103 result.linestyle = self.linestyle 

104 

105 if hasattr(self, 'xmin') and hasattr(self, 'xmax'): 

106 result.xmin, result.xmax, result.ymin, result.ymax = self.xmin, self.xmax, None, None 

107 if hasattr(self, '_positions_rsep'): 

108 result._positions_rsep = self._positions_rsep 

109 if hasattr(self, '_positions_zx'): 

110 result._positions_zx = self._positions_zx 

111 

112 def __add__(self, other): 

113 assert type(self) == type(other) 

114 result = object.__new__(self.__class__) 

115 result._values = self._values + other._values 

116 result._errors = np.sqrt(self._errors**2 + other._errors**2) 

117 self.fill_attributes(result) 

118 result.check_attributes(other) 

119 

120 return result 

121 

122 def __sub__(self, other): 

123 assert type(self) == type(other) 

124 result = object.__new__(self.__class__) 

125 result._values = self._values - other._values 

126 result._errors = np.sqrt(self._errors**2 + other._errors**2) 

127 self.fill_attributes(result) 

128 result.check_attributes(other) 

129 

130 return result 

131 

132 def __mul__(self, other): 

133 result = object.__new__(self.__class__) 

134 if isinstance(other, (float, Quantity)): 

135 # Scalar multiplication 

136 result._values = self._values * other 

137 result._errors = self._errors * other 

138 self.fill_attributes(result) 

139 

140 else: 

141 assert type(self) == type(other) 

142 result._values = self._values * other._values 

143 result._errors = result._values * np.sqrt((self._errors/self._values)**2 + (other._errors/other._values)**2) 

144 self.fill_attributes(result) 

145 result.check_attributes(other) 

146 

147 return result 

148 

149 def __truediv__(self, other): 

150 assert type(self) == type(other) 

151 assert self._values.size == other._values.size 

152 

153 result = object.__new__(self.__class__) 

154 result._values = self._values / other._values 

155 result._errors = result._values * np.sqrt((self._errors/self._values)**2 + (other._errors/other._values)**2) 

156 self.fill_attributes(result) 

157 result.check_attributes(other) 

158 

159 return result 

160 

161 def trim_to_mask(self, mask): 

162 result = object.__new__(self.__class__) 

163 result._values = self._values[mask] 

164 result._errors = self._errors[mask] 

165 self.fill_attributes(result) 

166 result.mask = np.ones_like(result._values).astype(bool) 

167 

168 if hasattr(self, '_positions_rsep'): 

169 result._positions_rsep = self._positions_rsep[mask] 

170 if hasattr(self, '_positions_zx'): 

171 result._positions_zx = self._positions_zx[mask] 

172 

173 return result