Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 13 additions & 13 deletions imod/mf6/aggregate/aggregate_schemes.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable
from typing import Any, Callable

import numpy as np
from pydantic.dataclasses import dataclass
Expand Down Expand Up @@ -37,10 +37,10 @@ class RiverAggregationMethod(DataclassType):

"""

stage: Callable = np.nanmean
conductance: Callable = np.nansum
bottom_elevation: Callable = np.nanmean
concentration: Callable = np.nanmean
stage: Callable[..., Any] = np.nanmean
conductance: Callable[..., Any] = np.nansum
bottom_elevation: Callable[..., Any] = np.nanmean
concentration: Callable[..., Any] = np.nanmean


@dataclass(config=_CONFIG)
Expand All @@ -65,9 +65,9 @@ class DrainageAggregationMethod(DataclassType):

"""

elevation: Callable = np.nanmean
conductance: Callable = np.nansum
concentration: Callable = np.nanmean
elevation: Callable[..., Any] = np.nanmean
conductance: Callable[..., Any] = np.nansum
concentration: Callable[..., Any] = np.nanmean


@dataclass(config=_CONFIG)
Expand All @@ -92,9 +92,9 @@ class GeneralHeadBoundaryAggregationMethod(DataclassType):

"""

head: Callable = np.nanmean
conductance: Callable = np.nansum
concentration: Callable = np.nanmean
head: Callable[..., Any] = np.nanmean
conductance: Callable[..., Any] = np.nansum
concentration: Callable[..., Any] = np.nanmean


@dataclass(config=_CONFIG)
Expand All @@ -118,5 +118,5 @@ class RechargeAggregationMethod(DataclassType):

"""

rate: Callable = np.nansum
concentration: Callable = np.nanmean
rate: Callable[..., Any] = np.nansum
concentration: Callable[..., Any] = np.nanmean
4 changes: 2 additions & 2 deletions imod/mf6/ats.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from imod.schemata import AllValueSchema, DimsSchema, DTypeSchema
from imod.typing import GridDataset

_PeriodDataType: TypeAlias = dict[np.int64, list]
_PeriodDataType: TypeAlias = dict[np.int64, list[Any]]
_PeriodDataVarNames: TypeAlias = tuple[str, str, str, str, str]


