use alloc::{boxed::Box, sync::Arc, vec::Vec}; use log::{error, warn}; use smoltcp::{ socket::{raw, tcp, udp}, wire, }; use system_error::SystemError; use crate::{ driver::net::NetDevice, libs::rwlock::RwLock, net::{ event_poll::EPollEventType, net_core::poll_ifaces, Endpoint, Protocol, ShutdownType, NET_DEVICES, }, }; use super::{ handle::GlobalSocketHandle, Socket, SocketHandleItem, SocketMetadata, SocketOptions, SocketPollMethod, SocketType, HANDLE_MAP, PORT_MANAGER, SOCKET_SET, }; /// @brief 表示原始的socket。原始套接字绕过传输层协议(如 TCP 或 UDP)并提供对网络层协议(如 IP)的直接访问。 /// /// ref: https://man7.org/linux/man-pages/man7/raw.7.html #[derive(Debug, Clone)] pub struct RawSocket { handle: GlobalSocketHandle, /// 用户发送的数据包是否包含了IP头. /// 如果是true,用户发送的数据包,必须包含IP头。(即用户要自行设置IP头+数据) /// 如果是false,用户发送的数据包,不包含IP头。(即用户只要设置数据) header_included: bool, /// socket的metadata metadata: SocketMetadata, } impl RawSocket { /// 元数据的缓冲区的大小 pub const DEFAULT_METADATA_BUF_SIZE: usize = 1024; /// 默认的接收缓冲区的大小 receive pub const DEFAULT_RX_BUF_SIZE: usize = 64 * 1024; /// 默认的发送缓冲区的大小 transmiss pub const DEFAULT_TX_BUF_SIZE: usize = 64 * 1024; /// @brief 创建一个原始的socket /// /// @param protocol 协议号 /// @param options socket的选项 /// /// @return 返回创建的原始的socket pub fn new(protocol: Protocol, options: SocketOptions) -> Self { let rx_buffer = raw::PacketBuffer::new( vec![raw::PacketMetadata::EMPTY; Self::DEFAULT_METADATA_BUF_SIZE], vec![0; Self::DEFAULT_RX_BUF_SIZE], ); let tx_buffer = raw::PacketBuffer::new( vec![raw::PacketMetadata::EMPTY; Self::DEFAULT_METADATA_BUF_SIZE], vec![0; Self::DEFAULT_TX_BUF_SIZE], ); let protocol: u8 = protocol.into(); let socket = raw::Socket::new( wire::IpVersion::Ipv4, wire::IpProtocol::from(protocol), rx_buffer, tx_buffer, ); // 把socket添加到socket集合中,并得到socket的句柄 let handle = GlobalSocketHandle::new_smoltcp_handle(SOCKET_SET.lock_irqsave().add(socket)); let metadata = SocketMetadata::new( SocketType::Raw, Self::DEFAULT_RX_BUF_SIZE, Self::DEFAULT_TX_BUF_SIZE, Self::DEFAULT_METADATA_BUF_SIZE, options, ); return Self { handle, header_included: false, metadata, }; } } impl Socket for RawSocket { fn close(&mut self) { let mut socket_set_guard = SOCKET_SET.lock_irqsave(); if let smoltcp::socket::Socket::Udp(mut sock) = socket_set_guard.remove(self.handle.smoltcp_handle().unwrap()) { sock.close(); } drop(socket_set_guard); poll_ifaces(); } fn read(&self, buf: &mut [u8]) -> (Result, Endpoint) { poll_ifaces(); loop { // 如何优化这里? let mut socket_set_guard = SOCKET_SET.lock_irqsave(); let socket = socket_set_guard.get_mut::(self.handle.smoltcp_handle().unwrap()); match socket.recv_slice(buf) { Ok(len) => { let packet = wire::Ipv4Packet::new_unchecked(buf); return ( Ok(len), Endpoint::Ip(Some(wire::IpEndpoint { addr: wire::IpAddress::Ipv4(packet.src_addr()), port: 0, })), ); } Err(_) => { if !self.metadata.options.contains(SocketOptions::BLOCK) { // 如果是非阻塞的socket,就返回错误 return (Err(SystemError::EAGAIN_OR_EWOULDBLOCK), Endpoint::Ip(None)); } } } drop(socket_set_guard); SocketHandleItem::sleep( self.socket_handle(), EPollEventType::EPOLLIN.bits() as u64, HANDLE_MAP.read_irqsave(), ); } } fn write(&self, buf: &[u8], to: Option) -> Result { // 如果用户发送的数据包,包含IP头,则直接发送 if self.header_included { let mut socket_set_guard = SOCKET_SET.lock_irqsave(); let socket = socket_set_guard.get_mut::(self.handle.smoltcp_handle().unwrap()); match socket.send_slice(buf) { Ok(_) => { return Ok(buf.len()); } Err(raw::SendError::BufferFull) => { return Err(SystemError::ENOBUFS); } } } else { // 如果用户发送的数据包,不包含IP头,则需要自己构造IP头 if let Some(Endpoint::Ip(Some(endpoint))) = to { let mut socket_set_guard = SOCKET_SET.lock_irqsave(); let socket: &mut raw::Socket = socket_set_guard.get_mut::(self.handle.smoltcp_handle().unwrap()); // 暴力解决方案:只考虑0号网卡。 TODO:考虑多网卡的情况!!! let iface = NET_DEVICES.read_irqsave().get(&0).unwrap().clone(); // 构造IP头 let ipv4_src_addr: Option = iface.inner_iface().lock().ipv4_addr(); if ipv4_src_addr.is_none() { return Err(SystemError::ENETUNREACH); } let ipv4_src_addr = ipv4_src_addr.unwrap(); if let wire::IpAddress::Ipv4(ipv4_dst) = endpoint.addr { let len = buf.len(); // 创建20字节的IPv4头部 let mut buffer: Vec = vec![0u8; len + 20]; let mut packet: wire::Ipv4Packet<&mut Vec> = wire::Ipv4Packet::new_unchecked(&mut buffer); // 封装ipv4 header packet.set_version(4); packet.set_header_len(20); packet.set_total_len((20 + len) as u16); packet.set_src_addr(ipv4_src_addr); packet.set_dst_addr(ipv4_dst); // 设置ipv4 header的protocol字段 packet.set_next_header(socket.ip_protocol()); // 获取IP数据包的负载字段 let payload: &mut [u8] = packet.payload_mut(); payload.copy_from_slice(buf); // 填充checksum字段 packet.fill_checksum(); // 发送数据包 socket.send_slice(&buffer).unwrap(); iface.poll(&mut socket_set_guard).ok(); drop(socket_set_guard); return Ok(len); } else { warn!("Unsupport Ip protocol type!"); return Err(SystemError::EINVAL); } } else { // 如果没有指定目的地址,则返回错误 return Err(SystemError::ENOTCONN); } } } fn connect(&mut self, _endpoint: Endpoint) -> Result<(), SystemError> { Ok(()) } fn metadata(&self) -> SocketMetadata { self.metadata.clone() } fn box_clone(&self) -> Box { Box::new(self.clone()) } fn socket_handle(&self) -> GlobalSocketHandle { self.handle } fn as_any_ref(&self) -> &dyn core::any::Any { self } fn as_any_mut(&mut self) -> &mut dyn core::any::Any { self } } /// @brief 表示udp socket /// /// https://man7.org/linux/man-pages/man7/udp.7.html #[derive(Debug, Clone)] pub struct UdpSocket { pub handle: GlobalSocketHandle, remote_endpoint: Option, // 记录远程endpoint提供给connect(), 应该使用IP地址。 metadata: SocketMetadata, } impl UdpSocket { /// 元数据的缓冲区的大小 pub const DEFAULT_METADATA_BUF_SIZE: usize = 1024; /// 默认的接收缓冲区的大小 receive pub const DEFAULT_RX_BUF_SIZE: usize = 64 * 1024; /// 默认的发送缓冲区的大小 transmiss pub const DEFAULT_TX_BUF_SIZE: usize = 64 * 1024; /// @brief 创建一个udp的socket /// /// @param options socket的选项 /// /// @return 返回创建的udp的socket pub fn new(options: SocketOptions) -> Self { let rx_buffer = udp::PacketBuffer::new( vec![udp::PacketMetadata::EMPTY; Self::DEFAULT_METADATA_BUF_SIZE], vec![0; Self::DEFAULT_RX_BUF_SIZE], ); let tx_buffer = udp::PacketBuffer::new( vec![udp::PacketMetadata::EMPTY; Self::DEFAULT_METADATA_BUF_SIZE], vec![0; Self::DEFAULT_TX_BUF_SIZE], ); let socket = udp::Socket::new(rx_buffer, tx_buffer); // 把socket添加到socket集合中,并得到socket的句柄 let handle: GlobalSocketHandle = GlobalSocketHandle::new_smoltcp_handle(SOCKET_SET.lock_irqsave().add(socket)); let metadata = SocketMetadata::new( SocketType::Udp, Self::DEFAULT_RX_BUF_SIZE, Self::DEFAULT_TX_BUF_SIZE, Self::DEFAULT_METADATA_BUF_SIZE, options, ); return Self { handle, remote_endpoint: None, metadata, }; } fn do_bind(&self, socket: &mut udp::Socket, endpoint: Endpoint) -> Result<(), SystemError> { if let Endpoint::Ip(Some(mut ip)) = endpoint { // 端口为0则分配随机端口 if ip.port == 0 { ip.port = PORT_MANAGER.get_ephemeral_port(self.metadata.socket_type)?; } // 检测端口是否已被占用 PORT_MANAGER.bind_port(self.metadata.socket_type, ip.port)?; let bind_res = if ip.addr.is_unspecified() { socket.bind(ip.port) } else { socket.bind(ip) }; match bind_res { Ok(()) => return Ok(()), Err(_) => return Err(SystemError::EINVAL), } } else { return Err(SystemError::EINVAL); } } } impl Socket for UdpSocket { fn close(&mut self) { let mut socket_set_guard = SOCKET_SET.lock_irqsave(); if let smoltcp::socket::Socket::Udp(mut sock) = socket_set_guard.remove(self.handle.smoltcp_handle().unwrap()) { sock.close(); } drop(socket_set_guard); poll_ifaces(); } /// @brief 在read函数执行之前,请先bind到本地的指定端口 fn read(&self, buf: &mut [u8]) -> (Result, Endpoint) { loop { // debug!("Wait22 to Read"); poll_ifaces(); let mut socket_set_guard = SOCKET_SET.lock_irqsave(); let socket = socket_set_guard.get_mut::(self.handle.smoltcp_handle().unwrap()); // debug!("Wait to Read"); if socket.can_recv() { if let Ok((size, metadata)) = socket.recv_slice(buf) { drop(socket_set_guard); poll_ifaces(); return (Ok(size), Endpoint::Ip(Some(metadata.endpoint))); } } else { // 如果socket没有连接,则忙等 // return (Err(SystemError::ENOTCONN), Endpoint::Ip(None)); } drop(socket_set_guard); SocketHandleItem::sleep( self.socket_handle(), EPollEventType::EPOLLIN.bits() as u64, HANDLE_MAP.read_irqsave(), ); } } fn write(&self, buf: &[u8], to: Option) -> Result { // debug!("udp to send: {:?}, len={}", to, buf.len()); let remote_endpoint: &wire::IpEndpoint = { if let Some(Endpoint::Ip(Some(ref endpoint))) = to { endpoint } else if let Some(Endpoint::Ip(Some(ref endpoint))) = self.remote_endpoint { endpoint } else { return Err(SystemError::ENOTCONN); } }; // debug!("udp write: remote = {:?}", remote_endpoint); let mut socket_set_guard = SOCKET_SET.lock_irqsave(); let socket = socket_set_guard.get_mut::(self.handle.smoltcp_handle().unwrap()); // debug!("is open()={}", socket.is_open()); // debug!("socket endpoint={:?}", socket.endpoint()); if socket.can_send() { // debug!("udp write: can send"); match socket.send_slice(buf, *remote_endpoint) { Ok(()) => { // debug!("udp write: send ok"); drop(socket_set_guard); poll_ifaces(); return Ok(buf.len()); } Err(_) => { // debug!("udp write: send err"); return Err(SystemError::ENOBUFS); } } } else { // debug!("udp write: can not send"); return Err(SystemError::ENOBUFS); }; } fn bind(&mut self, endpoint: Endpoint) -> Result<(), SystemError> { let mut sockets = SOCKET_SET.lock_irqsave(); let socket = sockets.get_mut::(self.handle.smoltcp_handle().unwrap()); // debug!("UDP Bind to {:?}", endpoint); return self.do_bind(socket, endpoint); } fn poll(&self) -> EPollEventType { let sockets = SOCKET_SET.lock_irqsave(); let socket = sockets.get::(self.handle.smoltcp_handle().unwrap()); return SocketPollMethod::udp_poll( socket, HANDLE_MAP .read_irqsave() .get(&self.socket_handle()) .unwrap() .shutdown_type(), ); } fn connect(&mut self, endpoint: Endpoint) -> Result<(), SystemError> { if let Endpoint::Ip(_) = endpoint { self.remote_endpoint = Some(endpoint); Ok(()) } else { Err(SystemError::EINVAL) } } fn ioctl( &self, _cmd: usize, _arg0: usize, _arg1: usize, _arg2: usize, ) -> Result { todo!() } fn metadata(&self) -> SocketMetadata { self.metadata.clone() } fn box_clone(&self) -> Box { return Box::new(self.clone()); } fn endpoint(&self) -> Option { let sockets = SOCKET_SET.lock_irqsave(); let socket = sockets.get::(self.handle.smoltcp_handle().unwrap()); let listen_endpoint = socket.endpoint(); if listen_endpoint.port == 0 { return None; } else { // 如果listen_endpoint的address是None,意味着“监听所有的地址”。 // 这里假设所有的地址都是ipv4 // TODO: 支持ipv6 let result = wire::IpEndpoint::new( listen_endpoint .addr .unwrap_or(wire::IpAddress::v4(0, 0, 0, 0)), listen_endpoint.port, ); return Some(Endpoint::Ip(Some(result))); } } fn peer_endpoint(&self) -> Option { return self.remote_endpoint.clone(); } fn socket_handle(&self) -> GlobalSocketHandle { self.handle } fn as_any_ref(&self) -> &dyn core::any::Any { self } fn as_any_mut(&mut self) -> &mut dyn core::any::Any { self } } /// @brief 表示 tcp socket /// /// https://man7.org/linux/man-pages/man7/tcp.7.html #[derive(Debug, Clone)] pub struct TcpSocket { handles: Vec, local_endpoint: Option, // save local endpoint for bind() is_listening: bool, metadata: SocketMetadata, } impl TcpSocket { /// 元数据的缓冲区的大小 pub const DEFAULT_METADATA_BUF_SIZE: usize = 1024; /// 默认的接收缓冲区的大小 receive pub const DEFAULT_RX_BUF_SIZE: usize = 512 * 1024; /// 默认的发送缓冲区的大小 transmiss pub const DEFAULT_TX_BUF_SIZE: usize = 512 * 1024; /// TcpSocket的特殊事件,用于在事件等待队列上sleep pub const CAN_CONNECT: u64 = 1u64 << 63; pub const CAN_ACCPET: u64 = 1u64 << 62; /// @brief 创建一个tcp的socket /// /// @param options socket的选项 /// /// @return 返回创建的tcp的socket pub fn new(options: SocketOptions) -> Self { // 创建handles数组并把socket添加到socket集合中,并得到socket的句柄 let handles: Vec = vec![GlobalSocketHandle::new_smoltcp_handle( SOCKET_SET.lock_irqsave().add(Self::create_new_socket()), )]; let metadata = SocketMetadata::new( SocketType::Tcp, Self::DEFAULT_RX_BUF_SIZE, Self::DEFAULT_TX_BUF_SIZE, Self::DEFAULT_METADATA_BUF_SIZE, options, ); // debug!("when there's a new tcp socket,its'len: {}",handles.len()); return Self { handles, local_endpoint: None, is_listening: false, metadata, }; } fn do_listen( &mut self, socket: &mut tcp::Socket, local_endpoint: wire::IpEndpoint, ) -> Result<(), SystemError> { let listen_result = if local_endpoint.addr.is_unspecified() { // debug!("Tcp Socket Listen on port {}", local_endpoint.port); socket.listen(local_endpoint.port) } else { // debug!("Tcp Socket Listen on {local_endpoint}"); socket.listen(local_endpoint) }; return match listen_result { Ok(()) => { // debug!( // "Tcp Socket Listen on {local_endpoint}, open?:{}", // socket.is_open() // ); self.is_listening = true; Ok(()) } Err(_) => Err(SystemError::EINVAL), }; } /// # create_new_socket - 创建新的TCP套接字 /// /// 该函数用于创建一个新的TCP套接字,并返回该套接字的引用。 fn create_new_socket() -> tcp::Socket<'static> { // 初始化tcp的buffer let rx_buffer = tcp::SocketBuffer::new(vec![0; Self::DEFAULT_RX_BUF_SIZE]); let tx_buffer = tcp::SocketBuffer::new(vec![0; Self::DEFAULT_TX_BUF_SIZE]); tcp::Socket::new(rx_buffer, tx_buffer) } } impl Socket for TcpSocket { fn close(&mut self) { for handle in self.handles.iter() { { let mut socket_set_guard = SOCKET_SET.lock_irqsave(); let smoltcp_handle = handle.smoltcp_handle().unwrap(); socket_set_guard .get_mut::(smoltcp_handle) .close(); drop(socket_set_guard); } poll_ifaces(); SOCKET_SET .lock_irqsave() .remove(handle.smoltcp_handle().unwrap()); // debug!("[Socket] [TCP] Close: {:?}", handle); } } fn read(&self, buf: &mut [u8]) -> (Result, Endpoint) { if HANDLE_MAP .read_irqsave() .get(&self.socket_handle()) .unwrap() .shutdown_type() .contains(ShutdownType::RCV_SHUTDOWN) { return (Err(SystemError::ENOTCONN), Endpoint::Ip(None)); } // debug!("tcp socket: read, buf len={}", buf.len()); // debug!("tcp socket:read, socket'len={}",self.handle.len()); loop { poll_ifaces(); let mut socket_set_guard = SOCKET_SET.lock_irqsave(); let socket = socket_set_guard .get_mut::(self.handles.get(0).unwrap().smoltcp_handle().unwrap()); // 如果socket已经关闭,返回错误 if !socket.is_active() { // debug!("Tcp Socket Read Error, socket is closed"); return (Err(SystemError::ENOTCONN), Endpoint::Ip(None)); } if socket.may_recv() { match socket.recv_slice(buf) { Ok(size) => { if size > 0 { let endpoint = if let Some(p) = socket.remote_endpoint() { p } else { return (Err(SystemError::ENOTCONN), Endpoint::Ip(None)); }; drop(socket_set_guard); poll_ifaces(); return (Ok(size), Endpoint::Ip(Some(endpoint))); } } Err(tcp::RecvError::InvalidState) => { warn!("Tcp Socket Read Error, InvalidState"); return (Err(SystemError::ENOTCONN), Endpoint::Ip(None)); } Err(tcp::RecvError::Finished) => { // 对端写端已关闭,我们应该关闭读端 HANDLE_MAP .write_irqsave() .get_mut(&self.socket_handle()) .unwrap() .shutdown_type_writer() .insert(ShutdownType::RCV_SHUTDOWN); return (Err(SystemError::ENOTCONN), Endpoint::Ip(None)); } } } else { return (Err(SystemError::ENOTCONN), Endpoint::Ip(None)); } drop(socket_set_guard); SocketHandleItem::sleep( self.socket_handle(), (EPollEventType::EPOLLIN.bits() | EPollEventType::EPOLLHUP.bits()) as u64, HANDLE_MAP.read_irqsave(), ); } } fn write(&self, buf: &[u8], _to: Option) -> Result { if HANDLE_MAP .read_irqsave() .get(&self.socket_handle()) .unwrap() .shutdown_type() .contains(ShutdownType::RCV_SHUTDOWN) { return Err(SystemError::ENOTCONN); } // debug!("tcp socket:write, socket'len={}",self.handle.len()); let mut socket_set_guard = SOCKET_SET.lock_irqsave(); let socket = socket_set_guard .get_mut::(self.handles.get(0).unwrap().smoltcp_handle().unwrap()); if socket.is_open() { if socket.can_send() { match socket.send_slice(buf) { Ok(size) => { drop(socket_set_guard); poll_ifaces(); return Ok(size); } Err(e) => { error!("Tcp Socket Write Error {e:?}"); return Err(SystemError::ENOBUFS); } } } else { return Err(SystemError::ENOBUFS); } } return Err(SystemError::ENOTCONN); } fn poll(&self) -> EPollEventType { let mut socket_set_guard = SOCKET_SET.lock_irqsave(); // debug!("tcp socket:poll, socket'len={}",self.handle.len()); let socket = socket_set_guard .get_mut::(self.handles.get(0).unwrap().smoltcp_handle().unwrap()); return SocketPollMethod::tcp_poll( socket, HANDLE_MAP .read_irqsave() .get(&self.socket_handle()) .unwrap() .shutdown_type(), ); } fn connect(&mut self, endpoint: Endpoint) -> Result<(), SystemError> { let mut sockets = SOCKET_SET.lock_irqsave(); // debug!("tcp socket:connect, socket'len={}",self.handle.len()); let socket = sockets.get_mut::(self.handles.get(0).unwrap().smoltcp_handle().unwrap()); if let Endpoint::Ip(Some(ip)) = endpoint { let temp_port = PORT_MANAGER.get_ephemeral_port(self.metadata.socket_type)?; // 检测端口是否被占用 PORT_MANAGER.bind_port(self.metadata.socket_type, temp_port)?; // debug!("temp_port: {}", temp_port); let iface: Arc = NET_DEVICES.write_irqsave().get(&0).unwrap().clone(); let mut inner_iface = iface.inner_iface().lock(); // debug!("to connect: {ip:?}"); match socket.connect(inner_iface.context(), ip, temp_port) { Ok(()) => { // avoid deadlock drop(inner_iface); drop(iface); drop(sockets); loop { poll_ifaces(); let mut sockets = SOCKET_SET.lock_irqsave(); let socket = sockets.get_mut::( self.handles.get(0).unwrap().smoltcp_handle().unwrap(), ); match socket.state() { tcp::State::Established => { return Ok(()); } tcp::State::SynSent => { drop(sockets); SocketHandleItem::sleep( self.socket_handle(), Self::CAN_CONNECT, HANDLE_MAP.read_irqsave(), ); } _ => { return Err(SystemError::ECONNREFUSED); } } } } Err(e) => { // error!("Tcp Socket Connect Error {e:?}"); match e { tcp::ConnectError::InvalidState => return Err(SystemError::EISCONN), tcp::ConnectError::Unaddressable => return Err(SystemError::EADDRNOTAVAIL), } } } } else { return Err(SystemError::EINVAL); } } /// @brief tcp socket 监听 local_endpoint 端口 /// /// @param backlog 未处理的连接队列的最大长度 fn listen(&mut self, backlog: usize) -> Result<(), SystemError> { if self.is_listening { return Ok(()); } let local_endpoint = self.local_endpoint.ok_or(SystemError::EINVAL)?; let mut sockets = SOCKET_SET.lock_irqsave(); // 获取handle的数量 let handlen = self.handles.len(); let backlog = handlen.max(backlog); // 添加剩余需要构建的socket // debug!("tcp socket:before listen, socket'len={}", self.handle_list.len()); let mut handle_guard = HANDLE_MAP.write_irqsave(); let wait_queue = Arc::clone(&handle_guard.get(&self.socket_handle()).unwrap().wait_queue); self.handles.extend((handlen..backlog).map(|_| { let socket = Self::create_new_socket(); let handle = GlobalSocketHandle::new_smoltcp_handle(sockets.add(socket)); let handle_item = SocketHandleItem::new(Some(wait_queue.clone())); handle_guard.insert(handle, handle_item); handle })); // debug!("tcp socket:listen, socket'len={}",self.handle.len()); // debug!("tcp socket:listen, backlog={backlog}"); // 监听所有的socket for i in 0..backlog { let handle = self.handles.get(i).unwrap(); let socket = sockets.get_mut::(handle.smoltcp_handle().unwrap()); if !socket.is_listening() { // debug!("Tcp Socket is already listening on {local_endpoint}"); self.do_listen(socket, local_endpoint)?; } // debug!("Tcp Socket before listen, open={}", socket.is_open()); } return Ok(()); } fn bind(&mut self, endpoint: Endpoint) -> Result<(), SystemError> { if let Endpoint::Ip(Some(mut ip)) = endpoint { if ip.port == 0 { ip.port = PORT_MANAGER.get_ephemeral_port(self.metadata.socket_type)?; } // 检测端口是否已被占用 PORT_MANAGER.bind_port(self.metadata.socket_type, ip.port)?; // debug!("tcp socket:bind, socket'len={}",self.handle.len()); self.local_endpoint = Some(ip); self.is_listening = false; return Ok(()); } return Err(SystemError::EINVAL); } fn shutdown(&mut self, shutdown_type: super::ShutdownType) -> Result<(), SystemError> { // TODO:目前只是在表层判断,对端不知晓,后续需使用tcp实现 HANDLE_MAP .write_irqsave() .get_mut(&self.socket_handle()) .unwrap() .shutdown_type = RwLock::new(shutdown_type); return Ok(()); } fn accept(&mut self) -> Result<(Box, Endpoint), SystemError> { if !self.is_listening { return Err(SystemError::EINVAL); } let endpoint = self.local_endpoint.ok_or(SystemError::EINVAL)?; loop { // debug!("tcp accept: poll_ifaces()"); poll_ifaces(); // debug!("tcp socket:accept, socket'len={}", self.handle_list.len()); let mut sockset = SOCKET_SET.lock_irqsave(); // Get the corresponding activated handler let global_handle_index = self.handles.iter().position(|handle| { let con_smol_sock = sockset.get::(handle.smoltcp_handle().unwrap()); con_smol_sock.is_active() }); if let Some(handle_index) = global_handle_index { let con_smol_sock = sockset .get::(self.handles[handle_index].smoltcp_handle().unwrap()); // debug!("[Socket] [TCP] Accept: {:?}", handle); // handle is connected socket's handle let remote_ep = con_smol_sock .remote_endpoint() .ok_or(SystemError::ENOTCONN)?; let mut tcp_socket = Self::create_new_socket(); self.do_listen(&mut tcp_socket, endpoint)?; let new_handle = GlobalSocketHandle::new_smoltcp_handle(sockset.add(tcp_socket)); // let handle in TcpSock be the new empty handle, and return the old connected handle let old_handle = core::mem::replace(&mut self.handles[handle_index], new_handle); let metadata = SocketMetadata::new( SocketType::Tcp, Self::DEFAULT_TX_BUF_SIZE, Self::DEFAULT_RX_BUF_SIZE, Self::DEFAULT_METADATA_BUF_SIZE, self.metadata.options, ); let sock_ret = Box::new(TcpSocket { handles: vec![old_handle], local_endpoint: self.local_endpoint, is_listening: false, metadata, }); { let mut handle_guard = HANDLE_MAP.write_irqsave(); // 先删除原来的 let item = handle_guard.remove(&old_handle).unwrap(); // 按照smoltcp行为,将新的handle绑定到原来的item let new_item = SocketHandleItem::new(None); handle_guard.insert(old_handle, new_item); // 插入新的item handle_guard.insert(new_handle, item); drop(handle_guard); } return Ok((sock_ret, Endpoint::Ip(Some(remote_ep)))); } drop(sockset); // debug!("[TCP] [Accept] sleeping socket with handle: {:?}", self.handles.get(0).unwrap().smoltcp_handle().unwrap()); SocketHandleItem::sleep( self.socket_handle(), // NOTICE Self::CAN_ACCPET, HANDLE_MAP.read_irqsave(), ); // debug!("tcp socket:after sleep, handle_guard'len={}",HANDLE_MAP.write_irqsave().len()); } } fn endpoint(&self) -> Option { let mut result: Option = self.local_endpoint.map(|x| Endpoint::Ip(Some(x))); if result.is_none() { let sockets = SOCKET_SET.lock_irqsave(); // debug!("tcp socket:endpoint, socket'len={}",self.handle.len()); let socket = sockets.get::(self.handles.get(0).unwrap().smoltcp_handle().unwrap()); if let Some(ep) = socket.local_endpoint() { result = Some(Endpoint::Ip(Some(ep))); } } return result; } fn peer_endpoint(&self) -> Option { let sockets = SOCKET_SET.lock_irqsave(); // debug!("tcp socket:peer_endpoint, socket'len={}",self.handle.len()); let socket = sockets.get::(self.handles.get(0).unwrap().smoltcp_handle().unwrap()); return socket.remote_endpoint().map(|x| Endpoint::Ip(Some(x))); } fn metadata(&self) -> SocketMetadata { self.metadata.clone() } fn box_clone(&self) -> Box { Box::new(self.clone()) } fn socket_handle(&self) -> GlobalSocketHandle { // debug!("tcp socket:socket_handle, socket'len={}",self.handle.len()); *self.handles.get(0).unwrap() } fn as_any_ref(&self) -> &dyn core::any::Any { self } fn as_any_mut(&mut self) -> &mut dyn core::any::Any { self } }