Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file removed example_notebooks/data/for_plotting.h5
Binary file not shown.
2 changes: 1 addition & 1 deletion src/paretobench/analyze_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,7 @@ def val_counts_to_summary_str(counts):
n_equal = int(counts.get("=", 0))
if (n_minus + n_plus + n_equal) == 0:
return " "
return " \multicolumn{1}{c}{%d/%d/%d} " % (n_plus, n_minus, n_equal)
return r" \multicolumn{1}{c}{%d/%d/%d} " % (n_plus, n_minus, n_equal)

comparisons = df.map(lambda x: (x[-1] if len(x) > 4 else "")).apply(pd.Series.value_counts).fillna(0)
comparisons = comparisons.apply(val_counts_to_summary_str)
Expand Down
3 changes: 1 addition & 2 deletions src/paretobench/containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ class Population(BaseModel):

# Total number of function evaluations performed during optimization after this population was completed
fevals: int

# Optional lists of names for decision variables, objectives, and constraints
names_x: Optional[List[str]] = None
names_f: Optional[List[str]] = None
Expand All @@ -47,7 +46,7 @@ class Population(BaseModel):
constraint_targets: np.ndarray

# Pydantic config
model_config = ConfigDict(arbitrary_types_allowed=True)
model_config = ConfigDict(arbitrary_types_allowed=True, validate_assignment=True)

@model_validator(mode="before")
@classmethod
Expand Down
3 changes: 2 additions & 1 deletion src/paretobench/problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def __call__(self, x: np.ndarray, check_bounds=True):
if check_bounds and ((x > self.var_upper_bounds).all() or (x < self.var_lower_bounds).all()):
raise InputError("Input lies outside of problem bounds.")
pop = self._call(x[None, :])
pop.x = np.reshape(x, (1, -1))

# If batched input is used
elif len(x.shape) == 2:
Expand All @@ -57,13 +58,13 @@ def __call__(self, x: np.ndarray, check_bounds=True):
if check_bounds and ((x > self.var_upper_bounds).all() or (x < self.var_lower_bounds).all()):
raise InputError("Input lies outside of problem bounds.")
pop = self._call(x)
pop.x = x

# If user provided something not usable
else:
raise ValueError(f"Incompatible shape of input array x: {x.shape}")

# Set the decision variables
pop.x = x
return pop

def _call(self, x: np.ndarray):
Expand Down
Binary file not shown.
Binary file not shown.
6 changes: 6 additions & 0 deletions tests/test_containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,12 @@ def test_population_invalid_dimensions():
Population(x=x, f=f, g=g, fevals=5)


def test_field_assignment_validation():
with pytest.raises(ValueError, match="Expected array with 2 dimensions for field 'x'"):
pop = Population(f=np.random.random((256, 2)))
pop.x = np.random.random((2))


def test_overwrite():
# Create a randomized Experiment object
experiment1 = Experiment.from_random(
Expand Down
10 changes: 10 additions & 0 deletions tests/test_problems.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,3 +100,13 @@ def test_pareto_front(problem_name, npoints=1000):
# Make sure the right size array is returned and it doesn't give bad values
assert p.n_objs == f.shape[1]
assert not np.isnan(f).any()


def test_unbatched_problem_evaluation():
"""
Confirms that the population object is correctly formated on unbatched calls to problem
"""
prob = pb.WFG1()
x = np.random.random((prob.n))
pop = prob(x)
assert len(pop.x.shape) == 2
Loading