Skip to content

Commit 688f88b

Browse files
authored
added support for auth_method in push() and authenticate()
1 parent d4b15c9 commit 688f88b

File tree

2 files changed

+40
-19
lines changed

2 files changed

+40
-19
lines changed

lib50/_api.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
DEFAULT_FILE_LIMIT = 10000
4040

4141

42-
def push(tool, slug, config_loader, repo=None, data=None, prompt=lambda question, included, excluded: True, file_limit=DEFAULT_FILE_LIMIT):
42+
def push(tool, slug, config_loader, repo=None, data=None, prompt=lambda question, included, excluded: True, file_limit=DEFAULT_FILE_LIMIT, auth_method=None):
4343
"""
4444
Pushes to Github in name of a tool.
4545
What should be pushed is configured by the tool and its configuration in the .cs50.yml file identified by the slug.
@@ -61,6 +61,10 @@ def push(tool, slug, config_loader, repo=None, data=None, prompt=lambda question
6161
:type prompt: lambda str, list, list => bool, optional
6262
:param file_limit: maximum number of files to be matched by any globbing pattern.
6363
:type file_limit: int
64+
:param auth_method: The authentication method to use. Accepts `"https"` or `"ssh"`. \
65+
If any other value is provided, attempts SSH \
66+
authentication first and fall back to HTTPS if SSH fails.
67+
:type auth_method: str
6468
:return: GitHub username and the commit hash
6569
:type: tuple(str, str)
6670
@@ -89,7 +93,7 @@ def push(tool, slug, config_loader, repo=None, data=None, prompt=lambda question
8993
remote, (honesty, included, excluded) = connect(slug, config_loader, file_limit=DEFAULT_FILE_LIMIT)
9094

9195
# Authenticate the user with GitHub, and prepare the submission
92-
with authenticate(remote["org"], repo=repo) as user, prepare(tool, slug, user, included):
96+
with authenticate(remote["org"], repo=repo, auth_method=auth_method) as user, prepare(tool, slug, user, included):
9397

9498
# Show any prompt if specified
9599
if prompt(honesty, included, excluded):
@@ -465,7 +469,7 @@ def batch_files(files, size=100):
465469
files_list = list(files)
466470
for i in range(0, len(files_list), size):
467471
yield files_list[i:i + size]
468-
472+
469473
for batch in batch_files(included):
470474
quoted_files = ' '.join(shlex.quote(f) for f in batch)
471475
run(git(f"add -f -- {quoted_files}"))
@@ -962,7 +966,7 @@ def run(command, quiet=False, timeout=None):
962966
ProgressBar.stop_all()
963967
passphrase = _prompt_password("Enter passphrase for SSH key: ")
964968
child.sendline(passphrase)
965-
969+
966970
# Get the full output by reading until EOF
967971
full_output = child.before + child.after + child.read()
968972
command_output = full_output.strip().replace("\r\n", "\n")

lib50/authentication.py

Lines changed: 32 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -32,14 +32,18 @@ class User:
3232
init=False)
3333

3434
@contextlib.contextmanager
35-
def authenticate(org, repo=None):
35+
def authenticate(org, repo=None, auth_method=None):
3636
"""
37-
A contextmanager that authenticates a user with GitHub via SSH if possible, otherwise via HTTPS.
37+
A contextmanager that authenticates a user with GitHub.
3838
3939
:param org: GitHub organisation to authenticate with
4040
:type org: str
4141
:param repo: GitHub repo (part of the org) to authenticate with. Default is the user's GitHub login.
4242
:type repo: str, optional
43+
:param auth_method: The authentication method to use. Accepts `"https"` or `"ssh"`. \
44+
If any other value is provided, attempts SSH \
45+
authentication first and fall back to HTTPS if SSH fails.
46+
:type auth_method: str, optional
4347
:return: an authenticated user
4448
:type: lib50.User
4549
@@ -51,21 +55,34 @@ def authenticate(org, repo=None):
5155
print(user.name)
5256
5357
"""
54-
with api.ProgressBar(_("Authenticating")) as progress_bar:
55-
# Both authentication methods can require user input, best stop the bar
56-
progress_bar.stop()
58+
def try_https(org, repo):
59+
with _authenticate_https(org, repo=repo) as user:
60+
return user
5761

58-
# Try auth through SSH
62+
def try_ssh(org, repo):
5963
user = _authenticate_ssh(org, repo=repo)
60-
61-
# SSH auth failed, fallback to HTTPS
6264
if user is None:
63-
with _authenticate_https(org, repo=repo) as user:
64-
yield user
65-
# yield SSH user
66-
else:
67-
yield user
65+
raise ConnectionError
66+
return user
67+
68+
# Showcase the type of authentication based on input
69+
method_label = f" ({auth_method.upper()})" if auth_method in ("https", "ssh") else ""
70+
with api.ProgressBar(_("Authenticating{}").format(method_label)) as progress_bar:
71+
# Both authentication methods can require user input, best stop the bar
72+
progress_bar.stop()
6873

74+
match auth_method:
75+
case "https":
76+
yield try_https(org, repo)
77+
case "ssh":
78+
yield try_ssh(org, repo)
79+
case _:
80+
# Try auth through SSH
81+
try:
82+
yield try_ssh(org, repo)
83+
except ConnectionError:
84+
# SSH auth failed, fallback to HTTPS
85+
yield try_https(org, repo)
6986

7087
def logout():
7188
"""
@@ -104,7 +121,7 @@ def run_authenticated(user, command, quiet=False, timeout=None):
104121
# Try to extract the conflicting branch prefix from the error message
105122
# Pattern: 'refs/heads/cs50/problems/2025/x' exists
106123
branch_prefix_match = re.search(r"'refs/heads/([^']+)' exists", command_output)
107-
124+
108125
if branch_prefix_match:
109126
conflicting_prefix = branch_prefix_match.group(1)
110127
error_msg = _("Looks like you're trying to push to a branch that conflicts with an existing one in the repository.\n"
@@ -195,7 +212,7 @@ class State(enum.Enum):
195212
else:
196213
if not os.environ.get("CODESPACES"):
197214
_show_gh_changes_warning()
198-
215+
199216
return None
200217
finally:
201218
child.close()

0 commit comments

Comments
 (0)