From e156b6fedca44b561dd1ca830457fe25788a24c5 Mon Sep 17 00:00:00 2001
From: Andrew Leech <andrew.leech@planetinnovation.com.au>
Date: Mon, 5 May 2025 21:03:14 +1000
Subject: [PATCH] micropython/aiorepl: Allow passing in stream to use.

Defaults to sys.stdin/out if not provided.

Signed-off-by: Andrew Leech <andrew.leech@planetinnovation.com.au>
---
 micropython/aiorepl/aiorepl.py | 115 +++++++++++++++++----------------
 1 file changed, 60 insertions(+), 55 deletions(-)

diff --git a/micropython/aiorepl/aiorepl.py b/micropython/aiorepl/aiorepl.py
index 3f437459d..470c1065a 100644
--- a/micropython/aiorepl/aiorepl.py
+++ b/micropython/aiorepl/aiorepl.py
@@ -93,13 +93,13 @@ async def kbd_intr_task(exec_task, s):
 
 
 # REPL task. Invoke this with an optional mutable globals dict.
-async def task(g=None, prompt="--> "):
+async def task(g=None, prompt="--> ", s_in=sys.stdin, s_out=sys.stdout):
     print("Starting asyncio REPL...")
     if g is None:
         g = __import__("__main__").__dict__
     try:
         micropython.kbd_intr(-1)
-        s = asyncio.StreamReader(sys.stdin)
+        s = asyncio.StreamReader(s_in)
         # clear = True
         hist = [None] * _HISTORY_LIMIT
         hist_i = 0  # Index of most recent entry.
@@ -108,7 +108,7 @@ async def task(g=None, prompt="--> "):
         t = 0  # timestamp of most recent character.
         while True:
             hist_b = 0  # How far back in the history are we currently.
-            sys.stdout.write(prompt)
+            s_out.write(prompt)
             cmd: str = ""
             paste = False
             curs = 0  # cursor offset from end of cmd buffer
@@ -122,7 +122,7 @@ async def task(g=None, prompt="--> "):
                     if c == 0x0A:
                         # LF
                         if paste:
-                            sys.stdout.write(b)
+                            s_out.write(b)
                             cmd += b
                             continue
                         # If the previous character was also LF, and was less
@@ -132,9 +132,9 @@ async def task(g=None, prompt="--> "):
                             continue
                         if curs:
                             # move cursor to end of the line
-                            sys.stdout.write("\x1b[{}C".format(curs))
+                            s_out.write("\x1b[{}C".format(curs))
                             curs = 0
-                        sys.stdout.write("\n")
+                        s_out.write("\n")
                         if cmd:
                             # Push current command.
                             hist[hist_i] = cmd
@@ -144,46 +144,46 @@ async def task(g=None, prompt="--> "):
 
                             result = await execute(cmd, g, s)
                             if result is not None:
-                                sys.stdout.write(repr(result))
-                                sys.stdout.write("\n")
+                                s_out.write(repr(result))
+                                s_out.write("\n")
                         break
                     elif c == 0x08 or c == 0x7F:
                         # Backspace.
                         if cmd:
                             if curs:
                                 cmd = "".join((cmd[: -curs - 1], cmd[-curs:]))
