Skip to content

Commit

Permalink
chore: add test for #28
Browse files Browse the repository at this point in the history
  • Loading branch information
phodal committed Jul 10, 2024
1 parent 7c3f1db commit a79e726
Showing 1 changed file with 59 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,65 @@ class Employee:
self.name = name
emp = Employee("Zara")
"""
PythonAnalyser().analysis(code, "")
}

@Test
internal fun shouldHandleForImportError() {
val code = """
from dsp.utils import normalize_text
from dspy.primitives.prediction import Completions, Prediction
default_normalize = lambda s: normalize_text(s) or None
def majority(prediction_or_completions, normalize=default_normalize, field=None):
""${'"'}
Returns the most common completion for the target field (or the last field) in the signature.
When normalize returns None, that completion is ignored.
In case of a tie, earlier completion are prioritized.
""${'"'}
assert any(isinstance(prediction_or_completions, t) for t in [Prediction, Completions, list])
input_type = type(prediction_or_completions)
# Get the completions
if isinstance(prediction_or_completions, Prediction):
completions = prediction_or_completions.completions
else:
completions = prediction_or_completions
try:
signature = completions.signature
except:
signature = None
if not field:
if signature:
field = signature.output_fields[-1]
else:
field = list(completions[0].keys())[-1]
# Normalize
normalize = normalize if normalize else lambda x: x
normalized_values = [normalize(completion[field]) for completion in completions]
normalized_values_ = [x for x in normalized_values if x is not None]
# Count
value_counts = {}
for value in (normalized_values_ or normalized_values):
value_counts[value] = value_counts.get(value, 0) + 1
majority_value = max(value_counts, key=value_counts.get)
# Return the first completion with the majority value in the field
for completion in completions:
if normalize(completion[field]) == majority_value:
break
# if input_type == Prediction:
return Prediction.from_completions([completion], signature=signature)
"""
PythonAnalyser().analysis(code, "")
}
Expand Down

0 comments on commit a79e726

Please sign in to comment.