Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 12 additions & 8 deletions pyomo/common/pyomo_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,20 @@ def _get_fullqual_name(func: typing.Callable) -> str:
return f"{func.__module__}.{func.__qualname__}"


def overload(func: typing.Callable):
"""Wrap typing.overload that remembers the overloaded signatures
if typing.TYPE_CHECKING:
from typing import overload as overload
else:
Comment on lines +21 to +23
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you need to disable the overload?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A function calling typing.overload is not recognized as an overload by the type checker.
We need to replace it with typing.overload during type checking.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK. I went digging and it turns out that logic is primarily needed by a dependent project -- but only for Python versions through 3.10. I think we should actually document that better in the code, with something like:

if sys.version_info[:2] <= (3, 10) and not TYPE_CHECKING:
    def overload(func: typing.Callable):
        """Wrap typing.overload that remembers the overloaded signatures

        This provides a custom implementation of typing.overload that
        remembers the overloaded signatures so that they are available for
        runtime inspection (backporting `get_overloads` from Python 3.11+).

        """
        _overloads.setdefault(_get_fullqual_name(func), []).append(func)
        return typing.overload(func)

    def get_overloads_for(func: typing.Callable):
        return _overloads.get(_get_fullqual_name(func), [])
else:
    from typing import overload, get_overloads as get_overloads_for

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typing_extensions might be helpful.


This provides a custom implementation of typing.overload that
remembers the overloaded signatures so that they are available for
runtime inspection.
def overload(func: typing.Callable):
"""Wrap typing.overload that remembers the overloaded signatures

"""
_overloads.setdefault(_get_fullqual_name(func), []).append(func)
return typing.overload(func)
This provides a custom implementation of typing.overload that
remembers the overloaded signatures so that they are available for
runtime inspection.

"""
_overloads.setdefault(_get_fullqual_name(func), []).append(func)
return typing.overload(func)


def get_overloads_for(func: typing.Callable):
Expand Down
5 changes: 5 additions & 0 deletions pyomo/contrib/appsi/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import weakref

