1 // Copyright © 2022, Microsoft Corporation 2 // 3 // SPDX-License-Identifier: Apache-2.0 4 // 5 6 use std::io::Read; 7 use std::os::unix::io::{AsRawFd, RawFd}; 8 use std::os::unix::net::UnixStream; 9 10 use anyhow::anyhow; 11 use thiserror::Error; 12 use vmm_sys_util::sock_ctrl_msg::ScmSocket; 13 14 #[derive(Error, Debug)] 15 pub enum Error { 16 #[error("Cannot connect to tpm Socket")] 17 ConnectToSocket(#[source] anyhow::Error), 18 #[error("Failed to read from socket")] 19 ReadFromSocket(#[source] anyhow::Error), 20 #[error("Failed to write to socket")] 21 WriteToSocket(#[source] anyhow::Error), 22 } 23 type Result<T> = anyhow::Result<T, Error>; 24 25 #[derive(PartialEq)] 26 enum SocketDevState { 27 Disconnected, 28 Connecting, 29 Connected, 30 } 31 32 pub struct SocketDev { 33 state: SocketDevState, 34 stream: Option<UnixStream>, 35 // Fd sent to swtpm process for Data Channel 36 write_msgfd: RawFd, 37 // Data Channel used by Cloud-Hypervisor 38 data_fd: RawFd, 39 // Control Channel used by Cloud-Hypervisor 40 control_fd: RawFd, 41 } 42 43 impl Default for SocketDev { 44 fn default() -> Self { 45 Self::new() 46 } 47 } 48 49 impl SocketDev { 50 pub fn new() -> Self { 51 Self { 52 state: SocketDevState::Disconnected, 53 stream: None, 54 write_msgfd: -1, 55 control_fd: -1, 56 data_fd: -1, 57 } 58 } 59 60 pub fn init(&mut self, path: String) -> Result<()> { 61 self.connect(&path)?; 62 Ok(()) 63 } 64 65 pub fn connect(&mut self, socket_path: &str) -> Result<()> { 66 self.state = SocketDevState::Connecting; 67 68 let s = UnixStream::connect(socket_path).map_err(|e| { 69 Error::ConnectToSocket(anyhow!("Failed to connect to tpm Socket. Error: {:?}", e)) 70 })?; 71 self.control_fd = s.as_raw_fd(); 72 self.stream = Some(s); 73 self.state = SocketDevState::Connected; 74 debug!("Connected to tpm socket path : {:?}", socket_path); 75 Ok(()) 76 } 77 78 pub fn set_datafd(&mut self, fd: RawFd) { 79 self.data_fd = fd; 80 } 81 82 pub fn set_msgfd(&mut self, fd: RawFd) { 83 self.write_msgfd = fd; 84 } 85 86 pub fn send_full(&self, buf: &[u8]) -> Result<usize> { 87 let write_fd = self.write_msgfd; 88 89 let size = self 90 .stream 91 .as_ref() 92 .unwrap() 93 .send_with_fd(buf, write_fd) 94 .map_err(|e| { 95 Error::WriteToSocket(anyhow!("Failed to write to Socket. Error: {:?}", e)) 96 })?; 97 98 Ok(size) 99 } 100 101 pub fn write(&mut self, buf: &[u8]) -> Result<usize> { 102 if self.stream.is_none() { 103 return Err(Error::WriteToSocket(anyhow!( 104 "Stream for tpm socket was not initialized" 105 ))); 106 } 107 108 if matches!(self.state, SocketDevState::Connected) { 109 let ret = self.send_full(buf)?; 110 // swtpm will receive data Fd after a successful send 111 // Reset cached write_msgfd after a successful send 112 // Ideally, write_msgfd is reset after first Ctrl Command 113 if ret > 0 && self.write_msgfd != 0 { 114 self.write_msgfd = 0; 115 } 116 Ok(ret) 117 } else { 118 Err(Error::WriteToSocket(anyhow!( 119 "TPM Socket was not in Connected State" 120 ))) 121 } 122 } 123 124 pub fn read(&mut self, buf: &mut [u8]) -> Result<usize> { 125 if self.stream.is_none() { 126 return Err(Error::ReadFromSocket(anyhow!( 127 "Stream for tpm socket was not initialized" 128 ))); 129 } 130 let mut socket = self.stream.as_ref().unwrap(); 131 let size: usize = socket.read(buf).map_err(|e| { 132 Error::ReadFromSocket(anyhow!("Failed to read from socket. Error Code {:?}", e)) 133 })?; 134 Ok(size) 135 } 136 } 137