Coverage for tcvx21/file_io/json_io_m.py: 86%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

28 statements  

1""" 

2Routines for writing dictionaries to JSON-format files, and reading dictionaries from JSON-format files 

3 

4Autonmatically handles numpy array i/o 

5""" 

6 

7import json 

8import numpy as np 

9from pathlib import Path 

10 

11 

12class NumpyEncoder(json.JSONEncoder): 

13 """ 

14 From https://stackoverflow.com/questions/26646362/numpy-array-is-not-json-serializable 

15 Converts numpy arrays into lists 

16 """ 

17 

18 def default(self, obj): 

19 if isinstance(obj, np.ndarray): 

20 return obj.tolist() 

21 return json.JSONEncoder.default(self, obj) 

22 

23 

24def recursive_convert_list_to_array(data: dict) -> dict: 

25 """ 

26 Converts all list elements in a nested dictionary to np.array 

27 """ 

28 

29 for key, value in data.items(): 

30 

31 if isinstance(value, list): 

32 data[key] = np.asarray(value) 

33 

34 elif isinstance(value, dict): 

35 data[key] = recursive_convert_list_to_array(value) 

36 

37 return data 

38 

39 

40def write_to_json(data: dict, filepath: Path, allow_overwrite: bool = False): 

41 """ 

42 Writes a dictionary of data to a JSON file at filepath 

43 """ 

44 

45 assert filepath.suffix == '.json', f"JSON files should be identified with the .json suffix" 

46 

47 if not allow_overwrite: 

48 assert not filepath.exists() 

49 

50 with open(filepath, 'w', encoding='utf-8') as f: 

51 json.dump(data, f, ensure_ascii=False, indent=4, cls=NumpyEncoder) 

52 

53 

54def read_from_json(filepath: Path, convert_list_to_array: bool = True) -> dict: 

55 """ 

56 Reads a dictionary of data from a JSON file at filepath, and if convert_list_to_array is True 

57 will convert all lists to np.array 

58 """ 

59 

60 assert filepath.exists(), f"File {filepath} not found" 

61 

62 with open(filepath, 'r', encoding='utf-8') as f: 

63 data = json.load(f) 

64 

65 if convert_list_to_array: 

66 data = recursive_convert_list_to_array(data) 

67 

68 return data