diff --git a/Cargo.lock b/Cargo.lock index 2d8db98..f5cd23b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -240,6 +240,18 @@ version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" +[[package]] +name = "httparse" +version = "1.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87" + +[[package]] +name = "itoa" +version = "1.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f42a60cbdf9a97f5d2305f08a87dc4e09308d1276d28c869c684d7777685682" + [[package]] name = "jiff" version = "0.2.10" @@ -266,16 +278,19 @@ dependencies = [ [[package]] name = "krunkit" -version = "1.3.0" +version = "1.3.1" dependencies = [ "anyhow", "clap", "core-foundation", "env_logger", + "httparse", "log", "mac_address", "objc2-io-kit", "regex", + "serde", + "serde_json", "sysinfo", ] @@ -462,24 +477,47 @@ checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c" [[package]] name = "serde" -version = "1.0.219" +version = "1.0.228" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f0e2c6ed6606019b4e29e69dbaba95b11854410e5347d525002456dbbb786b6" +checksum = "9a8e94ea7f378bd32cbbd37198a4a91436180c5bb472411e48b5ec2e2124ae9e" +dependencies = [ + "serde_core", + "serde_derive", +] + +[[package]] +name = "serde_core" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41d385c7d4ca58e59fc732af25c3983b67ac852c1a25000afe1175de458b67ad" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.219" +version = "1.0.228" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b0276cf7f2c73365f7157c8123c21cd9a50fbbd844757af28ca1f5925fc2a00" +checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79" dependencies = [ "proc-macro2", "quote", "syn", ] +[[package]] +name = "serde_json" +version = "1.0.150" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e8014e44b4736ed0538adeecded0fce2a272f22dc9578a7eb6b2d9993c74cfb9" +dependencies = [ + "itoa", + "memchr", + "serde", + "serde_core", + "zmij", +] + [[package]] name = "strsim" version = "0.11.0" @@ -670,3 +708,9 @@ name = "windows_x86_64_msvc" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" + +[[package]] +name = "zmij" +version = "1.0.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8848ee67ecc8aedbaf3e4122217aff892639231befc6a1b58d29fff4c2cabaa" diff --git a/Cargo.toml b/Cargo.toml index b60e15f..33b68b7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "krunkit" -version = "1.3.0" +version = "1.3.1" authors = ["The krunkit Authors"] edition = "2021" description = "CLI tool to start VMs with libkrun" @@ -18,3 +18,6 @@ env_logger = "0.11.8" regex = "1.11.1" core-foundation = "0.10.1" objc2-io-kit = "0.3.2" +httparse = "1.10.1" +serde = { version = "1.0.228", features = ["derive"] } +serde_json = "1.0.150" diff --git a/src/status.rs b/src/status.rs index 4eb2235..ab6a961 100644 --- a/src/status.rs +++ b/src/status.rs @@ -12,17 +12,14 @@ use std::{ }; use anyhow::{anyhow, Context}; +use serde::{Deserialize, Serialize}; #[link(name = "krun")] extern "C" { fn krun_get_shutdown_eventfd(ctx_id: u32) -> i32; } -const HTTP_RUNNING: &str = - "HTTP/1.1 200 OK\r\nContent-type: application/json\r\n\r\n{\"state\": \"VirtualMachineStateRunning\"}\0"; - -const HTTP_STOPPING: &str = - "HTTP/1.1 200 OK\r\nContent-type: application/json\r\n\r\n{\"state\": \"VirtualMachineStateStopping\"}\0"; +const VM_STATE_PATH: &str = "/vm/state"; #[derive(Clone, Copy, Debug, Default, PartialEq, Eq)] pub enum UriScheme { @@ -109,26 +106,166 @@ pub unsafe fn get_shutdown_eventfd(ctx_id: u32) -> i32 { fd } -fn handle_incoming_stream(stream: &mut T, shutdown_fd: &mut File) { +fn write_http_response(stream: &mut T, status: u16, reason: &str, body: &str) { + let response = format!( + "HTTP/1.1 {status} {reason}\r\nContent-Type: application/json\r\nContent-Length: {}\r\n\r\n{body}", + body.len(), + ); + if let Err(e) = stream.write_all(response.as_bytes()) { + log::error!("Failed to write HTTP response: {e}"); + } +} + +fn write_http_error(stream: &mut T, status: u16, reason: &str) { + let response = format!("HTTP/1.1 {status} {reason}\r\nContent-Length: 0\r\n\r\n"); + if let Err(e) = stream.write_all(response.as_bytes()) { + log::error!("Failed to write HTTP response: {e}"); + } +} + +#[derive(Serialize)] +#[serde(rename_all = "camelCase")] +struct VmStateResponse { + state: String, + can_start: bool, + can_pause: bool, + can_resume: bool, + can_stop: bool, + can_hard_stop: bool, +} + +impl VmStateResponse { + fn new(state: &str, can_stop: bool) -> Self { + Self { + state: state.to_string(), + can_start: false, + can_pause: false, + can_resume: false, + can_stop, + can_hard_stop: can_stop, + } + } +} + +#[derive(Deserialize)] +struct VmStateRequest { + state: String, +} + +fn normalize_path(path: &str) -> &str { + let path = path.split('?').next().unwrap_or(path); + path.strip_suffix('/').unwrap_or(path) +} + +fn content_length(headers: &[httparse::Header]) -> usize { + headers + .iter() + .find(|h| h.name.eq_ignore_ascii_case("Content-Length")) + .and_then(|h| std::str::from_utf8(h.value).ok()) + .and_then(|v| v.parse().ok()) + .unwrap_or(0) +} + +fn handle_incoming_stream( + stream: &mut T, + shutdown_fd: &mut File, + stopping: &mut bool, +) { let mut buf = [0u8; 4096]; - match stream.read(&mut buf) { - Ok(_sz) => { - let request = String::from_utf8_lossy(&buf); - if request.contains("POST") { - // Send a VirtualMachineStateStopping message to the client. - if let Err(e) = stream.write_all(HTTP_STOPPING.as_bytes()) { - log::error!("Failure writing POST response: {e}"); - } + let sz = match stream.read(&mut buf) { + Ok(0) => return, + Ok(sz) => sz, + Err(e) => { + log::error!("Failed to read from stream: {e}"); + return; + } + }; + + let mut headers = [httparse::EMPTY_HEADER; 32]; + let mut req = httparse::Request::new(&mut headers); + let header_len = match req.parse(&buf[..sz]) { + Ok(httparse::Status::Complete(len)) => len, + Ok(httparse::Status::Partial) => { + write_http_error(stream, 400, "Bad Request"); + return; + } + Err(_) => { + write_http_error(stream, 400, "Bad Request"); + return; + } + }; + + let method = match req.method { + Some(m) => m, + None => { + write_http_error(stream, 400, "Bad Request"); + return; + } + }; - // Shut down the VM. - if let Err(e) = shutdown_fd.write_all(&1u64.to_le_bytes()) { - log::error!("Failure writing to shutdown fd: {e}"); + let path = match req.path { + Some(p) => p, + None => { + write_http_error(stream, 400, "Bad Request"); + return; + } + }; + + if normalize_path(path) != VM_STATE_PATH { + write_http_error(stream, 404, "Not Found"); + return; + } + + match method { + "GET" => { + let (state, can_stop) = if *stopping { + ("VirtualMachineStateStopping", false) + } else { + ("VirtualMachineStateRunning", true) + }; + let body = serde_json::to_string(&VmStateResponse::new(state, can_stop)).unwrap(); + write_http_response(stream, 200, "OK", &body); + } + "POST" => { + let body_len = content_length(req.headers); + let body_end = std::cmp::min(header_len + body_len, sz); + let body = &buf[header_len..body_end]; + + let state_req: VmStateRequest = match serde_json::from_slice(body) { + Ok(r) => r, + Err(_) => { + write_http_response( + stream, + 400, + "Bad Request", + "{\"error\":\"missing or invalid 'state' field\"}", + ); + return; + } + }; + + match state_req.state.as_str() { + "Stop" | "HardStop" => { + *stopping = true; + let body = serde_json::to_string(&VmStateResponse::new( + "VirtualMachineStateStopping", + false, + )) + .unwrap(); + write_http_response(stream, 200, "OK", &body); + if let Err(e) = shutdown_fd.write_all(&1u64.to_le_bytes()) { + log::error!("Failed to write to shutdown fd: {e}"); + } + } + other => { + let error = format!("{{\"error\":\"unsupported state change: {other}\"}}"); + write_http_response(stream, 400, "Bad Request", &error); } - } else if let Err(e) = stream.write_all(HTTP_RUNNING.as_bytes()) { - log::error!("Failure writing GET response: {e}"); } } - Err(e) => log::error!("Failure reading stream: {e}"), + _ => { + write_http_error(stream, 405, "Method Not Allowed"); + } } } @@ -142,13 +279,15 @@ pub fn status_listener( let addr = addr.unwrap_or_default(); + let mut stopping = false; + match addr { RestfulUri::Tcp(addr, port) => { let listener = TcpListener::bind((addr, port)) .map_err(|e| anyhow!("Unable to bind to TCP listener: {}", e))?; for stream in listener.incoming() { - handle_incoming_stream(&mut stream.unwrap(), &mut shutdown) + handle_incoming_stream(&mut stream.unwrap(), &mut shutdown, &mut stopping) } } RestfulUri::Unix(path) => { @@ -161,7 +300,7 @@ pub fn status_listener( .map_err(|e| anyhow!("Unable to bind to unix socket: {}", e))?; for stream in listener.incoming() { - handle_incoming_stream(&mut stream.unwrap(), &mut shutdown) + handle_incoming_stream(&mut stream.unwrap(), &mut shutdown, &mut stopping) } } RestfulUri::None => unreachable!(), @@ -170,9 +309,190 @@ pub fn status_listener( Ok(()) } -#[allow(unused_imports)] +#[cfg(test)] mod tests { use super::*; + use std::io::Cursor; + + fn make_request(method: &str, path: &str, body: Option<&str>) -> Vec { + let mut req = format!("{method} {path} HTTP/1.1\r\nHost: localhost\r\n"); + if let Some(b) = body { + req.push_str(&format!("Content-Length: {}\r\n", b.len())); + req.push_str("Content-Type: application/json\r\n"); + } + req.push_str("\r\n"); + if let Some(b) = body { + req.push_str(b); + } + req.into_bytes() + } + + struct MockStream { + read: Cursor>, + written: Vec, + } + + impl std::io::Read for MockStream { + fn read(&mut self, buf: &mut [u8]) -> std::io::Result { + self.read.read(buf) + } + } + + impl std::io::Write for MockStream { + fn write(&mut self, buf: &[u8]) -> std::io::Result { + self.written.extend_from_slice(buf); + Ok(buf.len()) + } + fn flush(&mut self) -> std::io::Result<()> { + Ok(()) + } + } + + fn handle_request(request: &[u8], stopping: &mut bool) -> String { + let mut stream = MockStream { + read: Cursor::new(request.to_vec()), + written: Vec::new(), + }; + + let (sock_a, _sock_b) = std::os::unix::net::UnixStream::pair().unwrap(); + let mut shutdown_fd = + unsafe { File::from_raw_fd(std::os::fd::AsRawFd::as_raw_fd(&sock_a)) }; + + handle_incoming_stream(&mut stream, &mut shutdown_fd, stopping); + + std::mem::forget(shutdown_fd); + + String::from_utf8(stream.written).unwrap() + } + + fn response_status(response: &str) -> u16 { + response + .split_whitespace() + .nth(1) + .and_then(|s| s.parse().ok()) + .unwrap() + } + + fn response_body(response: &str) -> &str { + response.split("\r\n\r\n").nth(1).unwrap_or("") + } + + #[test] + fn get_vm_state_returns_running() { + let req = make_request("GET", "/vm/state", None); + let mut stopping = false; + let resp = handle_request(&req, &mut stopping); + assert_eq!(response_status(&resp), 200); + let body = response_body(&resp); + assert!(body.contains("\"state\":\"VirtualMachineStateRunning\"")); + assert!(body.contains("\"canStop\":true")); + } + + #[test] + fn get_vm_state_trailing_slash() { + let req = make_request("GET", "/vm/state/", None); + let mut stopping = false; + let resp = handle_request(&req, &mut stopping); + assert_eq!(response_status(&resp), 200); + assert!(response_body(&resp).contains("VirtualMachineStateRunning")); + } + + #[test] + fn get_vm_state_while_stopping() { + let req = make_request("GET", "/vm/state", None); + let mut stopping = true; + let resp = handle_request(&req, &mut stopping); + assert_eq!(response_status(&resp), 200); + let body = response_body(&resp); + assert!(body.contains("\"state\":\"VirtualMachineStateStopping\"")); + assert!(body.contains("\"canStop\":false")); + } + + #[test] + fn post_stop_returns_stopping() { + let req = make_request("POST", "/vm/state", Some("{\"state\":\"Stop\"}")); + let mut stopping = false; + let resp = handle_request(&req, &mut stopping); + assert_eq!(response_status(&resp), 200); + assert!(response_body(&resp).contains("VirtualMachineStateStopping")); + assert!(stopping); + } + + #[test] + fn post_hardstop_returns_stopping() { + let req = make_request("POST", "/vm/state", Some("{\"state\":\"HardStop\"}")); + let mut stopping = false; + let resp = handle_request(&req, &mut stopping); + assert_eq!(response_status(&resp), 200); + assert!(stopping); + } + + #[test] + fn post_invalid_state_returns_400() { + let req = make_request("POST", "/vm/state", Some("{\"state\":\"Pause\"}")); + let mut stopping = false; + let resp = handle_request(&req, &mut stopping); + assert_eq!(response_status(&resp), 400); + assert!(!stopping); + } + + #[test] + fn post_missing_body_returns_400() { + let req = make_request("POST", "/vm/state", None); + let mut stopping = false; + let resp = handle_request(&req, &mut stopping); + assert_eq!(response_status(&resp), 400); + } + + #[test] + fn unknown_path_returns_404() { + let req = make_request("GET", "/vm/inspect", None); + let mut stopping = false; + let resp = handle_request(&req, &mut stopping); + assert_eq!(response_status(&resp), 404); + } + + #[test] + fn unknown_method_returns_405() { + let req = make_request("DELETE", "/vm/state", None); + let mut stopping = false; + let resp = handle_request(&req, &mut stopping); + assert_eq!(response_status(&resp), 405); + } + + #[test] + fn path_with_query_string() { + let req = make_request("GET", "/vm/state?foo=bar", None); + let mut stopping = false; + let resp = handle_request(&req, &mut stopping); + assert_eq!(response_status(&resp), 200); + } + + #[test] + fn deserialize_vm_state_request() { + let r: VmStateRequest = serde_json::from_str("{\"state\":\"Stop\"}").unwrap(); + assert_eq!(r.state, "Stop"); + let r: VmStateRequest = serde_json::from_str("{ \"state\" : \"HardStop\" }").unwrap(); + assert_eq!(r.state, "HardStop"); + } + + #[test] + fn deserialize_vm_state_request_invalid() { + assert!(serde_json::from_str::("").is_err()); + assert!(serde_json::from_str::("{}").is_err()); + assert!(serde_json::from_str::("not json").is_err()); + } + + #[test] + fn serialize_vm_state_response() { + let resp = VmStateResponse::new("VirtualMachineStateRunning", true); + let json: serde_json::Value = + serde_json::from_str(&serde_json::to_string(&resp).unwrap()).unwrap(); + assert_eq!(json["state"], "VirtualMachineStateRunning"); + assert_eq!(json["canStop"], true); + assert_eq!(json["canHardStop"], true); + assert_eq!(json["canPause"], false); + } #[test] fn parse_valid_unix_scheme() {