1 // Copyright © 2020 Intel Corporation 2 // 3 // SPDX-License-Identifier: Apache-2.0 4 // 5 6 use std::fmt; 7 use std::io::{Read, Write}; 8 use std::os::unix::io::RawFd; 9 use vmm_sys_util::sock_ctrl_msg::ScmSocket; 10 11 #[derive(Debug)] 12 pub enum Error { 13 Socket(std::io::Error), 14 SocketSendFds(vmm_sys_util::errno::Error), 15 StatusCodeParsing(std::num::ParseIntError), 16 MissingProtocol, 17 ContentLengthParsing(std::num::ParseIntError), 18 ServerResponse(StatusCode, Option<String>), 19 } 20 21 impl fmt::Display for Error { 22 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 23 use Error::*; 24 match self { 25 Socket(e) => write!(f, "Error writing to or reading from HTTP socket: {}", e), 26 SocketSendFds(e) => write!(f, "Error writing to or reading from HTTP socket: {}", e), 27 StatusCodeParsing(e) => write!(f, "Error parsing HTTP status code: {}", e), 28 MissingProtocol => write!(f, "HTTP output is missing protocol statement"), 29 ContentLengthParsing(e) => write!(f, "Error parsing HTTP Content-Length field: {}", e), 30 ServerResponse(s, o) => { 31 if let Some(o) = o { 32 write!(f, "Server responded with an error: {:?}: {}", s, o) 33 } else { 34 write!(f, "Server responded with an error: {:?}", s) 35 } 36 } 37 } 38 } 39 } 40 41 #[derive(Clone, Copy, Debug)] 42 pub enum StatusCode { 43 Continue, 44 Ok, 45 NoContent, 46 BadRequest, 47 NotFound, 48 InternalServerError, 49 NotImplemented, 50 Unknown, 51 } 52 53 impl StatusCode { 54 fn from_raw(code: usize) -> StatusCode { 55 match code { 56 100 => StatusCode::Continue, 57 200 => StatusCode::Ok, 58 204 => StatusCode::NoContent, 59 400 => StatusCode::BadRequest, 60 404 => StatusCode::NotFound, 61 500 => StatusCode::InternalServerError, 62 501 => StatusCode::NotImplemented, 63 _ => StatusCode::Unknown, 64 } 65 } 66 67 fn parse(code: &str) -> Result<StatusCode, Error> { 68 Ok(StatusCode::from_raw( 69 code.trim().parse().map_err(Error::StatusCodeParsing)?, 70 )) 71 } 72 73 fn is_server_error(self) -> bool { 74 !matches!( 75 self, 76 StatusCode::Ok | StatusCode::Continue | StatusCode::NoContent 77 ) 78 } 79 } 80 81 fn get_header<'a>(res: &'a str, header: &'a str) -> Option<&'a str> { 82 let header_str = format!("{}: ", header); 83 res.find(&header_str) 84 .map(|o| &res[o + header_str.len()..o + res[o..].find('\r').unwrap()]) 85 } 86 87 fn get_status_code(res: &str) -> Result<StatusCode, Error> { 88 if let Some(o) = res.find("HTTP/1.1") { 89 Ok(StatusCode::parse( 90 &res[o + "HTTP/1.1 ".len()..res[o..].find('\r').unwrap()], 91 )?) 92 } else { 93 Err(Error::MissingProtocol) 94 } 95 } 96 97 fn parse_http_response(socket: &mut dyn Read) -> Result<Option<String>, Error> { 98 let mut res = String::new(); 99 let mut body_offset = None; 100 let mut content_length: Option<usize> = None; 101 loop { 102 let mut bytes = vec![0; 256]; 103 let count = socket.read(&mut bytes).map_err(Error::Socket)?; 104 res.push_str(std::str::from_utf8(&bytes[0..count]).unwrap()); 105 106 // End of headers 107 if let Some(o) = res.find("\r\n\r\n") { 108 body_offset = Some(o + "\r\n\r\n".len()); 109 110 // With all headers available we can see if there is any body 111 content_length = if let Some(length) = get_header(&res, "Content-Length") { 112 Some(length.trim().parse().map_err(Error::ContentLengthParsing)?) 113 } else { 114 None 115 }; 116 117 if content_length.is_none() { 118 break; 119 } 120 } 121 122 if let Some(body_offset) = body_offset { 123 if let Some(content_length) = content_length { 124 if res.len() >= content_length + body_offset { 125 break; 126 } 127 } 128 } 129 } 130 let body_string = content_length.and(Some(String::from(&res[body_offset.unwrap()..]))); 131 let status_code = get_status_code(&res)?; 132 133 if status_code.is_server_error() { 134 Err(Error::ServerResponse(status_code, body_string)) 135 } else { 136 Ok(body_string) 137 } 138 } 139 140 pub fn simple_api_command_with_fds<T: Read + Write + ScmSocket>( 141 socket: &mut T, 142 method: &str, 143 c: &str, 144 request_body: Option<&str>, 145 request_fds: Vec<RawFd>, 146 ) -> Result<(), Error> { 147 socket 148 .send_with_fds( 149 &[format!( 150 "{} /api/v1/vm.{} HTTP/1.1\r\nHost: localhost\r\nAccept: */*\r\n", 151 method, c 152 ) 153 .as_bytes()], 154 &request_fds, 155 ) 156 .map_err(Error::SocketSendFds)?; 157 158 if let Some(request_body) = request_body { 159 socket 160 .write_all(format!("Content-Length: {}\r\n", request_body.len()).as_bytes()) 161 .map_err(Error::Socket)?; 162 } 163 164 socket.write_all(b"\r\n").map_err(Error::Socket)?; 165 166 if let Some(request_body) = request_body { 167 socket 168 .write_all(request_body.as_bytes()) 169 .map_err(Error::Socket)?; 170 } 171 172 socket.flush().map_err(Error::Socket)?; 173 174 if let Some(body) = parse_http_response(socket)? { 175 println!("{}", body); 176 } 177 Ok(()) 178 } 179 180 pub fn simple_api_command<T: Read + Write + ScmSocket>( 181 socket: &mut T, 182 method: &str, 183 c: &str, 184 request_body: Option<&str>, 185 ) -> Result<(), Error> { 186 simple_api_command_with_fds(socket, method, c, request_body, Vec::new()) 187 } 188