xref: /DragonOS/kernel/src/net/socket/inet.rs (revision 4b0170bd6bb374d0e9699a0076cc23b976ad6db7)
1 use alloc::{boxed::Box, sync::Arc, vec::Vec};
2 use smoltcp::{
3     iface::SocketHandle,
4     socket::{raw, tcp, udp},
5     wire,
6 };
7 use system_error::SystemError;
8 
9 use crate::{
10     driver::net::NetDriver,
11     kerror, kwarn,
12     libs::rwlock::RwLock,
13     net::{
14         event_poll::EPollEventType, net_core::poll_ifaces, Endpoint, Protocol, ShutdownType,
15         NET_DRIVERS,
16     },
17 };
18 
19 use super::{
20     GlobalSocketHandle, Socket, SocketHandleItem, SocketMetadata, SocketOptions, SocketPollMethod,
21     SocketType, HANDLE_MAP, PORT_MANAGER, SOCKET_SET,
22 };
23 
24 /// @brief 表示原始的socket。原始套接字绕过传输层协议(如 TCP 或 UDP)并提供对网络层协议(如 IP)的直接访问。
25 ///
26 /// ref: https://man7.org/linux/man-pages/man7/raw.7.html
27 #[derive(Debug, Clone)]
28 pub struct RawSocket {
29     handle: Arc<GlobalSocketHandle>,
30     /// 用户发送的数据包是否包含了IP头.
31     /// 如果是true,用户发送的数据包,必须包含IP头。(即用户要自行设置IP头+数据)
32     /// 如果是false,用户发送的数据包,不包含IP头。(即用户只要设置数据)
33     header_included: bool,
34     /// socket的metadata
35     metadata: SocketMetadata,
36 }
37 
38 impl RawSocket {
39     /// 元数据的缓冲区的大小
40     pub const DEFAULT_METADATA_BUF_SIZE: usize = 1024;
41     /// 默认的接收缓冲区的大小 receive
42     pub const DEFAULT_RX_BUF_SIZE: usize = 64 * 1024;
43     /// 默认的发送缓冲区的大小 transmiss
44     pub const DEFAULT_TX_BUF_SIZE: usize = 64 * 1024;
45 
46     /// @brief 创建一个原始的socket
47     ///
48     /// @param protocol 协议号
49     /// @param options socket的选项
50     ///
51     /// @return 返回创建的原始的socket
52     pub fn new(protocol: Protocol, options: SocketOptions) -> Self {
53         let rx_buffer = raw::PacketBuffer::new(
54             vec![raw::PacketMetadata::EMPTY; Self::DEFAULT_METADATA_BUF_SIZE],
55             vec![0; Self::DEFAULT_RX_BUF_SIZE],
56         );
57         let tx_buffer = raw::PacketBuffer::new(
58             vec![raw::PacketMetadata::EMPTY; Self::DEFAULT_METADATA_BUF_SIZE],
59             vec![0; Self::DEFAULT_TX_BUF_SIZE],
60         );
61         let protocol: u8 = protocol.into();
62         let socket = raw::Socket::new(
63             wire::IpVersion::Ipv4,
64             wire::IpProtocol::from(protocol),
65             rx_buffer,
66             tx_buffer,
67         );
68 
69         // 把socket添加到socket集合中,并得到socket的句柄
70         let handle: Arc<GlobalSocketHandle> =
71             GlobalSocketHandle::new(SOCKET_SET.lock_irqsave().add(socket));
72 
73         let metadata = SocketMetadata::new(
74             SocketType::Raw,
75             Self::DEFAULT_RX_BUF_SIZE,
76             Self::DEFAULT_TX_BUF_SIZE,
77             Self::DEFAULT_METADATA_BUF_SIZE,
78             options,
79         );
80 
81         return Self {
82             handle,
83             header_included: false,
84             metadata,
85         };
86     }
87 }
88 
89 impl Socket for RawSocket {
90     fn read(&self, buf: &mut [u8]) -> (Result<usize, SystemError>, Endpoint) {
91         poll_ifaces();
92         loop {
93             // 如何优化这里?
94             let mut socket_set_guard = SOCKET_SET.lock_irqsave();
95             let socket = socket_set_guard.get_mut::<raw::Socket>(self.handle.0);
96 
97             match socket.recv_slice(buf) {
98                 Ok(len) => {
99                     let packet = wire::Ipv4Packet::new_unchecked(buf);
100                     return (
101                         Ok(len),
102                         Endpoint::Ip(Some(wire::IpEndpoint {
103                             addr: wire::IpAddress::Ipv4(packet.src_addr()),
104                             port: 0,
105                         })),
106                     );
107                 }
108                 Err(raw::RecvError::Exhausted) => {
109                     if !self.metadata.options.contains(SocketOptions::BLOCK) {
110                         // 如果是非阻塞的socket,就返回错误
111                         return (Err(SystemError::EAGAIN_OR_EWOULDBLOCK), Endpoint::Ip(None));
112                     }
113                 }
114             }
115             drop(socket_set_guard);
116             SocketHandleItem::sleep(
117                 self.socket_handle(),
118                 EPollEventType::EPOLLIN.bits() as u64,
119                 HANDLE_MAP.read_irqsave(),
120             );
121         }
122     }
123 
124     fn write(&self, buf: &[u8], to: Option<Endpoint>) -> Result<usize, SystemError> {
125         // 如果用户发送的数据包,包含IP头,则直接发送
126         if self.header_included {
127             let mut socket_set_guard = SOCKET_SET.lock_irqsave();
128             let socket = socket_set_guard.get_mut::<raw::Socket>(self.handle.0);
129             match socket.send_slice(buf) {
130                 Ok(_) => {
131                     return Ok(buf.len());
132                 }
133                 Err(raw::SendError::BufferFull) => {
134                     return Err(SystemError::ENOBUFS);
135                 }
136             }
137         } else {
138             // 如果用户发送的数据包,不包含IP头,则需要自己构造IP头
139 
140             if let Some(Endpoint::Ip(Some(endpoint))) = to {
141                 let mut socket_set_guard = SOCKET_SET.lock_irqsave();
142                 let socket: &mut raw::Socket =
143                     socket_set_guard.get_mut::<raw::Socket>(self.handle.0);
144 
145                 // 暴力解决方案:只考虑0号网卡。 TODO:考虑多网卡的情况!!!
146                 let iface = NET_DRIVERS.read_irqsave().get(&0).unwrap().clone();
147 
148                 // 构造IP头
149                 let ipv4_src_addr: Option<wire::Ipv4Address> =
150                     iface.inner_iface().lock().ipv4_addr();
151                 if ipv4_src_addr.is_none() {
152                     return Err(SystemError::ENETUNREACH);
153                 }
154                 let ipv4_src_addr = ipv4_src_addr.unwrap();
155 
156                 if let wire::IpAddress::Ipv4(ipv4_dst) = endpoint.addr {
157                     let len = buf.len();
158 
159                     // 创建20字节的IPv4头部
160                     let mut buffer: Vec<u8> = vec![0u8; len + 20];
161                     let mut packet: wire::Ipv4Packet<&mut Vec<u8>> =
162                         wire::Ipv4Packet::new_unchecked(&mut buffer);
163 
164                     // 封装ipv4 header
165                     packet.set_version(4);
166                     packet.set_header_len(20);
167                     packet.set_total_len((20 + len) as u16);
168                     packet.set_src_addr(ipv4_src_addr);
169                     packet.set_dst_addr(ipv4_dst);
170 
171                     // 设置ipv4 header的protocol字段
172                     packet.set_next_header(socket.ip_protocol());
173 
174                     // 获取IP数据包的负载字段
175                     let payload: &mut [u8] = packet.payload_mut();
176                     payload.copy_from_slice(buf);
177 
178                     // 填充checksum字段
179                     packet.fill_checksum();
180 
181                     // 发送数据包
182                     socket.send_slice(&buffer).unwrap();
183 
184                     iface.poll(&mut socket_set_guard).ok();
185 
186                     drop(socket_set_guard);
187                     return Ok(len);
188                 } else {
189                     kwarn!("Unsupport Ip protocol type!");
190                     return Err(SystemError::EINVAL);
191                 }
192             } else {
193                 // 如果没有指定目的地址,则返回错误
194                 return Err(SystemError::ENOTCONN);
195             }
196         }
197     }
198 
199     fn connect(&mut self, _endpoint: Endpoint) -> Result<(), SystemError> {
200         Ok(())
201     }
202 
203     fn metadata(&self) -> SocketMetadata {
204         self.metadata.clone()
205     }
206 
207     fn box_clone(&self) -> Box<dyn Socket> {
208         Box::new(self.clone())
209     }
210 
211     fn socket_handle(&self) -> SocketHandle {
212         self.handle.0
213     }
214 
215     fn as_any_ref(&self) -> &dyn core::any::Any {
216         self
217     }
218 
219     fn as_any_mut(&mut self) -> &mut dyn core::any::Any {
220         self
221     }
222 }
223 
224 /// @brief 表示udp socket
225 ///
226 /// https://man7.org/linux/man-pages/man7/udp.7.html
227 #[derive(Debug, Clone)]
228 pub struct UdpSocket {
229     pub handle: Arc<GlobalSocketHandle>,
230     remote_endpoint: Option<Endpoint>, // 记录远程endpoint提供给connect(), 应该使用IP地址。
231     metadata: SocketMetadata,
232 }
233 
234 impl UdpSocket {
235     /// 元数据的缓冲区的大小
236     pub const DEFAULT_METADATA_BUF_SIZE: usize = 1024;
237     /// 默认的接收缓冲区的大小 receive
238     pub const DEFAULT_RX_BUF_SIZE: usize = 64 * 1024;
239     /// 默认的发送缓冲区的大小 transmiss
240     pub const DEFAULT_TX_BUF_SIZE: usize = 64 * 1024;
241 
242     /// @brief 创建一个udp的socket
243     ///
244     /// @param options socket的选项
245     ///
246     /// @return 返回创建的udp的socket
247     pub fn new(options: SocketOptions) -> Self {
248         let rx_buffer = udp::PacketBuffer::new(
249             vec![udp::PacketMetadata::EMPTY; Self::DEFAULT_METADATA_BUF_SIZE],
250             vec![0; Self::DEFAULT_RX_BUF_SIZE],
251         );
252         let tx_buffer = udp::PacketBuffer::new(
253             vec![udp::PacketMetadata::EMPTY; Self::DEFAULT_METADATA_BUF_SIZE],
254             vec![0; Self::DEFAULT_TX_BUF_SIZE],
255         );
256         let socket = udp::Socket::new(rx_buffer, tx_buffer);
257 
258         // 把socket添加到socket集合中,并得到socket的句柄
259         let handle: Arc<GlobalSocketHandle> =
260             GlobalSocketHandle::new(SOCKET_SET.lock_irqsave().add(socket));
261 
262         let metadata = SocketMetadata::new(
263             SocketType::Udp,
264             Self::DEFAULT_RX_BUF_SIZE,
265             Self::DEFAULT_TX_BUF_SIZE,
266             Self::DEFAULT_METADATA_BUF_SIZE,
267             options,
268         );
269 
270         return Self {
271             handle,
272             remote_endpoint: None,
273             metadata,
274         };
275     }
276 
277     fn do_bind(&self, socket: &mut udp::Socket, endpoint: Endpoint) -> Result<(), SystemError> {
278         if let Endpoint::Ip(Some(mut ip)) = endpoint {
279             // 端口为0则分配随机端口
280             if ip.port == 0 {
281                 ip.port = PORT_MANAGER.get_ephemeral_port(self.metadata.socket_type)?;
282             }
283             // 检测端口是否已被占用
284             PORT_MANAGER.bind_port(self.metadata.socket_type, ip.port, self.handle.clone())?;
285 
286             let bind_res = if ip.addr.is_unspecified() {
287                 socket.bind(ip.port)
288             } else {
289                 socket.bind(ip)
290             };
291 
292             match bind_res {
293                 Ok(()) => return Ok(()),
294                 Err(_) => return Err(SystemError::EINVAL),
295             }
296         } else {
297             return Err(SystemError::EINVAL);
298         }
299     }
300 }
301 
302 impl Socket for UdpSocket {
303     /// @brief 在read函数执行之前,请先bind到本地的指定端口
304     fn read(&self, buf: &mut [u8]) -> (Result<usize, SystemError>, Endpoint) {
305         loop {
306             // kdebug!("Wait22 to Read");
307             poll_ifaces();
308             let mut socket_set_guard = SOCKET_SET.lock_irqsave();
309             let socket = socket_set_guard.get_mut::<udp::Socket>(self.handle.0);
310 
311             // kdebug!("Wait to Read");
312 
313             if socket.can_recv() {
314                 if let Ok((size, remote_endpoint)) = socket.recv_slice(buf) {
315                     drop(socket_set_guard);
316                     poll_ifaces();
317                     return (Ok(size), Endpoint::Ip(Some(remote_endpoint)));
318                 }
319             } else {
320                 // 如果socket没有连接,则忙等
321                 // return (Err(SystemError::ENOTCONN), Endpoint::Ip(None));
322             }
323             drop(socket_set_guard);
324             SocketHandleItem::sleep(
325                 self.socket_handle(),
326                 EPollEventType::EPOLLIN.bits() as u64,
327                 HANDLE_MAP.read_irqsave(),
328             );
329         }
330     }
331 
332     fn write(&self, buf: &[u8], to: Option<Endpoint>) -> Result<usize, SystemError> {
333         // kdebug!("udp to send: {:?}, len={}", to, buf.len());
334         let remote_endpoint: &wire::IpEndpoint = {
335             if let Some(Endpoint::Ip(Some(ref endpoint))) = to {
336                 endpoint
337             } else if let Some(Endpoint::Ip(Some(ref endpoint))) = self.remote_endpoint {
338                 endpoint
339             } else {
340                 return Err(SystemError::ENOTCONN);
341             }
342         };
343         // kdebug!("udp write: remote = {:?}", remote_endpoint);
344 
345         let mut socket_set_guard = SOCKET_SET.lock_irqsave();
346         let socket = socket_set_guard.get_mut::<udp::Socket>(self.handle.0);
347         // kdebug!("is open()={}", socket.is_open());
348         // kdebug!("socket endpoint={:?}", socket.endpoint());
349         if socket.can_send() {
350             // kdebug!("udp write: can send");
351             match socket.send_slice(buf, *remote_endpoint) {
352                 Ok(()) => {
353                     // kdebug!("udp write: send ok");
354                     drop(socket_set_guard);
355                     poll_ifaces();
356                     return Ok(buf.len());
357                 }
358                 Err(_) => {
359                     // kdebug!("udp write: send err");
360                     return Err(SystemError::ENOBUFS);
361                 }
362             }
363         } else {
364             // kdebug!("udp write: can not send");
365             return Err(SystemError::ENOBUFS);
366         };
367     }
368 
369     fn bind(&mut self, endpoint: Endpoint) -> Result<(), SystemError> {
370         let mut sockets = SOCKET_SET.lock_irqsave();
371         let socket = sockets.get_mut::<udp::Socket>(self.handle.0);
372         // kdebug!("UDP Bind to {:?}", endpoint);
373         return self.do_bind(socket, endpoint);
374     }
375 
376     fn poll(&self) -> EPollEventType {
377         let sockets = SOCKET_SET.lock_irqsave();
378         let socket = sockets.get::<udp::Socket>(self.handle.0);
379 
380         return SocketPollMethod::udp_poll(
381             socket,
382             HANDLE_MAP
383                 .read_irqsave()
384                 .get(&self.socket_handle())
385                 .unwrap()
386                 .shutdown_type(),
387         );
388     }
389 
390     fn connect(&mut self, endpoint: Endpoint) -> Result<(), SystemError> {
391         if let Endpoint::Ip(_) = endpoint {
392             self.remote_endpoint = Some(endpoint);
393             Ok(())
394         } else {
395             Err(SystemError::EINVAL)
396         }
397     }
398 
399     fn ioctl(
400         &self,
401         _cmd: usize,
402         _arg0: usize,
403         _arg1: usize,
404         _arg2: usize,
405     ) -> Result<usize, SystemError> {
406         todo!()
407     }
408 
409     fn metadata(&self) -> SocketMetadata {
410         self.metadata.clone()
411     }
412 
413     fn box_clone(&self) -> Box<dyn Socket> {
414         return Box::new(self.clone());
415     }
416 
417     fn endpoint(&self) -> Option<Endpoint> {
418         let sockets = SOCKET_SET.lock_irqsave();
419         let socket = sockets.get::<udp::Socket>(self.handle.0);
420         let listen_endpoint = socket.endpoint();
421 
422         if listen_endpoint.port == 0 {
423             return None;
424         } else {
425             // 如果listen_endpoint的address是None,意味着“监听所有的地址”。
426             // 这里假设所有的地址都是ipv4
427             // TODO: 支持ipv6
428             let result = wire::IpEndpoint::new(
429                 listen_endpoint
430                     .addr
431                     .unwrap_or(wire::IpAddress::v4(0, 0, 0, 0)),
432                 listen_endpoint.port,
433             );
434             return Some(Endpoint::Ip(Some(result)));
435         }
436     }
437 
438     fn peer_endpoint(&self) -> Option<Endpoint> {
439         return self.remote_endpoint.clone();
440     }
441 
442     fn socket_handle(&self) -> SocketHandle {
443         self.handle.0
444     }
445 
446     fn as_any_ref(&self) -> &dyn core::any::Any {
447         self
448     }
449 
450     fn as_any_mut(&mut self) -> &mut dyn core::any::Any {
451         self
452     }
453 }
454 
455 /// @brief 表示 tcp socket
456 ///
457 /// https://man7.org/linux/man-pages/man7/tcp.7.html
458 #[derive(Debug, Clone)]
459 pub struct TcpSocket {
460     handle: Arc<GlobalSocketHandle>,
461     local_endpoint: Option<wire::IpEndpoint>, // save local endpoint for bind()
462     is_listening: bool,
463     metadata: SocketMetadata,
464 }
465 
466 impl TcpSocket {
467     /// 元数据的缓冲区的大小
468     pub const DEFAULT_METADATA_BUF_SIZE: usize = 1024;
469     /// 默认的接收缓冲区的大小 receive
470     pub const DEFAULT_RX_BUF_SIZE: usize = 512 * 1024;
471     /// 默认的发送缓冲区的大小 transmiss
472     pub const DEFAULT_TX_BUF_SIZE: usize = 512 * 1024;
473 
474     /// TcpSocket的特殊事件,用于在事件等待队列上sleep
475     pub const CAN_CONNECT: u64 = 1u64 << 63;
476     pub const CAN_ACCPET: u64 = 1u64 << 62;
477 
478     /// @brief 创建一个tcp的socket
479     ///
480     /// @param options socket的选项
481     ///
482     /// @return 返回创建的tcp的socket
483     pub fn new(options: SocketOptions) -> Self {
484         let rx_buffer = tcp::SocketBuffer::new(vec![0; Self::DEFAULT_RX_BUF_SIZE]);
485         let tx_buffer = tcp::SocketBuffer::new(vec![0; Self::DEFAULT_TX_BUF_SIZE]);
486         let socket = tcp::Socket::new(rx_buffer, tx_buffer);
487 
488         // 把socket添加到socket集合中,并得到socket的句柄
489         let handle: Arc<GlobalSocketHandle> =
490             GlobalSocketHandle::new(SOCKET_SET.lock_irqsave().add(socket));
491 
492         let metadata = SocketMetadata::new(
493             SocketType::Tcp,
494             Self::DEFAULT_RX_BUF_SIZE,
495             Self::DEFAULT_TX_BUF_SIZE,
496             Self::DEFAULT_METADATA_BUF_SIZE,
497             options,
498         );
499 
500         return Self {
501             handle,
502             local_endpoint: None,
503             is_listening: false,
504             metadata,
505         };
506     }
507 
508     fn do_listen(
509         &mut self,
510         socket: &mut tcp::Socket,
511         local_endpoint: wire::IpEndpoint,
512     ) -> Result<(), SystemError> {
513         let listen_result = if local_endpoint.addr.is_unspecified() {
514             // kdebug!("Tcp Socket Listen on port {}", local_endpoint.port);
515             socket.listen(local_endpoint.port)
516         } else {
517             // kdebug!("Tcp Socket Listen on {local_endpoint}");
518             socket.listen(local_endpoint)
519         };
520         // TODO: 增加端口占用检查
521         return match listen_result {
522             Ok(()) => {
523                 // kdebug!(
524                 //     "Tcp Socket Listen on {local_endpoint}, open?:{}",
525                 //     socket.is_open()
526                 // );
527                 self.is_listening = true;
528 
529                 Ok(())
530             }
531             Err(_) => Err(SystemError::EINVAL),
532         };
533     }
534 }
535 
536 impl Socket for TcpSocket {
537     fn read(&self, buf: &mut [u8]) -> (Result<usize, SystemError>, Endpoint) {
538         if HANDLE_MAP
539             .read_irqsave()
540             .get(&self.socket_handle())
541             .unwrap()
542             .shutdown_type()
543             .contains(ShutdownType::RCV_SHUTDOWN)
544         {
545             return (Err(SystemError::ENOTCONN), Endpoint::Ip(None));
546         }
547         // kdebug!("tcp socket: read, buf len={}", buf.len());
548 
549         loop {
550             poll_ifaces();
551             let mut socket_set_guard = SOCKET_SET.lock_irqsave();
552             let socket = socket_set_guard.get_mut::<tcp::Socket>(self.handle.0);
553 
554             // 如果socket已经关闭,返回错误
555             if !socket.is_active() {
556                 // kdebug!("Tcp Socket Read Error, socket is closed");
557                 return (Err(SystemError::ENOTCONN), Endpoint::Ip(None));
558             }
559 
560             if socket.may_recv() {
561                 let recv_res = socket.recv_slice(buf);
562 
563                 if let Ok(size) = recv_res {
564                     if size > 0 {
565                         let endpoint = if let Some(p) = socket.remote_endpoint() {
566                             p
567                         } else {
568                             return (Err(SystemError::ENOTCONN), Endpoint::Ip(None));
569                         };
570 
571                         drop(socket_set_guard);
572                         poll_ifaces();
573                         return (Ok(size), Endpoint::Ip(Some(endpoint)));
574                     }
575                 } else {
576                     let err = recv_res.unwrap_err();
577                     match err {
578                         tcp::RecvError::InvalidState => {
579                             kwarn!("Tcp Socket Read Error, InvalidState");
580                             return (Err(SystemError::ENOTCONN), Endpoint::Ip(None));
581                         }
582                         tcp::RecvError::Finished => {
583                             // 对端写端已关闭,我们应该关闭读端
584                             HANDLE_MAP
585                                 .write_irqsave()
586                                 .get_mut(&self.socket_handle())
587                                 .unwrap()
588                                 .shutdown_type_writer()
589                                 .insert(ShutdownType::RCV_SHUTDOWN);
590                             return (Err(SystemError::ENOTCONN), Endpoint::Ip(None));
591                         }
592                     }
593                 }
594             } else {
595                 return (Err(SystemError::ENOTCONN), Endpoint::Ip(None));
596             }
597             drop(socket_set_guard);
598             SocketHandleItem::sleep(
599                 self.socket_handle(),
600                 EPollEventType::EPOLLIN.bits() as u64,
601                 HANDLE_MAP.read_irqsave(),
602             );
603         }
604     }
605 
606     fn write(&self, buf: &[u8], _to: Option<Endpoint>) -> Result<usize, SystemError> {
607         if HANDLE_MAP
608             .read_irqsave()
609             .get(&self.socket_handle())
610             .unwrap()
611             .shutdown_type()
612             .contains(ShutdownType::RCV_SHUTDOWN)
613         {
614             return Err(SystemError::ENOTCONN);
615         }
616         let mut socket_set_guard = SOCKET_SET.lock_irqsave();
617         let socket = socket_set_guard.get_mut::<tcp::Socket>(self.handle.0);
618 
619         if socket.is_open() {
620             if socket.can_send() {
621                 match socket.send_slice(buf) {
622                     Ok(size) => {
623                         drop(socket_set_guard);
624                         poll_ifaces();
625                         return Ok(size);
626                     }
627                     Err(e) => {
628                         kerror!("Tcp Socket Write Error {e:?}");
629                         return Err(SystemError::ENOBUFS);
630                     }
631                 }
632             } else {
633                 return Err(SystemError::ENOBUFS);
634             }
635         }
636 
637         return Err(SystemError::ENOTCONN);
638     }
639 
640     fn poll(&self) -> EPollEventType {
641         let mut socket_set_guard = SOCKET_SET.lock_irqsave();
642         let socket = socket_set_guard.get_mut::<tcp::Socket>(self.handle.0);
643 
644         return SocketPollMethod::tcp_poll(
645             socket,
646             HANDLE_MAP
647                 .read_irqsave()
648                 .get(&self.socket_handle())
649                 .unwrap()
650                 .shutdown_type(),
651         );
652     }
653 
654     fn connect(&mut self, endpoint: Endpoint) -> Result<(), SystemError> {
655         let mut sockets = SOCKET_SET.lock_irqsave();
656         let socket = sockets.get_mut::<tcp::Socket>(self.handle.0);
657 
658         if let Endpoint::Ip(Some(ip)) = endpoint {
659             let temp_port = PORT_MANAGER.get_ephemeral_port(self.metadata.socket_type)?;
660             // 检测端口是否被占用
661             PORT_MANAGER.bind_port(self.metadata.socket_type, temp_port, self.handle.clone())?;
662 
663             // kdebug!("temp_port: {}", temp_port);
664             let iface: Arc<dyn NetDriver> = NET_DRIVERS.write_irqsave().get(&0).unwrap().clone();
665             let mut inner_iface = iface.inner_iface().lock();
666             // kdebug!("to connect: {ip:?}");
667 
668             match socket.connect(inner_iface.context(), ip, temp_port) {
669                 Ok(()) => {
670                     // avoid deadlock
671                     drop(inner_iface);
672                     drop(iface);
673                     drop(sockets);
674                     loop {
675                         poll_ifaces();
676                         let mut sockets = SOCKET_SET.lock_irqsave();
677                         let socket = sockets.get_mut::<tcp::Socket>(self.handle.0);
678 
679                         match socket.state() {
680                             tcp::State::Established => {
681                                 return Ok(());
682                             }
683                             tcp::State::SynSent => {
684                                 drop(sockets);
685                                 SocketHandleItem::sleep(
686                                     self.socket_handle(),
687                                     Self::CAN_CONNECT,
688                                     HANDLE_MAP.read_irqsave(),
689                                 );
690                             }
691                             _ => {
692                                 return Err(SystemError::ECONNREFUSED);
693                             }
694                         }
695                     }
696                 }
697                 Err(e) => {
698                     // kerror!("Tcp Socket Connect Error {e:?}");
699                     match e {
700                         tcp::ConnectError::InvalidState => return Err(SystemError::EISCONN),
701                         tcp::ConnectError::Unaddressable => return Err(SystemError::EADDRNOTAVAIL),
702                     }
703                 }
704             }
705         } else {
706             return Err(SystemError::EINVAL);
707         }
708     }
709 
710     /// @brief tcp socket 监听 local_endpoint 端口
711     ///
712     /// @param backlog 未处理的连接队列的最大长度. 由于smoltcp不支持backlog,所以这个参数目前无效
713     fn listen(&mut self, _backlog: usize) -> Result<(), SystemError> {
714         if self.is_listening {
715             return Ok(());
716         }
717 
718         let local_endpoint = self.local_endpoint.ok_or(SystemError::EINVAL)?;
719         let mut sockets = SOCKET_SET.lock_irqsave();
720         let socket = sockets.get_mut::<tcp::Socket>(self.handle.0);
721 
722         if socket.is_listening() {
723             // kdebug!("Tcp Socket is already listening on {local_endpoint}");
724             return Ok(());
725         }
726         // kdebug!("Tcp Socket  before listen, open={}", socket.is_open());
727         return self.do_listen(socket, local_endpoint);
728     }
729 
730     fn bind(&mut self, endpoint: Endpoint) -> Result<(), SystemError> {
731         if let Endpoint::Ip(Some(mut ip)) = endpoint {
732             if ip.port == 0 {
733                 ip.port = PORT_MANAGER.get_ephemeral_port(self.metadata.socket_type)?;
734             }
735 
736             // 检测端口是否已被占用
737             PORT_MANAGER.bind_port(self.metadata.socket_type, ip.port, self.handle.clone())?;
738 
739             self.local_endpoint = Some(ip);
740             self.is_listening = false;
741             return Ok(());
742         }
743         return Err(SystemError::EINVAL);
744     }
745 
746     fn shutdown(&mut self, shutdown_type: super::ShutdownType) -> Result<(), SystemError> {
747         // TODO:目前只是在表层判断,对端不知晓,后续需使用tcp实现
748         HANDLE_MAP
749             .write_irqsave()
750             .get_mut(&self.socket_handle())
751             .unwrap()
752             .shutdown_type = RwLock::new(shutdown_type);
753         return Ok(());
754     }
755 
756     fn accept(&mut self) -> Result<(Box<dyn Socket>, Endpoint), SystemError> {
757         let endpoint = self.local_endpoint.ok_or(SystemError::EINVAL)?;
758         loop {
759             // kdebug!("tcp accept: poll_ifaces()");
760             poll_ifaces();
761 
762             let mut sockets = SOCKET_SET.lock_irqsave();
763 
764             let socket = sockets.get_mut::<tcp::Socket>(self.handle.0);
765 
766             if socket.is_active() {
767                 // kdebug!("tcp accept: socket.is_active()");
768                 let remote_ep = socket.remote_endpoint().ok_or(SystemError::ENOTCONN)?;
769 
770                 let new_socket = {
771                     // Initialize the TCP socket's buffers.
772                     let rx_buffer = tcp::SocketBuffer::new(vec![0; Self::DEFAULT_RX_BUF_SIZE]);
773                     let tx_buffer = tcp::SocketBuffer::new(vec![0; Self::DEFAULT_TX_BUF_SIZE]);
774                     // The new TCP socket used for sending and receiving data.
775                     let mut tcp_socket = tcp::Socket::new(rx_buffer, tx_buffer);
776                     self.do_listen(&mut tcp_socket, endpoint)
777                         .expect("do_listen failed");
778 
779                     // tcp_socket.listen(endpoint).unwrap();
780 
781                     // 之所以把old_handle存入new_socket, 是因为当前时刻,smoltcp已经把old_handle对应的socket与远程的endpoint关联起来了
782                     // 因此需要再为当前的socket分配一个新的handle
783                     let new_handle = GlobalSocketHandle::new(sockets.add(tcp_socket));
784                     let old_handle = ::core::mem::replace(&mut self.handle, new_handle.clone());
785 
786                     // 更新端口与 handle 的绑定
787                     if let Some(Endpoint::Ip(Some(ip))) = self.endpoint() {
788                         PORT_MANAGER.unbind_port(self.metadata.socket_type, ip.port)?;
789                         PORT_MANAGER.bind_port(
790                             self.metadata.socket_type,
791                             ip.port,
792                             new_handle.clone(),
793                         )?;
794                     }
795 
796                     let metadata = SocketMetadata::new(
797                         SocketType::Tcp,
798                         Self::DEFAULT_TX_BUF_SIZE,
799                         Self::DEFAULT_RX_BUF_SIZE,
800                         Self::DEFAULT_METADATA_BUF_SIZE,
801                         self.metadata.options,
802                     );
803 
804                     let new_socket = Box::new(TcpSocket {
805                         handle: old_handle.clone(),
806                         local_endpoint: self.local_endpoint,
807                         is_listening: false,
808                         metadata,
809                     });
810 
811                     // 更新handle表
812                     let mut handle_guard = HANDLE_MAP.write_irqsave();
813                     // 先删除原来的
814                     let item = handle_guard.remove(&old_handle.0).unwrap();
815                     // 按照smoltcp行为,将新的handle绑定到原来的item
816                     handle_guard.insert(new_handle.0, item);
817                     let new_item = SocketHandleItem::new();
818                     // 插入新的item
819                     handle_guard.insert(old_handle.0, new_item);
820 
821                     new_socket
822                 };
823                 // kdebug!("tcp accept: new socket: {:?}", new_socket);
824                 drop(sockets);
825                 poll_ifaces();
826 
827                 return Ok((new_socket, Endpoint::Ip(Some(remote_ep))));
828             }
829             drop(sockets);
830 
831             SocketHandleItem::sleep(
832                 self.socket_handle(),
833                 Self::CAN_ACCPET,
834                 HANDLE_MAP.read_irqsave(),
835             );
836         }
837     }
838 
839     fn endpoint(&self) -> Option<Endpoint> {
840         let mut result: Option<Endpoint> = self.local_endpoint.map(|x| Endpoint::Ip(Some(x)));
841 
842         if result.is_none() {
843             let sockets = SOCKET_SET.lock_irqsave();
844             let socket = sockets.get::<tcp::Socket>(self.handle.0);
845             if let Some(ep) = socket.local_endpoint() {
846                 result = Some(Endpoint::Ip(Some(ep)));
847             }
848         }
849         return result;
850     }
851 
852     fn peer_endpoint(&self) -> Option<Endpoint> {
853         let sockets = SOCKET_SET.lock_irqsave();
854         let socket = sockets.get::<tcp::Socket>(self.handle.0);
855         return socket.remote_endpoint().map(|x| Endpoint::Ip(Some(x)));
856     }
857 
858     fn metadata(&self) -> SocketMetadata {
859         self.metadata.clone()
860     }
861 
862     fn box_clone(&self) -> Box<dyn Socket> {
863         Box::new(self.clone())
864     }
865 
866     fn socket_handle(&self) -> SocketHandle {
867         self.handle.0
868     }
869 
870     fn as_any_ref(&self) -> &dyn core::any::Any {
871         self
872     }
873 
874     fn as_any_mut(&mut self) -> &mut dyn core::any::Any {
875         self
876     }
877 }
878