1 // Copyright © 2022, Microsoft Corporation 2 // 3 // SPDX-License-Identifier: Apache-2.0 4 // 5 6 use anyhow::anyhow; 7 use std::io::Read; 8 use std::os::unix::io::{AsRawFd, RawFd}; 9 use std::os::unix::net::UnixStream; 10 use thiserror::Error; 11 use vmm_sys_util::sock_ctrl_msg::ScmSocket; 12 13 #[derive(Error, Debug)] 14 pub enum Error { 15 #[error("Cannot connect to tpm Socket: {0}")] 16 ConnectToSocket(#[source] anyhow::Error), 17 #[error("Failed to read from socket: {0}")] 18 ReadFromSocket(#[source] anyhow::Error), 19 #[error("Failed to write to socket: {0}")] 20 WriteToSocket(#[source] anyhow::Error), 21 } 22 type Result<T> = anyhow::Result<T, Error>; 23 24 #[derive(PartialEq)] 25 enum SocketDevState { 26 Disconnected, 27 Connecting, 28 Connected, 29 } 30 31 pub struct SocketDev { 32 state: SocketDevState, 33 stream: Option<UnixStream>, 34 // Fd sent to swtpm process for Data Channel 35 write_msgfd: RawFd, 36 // Data Channel used by Cloud-Hypervisor 37 data_fd: RawFd, 38 // Control Channel used by Cloud-Hypervisor 39 control_fd: RawFd, 40 } 41 42 impl Default for SocketDev { 43 fn default() -> Self { 44 Self::new() 45 } 46 } 47 48 impl SocketDev { 49 pub fn new() -> Self { 50 Self { 51 state: SocketDevState::Disconnected, 52 stream: None, 53 write_msgfd: -1, 54 control_fd: -1, 55 data_fd: -1, 56 } 57 } 58 59 pub fn init(&mut self, path: String) -> Result<()> { 60 self.connect(&path)?; 61 Ok(()) 62 } 63 64 pub fn connect(&mut self, socket_path: &str) -> Result<()> { 65 self.state = SocketDevState::Connecting; 66 67 let s = UnixStream::connect(socket_path).map_err(|e| { 68 Error::ConnectToSocket(anyhow!("Failed to connect to tpm Socket. Error: {:?}", e)) 69 })?; 70 self.control_fd = s.as_raw_fd(); 71 self.stream = Some(s); 72 self.state = SocketDevState::Connected; 73 debug!("Connected to tpm socket path : {:?}", socket_path); 74 Ok(()) 75 } 76 77 pub fn set_datafd(&mut self, fd: RawFd) { 78 self.data_fd = fd; 79 } 80 81 pub fn set_msgfd(&mut self, fd: RawFd) { 82 self.write_msgfd = fd; 83 } 84 85 pub fn send_full(&self, buf: &[u8]) -> Result<usize> { 86 let write_fd = self.write_msgfd; 87 88 let size = self 89 .stream 90 .as_ref() 91 .unwrap() 92 .send_with_fd(buf, write_fd) 93 .map_err(|e| { 94 Error::WriteToSocket(anyhow!("Failed to write to Socket. Error: {:?}", e)) 95 })?; 96 97 Ok(size) 98 } 99 100 pub fn write(&mut self, buf: &[u8]) -> Result<usize> { 101 if self.stream.is_none() { 102 return Err(Error::WriteToSocket(anyhow!( 103 "Stream for tpm socket was not initialized" 104 ))); 105 } 106 107 if matches!(self.state, SocketDevState::Connected) { 108 let ret = self.send_full(buf)?; 109 // swtpm will receive data Fd after a successful send 110 // Reset cached write_msgfd after a successful send 111 // Ideally, write_msgfd is reset after first Ctrl Command 112 if ret > 0 && self.write_msgfd != 0 { 113 self.write_msgfd = 0; 114 } 115 Ok(ret) 116 } else { 117 Err(Error::WriteToSocket(anyhow!( 118 "TPM Socket was not in Connected State" 119 ))) 120 } 121 } 122 123 pub fn read(&mut self, buf: &mut [u8]) -> Result<usize> { 124 if self.stream.is_none() { 125 return Err(Error::ReadFromSocket(anyhow!( 126 "Stream for tpm socket was not initialized" 127 ))); 128 } 129 let mut socket = self.stream.as_ref().unwrap(); 130 let size: usize = socket.read(buf).map_err(|e| { 131 Error::ReadFromSocket(anyhow!("Failed to read from socket. Error Code {:?}", e)) 132 })?; 133 Ok(size) 134 } 135 } 136