From efccb9292874481a4f285b114dec408158c6294c Mon Sep 17 00:00:00 2001
From: kazk <kazk.dev@gmail.com>
Date: Sun, 12 Jul 2020 17:41:09 -0700
Subject: [PATCH 1/2] Format with Black

https://github.com/psf/black
---
 codewars_test/test_framework.py | 51 ++++++++++++++++++++-------------
 1 file changed, 31 insertions(+), 20 deletions(-)

diff --git a/codewars_test/test_framework.py b/codewars_test/test_framework.py
index 50b9f85..c035ca0 100644
--- a/codewars_test/test_framework.py
+++ b/codewars_test/test_framework.py
@@ -10,16 +10,19 @@ def format_message(message):
 
 
 def display(type, message, label="", mode=""):
-    print("\n<{0}:{1}:{2}>{3}".format(
-        type.upper(), mode.upper(), label, format_message(message)))
+    print(
+        "\n<{0}:{1}:{2}>{3}".format(
+            type.upper(), mode.upper(), label, format_message(message)
+        )
+    )
 
 
 def expect(passed=None, message=None, allow_raise=False):
     if passed:
-        display('PASSED', 'Test Passed')
+        display("PASSED", "Test Passed")
     else:
         message = message or "Value is not what was expected"
-        display('FAILED', message)
+        display("FAILED", message)
         if allow_raise:
             raise AssertException(message)
 
@@ -67,14 +70,17 @@ def expect_no_error(message, function, exception=BaseException):
     pass_()
 
 
-def pass_(): expect(True)
+def pass_():
+    expect(True)
 
 
-def fail(message): expect(False, message)
+def fail(message):
+    expect(False, message)
 
 
 def assert_approx_equals(
-        actual, expected, margin=1e-9, message=None, allow_raise=False):
+    actual, expected, margin=1e-9, message=None, allow_raise=False
+):
     msg = "{0} should be close to {1} with absolute or relative margin of {2}"
     equals_msg = msg.format(repr(actual), repr(expected), repr(margin))
     if message is None:
