Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 19 additions & 19 deletions autotest/test_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from modflow_devtools.misc import set_dir

from modflowapi import Callbacks, ModflowApi, run_simulation
from modflowapi.extensions.pakbase import AdvancedPackage, ArrayPackage, ListPackage
from modflowapi.extensions.pakbase import Package

data_pth = Path("../docs/examples/data")
pytestmark = pytest.mark.extensions
Expand Down Expand Up @@ -71,16 +71,16 @@ def callback(sim, step):
raise AssertionError("ApiModel has advanced prior to initialization callback")

dis = model.dis
if not isinstance(dis, ArrayPackage):
raise TypeError("DIS package has incorrect base class type")
if "idomain" not in dis.variable_names:
raise TypeError("DIS package should have grid array variables")

wel = model.wel
if not isinstance(wel, ListPackage):
raise TypeError("WEL package has incorrect base class type")
if wel.stress_period_data is None:
raise TypeError("WEL package should have stress period data")

gnc = model.gnc
if not isinstance(gnc, AdvancedPackage):
raise TypeError("GNC package has incorrect base class type")
if not isinstance(gnc, Package):
raise TypeError("GNC package has incorrect type")

rch = model.rch
if len(rch) != 2:
Expand Down Expand Up @@ -158,16 +158,16 @@ def callback(sim, step):
raise AssertionError("ApiModel has advanced prior to initialization callback")

dis = model.dis
if not isinstance(dis, ArrayPackage):
raise TypeError("DIS package has incorrect base class type")
if "idomain" not in dis.variable_names:
raise TypeError("DIS package should have grid array variables")

chd = model.chd_left
if not isinstance(chd, ListPackage):
raise TypeError("CHD package has incorrect base class type")
if chd.stress_period_data is None:
raise TypeError("CHD package should have stress period data")

hfb = model.hfb
if not isinstance(hfb, AdvancedPackage):
raise TypeError("HFB package has incorrect base class type")
if not isinstance(hfb, Package):
raise TypeError("HFB package has incorrect type")

chd = model.chd
if len(chd) != 2:
Expand Down Expand Up @@ -236,16 +236,16 @@ def callback(sim, step):
raise AssertionError("ApiModel has advanced prior to initialization callback")

dis = model.dis
if not isinstance(dis, ArrayPackage):
raise TypeError("DIS package has incorrect base class type")
if "idomain" not in dis.variable_names:
raise TypeError("DIS package should have grid array variables")

rch = model.rch
if not isinstance(rch, ListPackage):
raise TypeError("RCH package has incorrect base class type")
if rch.stress_period_data is None:
raise TypeError("RCH package should have stress period data")

mvr = model.mvr
if not isinstance(mvr, AdvancedPackage):
raise TypeError("MVR package has incorrect base class type")
if not isinstance(mvr, Package):
raise TypeError("MVR package has incorrect type")

top = dis.top.values
if not isinstance(top, np.ndarray):
Expand Down
111 changes: 6 additions & 105 deletions modflowapi/extensions/advpaks.py
Original file line number Diff line number Diff line change
@@ -1,107 +1,8 @@
import numpy as np

from .data import ListInput
from .pakbase import AdvancedPackage


class SfrPakage(AdvancedPackage):
"""
Container for SFR and SFR like packages

Parameters
----------
model : ApiModel
modflowapi model object
pkg_type : str
package type. Ex. "SFR"
pkg_name : str
package name (in the mf6 variables)
sim_package : bool
boolean flag for simulation level packages. Ex. TDIS, IMS
"""

def __init__(self, model, pkg_type, pkg_name, sim_package=False):
super().__init__(model, pkg_type, pkg_name, sim_package)

self._diversion_var_arrs = []
self._set_advanced_variable_addrs("diversions", "_diversion_var_addrs")
self._diversion_vars = ListInput(self, self._diversion_var_arrs, spd=False)

@property
def diversions(self):
return self._diversion_vars

@diversions.setter
def diversions(self, recarray):
"""
Setter object to update the diversions data

"""
if isinstance(recarray, np.recarray):
self._diversion_vars.values = recarray
elif isinstance(recarray, ListInput):
self._diversion_vars.values = recarray.values
elif recarray is None:
self._diversion_vars.values = recarray
else:
raise TypeError(f"{type(recarray)} is not a supported diversions type")


class LakPackage(AdvancedPackage):
"""
Container for LAK and LAK like packages

Parameters
----------
model : ApiModel
modflowapi model object
pkg_type : str
package type. Ex. "LAK"
pkg_name : str
package name (in the mf6 variables)
sim_package : bool
boolean flag for simulation level packages. Ex. TDIS, IMS
"""

def __init__(self, model, pkg_type, pkg_name, sim_package=False):
super().__init__(model, pkg_type, pkg_name, sim_package)


class MawPackage(AdvancedPackage):
"""
Container for MAW and MAW like packages

Parameters
----------
model : ApiModel
modflowapi model object
pkg_type : str
package type. Ex. "MAW"
pkg_name : str
package name (in the mf6 variables)
sim_package : bool
boolean flag for simulation level packages. Ex. TDIS, IMS
"""

def __init__(self, model, pkg_type, pkg_name, sim_package=False):
super().__init__(model, pkg_type, pkg_name, sim_package)


class UzfPackage(AdvancedPackage):
"""
Container for UZF and UZF like packages

Parameters
----------
model : ApiModel
modflowapi model object
pkg_type : str
package type. Ex. "UZF"
pkg_name : str
package name (in the mf6 variables)
sim_package : bool
boolean flag for simulation level packages. Ex. TDIS, IMS
"""

