1 // Copyright © 2020 Intel Corporation 2 // 3 // SPDX-License-Identifier: Apache-2.0 4 // 5 6 use std::fmt; 7 use std::io::{Read, Write}; 8 9 #[derive(Debug)] 10 pub enum Error { 11 Socket(std::io::Error), 12 StatusCodeParsing(std::num::ParseIntError), 13 MissingProtocol, 14 ContentLengthParsing(std::num::ParseIntError), 15 ServerResponse(StatusCode, Option<String>), 16 } 17 18 impl fmt::Display for Error { 19 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 20 use Error::*; 21 match self { 22 Socket(e) => write!(f, "Error writing to or reading from HTTP socket: {}", e), 23 StatusCodeParsing(e) => write!(f, "Error parsing HTTP status code: {}", e), 24 MissingProtocol => write!(f, "HTTP output is missing protocol statement"), 25 ContentLengthParsing(e) => write!(f, "Error parsing HTTP Content-Length field: {}", e), 26 ServerResponse(s, o) => { 27 if let Some(o) = o { 28 write!(f, "Server responded with an error: {:?}: {}", s, o) 29 } else { 30 write!(f, "Server responded with an error: {:?}", s) 31 } 32 } 33 } 34 } 35 } 36 37 #[derive(Clone, Copy, Debug)] 38 pub enum StatusCode { 39 Continue, 40 Ok, 41 NoContent, 42 BadRequest, 43 NotFound, 44 InternalServerError, 45 NotImplemented, 46 Unknown, 47 } 48 49 impl StatusCode { 50 fn from_raw(code: usize) -> StatusCode { 51 match code { 52 100 => StatusCode::Continue, 53 200 => StatusCode::Ok, 54 204 => StatusCode::NoContent, 55 400 => StatusCode::BadRequest, 56 404 => StatusCode::NotFound, 57 500 => StatusCode::InternalServerError, 58 501 => StatusCode::NotImplemented, 59 _ => StatusCode::Unknown, 60 } 61 } 62 63 fn parse(code: &str) -> Result<StatusCode, Error> { 64 Ok(StatusCode::from_raw( 65 code.trim().parse().map_err(Error::StatusCodeParsing)?, 66 )) 67 } 68 69 fn is_server_error(self) -> bool { 70 !matches!( 71 self, 72 StatusCode::Ok | StatusCode::Continue | StatusCode::NoContent 73 ) 74 } 75 } 76 77 fn get_header<'a>(res: &'a str, header: &'a str) -> Option<&'a str> { 78 let header_str = format!("{}: ", header); 79 res.find(&header_str) 80 .map(|o| &res[o + header_str.len()..o + res[o..].find('\r').unwrap()]) 81 } 82 83 fn get_status_code(res: &str) -> Result<StatusCode, Error> { 84 if let Some(o) = res.find("HTTP/1.1") { 85 Ok(StatusCode::parse( 86 &res[o + "HTTP/1.1 ".len()..res[o..].find('\r').unwrap()], 87 )?) 88 } else { 89 Err(Error::MissingProtocol) 90 } 91 } 92 93 fn parse_http_response(socket: &mut dyn Read) -> Result<Option<String>, Error> { 94 let mut res = String::new(); 95 let mut body_offset = None; 96 let mut content_length: Option<usize> = None; 97 loop { 98 let mut bytes = vec![0; 256]; 99 let count = socket.read(&mut bytes).map_err(Error::Socket)?; 100 res.push_str(std::str::from_utf8(&bytes[0..count]).unwrap()); 101 102 // End of headers 103 if let Some(o) = res.find("\r\n\r\n") { 104 body_offset = Some(o + "\r\n\r\n".len()); 105 106 // With all headers available we can see if there is any body 107 content_length = if let Some(length) = get_header(&res, "Content-Length") { 108 Some(length.trim().parse().map_err(Error::ContentLengthParsing)?) 109 } else { 110 None 111 }; 112 113 if content_length.is_none() { 114 break; 115 } 116 } 117 118 if let Some(body_offset) = body_offset { 119 if let Some(content_length) = content_length { 120 if res.len() >= content_length + body_offset { 121 break; 122 } 123 } 124 } 125 } 126 let body_string = content_length.and(Some(String::from(&res[body_offset.unwrap()..]))); 127 let status_code = get_status_code(&res)?; 128 129 if status_code.is_server_error() { 130 Err(Error::ServerResponse(status_code, body_string)) 131 } else { 132 Ok(body_string) 133 } 134 } 135 136 pub fn simple_api_command<T: Read + Write>( 137 socket: &mut T, 138 method: &str, 139 c: &str, 140 request_body: Option<&str>, 141 ) -> Result<(), Error> { 142 socket 143 .write_all( 144 format!( 145 "{} /api/v1/vm.{} HTTP/1.1\r\nHost: localhost\r\nAccept: */*\r\n", 146 method, c 147 ) 148 .as_bytes(), 149 ) 150 .map_err(Error::Socket)?; 151 152 if let Some(request_body) = request_body { 153 socket 154 .write_all(format!("Content-Length: {}\r\n", request_body.len()).as_bytes()) 155 .map_err(Error::Socket)?; 156 } 157 158 socket.write_all(b"\r\n").map_err(Error::Socket)?; 159 160 if let Some(request_body) = request_body { 161 socket 162 .write_all(request_body.as_bytes()) 163 .map_err(Error::Socket)?; 164 } 165 166 socket.flush().map_err(Error::Socket)?; 167 168 if let Some(body) = parse_http_response(socket)? { 169 println!("{}", body); 170 } 171 Ok(()) 172 } 173