Skip to content

Commit 21df1ba

Browse files
committed
remove unnecessary print statements
1 parent eedb416 commit 21df1ba

4 files changed

Lines changed: 39 additions & 83 deletions

File tree

emulator/src/data/climate_dataset.py

Lines changed: 23 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,18 @@
1-
import copy
2-
import logging
31
import os
42
import glob
5-
import pickle
6-
import shutil
73
import zipfile
8-
from typing import Dict, Optional, List, Callable, Tuple, Union
4+
from typing import Dict, Optional, List, Tuple, Union
95

106
import numpy as np
117
import xarray as xr
128
import torch
13-
from torch import Tensor
149

1510
from emulator.src.utils.utils import get_logger, map_variables_targetmip
1611
from emulator.src.data.constants import (
1712
LAT,
1813
LON,
1914
SEQ_LEN,
20-
INPUT4MIPS_TEMP_RES,
2115
CMIP6_TEMP_RES,
22-
INPUT4MIPS_NOM_RES,
2316
CMIP6_NOM_RES,
2417
DATA_DIR,
2518
OPENBURNING_MODEL_MAPPING,
@@ -130,9 +123,9 @@ def __init__(
130123
seq_len=seq_len,
131124
)
132125
# creates on cmip and on input4mip dataset
133-
print("Creating input4mips...")
126+
#print("Creating input4mips...")
134127
self.input4mips_ds = Input4MipsDataset(variables=in_variables_im, **ds_kwargs)
135-
print("Creating cmip6...")
128+
#print("Creating cmip6...")
136129
self.cmip6_ds = CMIP6Dataset(
137130
climate_model=climate_model,
138131
num_ensembles=num_ensembles,
@@ -151,7 +144,7 @@ def load_into_mem(
151144
): # -> np.ndarray():
152145
array_list = []
153146
for vlist in paths:
154-
print("Number of files per var:", len(vlist))
147+
#print("Number of files per var:", len(vlist))
155148
temp_data = xr.open_mfdataset(
156149
vlist, concat_dim="time", combine="nested"
157150
).compute() # .compute is not necessary but eh, doesn't hurt
@@ -162,9 +155,9 @@ def load_into_mem(
162155
temp_data = np.concatenate(array_list, axis=0)
163156

164157
if seq_len != SEQ_LEN:
165-
print(
166-
"Choosing a sequence length greater or lesser than the data sequence length."
167-
)
158+
# print(
159+
# "Choosing a sequence length greater or lesser than the data sequence length."
160+
# )
168161
new_num_years = int(
169162
np.floor(temp_data.shape[1] / seq_len / len(self.scenarios))
170163
)
@@ -272,10 +265,7 @@ def get_years_list(self, years: str, give_list: Optional[bool] = False):
272265
273266
"""
274267
if len(years) != 9:
275-
log.warn(
276-
"Years string must be in the format xxxx-yyyy eg. 2015-2100 with string length 9. Please check the year string."
277-
)
278-
raise ValueError
268+
raise ValueError("Years string must be in the format xxxx-yyyy eg. 2015-2100 with string length 9. Please check the year string.")
279269
splits = years.split("-")
280270
min_year, max_year = int(splits[0]), int(splits[1])
281271

@@ -292,9 +282,9 @@ def get_dataset_statistics(self, data, mode, type="z-norm", mips="cmip6"):
292282
min_val, max_val = self.get_min_max(data)
293283
return min_val, max_val
294284
else:
295-
print("Normalizing of type {0} has not been implemented!".format(type))
285+
raise NotImplementedError(f"Normalizing of type {type} has not been implemented!")
296286
else:
297-
print("In testing mode, skipping statistics calculations.")
287+
log.warning("In testing mode, skipping statistics calculations.")
298288

299289
def get_mean_std(self, data):
300290
# data shape (years*scenarios, seq, vars, lat, lon)
@@ -330,7 +320,7 @@ def normalize_data(self, data, stats, type="z-norm"):
330320
# z-norm: (data-mean)/(std + eps); eps=1e-9
331321
# min-max = (v - v.min()) / (v.max() - v.min())
332322

333-
print("Normalizing data...")
323+
#print("Normalizing data...")
334324
if self.channels_last:
335325
data = np.moveaxis(
336326
data, -1, 0
@@ -340,7 +330,7 @@ def normalize_data(self, data, stats, type="z-norm"):
340330
data, 2, 0
341331
) # shape (years, seq_len, num_vars, lat, lon) -> (num_vars, years, seq_len, lat, lon)
342332

343-
print("mean", stats["mean"].shape, "std", stats["std"].shape)
333+
#print("mean", stats["mean"].shape, "std", stats["std"].shape)
344334
norm_data = (data - stats["mean"]) / (stats["std"])
345335

346336
if self.channels_last:
@@ -395,9 +385,9 @@ def __str__(self):
395385
return s
396386

397387
def __len__(self):
398-
print(
399-
"Input4mips", self.input4mips_ds.length, "CMIP6 data", self.cmip6_ds.length
400-
)
388+
# print(
389+
# "Input4mips", self.input4mips_ds.length, "CMIP6 data", self.cmip6_ds.length
390+
# )
401391
# cmip must be num_ensemble members times input4mips
402392
assert (
403393
self.input4mips_ds.length * self.num_ensembles == self.cmip6_ds.length
@@ -453,18 +443,15 @@ def __init__( # inherits all the stuff from Base
453443
if isinstance(climate_model, str):
454444
self.root_dir = os.path.join(self.root_dir, climate_model)
455445
else:
456-
log.warn(
457-
"For loading multiple climate models, please make sure to use the Super Climate Dataset Class."
458-
)
459-
raise NotImplementedError
446+
raise NotImplementedError("For loading multiple climate models, please make sure to use the Super Climate Dataset Class.")
460447

461448
if num_ensembles == 1:
462449
ensembles = os.listdir(self.root_dir)
463450
self.ensemble_dir = [
464451
os.path.join(self.root_dir, ensembles[0])
465452
] # Taking first ensemble member
466453
else:
467-
print("Multiple ensembles", num_ensembles)
454+
#print("Multiple ensembles", num_ensembles)
468455
self.ensemble_dir = []
469456
ensembles = os.listdir(self.root_dir)
470457
for i, folder in enumerate(ensembles):
@@ -484,7 +471,7 @@ def __init__( # inherits all the stuff from Base
484471
os.path.join(output_save_dir, fname)
485472
): # we first need to get the name here to test that...
486473
self.data_path = os.path.join(output_save_dir, fname)
487-
print("path exists, reloading")
474+
#print("path exists, reloading")
488475
self.Data = self._reload_data(self.data_path)
489476

490477
# Load stats and normalize
@@ -516,13 +503,7 @@ def __init__( # inherits all the stuff from Base
516503
)
517504
files = glob.glob(var_dir + f"/*.nc", recursive=True)
518505
if len(files) == 0:
519-
print(
520-
"No files for this scenario, year, ensemble member pairing:",
521-
exp,
522-
y,
523-
em,
524-
)
525-
exit(0)
506+
raise FileNotFoundError(f"No files could be found for scenario {exp}, year {y}, and ensemble member {em}. Check if climate model runs for that pairing actually exist.")
526507
# loads all years!
527508
output_nc_files += files
528509
files_per_var.append(output_nc_files)
@@ -540,7 +521,6 @@ def __init__( # inherits all the stuff from Base
540521
)
541522

542523
if os.path.isfile(stats_fname):
543-
print("Stats file already exists! Loading from memory.")
544524
stats = self.load_statistics_data(stats_fname)
545525
self.norm_data = self.normalize_data(self.raw_data, stats)
546526

@@ -552,7 +532,7 @@ def __init__( # inherits all the stuff from Base
552532
self.norm_data = self.normalize_data(self.raw_data, stats)
553533

554534
save_file_name = self.write_dataset_statistics(stats_fname, stats)
555-
print("WROTE STATISTICS", save_file_name)
535+
#print("WROTE STATISTICS", save_file_name)
556536

557537
self.norm_data = self.normalize_data(self.raw_data, stats)
558538

@@ -630,14 +610,14 @@ def __init__(
630610
os.path.join(output_save_dir, fname)
631611
): # we first need to get the name here to test that...
632612
self.data_path = os.path.join(output_save_dir, fname)
633-
print("path exists, reloading")
613+
#print("path exists, reloading")
634614
self.Data = self._reload_data(self.data_path)
635615

636616
# Load stats and normalize
637617
stats_fname = self.get_save_name_from_kwargs(
638618
mode=mode, file="statistics", kwargs=fname_kwargs
639619
)
640-
print(stats_fname)
620+
#print(stats_fname)
641621
stats = self.load_dataset_statistics(
642622
os.path.join(self.output_save_dir, stats_fname),
643623
mode=self.mode,
@@ -705,7 +685,7 @@ def __init__(
705685
)
706686

707687
if os.path.isfile(stats_fname):
708-
print("Stats file already exists! Loading from mempory.")
688+
#print("Stats file already exists! Loading from mempory.")
709689
stats = self.load_statistics_data(stats_fname)
710690
self.norm_data = self.normalize_data(self.raw_data, stats)
711691

emulator/src/data/super_climate_dataset.py

Lines changed: 13 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,21 @@
1-
import copy
2-
import logging
31
import os
42
import glob
5-
import pickle
6-
import shutil
73
import zipfile
8-
from typing import Dict, Optional, List, Callable, Tuple, Union
9-
import copy
4+
from typing import Dict, Optional, List, Tuple, Union
105
import numpy as np
116
import xarray as xr
127
import torch
13-
from torch import Tensor
14-
import threading
158

169

1710
from emulator.src.utils.utils import get_logger, all_equal, map_variables_targetmip
1811
from emulator.src.data.constants import (
1912
LON,
2013
LAT,
2114
SEQ_LEN,
22-
INPUT4MIPS_TEMP_RES,
2315
CMIP6_TEMP_RES,
24-
INPUT4MIPS_NOM_RES,
2516
CMIP6_NOM_RES,
2617
DATA_DIR,
27-
OPENBURNING_MODEL_MAPPING,
2818
NO_OPENBURNING_VARS,
29-
AVAILABLE_MODELS_FIRETYPE,
3019
)
3120
log = get_logger()
3221
from abc import ABC, abstractmethod
@@ -267,9 +256,9 @@ def get_dataset_statistics(self, data: np.ndarray, mode: str, norm_type: str = "
267256
elif norm_type == "minmax":
268257
return self.get_min_max(data)
269258
else:
270-
print(f"Normalization of type {norm_type} has not been implemented!")
259+
raise NotImplementedError(f"Normalization of type {norm_type} has not been implemented!")
271260
else:
272-
print("In testing mode, skipping statistics calculations.")
261+
log.warning("In testing mode, skipping statistics calculations.")
273262

274263
def get_mean_std(self, data: np.ndarray):
275264
"""
@@ -329,8 +318,7 @@ def __len__(self) -> int:
329318
elif self.mode == 'val':
330319
return len(self.index_manager.val_indexes)
331320
else:
332-
print(f"Unknown mode: {self.mode}")
333-
raise ValueError
321+
raise ValueError(f"Unknown mode: {self.mode}")
334322

335323

336324
class SuperClimateDataset(ABC_Climate_Dataset):
@@ -473,10 +461,7 @@ def get_years_list(self, years: str, give_list: Optional[bool] = False):
473461
474462
"""
475463
if len(years) != 9:
476-
log.warn(
477-
"Years string must be in the format xxxx-yyyy eg. 2015-2100 with string length 9. Please check the year string."
478-
)
479-
raise ValueError
464+
raise ValueError("Years string must be in the format xxxx-yyyy eg. 2015-2100 with string length 9. Please check the year string.")
480465
splits = years.split("-")
481466
min_year, max_year = int(splits[0]), int(splits[1])
482467

@@ -519,8 +504,7 @@ def __getitem__(self, index): # Dict[str, Tensor]):
519504
return X, Y, model_id
520505

521506
def __str__(self):
522-
s = f" Super Emulator dataset: {len(self.index_manager.climate_models)} climate models with {self.index_manager.num_ensembles} ensemble members and {self.n_years} years used, with a total size of {len(self)} examples (in, out)."
523-
return s
507+
return f" Super Emulator dataset: {len(self.index_manager.climate_models)} climate models with {self.index_manager.num_ensembles} ensemble members and {self.n_years} years used, with a total size of {len(self)} examples (in, out)."
524508

525509

526510
def __len__(self):
@@ -531,8 +515,7 @@ def __len__(self):
531515
# elif self.mode=='train+val':
532516
# return self.get_initial_length()
533517
else:
534-
print("Unknown mode.", self.mode)
535-
raise ValueError
518+
raise ValueError(f"Unknown mode: {self.mode}")
536519

537520

538521

@@ -588,7 +571,7 @@ def __init__( # inherits all the stuff from Base
588571
os.path.join(output_save_dir, fname)
589572
): # we first need to get the name here to test that...
590573
self.data_path = os.path.join(output_save_dir, fname)
591-
print("path exists, reloading")
574+
#print("path exists, reloading")
592575
self.Data = self._reload_data(self.data_path)
593576

594577
# Load stats and normalize
@@ -621,16 +604,7 @@ def __init__( # inherits all the stuff from Base
621604
)
622605
files = glob.glob(var_dir + f"/*.nc", recursive=True)
623606
if len(files) == 0:
624-
print(
625-
"No files for this climate model, ensemble member, var, year ,scenario:",
626-
climate_model,
627-
data_dir.split("/")[-1],
628-
var,
629-
y,
630-
exp,
631-
)
632-
print("Exiting! Please fix the data issue.")
633-
exit(0)
607+
raise FileNotFoundError(f"No files for climate model {climate_model}, ensemble member {data_dir.split("/")[-1]}, var {var}, year {y}, scenario {exp}. Please check if climate model runs for this exact pairing actually exist.")
634608
# loads all years! implement splitting
635609
output_nc_files += files
636610
files_per_var.append(output_nc_files)
@@ -648,7 +622,7 @@ def __init__( # inherits all the stuff from Base
648622
)
649623

650624
if os.path.isfile(fname):
651-
print("Stats file already exists! Loading from memory.")
625+
#print("Stats file already exists! Loading from memory.")
652626
stats = self.load_statistics_data(stats_fname)
653627
self.norm_data = self.normalize_data(self.raw_data, stats)
654628

@@ -659,7 +633,7 @@ def __init__( # inherits all the stuff from Base
659633
stats = {"mean": stat1, "std": stat2}
660634
self.norm_data = self.normalize_data(self.raw_data, stats)
661635
save_file_name = self.write_dataset_statistics(stats_fname, stats)
662-
print("WROTE STATISTICS", save_file_name)
636+
#print("WROTE STATISTICS", save_file_name)
663637

664638
self.norm_data = self.normalize_data(self.raw_data, stats)
665639

@@ -742,7 +716,7 @@ def __init__( # inherits all the stuff from Base
742716
os.path.join(output_save_dir, fname)
743717
): # we first need to get the name here to test that...
744718
self.data_path = os.path.join(output_save_dir, fname)
745-
print("path exists, reloading")
719+
#print("path exists, reloading")
746720
self.Data = self._reload_data(self.data_path)
747721

748722
# Load stats and normalize
@@ -813,7 +787,7 @@ def __init__( # inherits all the stuff from Base
813787
)
814788

815789
if os.path.isfile(stats_fname):
816-
print("Stats file already exists! Loading from mempory.")
790+
#print("Stats file already exists! Loading from mempory.")
817791
stats = self.load_statistics_data(stats_fname)
818792
self.norm_data = self.normalize_data(self.raw_data, stats)
819793

emulator/src/datamodules/climate_datamodule.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def __init__(
104104
for model in self.test_models
105105
]
106106
self.emissions_tracker = self.hparams.emissions_tracker
107-
print("Test Sets: ", self.test_set_names)
107+
#print("Test Sets: ", self.test_set_names)
108108

109109
self._data_train = None
110110
self._data_val = None

emulator/src/utils/interface.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,8 @@ def get_datamodule(config: DictConfig) -> DummyDataModule:
112112
# _recursive_=False
113113
# )
114114

115+
# hydra automaticall instantiates the right class type (specified in the config)
116+
# to test this you can run print(type(data_module).__name__)
115117
data_module: DummyDataModule = hydra.utils.instantiate(
116118
config.datamodule,
117119
# input_transform=config.model.get("input_transform"),

0 commit comments

Comments
 (0)