diff --git a/autotest/test_interface.py b/autotest/test_interface.py index 3562761..0061ac4 100644 --- a/autotest/test_interface.py +++ b/autotest/test_interface.py @@ -7,7 +7,7 @@ from modflow_devtools.misc import set_dir from modflowapi import Callbacks, ModflowApi, run_simulation -from modflowapi.extensions.pakbase import AdvancedPackage, ArrayPackage, ListPackage +from modflowapi.extensions.pakbase import Package data_pth = Path("../docs/examples/data") pytestmark = pytest.mark.extensions @@ -71,16 +71,16 @@ def callback(sim, step): raise AssertionError("ApiModel has advanced prior to initialization callback") dis = model.dis - if not isinstance(dis, ArrayPackage): - raise TypeError("DIS package has incorrect base class type") + if "idomain" not in dis.variable_names: + raise TypeError("DIS package should have grid array variables") wel = model.wel - if not isinstance(wel, ListPackage): - raise TypeError("WEL package has incorrect base class type") + if wel.stress_period_data is None: + raise TypeError("WEL package should have stress period data") gnc = model.gnc - if not isinstance(gnc, AdvancedPackage): - raise TypeError("GNC package has incorrect base class type") + if not isinstance(gnc, Package): + raise TypeError("GNC package has incorrect type") rch = model.rch if len(rch) != 2: @@ -158,16 +158,16 @@ def callback(sim, step): raise AssertionError("ApiModel has advanced prior to initialization callback") dis = model.dis - if not isinstance(dis, ArrayPackage): - raise TypeError("DIS package has incorrect base class type") + if "idomain" not in dis.variable_names: + raise TypeError("DIS package should have grid array variables") chd = model.chd_left - if not isinstance(chd, ListPackage): - raise TypeError("CHD package has incorrect base class type") + if chd.stress_period_data is None: + raise TypeError("CHD package should have stress period data") hfb = model.hfb - if not isinstance(hfb, AdvancedPackage): - raise TypeError("HFB package has incorrect base class type") + if not isinstance(hfb, Package): + raise TypeError("HFB package has incorrect type") chd = model.chd if len(chd) != 2: @@ -236,16 +236,16 @@ def callback(sim, step): raise AssertionError("ApiModel has advanced prior to initialization callback") dis = model.dis - if not isinstance(dis, ArrayPackage): - raise TypeError("DIS package has incorrect base class type") + if "idomain" not in dis.variable_names: + raise TypeError("DIS package should have grid array variables") rch = model.rch - if not isinstance(rch, ListPackage): - raise TypeError("RCH package has incorrect base class type") + if rch.stress_period_data is None: + raise TypeError("RCH package should have stress period data") mvr = model.mvr - if not isinstance(mvr, AdvancedPackage): - raise TypeError("MVR package has incorrect base class type") + if not isinstance(mvr, Package): + raise TypeError("MVR package has incorrect type") top = dis.top.values if not isinstance(top, np.ndarray): diff --git a/modflowapi/extensions/advpaks.py b/modflowapi/extensions/advpaks.py index f0c6ae0..3b99984 100644 --- a/modflowapi/extensions/advpaks.py +++ b/modflowapi/extensions/advpaks.py @@ -1,107 +1,8 @@ -import numpy as np - -from .data import ListInput from .pakbase import AdvancedPackage - -class SfrPakage(AdvancedPackage): - """ - Container for SFR and SFR like packages - - Parameters - ---------- - model : ApiModel - modflowapi model object - pkg_type : str - package type. Ex. "SFR" - pkg_name : str - package name (in the mf6 variables) - sim_package : bool - boolean flag for simulation level packages. Ex. TDIS, IMS - """ - - def __init__(self, model, pkg_type, pkg_name, sim_package=False): - super().__init__(model, pkg_type, pkg_name, sim_package) - - self._diversion_var_arrs = [] - self._set_advanced_variable_addrs("diversions", "_diversion_var_addrs") - self._diversion_vars = ListInput(self, self._diversion_var_arrs, spd=False) - - @property - def diversions(self): - return self._diversion_vars - - @diversions.setter - def diversions(self, recarray): - """ - Setter object to update the diversions data - - """ - if isinstance(recarray, np.recarray): - self._diversion_vars.values = recarray - elif isinstance(recarray, ListInput): - self._diversion_vars.values = recarray.values - elif recarray is None: - self._diversion_vars.values = recarray - else: - raise TypeError(f"{type(recarray)} is not a supported diversions type") - - -class LakPackage(AdvancedPackage): - """ - Container for LAK and LAK like packages - - Parameters - ---------- - model : ApiModel - modflowapi model object - pkg_type : str - package type. Ex. "LAK" - pkg_name : str - package name (in the mf6 variables) - sim_package : bool - boolean flag for simulation level packages. Ex. TDIS, IMS - """ - - def __init__(self, model, pkg_type, pkg_name, sim_package=False): - super().__init__(model, pkg_type, pkg_name, sim_package) - - -class MawPackage(AdvancedPackage): - """ - Container for MAW and MAW like packages - - Parameters - ---------- - model : ApiModel - modflowapi model object - pkg_type : str - package type. Ex. "MAW" - pkg_name : str - package name (in the mf6 variables) - sim_package : bool - boolean flag for simulation level packages. Ex. TDIS, IMS - """ - - def __init__(self, model, pkg_type, pkg_name, sim_package=False): - super().__init__(model, pkg_type, pkg_name, sim_package) - - -class UzfPackage(AdvancedPackage): - """ - Container for UZF and UZF like packages - - Parameters - ---------- - model : ApiModel - modflowapi model object - pkg_type : str - package type. Ex. "UZF" - pkg_name : str - package name (in the mf6 variables) - sim_package : bool - boolean flag for simulation level packages. Ex. TDIS, IMS - """ - - def __init__(self, model, pkg_type, pkg_name, sim_package=False): - super().__init__(model, pkg_type, pkg_name, sim_package) +# Backward compatibility aliases +SfrPackage = AdvancedPackage +SfrPakage = AdvancedPackage # preserved old typo spelling +LakPackage = AdvancedPackage +MawPackage = AdvancedPackage +UzfPackage = AdvancedPackage diff --git a/modflowapi/extensions/apiexchange.py b/modflowapi/extensions/apiexchange.py index de9c3b3..6e5f7d5 100644 --- a/modflowapi/extensions/apiexchange.py +++ b/modflowapi/extensions/apiexchange.py @@ -1,5 +1,4 @@ from .apimodel import ApiMbase -from .pakbase import ListPackage class ApiExchange(ApiMbase): @@ -15,10 +14,7 @@ class ApiExchange(ApiMbase): modflow exchange name. ex. "GWF-GWF_1" """ + sim_level = True # exchange packages are simulation-level, not model-level + def __init__(self, mf6, name): - pkg_types = { - "gwf-gwf": ListPackage, - "gwt-gwt": ListPackage, - "gwe-gwe": ListPackage, - } - super().__init__(mf6, name, pkg_types) + super().__init__(mf6, name) diff --git a/modflowapi/extensions/apimodel.py b/modflowapi/extensions/apimodel.py index 7936d84..0535c57 100644 --- a/modflowapi/extensions/apimodel.py +++ b/modflowapi/extensions/apimodel.py @@ -1,7 +1,9 @@ +import math + import numpy as np -from .datamodel import get_package_type, gridshape -from .pakbase import AdvancedPackage, ArrayPackage, ListPackage, package_factory +from .datamodel import gridshape +from .pakbase import Package class ApiMbase: @@ -18,6 +20,8 @@ class ApiMbase: optional dictionary of package types and ApiPackage class types """ + sim_level = False # True for sim/exchange containers; False for model containers + def __init__(self, mf6, name, pkg_types=None): self.mf6 = mf6 self.name = name @@ -33,7 +37,7 @@ def package_list(self): """ Returns a list of package objects for the model """ - return [package for _, package in self.package_dict.items()] + return list(self.package_dict.values()) @property def package_names(self): @@ -44,7 +48,7 @@ def package_names(self): @property def package_types(self): - return list(set([package.pkg_type for package in self.package_list])) + return list({p.pkg_type for p in self.package_list}) def _set_package_names(self): """ @@ -56,11 +60,9 @@ def _set_package_names(self): if addr.endswith("PACKAGE_TYPE") and tmp[0] == self.name: pak_types[tmp[1]] = self.mf6.get_value(addr)[0] elif tmp[0] == self.name and len(tmp) == 2: - if tmp[0].startswith("GWF-GWF"): - pak_types[tmp[0]] = "GWF-GWF" - pak_types.pop("dis", None) - elif tmp[0].startswith("GWT-GWT"): - pak_types[tmp[0]] = "GWT-GWT" + parts = tmp[0].rsplit("_", 1)[0].split("-") + if len(parts) == 2 and parts[0] == parts[1]: + pak_types[tmp[0]] = tmp[0].rsplit("_", 1)[0] pak_types.pop("dis", None) self._pak_type = list(pak_types.values()) @@ -72,26 +74,15 @@ def _create_package_list(self): """ for ix, pkg_name in enumerate(self._pkg_names): pkg_type = self._pak_type[ix].lower() - if self._pkg_types is None: - basepackage = get_package_type(pkg_type) - else: - if pkg_type in self._pkg_types: - basepackage = self._pkg_types[pkg_type] - else: - basepackage = AdvancedPackage - - package = package_factory(pkg_type, basepackage) - adj_pkg_name = "".join(pkg_type.split("-")) - - if adj_pkg_name.lower() in ("gwfgwf", "gwtgwt"): - adj_pkg_name = "" + if self._pkg_types is not None and pkg_type in self._pkg_types: + pkg_cls = self._pkg_types[pkg_type] else: - adj_pkg_name = pkg_name + pkg_cls = Package - package = package(basepackage, self, pkg_type, adj_pkg_name) + package = pkg_cls(self, pkg_type, pkg_name, sim_package=self.sim_level) self.package_dict[pkg_name.lower()] = package - def get_package(self, pkg_name) -> ListPackage or ArrayPackage or AdvancedPackage: + def get_package(self, pkg_name) -> "Package": """ Method to get a package @@ -160,16 +151,9 @@ def __repr__(self): else: pass - s += "Packages accessible include: \n" - for typ, baseobj in [ - ("ArrayPackage", ArrayPackage), - ("ListPackage", ListPackage), - ("AdvancedPackage", AdvancedPackage), - ]: - s += f" {typ} objects:\n" - for name, obj in self.package_dict.items(): - if isinstance(obj, baseobj): - s += f" {name}: {type(obj)}\n" + s += "Packages accessible include:\n" + for name, pkg in self.package_dict.items(): + s += f" {name}: {pkg.pkg_type.upper()}\n" return s @@ -263,8 +247,8 @@ def shape(self): """ Returns a tuple of the model shape """ - ivn = self.mf6.get_input_var_names() if self._shape is None: + ivn = self.mf6.get_input_var_names() shape_vars = gridshape[self.dis_type] shape = [] for var in shape_vars: @@ -283,10 +267,7 @@ def size(self): Returns the number of nodes in the model """ if self._size is None: - size = 1 - for dim in self.shape: - size *= dim - self._size = size + self._size = math.prod(self.shape) return self._size @property diff --git a/modflowapi/extensions/apisimulation.py b/modflowapi/extensions/apisimulation.py index f9c98f5..e128330 100644 --- a/modflowapi/extensions/apisimulation.py +++ b/modflowapi/extensions/apisimulation.py @@ -2,7 +2,7 @@ from .apiexchange import ApiExchange from .apimodel import ApiMbase, ApiModel -from .pakbase import ApiSlnPackage, ListPackage, ScalarPackage, package_factory +from .pakbase import ApiSlnPackage, Package class ApiSimulation: @@ -35,9 +35,6 @@ def __init__(self, mf6, models, solutions, exchanges, tdis, ats): self.tdis = tdis self.ats = ats - self._ats_active = True - if ats is None: - self._ats_active = False def __getattr__(self, item): """ @@ -61,7 +58,7 @@ def __repr__(self): if self._exchanges: s += "\tExchanges include:\n" for name, exchange in self._exchanges.items(): - f"\t\t{name}: {type(exchange)}\n" + s += f"\t\t{name}: {type(exchange)}\n" return s @@ -71,7 +68,7 @@ def ats_active(self): Returns a boolean to indicate if the ATS package is used in this simulation. """ - return self._ats_active + return self.ats is not None @property def ats_period(self): @@ -133,9 +130,7 @@ def sln(self): """ if len(self._solutions) > 1: return list(self._solutions.values()) - else: - for sln in self._solutions.values(): - return sln + return next(iter(self._solutions.values())) @property def model_names(self): @@ -149,15 +144,14 @@ def exchange_names(self): """ Returns a list of exchange GWF-GWF names """ - if self._exchanges.keys(): - return list(self._exchanges.keys()) + return list(self._exchanges.keys()) @property def models(self): """ Returns a list of ApiModel objects associated with the simulation """ - return [v for _, v in self._models.items()] + return list(self._models.values()) @property def iteration(self): @@ -256,8 +250,7 @@ def get_exchange(self, exchange_name=None): raise AssertionError("No exchanges are present in this simulation") if exchange_name is None: - for _, exg in self._exchanges: - return exg + return next(iter(self._exchanges.values())) else: if exchange_name in self._exchanges: @@ -283,25 +276,7 @@ def load(mf6): id_var_addr = mf6.get_var_address("ID", name) if name.startswith("SLN"): continue - elif ( - name.startswith("GWFIM") - or name.startswith("GWTIM") - or name.startswith("GWEIM") - or name.startswith("PRTIM") - or name.startswith("CHFIM") - or name.startswith("OLFIM") - or name.startswith("SWFIM") - ): - continue - elif ( - name.startswith("GWFCON") - or name.startswith("GWTCON") - or name.startswith("GWECON") - or name.startswith("PRTCON") - or name.startswith("CHFCON") - or name.startswith("OLFCON") - or name.startswith("SWFCON") - ): + elif name[3:5] in ("IM", "CO"): continue if id_var_addr not in variables: continue @@ -321,22 +296,14 @@ def load(mf6): id_var_addr = mf6.get_var_address("ID", name) if name.lower() in models or name == "TDIS": continue - if ( - name.startswith("GWFIM") - or name.startswith("GWTIM") - or name.startswith("GWEIM") - or name.startswith("PRTIM") - or name.startswith("CHFIM") - or name.startswith("OLFIM") - or name.startswith("SWFIM") - ): + if name[3:5] == "IM": continue if id_var_addr not in variables: continue solution_names.append(t[0]) - idp_names = [i for i in mf6.get_value("__INPUT__/SIM/NAM/SLNMNAMES")] + idp_names = list(mf6.get_value("__INPUT__/SIM/NAM/SLNMNAMES")) solution_types = [ i[:-1].lower() for ix, i in enumerate(mf6.get_value("__INPUT__/SIM/NAM/SLNTYPE")) if idp_names[ix] ] @@ -353,30 +320,24 @@ def load(mf6): solutions = solution_dict - # TDIS package construction - tdis_constructor = package_factory("tdis", ScalarPackage) - tdis = tdis_constructor(ScalarPackage, tmpmdl, "tdis", "tdis", sim_package=True) + tdis = Package(tmpmdl, "tdis", "tdis", sim_package=True) ats = None - # ATS package construction for variable in variables: if variable.startswith("ATS"): - ats_constructor = package_factory("ats", ListPackage) - ats = ats_constructor(ListPackage, tmpmdl, "ats", "ats", sim_package=True) + ats = Package(tmpmdl, "ats", "ats", sim_package=True) break # get the exchanges exchange_names = [] for variable in variables: - if variable.startswith("GWF-GWF") or variable.startswith("GWT-GWT"): - exchange_name = variable.split("/")[0] - if exchange_name not in exchange_names: - exchange_names.append(exchange_name) + exchange_name = variable.split("/")[0] + parts = exchange_name.rsplit("_", 1)[0].split("-") + if len(parts) == 2 and parts[0] == parts[1] and exchange_name not in exchange_names: + exchange_names.append(exchange_name) - # sim_packages: tdis, gwf-gwf, sln exchanges = {} - for exchange_name in exchanges: - exchange = ApiExchange(mf6, exchange_name) - exchanges[exchange_name.lower()] = exchange + for exchange_name in exchange_names: + exchanges[exchange_name.lower()] = ApiExchange(mf6, exchange_name) return ApiSimulation(mf6, models, solutions, exchanges, tdis, ats) diff --git a/modflowapi/extensions/data.py b/modflowapi/extensions/data.py index c104da2..595fb90 100644 --- a/modflowapi/extensions/data.py +++ b/modflowapi/extensions/data.py @@ -494,7 +494,6 @@ def values(self, array): raise ValueError( f"{self.name} size {array.size} is not equal to modflow variable size {self.parent.model.size}" ) - array = array.ravel() if self._ptr.size != array.size: array = array[self.parent.model.nodetouser] @@ -888,3 +887,32 @@ def set_value(self, item, value): self._ptrs[item][0] = value else: raise KeyError(f"{item} is not accessible in this package") + + +class ScalarVar: + """ + A single scalar variable from the MODFLOW 6 memory manager. + + Parameters + ---------- + name : str + variable name + ptr : np.ndarray + 1-element pointer array from the MODFLOW 6 memory manager + """ + + def __init__(self, name: str, ptr): + self._name = name + self._ptr = ptr + + @property + def name(self) -> str: + return self._name + + @property + def values(self): + return self._ptr[0] + + @values.setter + def values(self, v): + self._ptr[0] = v diff --git a/modflowapi/extensions/datamodel.py b/modflowapi/extensions/datamodel.py index 669a0d5..77105cc 100644 --- a/modflowapi/extensions/datamodel.py +++ b/modflowapi/extensions/datamodel.py @@ -160,7 +160,7 @@ "gwe-gwe": ["nexg", "nodem1", "nodem2", "cl1", "cl2", "ihc", "hwva"], # simulation "ats": [ - "maxats", + ("maxats", ()), "iperats", "dt0", "dtmin", @@ -274,48 +274,3 @@ }, "maw": {"packagedata": ["nmawwells", ("ifno:range:nmawwells", "radius", "bot", "strt", "ngwfnodes")]}, } - - -def get_package_type(pkg_type): - from .advpaks import LakPackage, MawPackage, SfrPakage, UzfPackage - from .pakbase import AdvancedPackage, ArrayPackage, ListPackage, ScalarPackage - - pkg_types = { - "dis": ArrayPackage, - "chd": ListPackage, - "drn": ListPackage, - "evt": ListPackage, - "ghb": ListPackage, - "ic": ArrayPackage, - "npf": ArrayPackage, - "rch": ListPackage, - "riv": ListPackage, - "sto": ArrayPackage, - "wel": ListPackage, - # advanced - "sfr": SfrPakage, - "uzf": UzfPackage, - "lak": LakPackage, - "maw": MawPackage, - # "csub": None, - # gwt - "dsp": ArrayPackage, - "cnc": ListPackage, - "ist": ArrayPackage, - "mst": ArrayPackage, - "src": ListPackage, - # gwe - "cnd": ArrayPackage, - "est": ArrayPackage, - "cpt": ListPackage, - "esl": ListPackage, - # prt - "mip": ArrayPackage, - # sim_level pkgs - "tdis": ScalarPackage, - "ats": ListPackage, - } - if pkg_type in pkg_types: - return pkg_types[pkg_type] - else: - return AdvancedPackage diff --git a/modflowapi/extensions/pakbase.py b/modflowapi/extensions/pakbase.py index 2a48c9b..3a24d4c 100644 --- a/modflowapi/extensions/pakbase.py +++ b/modflowapi/extensions/pakbase.py @@ -1,632 +1,436 @@ import numpy as np -from .data import AdvancedInput, ArrayInput, ListInput, ScalarInput +from .data import AdvancedInput, ArrayPointer, ListInput, ScalarVar from .datamodel import adv_pkgvars, pkgvars +_BASE_ATTRS = frozenset({"model", "pkg_name", "pkg_type"}) +_ADV_BLOCK_NAMES = frozenset({"griddata", "packagedata", "perioddata"}) -class PackageBase: - """ - Base class for packages within the modflow-6 api + +def _unwrap_list_input(value): + """Return value.values if value is a ListInput, otherwise return value as-is.""" + return value.values if isinstance(value, ListInput) else value +class Package: + """ + Package object for MODFLOW 6 API packages. + Parameters ---------- model : ApiModel - modflowapi ApiModel object + modflowapi model object pkg_type : str - package type name. ex. 'wel' + package type name, e.g. 'wel' pkg_name : str - modflow package name. ex. 'wel_0' - child_type : str - type of child input package + package name in the MF6 variables, e.g. 'wel_0' sim_package : bool - flag to indicate this is a simulation level package + flag indicating this is a simulation-level package """ - def __init__(self, model, pkg_type, pkg_name, child_type, sim_package): + def __init__(self, model, pkg_type, pkg_name, sim_package=False): self.model = model - self.pkg_name = pkg_name self.pkg_type = pkg_type - self._child_type = child_type + self.pkg_name = pkg_name.upper() self._sim_package = sim_package - self._rhs = None - self._hcof = None self._bound_vars = [] self._advanced_var_names = None self._idm_enabled = False + self._rhs = None + self._hcof = None + self._vars = {} + self._list_vars = {} + self._variables_adv = None + self._build_inputs() + + # ------------------------------------------------------------------ + # Input construction + # ------------------------------------------------------------------ + + def _build_inputs(self): + if self.pkg_type in adv_pkgvars: + self._build_advanced_inputs() + if self.pkg_type in pkgvars: + vars_list = pkgvars[self.pkg_type] + if any(isinstance(v, tuple) for v in vars_list): + self._build_list_inputs(vars_list) + else: + self._build_plain_inputs(vars_list) + def _build_list_inputs(self, vars_list): var_addrs = [] - if self._child_type != "advanced": - for var in pkgvars[self.pkg_type]: - if isinstance(var, tuple): - bound_vars = [] - for bv in var[-1]: - t = bv.split(":") - if len(t) == 2: - # this is a repeating variable - addr = self.model.mf6.get_var_address(t[-1].upper(), self.model.name, self.pkg_name) - nrep = self.model.mf6.get_value(addr)[0] - if nrep > 1: - for rep in range(nrep): - bound_vars.append(f"{t[0]}{rep}") - else: - bound_vars.append(t[0]) - else: - bound_vars.append(t[0]) - - self._bound_vars = var[-1] - var = var[0] - - if sim_package: - var_addrs.append(self.model.mf6.get_var_address(var.upper(), self.pkg_name)) - else: - var_addrs.append(self.model.mf6.get_var_address(var.upper(), self.model.name, self.pkg_name)) + for var in vars_list: + if isinstance(var, tuple): + self._bound_vars = var[-1] + var = var[0] + if self._sim_package: + var_addrs.append(self.model.mf6.get_var_address(var.upper(), self.pkg_name)) + else: + var_addrs.append(self.model.mf6.get_var_address(var.upper(), self.model.name, self.pkg_name)) for var in self._bound_vars: addr_chk = self.model.mf6.get_var_address(var.upper(), self.model.name, self.pkg_name) if addr_chk in self.model.mf6.get_input_var_names(): - # change this to use idm self._idm_enabled = True var_addrs.append(addr_chk) - self.var_addrs = var_addrs - self._variables_adv = AdvancedInput(self) + self._list_vars["stress_period_data"] = ListInput(self, var_addrs, spd=True) - @property - def advanced_vars(self): - """ - Returns a list of additional "advanced" variables that are - accessible through the API - """ - if self._advanced_var_names is None: - adv_vars = [] - for var_addr in self.model.mf6.get_input_var_names(): - is_advanced = False - t = var_addr.split("/") - if not self._sim_package: - if t[0] == self.model.name and t[1] == self.pkg_name: - is_advanced = self._check_if_advanced_var(t[-1]) + def _build_plain_inputs(self, vars_list): + ivn = self.model.mf6.get_input_var_names() + for var in vars_list: + if self._sim_package: + addr = self.model.mf6.get_var_address(var.upper(), self.pkg_name) + else: + addr = self.model.mf6.get_var_address(var.upper(), self.model.name, self.pkg_name) + if addr not in ivn: + continue + name = var.lower() + if self._sim_package: + self._vars[name] = ScalarVar(name, self.model.mf6.get_value_ptr(addr)) + else: + arr_var = ArrayPointer(self, addr) + if arr_var.name is not None: + self._vars[name] = arr_var + + def _build_advanced_inputs(self): + adv_var_dict = adv_pkgvars[self.pkg_type] + + if "griddata" in adv_var_dict: + self._build_plain_inputs(adv_var_dict["griddata"]) + + pkg_var_addrs = self._collect_adv_var_addrs(adv_var_dict, "packagedata") + if pkg_var_addrs: + self._list_vars["packagedata"] = ListInput(self, pkg_var_addrs, spd=False) + + sp_var_addrs = [] + if "perioddata" in adv_var_dict: + for var in adv_var_dict["perioddata"]: + if isinstance(var, tuple): + use_bound = all(":" not in v for v in var[-1]) + if use_bound: + self._bound_vars = var[-1] + var = var[0] + else: + for v in var[-1]: + if ":" in v: + self._bound_vars.append(v.split(":")[0]) + else: + self._bound_vars.append(v) + sp_var_addrs.append(self._adv_var_addr(v)) + var = None + if var is not None: + sp_var_addrs.append(self.model.mf6.get_var_address(var.upper(), self.model.name, self.pkg_name)) + + if sp_var_addrs: + self._list_vars["stress_period_data"] = ListInput(self, sp_var_addrs, spd=True) + + for block in adv_var_dict: + if block in _ADV_BLOCK_NAMES: + continue + var_addrs = self._collect_adv_var_addrs(adv_var_dict, block) + if var_addrs: + self._list_vars[block] = ListInput(self, var_addrs, spd=False) + + def _collect_adv_var_addrs(self, adv_var_dict, block): + var_addrs = [] + if block in adv_var_dict: + for var in adv_var_dict[block]: + if not isinstance(var, tuple): + var_addrs.append(self._adv_var_addr(var)) else: - if t[0] == self.pkg_name: - is_advanced = self._check_if_advanced_var(t[-1]) + for v in var: + var_addrs.append(self._adv_var_addr(v)) + return var_addrs - if is_advanced: - adv_vars.append(t[-1].lower()) + def _adv_var_addr(self, var_str): + return f"{self.model.name}/{self.pkg_name}/{var_str.upper()}" - self._advanced_var_names = adv_vars - return self._advanced_var_names + # ------------------------------------------------------------------ + # Attribute dispatch + # ------------------------------------------------------------------ - def _check_if_advanced_var(self, variable_name): - """ - Method to check if a variable is an advanced variable - - Parameters - ---------- - variable_name : str - variable name to check - - Returns - ------- - bool - """ - if variable_name.lower() in self._bound_vars: - is_advanced = False - elif self.pkg_type not in pkgvars: - is_advanced = True - elif variable_name.lower() in pkgvars[self.pkg_type]: - is_advanced = False + def __repr__(self): + s = f"{self.pkg_type.upper()} Package: {self.pkg_name}\n" + names = self.variable_names + if names: + s += " Accessible variables include:\n" + for name in names: + s += f" {name}\n" + return s + + def _try_discover_var(self, name): + """Try to build an ArrayPointer for a package-scoped variable by name, return None if unavailable.""" + if self._sim_package: + return None + var_addr = self.model.mf6.get_var_address(name.upper(), self.model.name, self.pkg_name) + arr_var = ArrayPointer(self, var_addr) + return arr_var if arr_var.name is not None else None + + def __getattr__(self, item): + try: + vars_ = object.__getattribute__(self, "_vars") + except AttributeError: + raise AttributeError(item) + if item in vars_: + v = vars_[item] + return v.values if isinstance(v, ScalarVar) else v + try: + list_vars = object.__getattribute__(self, "_list_vars") + except AttributeError: + raise AttributeError(item) + if item in list_vars: + return list_vars[item] + var = self._try_discover_var(item) + if var is not None: + vars_[item] = var + return var + raise AttributeError(item) + + def __setattr__(self, item, value): + if item.startswith("_") or item in _BASE_ATTRS: + object.__setattr__(self, item, value) + return + for cls in type(self).__mro__: + desc = cls.__dict__.get(item) + if desc is not None and hasattr(desc, "__set__"): + desc.__set__(self, value) + return + try: + vars_ = object.__getattribute__(self, "_vars") + except AttributeError: + object.__setattr__(self, item, value) + return + if item in vars_: + vars_[item].values = value + return + try: + list_vars = object.__getattribute__(self, "_list_vars") + except AttributeError: + pass else: - is_advanced = True - return is_advanced + if item in list_vars: + list_vars[item].values = _unwrap_list_input(value) + return + var = self._try_discover_var(item) + if var is not None: + vars_[item] = var + vars_[item].values = value + return + raise AttributeError(f"{item} is not a valid attribute for {self.pkg_type}") + + # ------------------------------------------------------------------ + # Static properties + # ------------------------------------------------------------------ - def get_advanced_var(self, name): - """ - Method to get an advanced variable that is not automatically - accessible through stress period data or as an array name - """ - name = name.lower() - if name not in self.advanced_vars: - raise AssertionError(f"{name} is not accessible as an advanced variable for this package") + @property + def variable_names(self): + """Returns a sorted list of variable names accessible through the API.""" + return sorted(list(self._vars) + list(self._list_vars)) - values = self._variables_adv.get_variable(name) - if not self._sim_package: - if values.size == self.model.nodetouser.size and self._child_type == "array": - array = np.full(self.model.size, np.nan) - array[self.model.nodetouser] = values - return array + @property + def stress_period_data(self): + """Returns the ListInput for stress period data, or None if not present.""" + return self._list_vars.get("stress_period_data") - return values + @stress_period_data.setter + def stress_period_data(self, recarray): + lv = self._list_vars.get("stress_period_data") + if lv is not None: + lv.values = _unwrap_list_input(recarray) - def set_advanced_var(self, name, values): - """ - Method to set data to an advanced variable - - Parameters - ---------- - name : str - parameter name - values : np.ndarray - numpy array - """ - if not self._sim_package: - if self._child_type == "array" and values.size == self.model.size: - values = values[self.model.nodetouser] + @property + def packagedata(self): + """Returns the ListInput for packagedata, or None if not present.""" + return self._list_vars.get("packagedata") - self._variables_adv.set_variable(name, values) + @packagedata.setter + def packagedata(self, recarray): + lv = self._list_vars.get("packagedata") + if lv is not None: + lv.values = _unwrap_list_input(recarray) @property - def rhs(self): - if not self._sim_package: - if self._rhs is None: - var_addr = self.model.mf6.get_var_address("RHS", self.model.name, self.pkg_name) - if var_addr in self.model.mf6.get_input_var_names(): - self._rhs = self.model.mf6.get_value_ptr(var_addr) - else: - return + def nbound(self): + """Returns the number of active boundaries for the current stress period.""" + lv = self._list_vars.get("stress_period_data") + return lv._nbound[0] if lv is not None else None + @property + def maxbound(self): + """Returns the maximum number of boundaries.""" + lv = self._list_vars.get("stress_period_data") + return lv._maxbound[0] if lv is not None else None + + @property + def rhs(self): + if self._sim_package: + return None + if self._rhs is None: + var_addr = self.model.mf6.get_var_address("RHS", self.model.name, self.pkg_name) + if var_addr in self.model.mf6.get_input_var_names(): + self._rhs = self.model.mf6.get_value_ptr(var_addr) + else: + return None return np.copy(self._rhs) @rhs.setter def rhs(self, values): if self._rhs is None: - rhs = self.rhs - if rhs is None: + _ = self.rhs + if self._rhs is None: raise Exception(f"{self.pkg_type} does not have a rhs array") - self._rhs[:] = values[:] @property def hcof(self): - if not self._sim_package: - if self._hcof is None: - var_addr = self.model.mf6.get_var_address("HCOF", self.model.name, self.pkg_name) - if var_addr in self.model.mf6.get_input_var_names(): - self._hcof = self.model.mf6.get_value_ptr(var_addr) - else: - return - + if self._sim_package: + return None + if self._hcof is None: + var_addr = self.model.mf6.get_var_address("HCOF", self.model.name, self.pkg_name) + if var_addr in self.model.mf6.get_input_var_names(): + self._hcof = self.model.mf6.get_value_ptr(var_addr) + else: + return None return np.copy(self._hcof) @hcof.setter def hcof(self, values): if self._hcof is None: - hcof = self.hcof - if hcof is None: + _ = self.hcof + if self._hcof is None: raise Exception(f"{self.pkg_type} does not have an hcof array") - self._hcof[:] = values[:] - -class ListPackage(PackageBase): - """ - Package object for "list based" input packages such as WEL, DRN, RCH - - Parameters - ---------- - model : ApiModel - modflowapi model object - pkg_type : str - package type. Ex. "RCH" - pkg_name : str - package name (in the mf6 variables) - sim_package : bool - flag to indicate this is a simulation level package - """ - - def __init__(self, model, pkg_type, pkg_name, sim_package=False): - super().__init__(model, pkg_type, pkg_name.upper(), "list", sim_package) - - self._variables = ListInput(self) - - def __repr__(self): - s = f"{self.pkg_type.upper()} Package: {self.pkg_name}" - return s - @property - def nbound(self): - """ - Returns the "nbound" value for the stress period - """ - return self._variables._nbound[0] - - @property - def maxbound(self): - """ - Returns the "maxbound" value for the stress period - """ - return self._variables._maxbound[0] - - @property - def stress_period_data(self): - """ - Returns a ListInput object of the current stress_period_data - """ - return self._variables - - @stress_period_data.setter - def stress_period_data(self, recarray): - """ - Setter method to update the current stress_period_data - """ - if isinstance(recarray, np.recarray): - self._variables.values = recarray - elif isinstance(recarray, ListInput): - self._variables.values = recarray.values - elif recarray is None: - self._variables.values = recarray - else: - raise TypeError(f"{type(recarray)} is not a supported stress_period_data type") - - -class ArrayPackage(PackageBase): - """ - Package object for "array based" input packages such as NPF, DIS, - - Parameters - ---------- - model : ApiModel - modflowapi model object - pkg_type : str - package type. Ex. "DIS" - pkg_name : str - package name (in the mf6 variables) - sim_package : bool - flag to indicate this is a simulation level package - """ - - def __init__(self, model, pkg_type, pkg_name, sim_package=False): - super().__init__(model, pkg_type, pkg_name.upper(), "array", sim_package) - - self._variables = ArrayInput(self) - - def __repr__(self): - s = f"{self.pkg_type.upper()} Package: {self.pkg_name} \n" - s += " Accessible variables include:\n" - for var_name in self.variable_names: - s += f" {var_name} \n" - return s - - def __setattr__(self, item, value): - """ - Method that enables dynamic variable setting and distributes - modflow variable storage and updates to the data object class - """ - if item in ("model", "pkg_name", "pkg_type", "var_addrs"): - super().__setattr__(item, value) - - elif item.startswith("_"): - super().__setattr__(item, value) + def advanced_vars(self): + """Returns a list of additional variables accessible through get/set_advanced_var.""" + if self._advanced_var_names is None: + adv_vars = [] + for var_addr in self.model.mf6.get_input_var_names(): + t = var_addr.split("/") + is_advanced = False + if not self._sim_package: + if t[0] == self.model.name and t[1] == self.pkg_name: + is_advanced = self._check_if_advanced_var(t[-1]) + else: + if t[0] == self.pkg_name: + is_advanced = self._check_if_advanced_var(t[-1]) + if is_advanced: + adv_vars.append(t[-1].lower()) + self._advanced_var_names = adv_vars + return self._advanced_var_names - elif item in self._variables._ptrs: - self._variables.set_ptr(item, value) + def _check_if_advanced_var(self, variable_name): + if variable_name.lower() in self._bound_vars: + return False + if self.pkg_type not in pkgvars: + return True + if variable_name.lower() in pkgvars[self.pkg_type]: + return False + return True - else: - raise AttributeError(f"{item}") + def get_advanced_var(self, name): + """Get a variable not surfaced through stress_period_data or variable_names.""" + name = name.lower() + if name not in self.advanced_vars: + raise AssertionError(f"{name} is not accessible as an advanced variable for this package") + if self._variables_adv is None: + self._variables_adv = AdvancedInput(self) + values = self._variables_adv.get_variable(name) + if not self._sim_package: + if values.size == self.model.nodetouser.size: + array = np.full(self.model.size, np.nan) + array[self.model.nodetouser] = values + return array + return values - def __getattr__(self, item): - """ - Method to dynamically get modflow variables by attribute - """ - if item in self._variables._ptrs: - return self._variables.get_ptr(item) - else: - return super().__getattribute__(item) + def set_advanced_var(self, name, values): + """Set a variable not surfaced through stress_period_data or variable_names.""" + if isinstance(values, ArrayPointer): + values = np.asarray(values.values) + if not self._sim_package: + if values.size == self.model.size: + values = values[self.model.nodetouser] + if self._variables_adv is None: + self._variables_adv = AdvancedInput(self) + self._variables_adv.set_variable(name, values) - @property - def variable_names(self): - """ - Returns a list of valid modflow variable names that the user can access - """ - return self._variables.variable_names + # ------------------------------------------------------------------ + # Explicit accessor methods (backward compatibility) + # ------------------------------------------------------------------ def get_array(self, item): - """ - Method to get an array from modflow - - Parameters - ---------- - item : str - modflow variable name. Ex. "k11" - - Returns - ------- - np.array of modflow data - """ - return self._variables.get_array(item) + """Get a grid-shaped array variable by name.""" + v = self._vars.get(item) + if v is None or not isinstance(v, ArrayPointer): + raise KeyError(f"{item} is not accessible in this package") + return v.values def set_array(self, item, array): - """ - Method to update the modflow pointer arrays - - Parameters - ---------- - item : str - modflow variable name. Ex. "k11" - array : np.array - numpy array - - """ - self._variables.set_array(item, array) - - -class ScalarPackage(PackageBase): - """ - Container for advanced data packages - - Parameters - ---------- - model : ApiModel - modflowapi model object - pkg_type : str - package type. Ex. "RCH" - pkg_name : str - package name (in the mf6 variables) - sim_package : bool - boolean flag for simulation level packages. Ex. TDIS, IMS - """ - - def __init__(self, model, pkg_type, pkg_name, sim_package=False): - super().__init__(model, pkg_type, pkg_name.upper(), "scalar", sim_package) - - self._variables = ScalarInput(self) - - def __repr__(self): - s = f"{self.pkg_type.upper()} Package: {self.pkg_name} \n" - s += " Accessible variables include:\n" - for var_name in self.variable_names: - s += f" {var_name} \n" - return s - - def __setattr__(self, item, value): - """ - Method that enables dynamic variable setting and distributes - modflow variable storage and updates to the data object class - """ - if item in ("model", "pkg_name", "pkg_type", "var_addrs"): - super().__setattr__(item, value) - - elif item.startswith("_"): - super().__setattr__(item, value) - - elif item in self._variables._ptrs: - self._variables.set_value(item, value) - - elif item in ("mxiter",): - # hack for sln-ems - super().__setattr__(item, value) - - else: - raise AttributeError(f"{item}") - - def __getattr__(self, item): - """ - Method to dynamically get modflow variables by attribute - """ - if item in self._variables._ptrs: - return self._variables.get_value(item) - else: - return super().__getattribute__(item) - - @property - def variable_names(self): - """ - Returns a list of valid modflow variable names that the user can access - """ - return self._variables.variable_names + """Set a grid-shaped array variable by name.""" + v = self._vars.get(item) + if v is None or not isinstance(v, ArrayPointer): + raise KeyError(f"{item} is not a valid variable name for this package") + v.values = array def get_value(self, item): - """ - Method to get a scalar value from modflow - - Parameters - ---------- - item : str - modflow variable name. Ex. "NBOUND" - - Returns - ------- - np.array of modflow data - """ - return self._variables.get_value(item) + """Get a scalar variable by name.""" + v = self._vars.get(item) + if v is None or not isinstance(v, ScalarVar): + raise KeyError(f"{item} is not accessible in this package") + return v.values def set_value(self, item, value): - """ - Method to update the modflow pointer arrays + """Set a scalar variable by name.""" + v = self._vars.get(item) + if v is None or not isinstance(v, ScalarVar): + raise KeyError(f"{item} is not accessible in this package") + v.values = value - Parameters - ---------- - item : str - modflow variable name. Ex. "k11" - array : str, int, float - scalar value - """ - self._variables.set_value(item, value) +# ------------------------------------------------------------------ +# Marker subclasses — preserve isinstance compatibility +# ------------------------------------------------------------------ -class AdvancedPackage(PackageBase): - """ - Container for advanced data packages - - Parameters - ---------- - model : ApiModel - modflowapi model object - pkg_type : str - package type. Ex. "RCH" - pkg_name : str - package name (in the mf6 variables) - sim_package : bool - boolean flag for simulation level packages. Ex. TDIS, IMS - """ - - def __init__(self, model, pkg_type, pkg_name, sim_package=False): - super().__init__(model, pkg_type, pkg_name.upper(), "advanced", sim_package) - - self._idm_enabled = False - self._package_var_addrs = [] - self._sp_var_addrs = [] - self._package_vars = None - self._sp_vars = None - - if pkg_type in adv_pkgvars: - self._adv_var_dict = adv_pkgvars[pkg_type] - - self._set_advanced_variable_addrs("packagedata", "_package_var_addrs") - - if "perioddata" in self._adv_var_dict: - # create variable addresses!!!! - for var in self._adv_var_dict["perioddata"]: - if isinstance(var, tuple): - use_bound = True - for v in var[-1]: - if ":" in v: - use_bound = False - - if use_bound: - self._bound_vars = var[-1] - var = var[0] - else: - for v in var[-1]: - if ":" in v: - tmp = v.split(":")[0] - self._bound_vars.append(tmp) - else: - self._bound_vars.append(v) - var_addr = self._get_advanced_variable_addr(v) - self._sp_var_addrs.append(var_addr) - var = None - - if var is not None: - var_addr = self.model.mf6.get_var_address(var.upper(), self.model.name, self.pkg_name) - self._sp_var_addrs.append(var_addr) - - self._package_vars = ListInput(self, self._package_var_addrs, spd=False) - self._sp_vars = ListInput(self, self._sp_var_addrs) - - def __repr__(self): - s = f"{self.pkg_type.upper()} Package: {self.pkg_name} \n" - s += " Advanced Package, variables only accessible through\n" - s += " get_advanced_var() and set_advanced_var() methods" - return s +class PackageBase(Package): + pass - def _set_advanced_variable_addrs(self, block, attr): - """ - General method for setting advanced variable block addresses - to their attributes. Method is used to reduce code duplication - - Parameters - ---------- - block : str - data block key - attr : str - attribute name - - Returns - ------- - None - """ - var_addrs = [] - if block in self._adv_var_dict: - for var in self._adv_var_dict[block]: - if not isinstance(var, tuple): - var_addrs.append(self._get_advanced_variable_addr(var)) - else: - for v in var: - var_addrs.append(self._get_advanced_variable_addr(v)) - setattr(self, attr, var_addrs) +class ListPackage(Package): + pass - def _get_advanced_variable_addr(self, var_str): - """ - Method to create variable addresses for advanced packages that can - include non-standard logic and processing instructions - Parameters - ---------- - var_str : str +class ArrayPackage(Package): + pass - Returns - ------- - var_addr : str - """ - s = f"{self.model.name}/{self.pkg_name}/{var_str.upper()}" - return s - @property - def packagedata(self): - """ - Returns a BlockInput object of the packagedata - """ - return self._package_vars +class ScalarPackage(Package): + pass - @packagedata.setter - def packagedata(self, recarray): - """ - Setter method to update the packagedata - - Parameters - ---------- - recarray : np.recarray, ListInput, or None - - """ - if self._package_vars is not None: - if isinstance(recarray, np.recarray): - self._package_vars.values = recarray - elif isinstance(recarray, ListInput): - self._package_vars.values = recarray.values - elif recarray is None: - self._package_vars.values = recarray - else: - raise TypeError(f"{type(recarray)} is not a supported stress_period_data type") - @property - def maxbound(self): - """ - Returns the "maxbound" value for the stress period - """ - if self._sp_vars is not None: - return self._sp_vars._maxbound[0] +class AdvancedPackage(Package): + pass - @property - def stress_period_data(self): - """ - Returns a ListInput object of the current stress_period_data - """ - return self._sp_vars - @stress_period_data.setter - def stress_period_data(self, recarray): - """ - Setter method to update the current stress_period_data - """ - if self._sp_vars is not None: - if isinstance(recarray, np.recarray): - self._sp_vars.values = recarray - elif isinstance(recarray, ListInput): - self._sp_vars.values = recarray.values - elif recarray is None: - self._sp_vars.values = recarray - else: - raise TypeError(f"{type(recarray)} is not a supported stress_period_data type") +# ------------------------------------------------------------------ +# Solution package +# ------------------------------------------------------------------ -class ApiSlnPackage(ScalarPackage): +class ApiSlnPackage(Package): """ - Class to access solution packages + Class to access solution packages. Parameters ---------- - model : ApiModel - modflowapi model object - pkg_type : str - package type. Ex. "RCH" + sim : ApiSimulation or ApiMbase + simulation object pkg_name : str - package name (in the mf6 variables) - sim_package : bool - boolean flag for simulation level packages. Ex. TDIS, IMS - sln_type : str - ackronymn for the solution package type, default is "ims" + package name in the MF6 variables + pkg_type : str + solution type abbreviation, default 'ims' """ def __init__(self, sim, pkg_name, pkg_type="ims"): @@ -634,40 +438,12 @@ def __init__(self, sim, pkg_name, pkg_type="ims"): super().__init__(sim, f"sln-{pkg_type}", pkg_name, sim_package=True) - if pkg_type in ("ims",): - mdl = ApiMbase(sim.mf6, pkg_name.upper(), pkg_types={pkg_type: ScalarPackage}) - imslin = ScalarPackage(mdl, "ims", "IMSLINEAR") - for key, ptr in imslin._variables._ptrs.items(): - if key in self._variables._ptrs: - key = f"{imslin.pkg_type}_{key}".lower() - self._variables._ptrs[key] = ptr + if pkg_type == "ims": + mdl = ApiMbase(sim.mf6, pkg_name.upper(), pkg_types={pkg_type: Package}) + imslin = Package(mdl, "ims", "IMSLINEAR", sim_package=True) + for key, var in imslin._vars.items(): + if key in self._vars: + key = f"{imslin.pkg_type}_{key}" + self._vars[key] = var else: - self.mxiter = 10 - - -def package_factory(pkg_type, basepackage): - """ - Method to autogenerate unique package "types" from the base packages: - ArrayPackage, ListPackage, and AdvancedPackage - - Parameters - ---------- - pkg_type : str - package type - basepackage : ArrayPackage, ListPackage, or AdvancedPackage - a base package type - - Returns - Package object : ex. ApiWelPackage - """ - - # hack for now. need a pkg_type variable for robustness - def __init__(self, obj, model, pkg_type, pkg_name, sim_package=False): - obj.__init__(self, model, pkg_type, pkg_name, sim_package=sim_package) - - cls_str = "".join(pkg_type.split("-")) - cls_str = f"{cls_str[0].upper()}{cls_str[1:]}" - - package = type(f"Api{cls_str}Package", (basepackage,), {"__init__": __init__}) - - return package + object.__setattr__(self, "mxiter", 10)