xref: /cloud-hypervisor/tpm/src/socket.rs (revision 3ce0fef7fd546467398c914dbc74d8542e45cf6f)
1 // Copyright © 2022, Microsoft Corporation
2 //
3 // SPDX-License-Identifier: Apache-2.0
4 //
5 
6 use anyhow::anyhow;
7 use std::io::Read;
8 use std::os::unix::io::{AsRawFd, RawFd};
9 use std::os::unix::net::UnixStream;
10 use thiserror::Error;
11 use vmm_sys_util::sock_ctrl_msg::ScmSocket;
12 
13 #[derive(Error, Debug)]
14 pub enum Error {
15     #[error("Cannot connect to tpm Socket: {0}")]
16     ConnectToSocket(#[source] anyhow::Error),
17     #[error("Failed to read from socket: {0}")]
18     ReadFromSocket(#[source] anyhow::Error),
19     #[error("Failed to write to socket: {0}")]
20     WriteToSocket(#[source] anyhow::Error),
21 }
22 type Result<T> = anyhow::Result<T, Error>;
23 
24 #[derive(PartialEq)]
25 enum SocketDevState {
26     Disconnected,
27     Connecting,
28     Connected,
29 }
30 
31 pub struct SocketDev {
32     state: SocketDevState,
33     stream: Option<UnixStream>,
34     // Fd sent to swtpm process for Data Channel
35     write_msgfd: RawFd,
36     // Data Channel used by Cloud-Hypervisor
37     data_fd: RawFd,
38     // Control Channel used by Cloud-Hypervisor
39     control_fd: RawFd,
40 }
41 
42 impl Default for SocketDev {
43     fn default() -> Self {
44         Self::new()
45     }
46 }
47 
48 impl SocketDev {
49     pub fn new() -> Self {
50         Self {
51             state: SocketDevState::Disconnected,
52             stream: None,
53             write_msgfd: -1,
54             control_fd: -1,
55             data_fd: -1,
56         }
57     }
58 
59     pub fn init(&mut self, path: String) -> Result<()> {
60         self.connect(&path)?;
61         Ok(())
62     }
63 
64     pub fn connect(&mut self, socket_path: &str) -> Result<()> {
65         self.state = SocketDevState::Connecting;
66 
67         let s = UnixStream::connect(socket_path).map_err(|e| {
68             Error::ConnectToSocket(anyhow!("Failed to connect to tpm Socket. Error: {:?}", e))
69         })?;
70         self.control_fd = s.as_raw_fd();
71         self.stream = Some(s);
72         self.state = SocketDevState::Connected;
73         debug!("Connected to tpm socket path : {:?}", socket_path);
74         Ok(())
75     }
76 
77     pub fn set_datafd(&mut self, fd: RawFd) {
78         self.data_fd = fd;
79     }
80 
81     pub fn set_msgfd(&mut self, fd: RawFd) {
82         self.write_msgfd = fd;
83     }
84 
85     pub fn send_full(&self, buf: &[u8]) -> Result<usize> {
86         let write_fd = self.write_msgfd;
87 
88         let size = self
89             .stream
90             .as_ref()
91             .unwrap()
92             .send_with_fd(buf, write_fd)
93             .map_err(|e| {
94                 Error::WriteToSocket(anyhow!("Failed to write to Socket. Error: {:?}", e))
95             })?;
96 
97         Ok(size)
98     }
99 
100     pub fn write(&mut self, buf: &[u8]) -> Result<usize> {
101         if self.stream.is_none() {
102             return Err(Error::WriteToSocket(anyhow!(
103                 "Stream for tpm socket was not initialized"
104             )));
105         }
106 
107         if matches!(self.state, SocketDevState::Connected) {
108             let ret = self.send_full(buf)?;
109             // swtpm will receive data Fd after a successful send
110             // Reset cached write_msgfd after a successful send
111             // Ideally, write_msgfd is reset after first Ctrl Command
112             if ret > 0 && self.write_msgfd != 0 {
113                 self.write_msgfd = 0;
114             }
115             Ok(ret)
116         } else {
117             Err(Error::WriteToSocket(anyhow!(
118                 "TPM Socket was not in Connected State"
119             )))
120         }
121     }
122 
123     pub fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
124         if self.stream.is_none() {
125             return Err(Error::ReadFromSocket(anyhow!(
126                 "Stream for tpm socket was not initialized"
127             )));
128         }
129         let mut socket = self.stream.as_ref().unwrap();
130         let size: usize = socket.read(buf).map_err(|e| {
131             Error::ReadFromSocket(anyhow!("Failed to read from socket. Error Code {:?}", e))
132         })?;
133         Ok(size)
134     }
135 }
136