-                                sys.stdout.write(
+                                s_out.write(
                                     "\x08\x1b[K"
                                 )  # move cursor back, erase to end of line
-                                sys.stdout.write(cmd[-curs:])  # redraw line
-                                sys.stdout.write("\x1b[{}D".format(curs))  # reset cursor location
+                                s_out.write(cmd[-curs:])  # redraw line
+                                s_out.write("\x1b[{}D".format(curs))  # reset cursor location
                             else:
                                 cmd = cmd[:-1]
-                                sys.stdout.write("\x08 \x08")
+                                s_out.write("\x08 \x08")
                     elif c == CHAR_CTRL_A:
-                        await raw_repl(s, g)
+                        await raw_repl(s_in, s_out, g)
                         break
                     elif c == CHAR_CTRL_B:
                         continue
                     elif c == CHAR_CTRL_C:
                         if paste:
                             break
-                        sys.stdout.write("\n")
+                        s_out.write("\n")
                         break
                     elif c == CHAR_CTRL_D:
                         if paste:
                             result = await execute(cmd, g, s)
                             if result is not None:
-                                sys.stdout.write(repr(result))
-                                sys.stdout.write("\n")
+                                s_out.write(repr(result))
+                                s_out.write("\n")
                             break
 
-                        sys.stdout.write("\n")
+                        s_out.write("\n")
                         # Shutdown asyncio.
                         asyncio.new_event_loop()
                         return
                     elif c == CHAR_CTRL_E:
-                        sys.stdout.write("paste mode; Ctrl-C to cancel, Ctrl-D to finish\n===\n")
+                        s_out.write("paste mode; Ctrl-C to cancel, Ctrl-D to finish\n===\n")
                         paste = True
                     elif c == 0x1B:
                         # Start of escape sequence.
@@ -193,9 +193,9 @@ async def task(g=None, prompt="--> "):
                             hist[(hist_i - hist_b) % _HISTORY_LIMIT] = cmd
                             # Clear current command.
                             b = "\x08" * len(cmd)
-                            sys.stdout.write(b)
-                            sys.stdout.write(" " * len(cmd))
-                            sys.stdout.write(b)
+                            s_out.write(b)
+                            s_out.write(" " * len(cmd))
+                            s_out.write(b)
                             # Go backwards or forwards in the history.
                             if key == "[A":
                                 hist_b = min(hist_n, hist_b + 1)
@@ -203,56 +203,56 @@ async def task(g=None, prompt="--> "):
                                 hist_b = max(0, hist_b - 1)
                             # Update current command.
                             cmd = hist[(hist_i - hist_b) % _HISTORY_LIMIT]
-                            sys.stdout.write(cmd)
+                            s_out.write(cmd)
                         elif key == "[D":  # left
                             if curs < len(cmd) - 1:
                                 curs += 1
-                                sys.stdout.write("\x1b")
-                                sys.stdout.write(key)
+                                s_out.write("\x1b")
+                                s_out.write(key)
                         elif key == "[C":  # right
                             if curs:
                                 curs -= 1
-                                sys.stdout.write("\x1b")
-                                sys.stdout.write(key)
+                                s_out.write("\x1b")
+                                s_out.write(key)
                         elif key == "[H":  # home
                             pcurs = curs
                             curs = len(cmd)
-                            sys.stdout.write("\x1b[{}D".format(curs - pcurs))  # move cursor left
+                            s_out.write("\x1b[{}D".format(curs - pcurs))  # move cursor left
                         elif key == "[F":  # end
                             pcurs = curs
                             curs = 0
-                            sys.stdout.write("\x1b[{}C".format(pcurs))  # move cursor right
+                            s_out.write("\x1b[{}C".format(pcurs))  # move cursor right
                     else:
-                        # sys.stdout.write("\\x")
-                        # sys.stdout.write(hex(c))
+                        # s_out.write("\\x")
+                        # s_out.write(hex(c))
                         pass
                 else:
                     if curs:
                         # inserting into middle of line
                         cmd = "".join((cmd[:-curs], b, cmd[-curs:]))
-                        sys.stdout.write(cmd[-curs - 1 :])  # redraw line to end
-                        sys.stdout.write("\x1b[{}D".format(curs))  # reset cursor location
+                        s_out.write(cmd[-curs - 1 :])  # redraw line to end
+                        s_out.write("\x1b[{}D".format(curs))  # reset cursor location
                     else:
-                        sys.stdout.write(b)
+                        s_out.write(b)
                         cmd += b
     finally:
         micropython.kbd_intr(3)
 
 
-async def raw_paste(s, g, window=512):
-    sys.stdout.write("R\x01")  # supported
-    sys.stdout.write(bytearray([window & 0xFF, window >> 8, 0x01]).decode())
+def raw_paste(s_in, s_out, window=512):
+    s_out.write("R\x01")  # supported
+    s_out.write(bytearray([window & 0xFF, window >> 8, 0x01]).decode())
     eof = False
     idx = 0
     buff = bytearray(window)
     file = b""
     while not eof:
         for idx in range(window):
-            b = await s.read(1)
+            b = s_in.read(1)
             c = ord(b)
             if c == CHAR_CTRL_C or c == CHAR_CTRL_D:
                 # end of file
-                sys.stdout.write(chr(CHAR_CTRL_D))
+                s_out.write(chr(CHAR_CTRL_D))
                 if c == CHAR_CTRL_C:
                     raise KeyboardInterrupt
                 file += buff[:idx]
@@ -262,21 +262,26 @@ async def raw_paste(s, g, window=512):
 
         if not eof:
             file += buff
-            sys.stdout.write("\x01")  # indicate window available to host
+            s_out.write("\x01")  # indicate window available to host
 
     return file
 
 
-async def raw_repl(s: asyncio.StreamReader, g: dict):
+async def raw_repl(s_in: io.IOBase, s_out: io.IOBase, g: dict):
+    """
+    This function is blocking to prevent other
+    async tasks from writing to the stdio stream and
+    breaking the raw repl session.
+    """
     heading = "raw REPL; CTRL-B to exit\n"
     line = ""
-    sys.stdout.write(heading)
+    s_out.write(heading)
 
     while True:
         line = ""
-        sys.stdout.write(">")
+        s_out.write(">")
         while True:
-            b = await s.read(1)
+            b = s_in.read(1)
             c = ord(b)
             if c == CHAR_CTRL_A:
                 rline = line
@@ -284,16 +289,16 @@ async def raw_repl(s: asyncio.StreamReader, g: dict):
 
                 if len(rline) == 2 and ord(rline[0]) == CHAR_CTRL_E:
                     if rline[1] == "A":
-                        line = await raw_paste(s, g)
+                        line = raw_paste(s_in, s_out)
                         break
                 else:
                     # reset raw REPL
-                    sys.stdout.write(heading)
-                    sys.stdout.write(">")
+                    s_out.write(heading)
+                    s_out.write(">")
                 continue
             elif c == CHAR_CTRL_B:
                 # exit raw REPL
-                sys.stdout.write("\n")
+                s_out.write("\n")
                 return 0
             elif c == CHAR_CTRL_C:
                 # clear line
@@ -301,7 +306,7 @@ async def raw_repl(s: asyncio.StreamReader, g: dict):
             elif c == CHAR_CTRL_D:
                 # entry finished
                 # indicate reception of command
-                sys.stdout.write("OK")
+                s_out.write("OK")
                 break
             else:
                 # let through any other raw 8-bit value
@@ -310,16 +315,16 @@ async def raw_repl(s: asyncio.StreamReader, g: dict):
         if len(line) == 0:
             # Normally used to trigger soft-reset but stay in raw mode.
             # Fake it for aiorepl / mpremote.
-            sys.stdout.write("Ignored: soft reboot\n")
-            sys.stdout.write(heading)
+            s_out.write("Ignored: soft reboot\n")
+            s_out.write(heading)
 
         try:
             result = exec(line, g)
             if result is not None:
-                sys.stdout.write(repr(result))
-            sys.stdout.write(chr(CHAR_CTRL_D))
+                s_out.write(repr(result))
+            s_out.write(chr(CHAR_CTRL_D))
         except Exception as ex:
             print(line)
-            sys.stdout.write(chr(CHAR_CTRL_D))
-            sys.print_exception(ex, sys.stdout)
-        sys.stdout.write(chr(CHAR_CTRL_D))
+            s_out.write(chr(CHAR_CTRL_D))
+            sys.print_exception(ex, s_out)
+        s_out.write(chr(CHAR_CTRL_D))