diff --git a/example_notebooks/data/for_plotting.h5 b/example_notebooks/data/for_plotting.h5 deleted file mode 100644 index 1a46cc77..00000000 Binary files a/example_notebooks/data/for_plotting.h5 and /dev/null differ diff --git a/src/paretobench/analyze_metrics.py b/src/paretobench/analyze_metrics.py index 59cb32ad..feeea7ea 100644 --- a/src/paretobench/analyze_metrics.py +++ b/src/paretobench/analyze_metrics.py @@ -532,7 +532,7 @@ def val_counts_to_summary_str(counts): n_equal = int(counts.get("=", 0)) if (n_minus + n_plus + n_equal) == 0: return " " - return " \multicolumn{1}{c}{%d/%d/%d} " % (n_plus, n_minus, n_equal) + return r" \multicolumn{1}{c}{%d/%d/%d} " % (n_plus, n_minus, n_equal) comparisons = df.map(lambda x: (x[-1] if len(x) > 4 else "")).apply(pd.Series.value_counts).fillna(0) comparisons = comparisons.apply(val_counts_to_summary_str) diff --git a/src/paretobench/containers.py b/src/paretobench/containers.py index db3b4594..7b337c48 100644 --- a/src/paretobench/containers.py +++ b/src/paretobench/containers.py @@ -33,7 +33,6 @@ class Population(BaseModel): # Total number of function evaluations performed during optimization after this population was completed fevals: int - # Optional lists of names for decision variables, objectives, and constraints names_x: Optional[List[str]] = None names_f: Optional[List[str]] = None @@ -47,7 +46,7 @@ class Population(BaseModel): constraint_targets: np.ndarray # Pydantic config - model_config = ConfigDict(arbitrary_types_allowed=True) + model_config = ConfigDict(arbitrary_types_allowed=True, validate_assignment=True) @model_validator(mode="before") @classmethod diff --git a/src/paretobench/problem.py b/src/paretobench/problem.py index 3192f334..2ce5fb32 100644 --- a/src/paretobench/problem.py +++ b/src/paretobench/problem.py @@ -46,6 +46,7 @@ def __call__(self, x: np.ndarray, check_bounds=True): if check_bounds and ((x > self.var_upper_bounds).all() or (x < self.var_lower_bounds).all()): raise InputError("Input lies outside of problem bounds.") pop = self._call(x[None, :]) + pop.x = np.reshape(x, (1, -1)) # If batched input is used elif len(x.shape) == 2: @@ -57,13 +58,13 @@ def __call__(self, x: np.ndarray, check_bounds=True): if check_bounds and ((x > self.var_upper_bounds).all() or (x < self.var_lower_bounds).all()): raise InputError("Input lies outside of problem bounds.") pop = self._call(x) + pop.x = x # If user provided something not usable else: raise ValueError(f"Incompatible shape of input array x: {x.shape}") # Set the decision variables - pop.x = x return pop def _call(self, x: np.ndarray): diff --git a/tests/legacy_file_formats/paretobench_file_format_v1.0.0.h5 b/tests/legacy_file_formats/paretobench_file_format_v1.0.0.h5 deleted file mode 100644 index 04fd3ff8..00000000 Binary files a/tests/legacy_file_formats/paretobench_file_format_v1.0.0.h5 and /dev/null differ diff --git a/tests/legacy_file_formats/paretobench_file_format_v1.1.0.h5 b/tests/legacy_file_formats/paretobench_file_format_v1.1.0.h5 deleted file mode 100644 index 831337a0..00000000 Binary files a/tests/legacy_file_formats/paretobench_file_format_v1.1.0.h5 and /dev/null differ diff --git a/tests/test_containers.py b/tests/test_containers.py index 43c7b615..de024d94 100644 --- a/tests/test_containers.py +++ b/tests/test_containers.py @@ -288,6 +288,12 @@ def test_population_invalid_dimensions(): Population(x=x, f=f, g=g, fevals=5) +def test_field_assignment_validation(): + with pytest.raises(ValueError, match="Expected array with 2 dimensions for field 'x'"): + pop = Population(f=np.random.random((256, 2))) + pop.x = np.random.random((2)) + + def test_overwrite(): # Create a randomized Experiment object experiment1 = Experiment.from_random( diff --git a/tests/test_problems.py b/tests/test_problems.py index 26e555b4..8f2a7aed 100644 --- a/tests/test_problems.py +++ b/tests/test_problems.py @@ -100,3 +100,13 @@ def test_pareto_front(problem_name, npoints=1000): # Make sure the right size array is returned and it doesn't give bad values assert p.n_objs == f.shape[1] assert not np.isnan(f).any() + + +def test_unbatched_problem_evaluation(): + """ + Confirms that the population object is correctly formated on unbatched calls to problem + """ + prob = pb.WFG1() + x = np.random.random((prob.n)) + pop = prob(x) + assert len(pop.x.shape) == 2