Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions ds_toolkit.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,5 @@ dependencies:
- imageio
- scikit-image
- optuna
- seaborn
- hydra-core
2 changes: 2 additions & 0 deletions ds_toolkit_metal.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,5 @@ dependencies:
- imageio
- scikit-image
- optuna
- seaborn
- hydra-core
5 changes: 5 additions & 0 deletions jlab_datascience_toolkit/agents/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from jlab_datascience_toolkit.agents.registration import (
register,
make,
list_registered_modules,
)
109 changes: 109 additions & 0 deletions jlab_datascience_toolkit/agents/registration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import importlib
import logging

module_log = logging.getLogger("Module Registry")


def load(name):
mod_name, attr_name = name.split(":")
print(f"Attempting to load {mod_name} with {attr_name}")
mod = importlib.import_module(mod_name)
fn = getattr(mod, attr_name)
return fn


class ModuleSpec(object):
def __init__(self, id, entry_point=None, kwargs=None):
self.id = id
self.entry_point = entry_point
self._kwargs = {} if kwargs is None else kwargs

def make(self, **kwargs):
"""Instantiates an instance of data module with appropriate kwargs"""
if self.entry_point is None:
module_log.error(
"Attempting to make deprecated module {}. \
(HINT: is there a newer registered version \
of this module?)".format(
self.id
)
)
raise RuntimeError

_kwargs = self._kwargs.copy()
_kwargs.update(kwargs)
if callable(self.entry_point):
gen = self.entry_point(**_kwargs)
else:
cls = load(self.entry_point)
gen = cls(**_kwargs)

return gen


class ModuleRegistry(object):
def __init__(self):
self.module_specs = {}

def make(self, path, **kwargs):
if len(kwargs) > 0:
module_log.info("Making new module: %s (%s)", path, kwargs)
else:
module_log.info("Making new module: %s", path)
module_spec = self.spec(path)
module = module_spec.make(**kwargs)

return module

def all(self):
return self.module_specs.values()

def spec(self, path):
if ":" in path:
mod_name, _sep, id = path.partition(":")
try:
importlib.import_module(mod_name)
except ImportError:
module_log.error(
"A module ({}) was specified for the module but was not found, \
make sure the package is installed with `pip install` before \
calling `module.make()`".format(
mod_name
)
)
raise

else:
id = path

try:
return self.module_specs[id]
except KeyError:
module_log.error("No registered module with id: {}".format(id))
raise

def register(self, id, **kwargs):
if id in self.module_specs:
module_log.error("Cannot re-register id: {}".format(id))
raise RuntimeError
self.module_specs[id] = ModuleSpec(id, **kwargs)


# Global registry
module_registry = ModuleRegistry()


def register(id, **kwargs):
return module_registry.register(id, **kwargs)


def make(id, **kwargs):
return module_registry.make(id, **kwargs)


def spec(id):
return module_registry.spec(id)


