Coverage for tcvx21/grillix_post/components/grid_m.py: 96%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2Interface to the simulation grid
3"""
4from pathlib import Path
5import numpy as np
6import xarray as xr
7import pandas as pd
9district_dict = {
10 -1000: "OFF_GRID",
11 # point off the grid (i.e. normally 'nan')
12 813: "CORE",
13 # point located in core (outside actual computational domain, rho<rhomin)
14 814: "CLOSED",
15 # point located in closed field line region (within computational domain)
16 815: "SOL",
17 # point located in scrape-off layer (within computational domain)
18 816: "PRIVFLUX",
19 # point located in private flux region (within computational domain)
20 817: "WALL",
21 # point located in wall (outside computational domain, rho>rhomax)
22 818: "DOME",
23 # point located in divertor dome (outside computational domain, e.g. rho<rhomin_privflux)
24 819: "OUT",
25 # point located outside additional masks, i.e. shadow region (outside computational domain)
26}
29def calculate_shaping(r_u, z_u):
30 """
31 Calculates the constant shaping arrays
32 """
33 r_s, z_s = np.unique(r_u), np.unique(z_u)
35 point_index = xr.DataArray(np.arange(r_u.size), dims='points')
36 tricolumn_data = np.column_stack((r_u, z_u, point_index))
38 shaped_index = pd.DataFrame(tricolumn_data, columns=['x', 'y', 'z']).pivot_table(
39 values='z', index='y', columns='x', dropna=False)
41 mask = xr.DataArray(np.where(np.isnan(shaped_index), np.nan, 0),
42 dims=['Z', 'R'], coords={'Z': z_s, 'R': r_s})
44 shaped_index = xr.DataArray(np.nan_to_num(shaped_index.to_numpy(), 0).astype(int),
45 dims=['Z', 'R'], coords={'Z': z_s, 'R': r_s})
47 return mask, shaped_index, point_index
50def shape_single(r_u, z_u, input_array):
51 """Shapes a single array"""
52 mask, shaped_index, point_index = calculate_shaping(r_u, z_u)
54 return input_array.isel(points=shaped_index) + mask
57class Grid:
59 def __init__(self, grid_file: Path):
60 """
61 Initialises the grid object from a grid file
62 """
63 assert grid_file.exists()
65 grid_dataset = xr.open_dataset(grid_file)
67 r_u, z_u, self.spacing = self.get_unstructured_grid_arrays_from_file(grid_dataset)
68 r_s, z_s = np.unique(r_u), np.unique(z_u)
69 self.size = r_u.size
70 self.grid_size = grid_dataset.sizes['dim_nl']
71 self.r_u, self.z_u, self.r_s, self.z_s = r_u, z_u, r_s, z_s
73 mask, shaped_index, point_index = calculate_shaping(r_u, z_u)
75 self.point_index, self.shaped_index, self.mask = point_index, shaped_index, mask
77 # Read and shape the districts array
78 raw_districts = grid_dataset['info'][0, :]
79 raw_districts = self.shape(raw_districts.rename({'dim_nl': 'points'}))
81 districts = np.empty_like(raw_districts).astype(str)
82 for key, value in district_dict.items():
83 districts[np.nan_to_num(raw_districts.values, nan=-1000).astype(int) == key] = value
85 self.districts = xr.DataArray(districts, dims=('Z', 'R'), coords={'Z': z_s, 'R': r_s})
87 @staticmethod
88 def get_unstructured_grid_arrays_from_file(grid_ds: xr.Dataset):
89 """Given a netcdf grid file, return the x_unstructured and y_unstructred arrays"""
91 x_unstructured = grid_ds.xmin + (grid_ds['li'].values - 1) * grid_ds.hf
92 y_unstructured = grid_ds.ymin + (grid_ds['lj'].values - 1) * grid_ds.hf
94 spacing = np.mean(np.diff(np.unique(x_unstructured)))
95 assert np.allclose(
96 np.diff(np.unique(x_unstructured)), spacing), "Error: x basis vector is not equally spaced."
97 assert np.allclose(
98 np.diff(np.unique(y_unstructured)), spacing), "Error: x basis vector is not equally spaced."
100 return x_unstructured, y_unstructured, spacing
102 def shape(self, input_array: xr.DataArray):
103 """
104 Converts from tricolumn data to a 2D matrix
105 """
107 return input_array.isel(points=self.shaped_index) + self.mask
109 def flatten(self, input_array: xr.DataArray):
110 """
111 Converts from a 2D matrix to tricolumn data
112 """
114 def flatten_2D(array):
115 return array.flatten()[np.logical_not(self.mask.values.flatten())]
117 return xr.apply_ufunc(flatten_2D, input_array,
118 input_core_dims=[['Z', 'R'], ],
119 output_core_dims=[['points']],
120 vectorize=True,
121 dask='parallelized',
122 dask_gufunc_kwargs={"output_sizes": {"points": self.size}},
123 output_dtypes=[np.float64]
124 )