Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/pysatl_core/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,10 @@ class Distribution(Protocol):
analytical_computations : Mapping[str, AnalyticalComputation]
Direct analytical computations provided by the distribution.
sampling_strategy : SamplingStrategy
Strategy for generating random samples.
Strategy for generating random samples. Such an object is unique for each distribution.
computation_strategy : ComputationStrategy
Strategy for computing characteristics and conversions.
Such an object is unique for each distribution.
support : Support or None
Support of the distribution, if defined.
"""
Expand Down
151 changes: 133 additions & 18 deletions src/pysatl_core/families/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,13 @@
__copyright__ = "Copyright (c) 2025 PySATL project"
__license__ = "SPDX-License-Identifier: MIT"


from dataclasses import dataclass
from typing import TYPE_CHECKING, cast

from pysatl_core.distributions.distribution import Distribution
from pysatl_core.distributions.strategies import (
DefaultComputationStrategy,
DefaultSamplingUnivariateStrategy,
)
from pysatl_core.families.registry import ParametricFamilyRegister
from pysatl_core.types import NumericArray

Expand All @@ -40,8 +42,9 @@
ParametrizationName,
)

_KEEP: object = object()


@dataclass(slots=True)
class ParametricFamilyDistribution(Distribution):
"""
A specific distribution instance from a parametric family.
Expand All @@ -53,18 +56,45 @@ class ParametricFamilyDistribution(Distribution):
----------
family_name : str
Name of the distribution family.
_distribution_type : DistributionType
distribution_type : DistributionType
Type of this distribution.
_parametrization : Parametrization
parametrization : Parametrization
Parameter values for this distribution.
_support : Support or None
support : Support or None
Support of this distribution.
sampling_strategy : SamplingStrategy
Strategy for generating random samples. Such an object is unique for each distribution.
computation_strategy : ComputationStrategy
Strategy for computing characteristics and conversions.
Such an object is unique for each distribution.
"""

family_name: str
_distribution_type: DistributionType
_parametrization: Parametrization
_support: Support | None
def __init__(
self,
family_name: str,
distribution_type: DistributionType,
parametrization: Parametrization,
support: Support | None,
sampling_strategy: SamplingStrategy | None = None,
computation_strategy: ComputationStrategy[Any, Any] | None = None,
):
self._distribution_type = distribution_type
self._family_name = family_name
self._parametrization = parametrization
self._support = support

self._computation_strategy = computation_strategy or DefaultComputationStrategy()
self._sampling_strategy = sampling_strategy or DefaultSamplingUnivariateStrategy()

self._analytical_cache_key: tuple[int, GenericCharacteristicName] | None = None
self._analytical_cache_val: (
Mapping[GenericCharacteristicName, AnalyticalComputation[Any, Any]] | None
) = None

@property
def family_name(self) -> str:
"Get the name of the family this distribution belongs to."
return self._family_name

@property
def distribution_type(self) -> DistributionType:
Expand Down Expand Up @@ -142,25 +172,110 @@ def analytical_computations(
parametrization object or name changes.
"""
key = (id(self.parametrization), self.parametrization_name)
cache_key = getattr(self, "_analytical_cache_key", None)
cache_val = getattr(self, "_analytical_cache_val", None)

if cache_key != key or cache_val is None:
cache_val = self.family._build_analytical_computations(self.parametrization)
if self._analytical_cache_key != key or self._analytical_cache_val is None:
self._analytical_cache_val = self.family.build_analytical_computations(
self.parametrization
)
self._analytical_cache_key = key
self._analytical_cache_val = cache_val

return cache_val
return self._analytical_cache_val

@property
def sampling_strategy(self) -> SamplingStrategy:
"""Get the sampling strategy for this distribution."""
return self.family.sampling_strategy
return self._sampling_strategy

@property
def computation_strategy(self) -> ComputationStrategy[Any, Any]:
"""Get the computation strategy for this distribution."""
return self.family.computation_strategy
return self._computation_strategy

def with_sampling_strategy(
self, sampling_strategy: SamplingStrategy | None
) -> ParametricFamilyDistribution:
"""
Return a copy of this distribution with an updated sampling strategy.

Parameters
----------
sampling_strategy : SamplingStrategy | None
New sampling strategy. If ``None``, the default sampling strategy is used.

Returns
-------
ParametricFamilyDistribution
New distribution instance with the same parameters and updated strategy.
"""
return ParametricFamilyDistribution(
family_name=self._family_name,
distribution_type=self._distribution_type,
parametrization=self._parametrization,
support=self._support,
sampling_strategy=sampling_strategy,
computation_strategy=self._computation_strategy,
)

def with_computation_strategy(
self, computation_strategy: ComputationStrategy[Any, Any] | None
) -> ParametricFamilyDistribution:
"""
Return a copy of this distribution with an updated computation strategy.

Parameters
----------
computation_strategy : ComputationStrategy[Any, Any] | None
New computation strategy. If ``None``, the default computation strategy is used.

Returns
-------
ParametricFamilyDistribution
New distribution instance with the same parameters and updated strategy.
"""
return ParametricFamilyDistribution(
family_name=self._family_name,
distribution_type=self._distribution_type,
parametrization=self._parametrization,
support=self._support,
sampling_strategy=self._sampling_strategy,
computation_strategy=computation_strategy,
)

def with_strategies(
self,
*,
sampling_strategy: SamplingStrategy | None = None,
computation_strategy: ComputationStrategy[Any, Any] | None = None,
) -> ParametricFamilyDistribution:
"""
Return a copy of this distribution with updated strategies.

Parameters
----------
sampling_strategy : SamplingStrategy | None | object, optional
New sampling strategy. If not provided, the current strategy is preserved.
If explicitly set to ``None``, the default sampling strategy is used.
computation_strategy : ComputationStrategy[Any, Any] | None | object, optional
New computation strategy. If not provided, the current strategy is preserved.
If explicitly set to ``None``, the default computation strategy is used.

Returns
-------
ParametricFamilyDistribution
New distribution instance with the same parameters and updated strategies.
"""
new_sampling = self._sampling_strategy if sampling_strategy is _KEEP else sampling_strategy
new_computation = (
self._computation_strategy if computation_strategy is _KEEP else computation_strategy
)
return ParametricFamilyDistribution(
family_name=self._family_name,
distribution_type=self._distribution_type,
parametrization=self._parametrization,
support=self._support,
sampling_strategy=new_sampling,
computation_strategy=new_computation,
)

@property
def support(self) -> Support | None:
Expand Down
Loading
Loading