def __init__(self, model, pkg_type, pkg_name, sim_package=False):
super().__init__(model, pkg_type, pkg_name, sim_package)
# Backward compatibility aliases
SfrPackage = AdvancedPackage
SfrPakage = AdvancedPackage # preserved old typo spelling
LakPackage = AdvancedPackage
MawPackage = AdvancedPackage
UzfPackage = AdvancedPackage
10 changes: 3 additions & 7 deletions modflowapi/extensions/apiexchange.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from .apimodel import ApiMbase
from .pakbase import ListPackage


class ApiExchange(ApiMbase):
Expand All @@ -15,10 +14,7 @@ class ApiExchange(ApiMbase):
modflow exchange name. ex. "GWF-GWF_1"
"""

sim_level = True # exchange packages are simulation-level, not model-level

def __init__(self, mf6, name):
pkg_types = {
"gwf-gwf": ListPackage,
"gwt-gwt": ListPackage,
"gwe-gwe": ListPackage,
}
super().__init__(mf6, name, pkg_types)
super().__init__(mf6, name)
61 changes: 21 additions & 40 deletions modflowapi/extensions/apimodel.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import math

import numpy as np

from .datamodel import get_package_type, gridshape
from .pakbase import AdvancedPackage, ArrayPackage, ListPackage, package_factory
from .datamodel import gridshape
from .pakbase import Package


class ApiMbase:
Expand All @@ -18,6 +20,8 @@ class ApiMbase:
optional dictionary of package types and ApiPackage class types
"""

sim_level = False # True for sim/exchange containers; False for model containers

def __init__(self, mf6, name, pkg_types=None):
self.mf6 = mf6
self.name = name
Expand All @@ -33,7 +37,7 @@ def package_list(self):
"""
Returns a list of package objects for the model
"""
return [package for _, package in self.package_dict.items()]
return list(self.package_dict.values())

@property
def package_names(self):
Expand All @@ -44,7 +48,7 @@ def package_names(self):

@property
def package_types(self):
return list(set([package.pkg_type for package in self.package_list]))
return list({p.pkg_type for p in self.package_list})

def _set_package_names(self):
"""
Expand All @@ -56,11 +60,9 @@ def _set_package_names(self):
if addr.endswith("PACKAGE_TYPE") and tmp[0] == self.name:
pak_types[tmp[1]] = self.mf6.get_value(addr)[0]
elif tmp[0] == self.name and len(tmp) == 2:
if tmp[0].startswith("GWF-GWF"):
pak_types[tmp[0]] = "GWF-GWF"
pak_types.pop("dis", None)
elif tmp[0].startswith("GWT-GWT"):
pak_types[tmp[0]] = "GWT-GWT"
parts = tmp[0].rsplit("_", 1)[0].split("-")
if len(parts) == 2 and parts[0] == parts[1]:
pak_types[tmp[0]] = tmp[0].rsplit("_", 1)[0]
pak_types.pop("dis", None)

self._pak_type = list(pak_types.values())
Expand All @@ -72,26 +74,15 @@ def _create_package_list(self):
"""
for ix, pkg_name in enumerate(self._pkg_names):
pkg_type = self._pak_type[ix].lower()
if self._pkg_types is None:
basepackage = get_package_type(pkg_type)
else:
if pkg_type in self._pkg_types:
basepackage = self._pkg_types[pkg_type]
else:
basepackage = AdvancedPackage

package = package_factory(pkg_type, basepackage)
adj_pkg_name = "".join(pkg_type.split("-"))

if adj_pkg_name.lower() in ("gwfgwf", "gwtgwt"):
adj_pkg_name = ""
if self._pkg_types is not None and pkg_type in self._pkg_types:
pkg_cls = self._pkg_types[pkg_type]
else:
adj_pkg_name = pkg_name
pkg_cls = Package

package = package(basepackage, self, pkg_type, adj_pkg_name)
package = pkg_cls(self, pkg_type, pkg_name, sim_package=self.sim_level)
self.package_dict[pkg_name.lower()] = package

def get_package(self, pkg_name) -> ListPackage or ArrayPackage or AdvancedPackage:
def get_package(self, pkg_name) -> "Package":
"""
Method to get a package

Expand Down Expand Up @@ -160,16 +151,9 @@ def __repr__(self):
else:
pass

s += "Packages accessible include: \n"
for typ, baseobj in [
("ArrayPackage", ArrayPackage),
("ListPackage", ListPackage),
("AdvancedPackage", AdvancedPackage),
]:
s += f" {typ} objects:\n"
for name, obj in self.package_dict.items():
if isinstance(obj, baseobj):
s += f" {name}: {type(obj)}\n"
s += "Packages accessible include:\n"
for name, pkg in self.package_dict.items():
s += f" {name}: {pkg.pkg_type.upper()}\n"

return s

Expand Down Expand Up @@ -263,8 +247,8 @@ def shape(self):
"""
Returns a tuple of the model shape
"""
ivn = self.mf6.get_input_var_names()
if self._shape is None:
ivn = self.mf6.get_input_var_names()
shape_vars = gridshape[self.dis_type]
shape = []
for var in shape_vars:
Expand All @@ -283,10 +267,7 @@ def size(self):
Returns the number of nodes in the model
"""
if self._size is None:
size = 1
for dim in self.shape:
size *= dim
self._size = size
self._size = math.prod(self.shape)
return self._size

@property
Expand Down
Loading
Loading