Skip to content

Commit

Permalink
fixing mypy errors
Browse files Browse the repository at this point in the history
  • Loading branch information
jezsadler committed Dec 10, 2024
1 parent c58d370 commit 8c05da0
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 21 deletions.
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ check_untyped_defs = true
disallow_untyped_decorators = true
warn_unreachable = true
disallow_any_generics = true
disable_error_code = "attr-defined,misc"

[[tool.mypy.overrides]]
module = [
Expand All @@ -168,6 +169,7 @@ module = [
"keras.*",
"tensorflow.*",
"torch_geometric.*",
"juliacall.*",
]
ignore_missing_imports = true

Expand Down
2 changes: 1 addition & 1 deletion src/omlt/base/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,4 +106,4 @@ def new_expression(
list(self.exprs.keys()),
)
raise KeyError(msg)
return self.exprs[lang](**kwargs)
return self.exprs[lang](**kwargs) # type: ignore[abstract]
26 changes: 11 additions & 15 deletions src/omlt/base/julia.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from __future__ import annotations
from typing import Any

from numpy import float32
Expand Down Expand Up @@ -88,10 +89,10 @@ def to_jump(self):


class JumpVar(OmltElement):
def __init__(self, varinfo: JuMPVarInfo, name):
def __init__(self, varinfo: JuMPVarInfo, name: None | str):
self.info = varinfo
self.name = name
self.omltvar = None
self.omltvar: None | OmltScalarJuMP | OmltIndexedJuMP = None
self.index = None
self.construct()

Expand Down Expand Up @@ -178,7 +179,7 @@ def tanh(self):
class OmltScalarJuMP(OmltScalar):
format = "jump"

def __init__(self, *, binary=False, **kwargs: Any):
def __init__(self, *, binary: bool = False, **kwargs: Any):
super().__init__()

self._bounds = kwargs.pop("bounds", None)
Expand All @@ -204,8 +205,8 @@ def __init__(self, *, binary=False, **kwargs: Any):

self.binary = binary

self._value : None | int | float
_initialize = kwargs.pop("initialize", None)

if _initialize:
if isinstance(_initialize, (int, float)):
self._value = _initialize
Expand All @@ -214,7 +215,8 @@ def __init__(self, *, binary=False, **kwargs: Any):
else:
msg = (
"Initial value for JuMP variables must be an int"
f" or float, but {type(_initialize)} was provided."
" or float, but %s was provided.",
type(_initialize),
)
raise ValueError(msg)
else:
Expand Down Expand Up @@ -335,7 +337,7 @@ def __init__(self, *indexes: Any, binary: bool = False, **kwargs: Any):
self._vars[idx] = JumpVar(self._varinfo[idx], str(idx))
self._vars[idx].omltvar = self
self._vars[idx].index = idx
self._varrefs = {}
self._varrefs: dict[Any, Any] = {}
self._constructed = False
self._parent = None
self._name = None
Expand Down Expand Up @@ -364,18 +366,12 @@ def __setitem__(self, item, value):
self.construct()

def keys(self):
if self._parent is not None:
return self._varrefs.keys()
return self._vars.keys()

def values(self):
if self._parent is not None:
return self._varrefs.values()
def values(self, sort=Any): # noqa: ARG002
return self._vars.values()

def items(self):
if self._parent is not None:
return self._varrefs.items()
return self._vars.items()

def __len__(self):
Expand All @@ -400,7 +396,7 @@ def construct(self, *, data=None): # noqa: ARG002
self._vars[idx].omltvar = self
self._vars[idx].index = idx
if self._parent is not None:
block = self._parent()
block = self._parent() # type: ignore[unreachable]
if block._format == "jump" and block._jumpmodel is not None:
self._varrefs[idx] = self._vars[idx].add_to_model(block._jumpmodel)

Expand Down Expand Up @@ -454,7 +450,7 @@ def keys(self, sort=False): # noqa: ARG002, FBT002
def __setitem__(self, label, item):
self._jumpcons[label] = item
if self.model is not None and self.name is not None:
jump.add_constraint(
jump.add_constraint( # type: ignore[unreachable]
self.model._jumpmodel, item._jumpcon, self.name + "[" + str(label) + "]"
)

Expand Down
4 changes: 2 additions & 2 deletions src/omlt/base/pyomo.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ class OmltExprScalarPyomo(OmltExpr, pyo.Expression):
def __init__(self, expr=None):
self._index_set = {}
if isinstance(expr, OmltExprScalarPyomo):
self._expression = expr._expression
self._expression : pyo.Expression = expr._expression
elif isinstance(expr, (pyo.Expression, pyo.NumericValue)):
self._expression = expr
elif isinstance(expr, tuple):
Expand All @@ -336,7 +336,7 @@ def _parse_expression_tuple_term(self, term):
return term._pyovar
if isinstance(term, (pyo.Expression, pyo.Var, VarData, int, float, float32)):
return term
msg = ("Term of expression %s is an unsupported type. %s", term, type(term))
msg = ("Term of expression %s is an unsupported type. %s", term, type(term)) # type: ignore[unreachable]
raise TypeError(msg)

def _parse_expression_tuple(self, expr):
Expand Down
6 changes: 3 additions & 3 deletions src/omlt/base/var.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def valid_model_component(self):
class OmltScalar(OmltVar):
format: str | None = None

def __init__(self):
def __init__(self, binary=False): # noqa: ARG002, FBT002
self.expr_factory = expression.OmltExprFactory()

def is_indexed(self):
Expand Down Expand Up @@ -291,7 +291,7 @@ def new_var(
list(self.indexed.keys()),
)
raise KeyError(msg)
return self.indexed[lang](*indexes, binary=binary, **kwargs)
return self.indexed[lang](*indexes, binary=binary, **kwargs) # type: ignore[abstract, call-arg]
if lang not in self.scalars:
msg = (
"Variable format %s not recognized. Supported formats are %s",
Expand All @@ -300,4 +300,4 @@ def new_var(
)
raise KeyError(msg)

return self.scalars[lang](binary=binary, **kwargs)
return self.scalars[lang](binary=binary, **kwargs) # type: ignore[abstract]

0 comments on commit 8c05da0

Please sign in to comment.