From a79e7265b665e69b652a529f041098bfc0d7add8 Mon Sep 17 00:00:00 2001 From: Phodal Huang Date: Wed, 10 Jul 2024 14:37:28 +0800 Subject: [PATCH] chore: add test for #28 --- .../pythonast/PythonFullIdentListenerTest.kt | 59 +++++++++++++++++++ 1 file changed, 59 insertions(+) diff --git a/chapi-ast-python/src/test/kotlin/chapi/ast/pythonast/PythonFullIdentListenerTest.kt b/chapi-ast-python/src/test/kotlin/chapi/ast/pythonast/PythonFullIdentListenerTest.kt index 760e2fbc..f02d71e4 100644 --- a/chapi-ast-python/src/test/kotlin/chapi/ast/pythonast/PythonFullIdentListenerTest.kt +++ b/chapi-ast-python/src/test/kotlin/chapi/ast/pythonast/PythonFullIdentListenerTest.kt @@ -230,6 +230,65 @@ class Employee: self.name = name emp = Employee("Zara") +""" + PythonAnalyser().analysis(code, "") + } + + @Test + internal fun shouldHandleForImportError() { + val code = """ +from dsp.utils import normalize_text +from dspy.primitives.prediction import Completions, Prediction + +default_normalize = lambda s: normalize_text(s) or None + + +def majority(prediction_or_completions, normalize=default_normalize, field=None): + ""${'"'} + Returns the most common completion for the target field (or the last field) in the signature. + When normalize returns None, that completion is ignored. + In case of a tie, earlier completion are prioritized. + ""${'"'} + + assert any(isinstance(prediction_or_completions, t) for t in [Prediction, Completions, list]) + input_type = type(prediction_or_completions) + + # Get the completions + if isinstance(prediction_or_completions, Prediction): + completions = prediction_or_completions.completions + else: + completions = prediction_or_completions + + try: + signature = completions.signature + except: + signature = None + + if not field: + if signature: + field = signature.output_fields[-1] + else: + field = list(completions[0].keys())[-1] + + # Normalize + normalize = normalize if normalize else lambda x: x + normalized_values = [normalize(completion[field]) for completion in completions] + normalized_values_ = [x for x in normalized_values if x is not None] + + # Count + value_counts = {} + for value in (normalized_values_ or normalized_values): + value_counts[value] = value_counts.get(value, 0) + 1 + + majority_value = max(value_counts, key=value_counts.get) + + # Return the first completion with the majority value in the field + for completion in completions: + if normalize(completion[field]) == majority_value: + break + + # if input_type == Prediction: + return Prediction.from_completions([completion], signature=signature) """ PythonAnalyser().analysis(code, "") }