xref: /cloud-hypervisor/tpm/src/socket.rs (revision 88a9f799449c04180c6b9a21d3b9c0c4b57e2bd6)
1 // Copyright © 2022, Microsoft Corporation
2 //
3 // SPDX-License-Identifier: Apache-2.0
4 //
5 
6 use std::io::Read;
7 use std::os::unix::io::{AsRawFd, RawFd};
8 use std::os::unix::net::UnixStream;
9 
10 use anyhow::anyhow;
11 use thiserror::Error;
12 use vmm_sys_util::sock_ctrl_msg::ScmSocket;
13 
14 #[derive(Error, Debug)]
15 pub enum Error {
16     #[error("Cannot connect to tpm Socket: {0}")]
17     ConnectToSocket(#[source] anyhow::Error),
18     #[error("Failed to read from socket: {0}")]
19     ReadFromSocket(#[source] anyhow::Error),
20     #[error("Failed to write to socket: {0}")]
21     WriteToSocket(#[source] anyhow::Error),
22 }
23 type Result<T> = anyhow::Result<T, Error>;
24 
25 #[derive(PartialEq)]
26 enum SocketDevState {
27     Disconnected,
28     Connecting,
29     Connected,
30 }
31 
32 pub struct SocketDev {
33     state: SocketDevState,
34     stream: Option<UnixStream>,
35     // Fd sent to swtpm process for Data Channel
36     write_msgfd: RawFd,
37     // Data Channel used by Cloud-Hypervisor
38     data_fd: RawFd,
39     // Control Channel used by Cloud-Hypervisor
40     control_fd: RawFd,
41 }
42 
43 impl Default for SocketDev {
44     fn default() -> Self {
45         Self::new()
46     }
47 }
48 
49 impl SocketDev {
50     pub fn new() -> Self {
51         Self {
52             state: SocketDevState::Disconnected,
53             stream: None,
54             write_msgfd: -1,
55             control_fd: -1,
56             data_fd: -1,
57         }
58     }
59 
60     pub fn init(&mut self, path: String) -> Result<()> {
61         self.connect(&path)?;
62         Ok(())
63     }
64 
65     pub fn connect(&mut self, socket_path: &str) -> Result<()> {
66         self.state = SocketDevState::Connecting;
67 
68         let s = UnixStream::connect(socket_path).map_err(|e| {
69             Error::ConnectToSocket(anyhow!("Failed to connect to tpm Socket. Error: {:?}", e))
70         })?;
71         self.control_fd = s.as_raw_fd();
72         self.stream = Some(s);
73         self.state = SocketDevState::Connected;
74         debug!("Connected to tpm socket path : {:?}", socket_path);
75         Ok(())
76     }
77 
78     pub fn set_datafd(&mut self, fd: RawFd) {
79         self.data_fd = fd;
80     }
81 
82     pub fn set_msgfd(&mut self, fd: RawFd) {
83         self.write_msgfd = fd;
84     }
85 
86     pub fn send_full(&self, buf: &[u8]) -> Result<usize> {
87         let write_fd = self.write_msgfd;
88 
89         let size = self
90             .stream
91             .as_ref()
92             .unwrap()
93             .send_with_fd(buf, write_fd)
94             .map_err(|e| {
95                 Error::WriteToSocket(anyhow!("Failed to write to Socket. Error: {:?}", e))
96             })?;
97 
98         Ok(size)
99     }
100 
101     pub fn write(&mut self, buf: &[u8]) -> Result<usize> {
102         if self.stream.is_none() {
103             return Err(Error::WriteToSocket(anyhow!(
104                 "Stream for tpm socket was not initialized"
105             )));
106         }
107 
108         if matches!(self.state, SocketDevState::Connected) {
109             let ret = self.send_full(buf)?;
110             // swtpm will receive data Fd after a successful send
111             // Reset cached write_msgfd after a successful send
112             // Ideally, write_msgfd is reset after first Ctrl Command
113             if ret > 0 && self.write_msgfd != 0 {
114                 self.write_msgfd = 0;
115             }
116             Ok(ret)
117         } else {
118             Err(Error::WriteToSocket(anyhow!(
119                 "TPM Socket was not in Connected State"
120             )))
121         }
122     }
123 
124     pub fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
125         if self.stream.is_none() {
126             return Err(Error::ReadFromSocket(anyhow!(
127                 "Stream for tpm socket was not initialized"
128             )));
129         }
130         let mut socket = self.stream.as_ref().unwrap();
131         let size: usize = socket.read(buf).map_err(|e| {
132             Error::ReadFromSocket(anyhow!("Failed to read from socket. Error Code {:?}", e))
133         })?;
134         Ok(size)
135     }
136 }
137