def list_registered_modules():
return list(module_registry.module_specs.keys())
2 changes: 1 addition & 1 deletion jlab_datascience_toolkit/analyses/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from jlab_datascience_toolkit.utils.registration import (
from jlab_datascience_toolkit.analyses.registration import (
register,
make,
list_registered_modules,
Expand Down
4 changes: 2 additions & 2 deletions jlab_datascience_toolkit/analyses/multiclass_analysis_v0.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@


class Analysis:
def __init__(self, configs: dict):
self.configs = configs
def __init__(self, config: dict):
self.configs = config

def run(
self,
Expand Down
109 changes: 109 additions & 0 deletions jlab_datascience_toolkit/analyses/registration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import importlib
import logging

module_log = logging.getLogger("Module Registry")


def load(name):
mod_name, attr_name = name.split(":")
print(f"Attempting to load {mod_name} with {attr_name}")
mod = importlib.import_module(mod_name)
fn = getattr(mod, attr_name)
return fn


class ModuleSpec(object):
def __init__(self, id, entry_point=None, kwargs=None):
self.id = id
self.entry_point = entry_point
self._kwargs = {} if kwargs is None else kwargs

def make(self, **kwargs):
"""Instantiates an instance of data module with appropriate kwargs"""
if self.entry_point is None:
module_log.error(
"Attempting to make deprecated module {}. \
(HINT: is there a newer registered version \
of this module?)".format(
self.id
)
)
raise RuntimeError

_kwargs = self._kwargs.copy()
_kwargs.update(kwargs)
if callable(self.entry_point):
gen = self.entry_point(**_kwargs)
else:
cls = load(self.entry_point)
gen = cls(**_kwargs)

return gen


class ModuleRegistry(object):
def __init__(self):
self.module_specs = {}

def make(self, path, **kwargs):
if len(kwargs) > 0:
module_log.info("Making new module: %s (%s)", path, kwargs)
else:
module_log.info("Making new module: %s", path)
module_spec = self.spec(path)
module = module_spec.make(**kwargs)

return module

def all(self):
return self.module_specs.values()

def spec(self, path):
if ":" in path:
mod_name, _sep, id = path.partition(":")
try:
importlib.import_module(mod_name)
except ImportError:
module_log.error(
"A module ({}) was specified for the module but was not found, \
make sure the package is installed with `pip install` before \
calling `module.make()`".format(
mod_name
)
)
raise

else:
id = path

try:
return self.module_specs[id]
except KeyError:
module_log.error("No registered module with id: {}".format(id))
raise

def register(self, id, **kwargs):
if id in self.module_specs:
module_log.error("Cannot re-register id: {}".format(id))
raise RuntimeError
self.module_specs[id] = ModuleSpec(id, **kwargs)


# Global registry
module_registry = ModuleRegistry()


def register(id, **kwargs):
return module_registry.register(id, **kwargs)


def make(id, **kwargs):
return module_registry.make(id, **kwargs)


def spec(id):
return module_registry.spec(id)


def list_registered_modules():
return list(module_registry.module_specs.keys())
13 changes: 13 additions & 0 deletions jlab_datascience_toolkit/cores/jdst_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from jlab_datascience_toolkit.cores.jdst_module import JDSTModule
from abc import ABC, abstractmethod


class JDSTAgent(JDSTModule, ABC):
"""
Base class for an agent. This class inherits from the module base class.
"""

# Get a prediction:
@abstractmethod
def predict(self):
raise NotImplementedError
2 changes: 1 addition & 1 deletion jlab_datascience_toolkit/data_parsers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from jlab_datascience_toolkit.utils.registration import (
from jlab_datascience_toolkit.data_parsers.registration import (
register,
make,
list_registered_modules,
Expand Down
15 changes: 10 additions & 5 deletions jlab_datascience_toolkit/data_parsers/famous_datasets_v0.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from jlab_datascience_toolkit.cores.jdst_data_parser import JDSTDataParser
import seaborn as sns
import inspect
import logging
import yaml

test_log = logging.Logger(__name__)

class FamousDatasetsV0(JDSTDataParser):
"""Returns one of the example famous datasets such as iris.
Expand Down Expand Up @@ -43,11 +45,14 @@ class FamousDatasetsV0(JDSTDataParser):
save_data()
Does nothing
"""
def __init__(self, configs: dict):
self.configs = configs
self.dataset_name = configs['dataset_name']
self.settings = {k: v for k, v in configs.items() if k not in {'dataset_name', 'registered_name'}}

def __init__(self, config: dict):
self.configs = config
try:
self.dataset_name = config['dataset_name']
self.settings = {k: v for k, v in config.items() if k not in {'dataset_name', 'registered_name'}}
except:
test_log.error(">>> No valid configuration provided <<<")

def load_data(self):
if self.dataset_name == 'iris':
return sns.load_dataset('iris', **self.settings)
Expand Down
25 changes: 12 additions & 13 deletions jlab_datascience_toolkit/data_parsers/numpy_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class NumpyParser(JDSTDataParser):
ii) Combine single .npy files into one

Input(s):
i) Full path to .yaml configuration file
i) Configuration file
ii) Optional: User configuration, i.e. a python dict with additonal / alternative settings

Output(s):
Expand All @@ -22,20 +22,20 @@ class NumpyParser(JDSTDataParser):

# Initialize:
# *********************************************
def __init__(self, path_to_cfg, user_config={}):
def __init__(self, config, user_config={}):
# Set the name specific to this module:
self.module_name = "numpy_parser"

# Load the configuration:
self.config = self.load_config(path_to_cfg, user_config)
self.config = self.load_config(config, user_config)

# Save this config, if a path is provided:
if "store_cfg_loc" in self.config:
self.save_config(self.config["store_cfg_loc"])

# Run sanity check(s):
# i) Make sure that the provide data path(s) are list objects:
if isinstance(self.config["data_loc"], list) == False:
if bool(self.config) and isinstance(self.config["data_loc"], list) == False:
logging.error(
">>> "
+ self.module_name
Expand All @@ -54,25 +54,24 @@ def get_info(self):
# Handle configurations:
# *********************************************
# Load the config:
def load_config(self, path_to_cfg, user_config):
with open(path_to_cfg, "r") as file:
cfg = yaml.safe_load(file)

# Overwrite config with user settings, if provided
try:
def load_config(self, config, user_config):
if config is not None:
try:
if bool(user_config):
# ++++++++++++++++++++++++
for key in user_config:
cfg[key] = user_config[key]
config[key] = user_config[key]
# ++++++++++++++++++++++++
except:
except:
logging.exception(
">>> "
+ self.module_name
+ ": Invalid user config. Please make sure that a dictionary is provided <<<"
)
else:
config = {}

return cfg
return config

# -----------------------------

Expand Down
Loading