@@ -85,14 +91,14 @@ def assert_approx_equals(
     expect(abs((actual - expected) / div) < margin, message, allow_raise)
 
 
-'''
+"""
 Usage:
 @describe('describe text')
 def describe1():
     @it('it text')
     def it1():
         # some test cases...
-'''
+"""
 
 
 def _timed_block_factory(opening_text):
@@ -110,44 +116,49 @@ def wrapper(func):
             try:
                 func()
             except AssertionError as e:
-                display('FAILED', str(e))
+                display("FAILED", str(e))
             except Exception:
-                fail('Unexpected exception raised')
-                tb_str = ''.join(format_exception(*exc_info()))
-                display('ERROR', tb_str)
-            display('COMPLETEDIN', '{:.2f}'.format((timer() - time) * 1000))
+                fail("Unexpected exception raised")
+                tb_str = "".join(format_exception(*exc_info()))
+                display("ERROR", tb_str)
+            display("COMPLETEDIN", "{:.2f}".format((timer() - time) * 1000))
             if callable(after):
                 after()
+
         return wrapper
+
     return _timed_block_decorator
 
 
-describe = _timed_block_factory('DESCRIBE')
-it = _timed_block_factory('IT')
+describe = _timed_block_factory("DESCRIBE")
+it = _timed_block_factory("IT")
 
 
-'''
+"""
 Timeout utility
 Usage:
 @timeout(sec)
 def some_tests():
     any code block...
 Note: Timeout value can be a float.
-'''
+"""
 
 
 def timeout(sec):
     def wrapper(func):
         from multiprocessing import Process
-        msg = 'Should not throw any exceptions inside timeout'
+
+        msg = "Should not throw any exceptions inside timeout"
 
         def wrapped():
             expect_no_error(msg, func)
+
         process = Process(target=wrapped)
         process.start()
         process.join(sec)
         if process.is_alive():
-            fail('Exceeded time limit of {:.3f} seconds'.format(sec))
+            fail("Exceeded time limit of {:.3f} seconds".format(sec))
             process.terminate()
             process.join()
+
     return wrapper

From fe170b526e88d8eb8b047559e33069d59e1bebe5 Mon Sep 17 00:00:00 2001
From: kazk <kazk.dev@gmail.com>
Date: Sun, 12 Jul 2020 21:21:05 -0700
Subject: [PATCH 2/2] Use `AssertionError` when inside decorated test case

- Stops at first failure like most other test frameworks
- Single test passed message regardless of the number of assertions
- Discourage fat test cases
- Allow using external assertions
  - `np.testing.assert_equal`
  - `pd.testing.assert_frame_equal`
  - assertion packages
---
 codewars_test/test_framework.py               | 83 +++++++++++++++----
 tests/fixtures/custom_assertion.expected.txt  | 22 +++++
 tests/fixtures/custom_assertion.py            | 21 +++++
 .../fixtures/expect_error_sample.expected.txt | 52 +-----------
 tests/test_outputs.py                         |  7 +-
 5 files changed, 119 insertions(+), 66 deletions(-)
 create mode 100644 tests/fixtures/custom_assertion.expected.txt
 create mode 100644 tests/fixtures/custom_assertion.py

diff --git a/codewars_test/test_framework.py b/codewars_test/test_framework.py
index c035ca0..cfad8ed 100644
--- a/codewars_test/test_framework.py
+++ b/codewars_test/test_framework.py
@@ -1,4 +1,5 @@
 from __future__ import print_function
+import inspect
 
 
 class AssertException(Exception):
@@ -17,14 +18,43 @@ def display(type, message, label="", mode=""):
     )
 
 
-def expect(passed=None, message=None, allow_raise=False):
+# TODO Currently this only works if assertion functions are written directly in the test case.
+def _is_in_test_case():
+    frame = inspect.currentframe()
+    caller_frame = frame.f_back
+    test_case_frame = caller_frame.f_back
+    decorator_frame = test_case_frame.f_back
+    if not decorator_frame:
+        return False
+    if not "func" in decorator_frame.f_locals:
+        return False
+    func = decorator_frame.f_locals["func"]
+    code = test_case_frame.f_code
+    if func and func.__code__ == code and func.test_case_func:
+        return True
+    return False
+
+
+def _handle_test_result(passed, message=None, allow_raise=False, in_test_case=False):
     if passed:
-        display("PASSED", "Test Passed")
+        if not in_test_case:
+            display("PASSED", "Test Passed")
     else:
-        message = message or "Value is not what was expected"
-        display("FAILED", message)
-        if allow_raise:
-            raise AssertException(message)
+        if not message:
+            message = "Value is not what was expected"
+        if in_test_case:
+            raise AssertionError(message)
+        else:
+            display("FAILED", message)
+            if allow_raise:
+                # TODO Use AssertionError?
+                raise AssertException(message)
+
+
+def expect(passed=None, message=None, allow_raise=False):
+    _handle_test_result(
+        passed, message, allow_raise, _is_in_test_case(),
+    )
 
 
 def assert_equals(actual, expected, message=None, allow_raise=False):
@@ -34,7 +64,9 @@ def assert_equals(actual, expected, message=None, allow_raise=False):
     else:
         message += ": " + equals_msg
 
-    expect(actual == expected, message, allow_raise)
+    _handle_test_result(
+        actual == expected, message, allow_raise, _is_in_test_case(),
+    )
 
 
 def assert_not_equals(actual, expected, message=None, allow_raise=False):
@@ -45,7 +77,9 @@ def assert_not_equals(actual, expected, message=None, allow_raise=False):
     else:
         message += ": " + equals_msg
 
-    expect(not (actual == expected), message, allow_raise)
+    _handle_test_result(
+        not (actual == expected), message, allow_raise, _is_in_test_case(),
+    )
 
 
 def expect_error(message, function, exception=Exception):
@@ -56,26 +90,35 @@ def expect_error(message, function, exception=Exception):
         passed = True
     except Exception:
         pass
-    expect(passed, message)
+    _handle_test_result(
+        passed, message, False, _is_in_test_case(),
+    )
 
 
 def expect_no_error(message, function, exception=BaseException):
+    passed = True
     try:
         function()
     except exception as e:
-        fail("{}: {}".format(message or "Unexpected exception", repr(e)))
-        return
+        passed = False
+        message = "{}: {}".format(message or "Unexpected exception", repr(e))
     except Exception:
         pass
-    pass_()
+    _handle_test_result(
+        passed, message, False, _is_in_test_case(),
+    )
 
 
 def pass_():
-    expect(True)
+    if not _is_in_test_case():
+        display("PASSED", "Test Passed")
 
 
 def fail(message):
-    expect(False, message)
+    if _is_in_test_case():
+        raise AssertionError(message)
+    else:
+        display("FAILED", message)
 
 
 def assert_approx_equals(
@@ -88,7 +131,12 @@ def assert_approx_equals(
     else:
         message += ": " + equals_msg
     div = max(abs(actual), abs(expected), 1)
-    expect(abs((actual - expected) / div) < margin, message, allow_raise)
+    _handle_test_result(
+        abs((actual - expected) / div) < margin,
+        message,
+        allow_raise,
+        _is_in_test_case(),
+    )
 
 
 """
@@ -108,13 +156,18 @@ def _timed_block_factory(opening_text):
 
     def _timed_block_decorator(s, before=None, after=None):
         display(opening_text, s)
+        is_test_case = opening_text == "IT"
 
         def wrapper(func):
             if callable(before):
                 before()
             time = timer()
+            if is_test_case:
+                func.test_case_func = True
             try:
                 func()
+                if is_test_case:
+                    display("PASSED", "Test Passed")
             except AssertionError as e:
                 display("FAILED", str(e))
             except Exception:
diff --git a/tests/fixtures/custom_assertion.expected.txt b/tests/fixtures/custom_assertion.expected.txt
new file mode 100644
index 0000000..8c0f14b
--- /dev/null
+++ b/tests/fixtures/custom_assertion.expected.txt
@@ -0,0 +1,22 @@
+
+<DESCRIBE::>group 1
+
+<IT::>test 1
+
+<PASSED::>Test Passed
+
+<COMPLETEDIN::>0.00
+
+<IT::>test 2
+
+<FAILED::>Expected 1 to equal 2
+
+<COMPLETEDIN::>0.01
+
+<IT::>test 3
+
+<FAILED::>using assert
+
+<COMPLETEDIN::>0.00
+
+<COMPLETEDIN::>0.03
diff --git a/tests/fixtures/custom_assertion.py b/tests/fixtures/custom_assertion.py
new file mode 100644
index 0000000..5db6634
--- /dev/null
+++ b/tests/fixtures/custom_assertion.py
@@ -0,0 +1,21 @@
+import codewars_test as test
+
+
+def custom_assert_equal(a, b):
+    if a != b:
+        raise AssertionError("Expected {} to equal {}".format(a, b))
+
+
+@test.describe("group 1")
+def group_1():
+    @test.it("test 1")
+    def test_1():
+        custom_assert_equal(1, 1)
+
+    @test.it("test 2")
+    def test_2():
+        custom_assert_equal(1, 2)
+
+    @test.it("test 3")
+    def test_3():
+        assert 1 == 2, "using assert"
diff --git a/tests/fixtures/expect_error_sample.expected.txt b/tests/fixtures/expect_error_sample.expected.txt
index b2c7f56..5968378 100644
--- a/tests/fixtures/expect_error_sample.expected.txt
+++ b/tests/fixtures/expect_error_sample.expected.txt
@@ -5,72 +5,24 @@
 
 <FAILED::>f0 did not raise any exception
 
-<FAILED::>f0 did not raise Exception
-
-<FAILED::>f0 did not raise ArithmeticError
-
-<FAILED::>f0 did not raise ZeroDivisionError
-
-<FAILED::>f0 did not raise LookupError
-
-<FAILED::>f0 did not raise KeyError
-
-<FAILED::>f0 did not raise OSError
-
-<COMPLETEDIN::>0.03
+<COMPLETEDIN::>0.02
 
 <IT::>f1 raises Exception
 
-<PASSED::>Test Passed
-
-<PASSED::>Test Passed
-
 <FAILED::>f1 did not raise ArithmeticError
 
-<FAILED::>f1 did not raise ZeroDivisionError
-
-<FAILED::>f1 did not raise LookupError
-
-<FAILED::>f1 did not raise KeyError
-
-<FAILED::>f1 did not raise OSError
-
 <COMPLETEDIN::>0.02
 
 <IT::>f2 raises Exception >> ArithmeticError >> ZeroDivisionError
 
-<PASSED::>Test Passed
-
-<PASSED::>Test Passed
-
-<PASSED::>Test Passed
-
-<PASSED::>Test Passed
-
 <FAILED::>f2 did not raise LookupError
 
-<FAILED::>f2 did not raise KeyError
-
-<FAILED::>f2 did not raise OSError
-
 <COMPLETEDIN::>0.02
 
 <IT::>f3 raises Exception >> LookupError >> KeyError
 
-<PASSED::>Test Passed
-
-<PASSED::>Test Passed
-
 <FAILED::>f3 did not raise ArithmeticError
 
-<FAILED::>f3 did not raise ZeroDivisionError
-
-<PASSED::>Test Passed
-
-<PASSED::>Test Passed
-
-<FAILED::>f3 did not raise OSError
-
 <COMPLETEDIN::>0.02
 
-<COMPLETEDIN::>0.11
+<COMPLETEDIN::>0.10
diff --git a/tests/test_outputs.py b/tests/test_outputs.py
index bfe8396..43c40e9 100644
--- a/tests/test_outputs.py
+++ b/tests/test_outputs.py
@@ -23,7 +23,12 @@ def test(self):
             expected = re.sub(
                 r"(?<=<COMPLETEDIN::>)\d+(?:\.\d+)?", r"\\d+(?:\\.\\d+)?", r.read()
             )
-            self.assertRegex(result.stdout.decode("utf-8"), expected)
+            actual = result.stdout.decode("utf-8")
+            self.assertRegex(
+                actual,
+                expected,
+                "Expected Pattern:\n{}\n\nGot:\n{}\n".format(expected, actual),
+            )
 
     return test