xref: /DragonOS/kernel/src/net/socket/inet.rs (revision 1ea2daad8121b77ed704e6d7c3a09f478147441d)
1 use alloc::{boxed::Box, sync::Arc, vec::Vec};
2 use log::{error, warn};
3 use smoltcp::{
4     socket::{raw, tcp, udp},
5     wire,
6 };
7 use system_error::SystemError;
8 
9 use crate::{
10     driver::net::NetDevice,
11     libs::rwlock::RwLock,
12     net::{
13         event_poll::EPollEventType, net_core::poll_ifaces, Endpoint, Protocol, ShutdownType,
14         NET_DEVICES,
15     },
16 };
17 
18 use super::{
19     handle::GlobalSocketHandle, Socket, SocketHandleItem, SocketMetadata, SocketOptions,
20     SocketPollMethod, SocketType, HANDLE_MAP, PORT_MANAGER, SOCKET_SET,
21 };
22 
23 /// @brief 表示原始的socket。原始套接字绕过传输层协议(如 TCP 或 UDP)并提供对网络层协议(如 IP)的直接访问。
24 ///
25 /// ref: https://man7.org/linux/man-pages/man7/raw.7.html
26 #[derive(Debug, Clone)]
27 pub struct RawSocket {
28     handle: GlobalSocketHandle,
29     /// 用户发送的数据包是否包含了IP头.
30     /// 如果是true,用户发送的数据包,必须包含IP头。(即用户要自行设置IP头+数据)
31     /// 如果是false,用户发送的数据包,不包含IP头。(即用户只要设置数据)
32     header_included: bool,
33     /// socket的metadata
34     metadata: SocketMetadata,
35 }
36 
37 impl RawSocket {
38     /// 元数据的缓冲区的大小
39     pub const DEFAULT_METADATA_BUF_SIZE: usize = 1024;
40     /// 默认的接收缓冲区的大小 receive
41     pub const DEFAULT_RX_BUF_SIZE: usize = 64 * 1024;
42     /// 默认的发送缓冲区的大小 transmiss
43     pub const DEFAULT_TX_BUF_SIZE: usize = 64 * 1024;
44 
45     /// @brief 创建一个原始的socket
46     ///
47     /// @param protocol 协议号
48     /// @param options socket的选项
49     ///
50     /// @return 返回创建的原始的socket
51     pub fn new(protocol: Protocol, options: SocketOptions) -> Self {
52         let rx_buffer = raw::PacketBuffer::new(
53             vec![raw::PacketMetadata::EMPTY; Self::DEFAULT_METADATA_BUF_SIZE],
54             vec![0; Self::DEFAULT_RX_BUF_SIZE],
55         );
56         let tx_buffer = raw::PacketBuffer::new(
57             vec![raw::PacketMetadata::EMPTY; Self::DEFAULT_METADATA_BUF_SIZE],
58             vec![0; Self::DEFAULT_TX_BUF_SIZE],
59         );
60         let protocol: u8 = protocol.into();
61         let socket = raw::Socket::new(
62             wire::IpVersion::Ipv4,
63             wire::IpProtocol::from(protocol),
64             rx_buffer,
65             tx_buffer,
66         );
67 
68         // 把socket添加到socket集合中,并得到socket的句柄
69         let handle = GlobalSocketHandle::new_smoltcp_handle(SOCKET_SET.lock_irqsave().add(socket));
70 
71         let metadata = SocketMetadata::new(
72             SocketType::Raw,
73             Self::DEFAULT_RX_BUF_SIZE,
74             Self::DEFAULT_TX_BUF_SIZE,
75             Self::DEFAULT_METADATA_BUF_SIZE,
76             options,
77         );
78 
79         return Self {
80             handle,
81             header_included: false,
82             metadata,
83         };
84     }
85 }
86 
87 impl Socket for RawSocket {
88     fn close(&mut self) {
89         let mut socket_set_guard = SOCKET_SET.lock_irqsave();
90         if let smoltcp::socket::Socket::Udp(mut sock) =
91             socket_set_guard.remove(self.handle.smoltcp_handle().unwrap())
92         {
93             sock.close();
94         }
95         drop(socket_set_guard);
96         poll_ifaces();
97     }
98 
99     fn read(&self, buf: &mut [u8]) -> (Result<usize, SystemError>, Endpoint) {
100         poll_ifaces();
101         loop {
102             // 如何优化这里?
103             let mut socket_set_guard = SOCKET_SET.lock_irqsave();
104             let socket =
105                 socket_set_guard.get_mut::<raw::Socket>(self.handle.smoltcp_handle().unwrap());
106 
107             match socket.recv_slice(buf) {
108                 Ok(len) => {
109                     let packet = wire::Ipv4Packet::new_unchecked(buf);
110                     return (
111                         Ok(len),
112                         Endpoint::Ip(Some(wire::IpEndpoint {
113                             addr: wire::IpAddress::Ipv4(packet.src_addr()),
114                             port: 0,
115                         })),
116                     );
117                 }
118                 Err(_) => {
119                     if !self.metadata.options.contains(SocketOptions::BLOCK) {
120                         // 如果是非阻塞的socket,就返回错误
121                         return (Err(SystemError::EAGAIN_OR_EWOULDBLOCK), Endpoint::Ip(None));
122                     }
123                 }
124             }
125             drop(socket_set_guard);
126             SocketHandleItem::sleep(
127                 self.socket_handle(),
128                 EPollEventType::EPOLLIN.bits() as u64,
129                 HANDLE_MAP.read_irqsave(),
130             );
131         }
132     }
133 
134     fn write(&self, buf: &[u8], to: Option<Endpoint>) -> Result<usize, SystemError> {
135         // 如果用户发送的数据包,包含IP头,则直接发送
136         if self.header_included {
137             let mut socket_set_guard = SOCKET_SET.lock_irqsave();
138             let socket =
139                 socket_set_guard.get_mut::<raw::Socket>(self.handle.smoltcp_handle().unwrap());
140             match socket.send_slice(buf) {
141                 Ok(_) => {
142                     return Ok(buf.len());
143                 }
144                 Err(raw::SendError::BufferFull) => {
145                     return Err(SystemError::ENOBUFS);
146                 }
147             }
148         } else {
149             // 如果用户发送的数据包,不包含IP头,则需要自己构造IP头
150 
151             if let Some(Endpoint::Ip(Some(endpoint))) = to {
152                 let mut socket_set_guard = SOCKET_SET.lock_irqsave();
153                 let socket: &mut raw::Socket =
154                     socket_set_guard.get_mut::<raw::Socket>(self.handle.smoltcp_handle().unwrap());
155 
156                 // 暴力解决方案:只考虑0号网卡。 TODO:考虑多网卡的情况!!!
157                 let iface = NET_DEVICES.read_irqsave().get(&0).unwrap().clone();
158 
159                 // 构造IP头
160                 let ipv4_src_addr: Option<wire::Ipv4Address> =
161                     iface.inner_iface().lock().ipv4_addr();
162                 if ipv4_src_addr.is_none() {
163                     return Err(SystemError::ENETUNREACH);
164                 }
165                 let ipv4_src_addr = ipv4_src_addr.unwrap();
166 
167                 if let wire::IpAddress::Ipv4(ipv4_dst) = endpoint.addr {
168                     let len = buf.len();
169 
170                     // 创建20字节的IPv4头部
171                     let mut buffer: Vec<u8> = vec![0u8; len + 20];
172                     let mut packet: wire::Ipv4Packet<&mut Vec<u8>> =
173                         wire::Ipv4Packet::new_unchecked(&mut buffer);
174 
175                     // 封装ipv4 header
176                     packet.set_version(4);
177                     packet.set_header_len(20);
178                     packet.set_total_len((20 + len) as u16);
179                     packet.set_src_addr(ipv4_src_addr);
180                     packet.set_dst_addr(ipv4_dst);
181 
182                     // 设置ipv4 header的protocol字段
183                     packet.set_next_header(socket.ip_protocol());
184 
185                     // 获取IP数据包的负载字段
186                     let payload: &mut [u8] = packet.payload_mut();
187                     payload.copy_from_slice(buf);
188 
189                     // 填充checksum字段
190                     packet.fill_checksum();
191 
192                     // 发送数据包
193                     socket.send_slice(&buffer).unwrap();
194 
195                     iface.poll(&mut socket_set_guard).ok();
196 
197                     drop(socket_set_guard);
198                     return Ok(len);
199                 } else {
200                     warn!("Unsupport Ip protocol type!");
201                     return Err(SystemError::EINVAL);
202                 }
203             } else {
204                 // 如果没有指定目的地址,则返回错误
205                 return Err(SystemError::ENOTCONN);
206             }
207         }
208     }
209 
210     fn connect(&mut self, _endpoint: Endpoint) -> Result<(), SystemError> {
211         Ok(())
212     }
213 
214     fn metadata(&self) -> SocketMetadata {
215         self.metadata.clone()
216     }
217 
218     fn box_clone(&self) -> Box<dyn Socket> {
219         Box::new(self.clone())
220     }
221 
222     fn socket_handle(&self) -> GlobalSocketHandle {
223         self.handle
224     }
225 
226     fn as_any_ref(&self) -> &dyn core::any::Any {
227         self
228     }
229 
230     fn as_any_mut(&mut self) -> &mut dyn core::any::Any {
231         self
232     }
233 }
234 
235 /// @brief 表示udp socket
236 ///
237 /// https://man7.org/linux/man-pages/man7/udp.7.html
238 #[derive(Debug, Clone)]
239 pub struct UdpSocket {
240     pub handle: GlobalSocketHandle,
241     remote_endpoint: Option<Endpoint>, // 记录远程endpoint提供给connect(), 应该使用IP地址。
242     metadata: SocketMetadata,
243 }
244 
245 impl UdpSocket {
246     /// 元数据的缓冲区的大小
247     pub const DEFAULT_METADATA_BUF_SIZE: usize = 1024;
248     /// 默认的接收缓冲区的大小 receive
249     pub const DEFAULT_RX_BUF_SIZE: usize = 64 * 1024;
250     /// 默认的发送缓冲区的大小 transmiss
251     pub const DEFAULT_TX_BUF_SIZE: usize = 64 * 1024;
252 
253     /// @brief 创建一个udp的socket
254     ///
255     /// @param options socket的选项
256     ///
257     /// @return 返回创建的udp的socket
258     pub fn new(options: SocketOptions) -> Self {
259         let rx_buffer = udp::PacketBuffer::new(
260             vec![udp::PacketMetadata::EMPTY; Self::DEFAULT_METADATA_BUF_SIZE],
261             vec![0; Self::DEFAULT_RX_BUF_SIZE],
262         );
263         let tx_buffer = udp::PacketBuffer::new(
264             vec![udp::PacketMetadata::EMPTY; Self::DEFAULT_METADATA_BUF_SIZE],
265             vec![0; Self::DEFAULT_TX_BUF_SIZE],
266         );
267         let socket = udp::Socket::new(rx_buffer, tx_buffer);
268 
269         // 把socket添加到socket集合中,并得到socket的句柄
270         let handle: GlobalSocketHandle =
271             GlobalSocketHandle::new_smoltcp_handle(SOCKET_SET.lock_irqsave().add(socket));
272 
273         let metadata = SocketMetadata::new(
274             SocketType::Udp,
275             Self::DEFAULT_RX_BUF_SIZE,
276             Self::DEFAULT_TX_BUF_SIZE,
277             Self::DEFAULT_METADATA_BUF_SIZE,
278             options,
279         );
280 
281         return Self {
282             handle,
283             remote_endpoint: None,
284             metadata,
285         };
286     }
287 
288     fn do_bind(&self, socket: &mut udp::Socket, endpoint: Endpoint) -> Result<(), SystemError> {
289         if let Endpoint::Ip(Some(mut ip)) = endpoint {
290             // 端口为0则分配随机端口
291             if ip.port == 0 {
292                 ip.port = PORT_MANAGER.get_ephemeral_port(self.metadata.socket_type)?;
293             }
294             // 检测端口是否已被占用
295             PORT_MANAGER.bind_port(self.metadata.socket_type, ip.port)?;
296 
297             let bind_res = if ip.addr.is_unspecified() {
298                 socket.bind(ip.port)
299             } else {
300                 socket.bind(ip)
301             };
302 
303             match bind_res {
304                 Ok(()) => return Ok(()),
305                 Err(_) => return Err(SystemError::EINVAL),
306             }
307         } else {
308             return Err(SystemError::EINVAL);
309         }
310     }
311 }
312 
313 impl Socket for UdpSocket {
314     fn close(&mut self) {
315         let mut socket_set_guard = SOCKET_SET.lock_irqsave();
316         if let smoltcp::socket::Socket::Udp(mut sock) =
317             socket_set_guard.remove(self.handle.smoltcp_handle().unwrap())
318         {
319             sock.close();
320         }
321         drop(socket_set_guard);
322         poll_ifaces();
323     }
324 
325     /// @brief 在read函数执行之前,请先bind到本地的指定端口
326     fn read(&self, buf: &mut [u8]) -> (Result<usize, SystemError>, Endpoint) {
327         loop {
328             // debug!("Wait22 to Read");
329             poll_ifaces();
330             let mut socket_set_guard = SOCKET_SET.lock_irqsave();
331             let socket =
332                 socket_set_guard.get_mut::<udp::Socket>(self.handle.smoltcp_handle().unwrap());
333 
334             // debug!("Wait to Read");
335 
336             if socket.can_recv() {
337                 if let Ok((size, metadata)) = socket.recv_slice(buf) {
338                     drop(socket_set_guard);
339                     poll_ifaces();
340                     return (Ok(size), Endpoint::Ip(Some(metadata.endpoint)));
341                 }
342             } else {
343                 // 如果socket没有连接,则忙等
344                 // return (Err(SystemError::ENOTCONN), Endpoint::Ip(None));
345             }
346             drop(socket_set_guard);
347             SocketHandleItem::sleep(
348                 self.socket_handle(),
349                 EPollEventType::EPOLLIN.bits() as u64,
350                 HANDLE_MAP.read_irqsave(),
351             );
352         }
353     }
354 
355     fn write(&self, buf: &[u8], to: Option<Endpoint>) -> Result<usize, SystemError> {
356         // debug!("udp to send: {:?}, len={}", to, buf.len());
357         let remote_endpoint: &wire::IpEndpoint = {
358             if let Some(Endpoint::Ip(Some(ref endpoint))) = to {
359                 endpoint
360             } else if let Some(Endpoint::Ip(Some(ref endpoint))) = self.remote_endpoint {
361                 endpoint
362             } else {
363                 return Err(SystemError::ENOTCONN);
364             }
365         };
366         // debug!("udp write: remote = {:?}", remote_endpoint);
367 
368         let mut socket_set_guard = SOCKET_SET.lock_irqsave();
369         let socket = socket_set_guard.get_mut::<udp::Socket>(self.handle.smoltcp_handle().unwrap());
370         // debug!("is open()={}", socket.is_open());
371         // debug!("socket endpoint={:?}", socket.endpoint());
372         if socket.can_send() {
373             // debug!("udp write: can send");
374             match socket.send_slice(buf, *remote_endpoint) {
375                 Ok(()) => {
376                     // debug!("udp write: send ok");
377                     drop(socket_set_guard);
378                     poll_ifaces();
379                     return Ok(buf.len());
380                 }
381                 Err(_) => {
382                     // debug!("udp write: send err");
383                     return Err(SystemError::ENOBUFS);
384                 }
385             }
386         } else {
387             // debug!("udp write: can not send");
388             return Err(SystemError::ENOBUFS);
389         };
390     }
391 
392     fn bind(&mut self, endpoint: Endpoint) -> Result<(), SystemError> {
393         let mut sockets = SOCKET_SET.lock_irqsave();
394         let socket = sockets.get_mut::<udp::Socket>(self.handle.smoltcp_handle().unwrap());
395         // debug!("UDP Bind to {:?}", endpoint);
396         return self.do_bind(socket, endpoint);
397     }
398 
399     fn poll(&self) -> EPollEventType {
400         let sockets = SOCKET_SET.lock_irqsave();
401         let socket = sockets.get::<udp::Socket>(self.handle.smoltcp_handle().unwrap());
402 
403         return SocketPollMethod::udp_poll(
404             socket,
405             HANDLE_MAP
406                 .read_irqsave()
407                 .get(&self.socket_handle())
408                 .unwrap()
409                 .shutdown_type(),
410         );
411     }
412 
413     fn connect(&mut self, endpoint: Endpoint) -> Result<(), SystemError> {
414         if let Endpoint::Ip(_) = endpoint {
415             self.remote_endpoint = Some(endpoint);
416             Ok(())
417         } else {
418             Err(SystemError::EINVAL)
419         }
420     }
421 
422     fn ioctl(
423         &self,
424         _cmd: usize,
425         _arg0: usize,
426         _arg1: usize,
427         _arg2: usize,
428     ) -> Result<usize, SystemError> {
429         todo!()
430     }
431 
432     fn metadata(&self) -> SocketMetadata {
433         self.metadata.clone()
434     }
435 
436     fn box_clone(&self) -> Box<dyn Socket> {
437         return Box::new(self.clone());
438     }
439 
440     fn endpoint(&self) -> Option<Endpoint> {
441         let sockets = SOCKET_SET.lock_irqsave();
442         let socket = sockets.get::<udp::Socket>(self.handle.smoltcp_handle().unwrap());
443         let listen_endpoint = socket.endpoint();
444 
445         if listen_endpoint.port == 0 {
446             return None;
447         } else {
448             // 如果listen_endpoint的address是None,意味着“监听所有的地址”。
449             // 这里假设所有的地址都是ipv4
450             // TODO: 支持ipv6
451             let result = wire::IpEndpoint::new(
452                 listen_endpoint
453                     .addr
454                     .unwrap_or(wire::IpAddress::v4(0, 0, 0, 0)),
455                 listen_endpoint.port,
456             );
457             return Some(Endpoint::Ip(Some(result)));
458         }
459     }
460 
461     fn peer_endpoint(&self) -> Option<Endpoint> {
462         return self.remote_endpoint.clone();
463     }
464 
465     fn socket_handle(&self) -> GlobalSocketHandle {
466         self.handle
467     }
468 
469     fn as_any_ref(&self) -> &dyn core::any::Any {
470         self
471     }
472 
473     fn as_any_mut(&mut self) -> &mut dyn core::any::Any {
474         self
475     }
476 }
477 
478 /// @brief 表示 tcp socket
479 ///
480 /// https://man7.org/linux/man-pages/man7/tcp.7.html
481 #[derive(Debug, Clone)]
482 pub struct TcpSocket {
483     handles: Vec<GlobalSocketHandle>,
484     local_endpoint: Option<wire::IpEndpoint>, // save local endpoint for bind()
485     is_listening: bool,
486     metadata: SocketMetadata,
487 }
488 
489 impl TcpSocket {
490     /// 元数据的缓冲区的大小
491     pub const DEFAULT_METADATA_BUF_SIZE: usize = 1024;
492     /// 默认的接收缓冲区的大小 receive
493     pub const DEFAULT_RX_BUF_SIZE: usize = 512 * 1024;
494     /// 默认的发送缓冲区的大小 transmiss
495     pub const DEFAULT_TX_BUF_SIZE: usize = 512 * 1024;
496 
497     /// TcpSocket的特殊事件,用于在事件等待队列上sleep
498     pub const CAN_CONNECT: u64 = 1u64 << 63;
499     pub const CAN_ACCPET: u64 = 1u64 << 62;
500 
501     /// @brief 创建一个tcp的socket
502     ///
503     /// @param options socket的选项
504     ///
505     /// @return 返回创建的tcp的socket
506     pub fn new(options: SocketOptions) -> Self {
507         // 创建handles数组并把socket添加到socket集合中,并得到socket的句柄
508         let handles: Vec<GlobalSocketHandle> = vec![GlobalSocketHandle::new_smoltcp_handle(
509             SOCKET_SET.lock_irqsave().add(Self::create_new_socket()),
510         )];
511 
512         let metadata = SocketMetadata::new(
513             SocketType::Tcp,
514             Self::DEFAULT_RX_BUF_SIZE,
515             Self::DEFAULT_TX_BUF_SIZE,
516             Self::DEFAULT_METADATA_BUF_SIZE,
517             options,
518         );
519         // debug!("when there's a new tcp socket,its'len: {}",handles.len());
520 
521         return Self {
522             handles,
523             local_endpoint: None,
524             is_listening: false,
525             metadata,
526         };
527     }
528 
529     fn do_listen(
530         &mut self,
531         socket: &mut tcp::Socket,
532         local_endpoint: wire::IpEndpoint,
533     ) -> Result<(), SystemError> {
534         let listen_result = if local_endpoint.addr.is_unspecified() {
535             // debug!("Tcp Socket Listen on port {}", local_endpoint.port);
536             socket.listen(local_endpoint.port)
537         } else {
538             // debug!("Tcp Socket Listen on {local_endpoint}");
539             socket.listen(local_endpoint)
540         };
541         return match listen_result {
542             Ok(()) => {
543                 // debug!(
544                 //     "Tcp Socket Listen on {local_endpoint}, open?:{}",
545                 //     socket.is_open()
546                 // );
547                 self.is_listening = true;
548 
549                 Ok(())
550             }
551             Err(_) => Err(SystemError::EINVAL),
552         };
553     }
554 
555     /// # create_new_socket - 创建新的TCP套接字
556     ///
557     /// 该函数用于创建一个新的TCP套接字,并返回该套接字的引用。
558     fn create_new_socket() -> tcp::Socket<'static> {
559         // 初始化tcp的buffer
560         let rx_buffer = tcp::SocketBuffer::new(vec![0; Self::DEFAULT_RX_BUF_SIZE]);
561         let tx_buffer = tcp::SocketBuffer::new(vec![0; Self::DEFAULT_TX_BUF_SIZE]);
562         tcp::Socket::new(rx_buffer, tx_buffer)
563     }
564 }
565 
566 impl Socket for TcpSocket {
567     fn close(&mut self) {
568         for handle in self.handles.iter() {
569             {
570                 let mut socket_set_guard = SOCKET_SET.lock_irqsave();
571                 let smoltcp_handle = handle.smoltcp_handle().unwrap();
572                 socket_set_guard
573                     .get_mut::<smoltcp::socket::tcp::Socket>(smoltcp_handle)
574                     .close();
575                 drop(socket_set_guard);
576             }
577             poll_ifaces();
578             SOCKET_SET
579                 .lock_irqsave()
580                 .remove(handle.smoltcp_handle().unwrap());
581             // debug!("[Socket] [TCP] Close: {:?}", handle);
582         }
583     }
584 
585     fn read(&self, buf: &mut [u8]) -> (Result<usize, SystemError>, Endpoint) {
586         if HANDLE_MAP
587             .read_irqsave()
588             .get(&self.socket_handle())
589             .unwrap()
590             .shutdown_type()
591             .contains(ShutdownType::RCV_SHUTDOWN)
592         {
593             return (Err(SystemError::ENOTCONN), Endpoint::Ip(None));
594         }
595         // debug!("tcp socket: read, buf len={}", buf.len());
596         // debug!("tcp socket:read, socket'len={}",self.handle.len());
597         loop {
598             poll_ifaces();
599             let mut socket_set_guard = SOCKET_SET.lock_irqsave();
600 
601             let socket = socket_set_guard
602                 .get_mut::<tcp::Socket>(self.handles.get(0).unwrap().smoltcp_handle().unwrap());
603 
604             // 如果socket已经关闭,返回错误
605             if !socket.is_active() {
606                 // debug!("Tcp Socket Read Error, socket is closed");
607                 return (Err(SystemError::ENOTCONN), Endpoint::Ip(None));
608             }
609 
610             if socket.may_recv() {
611                 match socket.recv_slice(buf) {
612                     Ok(size) => {
613                         if size > 0 {
614                             let endpoint = if let Some(p) = socket.remote_endpoint() {
615                                 p
616                             } else {
617                                 return (Err(SystemError::ENOTCONN), Endpoint::Ip(None));
618                             };
619 
620                             drop(socket_set_guard);
621                             poll_ifaces();
622                             return (Ok(size), Endpoint::Ip(Some(endpoint)));
623                         }
624                     }
625                     Err(tcp::RecvError::InvalidState) => {
626                         warn!("Tcp Socket Read Error, InvalidState");
627                         return (Err(SystemError::ENOTCONN), Endpoint::Ip(None));
628                     }
629                     Err(tcp::RecvError::Finished) => {
630                         // 对端写端已关闭,我们应该关闭读端
631                         HANDLE_MAP
632                             .write_irqsave()
633                             .get_mut(&self.socket_handle())
634                             .unwrap()
635                             .shutdown_type_writer()
636                             .insert(ShutdownType::RCV_SHUTDOWN);
637                         return (Err(SystemError::ENOTCONN), Endpoint::Ip(None));
638                     }
639                 }
640             } else {
641                 return (Err(SystemError::ENOTCONN), Endpoint::Ip(None));
642             }
643             drop(socket_set_guard);
644             SocketHandleItem::sleep(
645                 self.socket_handle(),
646                 (EPollEventType::EPOLLIN.bits() | EPollEventType::EPOLLHUP.bits()) as u64,
647                 HANDLE_MAP.read_irqsave(),
648             );
649         }
650     }
651 
652     fn write(&self, buf: &[u8], _to: Option<Endpoint>) -> Result<usize, SystemError> {
653         if HANDLE_MAP
654             .read_irqsave()
655             .get(&self.socket_handle())
656             .unwrap()
657             .shutdown_type()
658             .contains(ShutdownType::RCV_SHUTDOWN)
659         {
660             return Err(SystemError::ENOTCONN);
661         }
662         // debug!("tcp socket:write, socket'len={}",self.handle.len());
663 
664         let mut socket_set_guard = SOCKET_SET.lock_irqsave();
665 
666         let socket = socket_set_guard
667             .get_mut::<tcp::Socket>(self.handles.get(0).unwrap().smoltcp_handle().unwrap());
668 
669         if socket.is_open() {
670             if socket.can_send() {
671                 match socket.send_slice(buf) {
672                     Ok(size) => {
673                         drop(socket_set_guard);
674                         poll_ifaces();
675                         return Ok(size);
676                     }
677                     Err(e) => {
678                         error!("Tcp Socket Write Error {e:?}");
679                         return Err(SystemError::ENOBUFS);
680                     }
681                 }
682             } else {
683                 return Err(SystemError::ENOBUFS);
684             }
685         }
686 
687         return Err(SystemError::ENOTCONN);
688     }
689 
690     fn poll(&self) -> EPollEventType {
691         let mut socket_set_guard = SOCKET_SET.lock_irqsave();
692         // debug!("tcp socket:poll, socket'len={}",self.handle.len());
693 
694         let socket = socket_set_guard
695             .get_mut::<tcp::Socket>(self.handles.get(0).unwrap().smoltcp_handle().unwrap());
696         return SocketPollMethod::tcp_poll(
697             socket,
698             HANDLE_MAP
699                 .read_irqsave()
700                 .get(&self.socket_handle())
701                 .unwrap()
702                 .shutdown_type(),
703         );
704     }
705 
706     fn connect(&mut self, endpoint: Endpoint) -> Result<(), SystemError> {
707         let mut sockets = SOCKET_SET.lock_irqsave();
708         // debug!("tcp socket:connect, socket'len={}",self.handle.len());
709 
710         let socket =
711             sockets.get_mut::<tcp::Socket>(self.handles.get(0).unwrap().smoltcp_handle().unwrap());
712 
713         if let Endpoint::Ip(Some(ip)) = endpoint {
714             let temp_port = PORT_MANAGER.get_ephemeral_port(self.metadata.socket_type)?;
715             // 检测端口是否被占用
716             PORT_MANAGER.bind_port(self.metadata.socket_type, temp_port)?;
717 
718             // debug!("temp_port: {}", temp_port);
719             let iface: Arc<dyn NetDevice> = NET_DEVICES.write_irqsave().get(&0).unwrap().clone();
720             let mut inner_iface = iface.inner_iface().lock();
721             // debug!("to connect: {ip:?}");
722 
723             match socket.connect(inner_iface.context(), ip, temp_port) {
724                 Ok(()) => {
725                     // avoid deadlock
726                     drop(inner_iface);
727                     drop(iface);
728                     drop(sockets);
729                     loop {
730                         poll_ifaces();
731                         let mut sockets = SOCKET_SET.lock_irqsave();
732                         let socket = sockets.get_mut::<tcp::Socket>(
733                             self.handles.get(0).unwrap().smoltcp_handle().unwrap(),
734                         );
735 
736                         match socket.state() {
737                             tcp::State::Established => {
738                                 return Ok(());
739                             }
740                             tcp::State::SynSent => {
741                                 drop(sockets);
742                                 SocketHandleItem::sleep(
743                                     self.socket_handle(),
744                                     Self::CAN_CONNECT,
745                                     HANDLE_MAP.read_irqsave(),
746                                 );
747                             }
748                             _ => {
749                                 return Err(SystemError::ECONNREFUSED);
750                             }
751                         }
752                     }
753                 }
754                 Err(e) => {
755                     // error!("Tcp Socket Connect Error {e:?}");
756                     match e {
757                         tcp::ConnectError::InvalidState => return Err(SystemError::EISCONN),
758                         tcp::ConnectError::Unaddressable => return Err(SystemError::EADDRNOTAVAIL),
759                     }
760                 }
761             }
762         } else {
763             return Err(SystemError::EINVAL);
764         }
765     }
766 
767     /// @brief tcp socket 监听 local_endpoint 端口
768     ///
769     /// @param backlog 未处理的连接队列的最大长度
770     fn listen(&mut self, backlog: usize) -> Result<(), SystemError> {
771         if self.is_listening {
772             return Ok(());
773         }
774 
775         let local_endpoint = self.local_endpoint.ok_or(SystemError::EINVAL)?;
776         let mut sockets = SOCKET_SET.lock_irqsave();
777         // 获取handle的数量
778         let handlen = self.handles.len();
779         let backlog = handlen.max(backlog);
780 
781         // 添加剩余需要构建的socket
782         // debug!("tcp socket:before listen, socket'len={}", self.handle_list.len());
783         let mut handle_guard = HANDLE_MAP.write_irqsave();
784         let wait_queue = Arc::clone(&handle_guard.get(&self.socket_handle()).unwrap().wait_queue);
785 
786         self.handles.extend((handlen..backlog).map(|_| {
787             let socket = Self::create_new_socket();
788             let handle = GlobalSocketHandle::new_smoltcp_handle(sockets.add(socket));
789             let handle_item = SocketHandleItem::new(Some(wait_queue.clone()));
790             handle_guard.insert(handle, handle_item);
791             handle
792         }));
793         // debug!("tcp socket:listen, socket'len={}",self.handle.len());
794         // debug!("tcp socket:listen, backlog={backlog}");
795 
796         // 监听所有的socket
797         for i in 0..backlog {
798             let handle = self.handles.get(i).unwrap();
799 
800             let socket = sockets.get_mut::<tcp::Socket>(handle.smoltcp_handle().unwrap());
801 
802             if !socket.is_listening() {
803                 // debug!("Tcp Socket is already listening on {local_endpoint}");
804                 self.do_listen(socket, local_endpoint)?;
805             }
806             // debug!("Tcp Socket  before listen, open={}", socket.is_open());
807         }
808         return Ok(());
809     }
810 
811     fn bind(&mut self, endpoint: Endpoint) -> Result<(), SystemError> {
812         if let Endpoint::Ip(Some(mut ip)) = endpoint {
813             if ip.port == 0 {
814                 ip.port = PORT_MANAGER.get_ephemeral_port(self.metadata.socket_type)?;
815             }
816 
817             // 检测端口是否已被占用
818             PORT_MANAGER.bind_port(self.metadata.socket_type, ip.port)?;
819             // debug!("tcp socket:bind, socket'len={}",self.handle.len());
820 
821             self.local_endpoint = Some(ip);
822             self.is_listening = false;
823             return Ok(());
824         }
825         return Err(SystemError::EINVAL);
826     }
827 
828     fn shutdown(&mut self, shutdown_type: super::ShutdownType) -> Result<(), SystemError> {
829         // TODO:目前只是在表层判断,对端不知晓,后续需使用tcp实现
830         HANDLE_MAP
831             .write_irqsave()
832             .get_mut(&self.socket_handle())
833             .unwrap()
834             .shutdown_type = RwLock::new(shutdown_type);
835         return Ok(());
836     }
837 
838     fn accept(&mut self) -> Result<(Box<dyn Socket>, Endpoint), SystemError> {
839         if !self.is_listening {
840             return Err(SystemError::EINVAL);
841         }
842         let endpoint = self.local_endpoint.ok_or(SystemError::EINVAL)?;
843         loop {
844             // debug!("tcp accept: poll_ifaces()");
845             poll_ifaces();
846             // debug!("tcp socket:accept, socket'len={}", self.handle_list.len());
847 
848             let mut sockset = SOCKET_SET.lock_irqsave();
849             // Get the corresponding activated handler
850             let global_handle_index = self.handles.iter().position(|handle| {
851                 let con_smol_sock = sockset.get::<tcp::Socket>(handle.smoltcp_handle().unwrap());
852                 con_smol_sock.is_active()
853             });
854 
855             if let Some(handle_index) = global_handle_index {
856                 let con_smol_sock = sockset
857                     .get::<tcp::Socket>(self.handles[handle_index].smoltcp_handle().unwrap());
858 
859                 // debug!("[Socket] [TCP] Accept: {:?}", handle);
860                 // handle is connected socket's handle
861                 let remote_ep = con_smol_sock
862                     .remote_endpoint()
863                     .ok_or(SystemError::ENOTCONN)?;
864 
865                 let mut tcp_socket = Self::create_new_socket();
866                 self.do_listen(&mut tcp_socket, endpoint)?;
867 
868                 let new_handle = GlobalSocketHandle::new_smoltcp_handle(sockset.add(tcp_socket));
869 
870                 // let handle in TcpSock be the new empty handle, and return the old connected handle
871                 let old_handle = core::mem::replace(&mut self.handles[handle_index], new_handle);
872 
873                 let metadata = SocketMetadata::new(
874                     SocketType::Tcp,
875                     Self::DEFAULT_TX_BUF_SIZE,
876                     Self::DEFAULT_RX_BUF_SIZE,
877                     Self::DEFAULT_METADATA_BUF_SIZE,
878                     self.metadata.options,
879                 );
880 
881                 let sock_ret = Box::new(TcpSocket {
882                     handles: vec![old_handle],
883                     local_endpoint: self.local_endpoint,
884                     is_listening: false,
885                     metadata,
886                 });
887 
888                 {
889                     let mut handle_guard = HANDLE_MAP.write_irqsave();
890                     // 先删除原来的
891                     let item = handle_guard.remove(&old_handle).unwrap();
892 
893                     // 按照smoltcp行为,将新的handle绑定到原来的item
894                     let new_item = SocketHandleItem::new(None);
895                     handle_guard.insert(old_handle, new_item);
896                     // 插入新的item
897                     handle_guard.insert(new_handle, item);
898                     drop(handle_guard);
899                 }
900                 return Ok((sock_ret, Endpoint::Ip(Some(remote_ep))));
901             }
902 
903             drop(sockset);
904 
905             // debug!("[TCP] [Accept] sleeping socket with handle: {:?}", self.handles.get(0).unwrap().smoltcp_handle().unwrap());
906             SocketHandleItem::sleep(
907                 self.socket_handle(), // NOTICE
908                 Self::CAN_ACCPET,
909                 HANDLE_MAP.read_irqsave(),
910             );
911             // debug!("tcp socket:after sleep, handle_guard'len={}",HANDLE_MAP.write_irqsave().len());
912         }
913     }
914 
915     fn endpoint(&self) -> Option<Endpoint> {
916         let mut result: Option<Endpoint> = self.local_endpoint.map(|x| Endpoint::Ip(Some(x)));
917 
918         if result.is_none() {
919             let sockets = SOCKET_SET.lock_irqsave();
920             // debug!("tcp socket:endpoint, socket'len={}",self.handle.len());
921 
922             let socket =
923                 sockets.get::<tcp::Socket>(self.handles.get(0).unwrap().smoltcp_handle().unwrap());
924             if let Some(ep) = socket.local_endpoint() {
925                 result = Some(Endpoint::Ip(Some(ep)));
926             }
927         }
928         return result;
929     }
930 
931     fn peer_endpoint(&self) -> Option<Endpoint> {
932         let sockets = SOCKET_SET.lock_irqsave();
933         // debug!("tcp socket:peer_endpoint, socket'len={}",self.handle.len());
934 
935         let socket =
936             sockets.get::<tcp::Socket>(self.handles.get(0).unwrap().smoltcp_handle().unwrap());
937         return socket.remote_endpoint().map(|x| Endpoint::Ip(Some(x)));
938     }
939 
940     fn metadata(&self) -> SocketMetadata {
941         self.metadata.clone()
942     }
943 
944     fn box_clone(&self) -> Box<dyn Socket> {
945         Box::new(self.clone())
946     }
947 
948     fn socket_handle(&self) -> GlobalSocketHandle {
949         // debug!("tcp socket:socket_handle, socket'len={}",self.handle.len());
950 
951         *self.handles.get(0).unwrap()
952     }
953 
954     fn as_any_ref(&self) -> &dyn core::any::Any {
955         self
956     }
957 
958     fn as_any_mut(&mut self) -> &mut dyn core::any::Any {
959         self
960     }
961 }
962