diff --git a/ds_toolkit.yaml b/ds_toolkit.yaml index 4e14491..24a68a1 100644 --- a/ds_toolkit.yaml +++ b/ds_toolkit.yaml @@ -23,3 +23,5 @@ dependencies: - imageio - scikit-image - optuna + - seaborn + - hydra-core diff --git a/ds_toolkit_metal.yaml b/ds_toolkit_metal.yaml index 057a4d8..d469403 100644 --- a/ds_toolkit_metal.yaml +++ b/ds_toolkit_metal.yaml @@ -26,3 +26,5 @@ dependencies: - imageio - scikit-image - optuna + - seaborn + - hydra-core diff --git a/jlab_datascience_toolkit/agents/__init__.py b/jlab_datascience_toolkit/agents/__init__.py new file mode 100644 index 0000000..28eb64b --- /dev/null +++ b/jlab_datascience_toolkit/agents/__init__.py @@ -0,0 +1,5 @@ +from jlab_datascience_toolkit.agents.registration import ( + register, + make, + list_registered_modules, +) diff --git a/jlab_datascience_toolkit/agents/registration.py b/jlab_datascience_toolkit/agents/registration.py new file mode 100644 index 0000000..29a0000 --- /dev/null +++ b/jlab_datascience_toolkit/agents/registration.py @@ -0,0 +1,109 @@ +import importlib +import logging + +module_log = logging.getLogger("Module Registry") + + +def load(name): + mod_name, attr_name = name.split(":") + print(f"Attempting to load {mod_name} with {attr_name}") + mod = importlib.import_module(mod_name) + fn = getattr(mod, attr_name) + return fn + + +class ModuleSpec(object): + def __init__(self, id, entry_point=None, kwargs=None): + self.id = id + self.entry_point = entry_point + self._kwargs = {} if kwargs is None else kwargs + + def make(self, **kwargs): + """Instantiates an instance of data module with appropriate kwargs""" + if self.entry_point is None: + module_log.error( + "Attempting to make deprecated module {}. \ + (HINT: is there a newer registered version \ + of this module?)".format( + self.id + ) + ) + raise RuntimeError + + _kwargs = self._kwargs.copy() + _kwargs.update(kwargs) + if callable(self.entry_point): + gen = self.entry_point(**_kwargs) + else: + cls = load(self.entry_point) + gen = cls(**_kwargs) + + return gen + + +class ModuleRegistry(object): + def __init__(self): + self.module_specs = {} + + def make(self, path, **kwargs): + if len(kwargs) > 0: + module_log.info("Making new module: %s (%s)", path, kwargs) + else: + module_log.info("Making new module: %s", path) + module_spec = self.spec(path) + module = module_spec.make(**kwargs) + + return module + + def all(self): + return self.module_specs.values() + + def spec(self, path): + if ":" in path: + mod_name, _sep, id = path.partition(":") + try: + importlib.import_module(mod_name) + except ImportError: + module_log.error( + "A module ({}) was specified for the module but was not found, \ + make sure the package is installed with `pip install` before \ + calling `module.make()`".format( + mod_name + ) + ) + raise + + else: + id = path + + try: + return self.module_specs[id] + except KeyError: + module_log.error("No registered module with id: {}".format(id)) + raise + + def register(self, id, **kwargs): + if id in self.module_specs: + module_log.error("Cannot re-register id: {}".format(id)) + raise RuntimeError + self.module_specs[id] = ModuleSpec(id, **kwargs) + + +# Global registry +module_registry = ModuleRegistry() + + +def register(id, **kwargs): + return module_registry.register(id, **kwargs) + + +def make(id, **kwargs): + return module_registry.make(id, **kwargs) + + +def spec(id): + return module_registry.spec(id) + + +def list_registered_modules(): + return list(module_registry.module_specs.keys()) diff --git a/jlab_datascience_toolkit/analyses/__init__.py b/jlab_datascience_toolkit/analyses/__init__.py index 142d64b..28ca97f 100644 --- a/jlab_datascience_toolkit/analyses/__init__.py +++ b/jlab_datascience_toolkit/analyses/__init__.py @@ -1,4 +1,4 @@ -from jlab_datascience_toolkit.utils.registration import ( +from jlab_datascience_toolkit.analyses.registration import ( register, make, list_registered_modules, diff --git a/jlab_datascience_toolkit/analyses/multiclass_analysis_v0.py b/jlab_datascience_toolkit/analyses/multiclass_analysis_v0.py index 1027f88..5960398 100644 --- a/jlab_datascience_toolkit/analyses/multiclass_analysis_v0.py +++ b/jlab_datascience_toolkit/analyses/multiclass_analysis_v0.py @@ -5,8 +5,8 @@ class Analysis: - def __init__(self, configs: dict): - self.configs = configs + def __init__(self, config: dict): + self.configs = config def run( self, diff --git a/jlab_datascience_toolkit/analyses/registration.py b/jlab_datascience_toolkit/analyses/registration.py new file mode 100644 index 0000000..29a0000 --- /dev/null +++ b/jlab_datascience_toolkit/analyses/registration.py @@ -0,0 +1,109 @@ +import importlib +import logging + +module_log = logging.getLogger("Module Registry") + + +def load(name): + mod_name, attr_name = name.split(":") + print(f"Attempting to load {mod_name} with {attr_name}") + mod = importlib.import_module(mod_name) + fn = getattr(mod, attr_name) + return fn + + +class ModuleSpec(object): + def __init__(self, id, entry_point=None, kwargs=None): + self.id = id + self.entry_point = entry_point + self._kwargs = {} if kwargs is None else kwargs + + def make(self, **kwargs): + """Instantiates an instance of data module with appropriate kwargs""" + if self.entry_point is None: + module_log.error( + "Attempting to make deprecated module {}. \ + (HINT: is there a newer registered version \ + of this module?)".format( + self.id + ) + ) + raise RuntimeError + + _kwargs = self._kwargs.copy() + _kwargs.update(kwargs) + if callable(self.entry_point): + gen = self.entry_point(**_kwargs) + else: + cls = load(self.entry_point) + gen = cls(**_kwargs) + + return gen + + +class ModuleRegistry(object): + def __init__(self): + self.module_specs = {} + + def make(self, path, **kwargs): + if len(kwargs) > 0: + module_log.info("Making new module: %s (%s)", path, kwargs) + else: + module_log.info("Making new module: %s", path) + module_spec = self.spec(path) + module = module_spec.make(**kwargs) + + return module + + def all(self): + return self.module_specs.values() + + def spec(self, path): + if ":" in path: + mod_name, _sep, id = path.partition(":") + try: + importlib.import_module(mod_name) + except ImportError: + module_log.error( + "A module ({}) was specified for the module but was not found, \ + make sure the package is installed with `pip install` before \ + calling `module.make()`".format( + mod_name + ) + ) + raise + + else: + id = path + + try: + return self.module_specs[id] + except KeyError: + module_log.error("No registered module with id: {}".format(id)) + raise + + def register(self, id, **kwargs): + if id in self.module_specs: + module_log.error("Cannot re-register id: {}".format(id)) + raise RuntimeError + self.module_specs[id] = ModuleSpec(id, **kwargs) + + +# Global registry +module_registry = ModuleRegistry() + + +def register(id, **kwargs): + return module_registry.register(id, **kwargs) + + +def make(id, **kwargs): + return module_registry.make(id, **kwargs) + + +def spec(id): + return module_registry.spec(id) + + +def list_registered_modules(): + return list(module_registry.module_specs.keys()) diff --git a/jlab_datascience_toolkit/cores/jdst_agent.py b/jlab_datascience_toolkit/cores/jdst_agent.py new file mode 100644 index 0000000..2035965 --- /dev/null +++ b/jlab_datascience_toolkit/cores/jdst_agent.py @@ -0,0 +1,13 @@ +from jlab_datascience_toolkit.cores.jdst_module import JDSTModule +from abc import ABC, abstractmethod + + +class JDSTAgent(JDSTModule, ABC): + """ + Base class for an agent. This class inherits from the module base class. + """ + + # Get a prediction: + @abstractmethod + def predict(self): + raise NotImplementedError \ No newline at end of file diff --git a/jlab_datascience_toolkit/data_parsers/__init__.py b/jlab_datascience_toolkit/data_parsers/__init__.py index 1322e2a..57a3e8d 100644 --- a/jlab_datascience_toolkit/data_parsers/__init__.py +++ b/jlab_datascience_toolkit/data_parsers/__init__.py @@ -1,4 +1,4 @@ -from jlab_datascience_toolkit.utils.registration import ( +from jlab_datascience_toolkit.data_parsers.registration import ( register, make, list_registered_modules, diff --git a/jlab_datascience_toolkit/data_parsers/famous_datasets_v0.py b/jlab_datascience_toolkit/data_parsers/famous_datasets_v0.py index 6ed7878..728f129 100644 --- a/jlab_datascience_toolkit/data_parsers/famous_datasets_v0.py +++ b/jlab_datascience_toolkit/data_parsers/famous_datasets_v0.py @@ -1,8 +1,10 @@ from jlab_datascience_toolkit.cores.jdst_data_parser import JDSTDataParser import seaborn as sns import inspect +import logging import yaml +test_log = logging.Logger(__name__) class FamousDatasetsV0(JDSTDataParser): """Returns one of the example famous datasets such as iris. @@ -43,11 +45,14 @@ class FamousDatasetsV0(JDSTDataParser): save_data() Does nothing """ - def __init__(self, configs: dict): - self.configs = configs - self.dataset_name = configs['dataset_name'] - self.settings = {k: v for k, v in configs.items() if k not in {'dataset_name', 'registered_name'}} - + def __init__(self, config: dict): + self.configs = config + try: + self.dataset_name = config['dataset_name'] + self.settings = {k: v for k, v in config.items() if k not in {'dataset_name', 'registered_name'}} + except: + test_log.error(">>> No valid configuration provided <<<") + def load_data(self): if self.dataset_name == 'iris': return sns.load_dataset('iris', **self.settings) diff --git a/jlab_datascience_toolkit/data_parsers/numpy_parser.py b/jlab_datascience_toolkit/data_parsers/numpy_parser.py index 73a61fc..77826b7 100644 --- a/jlab_datascience_toolkit/data_parsers/numpy_parser.py +++ b/jlab_datascience_toolkit/data_parsers/numpy_parser.py @@ -13,7 +13,7 @@ class NumpyParser(JDSTDataParser): ii) Combine single .npy files into one Input(s): - i) Full path to .yaml configuration file + i) Configuration file ii) Optional: User configuration, i.e. a python dict with additonal / alternative settings Output(s): @@ -22,12 +22,12 @@ class NumpyParser(JDSTDataParser): # Initialize: # ********************************************* - def __init__(self, path_to_cfg, user_config={}): + def __init__(self, config, user_config={}): # Set the name specific to this module: self.module_name = "numpy_parser" # Load the configuration: - self.config = self.load_config(path_to_cfg, user_config) + self.config = self.load_config(config, user_config) # Save this config, if a path is provided: if "store_cfg_loc" in self.config: @@ -35,7 +35,7 @@ def __init__(self, path_to_cfg, user_config={}): # Run sanity check(s): # i) Make sure that the provide data path(s) are list objects: - if isinstance(self.config["data_loc"], list) == False: + if bool(self.config) and isinstance(self.config["data_loc"], list) == False: logging.error( ">>> " + self.module_name @@ -54,25 +54,24 @@ def get_info(self): # Handle configurations: # ********************************************* # Load the config: - def load_config(self, path_to_cfg, user_config): - with open(path_to_cfg, "r") as file: - cfg = yaml.safe_load(file) - - # Overwrite config with user settings, if provided - try: + def load_config(self, config, user_config): + if config is not None: + try: if bool(user_config): # ++++++++++++++++++++++++ for key in user_config: - cfg[key] = user_config[key] + config[key] = user_config[key] # ++++++++++++++++++++++++ - except: + except: logging.exception( ">>> " + self.module_name + ": Invalid user config. Please make sure that a dictionary is provided <<<" ) + else: + config = {} - return cfg + return config # ----------------------------- diff --git a/jlab_datascience_toolkit/data_parsers/registration.py b/jlab_datascience_toolkit/data_parsers/registration.py new file mode 100644 index 0000000..29a0000 --- /dev/null +++ b/jlab_datascience_toolkit/data_parsers/registration.py @@ -0,0 +1,109 @@ +import importlib +import logging + +module_log = logging.getLogger("Module Registry") + + +def load(name): + mod_name, attr_name = name.split(":") + print(f"Attempting to load {mod_name} with {attr_name}") + mod = importlib.import_module(mod_name) + fn = getattr(mod, attr_name) + return fn + + +class ModuleSpec(object): + def __init__(self, id, entry_point=None, kwargs=None): + self.id = id + self.entry_point = entry_point + self._kwargs = {} if kwargs is None else kwargs + + def make(self, **kwargs): + """Instantiates an instance of data module with appropriate kwargs""" + if self.entry_point is None: + module_log.error( + "Attempting to make deprecated module {}. \ + (HINT: is there a newer registered version \ + of this module?)".format( + self.id + ) + ) + raise RuntimeError + + _kwargs = self._kwargs.copy() + _kwargs.update(kwargs) + if callable(self.entry_point): + gen = self.entry_point(**_kwargs) + else: + cls = load(self.entry_point) + gen = cls(**_kwargs) + + return gen + + +class ModuleRegistry(object): + def __init__(self): + self.module_specs = {} + + def make(self, path, **kwargs): + if len(kwargs) > 0: + module_log.info("Making new module: %s (%s)", path, kwargs) + else: + module_log.info("Making new module: %s", path) + module_spec = self.spec(path) + module = module_spec.make(**kwargs) + + return module + + def all(self): + return self.module_specs.values() + + def spec(self, path): + if ":" in path: + mod_name, _sep, id = path.partition(":") + try: + importlib.import_module(mod_name) + except ImportError: + module_log.error( + "A module ({}) was specified for the module but was not found, \ + make sure the package is installed with `pip install` before \ + calling `module.make()`".format( + mod_name + ) + ) + raise + + else: + id = path + + try: + return self.module_specs[id] + except KeyError: + module_log.error("No registered module with id: {}".format(id)) + raise + + def register(self, id, **kwargs): + if id in self.module_specs: + module_log.error("Cannot re-register id: {}".format(id)) + raise RuntimeError + self.module_specs[id] = ModuleSpec(id, **kwargs) + + +# Global registry +module_registry = ModuleRegistry() + + +def register(id, **kwargs): + return module_registry.register(id, **kwargs) + + +def make(id, **kwargs): + return module_registry.make(id, **kwargs) + + +def spec(id): + return module_registry.spec(id) + + +def list_registered_modules(): + return list(module_registry.module_specs.keys()) diff --git a/jlab_datascience_toolkit/data_preps/__init__.py b/jlab_datascience_toolkit/data_preps/__init__.py index 303b12b..04148a6 100644 --- a/jlab_datascience_toolkit/data_preps/__init__.py +++ b/jlab_datascience_toolkit/data_preps/__init__.py @@ -1,9 +1,10 @@ -from jlab_datascience_toolkit.utils.registration import ( +from jlab_datascience_toolkit.data_preps.registration import ( register, make, list_registered_modules, ) + register( id="NumpyMinMaxScaler_v0", entry_point="jlab_datascience_toolkit.data_preps.numpy_minmax_scaler:NumpyMinMaxScaler", diff --git a/jlab_datascience_toolkit/data_preps/numpy_minmax_scaler.py b/jlab_datascience_toolkit/data_preps/numpy_minmax_scaler.py index 191f98f..3250121 100644 --- a/jlab_datascience_toolkit/data_preps/numpy_minmax_scaler.py +++ b/jlab_datascience_toolkit/data_preps/numpy_minmax_scaler.py @@ -10,21 +10,22 @@ class NumpyMinMaxScaler(JDSTDataPrep): # Initialize: # ********************************************* - def __init__(self, path_to_cfg, user_config={}): + def __init__(self, config, user_config={}): # Set the name specific to this module: self.module_name = "numpy_minmax_scaler" # Load the configuration: - self.config = self.load_config(path_to_cfg, user_config) + self.config = self.load_config(config, user_config) # Save this config, if a path is provided: if "store_cfg_loc" in self.config: self.save_config(self.config["store_cfg_loc"]) # Set up the scaler: - try: + if bool(self.config): + try: self.scaler = MinMaxScaler(self.config["feature_range"]) - except: + except: logging.exception( ">>> " + self.module_name @@ -39,7 +40,7 @@ def get_info(self): print(" ") print("*** Info: NumpyMinMaxScaler ***") print("Input(s):") - print("i) Full path to .yaml configuration file ") + print("i) Full configuration file ") print( "ii) Optional: User configuration, i.e. a python dict with additonal / alternative settings" ) @@ -62,25 +63,27 @@ def get_info(self): # Handle configurations: # ********************************************* # Load the config: - def load_config(self, path_to_cfg, user_config): - with open(path_to_cfg, "r") as file: - cfg = yaml.safe_load(file) + def load_config(self, config, user_config): + if config is not None: - # Overwrite config with user settings, if provided - try: + # Overwrite config with user settings, if provided + try: if bool(user_config): # ++++++++++++++++++++++++ for key in user_config: - cfg[key] = user_config[key] + config[key] = user_config[key] # ++++++++++++++++++++++++ - except: + except: logging.exception( ">>> " + self.module_name + ": Invalid user config. Please make sure that a dictionary is provided <<<" ) - return cfg + return config + + else: + return {} # ----------------------------- diff --git a/jlab_datascience_toolkit/data_preps/registration.py b/jlab_datascience_toolkit/data_preps/registration.py new file mode 100644 index 0000000..29a0000 --- /dev/null +++ b/jlab_datascience_toolkit/data_preps/registration.py @@ -0,0 +1,109 @@ +import importlib +import logging + +module_log = logging.getLogger("Module Registry") + + +def load(name): + mod_name, attr_name = name.split(":") + print(f"Attempting to load {mod_name} with {attr_name}") + mod = importlib.import_module(mod_name) + fn = getattr(mod, attr_name) + return fn + + +class ModuleSpec(object): + def __init__(self, id, entry_point=None, kwargs=None): + self.id = id + self.entry_point = entry_point + self._kwargs = {} if kwargs is None else kwargs + + def make(self, **kwargs): + """Instantiates an instance of data module with appropriate kwargs""" + if self.entry_point is None: + module_log.error( + "Attempting to make deprecated module {}. \ + (HINT: is there a newer registered version \ + of this module?)".format( + self.id + ) + ) + raise RuntimeError + + _kwargs = self._kwargs.copy() + _kwargs.update(kwargs) + if callable(self.entry_point): + gen = self.entry_point(**_kwargs) + else: + cls = load(self.entry_point) + gen = cls(**_kwargs) + + return gen + + +class ModuleRegistry(object): + def __init__(self): + self.module_specs = {} + + def make(self, path, **kwargs): + if len(kwargs) > 0: + module_log.info("Making new module: %s (%s)", path, kwargs) + else: + module_log.info("Making new module: %s", path) + module_spec = self.spec(path) + module = module_spec.make(**kwargs) + + return module + + def all(self): + return self.module_specs.values() + + def spec(self, path): + if ":" in path: + mod_name, _sep, id = path.partition(":") + try: + importlib.import_module(mod_name) + except ImportError: + module_log.error( + "A module ({}) was specified for the module but was not found, \ + make sure the package is installed with `pip install` before \ + calling `module.make()`".format( + mod_name + ) + ) + raise + + else: + id = path + + try: + return self.module_specs[id] + except KeyError: + module_log.error("No registered module with id: {}".format(id)) + raise + + def register(self, id, **kwargs): + if id in self.module_specs: + module_log.error("Cannot re-register id: {}".format(id)) + raise RuntimeError + self.module_specs[id] = ModuleSpec(id, **kwargs) + + +# Global registry +module_registry = ModuleRegistry() + + +def register(id, **kwargs): + return module_registry.register(id, **kwargs) + + +def make(id, **kwargs): + return module_registry.make(id, **kwargs) + + +def spec(id): + return module_registry.spec(id) + + +def list_registered_modules(): + return list(module_registry.module_specs.keys()) diff --git a/jlab_datascience_toolkit/data_preps/split_dataframe_v0.py b/jlab_datascience_toolkit/data_preps/split_dataframe_v0.py index 7cedb58..4a2d38f 100644 --- a/jlab_datascience_toolkit/data_preps/split_dataframe_v0.py +++ b/jlab_datascience_toolkit/data_preps/split_dataframe_v0.py @@ -12,17 +12,18 @@ class SplitDataFrame(JDSTDataPrep): Each array is then splitted by rows according to the given rows_fractions (which must add up to one). """ - def __init__(self, configs: dict): - self.configs = configs - self.feature_columns = configs.get( + def __init__(self, config: dict): + self.configs = config + if bool(self.configs): + self.feature_columns = config.get( "feature_columns", None - ) # If None, all columns are considered - self.target_columns = configs.get( + ) # If None, all columns are considered + self.target_columns = config.get( "target_columns", None - ) # If None, there will be no target array - self.rows_fractions = configs.get("rows_fractions", [1.0]) - self.random_state = configs.get("random_state", None) - assert sum(self.rows_fractions) == 1, "Fractions must add up to 1 !!!" + ) # If None, there will be no target array + self.rows_fractions = config.get("rows_fractions", [1.0]) + self.random_state = config.get("random_state", None) + assert sum(self.rows_fractions) == 1, "Fractions must add up to 1 !!!" @staticmethod def split_by_columns( diff --git a/jlab_datascience_toolkit/models/__init__.py b/jlab_datascience_toolkit/models/__init__.py index 30ffe9f..6481ae5 100644 --- a/jlab_datascience_toolkit/models/__init__.py +++ b/jlab_datascience_toolkit/models/__init__.py @@ -1,4 +1,4 @@ -from jlab_datascience_toolkit.utils.registration import ( +from jlab_datascience_toolkit.models.registration import ( register, make, list_registered_modules, diff --git a/jlab_datascience_toolkit/models/keras_mlp_v0.py b/jlab_datascience_toolkit/models/keras_mlp_v0.py index 5e53e39..4b05ff2 100644 --- a/jlab_datascience_toolkit/models/keras_mlp_v0.py +++ b/jlab_datascience_toolkit/models/keras_mlp_v0.py @@ -10,7 +10,7 @@ class KerasMLP(JDSTModel): Defines an MLP model. self is not a keras.Model itself. Instead, it has a "model" attribute which is a keras.Model. """ - def __init__(self, configs: dict): + def __init__(self, config: dict): """ configs has the following keywords: 1) 'input_dim' @@ -19,10 +19,11 @@ def __init__(self, configs: dict): 2.2) 'layer_type': 'Dropout', 'layer_configs': keras Dropout layer configs 2.3) 'layer_type': 'BatchNormalization', 'layer_configs': keras BN layer cnfigs """ - self.configs = configs - inputs = keras.layers.Input(shape=(configs["input_dim"],)) - outputs = inputs - for layer_dict in configs["layers_dicts"]: + self.configs = config + if bool(self.configs): + inputs = keras.layers.Input(shape=(config["input_dim"],)) + outputs = inputs + for layer_dict in config["layers_dicts"]: layer_type = layer_dict["layer_type"] layer_configs = layer_dict.get("layer_configs", {}) if layer_type == "Dense": @@ -33,7 +34,7 @@ def __init__(self, configs: dict): outputs = keras.layers.BatchNormalization(**layer_configs)(outputs) else: raise NameError("Unrecognized layer_type !!!") - self.model = keras.models.Model(inputs=inputs, outputs=outputs) + self.model = keras.models.Model(inputs=inputs, outputs=outputs) def predict(self, x): y = self.model.predict(x) diff --git a/jlab_datascience_toolkit/models/registration.py b/jlab_datascience_toolkit/models/registration.py new file mode 100644 index 0000000..29a0000 --- /dev/null +++ b/jlab_datascience_toolkit/models/registration.py @@ -0,0 +1,109 @@ +import importlib +import logging + +module_log = logging.getLogger("Module Registry") + + +def load(name): + mod_name, attr_name = name.split(":") + print(f"Attempting to load {mod_name} with {attr_name}") + mod = importlib.import_module(mod_name) + fn = getattr(mod, attr_name) + return fn + + +class ModuleSpec(object): + def __init__(self, id, entry_point=None, kwargs=None): + self.id = id + self.entry_point = entry_point + self._kwargs = {} if kwargs is None else kwargs + + def make(self, **kwargs): + """Instantiates an instance of data module with appropriate kwargs""" + if self.entry_point is None: + module_log.error( + "Attempting to make deprecated module {}. \ + (HINT: is there a newer registered version \ + of this module?)".format( + self.id + ) + ) + raise RuntimeError + + _kwargs = self._kwargs.copy() + _kwargs.update(kwargs) + if callable(self.entry_point): + gen = self.entry_point(**_kwargs) + else: + cls = load(self.entry_point) + gen = cls(**_kwargs) + + return gen + + +class ModuleRegistry(object): + def __init__(self): + self.module_specs = {} + + def make(self, path, **kwargs): + if len(kwargs) > 0: + module_log.info("Making new module: %s (%s)", path, kwargs) + else: + module_log.info("Making new module: %s", path) + module_spec = self.spec(path) + module = module_spec.make(**kwargs) + + return module + + def all(self): + return self.module_specs.values() + + def spec(self, path): + if ":" in path: + mod_name, _sep, id = path.partition(":") + try: + importlib.import_module(mod_name) + except ImportError: + module_log.error( + "A module ({}) was specified for the module but was not found, \ + make sure the package is installed with `pip install` before \ + calling `module.make()`".format( + mod_name + ) + ) + raise + + else: + id = path + + try: + return self.module_specs[id] + except KeyError: + module_log.error("No registered module with id: {}".format(id)) + raise + + def register(self, id, **kwargs): + if id in self.module_specs: + module_log.error("Cannot re-register id: {}".format(id)) + raise RuntimeError + self.module_specs[id] = ModuleSpec(id, **kwargs) + + +# Global registry +module_registry = ModuleRegistry() + + +def register(id, **kwargs): + return module_registry.register(id, **kwargs) + + +def make(id, **kwargs): + return module_registry.make(id, **kwargs) + + +def spec(id): + return module_registry.spec(id) + + +def list_registered_modules(): + return list(module_registry.module_specs.keys()) diff --git a/jlab_datascience_toolkit/trainers/keras_trainer_v0.py b/jlab_datascience_toolkit/trainers/keras_trainer_v0.py index f605067..0f0fc7d 100644 --- a/jlab_datascience_toolkit/trainers/keras_trainer_v0.py +++ b/jlab_datascience_toolkit/trainers/keras_trainer_v0.py @@ -52,10 +52,10 @@ class Trainer(JDSTTrainer): Satatic method loading configurations from a given path """ - def __init__(self, configs: dict): - self.configs = configs + def __init__(self, config: dict): + self.configs = config self.settings = ( - configs.copy() + config.copy() ) # Must be separate from configs as it can include actual keras callback objects self.settings.pop("registered_name") diff --git a/jlab_datascience_toolkit/verify_installation.py b/jlab_datascience_toolkit/verify_installation.py new file mode 100644 index 0000000..b7d10bf --- /dev/null +++ b/jlab_datascience_toolkit/verify_installation.py @@ -0,0 +1,18 @@ +import jlab_datascience_toolkit.data_parsers as parsers +import jlab_datascience_toolkit.data_preps as preps +import jlab_datascience_toolkit.models as models +import jlab_datascience_toolkit.agents as agents + + +modules = { + 'data parsers': parsers, + 'data preps': preps, + 'models': models, + 'agents': agents +} + +print(" ") +for mod in modules: + print(f"Available {mod}:") + print(modules[mod].list_registered_modules()) + print(" ") \ No newline at end of file diff --git a/jlab_datascience_toolkit/workflows/example_workflow_v0.py b/jlab_datascience_toolkit/workflows/example_workflow_v0.py index d785c31..dd22003 100644 --- a/jlab_datascience_toolkit/workflows/example_workflow_v0.py +++ b/jlab_datascience_toolkit/workflows/example_workflow_v0.py @@ -21,13 +21,13 @@ def main(configs: DictConfig): analysis_configs = configs["analysis_configs"] # 1) Load Data - parser = make_parser(parser_configs["registered_name"], configs=parser_configs) + parser = make_parser(parser_configs["registered_name"], config=parser_configs) df = parser.load_data() classes_list = [(c, i) for i, c in enumerate(df["species"].unique().tolist())] df["species_int"] = df["species"].map(dict(classes_list)) # 2) Split Data - prep = make_prep(prep_configs["registered_name"], configs=prep_configs) + prep = make_prep(prep_configs["registered_name"], config=prep_configs) x_train, x_val, x_test, y_train, y_val, y_test = prep.run(df) # 3) Scaling @@ -37,10 +37,10 @@ def main(configs: DictConfig): x_test = scaler.transform(x_test) # 4) Define Model - model = make_model(model_configs["registered_name"], configs=model_configs) + model = make_model(model_configs["registered_name"], config=model_configs) # 5) Train Model - trainer = make_trainer(trainer_configs["registered_name"], configs=trainer_configs) + trainer = make_trainer(trainer_configs["registered_name"], config=trainer_configs) history = trainer.fit( model=model, x=x_train, y=y_train, validation_data=(x_val, y_val), logdir=logdir ) @@ -49,7 +49,7 @@ def main(configs: DictConfig): y_pred = model.predict(x_test) # (n_samples, c_classes) y_pred = y_pred.argmax(axis=1) # (n_samples) multiclass_ana = make_analysis( - analysis_configs["registered_name"], configs=analysis_configs + analysis_configs["registered_name"], config=analysis_configs ) results = multiclass_ana.run( y_test, diff --git a/utests/utest_famous_datasets_v0.py b/utests/utest_famous_datasets_v0.py index 24a0a97..a4d22bb 100644 --- a/utests/utest_famous_datasets_v0.py +++ b/utests/utest_famous_datasets_v0.py @@ -5,7 +5,7 @@ class TestSplitDataFrame(unittest.TestCase): def test_iris(self): - parser = FamousDatasetsV0(configs={'dataset_name': 'iris'}) + parser = FamousDatasetsV0(config={'dataset_name': 'iris'}) df = parser.load_data() self.assertTrue(isinstance(df, pd.DataFrame)) diff --git a/utests/utest_keras_mlp_v0.py b/utests/utest_keras_mlp_v0.py index f74d7bf..56381ac 100644 --- a/utests/utest_keras_mlp_v0.py +++ b/utests/utest_keras_mlp_v0.py @@ -23,7 +23,7 @@ def setUp(cls): }, ], } - cls.model = make_model(cls.configs["registered_name"], configs=cls.configs) + cls.model = make_model(cls.configs["registered_name"], config=cls.configs) cls.x = np.random.rand(100, 4) cls.model_folder = "./model_folder/" @@ -34,7 +34,7 @@ def test_predict(self): def test_save_and_load(self): y_pred_old = self.model.predict(self.x) self.model.save(self.model_folder) - model_new = make_model(self.configs["registered_name"], configs=self.configs) + model_new = make_model(self.configs["registered_name"], config=self.configs) model_new.load(self.model_folder) y_pred_new = model_new.predict(self.x) self.assertTrue(np.array_equal(y_pred_old, y_pred_new)) diff --git a/utests/utest_numpy_minmax_scaler.py b/utests/utest_numpy_minmax_scaler.py index 0fdaa8b..82dee94 100644 --- a/utests/utest_numpy_minmax_scaler.py +++ b/utests/utest_numpy_minmax_scaler.py @@ -38,15 +38,9 @@ def test_drive_numpy_minmax_scaler(self): # Now load the scaler by defining a user config first: print("Load numpy min max scaler...") - this_file_loc = os.path.dirname(__file__) - cfg_loc = os.path.join( - this_file_loc, - "../jlab_datascience_toolkit/cfgs/defaults/numpy_minmax_scaler_cfg.yaml", - ) - param_store_loc = this_file_loc + "/numpy_minmax_scaler_params" - scaler_cfg = {"feature_range": (-1.0, 1.0), "store_loc": param_store_loc} + scaler_cfg = {"feature_range": (-1.0, 1.0), "store_loc": "numpy_minmax_scaler_params"} npy_scaler = preps.make( - "NumpyMinMaxScaler_v0", path_to_cfg=cfg_loc, user_config=scaler_cfg + "NumpyMinMaxScaler_v0", config=scaler_cfg ) # Print the module info: diff --git a/utests/utest_numpy_parser.py b/utests/utest_numpy_parser.py index 6d66d48..5c3d174 100644 --- a/utests/utest_numpy_parser.py +++ b/utests/utest_numpy_parser.py @@ -53,14 +53,13 @@ def test_drive_numpy_parser(self): # so we need to provide an additional config that allows us to overwrite the default setting (which is simply "") print("Load numpy parser...") - parser_cfg = {"data_loc": data_locs} - this_file_loc = os.path.dirname(__file__) - cfg_loc = os.path.join( - this_file_loc, - "../jlab_datascience_toolkit/cfgs/defaults/numpy_parser_cfg.yaml", - ) + parser_cfg = { + "data_loc": data_locs, + "event_axis":0, + "dtype":"float32" + } npy_parser = parsers.make( - "NumpyParser_v0", path_to_cfg=cfg_loc, user_config=parser_cfg + "NumpyParser_v0", config=parser_cfg ) # Lets see if we can call the information about this module: diff --git a/utests/utest_registry.py b/utests/utest_registry.py new file mode 100644 index 0000000..ff7b294 --- /dev/null +++ b/utests/utest_registry.py @@ -0,0 +1,33 @@ +import unittest +import jlab_datascience_toolkit.data_parsers as parsers +import jlab_datascience_toolkit.data_preps as preps +import jlab_datascience_toolkit.models as models +import jlab_datascience_toolkit.agents as agents +import yaml + +class TestRegistry(unittest.TestCase): + + def __init__(self, *args, **kwargs): + super(TestRegistry, self).__init__(*args, **kwargs) + self.modules = { + 'data parsers': parsers, + 'data preps': preps, + 'models': models, + 'agents': agents + } + + def test_registry(self): + for mod in self.modules: + available_modules = self.modules[mod].list_registered_modules() + + if len(available_modules) > 0: + for id in available_modules: + print(f"Making module: {id}...") + self.modules[mod].make(id,config=None) + print("...done!") + print(" ") + else: + print(f"Not modules registered for: {mod}") + +if __name__ == "__main__": + unittest.main()