xref: /DragonOS/kernel/src/net/socket/inet.rs (revision dd8e74ef0d7f91a141bd217736bef4fe7dc6df3d)
1 use alloc::{boxed::Box, sync::Arc, vec::Vec};
2 use smoltcp::{
3     socket::{raw, tcp, udp, AnySocket},
4     wire,
5 };
6 use system_error::SystemError;
7 
8 use crate::{
9     arch::rand::rand,
10     driver::net::NetDevice,
11     kerror, kwarn,
12     libs::rwlock::RwLock,
13     net::{
14         event_poll::EPollEventType, net_core::poll_ifaces, Endpoint, Protocol, ShutdownType,
15         NET_DEVICES,
16     },
17 };
18 
19 use super::{
20     handle::GlobalSocketHandle, Socket, SocketHandleItem, SocketMetadata, SocketOptions,
21     SocketPollMethod, SocketType, HANDLE_MAP, PORT_MANAGER, SOCKET_SET,
22 };
23 
24 /// @brief 表示原始的socket。原始套接字绕过传输层协议(如 TCP 或 UDP)并提供对网络层协议(如 IP)的直接访问。
25 ///
26 /// ref: https://man7.org/linux/man-pages/man7/raw.7.html
27 #[derive(Debug, Clone)]
28 pub struct RawSocket {
29     handle: GlobalSocketHandle,
30     /// 用户发送的数据包是否包含了IP头.
31     /// 如果是true,用户发送的数据包,必须包含IP头。(即用户要自行设置IP头+数据)
32     /// 如果是false,用户发送的数据包,不包含IP头。(即用户只要设置数据)
33     header_included: bool,
34     /// socket的metadata
35     metadata: SocketMetadata,
36 }
37 
38 impl RawSocket {
39     /// 元数据的缓冲区的大小
40     pub const DEFAULT_METADATA_BUF_SIZE: usize = 1024;
41     /// 默认的接收缓冲区的大小 receive
42     pub const DEFAULT_RX_BUF_SIZE: usize = 64 * 1024;
43     /// 默认的发送缓冲区的大小 transmiss
44     pub const DEFAULT_TX_BUF_SIZE: usize = 64 * 1024;
45 
46     /// @brief 创建一个原始的socket
47     ///
48     /// @param protocol 协议号
49     /// @param options socket的选项
50     ///
51     /// @return 返回创建的原始的socket
52     pub fn new(protocol: Protocol, options: SocketOptions) -> Self {
53         let rx_buffer = raw::PacketBuffer::new(
54             vec![raw::PacketMetadata::EMPTY; Self::DEFAULT_METADATA_BUF_SIZE],
55             vec![0; Self::DEFAULT_RX_BUF_SIZE],
56         );
57         let tx_buffer = raw::PacketBuffer::new(
58             vec![raw::PacketMetadata::EMPTY; Self::DEFAULT_METADATA_BUF_SIZE],
59             vec![0; Self::DEFAULT_TX_BUF_SIZE],
60         );
61         let protocol: u8 = protocol.into();
62         let socket = raw::Socket::new(
63             wire::IpVersion::Ipv4,
64             wire::IpProtocol::from(protocol),
65             rx_buffer,
66             tx_buffer,
67         );
68 
69         // 把socket添加到socket集合中,并得到socket的句柄
70         let handle = GlobalSocketHandle::new_smoltcp_handle(SOCKET_SET.lock_irqsave().add(socket));
71 
72         let metadata = SocketMetadata::new(
73             SocketType::Raw,
74             Self::DEFAULT_RX_BUF_SIZE,
75             Self::DEFAULT_TX_BUF_SIZE,
76             Self::DEFAULT_METADATA_BUF_SIZE,
77             options,
78         );
79 
80         return Self {
81             handle,
82             header_included: false,
83             metadata,
84         };
85     }
86 }
87 
88 impl Socket for RawSocket {
89     fn close(&mut self) {
90         let mut socket_set_guard = SOCKET_SET.lock_irqsave();
91         socket_set_guard.remove(self.handle.smoltcp_handle().unwrap()); // 删除的时候,会发送一条FINISH的信息?
92         drop(socket_set_guard);
93         poll_ifaces();
94     }
95 
96     fn read(&self, buf: &mut [u8]) -> (Result<usize, SystemError>, Endpoint) {
97         poll_ifaces();
98         loop {
99             // 如何优化这里?
100             let mut socket_set_guard = SOCKET_SET.lock_irqsave();
101             let socket =
102                 socket_set_guard.get_mut::<raw::Socket>(self.handle.smoltcp_handle().unwrap());
103 
104             match socket.recv_slice(buf) {
105                 Ok(len) => {
106                     let packet = wire::Ipv4Packet::new_unchecked(buf);
107                     return (
108                         Ok(len),
109                         Endpoint::Ip(Some(wire::IpEndpoint {
110                             addr: wire::IpAddress::Ipv4(packet.src_addr()),
111                             port: 0,
112                         })),
113                     );
114                 }
115                 Err(_) => {
116                     if !self.metadata.options.contains(SocketOptions::BLOCK) {
117                         // 如果是非阻塞的socket,就返回错误
118                         return (Err(SystemError::EAGAIN_OR_EWOULDBLOCK), Endpoint::Ip(None));
119                     }
120                 }
121             }
122             drop(socket_set_guard);
123             SocketHandleItem::sleep(
124                 self.socket_handle(),
125                 EPollEventType::EPOLLIN.bits() as u64,
126                 HANDLE_MAP.read_irqsave(),
127             );
128         }
129     }
130 
131     fn write(&self, buf: &[u8], to: Option<Endpoint>) -> Result<usize, SystemError> {
132         // 如果用户发送的数据包,包含IP头,则直接发送
133         if self.header_included {
134             let mut socket_set_guard = SOCKET_SET.lock_irqsave();
135             let socket =
136                 socket_set_guard.get_mut::<raw::Socket>(self.handle.smoltcp_handle().unwrap());
137             match socket.send_slice(buf) {
138                 Ok(_) => {
139                     return Ok(buf.len());
140                 }
141                 Err(raw::SendError::BufferFull) => {
142                     return Err(SystemError::ENOBUFS);
143                 }
144             }
145         } else {
146             // 如果用户发送的数据包,不包含IP头,则需要自己构造IP头
147 
148             if let Some(Endpoint::Ip(Some(endpoint))) = to {
149                 let mut socket_set_guard = SOCKET_SET.lock_irqsave();
150                 let socket: &mut raw::Socket =
151                     socket_set_guard.get_mut::<raw::Socket>(self.handle.smoltcp_handle().unwrap());
152 
153                 // 暴力解决方案:只考虑0号网卡。 TODO:考虑多网卡的情况!!!
154                 let iface = NET_DEVICES.read_irqsave().get(&0).unwrap().clone();
155 
156                 // 构造IP头
157                 let ipv4_src_addr: Option<wire::Ipv4Address> =
158                     iface.inner_iface().lock().ipv4_addr();
159                 if ipv4_src_addr.is_none() {
160                     return Err(SystemError::ENETUNREACH);
161                 }
162                 let ipv4_src_addr = ipv4_src_addr.unwrap();
163 
164                 if let wire::IpAddress::Ipv4(ipv4_dst) = endpoint.addr {
165                     let len = buf.len();
166 
167                     // 创建20字节的IPv4头部
168                     let mut buffer: Vec<u8> = vec![0u8; len + 20];
169                     let mut packet: wire::Ipv4Packet<&mut Vec<u8>> =
170                         wire::Ipv4Packet::new_unchecked(&mut buffer);
171 
172                     // 封装ipv4 header
173                     packet.set_version(4);
174                     packet.set_header_len(20);
175                     packet.set_total_len((20 + len) as u16);
176                     packet.set_src_addr(ipv4_src_addr);
177                     packet.set_dst_addr(ipv4_dst);
178 
179                     // 设置ipv4 header的protocol字段
180                     packet.set_next_header(socket.ip_protocol());
181 
182                     // 获取IP数据包的负载字段
183                     let payload: &mut [u8] = packet.payload_mut();
184                     payload.copy_from_slice(buf);
185 
186                     // 填充checksum字段
187                     packet.fill_checksum();
188 
189                     // 发送数据包
190                     socket.send_slice(&buffer).unwrap();
191 
192                     iface.poll(&mut socket_set_guard).ok();
193 
194                     drop(socket_set_guard);
195                     return Ok(len);
196                 } else {
197                     kwarn!("Unsupport Ip protocol type!");
198                     return Err(SystemError::EINVAL);
199                 }
200             } else {
201                 // 如果没有指定目的地址,则返回错误
202                 return Err(SystemError::ENOTCONN);
203             }
204         }
205     }
206 
207     fn connect(&mut self, _endpoint: Endpoint) -> Result<(), SystemError> {
208         Ok(())
209     }
210 
211     fn metadata(&self) -> SocketMetadata {
212         self.metadata.clone()
213     }
214 
215     fn box_clone(&self) -> Box<dyn Socket> {
216         Box::new(self.clone())
217     }
218 
219     fn socket_handle(&self) -> GlobalSocketHandle {
220         self.handle
221     }
222 
223     fn as_any_ref(&self) -> &dyn core::any::Any {
224         self
225     }
226 
227     fn as_any_mut(&mut self) -> &mut dyn core::any::Any {
228         self
229     }
230 }
231 
232 /// @brief 表示udp socket
233 ///
234 /// https://man7.org/linux/man-pages/man7/udp.7.html
235 #[derive(Debug, Clone)]
236 pub struct UdpSocket {
237     pub handle: GlobalSocketHandle,
238     remote_endpoint: Option<Endpoint>, // 记录远程endpoint提供给connect(), 应该使用IP地址。
239     metadata: SocketMetadata,
240 }
241 
242 impl UdpSocket {
243     /// 元数据的缓冲区的大小
244     pub const DEFAULT_METADATA_BUF_SIZE: usize = 1024;
245     /// 默认的接收缓冲区的大小 receive
246     pub const DEFAULT_RX_BUF_SIZE: usize = 64 * 1024;
247     /// 默认的发送缓冲区的大小 transmiss
248     pub const DEFAULT_TX_BUF_SIZE: usize = 64 * 1024;
249 
250     /// @brief 创建一个udp的socket
251     ///
252     /// @param options socket的选项
253     ///
254     /// @return 返回创建的udp的socket
255     pub fn new(options: SocketOptions) -> Self {
256         let rx_buffer = udp::PacketBuffer::new(
257             vec![udp::PacketMetadata::EMPTY; Self::DEFAULT_METADATA_BUF_SIZE],
258             vec![0; Self::DEFAULT_RX_BUF_SIZE],
259         );
260         let tx_buffer = udp::PacketBuffer::new(
261             vec![udp::PacketMetadata::EMPTY; Self::DEFAULT_METADATA_BUF_SIZE],
262             vec![0; Self::DEFAULT_TX_BUF_SIZE],
263         );
264         let socket = udp::Socket::new(rx_buffer, tx_buffer);
265 
266         // 把socket添加到socket集合中,并得到socket的句柄
267         let handle: GlobalSocketHandle =
268             GlobalSocketHandle::new_smoltcp_handle(SOCKET_SET.lock_irqsave().add(socket));
269 
270         let metadata = SocketMetadata::new(
271             SocketType::Udp,
272             Self::DEFAULT_RX_BUF_SIZE,
273             Self::DEFAULT_TX_BUF_SIZE,
274             Self::DEFAULT_METADATA_BUF_SIZE,
275             options,
276         );
277 
278         return Self {
279             handle,
280             remote_endpoint: None,
281             metadata,
282         };
283     }
284 
285     fn do_bind(&self, socket: &mut udp::Socket, endpoint: Endpoint) -> Result<(), SystemError> {
286         if let Endpoint::Ip(Some(mut ip)) = endpoint {
287             // 端口为0则分配随机端口
288             if ip.port == 0 {
289                 ip.port = PORT_MANAGER.get_ephemeral_port(self.metadata.socket_type)?;
290             }
291             // 检测端口是否已被占用
292             PORT_MANAGER.bind_port(self.metadata.socket_type, ip.port, self.clone())?;
293 
294             let bind_res = if ip.addr.is_unspecified() {
295                 socket.bind(ip.port)
296             } else {
297                 socket.bind(ip)
298             };
299 
300             match bind_res {
301                 Ok(()) => return Ok(()),
302                 Err(_) => return Err(SystemError::EINVAL),
303             }
304         } else {
305             return Err(SystemError::EINVAL);
306         }
307     }
308 }
309 
310 impl Socket for UdpSocket {
311     fn close(&mut self) {
312         let mut socket_set_guard = SOCKET_SET.lock_irqsave();
313         socket_set_guard.remove(self.handle.smoltcp_handle().unwrap()); // 删除的时候,会发送一条FINISH的信息?
314         drop(socket_set_guard);
315         poll_ifaces();
316     }
317 
318     /// @brief 在read函数执行之前,请先bind到本地的指定端口
319     fn read(&self, buf: &mut [u8]) -> (Result<usize, SystemError>, Endpoint) {
320         loop {
321             // kdebug!("Wait22 to Read");
322             poll_ifaces();
323             let mut socket_set_guard = SOCKET_SET.lock_irqsave();
324             let socket =
325                 socket_set_guard.get_mut::<udp::Socket>(self.handle.smoltcp_handle().unwrap());
326 
327             // kdebug!("Wait to Read");
328 
329             if socket.can_recv() {
330                 if let Ok((size, metadata)) = socket.recv_slice(buf) {
331                     drop(socket_set_guard);
332                     poll_ifaces();
333                     return (Ok(size), Endpoint::Ip(Some(metadata.endpoint)));
334                 }
335             } else {
336                 // 如果socket没有连接,则忙等
337                 // return (Err(SystemError::ENOTCONN), Endpoint::Ip(None));
338             }
339             drop(socket_set_guard);
340             SocketHandleItem::sleep(
341                 self.socket_handle(),
342                 EPollEventType::EPOLLIN.bits() as u64,
343                 HANDLE_MAP.read_irqsave(),
344             );
345         }
346     }
347 
348     fn write(&self, buf: &[u8], to: Option<Endpoint>) -> Result<usize, SystemError> {
349         // kdebug!("udp to send: {:?}, len={}", to, buf.len());
350         let remote_endpoint: &wire::IpEndpoint = {
351             if let Some(Endpoint::Ip(Some(ref endpoint))) = to {
352                 endpoint
353             } else if let Some(Endpoint::Ip(Some(ref endpoint))) = self.remote_endpoint {
354                 endpoint
355             } else {
356                 return Err(SystemError::ENOTCONN);
357             }
358         };
359         // kdebug!("udp write: remote = {:?}", remote_endpoint);
360 
361         let mut socket_set_guard = SOCKET_SET.lock_irqsave();
362         let socket = socket_set_guard.get_mut::<udp::Socket>(self.handle.smoltcp_handle().unwrap());
363         // kdebug!("is open()={}", socket.is_open());
364         // kdebug!("socket endpoint={:?}", socket.endpoint());
365         if socket.can_send() {
366             // kdebug!("udp write: can send");
367             match socket.send_slice(buf, *remote_endpoint) {
368                 Ok(()) => {
369                     // kdebug!("udp write: send ok");
370                     drop(socket_set_guard);
371                     poll_ifaces();
372                     return Ok(buf.len());
373                 }
374                 Err(_) => {
375                     // kdebug!("udp write: send err");
376                     return Err(SystemError::ENOBUFS);
377                 }
378             }
379         } else {
380             // kdebug!("udp write: can not send");
381             return Err(SystemError::ENOBUFS);
382         };
383     }
384 
385     fn bind(&mut self, endpoint: Endpoint) -> Result<(), SystemError> {
386         let mut sockets = SOCKET_SET.lock_irqsave();
387         let socket = sockets.get_mut::<udp::Socket>(self.handle.smoltcp_handle().unwrap());
388         // kdebug!("UDP Bind to {:?}", endpoint);
389         return self.do_bind(socket, endpoint);
390     }
391 
392     fn poll(&self) -> EPollEventType {
393         let sockets = SOCKET_SET.lock_irqsave();
394         let socket = sockets.get::<udp::Socket>(self.handle.smoltcp_handle().unwrap());
395 
396         return SocketPollMethod::udp_poll(
397             socket,
398             HANDLE_MAP
399                 .read_irqsave()
400                 .get(&self.socket_handle())
401                 .unwrap()
402                 .shutdown_type(),
403         );
404     }
405 
406     fn connect(&mut self, endpoint: Endpoint) -> Result<(), SystemError> {
407         if let Endpoint::Ip(_) = endpoint {
408             self.remote_endpoint = Some(endpoint);
409             Ok(())
410         } else {
411             Err(SystemError::EINVAL)
412         }
413     }
414 
415     fn ioctl(
416         &self,
417         _cmd: usize,
418         _arg0: usize,
419         _arg1: usize,
420         _arg2: usize,
421     ) -> Result<usize, SystemError> {
422         todo!()
423     }
424 
425     fn metadata(&self) -> SocketMetadata {
426         self.metadata.clone()
427     }
428 
429     fn box_clone(&self) -> Box<dyn Socket> {
430         return Box::new(self.clone());
431     }
432 
433     fn endpoint(&self) -> Option<Endpoint> {
434         let sockets = SOCKET_SET.lock_irqsave();
435         let socket = sockets.get::<udp::Socket>(self.handle.smoltcp_handle().unwrap());
436         let listen_endpoint = socket.endpoint();
437 
438         if listen_endpoint.port == 0 {
439             return None;
440         } else {
441             // 如果listen_endpoint的address是None,意味着“监听所有的地址”。
442             // 这里假设所有的地址都是ipv4
443             // TODO: 支持ipv6
444             let result = wire::IpEndpoint::new(
445                 listen_endpoint
446                     .addr
447                     .unwrap_or(wire::IpAddress::v4(0, 0, 0, 0)),
448                 listen_endpoint.port,
449             );
450             return Some(Endpoint::Ip(Some(result)));
451         }
452     }
453 
454     fn peer_endpoint(&self) -> Option<Endpoint> {
455         return self.remote_endpoint.clone();
456     }
457 
458     fn socket_handle(&self) -> GlobalSocketHandle {
459         self.handle
460     }
461 
462     fn as_any_ref(&self) -> &dyn core::any::Any {
463         self
464     }
465 
466     fn as_any_mut(&mut self) -> &mut dyn core::any::Any {
467         self
468     }
469 }
470 
471 /// @brief 表示 tcp socket
472 ///
473 /// https://man7.org/linux/man-pages/man7/tcp.7.html
474 #[derive(Debug, Clone)]
475 pub struct TcpSocket {
476     handles: Vec<GlobalSocketHandle>,
477     local_endpoint: Option<wire::IpEndpoint>, // save local endpoint for bind()
478     is_listening: bool,
479     metadata: SocketMetadata,
480 }
481 
482 impl TcpSocket {
483     /// 元数据的缓冲区的大小
484     pub const DEFAULT_METADATA_BUF_SIZE: usize = 1024;
485     /// 默认的接收缓冲区的大小 receive
486     pub const DEFAULT_RX_BUF_SIZE: usize = 512 * 1024;
487     /// 默认的发送缓冲区的大小 transmiss
488     pub const DEFAULT_TX_BUF_SIZE: usize = 512 * 1024;
489 
490     /// TcpSocket的特殊事件,用于在事件等待队列上sleep
491     pub const CAN_CONNECT: u64 = 1u64 << 63;
492     pub const CAN_ACCPET: u64 = 1u64 << 62;
493 
494     /// @brief 创建一个tcp的socket
495     ///
496     /// @param options socket的选项
497     ///
498     /// @return 返回创建的tcp的socket
499     pub fn new(options: SocketOptions) -> Self {
500         // 创建handles数组并把socket添加到socket集合中,并得到socket的句柄
501         let handles: Vec<GlobalSocketHandle> = vec![GlobalSocketHandle::new_smoltcp_handle(
502             SOCKET_SET.lock_irqsave().add(Self::create_new_socket()),
503         )];
504 
505         let metadata = SocketMetadata::new(
506             SocketType::Tcp,
507             Self::DEFAULT_RX_BUF_SIZE,
508             Self::DEFAULT_TX_BUF_SIZE,
509             Self::DEFAULT_METADATA_BUF_SIZE,
510             options,
511         );
512         // kdebug!("when there's a new tcp socket,its'len: {}",handles.len());
513 
514         return Self {
515             handles,
516             local_endpoint: None,
517             is_listening: false,
518             metadata,
519         };
520     }
521 
522     fn do_listen(
523         &mut self,
524         socket: &mut tcp::Socket,
525         local_endpoint: wire::IpEndpoint,
526     ) -> Result<(), SystemError> {
527         let listen_result = if local_endpoint.addr.is_unspecified() {
528             // kdebug!("Tcp Socket Listen on port {}", local_endpoint.port);
529             socket.listen(local_endpoint.port)
530         } else {
531             // kdebug!("Tcp Socket Listen on {local_endpoint}");
532             socket.listen(local_endpoint)
533         };
534         return match listen_result {
535             Ok(()) => {
536                 // kdebug!(
537                 //     "Tcp Socket Listen on {local_endpoint}, open?:{}",
538                 //     socket.is_open()
539                 // );
540                 self.is_listening = true;
541 
542                 Ok(())
543             }
544             Err(_) => Err(SystemError::EINVAL),
545         };
546     }
547 
548     /// # create_new_socket - 创建新的TCP套接字
549     ///
550     /// 该函数用于创建一个新的TCP套接字,并返回该套接字的引用。
551     fn create_new_socket() -> tcp::Socket<'static> {
552         // 初始化tcp的buffer
553         let rx_buffer = tcp::SocketBuffer::new(vec![0; Self::DEFAULT_RX_BUF_SIZE]);
554         let tx_buffer = tcp::SocketBuffer::new(vec![0; Self::DEFAULT_TX_BUF_SIZE]);
555         tcp::Socket::new(rx_buffer, tx_buffer)
556     }
557 }
558 
559 impl Socket for TcpSocket {
560     fn close(&mut self) {
561         for handle in self.handles.iter() {
562             let mut socket_set_guard = SOCKET_SET.lock_irqsave();
563             socket_set_guard.remove(handle.smoltcp_handle().unwrap()); // 删除的时候,会发送一条FINISH的信息?
564             drop(socket_set_guard);
565         }
566         poll_ifaces();
567     }
568 
569     fn read(&self, buf: &mut [u8]) -> (Result<usize, SystemError>, Endpoint) {
570         if HANDLE_MAP
571             .read_irqsave()
572             .get(&self.socket_handle())
573             .unwrap()
574             .shutdown_type()
575             .contains(ShutdownType::RCV_SHUTDOWN)
576         {
577             return (Err(SystemError::ENOTCONN), Endpoint::Ip(None));
578         }
579         // kdebug!("tcp socket: read, buf len={}", buf.len());
580         // kdebug!("tcp socket:read, socket'len={}",self.handle.len());
581         loop {
582             poll_ifaces();
583             let mut socket_set_guard = SOCKET_SET.lock_irqsave();
584 
585             let socket = socket_set_guard
586                 .get_mut::<tcp::Socket>(self.handles.get(0).unwrap().smoltcp_handle().unwrap());
587 
588             // 如果socket已经关闭,返回错误
589             if !socket.is_active() {
590                 // kdebug!("Tcp Socket Read Error, socket is closed");
591                 return (Err(SystemError::ENOTCONN), Endpoint::Ip(None));
592             }
593 
594             if socket.may_recv() {
595                 match socket.recv_slice(buf) {
596                     Ok(size) => {
597                         if size > 0 {
598                             let endpoint = if let Some(p) = socket.remote_endpoint() {
599                                 p
600                             } else {
601                                 return (Err(SystemError::ENOTCONN), Endpoint::Ip(None));
602                             };
603 
604                             drop(socket_set_guard);
605                             poll_ifaces();
606                             return (Ok(size), Endpoint::Ip(Some(endpoint)));
607                         }
608                     }
609                     Err(tcp::RecvError::InvalidState) => {
610                         kwarn!("Tcp Socket Read Error, InvalidState");
611                         return (Err(SystemError::ENOTCONN), Endpoint::Ip(None));
612                     }
613                     Err(tcp::RecvError::Finished) => {
614                         // 对端写端已关闭,我们应该关闭读端
615                         HANDLE_MAP
616                             .write_irqsave()
617                             .get_mut(&self.socket_handle())
618                             .unwrap()
619                             .shutdown_type_writer()
620                             .insert(ShutdownType::RCV_SHUTDOWN);
621                         return (Err(SystemError::ENOTCONN), Endpoint::Ip(None));
622                     }
623                 }
624             } else {
625                 return (Err(SystemError::ENOTCONN), Endpoint::Ip(None));
626             }
627             drop(socket_set_guard);
628             SocketHandleItem::sleep(
629                 self.socket_handle(),
630                 EPollEventType::EPOLLIN.bits() as u64,
631                 HANDLE_MAP.read_irqsave(),
632             );
633         }
634     }
635 
636     fn write(&self, buf: &[u8], _to: Option<Endpoint>) -> Result<usize, SystemError> {
637         if HANDLE_MAP
638             .read_irqsave()
639             .get(&self.socket_handle())
640             .unwrap()
641             .shutdown_type()
642             .contains(ShutdownType::RCV_SHUTDOWN)
643         {
644             return Err(SystemError::ENOTCONN);
645         }
646         // kdebug!("tcp socket:write, socket'len={}",self.handle.len());
647 
648         let mut socket_set_guard = SOCKET_SET.lock_irqsave();
649 
650         let socket = socket_set_guard
651             .get_mut::<tcp::Socket>(self.handles.get(0).unwrap().smoltcp_handle().unwrap());
652 
653         if socket.is_open() {
654             if socket.can_send() {
655                 match socket.send_slice(buf) {
656                     Ok(size) => {
657                         drop(socket_set_guard);
658                         poll_ifaces();
659                         return Ok(size);
660                     }
661                     Err(e) => {
662                         kerror!("Tcp Socket Write Error {e:?}");
663                         return Err(SystemError::ENOBUFS);
664                     }
665                 }
666             } else {
667                 return Err(SystemError::ENOBUFS);
668             }
669         }
670 
671         return Err(SystemError::ENOTCONN);
672     }
673 
674     fn poll(&self) -> EPollEventType {
675         let mut socket_set_guard = SOCKET_SET.lock_irqsave();
676         // kdebug!("tcp socket:poll, socket'len={}",self.handle.len());
677 
678         let socket = socket_set_guard
679             .get_mut::<tcp::Socket>(self.handles.get(0).unwrap().smoltcp_handle().unwrap());
680         return SocketPollMethod::tcp_poll(
681             socket,
682             HANDLE_MAP
683                 .read_irqsave()
684                 .get(&self.socket_handle())
685                 .unwrap()
686                 .shutdown_type(),
687         );
688     }
689 
690     fn connect(&mut self, endpoint: Endpoint) -> Result<(), SystemError> {
691         let mut sockets = SOCKET_SET.lock_irqsave();
692         // kdebug!("tcp socket:connect, socket'len={}",self.handle.len());
693 
694         let socket =
695             sockets.get_mut::<tcp::Socket>(self.handles.get(0).unwrap().smoltcp_handle().unwrap());
696 
697         if let Endpoint::Ip(Some(ip)) = endpoint {
698             let temp_port = PORT_MANAGER.get_ephemeral_port(self.metadata.socket_type)?;
699             // 检测端口是否被占用
700             PORT_MANAGER.bind_port(self.metadata.socket_type, temp_port, self.clone())?;
701 
702             // kdebug!("temp_port: {}", temp_port);
703             let iface: Arc<dyn NetDevice> = NET_DEVICES.write_irqsave().get(&0).unwrap().clone();
704             let mut inner_iface = iface.inner_iface().lock();
705             // kdebug!("to connect: {ip:?}");
706 
707             match socket.connect(inner_iface.context(), ip, temp_port) {
708                 Ok(()) => {
709                     // avoid deadlock
710                     drop(inner_iface);
711                     drop(iface);
712                     drop(sockets);
713                     loop {
714                         poll_ifaces();
715                         let mut sockets = SOCKET_SET.lock_irqsave();
716                         let socket = sockets.get_mut::<tcp::Socket>(
717                             self.handles.get(0).unwrap().smoltcp_handle().unwrap(),
718                         );
719 
720                         match socket.state() {
721                             tcp::State::Established => {
722                                 return Ok(());
723                             }
724                             tcp::State::SynSent => {
725                                 drop(sockets);
726                                 SocketHandleItem::sleep(
727                                     self.socket_handle(),
728                                     Self::CAN_CONNECT,
729                                     HANDLE_MAP.read_irqsave(),
730                                 );
731                             }
732                             _ => {
733                                 return Err(SystemError::ECONNREFUSED);
734                             }
735                         }
736                     }
737                 }
738                 Err(e) => {
739                     // kerror!("Tcp Socket Connect Error {e:?}");
740                     match e {
741                         tcp::ConnectError::InvalidState => return Err(SystemError::EISCONN),
742                         tcp::ConnectError::Unaddressable => return Err(SystemError::EADDRNOTAVAIL),
743                     }
744                 }
745             }
746         } else {
747             return Err(SystemError::EINVAL);
748         }
749     }
750 
751     /// @brief tcp socket 监听 local_endpoint 端口
752     ///
753     /// @param backlog 未处理的连接队列的最大长度. 由于smoltcp不支持backlog,所以这个参数目前无效
754     fn listen(&mut self, backlog: usize) -> Result<(), SystemError> {
755         if self.is_listening {
756             return Ok(());
757         }
758 
759         let local_endpoint = self.local_endpoint.ok_or(SystemError::EINVAL)?;
760         let mut sockets = SOCKET_SET.lock_irqsave();
761         // 获取handle的数量
762         let handlen = self.handles.len();
763         let backlog = handlen.max(backlog);
764 
765         // 添加剩余需要构建的socket
766         // kdebug!("tcp socket:before listen, socket'len={}",self.handle.len());
767         let mut handle_guard = HANDLE_MAP.write_irqsave();
768         self.handles.extend((handlen..backlog).map(|_| {
769             let socket = Self::create_new_socket();
770             let handle = GlobalSocketHandle::new_smoltcp_handle(sockets.add(socket));
771             let handle_item = SocketHandleItem::new();
772             handle_guard.insert(handle, handle_item);
773             handle
774         }));
775         // kdebug!("tcp socket:listen, socket'len={}",self.handle.len());
776         // kdebug!("tcp socket:listen, backlog={backlog}");
777 
778         // 监听所有的socket
779         for i in 0..backlog {
780             let handle = self.handles.get(i).unwrap();
781 
782             let socket = sockets.get_mut::<tcp::Socket>(handle.smoltcp_handle().unwrap());
783 
784             if !socket.is_listening() {
785                 // kdebug!("Tcp Socket is already listening on {local_endpoint}");
786                 self.do_listen(socket, local_endpoint)?;
787             }
788             // kdebug!("Tcp Socket  before listen, open={}", socket.is_open());
789         }
790         return Ok(());
791     }
792 
793     fn bind(&mut self, endpoint: Endpoint) -> Result<(), SystemError> {
794         if let Endpoint::Ip(Some(mut ip)) = endpoint {
795             if ip.port == 0 {
796                 ip.port = PORT_MANAGER.get_ephemeral_port(self.metadata.socket_type)?;
797             }
798 
799             // 检测端口是否已被占用
800             PORT_MANAGER.bind_port(self.metadata.socket_type, ip.port, self.clone())?;
801             // kdebug!("tcp socket:bind, socket'len={}",self.handle.len());
802 
803             self.local_endpoint = Some(ip);
804             self.is_listening = false;
805             return Ok(());
806         }
807         return Err(SystemError::EINVAL);
808     }
809 
810     fn shutdown(&mut self, shutdown_type: super::ShutdownType) -> Result<(), SystemError> {
811         // TODO:目前只是在表层判断,对端不知晓,后续需使用tcp实现
812         HANDLE_MAP
813             .write_irqsave()
814             .get_mut(&self.socket_handle())
815             .unwrap()
816             .shutdown_type = RwLock::new(shutdown_type);
817         return Ok(());
818     }
819 
820     fn accept(&mut self) -> Result<(Box<dyn Socket>, Endpoint), SystemError> {
821         let endpoint = self.local_endpoint.ok_or(SystemError::EINVAL)?;
822         loop {
823             // kdebug!("tcp accept: poll_ifaces()");
824             poll_ifaces();
825             // kdebug!("tcp socket:accept, socket'len={}",self.handle.len());
826 
827             let mut sockets = SOCKET_SET.lock_irqsave();
828 
829             // 随机获取访问的socket的handle
830             let index: usize = rand() % self.handles.len();
831             let handle = self.handles.get(index).unwrap();
832 
833             let socket = sockets
834                 .iter_mut()
835                 .find(|y| {
836                     tcp::Socket::downcast(y.1)
837                         .map(|y| y.is_active())
838                         .unwrap_or(false)
839                 })
840                 .map(|y| tcp::Socket::downcast_mut(y.1).unwrap());
841             if let Some(socket) = socket {
842                 if socket.is_active() {
843                     // kdebug!("tcp accept: socket.is_active()");
844                     let remote_ep = socket.remote_endpoint().ok_or(SystemError::ENOTCONN)?;
845 
846                     let new_socket = {
847                         // The new TCP socket used for sending and receiving data.
848                         let mut tcp_socket = Self::create_new_socket();
849                         self.do_listen(&mut tcp_socket, endpoint)
850                             .expect("do_listen failed");
851 
852                         // tcp_socket.listen(endpoint).unwrap();
853 
854                         // 之所以把old_handle存入new_socket, 是因为当前时刻,smoltcp已经把old_handle对应的socket与远程的endpoint关联起来了
855                         // 因此需要再为当前的socket分配一个新的handle
856                         let new_handle =
857                             GlobalSocketHandle::new_smoltcp_handle(sockets.add(tcp_socket));
858                         let old_handle = ::core::mem::replace(
859                             &mut *self.handles.get_mut(index).unwrap(),
860                             new_handle,
861                         );
862 
863                         let metadata = SocketMetadata::new(
864                             SocketType::Tcp,
865                             Self::DEFAULT_TX_BUF_SIZE,
866                             Self::DEFAULT_RX_BUF_SIZE,
867                             Self::DEFAULT_METADATA_BUF_SIZE,
868                             self.metadata.options,
869                         );
870 
871                         let new_socket = Box::new(TcpSocket {
872                             handles: vec![old_handle],
873                             local_endpoint: self.local_endpoint,
874                             is_listening: false,
875                             metadata,
876                         });
877                         // kdebug!("tcp socket:after accept, socket'len={}",new_socket.handle.len());
878 
879                         // 更新端口与 socket 的绑定
880                         if let Some(Endpoint::Ip(Some(ip))) = self.endpoint() {
881                             PORT_MANAGER.unbind_port(self.metadata.socket_type, ip.port)?;
882                             PORT_MANAGER.bind_port(
883                                 self.metadata.socket_type,
884                                 ip.port,
885                                 *new_socket.clone(),
886                             )?;
887                         }
888 
889                         // 更新handle表
890                         let mut handle_guard = HANDLE_MAP.write_irqsave();
891                         // 先删除原来的
892 
893                         let item = handle_guard.remove(&old_handle).unwrap();
894 
895                         // 按照smoltcp行为,将新的handle绑定到原来的item
896                         handle_guard.insert(new_handle, item);
897                         let new_item = SocketHandleItem::new();
898 
899                         // 插入新的item
900                         handle_guard.insert(old_handle, new_item);
901 
902                         new_socket
903                     };
904                     // kdebug!("tcp accept: new socket: {:?}", new_socket);
905                     drop(sockets);
906                     poll_ifaces();
907 
908                     return Ok((new_socket, Endpoint::Ip(Some(remote_ep))));
909                 }
910             }
911             // kdebug!("tcp socket:before sleep, handle_guard'len={}",HANDLE_MAP.write_irqsave().len());
912 
913             drop(sockets);
914             SocketHandleItem::sleep(*handle, Self::CAN_ACCPET, HANDLE_MAP.read_irqsave());
915             // kdebug!("tcp socket:after sleep, handle_guard'len={}",HANDLE_MAP.write_irqsave().len());
916         }
917     }
918 
919     fn endpoint(&self) -> Option<Endpoint> {
920         let mut result: Option<Endpoint> = self.local_endpoint.map(|x| Endpoint::Ip(Some(x)));
921 
922         if result.is_none() {
923             let sockets = SOCKET_SET.lock_irqsave();
924             // kdebug!("tcp socket:endpoint, socket'len={}",self.handle.len());
925 
926             let socket =
927                 sockets.get::<tcp::Socket>(self.handles.get(0).unwrap().smoltcp_handle().unwrap());
928             if let Some(ep) = socket.local_endpoint() {
929                 result = Some(Endpoint::Ip(Some(ep)));
930             }
931         }
932         return result;
933     }
934 
935     fn peer_endpoint(&self) -> Option<Endpoint> {
936         let sockets = SOCKET_SET.lock_irqsave();
937         // kdebug!("tcp socket:peer_endpoint, socket'len={}",self.handle.len());
938 
939         let socket =
940             sockets.get::<tcp::Socket>(self.handles.get(0).unwrap().smoltcp_handle().unwrap());
941         return socket.remote_endpoint().map(|x| Endpoint::Ip(Some(x)));
942     }
943 
944     fn metadata(&self) -> SocketMetadata {
945         self.metadata.clone()
946     }
947 
948     fn box_clone(&self) -> Box<dyn Socket> {
949         Box::new(self.clone())
950     }
951 
952     fn socket_handle(&self) -> GlobalSocketHandle {
953         // kdebug!("tcp socket:socket_handle, socket'len={}",self.handle.len());
954 
955         *self.handles.get(0).unwrap()
956     }
957 
958     fn as_any_ref(&self) -> &dyn core::any::Any {
959         self
960     }
961 
962     fn as_any_mut(&mut self) -> &mut dyn core::any::Any {
963         self
964     }
965 }
966