xref: /DragonOS/kernel/src/net/socket/inet.rs (revision 911132c4b8ea0e9c49a4e84b9fa1db114102acbb)
1 use alloc::{boxed::Box, sync::Arc, vec::Vec};
2 use smoltcp::{
3     iface::SocketHandle,
4     socket::{raw, tcp, udp},
5     wire,
6 };
7 use system_error::SystemError;
8 
9 use crate::{
10     driver::net::NetDriver,
11     kerror, kwarn,
12     libs::rwlock::RwLock,
13     net::{
14         event_poll::EPollEventType, net_core::poll_ifaces, Endpoint, Protocol, ShutdownType,
15         NET_DRIVERS,
16     },
17 };
18 
19 use super::{
20     GlobalSocketHandle, Socket, SocketHandleItem, SocketMetadata, SocketOptions, SocketPollMethod,
21     SocketType, HANDLE_MAP, PORT_MANAGER, SOCKET_SET,
22 };
23 
24 /// @brief 表示原始的socket。原始套接字绕过传输层协议(如 TCP 或 UDP)并提供对网络层协议(如 IP)的直接访问。
25 ///
26 /// ref: https://man7.org/linux/man-pages/man7/raw.7.html
27 #[derive(Debug, Clone)]
28 pub struct RawSocket {
29     handle: Arc<GlobalSocketHandle>,
30     /// 用户发送的数据包是否包含了IP头.
31     /// 如果是true,用户发送的数据包,必须包含IP头。(即用户要自行设置IP头+数据)
32     /// 如果是false,用户发送的数据包,不包含IP头。(即用户只要设置数据)
33     header_included: bool,
34     /// socket的metadata
35     metadata: SocketMetadata,
36 }
37 
38 impl RawSocket {
39     /// 元数据的缓冲区的大小
40     pub const DEFAULT_METADATA_BUF_SIZE: usize = 1024;
41     /// 默认的接收缓冲区的大小 receive
42     pub const DEFAULT_RX_BUF_SIZE: usize = 64 * 1024;
43     /// 默认的发送缓冲区的大小 transmiss
44     pub const DEFAULT_TX_BUF_SIZE: usize = 64 * 1024;
45 
46     /// @brief 创建一个原始的socket
47     ///
48     /// @param protocol 协议号
49     /// @param options socket的选项
50     ///
51     /// @return 返回创建的原始的socket
52     pub fn new(protocol: Protocol, options: SocketOptions) -> Self {
53         let rx_buffer = raw::PacketBuffer::new(
54             vec![raw::PacketMetadata::EMPTY; Self::DEFAULT_METADATA_BUF_SIZE],
55             vec![0; Self::DEFAULT_RX_BUF_SIZE],
56         );
57         let tx_buffer = raw::PacketBuffer::new(
58             vec![raw::PacketMetadata::EMPTY; Self::DEFAULT_METADATA_BUF_SIZE],
59             vec![0; Self::DEFAULT_TX_BUF_SIZE],
60         );
61         let protocol: u8 = protocol.into();
62         let socket = raw::Socket::new(
63             wire::IpVersion::Ipv4,
64             wire::IpProtocol::from(protocol),
65             rx_buffer,
66             tx_buffer,
67         );
68 
69         // 把socket添加到socket集合中,并得到socket的句柄
70         let handle: Arc<GlobalSocketHandle> =
71             GlobalSocketHandle::new(SOCKET_SET.lock_irqsave().add(socket));
72 
73         let metadata = SocketMetadata::new(
74             SocketType::Raw,
75             Self::DEFAULT_RX_BUF_SIZE,
76             Self::DEFAULT_TX_BUF_SIZE,
77             Self::DEFAULT_METADATA_BUF_SIZE,
78             options,
79         );
80 
81         return Self {
82             handle,
83             header_included: false,
84             metadata,
85         };
86     }
87 }
88 
89 impl Socket for RawSocket {
90     fn read(&self, buf: &mut [u8]) -> (Result<usize, SystemError>, Endpoint) {
91         poll_ifaces();
92         loop {
93             // 如何优化这里?
94             let mut socket_set_guard = SOCKET_SET.lock_irqsave();
95             let socket = socket_set_guard.get_mut::<raw::Socket>(self.handle.0);
96 
97             match socket.recv_slice(buf) {
98                 Ok(len) => {
99                     let packet = wire::Ipv4Packet::new_unchecked(buf);
100                     return (
101                         Ok(len),
102                         Endpoint::Ip(Some(wire::IpEndpoint {
103                             addr: wire::IpAddress::Ipv4(packet.src_addr()),
104                             port: 0,
105                         })),
106                     );
107                 }
108                 Err(raw::RecvError::Exhausted) => {
109                     if !self.metadata.options.contains(SocketOptions::BLOCK) {
110                         // 如果是非阻塞的socket,就返回错误
111                         return (Err(SystemError::EAGAIN_OR_EWOULDBLOCK), Endpoint::Ip(None));
112                     }
113                 }
114             }
115             drop(socket_set_guard);
116             SocketHandleItem::sleep(
117                 self.socket_handle(),
118                 EPollEventType::EPOLLIN.bits() as u64,
119                 HANDLE_MAP.read_irqsave(),
120             );
121         }
122     }
123 
124     fn write(&self, buf: &[u8], to: Option<Endpoint>) -> Result<usize, SystemError> {
125         // 如果用户发送的数据包,包含IP头,则直接发送
126         if self.header_included {
127             let mut socket_set_guard = SOCKET_SET.lock_irqsave();
128             let socket = socket_set_guard.get_mut::<raw::Socket>(self.handle.0);
129             match socket.send_slice(buf) {
130                 Ok(_) => {
131                     return Ok(buf.len());
132                 }
133                 Err(raw::SendError::BufferFull) => {
134                     return Err(SystemError::ENOBUFS);
135                 }
136             }
137         } else {
138             // 如果用户发送的数据包,不包含IP头,则需要自己构造IP头
139 
140             if let Some(Endpoint::Ip(Some(endpoint))) = to {
141                 let mut socket_set_guard = SOCKET_SET.lock_irqsave();
142                 let socket: &mut raw::Socket =
143                     socket_set_guard.get_mut::<raw::Socket>(self.handle.0);
144 
145                 // 暴力解决方案:只考虑0号网卡。 TODO:考虑多网卡的情况!!!
146                 let iface = NET_DRIVERS.read_irqsave().get(&0).unwrap().clone();
147 
148                 // 构造IP头
149                 let ipv4_src_addr: Option<wire::Ipv4Address> =
150                     iface.inner_iface().lock().ipv4_addr();
151                 if ipv4_src_addr.is_none() {
152                     return Err(SystemError::ENETUNREACH);
153                 }
154                 let ipv4_src_addr = ipv4_src_addr.unwrap();
155 
156                 if let wire::IpAddress::Ipv4(ipv4_dst) = endpoint.addr {
157                     let len = buf.len();
158 
159                     // 创建20字节的IPv4头部
160                     let mut buffer: Vec<u8> = vec![0u8; len + 20];
161                     let mut packet: wire::Ipv4Packet<&mut Vec<u8>> =
162                         wire::Ipv4Packet::new_unchecked(&mut buffer);
163 
164                     // 封装ipv4 header
165                     packet.set_version(4);
166                     packet.set_header_len(20);
167                     packet.set_total_len((20 + len) as u16);
168                     packet.set_src_addr(ipv4_src_addr);
169                     packet.set_dst_addr(ipv4_dst);
170 
171                     // 设置ipv4 header的protocol字段
172                     packet.set_next_header(socket.ip_protocol());
173 
174                     // 获取IP数据包的负载字段
175                     let payload: &mut [u8] = packet.payload_mut();
176                     payload.copy_from_slice(buf);
177 
178                     // 填充checksum字段
179                     packet.fill_checksum();
180 
181                     // 发送数据包
182                     socket.send_slice(&buffer).unwrap();
183 
184                     iface.poll(&mut socket_set_guard).ok();
185 
186                     drop(socket_set_guard);
187                     return Ok(len);
188                 } else {
189                     kwarn!("Unsupport Ip protocol type!");
190                     return Err(SystemError::EINVAL);
191                 }
192             } else {
193                 // 如果没有指定目的地址,则返回错误
194                 return Err(SystemError::ENOTCONN);
195             }
196         }
197     }
198 
199     fn connect(&mut self, _endpoint: Endpoint) -> Result<(), SystemError> {
200         Ok(())
201     }
202 
203     fn metadata(&self) -> SocketMetadata {
204         self.metadata.clone()
205     }
206 
207     fn box_clone(&self) -> Box<dyn Socket> {
208         Box::new(self.clone())
209     }
210 
211     fn socket_handle(&self) -> SocketHandle {
212         self.handle.0
213     }
214 
215     fn as_any_ref(&self) -> &dyn core::any::Any {
216         self
217     }
218 
219     fn as_any_mut(&mut self) -> &mut dyn core::any::Any {
220         self
221     }
222 }
223 
224 /// @brief 表示udp socket
225 ///
226 /// https://man7.org/linux/man-pages/man7/udp.7.html
227 #[derive(Debug, Clone)]
228 pub struct UdpSocket {
229     pub handle: Arc<GlobalSocketHandle>,
230     remote_endpoint: Option<Endpoint>, // 记录远程endpoint提供给connect(), 应该使用IP地址。
231     metadata: SocketMetadata,
232 }
233 
234 impl UdpSocket {
235     /// 元数据的缓冲区的大小
236     pub const DEFAULT_METADATA_BUF_SIZE: usize = 1024;
237     /// 默认的接收缓冲区的大小 receive
238     pub const DEFAULT_RX_BUF_SIZE: usize = 64 * 1024;
239     /// 默认的发送缓冲区的大小 transmiss
240     pub const DEFAULT_TX_BUF_SIZE: usize = 64 * 1024;
241 
242     /// @brief 创建一个udp的socket
243     ///
244     /// @param options socket的选项
245     ///
246     /// @return 返回创建的udp的socket
247     pub fn new(options: SocketOptions) -> Self {
248         let rx_buffer = udp::PacketBuffer::new(
249             vec![udp::PacketMetadata::EMPTY; Self::DEFAULT_METADATA_BUF_SIZE],
250             vec![0; Self::DEFAULT_RX_BUF_SIZE],
251         );
252         let tx_buffer = udp::PacketBuffer::new(
253             vec![udp::PacketMetadata::EMPTY; Self::DEFAULT_METADATA_BUF_SIZE],
254             vec![0; Self::DEFAULT_TX_BUF_SIZE],
255         );
256         let socket = udp::Socket::new(rx_buffer, tx_buffer);
257 
258         // 把socket添加到socket集合中,并得到socket的句柄
259         let handle: Arc<GlobalSocketHandle> =
260             GlobalSocketHandle::new(SOCKET_SET.lock_irqsave().add(socket));
261 
262         let metadata = SocketMetadata::new(
263             SocketType::Udp,
264             Self::DEFAULT_RX_BUF_SIZE,
265             Self::DEFAULT_TX_BUF_SIZE,
266             Self::DEFAULT_METADATA_BUF_SIZE,
267             options,
268         );
269 
270         return Self {
271             handle,
272             remote_endpoint: None,
273             metadata,
274         };
275     }
276 
277     fn do_bind(&self, socket: &mut udp::Socket, endpoint: Endpoint) -> Result<(), SystemError> {
278         if let Endpoint::Ip(Some(ip)) = endpoint {
279             // 检测端口是否已被占用
280             PORT_MANAGER.bind_port(self.metadata.socket_type, ip.port, self.handle.clone())?;
281 
282             let bind_res = if ip.addr.is_unspecified() {
283                 socket.bind(ip.port)
284             } else {
285                 socket.bind(ip)
286             };
287 
288             match bind_res {
289                 Ok(()) => return Ok(()),
290                 Err(_) => return Err(SystemError::EINVAL),
291             }
292         } else {
293             return Err(SystemError::EINVAL);
294         }
295     }
296 }
297 
298 impl Socket for UdpSocket {
299     /// @brief 在read函数执行之前,请先bind到本地的指定端口
300     fn read(&self, buf: &mut [u8]) -> (Result<usize, SystemError>, Endpoint) {
301         loop {
302             // kdebug!("Wait22 to Read");
303             poll_ifaces();
304             let mut socket_set_guard = SOCKET_SET.lock_irqsave();
305             let socket = socket_set_guard.get_mut::<udp::Socket>(self.handle.0);
306 
307             // kdebug!("Wait to Read");
308 
309             if socket.can_recv() {
310                 if let Ok((size, remote_endpoint)) = socket.recv_slice(buf) {
311                     drop(socket_set_guard);
312                     poll_ifaces();
313                     return (Ok(size), Endpoint::Ip(Some(remote_endpoint)));
314                 }
315             } else {
316                 // 如果socket没有连接,则忙等
317                 // return (Err(SystemError::ENOTCONN), Endpoint::Ip(None));
318             }
319             drop(socket_set_guard);
320             SocketHandleItem::sleep(
321                 self.socket_handle(),
322                 EPollEventType::EPOLLIN.bits() as u64,
323                 HANDLE_MAP.read_irqsave(),
324             );
325         }
326     }
327 
328     fn write(&self, buf: &[u8], to: Option<Endpoint>) -> Result<usize, SystemError> {
329         // kdebug!("udp to send: {:?}, len={}", to, buf.len());
330         let remote_endpoint: &wire::IpEndpoint = {
331             if let Some(Endpoint::Ip(Some(ref endpoint))) = to {
332                 endpoint
333             } else if let Some(Endpoint::Ip(Some(ref endpoint))) = self.remote_endpoint {
334                 endpoint
335             } else {
336                 return Err(SystemError::ENOTCONN);
337             }
338         };
339         // kdebug!("udp write: remote = {:?}", remote_endpoint);
340 
341         let mut socket_set_guard = SOCKET_SET.lock_irqsave();
342         let socket = socket_set_guard.get_mut::<udp::Socket>(self.handle.0);
343         // kdebug!("is open()={}", socket.is_open());
344         // kdebug!("socket endpoint={:?}", socket.endpoint());
345         if socket.endpoint().port == 0 {
346             let temp_port = PORT_MANAGER.get_ephemeral_port(self.metadata.socket_type)?;
347 
348             let local_ep = match remote_endpoint.addr {
349                 // 远程remote endpoint使用什么协议,发送的时候使用的协议是一样的吧
350                 // 否则就用 self.endpoint().addr.unwrap()
351                 wire::IpAddress::Ipv4(_) => Endpoint::Ip(Some(wire::IpEndpoint::new(
352                     wire::IpAddress::Ipv4(wire::Ipv4Address::UNSPECIFIED),
353                     temp_port,
354                 ))),
355                 wire::IpAddress::Ipv6(_) => Endpoint::Ip(Some(wire::IpEndpoint::new(
356                     wire::IpAddress::Ipv6(wire::Ipv6Address::UNSPECIFIED),
357                     temp_port,
358                 ))),
359             };
360             // kdebug!("udp write: local_ep = {:?}", local_ep);
361             self.do_bind(socket, local_ep)?;
362         }
363         // kdebug!("is open()={}", socket.is_open());
364         if socket.can_send() {
365             // kdebug!("udp write: can send");
366             match socket.send_slice(buf, *remote_endpoint) {
367                 Ok(()) => {
368                     // kdebug!("udp write: send ok");
369                     drop(socket_set_guard);
370                     poll_ifaces();
371                     return Ok(buf.len());
372                 }
373                 Err(_) => {
374                     // kdebug!("udp write: send err");
375                     return Err(SystemError::ENOBUFS);
376                 }
377             }
378         } else {
379             // kdebug!("udp write: can not send");
380             return Err(SystemError::ENOBUFS);
381         };
382     }
383 
384     fn bind(&mut self, endpoint: Endpoint) -> Result<(), SystemError> {
385         let mut sockets = SOCKET_SET.lock_irqsave();
386         let socket = sockets.get_mut::<udp::Socket>(self.handle.0);
387         // kdebug!("UDP Bind to {:?}", endpoint);
388         return self.do_bind(socket, endpoint);
389     }
390 
391     fn poll(&self) -> EPollEventType {
392         let sockets = SOCKET_SET.lock_irqsave();
393         let socket = sockets.get::<udp::Socket>(self.handle.0);
394 
395         return SocketPollMethod::udp_poll(
396             socket,
397             HANDLE_MAP
398                 .read_irqsave()
399                 .get(&self.socket_handle())
400                 .unwrap()
401                 .shutdown_type(),
402         );
403     }
404 
405     fn connect(&mut self, endpoint: Endpoint) -> Result<(), SystemError> {
406         if let Endpoint::Ip(_) = endpoint {
407             self.remote_endpoint = Some(endpoint);
408             Ok(())
409         } else {
410             Err(SystemError::EINVAL)
411         }
412     }
413 
414     fn ioctl(
415         &self,
416         _cmd: usize,
417         _arg0: usize,
418         _arg1: usize,
419         _arg2: usize,
420     ) -> Result<usize, SystemError> {
421         todo!()
422     }
423 
424     fn metadata(&self) -> SocketMetadata {
425         self.metadata.clone()
426     }
427 
428     fn box_clone(&self) -> Box<dyn Socket> {
429         return Box::new(self.clone());
430     }
431 
432     fn endpoint(&self) -> Option<Endpoint> {
433         let sockets = SOCKET_SET.lock_irqsave();
434         let socket = sockets.get::<udp::Socket>(self.handle.0);
435         let listen_endpoint = socket.endpoint();
436 
437         if listen_endpoint.port == 0 {
438             return None;
439         } else {
440             // 如果listen_endpoint的address是None,意味着“监听所有的地址”。
441             // 这里假设所有的地址都是ipv4
442             // TODO: 支持ipv6
443             let result = wire::IpEndpoint::new(
444                 listen_endpoint
445                     .addr
446                     .unwrap_or(wire::IpAddress::v4(0, 0, 0, 0)),
447                 listen_endpoint.port,
448             );
449             return Some(Endpoint::Ip(Some(result)));
450         }
451     }
452 
453     fn peer_endpoint(&self) -> Option<Endpoint> {
454         return self.remote_endpoint.clone();
455     }
456 
457     fn socket_handle(&self) -> SocketHandle {
458         self.handle.0
459     }
460 
461     fn as_any_ref(&self) -> &dyn core::any::Any {
462         self
463     }
464 
465     fn as_any_mut(&mut self) -> &mut dyn core::any::Any {
466         self
467     }
468 }
469 
470 /// @brief 表示 tcp socket
471 ///
472 /// https://man7.org/linux/man-pages/man7/tcp.7.html
473 #[derive(Debug, Clone)]
474 pub struct TcpSocket {
475     handle: Arc<GlobalSocketHandle>,
476     local_endpoint: Option<wire::IpEndpoint>, // save local endpoint for bind()
477     is_listening: bool,
478     metadata: SocketMetadata,
479 }
480 
481 impl TcpSocket {
482     /// 元数据的缓冲区的大小
483     pub const DEFAULT_METADATA_BUF_SIZE: usize = 1024;
484     /// 默认的接收缓冲区的大小 receive
485     pub const DEFAULT_RX_BUF_SIZE: usize = 512 * 1024;
486     /// 默认的发送缓冲区的大小 transmiss
487     pub const DEFAULT_TX_BUF_SIZE: usize = 512 * 1024;
488 
489     /// TcpSocket的特殊事件,用于在事件等待队列上sleep
490     pub const CAN_CONNECT: u64 = 1u64 << 63;
491     pub const CAN_ACCPET: u64 = 1u64 << 62;
492 
493     /// @brief 创建一个tcp的socket
494     ///
495     /// @param options socket的选项
496     ///
497     /// @return 返回创建的tcp的socket
498     pub fn new(options: SocketOptions) -> Self {
499         let rx_buffer = tcp::SocketBuffer::new(vec![0; Self::DEFAULT_RX_BUF_SIZE]);
500         let tx_buffer = tcp::SocketBuffer::new(vec![0; Self::DEFAULT_TX_BUF_SIZE]);
501         let socket = tcp::Socket::new(rx_buffer, tx_buffer);
502 
503         // 把socket添加到socket集合中,并得到socket的句柄
504         let handle: Arc<GlobalSocketHandle> =
505             GlobalSocketHandle::new(SOCKET_SET.lock_irqsave().add(socket));
506 
507         let metadata = SocketMetadata::new(
508             SocketType::Tcp,
509             Self::DEFAULT_RX_BUF_SIZE,
510             Self::DEFAULT_TX_BUF_SIZE,
511             Self::DEFAULT_METADATA_BUF_SIZE,
512             options,
513         );
514 
515         return Self {
516             handle,
517             local_endpoint: None,
518             is_listening: false,
519             metadata,
520         };
521     }
522 
523     fn do_listen(
524         &mut self,
525         socket: &mut tcp::Socket,
526         local_endpoint: wire::IpEndpoint,
527     ) -> Result<(), SystemError> {
528         let listen_result = if local_endpoint.addr.is_unspecified() {
529             // kdebug!("Tcp Socket Listen on port {}", local_endpoint.port);
530             socket.listen(local_endpoint.port)
531         } else {
532             // kdebug!("Tcp Socket Listen on {local_endpoint}");
533             socket.listen(local_endpoint)
534         };
535         // TODO: 增加端口占用检查
536         return match listen_result {
537             Ok(()) => {
538                 // kdebug!(
539                 //     "Tcp Socket Listen on {local_endpoint}, open?:{}",
540                 //     socket.is_open()
541                 // );
542                 self.is_listening = true;
543 
544                 Ok(())
545             }
546             Err(_) => Err(SystemError::EINVAL),
547         };
548     }
549 }
550 
551 impl Socket for TcpSocket {
552     fn read(&self, buf: &mut [u8]) -> (Result<usize, SystemError>, Endpoint) {
553         if HANDLE_MAP
554             .read_irqsave()
555             .get(&self.socket_handle())
556             .unwrap()
557             .shutdown_type()
558             .contains(ShutdownType::RCV_SHUTDOWN)
559         {
560             return (Err(SystemError::ENOTCONN), Endpoint::Ip(None));
561         }
562         // kdebug!("tcp socket: read, buf len={}", buf.len());
563 
564         loop {
565             poll_ifaces();
566             let mut socket_set_guard = SOCKET_SET.lock_irqsave();
567             let socket = socket_set_guard.get_mut::<tcp::Socket>(self.handle.0);
568 
569             // 如果socket已经关闭,返回错误
570             if !socket.is_active() {
571                 // kdebug!("Tcp Socket Read Error, socket is closed");
572                 return (Err(SystemError::ENOTCONN), Endpoint::Ip(None));
573             }
574 
575             if socket.may_recv() {
576                 let recv_res = socket.recv_slice(buf);
577 
578                 if let Ok(size) = recv_res {
579                     if size > 0 {
580                         let endpoint = if let Some(p) = socket.remote_endpoint() {
581                             p
582                         } else {
583                             return (Err(SystemError::ENOTCONN), Endpoint::Ip(None));
584                         };
585 
586                         drop(socket_set_guard);
587                         poll_ifaces();
588                         return (Ok(size), Endpoint::Ip(Some(endpoint)));
589                     }
590                 } else {
591                     let err = recv_res.unwrap_err();
592                     match err {
593                         tcp::RecvError::InvalidState => {
594                             kwarn!("Tcp Socket Read Error, InvalidState");
595                             return (Err(SystemError::ENOTCONN), Endpoint::Ip(None));
596                         }
597                         tcp::RecvError::Finished => {
598                             // 对端写端已关闭,我们应该关闭读端
599                             HANDLE_MAP
600                                 .write_irqsave()
601                                 .get_mut(&self.socket_handle())
602                                 .unwrap()
603                                 .shutdown_type_writer()
604                                 .insert(ShutdownType::RCV_SHUTDOWN);
605                             return (Err(SystemError::ENOTCONN), Endpoint::Ip(None));
606                         }
607                     }
608                 }
609             } else {
610                 return (Err(SystemError::ENOTCONN), Endpoint::Ip(None));
611             }
612             drop(socket_set_guard);
613             SocketHandleItem::sleep(
614                 self.socket_handle(),
615                 EPollEventType::EPOLLIN.bits() as u64,
616                 HANDLE_MAP.read_irqsave(),
617             );
618         }
619     }
620 
621     fn write(&self, buf: &[u8], _to: Option<Endpoint>) -> Result<usize, SystemError> {
622         if HANDLE_MAP
623             .read_irqsave()
624             .get(&self.socket_handle())
625             .unwrap()
626             .shutdown_type()
627             .contains(ShutdownType::RCV_SHUTDOWN)
628         {
629             return Err(SystemError::ENOTCONN);
630         }
631         let mut socket_set_guard = SOCKET_SET.lock_irqsave();
632         let socket = socket_set_guard.get_mut::<tcp::Socket>(self.handle.0);
633 
634         if socket.is_open() {
635             if socket.can_send() {
636                 match socket.send_slice(buf) {
637                     Ok(size) => {
638                         drop(socket_set_guard);
639                         poll_ifaces();
640                         return Ok(size);
641                     }
642                     Err(e) => {
643                         kerror!("Tcp Socket Write Error {e:?}");
644                         return Err(SystemError::ENOBUFS);
645                     }
646                 }
647             } else {
648                 return Err(SystemError::ENOBUFS);
649             }
650         }
651 
652         return Err(SystemError::ENOTCONN);
653     }
654 
655     fn poll(&self) -> EPollEventType {
656         let mut socket_set_guard = SOCKET_SET.lock_irqsave();
657         let socket = socket_set_guard.get_mut::<tcp::Socket>(self.handle.0);
658 
659         return SocketPollMethod::tcp_poll(
660             socket,
661             HANDLE_MAP
662                 .read_irqsave()
663                 .get(&self.socket_handle())
664                 .unwrap()
665                 .shutdown_type(),
666         );
667     }
668 
669     fn connect(&mut self, endpoint: Endpoint) -> Result<(), SystemError> {
670         let mut sockets = SOCKET_SET.lock_irqsave();
671         let socket = sockets.get_mut::<tcp::Socket>(self.handle.0);
672 
673         if let Endpoint::Ip(Some(ip)) = endpoint {
674             let temp_port = PORT_MANAGER.get_ephemeral_port(self.metadata.socket_type)?;
675             // 检测端口是否被占用
676             PORT_MANAGER.bind_port(self.metadata.socket_type, temp_port, self.handle.clone())?;
677 
678             // kdebug!("temp_port: {}", temp_port);
679             let iface: Arc<dyn NetDriver> = NET_DRIVERS.write_irqsave().get(&0).unwrap().clone();
680             let mut inner_iface = iface.inner_iface().lock();
681             // kdebug!("to connect: {ip:?}");
682 
683             match socket.connect(inner_iface.context(), ip, temp_port) {
684                 Ok(()) => {
685                     // avoid deadlock
686                     drop(inner_iface);
687                     drop(iface);
688                     drop(sockets);
689                     loop {
690                         poll_ifaces();
691                         let mut sockets = SOCKET_SET.lock_irqsave();
692                         let socket = sockets.get_mut::<tcp::Socket>(self.handle.0);
693 
694                         match socket.state() {
695                             tcp::State::Established => {
696                                 return Ok(());
697                             }
698                             tcp::State::SynSent => {
699                                 drop(sockets);
700                                 SocketHandleItem::sleep(
701                                     self.socket_handle(),
702                                     Self::CAN_CONNECT,
703                                     HANDLE_MAP.read_irqsave(),
704                                 );
705                             }
706                             _ => {
707                                 return Err(SystemError::ECONNREFUSED);
708                             }
709                         }
710                     }
711                 }
712                 Err(e) => {
713                     // kerror!("Tcp Socket Connect Error {e:?}");
714                     match e {
715                         tcp::ConnectError::InvalidState => return Err(SystemError::EISCONN),
716                         tcp::ConnectError::Unaddressable => return Err(SystemError::EADDRNOTAVAIL),
717                     }
718                 }
719             }
720         } else {
721             return Err(SystemError::EINVAL);
722         }
723     }
724 
725     /// @brief tcp socket 监听 local_endpoint 端口
726     ///
727     /// @param backlog 未处理的连接队列的最大长度. 由于smoltcp不支持backlog,所以这个参数目前无效
728     fn listen(&mut self, _backlog: usize) -> Result<(), SystemError> {
729         if self.is_listening {
730             return Ok(());
731         }
732 
733         let local_endpoint = self.local_endpoint.ok_or(SystemError::EINVAL)?;
734         let mut sockets = SOCKET_SET.lock_irqsave();
735         let socket = sockets.get_mut::<tcp::Socket>(self.handle.0);
736 
737         if socket.is_listening() {
738             // kdebug!("Tcp Socket is already listening on {local_endpoint}");
739             return Ok(());
740         }
741         // kdebug!("Tcp Socket  before listen, open={}", socket.is_open());
742         return self.do_listen(socket, local_endpoint);
743     }
744 
745     fn bind(&mut self, endpoint: Endpoint) -> Result<(), SystemError> {
746         if let Endpoint::Ip(Some(mut ip)) = endpoint {
747             if ip.port == 0 {
748                 ip.port = PORT_MANAGER.get_ephemeral_port(self.metadata.socket_type)?;
749             }
750 
751             // 检测端口是否已被占用
752             PORT_MANAGER.bind_port(self.metadata.socket_type, ip.port, self.handle.clone())?;
753 
754             self.local_endpoint = Some(ip);
755             self.is_listening = false;
756             return Ok(());
757         }
758         return Err(SystemError::EINVAL);
759     }
760 
761     fn shutdown(&mut self, shutdown_type: super::ShutdownType) -> Result<(), SystemError> {
762         // TODO:目前只是在表层判断,对端不知晓,后续需使用tcp实现
763         HANDLE_MAP
764             .write_irqsave()
765             .get_mut(&self.socket_handle())
766             .unwrap()
767             .shutdown_type = RwLock::new(shutdown_type);
768         return Ok(());
769     }
770 
771     fn accept(&mut self) -> Result<(Box<dyn Socket>, Endpoint), SystemError> {
772         let endpoint = self.local_endpoint.ok_or(SystemError::EINVAL)?;
773         loop {
774             // kdebug!("tcp accept: poll_ifaces()");
775             poll_ifaces();
776 
777             let mut sockets = SOCKET_SET.lock_irqsave();
778 
779             let socket = sockets.get_mut::<tcp::Socket>(self.handle.0);
780 
781             if socket.is_active() {
782                 // kdebug!("tcp accept: socket.is_active()");
783                 let remote_ep = socket.remote_endpoint().ok_or(SystemError::ENOTCONN)?;
784 
785                 let new_socket = {
786                     // Initialize the TCP socket's buffers.
787                     let rx_buffer = tcp::SocketBuffer::new(vec![0; Self::DEFAULT_RX_BUF_SIZE]);
788                     let tx_buffer = tcp::SocketBuffer::new(vec![0; Self::DEFAULT_TX_BUF_SIZE]);
789                     // The new TCP socket used for sending and receiving data.
790                     let mut tcp_socket = tcp::Socket::new(rx_buffer, tx_buffer);
791                     self.do_listen(&mut tcp_socket, endpoint)
792                         .expect("do_listen failed");
793 
794                     // tcp_socket.listen(endpoint).unwrap();
795 
796                     // 之所以把old_handle存入new_socket, 是因为当前时刻,smoltcp已经把old_handle对应的socket与远程的endpoint关联起来了
797                     // 因此需要再为当前的socket分配一个新的handle
798                     let new_handle = GlobalSocketHandle::new(sockets.add(tcp_socket));
799                     let old_handle = ::core::mem::replace(&mut self.handle, new_handle.clone());
800 
801                     // 更新端口与 handle 的绑定
802                     if let Some(Endpoint::Ip(Some(ip))) = self.endpoint() {
803                         PORT_MANAGER.unbind_port(self.metadata.socket_type, ip.port)?;
804                         PORT_MANAGER.bind_port(
805                             self.metadata.socket_type,
806                             ip.port,
807                             new_handle.clone(),
808                         )?;
809                     }
810 
811                     let metadata = SocketMetadata::new(
812                         SocketType::Tcp,
813                         Self::DEFAULT_TX_BUF_SIZE,
814                         Self::DEFAULT_RX_BUF_SIZE,
815                         Self::DEFAULT_METADATA_BUF_SIZE,
816                         self.metadata.options,
817                     );
818 
819                     let new_socket = Box::new(TcpSocket {
820                         handle: old_handle.clone(),
821                         local_endpoint: self.local_endpoint,
822                         is_listening: false,
823                         metadata,
824                     });
825 
826                     // 更新handle表
827                     let mut handle_guard = HANDLE_MAP.write_irqsave();
828                     // 先删除原来的
829                     let item = handle_guard.remove(&old_handle.0).unwrap();
830                     // 按照smoltcp行为,将新的handle绑定到原来的item
831                     handle_guard.insert(new_handle.0, item);
832                     let new_item = SocketHandleItem::new();
833                     // 插入新的item
834                     handle_guard.insert(old_handle.0, new_item);
835 
836                     new_socket
837                 };
838                 // kdebug!("tcp accept: new socket: {:?}", new_socket);
839                 drop(sockets);
840                 poll_ifaces();
841 
842                 return Ok((new_socket, Endpoint::Ip(Some(remote_ep))));
843             }
844             drop(sockets);
845 
846             SocketHandleItem::sleep(
847                 self.socket_handle(),
848                 Self::CAN_ACCPET,
849                 HANDLE_MAP.read_irqsave(),
850             );
851         }
852     }
853 
854     fn endpoint(&self) -> Option<Endpoint> {
855         let mut result: Option<Endpoint> = self.local_endpoint.map(|x| Endpoint::Ip(Some(x)));
856 
857         if result.is_none() {
858             let sockets = SOCKET_SET.lock_irqsave();
859             let socket = sockets.get::<tcp::Socket>(self.handle.0);
860             if let Some(ep) = socket.local_endpoint() {
861                 result = Some(Endpoint::Ip(Some(ep)));
862             }
863         }
864         return result;
865     }
866 
867     fn peer_endpoint(&self) -> Option<Endpoint> {
868         let sockets = SOCKET_SET.lock_irqsave();
869         let socket = sockets.get::<tcp::Socket>(self.handle.0);
870         return socket.remote_endpoint().map(|x| Endpoint::Ip(Some(x)));
871     }
872 
873     fn metadata(&self) -> SocketMetadata {
874         self.metadata.clone()
875     }
876 
877     fn box_clone(&self) -> Box<dyn Socket> {
878         Box::new(self.clone())
879     }
880 
881     fn socket_handle(&self) -> SocketHandle {
882         self.handle.0
883     }
884 
885     fn as_any_ref(&self) -> &dyn core::any::Any {
886         self
887     }
888 
889     fn as_any_mut(&mut self) -> &mut dyn core::any::Any {
890         self
891     }
892 }
893