diff --git a/nemo_run/core/tunnel/client.py b/nemo_run/core/tunnel/client.py index 316dbc78..750defa4 100644 --- a/nemo_run/core/tunnel/client.py +++ b/nemo_run/core/tunnel/client.py @@ -1,11 +1,10 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2024-2025, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -196,6 +195,7 @@ class SSHTunnel(Tunnel): host: str user: str + port: Optional[int] = None identity: Optional[str] = None shell: Optional[str] = None pre_command: Optional[str] = None @@ -263,6 +263,7 @@ def _authenticate(self): config = Config(overrides={"run": {"in_stream": False}}) self.session = Connection( self.host, + port=self.port, user=self.user, connect_kwargs=connect_kwargs, forward_agent=False, diff --git a/test/core/tunnel/test_client.py b/test/core/tunnel/test_client.py index b14dfeef..e2f5f50e 100644 --- a/test/core/tunnel/test_client.py +++ b/test/core/tunnel/test_client.py @@ -1,11 +1,10 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2024-2025, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -162,6 +161,7 @@ def test_connect_with_identity(self, mock_config, mock_connection): mock_connection.assert_called_once_with( "test.host", + port=None, user="test_user", connect_kwargs={"key_filename": ["/path/to/key"]}, forward_agent=False,