from typing import (
TYPE_CHECKING,
Sequence,
Dict,
Optional,
Expand Down Expand Up @@ -1714,5 +1715,9 @@ class LegacySolver(LegacySolverInterface, cls):

return decorator

if TYPE_CHECKING:
# NOTE: `Factory.__call__` can return None, but for the common case
def __call__(self, name, **kwds) -> Solver: ...


SolverFactory = SolverFactoryClass()
10 changes: 9 additions & 1 deletion pyomo/contrib/solver/common/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@
# ___________________________________________________________________________


from pyomo.opt.base.solvers import LegacySolverFactory
from typing import TYPE_CHECKING

from pyomo.common.factory import Factory
from pyomo.contrib.solver.common.base import LegacySolverWrapper
from pyomo.opt.base.solvers import LegacySolverFactory


class SolverFactoryClass(Factory):
Expand Down Expand Up @@ -107,6 +109,12 @@ class LegacySolver(LegacySolverWrapper, cls):

return decorator

if TYPE_CHECKING:
from pyomo.contrib.solver.common.base import SolverBase

# NOTE: `Factory.__call__` can return None, but for the common case
def __call__(self, name, **kwds) -> SolverBase: ...


#: Global registry/factory for "v2" solver interfaces.
SolverFactory: SolverFactoryClass = SolverFactoryClass()
9 changes: 7 additions & 2 deletions pyomo/core/base/PyomoModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from weakref import ref as weakref_ref
import gc
import math
from typing import TypeVar

from pyomo.common import timing
from pyomo.common.collections import Bunch
Expand Down Expand Up @@ -572,6 +573,10 @@ def select(
StaleFlagManager.mark_all_as_stale(delayed=True)


# NOTE: Python 3.11+ use `typing.Self`
ModelT = TypeVar("ModelT", bound="Model")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We haven't really settled on a typing convention for naming types, but appending T seems confusing.

  • would standardizing on appendint Type be more clear?
  • should local TypeVar objects be private by default (i.e., _modelType here)?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_ModelType is better.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed.



@ModelComponentFactory.register(
'Model objects can be used as a component of other models.'
)
Expand All @@ -583,9 +588,9 @@ class Model(ScalarBlock):

_Block_reserved_words = set()

def __new__(cls, *args, **kwds):
def __new__(cls: type[ModelT], *args, **kwds) -> ModelT:
if cls != Model:
return super(Model, cls).__new__(cls)
return super(Model, cls).__new__(cls) # type: ignore

raise TypeError(
"Directly creating the 'Model' class is not allowed. Please use the "
Expand Down
10 changes: 5 additions & 5 deletions pyomo/core/base/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -2085,17 +2085,17 @@ class Block(ActiveIndexedComponent):
_ComponentDataClass = BlockData
_private_data_initializers = defaultdict(lambda: dict)

@overload
def __new__(
cls: Type[Block], *args, **kwds
) -> Union[ScalarBlock, IndexedBlock]: ...

@overload
def __new__(cls: Type[ScalarBlock], *args, **kwds) -> ScalarBlock: ...

@overload
def __new__(cls: Type[IndexedBlock], *args, **kwds) -> IndexedBlock: ...

@overload
def __new__(
cls: Type[Block], *args, **kwds
) -> Union[ScalarBlock, IndexedBlock]: ...

def __new__(cls, *args, **kwds):
if cls != Block:
return super(Block, cls).__new__(cls)
Expand Down
10 changes: 5 additions & 5 deletions pyomo/core/base/constraint.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,17 +638,17 @@ class Constraint(ActiveIndexedComponent):
Violated = Infeasible
Satisfied = Feasible

@overload
def __new__(
cls: Type[Constraint], *args, **kwds
) -> Union[ScalarConstraint, IndexedConstraint]: ...

@overload
def __new__(cls: Type[ScalarConstraint], *args, **kwds) -> ScalarConstraint: ...

@overload
def __new__(cls: Type[IndexedConstraint], *args, **kwds) -> IndexedConstraint: ...

@overload
def __new__(
cls: Type[Constraint], *args, **kwds
) -> Union[ScalarConstraint, IndexedConstraint]: ...

def __new__(cls, *args, **kwds):
if cls != Constraint:
return super().__new__(cls)
Expand Down
10 changes: 5 additions & 5 deletions pyomo/core/base/param.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,17 +309,17 @@ class NoValue:

pass

@overload
def __new__(
cls: Type[Param], *args, **kwds
) -> Union[ScalarParam, IndexedParam]: ...

@overload
def __new__(cls: Type[ScalarParam], *args, **kwds) -> ScalarParam: ...

@overload
def __new__(cls: Type[IndexedParam], *args, **kwds) -> IndexedParam: ...

@overload
def __new__(
cls: Type[Param], *args, **kwds
) -> Union[ScalarParam, IndexedParam]: ...

def __new__(cls, *args, **kwds):
if cls != Param:
return super(Param, cls).__new__(cls)
Expand Down
4 changes: 2 additions & 2 deletions pyomo/core/base/set.py
Original file line number Diff line number Diff line change
Expand Up @@ -2128,10 +2128,10 @@ class SortedOrder:
_UnorderedInitializers = {set}

@overload
def __new__(cls: Type[Set], *args, **kwds) -> Union[SetData, IndexedSet]: ...
def __new__(cls: Type[OrderedScalarSet], *args, **kwds) -> OrderedScalarSet: ...

@overload
def __new__(cls: Type[OrderedScalarSet], *args, **kwds) -> OrderedScalarSet: ...
def __new__(cls: Type[Set], *args, **kwds) -> Union[SetData, IndexedSet]: ...

def __new__(cls, *args, **kwds):
if cls is not Set:
Expand Down
6 changes: 3 additions & 3 deletions pyomo/core/base/var.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,15 +575,15 @@ class Var(IndexedComponent, IndexedComponent_NDArrayMixin):

_ComponentDataClass = VarData

@overload
def __new__(cls: Type[Var], *args, **kwargs) -> Union[ScalarVar, IndexedVar]: ...

@overload
def __new__(cls: Type[ScalarVar], *args, **kwargs) -> ScalarVar: ...

@overload
def __new__(cls: Type[IndexedVar], *args, **kwargs) -> IndexedVar: ...

@overload
def __new__(cls: Type[Var], *args, **kwargs) -> Union[ScalarVar, IndexedVar]: ...

def __new__(cls, *args, **kwargs):
if cls is not Var:
return super(Var, cls).__new__(cls)
Expand Down
35 changes: 29 additions & 6 deletions pyomo/future.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,15 @@
# This software is distributed under the 3-clause BSD License.
# ___________________________________________________________________________

from typing import TYPE_CHECKING, Any, Literal, overload

import pyomo.environ as _environ

if TYPE_CHECKING:
import pyomo.contrib.appsi.base as _appsi
import pyomo.contrib.solver.common.factory as _contrib
import pyomo.opt.base.solvers as _solvers

__doc__ = """
Preview capabilities through ``pyomo.__future__``
=================================================
Expand All @@ -28,13 +35,29 @@

"""

solver_factory_v1: "_solvers.SolverFactoryClass"
solver_factory_v2: "_appsi.SolverFactoryClass"
solver_factory_v3: "_contrib.SolverFactoryClass"


def __getattr__(name):
if name in ('solver_factory_v1', 'solver_factory_v2', 'solver_factory_v3'):
if name in ("solver_factory_v1", "solver_factory_v2", "solver_factory_v3"):
return solver_factory(int(name[-1]))
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")


@overload
def solver_factory(version: None = None) -> int: ...
@overload
def solver_factory(version: Literal[1]) -> "_solvers.SolverFactoryClass": ...
@overload
def solver_factory(version: Literal[2]) -> "_appsi.SolverFactoryClass": ...
@overload
def solver_factory(version: Literal[3]) -> "_contrib.SolverFactoryClass": ...
@overload
def solver_factory(version: int) -> Any: ...


def solver_factory(version=None):
"""Get (or set) the active implementation of the SolverFactory

Expand Down Expand Up @@ -90,19 +113,19 @@ def solver_factory(version=None):
if current is None:
for ver, cls in versions.items():
if cls._cls is _environ.SolverFactory._cls:
solver_factory._active_version = ver
solver_factory._active_version = ver # type: ignore
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We shouldn't need these annotations, should we?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because we cannot add attributes to FunctionType.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can't annotate the function, plus the attribute whereit is first instantiated at the bottom of the file?

def solver_factory(version: int | None = None) -> int:
# ...
solver_factory._active_version: int = solver_factory()

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is allowed at runtime, but it's a type violation.
ref: microsoft/pyright#8838

break
return solver_factory._active_version
return solver_factory._active_version # type: ignore
#
# The user is just asking what the current SolverFactory is; tell them.
if version is None:
return solver_factory._active_version
return solver_factory._active_version # type: ignore
#
# Update the current SolverFactory to be a shim around (shallow copy
# of) the new active factory
src = versions.get(version, None)
if version is not None:
solver_factory._active_version = version
solver_factory._active_version = version # type: ignore
for attr in ('_description', '_cls', '_doc'):
setattr(_environ.SolverFactory, attr, getattr(src, attr))
else:
Expand All @@ -113,4 +136,4 @@ def solver_factory(version=None):
return src


solver_factory._active_version = solver_factory()
solver_factory._active_version = solver_factory() # type: ignore
6 changes: 6 additions & 0 deletions pyomo/opt/base/solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import time
import logging
import shlex
from typing import overload

from pyomo.common import Factory
from pyomo.common.enums import SolverAPIVersion
Expand Down Expand Up @@ -144,6 +145,11 @@ def _solver_error(self, method_name):


class SolverFactoryClass(Factory):
@overload
def __call__(self, _name: None = None, **kwds) -> "SolverFactoryClass": ...
@overload
def __call__(self, _name, **kwds) -> "OptSolver": ...
Comment on lines +148 to +151
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Python documentation indicates that you should only use the Ellipsis notation in .pyi files. I think this should be:

Suggested change
@overload
def __call__(self, _name: None = None, **kwds) -> "SolverFactoryClass": ...
@overload
def __call__(self, _name, **kwds) -> "OptSolver": ...
@overload
def __call__(self, _name: None = None, **kwds) -> "SolverFactoryClass":
pass
@overload
def __call__(self, _name, **kwds) -> "OptSolver":
pass

Copy link
Author

@n-takumasa n-takumasa Nov 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not believe the documentation states that you must not use Ellipsis in .py files.
In fact, it is uncommon to see pass used as the body for an @overloaded function.


def __call__(self, _name=None, **kwds):
if _name is None:
return self
Expand Down