xref: /DragonOS/kernel/src/net/socket/inet.rs (revision ceeb2e943ca7645609920ec7ad8bfceea2b13de6)
1 use alloc::{boxed::Box, sync::Arc, vec::Vec};
2 use smoltcp::{
3     iface::SocketHandle,
4     socket::{raw, tcp, udp},
5     wire,
6 };
7 use system_error::SystemError;
8 
9 use crate::{
10     arch::rand::rand,
11     driver::net::NetDriver,
12     kerror, kwarn,
13     libs::rwlock::RwLock,
14     net::{
15         event_poll::EPollEventType, net_core::poll_ifaces, Endpoint, Protocol, ShutdownType,
16         NET_DRIVERS,
17     },
18 };
19 
20 use super::{
21     GlobalSocketHandle, Socket, SocketHandleItem, SocketMetadata, SocketOptions, SocketPollMethod,
22     SocketType, HANDLE_MAP, PORT_MANAGER, SOCKET_SET,
23 };
24 
25 /// @brief 表示原始的socket。原始套接字绕过传输层协议(如 TCP 或 UDP)并提供对网络层协议(如 IP)的直接访问。
26 ///
27 /// ref: https://man7.org/linux/man-pages/man7/raw.7.html
28 #[derive(Debug, Clone)]
29 pub struct RawSocket {
30     handle: Arc<GlobalSocketHandle>,
31     /// 用户发送的数据包是否包含了IP头.
32     /// 如果是true,用户发送的数据包,必须包含IP头。(即用户要自行设置IP头+数据)
33     /// 如果是false,用户发送的数据包,不包含IP头。(即用户只要设置数据)
34     header_included: bool,
35     /// socket的metadata
36     metadata: SocketMetadata,
37 }
38 
39 impl RawSocket {
40     /// 元数据的缓冲区的大小
41     pub const DEFAULT_METADATA_BUF_SIZE: usize = 1024;
42     /// 默认的接收缓冲区的大小 receive
43     pub const DEFAULT_RX_BUF_SIZE: usize = 64 * 1024;
44     /// 默认的发送缓冲区的大小 transmiss
45     pub const DEFAULT_TX_BUF_SIZE: usize = 64 * 1024;
46 
47     /// @brief 创建一个原始的socket
48     ///
49     /// @param protocol 协议号
50     /// @param options socket的选项
51     ///
52     /// @return 返回创建的原始的socket
53     pub fn new(protocol: Protocol, options: SocketOptions) -> Self {
54         let rx_buffer = raw::PacketBuffer::new(
55             vec![raw::PacketMetadata::EMPTY; Self::DEFAULT_METADATA_BUF_SIZE],
56             vec![0; Self::DEFAULT_RX_BUF_SIZE],
57         );
58         let tx_buffer = raw::PacketBuffer::new(
59             vec![raw::PacketMetadata::EMPTY; Self::DEFAULT_METADATA_BUF_SIZE],
60             vec![0; Self::DEFAULT_TX_BUF_SIZE],
61         );
62         let protocol: u8 = protocol.into();
63         let socket = raw::Socket::new(
64             wire::IpVersion::Ipv4,
65             wire::IpProtocol::from(protocol),
66             rx_buffer,
67             tx_buffer,
68         );
69 
70         // 把socket添加到socket集合中,并得到socket的句柄
71         let handle: Arc<GlobalSocketHandle> =
72             GlobalSocketHandle::new(SOCKET_SET.lock_irqsave().add(socket));
73 
74         let metadata = SocketMetadata::new(
75             SocketType::Raw,
76             Self::DEFAULT_RX_BUF_SIZE,
77             Self::DEFAULT_TX_BUF_SIZE,
78             Self::DEFAULT_METADATA_BUF_SIZE,
79             options,
80         );
81 
82         return Self {
83             handle,
84             header_included: false,
85             metadata,
86         };
87     }
88 }
89 
90 impl Socket for RawSocket {
91     fn read(&self, buf: &mut [u8]) -> (Result<usize, SystemError>, Endpoint) {
92         poll_ifaces();
93         loop {
94             // 如何优化这里?
95             let mut socket_set_guard = SOCKET_SET.lock_irqsave();
96             let socket = socket_set_guard.get_mut::<raw::Socket>(self.handle.0);
97 
98             match socket.recv_slice(buf) {
99                 Ok(len) => {
100                     let packet = wire::Ipv4Packet::new_unchecked(buf);
101                     return (
102                         Ok(len),
103                         Endpoint::Ip(Some(wire::IpEndpoint {
104                             addr: wire::IpAddress::Ipv4(packet.src_addr()),
105                             port: 0,
106                         })),
107                     );
108                 }
109                 Err(raw::RecvError::Exhausted) => {
110                     if !self.metadata.options.contains(SocketOptions::BLOCK) {
111                         // 如果是非阻塞的socket,就返回错误
112                         return (Err(SystemError::EAGAIN_OR_EWOULDBLOCK), Endpoint::Ip(None));
113                     }
114                 }
115             }
116             drop(socket_set_guard);
117             SocketHandleItem::sleep(
118                 self.socket_handle(),
119                 EPollEventType::EPOLLIN.bits() as u64,
120                 HANDLE_MAP.read_irqsave(),
121             );
122         }
123     }
124 
125     fn write(&self, buf: &[u8], to: Option<Endpoint>) -> Result<usize, SystemError> {
126         // 如果用户发送的数据包,包含IP头,则直接发送
127         if self.header_included {
128             let mut socket_set_guard = SOCKET_SET.lock_irqsave();
129             let socket = socket_set_guard.get_mut::<raw::Socket>(self.handle.0);
130             match socket.send_slice(buf) {
131                 Ok(_) => {
132                     return Ok(buf.len());
133                 }
134                 Err(raw::SendError::BufferFull) => {
135                     return Err(SystemError::ENOBUFS);
136                 }
137             }
138         } else {
139             // 如果用户发送的数据包,不包含IP头,则需要自己构造IP头
140 
141             if let Some(Endpoint::Ip(Some(endpoint))) = to {
142                 let mut socket_set_guard = SOCKET_SET.lock_irqsave();
143                 let socket: &mut raw::Socket =
144                     socket_set_guard.get_mut::<raw::Socket>(self.handle.0);
145 
146                 // 暴力解决方案:只考虑0号网卡。 TODO:考虑多网卡的情况!!!
147                 let iface = NET_DRIVERS.read_irqsave().get(&0).unwrap().clone();
148 
149                 // 构造IP头
150                 let ipv4_src_addr: Option<wire::Ipv4Address> =
151                     iface.inner_iface().lock().ipv4_addr();
152                 if ipv4_src_addr.is_none() {
153                     return Err(SystemError::ENETUNREACH);
154                 }
155                 let ipv4_src_addr = ipv4_src_addr.unwrap();
156 
157                 if let wire::IpAddress::Ipv4(ipv4_dst) = endpoint.addr {
158                     let len = buf.len();
159 
160                     // 创建20字节的IPv4头部
161                     let mut buffer: Vec<u8> = vec![0u8; len + 20];
162                     let mut packet: wire::Ipv4Packet<&mut Vec<u8>> =
163                         wire::Ipv4Packet::new_unchecked(&mut buffer);
164 
165                     // 封装ipv4 header
166                     packet.set_version(4);
167                     packet.set_header_len(20);
168                     packet.set_total_len((20 + len) as u16);
169                     packet.set_src_addr(ipv4_src_addr);
170                     packet.set_dst_addr(ipv4_dst);
171 
172                     // 设置ipv4 header的protocol字段
173                     packet.set_next_header(socket.ip_protocol());
174 
175                     // 获取IP数据包的负载字段
176                     let payload: &mut [u8] = packet.payload_mut();
177                     payload.copy_from_slice(buf);
178 
179                     // 填充checksum字段
180                     packet.fill_checksum();
181 
182                     // 发送数据包
183                     socket.send_slice(&buffer).unwrap();
184 
185                     iface.poll(&mut socket_set_guard).ok();
186 
187                     drop(socket_set_guard);
188                     return Ok(len);
189                 } else {
190                     kwarn!("Unsupport Ip protocol type!");
191                     return Err(SystemError::EINVAL);
192                 }
193             } else {
194                 // 如果没有指定目的地址,则返回错误
195                 return Err(SystemError::ENOTCONN);
196             }
197         }
198     }
199 
200     fn connect(&mut self, _endpoint: Endpoint) -> Result<(), SystemError> {
201         Ok(())
202     }
203 
204     fn metadata(&self) -> SocketMetadata {
205         self.metadata.clone()
206     }
207 
208     fn box_clone(&self) -> Box<dyn Socket> {
209         Box::new(self.clone())
210     }
211 
212     fn socket_handle(&self) -> SocketHandle {
213         self.handle.0
214     }
215 
216     fn as_any_ref(&self) -> &dyn core::any::Any {
217         self
218     }
219 
220     fn as_any_mut(&mut self) -> &mut dyn core::any::Any {
221         self
222     }
223 }
224 
225 /// @brief 表示udp socket
226 ///
227 /// https://man7.org/linux/man-pages/man7/udp.7.html
228 #[derive(Debug, Clone)]
229 pub struct UdpSocket {
230     pub handle: Arc<GlobalSocketHandle>,
231     remote_endpoint: Option<Endpoint>, // 记录远程endpoint提供给connect(), 应该使用IP地址。
232     metadata: SocketMetadata,
233 }
234 
235 impl UdpSocket {
236     /// 元数据的缓冲区的大小
237     pub const DEFAULT_METADATA_BUF_SIZE: usize = 1024;
238     /// 默认的接收缓冲区的大小 receive
239     pub const DEFAULT_RX_BUF_SIZE: usize = 64 * 1024;
240     /// 默认的发送缓冲区的大小 transmiss
241     pub const DEFAULT_TX_BUF_SIZE: usize = 64 * 1024;
242 
243     /// @brief 创建一个udp的socket
244     ///
245     /// @param options socket的选项
246     ///
247     /// @return 返回创建的udp的socket
248     pub fn new(options: SocketOptions) -> Self {
249         let rx_buffer = udp::PacketBuffer::new(
250             vec![udp::PacketMetadata::EMPTY; Self::DEFAULT_METADATA_BUF_SIZE],
251             vec![0; Self::DEFAULT_RX_BUF_SIZE],
252         );
253         let tx_buffer = udp::PacketBuffer::new(
254             vec![udp::PacketMetadata::EMPTY; Self::DEFAULT_METADATA_BUF_SIZE],
255             vec![0; Self::DEFAULT_TX_BUF_SIZE],
256         );
257         let socket = udp::Socket::new(rx_buffer, tx_buffer);
258 
259         // 把socket添加到socket集合中,并得到socket的句柄
260         let handle: Arc<GlobalSocketHandle> =
261             GlobalSocketHandle::new(SOCKET_SET.lock_irqsave().add(socket));
262 
263         let metadata = SocketMetadata::new(
264             SocketType::Udp,
265             Self::DEFAULT_RX_BUF_SIZE,
266             Self::DEFAULT_TX_BUF_SIZE,
267             Self::DEFAULT_METADATA_BUF_SIZE,
268             options,
269         );
270 
271         return Self {
272             handle,
273             remote_endpoint: None,
274             metadata,
275         };
276     }
277 
278     fn do_bind(&self, socket: &mut udp::Socket, endpoint: Endpoint) -> Result<(), SystemError> {
279         if let Endpoint::Ip(Some(mut ip)) = endpoint {
280             // 端口为0则分配随机端口
281             if ip.port == 0 {
282                 ip.port = PORT_MANAGER.get_ephemeral_port(self.metadata.socket_type)?;
283             }
284             // 检测端口是否已被占用
285             PORT_MANAGER.bind_port(self.metadata.socket_type, ip.port, self.clone())?;
286 
287             let bind_res = if ip.addr.is_unspecified() {
288                 socket.bind(ip.port)
289             } else {
290                 socket.bind(ip)
291             };
292 
293             match bind_res {
294                 Ok(()) => return Ok(()),
295                 Err(_) => return Err(SystemError::EINVAL),
296             }
297         } else {
298             return Err(SystemError::EINVAL);
299         }
300     }
301 }
302 
303 impl Socket for UdpSocket {
304     /// @brief 在read函数执行之前,请先bind到本地的指定端口
305     fn read(&self, buf: &mut [u8]) -> (Result<usize, SystemError>, Endpoint) {
306         loop {
307             // kdebug!("Wait22 to Read");
308             poll_ifaces();
309             let mut socket_set_guard = SOCKET_SET.lock_irqsave();
310             let socket = socket_set_guard.get_mut::<udp::Socket>(self.handle.0);
311 
312             // kdebug!("Wait to Read");
313 
314             if socket.can_recv() {
315                 if let Ok((size, remote_endpoint)) = socket.recv_slice(buf) {
316                     drop(socket_set_guard);
317                     poll_ifaces();
318                     return (Ok(size), Endpoint::Ip(Some(remote_endpoint)));
319                 }
320             } else {
321                 // 如果socket没有连接,则忙等
322                 // return (Err(SystemError::ENOTCONN), Endpoint::Ip(None));
323             }
324             drop(socket_set_guard);
325             SocketHandleItem::sleep(
326                 self.socket_handle(),
327                 EPollEventType::EPOLLIN.bits() as u64,
328                 HANDLE_MAP.read_irqsave(),
329             );
330         }
331     }
332 
333     fn write(&self, buf: &[u8], to: Option<Endpoint>) -> Result<usize, SystemError> {
334         // kdebug!("udp to send: {:?}, len={}", to, buf.len());
335         let remote_endpoint: &wire::IpEndpoint = {
336             if let Some(Endpoint::Ip(Some(ref endpoint))) = to {
337                 endpoint
338             } else if let Some(Endpoint::Ip(Some(ref endpoint))) = self.remote_endpoint {
339                 endpoint
340             } else {
341                 return Err(SystemError::ENOTCONN);
342             }
343         };
344         // kdebug!("udp write: remote = {:?}", remote_endpoint);
345 
346         let mut socket_set_guard = SOCKET_SET.lock_irqsave();
347         let socket = socket_set_guard.get_mut::<udp::Socket>(self.handle.0);
348         // kdebug!("is open()={}", socket.is_open());
349         // kdebug!("socket endpoint={:?}", socket.endpoint());
350         if socket.can_send() {
351             // kdebug!("udp write: can send");
352             match socket.send_slice(buf, *remote_endpoint) {
353                 Ok(()) => {
354                     // kdebug!("udp write: send ok");
355                     drop(socket_set_guard);
356                     poll_ifaces();
357                     return Ok(buf.len());
358                 }
359                 Err(_) => {
360                     // kdebug!("udp write: send err");
361                     return Err(SystemError::ENOBUFS);
362                 }
363             }
364         } else {
365             // kdebug!("udp write: can not send");
366             return Err(SystemError::ENOBUFS);
367         };
368     }
369 
370     fn bind(&mut self, endpoint: Endpoint) -> Result<(), SystemError> {
371         let mut sockets = SOCKET_SET.lock_irqsave();
372         let socket = sockets.get_mut::<udp::Socket>(self.handle.0);
373         // kdebug!("UDP Bind to {:?}", endpoint);
374         return self.do_bind(socket, endpoint);
375     }
376 
377     fn poll(&self) -> EPollEventType {
378         let sockets = SOCKET_SET.lock_irqsave();
379         let socket = sockets.get::<udp::Socket>(self.handle.0);
380 
381         return SocketPollMethod::udp_poll(
382             socket,
383             HANDLE_MAP
384                 .read_irqsave()
385                 .get(&self.socket_handle())
386                 .unwrap()
387                 .shutdown_type(),
388         );
389     }
390 
391     fn connect(&mut self, endpoint: Endpoint) -> Result<(), SystemError> {
392         if let Endpoint::Ip(_) = endpoint {
393             self.remote_endpoint = Some(endpoint);
394             Ok(())
395         } else {
396             Err(SystemError::EINVAL)
397         }
398     }
399 
400     fn ioctl(
401         &self,
402         _cmd: usize,
403         _arg0: usize,
404         _arg1: usize,
405         _arg2: usize,
406     ) -> Result<usize, SystemError> {
407         todo!()
408     }
409 
410     fn metadata(&self) -> SocketMetadata {
411         self.metadata.clone()
412     }
413 
414     fn box_clone(&self) -> Box<dyn Socket> {
415         return Box::new(self.clone());
416     }
417 
418     fn endpoint(&self) -> Option<Endpoint> {
419         let sockets = SOCKET_SET.lock_irqsave();
420         let socket = sockets.get::<udp::Socket>(self.handle.0);
421         let listen_endpoint = socket.endpoint();
422 
423         if listen_endpoint.port == 0 {
424             return None;
425         } else {
426             // 如果listen_endpoint的address是None,意味着“监听所有的地址”。
427             // 这里假设所有的地址都是ipv4
428             // TODO: 支持ipv6
429             let result = wire::IpEndpoint::new(
430                 listen_endpoint
431                     .addr
432                     .unwrap_or(wire::IpAddress::v4(0, 0, 0, 0)),
433                 listen_endpoint.port,
434             );
435             return Some(Endpoint::Ip(Some(result)));
436         }
437     }
438 
439     fn peer_endpoint(&self) -> Option<Endpoint> {
440         return self.remote_endpoint.clone();
441     }
442 
443     fn socket_handle(&self) -> SocketHandle {
444         self.handle.0
445     }
446 
447     fn as_any_ref(&self) -> &dyn core::any::Any {
448         self
449     }
450 
451     fn as_any_mut(&mut self) -> &mut dyn core::any::Any {
452         self
453     }
454 }
455 
456 /// @brief 表示 tcp socket
457 ///
458 /// https://man7.org/linux/man-pages/man7/tcp.7.html
459 #[derive(Debug, Clone)]
460 pub struct TcpSocket {
461     handles: Vec<Arc<GlobalSocketHandle>>,
462     local_endpoint: Option<wire::IpEndpoint>, // save local endpoint for bind()
463     is_listening: bool,
464     metadata: SocketMetadata,
465 }
466 
467 impl TcpSocket {
468     /// 元数据的缓冲区的大小
469     pub const DEFAULT_METADATA_BUF_SIZE: usize = 1024;
470     /// 默认的接收缓冲区的大小 receive
471     pub const DEFAULT_RX_BUF_SIZE: usize = 512 * 1024;
472     /// 默认的发送缓冲区的大小 transmiss
473     pub const DEFAULT_TX_BUF_SIZE: usize = 512 * 1024;
474 
475     /// TcpSocket的特殊事件,用于在事件等待队列上sleep
476     pub const CAN_CONNECT: u64 = 1u64 << 63;
477     pub const CAN_ACCPET: u64 = 1u64 << 62;
478 
479     /// @brief 创建一个tcp的socket
480     ///
481     /// @param options socket的选项
482     ///
483     /// @return 返回创建的tcp的socket
484     pub fn new(options: SocketOptions) -> Self {
485         // 创建handles数组并把socket添加到socket集合中,并得到socket的句柄
486         let handles: Vec<Arc<GlobalSocketHandle>> = vec![GlobalSocketHandle::new(
487             SOCKET_SET.lock_irqsave().add(Self::create_new_socket()),
488         )];
489 
490         let metadata = SocketMetadata::new(
491             SocketType::Tcp,
492             Self::DEFAULT_RX_BUF_SIZE,
493             Self::DEFAULT_TX_BUF_SIZE,
494             Self::DEFAULT_METADATA_BUF_SIZE,
495             options,
496         );
497         // kdebug!("when there's a new tcp socket,its'len: {}",handles.len());
498 
499         return Self {
500             handles,
501             local_endpoint: None,
502             is_listening: false,
503             metadata,
504         };
505     }
506 
507     fn do_listen(
508         &mut self,
509         socket: &mut tcp::Socket,
510         local_endpoint: wire::IpEndpoint,
511     ) -> Result<(), SystemError> {
512         let listen_result = if local_endpoint.addr.is_unspecified() {
513             // kdebug!("Tcp Socket Listen on port {}", local_endpoint.port);
514             socket.listen(local_endpoint.port)
515         } else {
516             // kdebug!("Tcp Socket Listen on {local_endpoint}");
517             socket.listen(local_endpoint)
518         };
519         return match listen_result {
520             Ok(()) => {
521                 // kdebug!(
522                 //     "Tcp Socket Listen on {local_endpoint}, open?:{}",
523                 //     socket.is_open()
524                 // );
525                 self.is_listening = true;
526 
527                 Ok(())
528             }
529             Err(_) => Err(SystemError::EINVAL),
530         };
531     }
532 
533     /// # create_new_socket - 创建新的TCP套接字
534     ///
535     /// 该函数用于创建一个新的TCP套接字,并返回该套接字的引用。
536     fn create_new_socket() -> tcp::Socket<'static> {
537         // 初始化tcp的buffer
538         let rx_buffer = tcp::SocketBuffer::new(vec![0; Self::DEFAULT_RX_BUF_SIZE]);
539         let tx_buffer = tcp::SocketBuffer::new(vec![0; Self::DEFAULT_TX_BUF_SIZE]);
540         tcp::Socket::new(rx_buffer, tx_buffer)
541     }
542 }
543 
544 impl Socket for TcpSocket {
545     fn read(&self, buf: &mut [u8]) -> (Result<usize, SystemError>, Endpoint) {
546         if HANDLE_MAP
547             .read_irqsave()
548             .get(&self.socket_handle())
549             .unwrap()
550             .shutdown_type()
551             .contains(ShutdownType::RCV_SHUTDOWN)
552         {
553             return (Err(SystemError::ENOTCONN), Endpoint::Ip(None));
554         }
555         // kdebug!("tcp socket: read, buf len={}", buf.len());
556         // kdebug!("tcp socket:read, socket'len={}",self.handle.len());
557         loop {
558             poll_ifaces();
559             let mut socket_set_guard = SOCKET_SET.lock_irqsave();
560 
561             let socket = socket_set_guard.get_mut::<tcp::Socket>(self.handles.get(0).unwrap().0);
562 
563             // 如果socket已经关闭,返回错误
564             if !socket.is_active() {
565                 // kdebug!("Tcp Socket Read Error, socket is closed");
566                 return (Err(SystemError::ENOTCONN), Endpoint::Ip(None));
567             }
568 
569             if socket.may_recv() {
570                 let recv_res = socket.recv_slice(buf);
571 
572                 if let Ok(size) = recv_res {
573                     if size > 0 {
574                         let endpoint = if let Some(p) = socket.remote_endpoint() {
575                             p
576                         } else {
577                             return (Err(SystemError::ENOTCONN), Endpoint::Ip(None));
578                         };
579 
580                         drop(socket_set_guard);
581                         poll_ifaces();
582                         return (Ok(size), Endpoint::Ip(Some(endpoint)));
583                     }
584                 } else {
585                     let err = recv_res.unwrap_err();
586                     match err {
587                         tcp::RecvError::InvalidState => {
588                             kwarn!("Tcp Socket Read Error, InvalidState");
589                             return (Err(SystemError::ENOTCONN), Endpoint::Ip(None));
590                         }
591                         tcp::RecvError::Finished => {
592                             // 对端写端已关闭,我们应该关闭读端
593                             HANDLE_MAP
594                                 .write_irqsave()
595                                 .get_mut(&self.socket_handle())
596                                 .unwrap()
597                                 .shutdown_type_writer()
598                                 .insert(ShutdownType::RCV_SHUTDOWN);
599                             return (Err(SystemError::ENOTCONN), Endpoint::Ip(None));
600                         }
601                     }
602                 }
603             } else {
604                 return (Err(SystemError::ENOTCONN), Endpoint::Ip(None));
605             }
606             drop(socket_set_guard);
607             SocketHandleItem::sleep(
608                 self.socket_handle(),
609                 EPollEventType::EPOLLIN.bits() as u64,
610                 HANDLE_MAP.read_irqsave(),
611             );
612         }
613     }
614 
615     fn write(&self, buf: &[u8], _to: Option<Endpoint>) -> Result<usize, SystemError> {
616         if HANDLE_MAP
617             .read_irqsave()
618             .get(&self.socket_handle())
619             .unwrap()
620             .shutdown_type()
621             .contains(ShutdownType::RCV_SHUTDOWN)
622         {
623             return Err(SystemError::ENOTCONN);
624         }
625         // kdebug!("tcp socket:write, socket'len={}",self.handle.len());
626 
627         let mut socket_set_guard = SOCKET_SET.lock_irqsave();
628 
629         let socket = socket_set_guard.get_mut::<tcp::Socket>(self.handles.get(0).unwrap().0);
630 
631         if socket.is_open() {
632             if socket.can_send() {
633                 match socket.send_slice(buf) {
634                     Ok(size) => {
635                         drop(socket_set_guard);
636                         poll_ifaces();
637                         return Ok(size);
638                     }
639                     Err(e) => {
640                         kerror!("Tcp Socket Write Error {e:?}");
641                         return Err(SystemError::ENOBUFS);
642                     }
643                 }
644             } else {
645                 return Err(SystemError::ENOBUFS);
646             }
647         }
648 
649         return Err(SystemError::ENOTCONN);
650     }
651 
652     fn poll(&self) -> EPollEventType {
653         let mut socket_set_guard = SOCKET_SET.lock_irqsave();
654         // kdebug!("tcp socket:poll, socket'len={}",self.handle.len());
655 
656         let socket = socket_set_guard.get_mut::<tcp::Socket>(self.handles.get(0).unwrap().0);
657         return SocketPollMethod::tcp_poll(
658             socket,
659             HANDLE_MAP
660                 .read_irqsave()
661                 .get(&self.socket_handle())
662                 .unwrap()
663                 .shutdown_type(),
664         );
665     }
666 
667     fn connect(&mut self, endpoint: Endpoint) -> Result<(), SystemError> {
668         let mut sockets = SOCKET_SET.lock_irqsave();
669         // kdebug!("tcp socket:connect, socket'len={}",self.handle.len());
670 
671         let socket = sockets.get_mut::<tcp::Socket>(self.handles.get(0).unwrap().0);
672 
673         if let Endpoint::Ip(Some(ip)) = endpoint {
674             let temp_port = PORT_MANAGER.get_ephemeral_port(self.metadata.socket_type)?;
675             // 检测端口是否被占用
676             PORT_MANAGER.bind_port(self.metadata.socket_type, temp_port, self.clone())?;
677 
678             // kdebug!("temp_port: {}", temp_port);
679             let iface: Arc<dyn NetDriver> = NET_DRIVERS.write_irqsave().get(&0).unwrap().clone();
680             let mut inner_iface = iface.inner_iface().lock();
681             // kdebug!("to connect: {ip:?}");
682 
683             match socket.connect(inner_iface.context(), ip, temp_port) {
684                 Ok(()) => {
685                     // avoid deadlock
686                     drop(inner_iface);
687                     drop(iface);
688                     drop(sockets);
689                     loop {
690                         poll_ifaces();
691                         let mut sockets = SOCKET_SET.lock_irqsave();
692                         let socket = sockets.get_mut::<tcp::Socket>(self.handles.get(0).unwrap().0);
693 
694                         match socket.state() {
695                             tcp::State::Established => {
696                                 return Ok(());
697                             }
698                             tcp::State::SynSent => {
699                                 drop(sockets);
700                                 SocketHandleItem::sleep(
701                                     self.socket_handle(),
702                                     Self::CAN_CONNECT,
703                                     HANDLE_MAP.read_irqsave(),
704                                 );
705                             }
706                             _ => {
707                                 return Err(SystemError::ECONNREFUSED);
708                             }
709                         }
710                     }
711                 }
712                 Err(e) => {
713                     // kerror!("Tcp Socket Connect Error {e:?}");
714                     match e {
715                         tcp::ConnectError::InvalidState => return Err(SystemError::EISCONN),
716                         tcp::ConnectError::Unaddressable => return Err(SystemError::EADDRNOTAVAIL),
717                     }
718                 }
719             }
720         } else {
721             return Err(SystemError::EINVAL);
722         }
723     }
724 
725     /// @brief tcp socket 监听 local_endpoint 端口
726     ///
727     /// @param backlog 未处理的连接队列的最大长度. 由于smoltcp不支持backlog,所以这个参数目前无效
728     fn listen(&mut self, backlog: usize) -> Result<(), SystemError> {
729         if self.is_listening {
730             return Ok(());
731         }
732 
733         let local_endpoint = self.local_endpoint.ok_or(SystemError::EINVAL)?;
734         let mut sockets = SOCKET_SET.lock_irqsave();
735         // 获取handle的数量
736         let handlen = self.handles.len();
737         let backlog = handlen.max(backlog);
738 
739         // 添加剩余需要构建的socket
740         // kdebug!("tcp socket:before listen, socket'len={}",self.handle.len());
741         let mut handle_guard = HANDLE_MAP.write_irqsave();
742         self.handles.extend((handlen..backlog).map(|_| {
743             let socket = Self::create_new_socket();
744             let handle = GlobalSocketHandle::new(sockets.add(socket));
745             let handle_item = SocketHandleItem::new();
746             handle_guard.insert(handle.0, handle_item);
747             handle
748         }));
749         // kdebug!("tcp socket:listen, socket'len={}",self.handle.len());
750         // kdebug!("tcp socket:listen, backlog={backlog}");
751 
752         // 监听所有的socket
753         for i in 0..backlog {
754             let handle = self.handles.get(i).unwrap();
755 
756             let socket = sockets.get_mut::<tcp::Socket>(handle.0);
757 
758             if !socket.is_listening() {
759                 // kdebug!("Tcp Socket is already listening on {local_endpoint}");
760                 self.do_listen(socket, local_endpoint)?;
761             }
762             // kdebug!("Tcp Socket  before listen, open={}", socket.is_open());
763         }
764         return Ok(());
765     }
766 
767     fn bind(&mut self, endpoint: Endpoint) -> Result<(), SystemError> {
768         if let Endpoint::Ip(Some(mut ip)) = endpoint {
769             if ip.port == 0 {
770                 ip.port = PORT_MANAGER.get_ephemeral_port(self.metadata.socket_type)?;
771             }
772 
773             // 检测端口是否已被占用
774             PORT_MANAGER.bind_port(self.metadata.socket_type, ip.port, self.clone())?;
775             // kdebug!("tcp socket:bind, socket'len={}",self.handle.len());
776 
777             self.local_endpoint = Some(ip);
778             self.is_listening = false;
779             return Ok(());
780         }
781         return Err(SystemError::EINVAL);
782     }
783 
784     fn shutdown(&mut self, shutdown_type: super::ShutdownType) -> Result<(), SystemError> {
785         // TODO:目前只是在表层判断,对端不知晓,后续需使用tcp实现
786         HANDLE_MAP
787             .write_irqsave()
788             .get_mut(&self.socket_handle())
789             .unwrap()
790             .shutdown_type = RwLock::new(shutdown_type);
791         return Ok(());
792     }
793 
794     fn accept(&mut self) -> Result<(Box<dyn Socket>, Endpoint), SystemError> {
795         let endpoint = self.local_endpoint.ok_or(SystemError::EINVAL)?;
796         loop {
797             // kdebug!("tcp accept: poll_ifaces()");
798             poll_ifaces();
799             // kdebug!("tcp socket:accept, socket'len={}",self.handle.len());
800 
801             let mut sockets = SOCKET_SET.lock_irqsave();
802 
803             // 随机获取访问的socket的handle
804             let index: usize = rand() % self.handles.len();
805             let handle = self.handles.get(index).unwrap();
806             let socket = sockets.get_mut::<tcp::Socket>(handle.0);
807 
808             if socket.is_active() {
809                 // kdebug!("tcp accept: socket.is_active()");
810                 let remote_ep = socket.remote_endpoint().ok_or(SystemError::ENOTCONN)?;
811 
812                 let new_socket = {
813                     // The new TCP socket used for sending and receiving data.
814                     let mut tcp_socket = Self::create_new_socket();
815                     self.do_listen(&mut tcp_socket, endpoint)
816                         .expect("do_listen failed");
817 
818                     // tcp_socket.listen(endpoint).unwrap();
819 
820                     // 之所以把old_handle存入new_socket, 是因为当前时刻,smoltcp已经把old_handle对应的socket与远程的endpoint关联起来了
821                     // 因此需要再为当前的socket分配一个新的handle
822                     let new_handle = GlobalSocketHandle::new(sockets.add(tcp_socket));
823                     let old_handle = ::core::mem::replace(
824                         &mut *self.handles.get_mut(index).unwrap(),
825                         new_handle.clone(),
826                     );
827 
828                     let metadata = SocketMetadata::new(
829                         SocketType::Tcp,
830                         Self::DEFAULT_TX_BUF_SIZE,
831                         Self::DEFAULT_RX_BUF_SIZE,
832                         Self::DEFAULT_METADATA_BUF_SIZE,
833                         self.metadata.options,
834                     );
835 
836                     let new_socket = Box::new(TcpSocket {
837                         handles: vec![old_handle.clone()],
838                         local_endpoint: self.local_endpoint,
839                         is_listening: false,
840                         metadata,
841                     });
842                     // kdebug!("tcp socket:after accept, socket'len={}",new_socket.handle.len());
843 
844                     // 更新端口与 socket 的绑定
845                     if let Some(Endpoint::Ip(Some(ip))) = self.endpoint() {
846                         PORT_MANAGER.unbind_port(self.metadata.socket_type, ip.port)?;
847                         PORT_MANAGER.bind_port(
848                             self.metadata.socket_type,
849                             ip.port,
850                             *new_socket.clone(),
851                         )?;
852                     }
853 
854                     // 更新handle表
855                     let mut handle_guard = HANDLE_MAP.write_irqsave();
856                     // 先删除原来的
857 
858                     let item = handle_guard.remove(&old_handle.0).unwrap();
859 
860                     // 按照smoltcp行为,将新的handle绑定到原来的item
861                     handle_guard.insert(new_handle.0, item);
862                     let new_item = SocketHandleItem::new();
863 
864                     // 插入新的item
865                     handle_guard.insert(old_handle.0, new_item);
866 
867                     new_socket
868                 };
869                 // kdebug!("tcp accept: new socket: {:?}", new_socket);
870                 drop(sockets);
871                 poll_ifaces();
872 
873                 return Ok((new_socket, Endpoint::Ip(Some(remote_ep))));
874             }
875             // kdebug!("tcp socket:before sleep, handle_guard'len={}",HANDLE_MAP.write_irqsave().len());
876 
877             drop(sockets);
878             SocketHandleItem::sleep(handle.0, Self::CAN_ACCPET, HANDLE_MAP.read_irqsave());
879             // kdebug!("tcp socket:after sleep, handle_guard'len={}",HANDLE_MAP.write_irqsave().len());
880         }
881     }
882 
883     fn endpoint(&self) -> Option<Endpoint> {
884         let mut result: Option<Endpoint> = self.local_endpoint.map(|x| Endpoint::Ip(Some(x)));
885 
886         if result.is_none() {
887             let sockets = SOCKET_SET.lock_irqsave();
888             // kdebug!("tcp socket:endpoint, socket'len={}",self.handle.len());
889 
890             let socket = sockets.get::<tcp::Socket>(self.handles.get(0).unwrap().0);
891             if let Some(ep) = socket.local_endpoint() {
892                 result = Some(Endpoint::Ip(Some(ep)));
893             }
894         }
895         return result;
896     }
897 
898     fn peer_endpoint(&self) -> Option<Endpoint> {
899         let sockets = SOCKET_SET.lock_irqsave();
900         // kdebug!("tcp socket:peer_endpoint, socket'len={}",self.handle.len());
901 
902         let socket = sockets.get::<tcp::Socket>(self.handles.get(0).unwrap().0);
903         return socket.remote_endpoint().map(|x| Endpoint::Ip(Some(x)));
904     }
905 
906     fn metadata(&self) -> SocketMetadata {
907         self.metadata.clone()
908     }
909 
910     fn box_clone(&self) -> Box<dyn Socket> {
911         Box::new(self.clone())
912     }
913 
914     fn socket_handle(&self) -> SocketHandle {
915         // kdebug!("tcp socket:socket_handle, socket'len={}",self.handle.len());
916 
917         self.handles.get(0).unwrap().0
918     }
919 
920     fn as_any_ref(&self) -> &dyn core::any::Any {
921         self
922     }
923 
924     fn as_any_mut(&mut self) -> &mut dyn core::any::Any {
925         self
926     }
927 }
928