xref: /cloud-hypervisor/api_client/src/lib.rs (revision 496ceed1d02b5884e2a4c570ef231d2c90b64fc0)
1 // Copyright © 2020 Intel Corporation
2 //
3 // SPDX-License-Identifier: Apache-2.0
4 //
5 
6 use std::fmt;
7 use std::io::{Read, Write};
8 
9 #[derive(Debug)]
10 pub enum Error {
11     Socket(std::io::Error),
12     StatusCodeParsing(std::num::ParseIntError),
13     MissingProtocol,
14     ContentLengthParsing(std::num::ParseIntError),
15     ServerResponse(StatusCode, Option<String>),
16 }
17 
18 impl fmt::Display for Error {
19     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
20         use Error::*;
21         match self {
22             Socket(e) => write!(f, "Error writing to or reading from HTTP socket: {}", e),
23             StatusCodeParsing(e) => write!(f, "Error parsing HTTP status code: {}", e),
24             MissingProtocol => write!(f, "HTTP output is missing protocol statement"),
25             ContentLengthParsing(e) => write!(f, "Error parsing HTTP Content-Length field: {}", e),
26             ServerResponse(s, o) => {
27                 if let Some(o) = o {
28                     write!(f, "Server responded with an error: {:?}: {}", s, o)
29                 } else {
30                     write!(f, "Server responded with an error: {:?}", s)
31                 }
32             }
33         }
34     }
35 }
36 
37 #[derive(Clone, Copy, Debug)]
38 pub enum StatusCode {
39     Continue,
40     Ok,
41     NoContent,
42     BadRequest,
43     NotFound,
44     InternalServerError,
45     NotImplemented,
46     Unknown,
47 }
48 
49 impl StatusCode {
50     fn from_raw(code: usize) -> StatusCode {
51         match code {
52             100 => StatusCode::Continue,
53             200 => StatusCode::Ok,
54             204 => StatusCode::NoContent,
55             400 => StatusCode::BadRequest,
56             404 => StatusCode::NotFound,
57             500 => StatusCode::InternalServerError,
58             501 => StatusCode::NotImplemented,
59             _ => StatusCode::Unknown,
60         }
61     }
62 
63     fn parse(code: &str) -> Result<StatusCode, Error> {
64         Ok(StatusCode::from_raw(
65             code.trim().parse().map_err(Error::StatusCodeParsing)?,
66         ))
67     }
68 
69     fn is_server_error(self) -> bool {
70         !matches!(
71             self,
72             StatusCode::Ok | StatusCode::Continue | StatusCode::NoContent
73         )
74     }
75 }
76 
77 fn get_header<'a>(res: &'a str, header: &'a str) -> Option<&'a str> {
78     let header_str = format!("{}: ", header);
79     res.find(&header_str)
80         .map(|o| &res[o + header_str.len()..o + res[o..].find('\r').unwrap()])
81 }
82 
83 fn get_status_code(res: &str) -> Result<StatusCode, Error> {
84     if let Some(o) = res.find("HTTP/1.1") {
85         Ok(StatusCode::parse(
86             &res[o + "HTTP/1.1 ".len()..res[o..].find('\r').unwrap()],
87         )?)
88     } else {
89         Err(Error::MissingProtocol)
90     }
91 }
92 
93 fn parse_http_response(socket: &mut dyn Read) -> Result<Option<String>, Error> {
94     let mut res = String::new();
95     let mut body_offset = None;
96     let mut content_length: Option<usize> = None;
97     loop {
98         let mut bytes = vec![0; 256];
99         let count = socket.read(&mut bytes).map_err(Error::Socket)?;
100         res.push_str(std::str::from_utf8(&bytes[0..count]).unwrap());
101 
102         // End of headers
103         if let Some(o) = res.find("\r\n\r\n") {
104             body_offset = Some(o + "\r\n\r\n".len());
105 
106             // With all headers available we can see if there is any body
107             content_length = if let Some(length) = get_header(&res, "Content-Length") {
108                 Some(length.trim().parse().map_err(Error::ContentLengthParsing)?)
109             } else {
110                 None
111             };
112 
113             if content_length.is_none() {
114                 break;
115             }
116         }
117 
118         if let Some(body_offset) = body_offset {
119             if let Some(content_length) = content_length {
120                 if res.len() >= content_length + body_offset {
121                     break;
122                 }
123             }
124         }
125     }
126     let body_string = content_length.and(Some(String::from(&res[body_offset.unwrap()..])));
127     let status_code = get_status_code(&res)?;
128 
129     if status_code.is_server_error() {
130         Err(Error::ServerResponse(status_code, body_string))
131     } else {
132         Ok(body_string)
133     }
134 }
135 
136 pub fn simple_api_command<T: Read + Write>(
137     socket: &mut T,
138     method: &str,
139     c: &str,
140     request_body: Option<&str>,
141 ) -> Result<(), Error> {
142     socket
143         .write_all(
144             format!(
145                 "{} /api/v1/vm.{} HTTP/1.1\r\nHost: localhost\r\nAccept: */*\r\n",
146                 method, c
147             )
148             .as_bytes(),
149         )
150         .map_err(Error::Socket)?;
151 
152     if let Some(request_body) = request_body {
153         socket
154             .write_all(format!("Content-Length: {}\r\n", request_body.len()).as_bytes())
155             .map_err(Error::Socket)?;
156     }
157 
158     socket.write_all(b"\r\n").map_err(Error::Socket)?;
159 
160     if let Some(request_body) = request_body {
161         socket
162             .write_all(request_body.as_bytes())
163             .map_err(Error::Socket)?;
164     }
165 
166     socket.flush().map_err(Error::Socket)?;
167 
168     if let Some(body) = parse_http_response(socket)? {
169         println!("{}", body);
170     }
171     Ok(())
172 }
173