Expand Down Expand Up @@ -208,7 +208,7 @@ def _get_render_dictionary(
d["perioddata"] = perioddata
return d

def _validate(self, schemata: dict, **kwargs):
def _validate(self, schemata: dict[str, Any], **kwargs):
# Insert additional kwargs
kwargs["dt_max"] = self["dt_max"]
errors = super()._validate(schemata, **kwargs)
Expand Down
18 changes: 12 additions & 6 deletions imod/mf6/boundary_condition.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import abc
import pathlib
from copy import copy, deepcopy
from typing import Mapping, Optional, Union
from typing import Any, MutableMapping, Optional, Union

import numpy as np
import xarray as xr
Expand Down Expand Up @@ -69,13 +69,15 @@ class BoundaryCondition(Package, abc.ABC):
not the array input which is used in :class:`Package`.
"""

def __init__(self, allargs: Mapping[str, GridDataArray | float | int | bool | str]):
def __init__(
self, allargs: MutableMapping[str, GridDataArray | float | int | bool | str]
):
# Convert repeat_stress in dict to a xr.DataArray in the right shape if
# necessary, which is required to merge it into the dataset.
if "repeat_stress" in allargs.keys() and isinstance(
allargs["repeat_stress"], dict
):
allargs["repeat_stress"] = get_repeat_stress(allargs["repeat_stress"]) # type: ignore
allargs["repeat_stress"] = get_repeat_stress(allargs["repeat_stress"])
# Call the Package constructor, this merges the arguments into a dataset.
super().__init__(allargs)
if "concentration" in allargs.keys() and allargs["concentration"] is None:
Expand Down Expand Up @@ -197,7 +199,9 @@ def _period_paths(
return periods

def _get_unfiltered_pkg_options(
self, predefined_options: dict, not_options: Optional[list] = None
self,
predefined_options: dict[str, Any],
not_options: Optional[list[str]] = None,
):
options = copy(predefined_options)

Expand All @@ -208,11 +212,13 @@ def _get_unfiltered_pkg_options(
if varname in not_options:
continue
v = self.dataset[varname].values[()]
options[varname] = v
options[str(varname)] = v
return options

def _get_pkg_options(
self, predefined_options: dict, not_options: Optional[list] = None
self,
predefined_options: dict[str, Any],
not_options: Optional[list[str]] = None,
):
unfiltered_options = self._get_unfiltered_pkg_options(
predefined_options, not_options=not_options
Expand Down
6 changes: 4 additions & 2 deletions imod/mf6/evt.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional
from typing import Any, Optional

import numpy as np

Expand Down Expand Up @@ -250,7 +250,9 @@ def _validate(self, schemata, **kwargs):
return errors

def _get_pkg_options(
self, predefined_options: dict, not_options: Optional[list] = None
self,
predefined_options: dict[str, Any],
not_options: Optional[list[str]] = None,
):
options = super()._get_pkg_options(predefined_options, not_options=not_options)
# Add amount of segments
Expand Down
22 changes: 11 additions & 11 deletions imod/mf6/hfb.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from copy import deepcopy
from enum import Enum
from pathlib import Path
from typing import TYPE_CHECKING, Dict, List, Optional, Self, Tuple
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Self, Tuple

import cftime
import numpy as np
Expand Down Expand Up @@ -186,7 +186,7 @@ def to_connected_cells_dataset(
idomain: GridDataArray,
grid: xu.Ugrid2d,
edge_index: np.ndarray,
edge_values: dict,
edge_values: dict[str, Any],
) -> xr.Dataset:
"""
Converts a cell edge grid with values defined on the edges to a dataset with the cell ids of the connected cells,
Expand Down Expand Up @@ -416,7 +416,7 @@ def _prepare_barrier_dataset_for_mf6_adapter(dataset: xr.Dataset) -> xr.Dataset:

def _snap_to_grid_and_aggregate(
barrier_dataframe: GeoDataFrameType, grid2d: xu.Ugrid2d, vardict_agg: dict[str, str]
) -> tuple[xu.UgridDataset, npt.NDArray]:
) -> tuple[xu.UgridDataset, npt.NDArray[Any]]:
"""
Snap barrier dataframe to grid and aggregate multiple lines with a list of
methods per variable.
Expand Down Expand Up @@ -481,7 +481,7 @@ def __init__(
geometry: "gpd.GeoDataFrame",
print_input: bool = False,
) -> None:
dict_dataset = {"print_input": print_input}
dict_dataset: dict[str, Any] = {"print_input": print_input}
super().__init__(dict_dataset)
self.line_data = geometry

Expand Down Expand Up @@ -859,7 +859,7 @@ def _get_variable_name(self) -> str:
raise NotImplementedError

@abc.abstractmethod
def _get_vertical_variables(self) -> list:
def _get_vertical_variables(self) -> list[str]:
raise NotImplementedError

def clip_box(
Expand Down Expand Up @@ -1143,7 +1143,7 @@ def _get_barrier_type(self):
def _get_variable_name(self) -> str:
return "hydraulic_characteristic"

def _get_vertical_variables(self) -> list:
def _get_vertical_variables(self) -> list[str]:
return []

def _compute_barrier_values(
Expand Down Expand Up @@ -1220,7 +1220,7 @@ def _get_barrier_type(self):
def _get_variable_name(self) -> str:
return "hydraulic_characteristic"

def _get_vertical_variables(self) -> list:
def _get_vertical_variables(self) -> list[str]:
return ["layer"]

def _compute_barrier_values(
Expand Down Expand Up @@ -1294,7 +1294,7 @@ def _get_barrier_type(self):
def _get_variable_name(self) -> str:
return "multiplier"

def _get_vertical_variables(self) -> list:
def _get_vertical_variables(self) -> list[str]:
return []

def _compute_barrier_values(
Expand Down Expand Up @@ -1373,7 +1373,7 @@ def _get_barrier_type(self):
def _get_variable_name(self) -> str:
return "multiplier"

def _get_vertical_variables(self) -> list:
def _get_vertical_variables(self) -> list[str]:
return ["layer"]

def _compute_barrier_values(
Expand Down Expand Up @@ -1467,7 +1467,7 @@ def _get_barrier_type(self):
def _get_variable_name(self) -> str:
return "resistance"

def _get_vertical_variables(self) -> list:
def _get_vertical_variables(self) -> list[str]:
return []

def _compute_barrier_values(
Expand Down Expand Up @@ -1539,7 +1539,7 @@ def _get_barrier_type(self):
def _get_variable_name(self) -> str:
return "resistance"

def _get_vertical_variables(self) -> list:
def _get_vertical_variables(self) -> list[str]:
return ["layer"]

def _compute_barrier_values(
Expand Down
4 changes: 2 additions & 2 deletions imod/mf6/mf6_hfb_adapter.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from copy import deepcopy
from typing import Union
from typing import Any, Union

import numpy as np
import xarray as xr
Expand Down Expand Up @@ -115,7 +115,7 @@ def __init__(
print_input: Union[bool, xr.DataArray] = False,
validate: Union[bool, xr.DataArray] = True,
):
dict_dataset = {
dict_dataset: dict[str, Any] = {
"cell_id1": cell_id1,
"cell_id2": cell_id2,
"layer": layer,
Expand Down
12 changes: 6 additions & 6 deletions imod/mf6/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def _create_boundary_condition_clipped_boundary(
original_model: Modflow6Model,
clipped_model: Modflow6Model,
state_for_boundary: Optional[GridDataArray],
clip_box_args: tuple,
clip_box_args: tuple[Any, ...],
) -> Optional[StateType]:
# Create temporary boundary condition for the original model boundary. This
# is used later to see which boundaries can be ignored as they were already
Expand Down Expand Up @@ -146,7 +146,7 @@ def _create_boundary_condition_clipped_boundary(
return bc_constant_pkg


class Modflow6Model(collections.UserDict, IModel, abc.ABC):
class Modflow6Model(collections.UserDict[str, Package], IModel, abc.ABC):
_mandatory_packages: tuple[str, ...] = ()
_init_schemata: SchemataDict = {}
_model_id: Optional[str] = None
Expand All @@ -165,7 +165,7 @@ def __init__(self):

@standard_log_decorator()
def _validate_options(
self, schemata: dict, **kwargs
self, schemata: dict[str, Any], **kwargs
) -> dict[str, list[ValidationError]]:
return validate_schemata_dict(schemata, self._options, **kwargs)

Expand Down Expand Up @@ -568,7 +568,7 @@ def _write(
globaltimes=globaltimes,
write_context=pkg_write_context,
)
elif issubclass(type(pkg), imod.mf6.HorizontalFlowBarrierBase):
elif isinstance(pkg, imod.mf6.HorizontalFlowBarrierBase):
mf6_hfb_ls.append(pkg)
else:
pkg._write(
Expand Down Expand Up @@ -656,7 +656,7 @@ def dump(
if statusinfo.has_errors():
raise ValidationError(statusinfo.to_string())

toml_content: dict = collections.defaultdict(dict)
toml_content: dict[str, Any] = collections.defaultdict(dict)

for pkgname, pkg in self.items():
pkg_path = pkg.to_file(
Expand Down Expand Up @@ -696,7 +696,7 @@ def from_file(cls, toml_path):
return instance

@property
def options(self) -> dict:
def options(self) -> dict[str, Any]:
if self._options is None:
raise ValueError("Model id has not been set")
return self._options
Expand Down
4 changes: 2 additions & 2 deletions imod/mf6/out/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def open_cbc(
"disu": disu.open_hds,
}

_OPEN_CBC: Dict[str, Callable] = {
_OPEN_CBC: Dict[str, Callable[..., Any]] = {
"dis": dis.open_cbc,
"disv": disv.open_cbc,
"disu": disu.open_cbc,
Expand All @@ -76,7 +76,7 @@ def open_cbc(
}


def _get_function(d: Dict[str, Callable], key: str) -> Callable:
def _get_function(d: Dict[str, Callable[..., Any]], key: str) -> Callable[..., Any]:
try:
func = d[key]
except KeyError:
Expand Down
2 changes: 1 addition & 1 deletion imod/mf6/out/cbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ def read_imeth6_budgets_dense(
dtype: np.dtype,
pos: int,
size: int,
shape: tuple,
shape: Tuple[int, ...],
return_variable: str,
indices: np.ndarray | None,
) -> FloatArray:
Expand Down
8 changes: 4 additions & 4 deletions imod/mf6/out/dis.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def open_dvs(


def open_imeth1_budgets(
cbc_path: FilePath, grb_content: dict, header_list: List[cbc.Imeth1Header]
cbc_path: FilePath, grb_content: Dict[str, Any], header_list: List[cbc.Imeth1Header]
) -> xr.DataArray:
"""
Open the data for an imeth==1 budget section. Data is read lazily per
Expand Down Expand Up @@ -252,7 +252,7 @@ def open_imeth1_budgets(

def open_imeth6_budgets(
cbc_path: FilePath,
grb_content: dict,
grb_content: Dict[str, Any],
header_list: List[cbc.Imeth6Header],
return_variable: str = "budget",
indices: np.ndarray | None = None,
Expand Down Expand Up @@ -374,7 +374,7 @@ def dis_indices(


def dis_to_right_front_lower_indices(
grb_content: dict,
grb_content: Dict[str, Any],
) -> Tuple[xr.DataArray, xr.DataArray, xr.DataArray]:
"""
Infer the indices to extract right, front, and lower face flows from the
Expand Down Expand Up @@ -442,7 +442,7 @@ def dis_extract_face_budgets(


def dis_open_face_budgets(
cbc_path: FilePath, grb_content: dict, header_list: List[cbc.Imeth1Header]
cbc_path: FilePath, grb_content: Dict[str, Any], header_list: List[cbc.Imeth1Header]
) -> Tuple[xr.DataArray, xr.DataArray, xr.DataArray]:
"""
Open the flow-ja-face, and extract right, front, and lower face flows.
Expand Down
4 changes: 2 additions & 2 deletions imod/mf6/out/disu.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,13 +88,13 @@ def open_dvs(


def open_imeth1_budgets(
cbc_path: FilePath, grb_content: dict, header_list: List[cbc.Imeth1Header]
cbc_path: FilePath, grb_content: Dict[str, Any], header_list: List[cbc.Imeth1Header]
) -> xr.DataArray:
raise NotImplementedError


def open_imeth6_budgets(
cbc_path: FilePath, grb_content: dict, header_list: List[cbc.Imeth6Header]
cbc_path: FilePath, grb_content: Dict[str, Any], header_list: List[cbc.Imeth6Header]
) -> xr.DataArray:
raise NotImplementedError

Expand Down
Loading
Loading