xref: /DragonOS/kernel/src/net/socket/inet.rs (revision 634349e0ebfca487e6aa2761a796f04895908718)
1 use alloc::{boxed::Box, sync::Arc, vec::Vec};
2 use log::{error, warn};
3 use smoltcp::{
4     socket::{raw, tcp, udp},
5     wire,
6 };
7 use system_error::SystemError;
8 
9 use crate::{
10     driver::net::NetDevice,
11     libs::rwlock::RwLock,
12     net::{
13         event_poll::EPollEventType, net_core::poll_ifaces, Endpoint, Protocol, ShutdownType,
14         NET_DEVICES,
15     },
16 };
17 
18 use super::{
19     handle::GlobalSocketHandle, PosixSocketHandleItem, Socket, SocketHandleItem, SocketMetadata,
20     SocketOptions, SocketPollMethod, SocketType, HANDLE_MAP, PORT_MANAGER, SOCKET_SET,
21 };
22 
23 /// @brief 表示原始的socket。原始套接字绕过传输层协议(如 TCP 或 UDP)并提供对网络层协议(如 IP)的直接访问。
24 ///
25 /// ref: https://man7.org/linux/man-pages/man7/raw.7.html
26 #[derive(Debug, Clone)]
27 pub struct RawSocket {
28     handle: GlobalSocketHandle,
29     /// 用户发送的数据包是否包含了IP头.
30     /// 如果是true,用户发送的数据包,必须包含IP头。(即用户要自行设置IP头+数据)
31     /// 如果是false,用户发送的数据包,不包含IP头。(即用户只要设置数据)
32     header_included: bool,
33     /// socket的metadata
34     metadata: SocketMetadata,
35     posix_item: Arc<PosixSocketHandleItem>,
36 }
37 
38 impl RawSocket {
39     /// 元数据的缓冲区的大小
40     pub const DEFAULT_METADATA_BUF_SIZE: usize = 1024;
41     /// 默认的接收缓冲区的大小 receive
42     pub const DEFAULT_RX_BUF_SIZE: usize = 64 * 1024;
43     /// 默认的发送缓冲区的大小 transmiss
44     pub const DEFAULT_TX_BUF_SIZE: usize = 64 * 1024;
45 
46     /// @brief 创建一个原始的socket
47     ///
48     /// @param protocol 协议号
49     /// @param options socket的选项
50     ///
51     /// @return 返回创建的原始的socket
52     pub fn new(protocol: Protocol, options: SocketOptions) -> Self {
53         let rx_buffer = raw::PacketBuffer::new(
54             vec![raw::PacketMetadata::EMPTY; Self::DEFAULT_METADATA_BUF_SIZE],
55             vec![0; Self::DEFAULT_RX_BUF_SIZE],
56         );
57         let tx_buffer = raw::PacketBuffer::new(
58             vec![raw::PacketMetadata::EMPTY; Self::DEFAULT_METADATA_BUF_SIZE],
59             vec![0; Self::DEFAULT_TX_BUF_SIZE],
60         );
61         let protocol: u8 = protocol.into();
62         let socket = raw::Socket::new(
63             wire::IpVersion::Ipv4,
64             wire::IpProtocol::from(protocol),
65             rx_buffer,
66             tx_buffer,
67         );
68 
69         // 把socket添加到socket集合中,并得到socket的句柄
70         let handle = GlobalSocketHandle::new_smoltcp_handle(SOCKET_SET.lock_irqsave().add(socket));
71 
72         let metadata = SocketMetadata::new(
73             SocketType::Raw,
74             Self::DEFAULT_RX_BUF_SIZE,
75             Self::DEFAULT_TX_BUF_SIZE,
76             Self::DEFAULT_METADATA_BUF_SIZE,
77             options,
78         );
79 
80         let posix_item = Arc::new(PosixSocketHandleItem::new(None));
81 
82         return Self {
83             handle,
84             header_included: false,
85             metadata,
86             posix_item,
87         };
88     }
89 }
90 
91 impl Socket for RawSocket {
92     fn posix_item(&self) -> Arc<PosixSocketHandleItem> {
93         self.posix_item.clone()
94     }
95 
96     fn close(&mut self) {
97         let mut socket_set_guard = SOCKET_SET.lock_irqsave();
98         if let smoltcp::socket::Socket::Udp(mut sock) =
99             socket_set_guard.remove(self.handle.smoltcp_handle().unwrap())
100         {
101             sock.close();
102         }
103         drop(socket_set_guard);
104         poll_ifaces();
105     }
106 
107     fn read(&self, buf: &mut [u8]) -> (Result<usize, SystemError>, Endpoint) {
108         poll_ifaces();
109         loop {
110             // 如何优化这里?
111             let mut socket_set_guard = SOCKET_SET.lock_irqsave();
112             let socket =
113                 socket_set_guard.get_mut::<raw::Socket>(self.handle.smoltcp_handle().unwrap());
114 
115             match socket.recv_slice(buf) {
116                 Ok(len) => {
117                     let packet = wire::Ipv4Packet::new_unchecked(buf);
118                     return (
119                         Ok(len),
120                         Endpoint::Ip(Some(wire::IpEndpoint {
121                             addr: wire::IpAddress::Ipv4(packet.src_addr()),
122                             port: 0,
123                         })),
124                     );
125                 }
126                 Err(_) => {
127                     if !self.metadata.options.contains(SocketOptions::BLOCK) {
128                         // 如果是非阻塞的socket,就返回错误
129                         return (Err(SystemError::EAGAIN_OR_EWOULDBLOCK), Endpoint::Ip(None));
130                     }
131                 }
132             }
133             drop(socket_set_guard);
134             self.posix_item.sleep(EPollEventType::EPOLLIN.bits() as u64);
135         }
136     }
137 
138     fn write(&self, buf: &[u8], to: Option<Endpoint>) -> Result<usize, SystemError> {
139         // 如果用户发送的数据包,包含IP头,则直接发送
140         if self.header_included {
141             let mut socket_set_guard = SOCKET_SET.lock_irqsave();
142             let socket =
143                 socket_set_guard.get_mut::<raw::Socket>(self.handle.smoltcp_handle().unwrap());
144             match socket.send_slice(buf) {
145                 Ok(_) => {
146                     return Ok(buf.len());
147                 }
148                 Err(raw::SendError::BufferFull) => {
149                     return Err(SystemError::ENOBUFS);
150                 }
151             }
152         } else {
153             // 如果用户发送的数据包,不包含IP头,则需要自己构造IP头
154 
155             if let Some(Endpoint::Ip(Some(endpoint))) = to {
156                 let mut socket_set_guard = SOCKET_SET.lock_irqsave();
157                 let socket: &mut raw::Socket =
158                     socket_set_guard.get_mut::<raw::Socket>(self.handle.smoltcp_handle().unwrap());
159 
160                 // 暴力解决方案:只考虑0号网卡。 TODO:考虑多网卡的情况!!!
161                 let iface = NET_DEVICES.read_irqsave().get(&0).unwrap().clone();
162 
163                 // 构造IP头
164                 let ipv4_src_addr: Option<wire::Ipv4Address> =
165                     iface.inner_iface().lock().ipv4_addr();
166                 if ipv4_src_addr.is_none() {
167                     return Err(SystemError::ENETUNREACH);
168                 }
169                 let ipv4_src_addr = ipv4_src_addr.unwrap();
170 
171                 if let wire::IpAddress::Ipv4(ipv4_dst) = endpoint.addr {
172                     let len = buf.len();
173 
174                     // 创建20字节的IPv4头部
175                     let mut buffer: Vec<u8> = vec![0u8; len + 20];
176                     let mut packet: wire::Ipv4Packet<&mut Vec<u8>> =
177                         wire::Ipv4Packet::new_unchecked(&mut buffer);
178 
179                     // 封装ipv4 header
180                     packet.set_version(4);
181                     packet.set_header_len(20);
182                     packet.set_total_len((20 + len) as u16);
183                     packet.set_src_addr(ipv4_src_addr);
184                     packet.set_dst_addr(ipv4_dst);
185 
186                     // 设置ipv4 header的protocol字段
187                     packet.set_next_header(socket.ip_protocol());
188 
189                     // 获取IP数据包的负载字段
190                     let payload: &mut [u8] = packet.payload_mut();
191                     payload.copy_from_slice(buf);
192 
193                     // 填充checksum字段
194                     packet.fill_checksum();
195 
196                     // 发送数据包
197                     socket.send_slice(&buffer).unwrap();
198 
199                     iface.poll(&mut socket_set_guard).ok();
200 
201                     drop(socket_set_guard);
202                     return Ok(len);
203                 } else {
204                     warn!("Unsupport Ip protocol type!");
205                     return Err(SystemError::EINVAL);
206                 }
207             } else {
208                 // 如果没有指定目的地址,则返回错误
209                 return Err(SystemError::ENOTCONN);
210             }
211         }
212     }
213 
214     fn connect(&mut self, _endpoint: Endpoint) -> Result<(), SystemError> {
215         Ok(())
216     }
217 
218     fn metadata(&self) -> SocketMetadata {
219         self.metadata.clone()
220     }
221 
222     fn box_clone(&self) -> Box<dyn Socket> {
223         Box::new(self.clone())
224     }
225 
226     fn socket_handle(&self) -> GlobalSocketHandle {
227         self.handle
228     }
229 
230     fn as_any_ref(&self) -> &dyn core::any::Any {
231         self
232     }
233 
234     fn as_any_mut(&mut self) -> &mut dyn core::any::Any {
235         self
236     }
237 }
238 
239 /// @brief 表示udp socket
240 ///
241 /// https://man7.org/linux/man-pages/man7/udp.7.html
242 #[derive(Debug, Clone)]
243 pub struct UdpSocket {
244     pub handle: GlobalSocketHandle,
245     remote_endpoint: Option<Endpoint>, // 记录远程endpoint提供给connect(), 应该使用IP地址。
246     metadata: SocketMetadata,
247     posix_item: Arc<PosixSocketHandleItem>,
248 }
249 
250 impl UdpSocket {
251     /// 元数据的缓冲区的大小
252     pub const DEFAULT_METADATA_BUF_SIZE: usize = 1024;
253     /// 默认的接收缓冲区的大小 receive
254     pub const DEFAULT_RX_BUF_SIZE: usize = 64 * 1024;
255     /// 默认的发送缓冲区的大小 transmiss
256     pub const DEFAULT_TX_BUF_SIZE: usize = 64 * 1024;
257 
258     /// @brief 创建一个udp的socket
259     ///
260     /// @param options socket的选项
261     ///
262     /// @return 返回创建的udp的socket
263     pub fn new(options: SocketOptions) -> Self {
264         let rx_buffer = udp::PacketBuffer::new(
265             vec![udp::PacketMetadata::EMPTY; Self::DEFAULT_METADATA_BUF_SIZE],
266             vec![0; Self::DEFAULT_RX_BUF_SIZE],
267         );
268         let tx_buffer = udp::PacketBuffer::new(
269             vec![udp::PacketMetadata::EMPTY; Self::DEFAULT_METADATA_BUF_SIZE],
270             vec![0; Self::DEFAULT_TX_BUF_SIZE],
271         );
272         let socket = udp::Socket::new(rx_buffer, tx_buffer);
273 
274         // 把socket添加到socket集合中,并得到socket的句柄
275         let handle: GlobalSocketHandle =
276             GlobalSocketHandle::new_smoltcp_handle(SOCKET_SET.lock_irqsave().add(socket));
277 
278         let metadata = SocketMetadata::new(
279             SocketType::Udp,
280             Self::DEFAULT_RX_BUF_SIZE,
281             Self::DEFAULT_TX_BUF_SIZE,
282             Self::DEFAULT_METADATA_BUF_SIZE,
283             options,
284         );
285 
286         let posix_item = Arc::new(PosixSocketHandleItem::new(None));
287 
288         return Self {
289             handle,
290             remote_endpoint: None,
291             metadata,
292             posix_item,
293         };
294     }
295 
296     fn do_bind(&self, socket: &mut udp::Socket, endpoint: Endpoint) -> Result<(), SystemError> {
297         if let Endpoint::Ip(Some(mut ip)) = endpoint {
298             // 端口为0则分配随机端口
299             if ip.port == 0 {
300                 ip.port = PORT_MANAGER.get_ephemeral_port(self.metadata.socket_type)?;
301             }
302             // 检测端口是否已被占用
303             PORT_MANAGER.bind_port(self.metadata.socket_type, ip.port)?;
304 
305             let bind_res = if ip.addr.is_unspecified() {
306                 socket.bind(ip.port)
307             } else {
308                 socket.bind(ip)
309             };
310 
311             match bind_res {
312                 Ok(()) => return Ok(()),
313                 Err(_) => return Err(SystemError::EINVAL),
314             }
315         } else {
316             return Err(SystemError::EINVAL);
317         }
318     }
319 }
320 
321 impl Socket for UdpSocket {
322     fn posix_item(&self) -> Arc<PosixSocketHandleItem> {
323         self.posix_item.clone()
324     }
325 
326     fn close(&mut self) {
327         let mut socket_set_guard = SOCKET_SET.lock_irqsave();
328         if let smoltcp::socket::Socket::Udp(mut sock) =
329             socket_set_guard.remove(self.handle.smoltcp_handle().unwrap())
330         {
331             sock.close();
332         }
333         drop(socket_set_guard);
334         poll_ifaces();
335     }
336 
337     /// @brief 在read函数执行之前,请先bind到本地的指定端口
338     fn read(&self, buf: &mut [u8]) -> (Result<usize, SystemError>, Endpoint) {
339         loop {
340             // debug!("Wait22 to Read");
341             poll_ifaces();
342             let mut socket_set_guard = SOCKET_SET.lock_irqsave();
343             let socket =
344                 socket_set_guard.get_mut::<udp::Socket>(self.handle.smoltcp_handle().unwrap());
345 
346             // debug!("Wait to Read");
347 
348             if socket.can_recv() {
349                 if let Ok((size, metadata)) = socket.recv_slice(buf) {
350                     drop(socket_set_guard);
351                     poll_ifaces();
352                     return (Ok(size), Endpoint::Ip(Some(metadata.endpoint)));
353                 }
354             } else {
355                 // 如果socket没有连接,则忙等
356                 // return (Err(SystemError::ENOTCONN), Endpoint::Ip(None));
357             }
358             drop(socket_set_guard);
359             self.posix_item.sleep(EPollEventType::EPOLLIN.bits() as u64);
360         }
361     }
362 
363     fn write(&self, buf: &[u8], to: Option<Endpoint>) -> Result<usize, SystemError> {
364         // debug!("udp to send: {:?}, len={}", to, buf.len());
365         let remote_endpoint: &wire::IpEndpoint = {
366             if let Some(Endpoint::Ip(Some(ref endpoint))) = to {
367                 endpoint
368             } else if let Some(Endpoint::Ip(Some(ref endpoint))) = self.remote_endpoint {
369                 endpoint
370             } else {
371                 return Err(SystemError::ENOTCONN);
372             }
373         };
374         // debug!("udp write: remote = {:?}", remote_endpoint);
375 
376         let mut socket_set_guard = SOCKET_SET.lock_irqsave();
377         let socket = socket_set_guard.get_mut::<udp::Socket>(self.handle.smoltcp_handle().unwrap());
378         // debug!("is open()={}", socket.is_open());
379         // debug!("socket endpoint={:?}", socket.endpoint());
380         if socket.can_send() {
381             // debug!("udp write: can send");
382             match socket.send_slice(buf, *remote_endpoint) {
383                 Ok(()) => {
384                     // debug!("udp write: send ok");
385                     drop(socket_set_guard);
386                     poll_ifaces();
387                     return Ok(buf.len());
388                 }
389                 Err(_) => {
390                     // debug!("udp write: send err");
391                     return Err(SystemError::ENOBUFS);
392                 }
393             }
394         } else {
395             // debug!("udp write: can not send");
396             return Err(SystemError::ENOBUFS);
397         };
398     }
399 
400     fn bind(&mut self, endpoint: Endpoint) -> Result<(), SystemError> {
401         let mut sockets = SOCKET_SET.lock_irqsave();
402         let socket = sockets.get_mut::<udp::Socket>(self.handle.smoltcp_handle().unwrap());
403         // debug!("UDP Bind to {:?}", endpoint);
404         return self.do_bind(socket, endpoint);
405     }
406 
407     fn poll(&self) -> EPollEventType {
408         let sockets = SOCKET_SET.lock_irqsave();
409         let socket = sockets.get::<udp::Socket>(self.handle.smoltcp_handle().unwrap());
410 
411         return SocketPollMethod::udp_poll(
412             socket,
413             HANDLE_MAP
414                 .read_irqsave()
415                 .get(&self.socket_handle())
416                 .unwrap()
417                 .shutdown_type(),
418         );
419     }
420 
421     fn connect(&mut self, endpoint: Endpoint) -> Result<(), SystemError> {
422         if let Endpoint::Ip(_) = endpoint {
423             self.remote_endpoint = Some(endpoint);
424             Ok(())
425         } else {
426             Err(SystemError::EINVAL)
427         }
428     }
429 
430     fn ioctl(
431         &self,
432         _cmd: usize,
433         _arg0: usize,
434         _arg1: usize,
435         _arg2: usize,
436     ) -> Result<usize, SystemError> {
437         todo!()
438     }
439 
440     fn metadata(&self) -> SocketMetadata {
441         self.metadata.clone()
442     }
443 
444     fn box_clone(&self) -> Box<dyn Socket> {
445         return Box::new(self.clone());
446     }
447 
448     fn endpoint(&self) -> Option<Endpoint> {
449         let sockets = SOCKET_SET.lock_irqsave();
450         let socket = sockets.get::<udp::Socket>(self.handle.smoltcp_handle().unwrap());
451         let listen_endpoint = socket.endpoint();
452 
453         if listen_endpoint.port == 0 {
454             return None;
455         } else {
456             // 如果listen_endpoint的address是None,意味着“监听所有的地址”。
457             // 这里假设所有的地址都是ipv4
458             // TODO: 支持ipv6
459             let result = wire::IpEndpoint::new(
460                 listen_endpoint
461                     .addr
462                     .unwrap_or(wire::IpAddress::v4(0, 0, 0, 0)),
463                 listen_endpoint.port,
464             );
465             return Some(Endpoint::Ip(Some(result)));
466         }
467     }
468 
469     fn peer_endpoint(&self) -> Option<Endpoint> {
470         return self.remote_endpoint.clone();
471     }
472 
473     fn socket_handle(&self) -> GlobalSocketHandle {
474         self.handle
475     }
476 
477     fn as_any_ref(&self) -> &dyn core::any::Any {
478         self
479     }
480 
481     fn as_any_mut(&mut self) -> &mut dyn core::any::Any {
482         self
483     }
484 }
485 
486 /// @brief 表示 tcp socket
487 ///
488 /// https://man7.org/linux/man-pages/man7/tcp.7.html
489 #[derive(Debug, Clone)]
490 pub struct TcpSocket {
491     handles: Vec<GlobalSocketHandle>,
492     local_endpoint: Option<wire::IpEndpoint>, // save local endpoint for bind()
493     is_listening: bool,
494     metadata: SocketMetadata,
495     posix_item: Arc<PosixSocketHandleItem>,
496 }
497 
498 impl TcpSocket {
499     /// 元数据的缓冲区的大小
500     pub const DEFAULT_METADATA_BUF_SIZE: usize = 1024;
501     /// 默认的接收缓冲区的大小 receive
502     pub const DEFAULT_RX_BUF_SIZE: usize = 512 * 1024;
503     /// 默认的发送缓冲区的大小 transmiss
504     pub const DEFAULT_TX_BUF_SIZE: usize = 512 * 1024;
505 
506     /// TcpSocket的特殊事件,用于在事件等待队列上sleep
507     pub const CAN_CONNECT: u64 = 1u64 << 63;
508     pub const CAN_ACCPET: u64 = 1u64 << 62;
509 
510     /// @brief 创建一个tcp的socket
511     ///
512     /// @param options socket的选项
513     ///
514     /// @return 返回创建的tcp的socket
515     pub fn new(options: SocketOptions) -> Self {
516         // 创建handles数组并把socket添加到socket集合中,并得到socket的句柄
517         let handles: Vec<GlobalSocketHandle> = vec![GlobalSocketHandle::new_smoltcp_handle(
518             SOCKET_SET.lock_irqsave().add(Self::create_new_socket()),
519         )];
520 
521         let metadata = SocketMetadata::new(
522             SocketType::Tcp,
523             Self::DEFAULT_RX_BUF_SIZE,
524             Self::DEFAULT_TX_BUF_SIZE,
525             Self::DEFAULT_METADATA_BUF_SIZE,
526             options,
527         );
528         let posix_item = Arc::new(PosixSocketHandleItem::new(None));
529         // debug!("when there's a new tcp socket,its'len: {}",handles.len());
530 
531         return Self {
532             handles,
533             local_endpoint: None,
534             is_listening: false,
535             metadata,
536             posix_item,
537         };
538     }
539 
540     fn do_listen(
541         &mut self,
542         socket: &mut tcp::Socket,
543         local_endpoint: wire::IpEndpoint,
544     ) -> Result<(), SystemError> {
545         let listen_result = if local_endpoint.addr.is_unspecified() {
546             socket.listen(local_endpoint.port)
547         } else {
548             socket.listen(local_endpoint)
549         };
550         return match listen_result {
551             Ok(()) => {
552                 // debug!(
553                 //     "Tcp Socket Listen on {local_endpoint}, open?:{}",
554                 //     socket.is_open()
555                 // );
556                 self.is_listening = true;
557 
558                 Ok(())
559             }
560             Err(_) => Err(SystemError::EINVAL),
561         };
562     }
563 
564     /// # create_new_socket - 创建新的TCP套接字
565     ///
566     /// 该函数用于创建一个新的TCP套接字,并返回该套接字的引用。
567     fn create_new_socket() -> tcp::Socket<'static> {
568         // 初始化tcp的buffer
569         let rx_buffer = tcp::SocketBuffer::new(vec![0; Self::DEFAULT_RX_BUF_SIZE]);
570         let tx_buffer = tcp::SocketBuffer::new(vec![0; Self::DEFAULT_TX_BUF_SIZE]);
571         tcp::Socket::new(rx_buffer, tx_buffer)
572     }
573 
574     /// listening状态的posix socket是需要特殊处理的
575     fn tcp_poll_listening(&self) -> EPollEventType {
576         let socketset_guard = SOCKET_SET.lock_irqsave();
577 
578         let can_accept = self.handles.iter().any(|h| {
579             if let Some(sh) = h.smoltcp_handle() {
580                 let socket = socketset_guard.get::<tcp::Socket>(sh);
581                 socket.is_active()
582             } else {
583                 false
584             }
585         });
586 
587         if can_accept {
588             return EPollEventType::EPOLL_LISTEN_CAN_ACCEPT;
589         } else {
590             return EPollEventType::empty();
591         }
592     }
593 }
594 
595 impl Socket for TcpSocket {
596     fn posix_item(&self) -> Arc<PosixSocketHandleItem> {
597         self.posix_item.clone()
598     }
599 
600     fn close(&mut self) {
601         for handle in self.handles.iter() {
602             {
603                 let mut socket_set_guard = SOCKET_SET.lock_irqsave();
604                 let smoltcp_handle = handle.smoltcp_handle().unwrap();
605                 socket_set_guard
606                     .get_mut::<smoltcp::socket::tcp::Socket>(smoltcp_handle)
607                     .close();
608                 drop(socket_set_guard);
609             }
610             poll_ifaces();
611             SOCKET_SET
612                 .lock_irqsave()
613                 .remove(handle.smoltcp_handle().unwrap());
614             // debug!("[Socket] [TCP] Close: {:?}", handle);
615         }
616     }
617 
618     fn read(&self, buf: &mut [u8]) -> (Result<usize, SystemError>, Endpoint) {
619         if HANDLE_MAP
620             .read_irqsave()
621             .get(&self.socket_handle())
622             .unwrap()
623             .shutdown_type()
624             .contains(ShutdownType::RCV_SHUTDOWN)
625         {
626             return (Err(SystemError::ENOTCONN), Endpoint::Ip(None));
627         }
628         // debug!("tcp socket: read, buf len={}", buf.len());
629         // debug!("tcp socket:read, socket'len={}",self.handle.len());
630         loop {
631             poll_ifaces();
632             let mut socket_set_guard = SOCKET_SET.lock_irqsave();
633 
634             let socket = socket_set_guard
635                 .get_mut::<tcp::Socket>(self.handles.get(0).unwrap().smoltcp_handle().unwrap());
636 
637             // 如果socket已经关闭,返回错误
638             if !socket.is_active() {
639                 // debug!("Tcp Socket Read Error, socket is closed");
640                 return (Err(SystemError::ENOTCONN), Endpoint::Ip(None));
641             }
642 
643             if socket.may_recv() {
644                 match socket.recv_slice(buf) {
645                     Ok(size) => {
646                         if size > 0 {
647                             let endpoint = if let Some(p) = socket.remote_endpoint() {
648                                 p
649                             } else {
650                                 return (Err(SystemError::ENOTCONN), Endpoint::Ip(None));
651                             };
652 
653                             drop(socket_set_guard);
654                             poll_ifaces();
655                             return (Ok(size), Endpoint::Ip(Some(endpoint)));
656                         }
657                     }
658                     Err(tcp::RecvError::InvalidState) => {
659                         warn!("Tcp Socket Read Error, InvalidState");
660                         return (Err(SystemError::ENOTCONN), Endpoint::Ip(None));
661                     }
662                     Err(tcp::RecvError::Finished) => {
663                         // 对端写端已关闭,我们应该关闭读端
664                         HANDLE_MAP
665                             .write_irqsave()
666                             .get_mut(&self.socket_handle())
667                             .unwrap()
668                             .shutdown_type_writer()
669                             .insert(ShutdownType::RCV_SHUTDOWN);
670                         return (Err(SystemError::ENOTCONN), Endpoint::Ip(None));
671                     }
672                 }
673             } else {
674                 return (Err(SystemError::ENOTCONN), Endpoint::Ip(None));
675             }
676             drop(socket_set_guard);
677             self.posix_item
678                 .sleep((EPollEventType::EPOLLIN | EPollEventType::EPOLLHUP).bits() as u64);
679         }
680     }
681 
682     fn write(&self, buf: &[u8], _to: Option<Endpoint>) -> Result<usize, SystemError> {
683         if HANDLE_MAP
684             .read_irqsave()
685             .get(&self.socket_handle())
686             .unwrap()
687             .shutdown_type()
688             .contains(ShutdownType::RCV_SHUTDOWN)
689         {
690             return Err(SystemError::ENOTCONN);
691         }
692         // debug!("tcp socket:write, socket'len={}",self.handle.len());
693 
694         let mut socket_set_guard = SOCKET_SET.lock_irqsave();
695 
696         let socket = socket_set_guard
697             .get_mut::<tcp::Socket>(self.handles.get(0).unwrap().smoltcp_handle().unwrap());
698 
699         if socket.is_open() {
700             if socket.can_send() {
701                 match socket.send_slice(buf) {
702                     Ok(size) => {
703                         drop(socket_set_guard);
704                         poll_ifaces();
705                         return Ok(size);
706                     }
707                     Err(e) => {
708                         error!("Tcp Socket Write Error {e:?}");
709                         return Err(SystemError::ENOBUFS);
710                     }
711                 }
712             } else {
713                 return Err(SystemError::ENOBUFS);
714             }
715         }
716 
717         return Err(SystemError::ENOTCONN);
718     }
719 
720     fn poll(&self) -> EPollEventType {
721         // 处理listen的快速路径
722         if self.is_listening {
723             return self.tcp_poll_listening();
724         }
725         // 由于上面处理了listening状态,所以这里只处理非listening状态,这种情况下只有一个handle
726 
727         assert!(self.handles.len() == 1);
728 
729         let mut socket_set_guard = SOCKET_SET.lock_irqsave();
730         // debug!("tcp socket:poll, socket'len={}",self.handle.len());
731 
732         let socket = socket_set_guard
733             .get_mut::<tcp::Socket>(self.handles.get(0).unwrap().smoltcp_handle().unwrap());
734         let handle_map_guard = HANDLE_MAP.read_irqsave();
735         let handle_item = handle_map_guard.get(&self.socket_handle()).unwrap();
736         let shutdown_type = handle_item.shutdown_type();
737         let is_posix_listen = handle_item.is_posix_listen;
738         drop(handle_map_guard);
739 
740         return SocketPollMethod::tcp_poll(socket, shutdown_type, is_posix_listen);
741     }
742 
743     fn connect(&mut self, endpoint: Endpoint) -> Result<(), SystemError> {
744         let mut sockets = SOCKET_SET.lock_irqsave();
745         // debug!("tcp socket:connect, socket'len={}", self.handles.len());
746 
747         let socket =
748             sockets.get_mut::<tcp::Socket>(self.handles.get(0).unwrap().smoltcp_handle().unwrap());
749 
750         if let Endpoint::Ip(Some(ip)) = endpoint {
751             let temp_port = PORT_MANAGER.get_ephemeral_port(self.metadata.socket_type)?;
752             // 检测端口是否被占用
753             PORT_MANAGER.bind_port(self.metadata.socket_type, temp_port)?;
754 
755             // debug!("temp_port: {}", temp_port);
756             let iface: Arc<dyn NetDevice> = NET_DEVICES.write_irqsave().get(&0).unwrap().clone();
757             let mut inner_iface = iface.inner_iface().lock();
758             // debug!("to connect: {ip:?}");
759 
760             match socket.connect(inner_iface.context(), ip, temp_port) {
761                 Ok(()) => {
762                     // avoid deadlock
763                     drop(inner_iface);
764                     drop(iface);
765                     drop(sockets);
766                     loop {
767                         poll_ifaces();
768                         let mut sockets = SOCKET_SET.lock_irqsave();
769                         let socket = sockets.get_mut::<tcp::Socket>(
770                             self.handles.get(0).unwrap().smoltcp_handle().unwrap(),
771                         );
772 
773                         match socket.state() {
774                             tcp::State::Established => {
775                                 return Ok(());
776                             }
777                             tcp::State::SynSent => {
778                                 drop(sockets);
779                                 self.posix_item.sleep(Self::CAN_CONNECT);
780                             }
781                             _ => {
782                                 return Err(SystemError::ECONNREFUSED);
783                             }
784                         }
785                     }
786                 }
787                 Err(e) => {
788                     // error!("Tcp Socket Connect Error {e:?}");
789                     match e {
790                         tcp::ConnectError::InvalidState => return Err(SystemError::EISCONN),
791                         tcp::ConnectError::Unaddressable => return Err(SystemError::EADDRNOTAVAIL),
792                     }
793                 }
794             }
795         } else {
796             return Err(SystemError::EINVAL);
797         }
798     }
799 
800     /// @brief tcp socket 监听 local_endpoint 端口
801     ///
802     /// @param backlog 未处理的连接队列的最大长度
803     fn listen(&mut self, backlog: usize) -> Result<(), SystemError> {
804         if self.is_listening {
805             return Ok(());
806         }
807 
808         // debug!(
809         //     "tcp socket:listen, socket'len={}, backlog = {backlog}",
810         //     self.handles.len()
811         // );
812 
813         let local_endpoint = self.local_endpoint.ok_or(SystemError::EINVAL)?;
814         let mut sockets = SOCKET_SET.lock_irqsave();
815         // 获取handle的数量
816         let handlen = self.handles.len();
817         let backlog = handlen.max(backlog);
818 
819         // 添加剩余需要构建的socket
820         // debug!("tcp socket:before listen, socket'len={}", self.handle_list.len());
821         let mut handle_guard = HANDLE_MAP.write_irqsave();
822         let socket_handle_item_0 = handle_guard.get_mut(&self.socket_handle()).unwrap();
823         socket_handle_item_0.is_posix_listen = true;
824 
825         self.handles.extend((handlen..backlog).map(|_| {
826             let socket = Self::create_new_socket();
827             let handle = GlobalSocketHandle::new_smoltcp_handle(sockets.add(socket));
828             let mut handle_item = SocketHandleItem::new(Arc::downgrade(&self.posix_item));
829             handle_item.is_posix_listen = true;
830             handle_guard.insert(handle, handle_item);
831             handle
832         }));
833 
834         // debug!("tcp socket:listen, socket'len={}", self.handles.len());
835         // debug!("tcp socket:listen, backlog={backlog}");
836 
837         // 监听所有的socket
838         for i in 0..backlog {
839             let handle = self.handles.get(i).unwrap();
840 
841             let socket = sockets.get_mut::<tcp::Socket>(handle.smoltcp_handle().unwrap());
842 
843             if !socket.is_listening() {
844                 // debug!("Tcp Socket is already listening on {local_endpoint}");
845                 self.do_listen(socket, local_endpoint)?;
846             }
847             // debug!("Tcp Socket  before listen, open={}", socket.is_open());
848         }
849 
850         return Ok(());
851     }
852 
853     fn bind(&mut self, endpoint: Endpoint) -> Result<(), SystemError> {
854         if let Endpoint::Ip(Some(mut ip)) = endpoint {
855             if ip.port == 0 {
856                 ip.port = PORT_MANAGER.get_ephemeral_port(self.metadata.socket_type)?;
857             }
858 
859             // 检测端口是否已被占用
860             PORT_MANAGER.bind_port(self.metadata.socket_type, ip.port)?;
861             // debug!("tcp socket:bind, socket'len={}",self.handle.len());
862 
863             self.local_endpoint = Some(ip);
864             self.is_listening = false;
865 
866             return Ok(());
867         }
868         return Err(SystemError::EINVAL);
869     }
870 
871     fn shutdown(&mut self, shutdown_type: super::ShutdownType) -> Result<(), SystemError> {
872         // TODO:目前只是在表层判断,对端不知晓,后续需使用tcp实现
873         HANDLE_MAP
874             .write_irqsave()
875             .get_mut(&self.socket_handle())
876             .unwrap()
877             .shutdown_type = RwLock::new(shutdown_type);
878         return Ok(());
879     }
880 
881     fn accept(&mut self) -> Result<(Box<dyn Socket>, Endpoint), SystemError> {
882         if !self.is_listening {
883             return Err(SystemError::EINVAL);
884         }
885         let endpoint = self.local_endpoint.ok_or(SystemError::EINVAL)?;
886         loop {
887             // debug!("tcp accept: poll_ifaces()");
888             poll_ifaces();
889             // debug!("tcp socket:accept, socket'len={}", self.handle_list.len());
890 
891             let mut sockset = SOCKET_SET.lock_irqsave();
892             // Get the corresponding activated handler
893             let global_handle_index = self.handles.iter().position(|handle| {
894                 let con_smol_sock = sockset.get::<tcp::Socket>(handle.smoltcp_handle().unwrap());
895                 con_smol_sock.is_active()
896             });
897 
898             if let Some(handle_index) = global_handle_index {
899                 let con_smol_sock = sockset
900                     .get::<tcp::Socket>(self.handles[handle_index].smoltcp_handle().unwrap());
901 
902                 // debug!("[Socket] [TCP] Accept: {:?}", handle);
903                 // handle is connected socket's handle
904                 let remote_ep = con_smol_sock
905                     .remote_endpoint()
906                     .ok_or(SystemError::ENOTCONN)?;
907 
908                 let tcp_socket = Self::create_new_socket();
909 
910                 let new_handle = GlobalSocketHandle::new_smoltcp_handle(sockset.add(tcp_socket));
911 
912                 // let handle in TcpSock be the new empty handle, and return the old connected handle
913                 let old_handle = core::mem::replace(&mut self.handles[handle_index], new_handle);
914 
915                 let metadata = SocketMetadata::new(
916                     SocketType::Tcp,
917                     Self::DEFAULT_TX_BUF_SIZE,
918                     Self::DEFAULT_RX_BUF_SIZE,
919                     Self::DEFAULT_METADATA_BUF_SIZE,
920                     self.metadata.options,
921                 );
922 
923                 let sock_ret = Box::new(TcpSocket {
924                     handles: vec![old_handle],
925                     local_endpoint: self.local_endpoint,
926                     is_listening: false,
927                     metadata,
928                     posix_item: Arc::new(PosixSocketHandleItem::new(None)),
929                 });
930 
931                 {
932                     let mut handle_guard = HANDLE_MAP.write_irqsave();
933                     // 先删除原来的
934                     let item = handle_guard.remove(&old_handle).unwrap();
935                     item.reset_shutdown_type();
936                     assert!(item.is_posix_listen);
937 
938                     // 按照smoltcp行为,将新的handle绑定到原来的item
939                     let new_item = SocketHandleItem::new(Arc::downgrade(&sock_ret.posix_item));
940                     handle_guard.insert(old_handle, new_item);
941                     // 插入新的item
942                     handle_guard.insert(new_handle, item);
943 
944                     let socket = sockset.get_mut::<tcp::Socket>(
945                         self.handles[handle_index].smoltcp_handle().unwrap(),
946                     );
947 
948                     if !socket.is_listening() {
949                         self.do_listen(socket, endpoint)?;
950                     }
951 
952                     drop(handle_guard);
953                 }
954 
955                 return Ok((sock_ret, Endpoint::Ip(Some(remote_ep))));
956             }
957 
958             drop(sockset);
959 
960             // debug!("[TCP] [Accept] sleeping socket with handle: {:?}", self.handles.get(0).unwrap().smoltcp_handle().unwrap());
961             self.posix_item.sleep(Self::CAN_ACCPET);
962             // debug!("tcp socket:after sleep, handle_guard'len={}",HANDLE_MAP.write_irqsave().len());
963         }
964     }
965 
966     fn endpoint(&self) -> Option<Endpoint> {
967         let mut result: Option<Endpoint> = self.local_endpoint.map(|x| Endpoint::Ip(Some(x)));
968 
969         if result.is_none() {
970             let sockets = SOCKET_SET.lock_irqsave();
971             // debug!("tcp socket:endpoint, socket'len={}",self.handle.len());
972 
973             let socket =
974                 sockets.get::<tcp::Socket>(self.handles.get(0).unwrap().smoltcp_handle().unwrap());
975             if let Some(ep) = socket.local_endpoint() {
976                 result = Some(Endpoint::Ip(Some(ep)));
977             }
978         }
979         return result;
980     }
981 
982     fn peer_endpoint(&self) -> Option<Endpoint> {
983         let sockets = SOCKET_SET.lock_irqsave();
984         // debug!("tcp socket:peer_endpoint, socket'len={}",self.handle.len());
985 
986         let socket =
987             sockets.get::<tcp::Socket>(self.handles.get(0).unwrap().smoltcp_handle().unwrap());
988         return socket.remote_endpoint().map(|x| Endpoint::Ip(Some(x)));
989     }
990 
991     fn metadata(&self) -> SocketMetadata {
992         self.metadata.clone()
993     }
994 
995     fn box_clone(&self) -> Box<dyn Socket> {
996         Box::new(self.clone())
997     }
998 
999     fn socket_handle(&self) -> GlobalSocketHandle {
1000         // debug!("tcp socket:socket_handle, socket'len={}",self.handle.len());
1001 
1002         *self.handles.get(0).unwrap()
1003     }
1004 
1005     fn as_any_ref(&self) -> &dyn core::any::Any {
1006         self
1007     }
1008 
1009     fn as_any_mut(&mut self) -> &mut dyn core::any::Any {
1010         self
1011     }
1012 }
1013