1 use alloc::{boxed::Box, sync::Arc, vec::Vec};
2 use smoltcp::{
3     iface::SocketHandle,
4     socket::{raw, tcp, udp},
5     wire,
6 };
7 use system_error::SystemError;
8 
9 use crate::{
10     driver::net::NetDriver,
11     kerror, kwarn,
12     libs::{rwlock::RwLock, spinlock::SpinLock},
13     net::{
14         event_poll::EPollEventType, net_core::poll_ifaces, Endpoint, Protocol, ShutdownType,
15         NET_DRIVERS,
16     },
17 };
18 
19 use super::{
20     GlobalSocketHandle, Socket, SocketHandleItem, SocketMetadata, SocketOptions, SocketPollMethod,
21     SocketType, SocketpairOps, HANDLE_MAP, PORT_MANAGER, SOCKET_SET,
22 };
23 
24 /// @brief 表示原始的socket。原始套接字绕过传输层协议(如 TCP 或 UDP)并提供对网络层协议(如 IP)的直接访问。
25 ///
26 /// ref: https://man7.org/linux/man-pages/man7/raw.7.html
27 #[derive(Debug, Clone)]
28 pub struct RawSocket {
29     handle: Arc<GlobalSocketHandle>,
30     /// 用户发送的数据包是否包含了IP头.
31     /// 如果是true,用户发送的数据包,必须包含IP头。(即用户要自行设置IP头+数据)
32     /// 如果是false,用户发送的数据包,不包含IP头。(即用户只要设置数据)
33     header_included: bool,
34     /// socket的metadata
35     metadata: SocketMetadata,
36 }
37 
38 impl RawSocket {
39     /// 元数据的缓冲区的大小
40     pub const DEFAULT_METADATA_BUF_SIZE: usize = 1024;
41     /// 默认的接收缓冲区的大小 receive
42     pub const DEFAULT_RX_BUF_SIZE: usize = 64 * 1024;
43     /// 默认的发送缓冲区的大小 transmiss
44     pub const DEFAULT_TX_BUF_SIZE: usize = 64 * 1024;
45 
46     /// @brief 创建一个原始的socket
47     ///
48     /// @param protocol 协议号
49     /// @param options socket的选项
50     ///
51     /// @return 返回创建的原始的socket
new(protocol: Protocol, options: SocketOptions) -> Self52     pub fn new(protocol: Protocol, options: SocketOptions) -> Self {
53         let rx_buffer = raw::PacketBuffer::new(
54             vec![raw::PacketMetadata::EMPTY; Self::DEFAULT_METADATA_BUF_SIZE],
55             vec![0; Self::DEFAULT_RX_BUF_SIZE],
56         );
57         let tx_buffer = raw::PacketBuffer::new(
58             vec![raw::PacketMetadata::EMPTY; Self::DEFAULT_METADATA_BUF_SIZE],
59             vec![0; Self::DEFAULT_TX_BUF_SIZE],
60         );
61         let protocol: u8 = protocol.into();
62         let socket = raw::Socket::new(
63             wire::IpVersion::Ipv4,
64             wire::IpProtocol::from(protocol),
65             rx_buffer,
66             tx_buffer,
67         );
68 
69         // 把socket添加到socket集合中,并得到socket的句柄
70         let handle: Arc<GlobalSocketHandle> =
71             GlobalSocketHandle::new(SOCKET_SET.lock_irqsave().add(socket));
72 
73         let metadata = SocketMetadata::new(
74             SocketType::RawSocket,
75             Self::DEFAULT_RX_BUF_SIZE,
76             Self::DEFAULT_TX_BUF_SIZE,
77             Self::DEFAULT_METADATA_BUF_SIZE,
78             options,
79         );
80 
81         return Self {
82             handle,
83             header_included: false,
84             metadata,
85         };
86     }
87 }
88 
89 impl Socket for RawSocket {
as_any_ref(&self) -> &dyn core::any::Any90     fn as_any_ref(&self) -> &dyn core::any::Any {
91         self
92     }
93 
as_any_mut(&mut self) -> &mut dyn core::any::Any94     fn as_any_mut(&mut self) -> &mut dyn core::any::Any {
95         self
96     }
97 
read(&mut self, buf: &mut [u8]) -> (Result<usize, SystemError>, Endpoint)98     fn read(&mut self, buf: &mut [u8]) -> (Result<usize, SystemError>, Endpoint) {
99         poll_ifaces();
100         loop {
101             // 如何优化这里?
102             let mut socket_set_guard = SOCKET_SET.lock_irqsave();
103             let socket = socket_set_guard.get_mut::<raw::Socket>(self.handle.0);
104 
105             match socket.recv_slice(buf) {
106                 Ok(len) => {
107                     let packet = wire::Ipv4Packet::new_unchecked(buf);
108                     return (
109                         Ok(len),
110                         Endpoint::Ip(Some(wire::IpEndpoint {
111                             addr: wire::IpAddress::Ipv4(packet.src_addr()),
112                             port: 0,
113                         })),
114                     );
115                 }
116                 Err(raw::RecvError::Exhausted) => {
117                     if !self.metadata.options.contains(SocketOptions::BLOCK) {
118                         // 如果是非阻塞的socket,就返回错误
119                         return (Err(SystemError::EAGAIN_OR_EWOULDBLOCK), Endpoint::Ip(None));
120                     }
121                 }
122             }
123             drop(socket_set_guard);
124             SocketHandleItem::sleep(
125                 self.socket_handle(),
126                 EPollEventType::EPOLLIN.bits() as u64,
127                 HANDLE_MAP.read_irqsave(),
128             );
129         }
130     }
131 
write(&self, buf: &[u8], to: Option<Endpoint>) -> Result<usize, SystemError>132     fn write(&self, buf: &[u8], to: Option<Endpoint>) -> Result<usize, SystemError> {
133         // 如果用户发送的数据包,包含IP头,则直接发送
134         if self.header_included {
135             let mut socket_set_guard = SOCKET_SET.lock_irqsave();
136             let socket = socket_set_guard.get_mut::<raw::Socket>(self.handle.0);
137             match socket.send_slice(buf) {
138                 Ok(_) => {
139                     return Ok(buf.len());
140                 }
141                 Err(raw::SendError::BufferFull) => {
142                     return Err(SystemError::ENOBUFS);
143                 }
144             }
145         } else {
146             // 如果用户发送的数据包,不包含IP头,则需要自己构造IP头
147 
148             if let Some(Endpoint::Ip(Some(endpoint))) = to {
149                 let mut socket_set_guard = SOCKET_SET.lock_irqsave();
150                 let socket: &mut raw::Socket =
151                     socket_set_guard.get_mut::<raw::Socket>(self.handle.0);
152 
153                 // 暴力解决方案:只考虑0号网卡。 TODO:考虑多网卡的情况!!!
154                 let iface = NET_DRIVERS.read_irqsave().get(&0).unwrap().clone();
155 
156                 // 构造IP头
157                 let ipv4_src_addr: Option<wire::Ipv4Address> =
158                     iface.inner_iface().lock().ipv4_addr();
159                 if ipv4_src_addr.is_none() {
160                     return Err(SystemError::ENETUNREACH);
161                 }
162                 let ipv4_src_addr = ipv4_src_addr.unwrap();
163 
164                 if let wire::IpAddress::Ipv4(ipv4_dst) = endpoint.addr {
165                     let len = buf.len();
166 
167                     // 创建20字节的IPv4头部
168                     let mut buffer: Vec<u8> = vec![0u8; len + 20];
169                     let mut packet: wire::Ipv4Packet<&mut Vec<u8>> =
170                         wire::Ipv4Packet::new_unchecked(&mut buffer);
171 
172                     // 封装ipv4 header
173                     packet.set_version(4);
174                     packet.set_header_len(20);
175                     packet.set_total_len((20 + len) as u16);
176                     packet.set_src_addr(ipv4_src_addr);
177                     packet.set_dst_addr(ipv4_dst);
178 
179                     // 设置ipv4 header的protocol字段
180                     packet.set_next_header(socket.ip_protocol().into());
181 
182                     // 获取IP数据包的负载字段
183                     let payload: &mut [u8] = packet.payload_mut();
184                     payload.copy_from_slice(buf);
185 
186                     // 填充checksum字段
187                     packet.fill_checksum();
188 
189                     // 发送数据包
190                     socket.send_slice(&buffer).unwrap();
191 
192                     iface.poll(&mut socket_set_guard).ok();
193 
194                     drop(socket_set_guard);
195                     return Ok(len);
196                 } else {
197                     kwarn!("Unsupport Ip protocol type!");
198                     return Err(SystemError::EINVAL);
199                 }
200             } else {
201                 // 如果没有指定目的地址,则返回错误
202                 return Err(SystemError::ENOTCONN);
203             }
204         }
205     }
206 
connect(&mut self, _endpoint: Endpoint) -> Result<(), SystemError>207     fn connect(&mut self, _endpoint: Endpoint) -> Result<(), SystemError> {
208         Ok(())
209     }
210 
metadata(&self) -> Result<SocketMetadata, SystemError>211     fn metadata(&self) -> Result<SocketMetadata, SystemError> {
212         Ok(self.metadata.clone())
213     }
214 
box_clone(&self) -> Box<dyn Socket>215     fn box_clone(&self) -> Box<dyn Socket> {
216         return Box::new(self.clone());
217     }
218 
socket_handle(&self) -> SocketHandle219     fn socket_handle(&self) -> SocketHandle {
220         self.handle.0
221     }
222 }
223 
224 /// @brief 表示udp socket
225 ///
226 /// https://man7.org/linux/man-pages/man7/udp.7.html
227 #[derive(Debug, Clone)]
228 pub struct UdpSocket {
229     pub handle: Arc<GlobalSocketHandle>,
230     remote_endpoint: Option<Endpoint>, // 记录远程endpoint提供给connect(), 应该使用IP地址。
231     metadata: SocketMetadata,
232 }
233 
234 impl UdpSocket {
235     /// 元数据的缓冲区的大小
236     pub const DEFAULT_METADATA_BUF_SIZE: usize = 1024;
237     /// 默认的接收缓冲区的大小 receive
238     pub const DEFAULT_RX_BUF_SIZE: usize = 64 * 1024;
239     /// 默认的发送缓冲区的大小 transmiss
240     pub const DEFAULT_TX_BUF_SIZE: usize = 64 * 1024;
241 
242     /// @brief 创建一个udp的socket
243     ///
244     /// @param options socket的选项
245     ///
246     /// @return 返回创建的udp的socket
new(options: SocketOptions) -> Self247     pub fn new(options: SocketOptions) -> Self {
248         let rx_buffer = udp::PacketBuffer::new(
249             vec![udp::PacketMetadata::EMPTY; Self::DEFAULT_METADATA_BUF_SIZE],
250             vec![0; Self::DEFAULT_RX_BUF_SIZE],
251         );
252         let tx_buffer = udp::PacketBuffer::new(
253             vec![udp::PacketMetadata::EMPTY; Self::DEFAULT_METADATA_BUF_SIZE],
254             vec![0; Self::DEFAULT_TX_BUF_SIZE],
255         );
256         let socket = udp::Socket::new(rx_buffer, tx_buffer);
257 
258         // 把socket添加到socket集合中,并得到socket的句柄
259         let handle: Arc<GlobalSocketHandle> =
260             GlobalSocketHandle::new(SOCKET_SET.lock_irqsave().add(socket));
261 
262         let metadata = SocketMetadata::new(
263             SocketType::UdpSocket,
264             Self::DEFAULT_RX_BUF_SIZE,
265             Self::DEFAULT_TX_BUF_SIZE,
266             Self::DEFAULT_METADATA_BUF_SIZE,
267             options,
268         );
269 
270         return Self {
271             handle,
272             remote_endpoint: None,
273             metadata,
274         };
275     }
276 
do_bind(&self, socket: &mut udp::Socket, endpoint: Endpoint) -> Result<(), SystemError>277     fn do_bind(&self, socket: &mut udp::Socket, endpoint: Endpoint) -> Result<(), SystemError> {
278         if let Endpoint::Ip(Some(ip)) = endpoint {
279             // 检测端口是否已被占用
280             PORT_MANAGER.bind_port(self.metadata.socket_type, ip.port, self.handle.clone())?;
281 
282             let bind_res = if ip.addr.is_unspecified() {
283                 socket.bind(ip.port)
284             } else {
285                 socket.bind(ip)
286             };
287 
288             match bind_res {
289                 Ok(()) => return Ok(()),
290                 Err(_) => return Err(SystemError::EINVAL),
291             }
292         } else {
293             return Err(SystemError::EINVAL);
294         }
295     }
296 }
297 
298 impl Socket for UdpSocket {
as_any_ref(&self) -> &dyn core::any::Any299     fn as_any_ref(&self) -> &dyn core::any::Any {
300         self
301     }
302 
as_any_mut(&mut self) -> &mut dyn core::any::Any303     fn as_any_mut(&mut self) -> &mut dyn core::any::Any {
304         self
305     }
306 
307     /// @brief 在read函数执行之前,请先bind到本地的指定端口
read(&mut self, buf: &mut [u8]) -> (Result<usize, SystemError>, Endpoint)308     fn read(&mut self, buf: &mut [u8]) -> (Result<usize, SystemError>, Endpoint) {
309         loop {
310             // kdebug!("Wait22 to Read");
311             poll_ifaces();
312             let mut socket_set_guard = SOCKET_SET.lock_irqsave();
313             let socket = socket_set_guard.get_mut::<udp::Socket>(self.handle.0);
314 
315             // kdebug!("Wait to Read");
316 
317             if socket.can_recv() {
318                 if let Ok((size, remote_endpoint)) = socket.recv_slice(buf) {
319                     drop(socket_set_guard);
320                     poll_ifaces();
321                     return (Ok(size), Endpoint::Ip(Some(remote_endpoint)));
322                 }
323             } else {
324                 // 如果socket没有连接,则忙等
325                 // return (Err(SystemError::ENOTCONN), Endpoint::Ip(None));
326             }
327             drop(socket_set_guard);
328             SocketHandleItem::sleep(
329                 self.socket_handle(),
330                 EPollEventType::EPOLLIN.bits() as u64,
331                 HANDLE_MAP.read_irqsave(),
332             );
333         }
334     }
335 
write(&self, buf: &[u8], to: Option<Endpoint>) -> Result<usize, SystemError>336     fn write(&self, buf: &[u8], to: Option<Endpoint>) -> Result<usize, SystemError> {
337         // kdebug!("udp to send: {:?}, len={}", to, buf.len());
338         let remote_endpoint: &wire::IpEndpoint = {
339             if let Some(Endpoint::Ip(Some(ref endpoint))) = to {
340                 endpoint
341             } else if let Some(Endpoint::Ip(Some(ref endpoint))) = self.remote_endpoint {
342                 endpoint
343             } else {
344                 return Err(SystemError::ENOTCONN);
345             }
346         };
347         // kdebug!("udp write: remote = {:?}", remote_endpoint);
348 
349         let mut socket_set_guard = SOCKET_SET.lock_irqsave();
350         let socket = socket_set_guard.get_mut::<udp::Socket>(self.handle.0);
351         // kdebug!("is open()={}", socket.is_open());
352         // kdebug!("socket endpoint={:?}", socket.endpoint());
353         if socket.endpoint().port == 0 {
354             let temp_port = PORT_MANAGER.get_ephemeral_port(self.metadata.socket_type)?;
355 
356             let local_ep = match remote_endpoint.addr {
357                 // 远程remote endpoint使用什么协议,发送的时候使用的协议是一样的吧
358                 // 否则就用 self.endpoint().addr.unwrap()
359                 wire::IpAddress::Ipv4(_) => Endpoint::Ip(Some(wire::IpEndpoint::new(
360                     wire::IpAddress::Ipv4(wire::Ipv4Address::UNSPECIFIED),
361                     temp_port,
362                 ))),
363                 wire::IpAddress::Ipv6(_) => Endpoint::Ip(Some(wire::IpEndpoint::new(
364                     wire::IpAddress::Ipv6(wire::Ipv6Address::UNSPECIFIED),
365                     temp_port,
366                 ))),
367             };
368             // kdebug!("udp write: local_ep = {:?}", local_ep);
369             self.do_bind(socket, local_ep)?;
370         }
371         // kdebug!("is open()={}", socket.is_open());
372         if socket.can_send() {
373             // kdebug!("udp write: can send");
374             match socket.send_slice(&buf, *remote_endpoint) {
375                 Ok(()) => {
376                     // kdebug!("udp write: send ok");
377                     drop(socket_set_guard);
378                     poll_ifaces();
379                     return Ok(buf.len());
380                 }
381                 Err(_) => {
382                     // kdebug!("udp write: send err");
383                     return Err(SystemError::ENOBUFS);
384                 }
385             }
386         } else {
387             // kdebug!("udp write: can not send");
388             return Err(SystemError::ENOBUFS);
389         };
390     }
391 
bind(&mut self, endpoint: Endpoint) -> Result<(), SystemError>392     fn bind(&mut self, endpoint: Endpoint) -> Result<(), SystemError> {
393         let mut sockets = SOCKET_SET.lock_irqsave();
394         let socket = sockets.get_mut::<udp::Socket>(self.handle.0);
395         // kdebug!("UDP Bind to {:?}", endpoint);
396         return self.do_bind(socket, endpoint);
397     }
398 
poll(&self) -> EPollEventType399     fn poll(&self) -> EPollEventType {
400         let sockets = SOCKET_SET.lock_irqsave();
401         let socket = sockets.get::<udp::Socket>(self.handle.0);
402 
403         return SocketPollMethod::udp_poll(
404             socket,
405             HANDLE_MAP
406                 .read_irqsave()
407                 .get(&self.socket_handle())
408                 .unwrap()
409                 .shutdown_type(),
410         );
411     }
412 
connect(&mut self, endpoint: Endpoint) -> Result<(), SystemError>413     fn connect(&mut self, endpoint: Endpoint) -> Result<(), SystemError> {
414         if let Endpoint::Ip(_) = endpoint {
415             self.remote_endpoint = Some(endpoint);
416             return Ok(());
417         } else {
418             return Err(SystemError::EINVAL);
419         };
420     }
421 
ioctl( &self, _cmd: usize, _arg0: usize, _arg1: usize, _arg2: usize, ) -> Result<usize, SystemError>422     fn ioctl(
423         &self,
424         _cmd: usize,
425         _arg0: usize,
426         _arg1: usize,
427         _arg2: usize,
428     ) -> Result<usize, SystemError> {
429         todo!()
430     }
431 
metadata(&self) -> Result<SocketMetadata, SystemError>432     fn metadata(&self) -> Result<SocketMetadata, SystemError> {
433         Ok(self.metadata.clone())
434     }
435 
box_clone(&self) -> Box<dyn Socket>436     fn box_clone(&self) -> Box<dyn Socket> {
437         return Box::new(self.clone());
438     }
439 
endpoint(&self) -> Option<Endpoint>440     fn endpoint(&self) -> Option<Endpoint> {
441         let sockets = SOCKET_SET.lock_irqsave();
442         let socket = sockets.get::<udp::Socket>(self.handle.0);
443         let listen_endpoint = socket.endpoint();
444 
445         if listen_endpoint.port == 0 {
446             return None;
447         } else {
448             // 如果listen_endpoint的address是None,意味着“监听所有的地址”。
449             // 这里假设所有的地址都是ipv4
450             // TODO: 支持ipv6
451             let result = wire::IpEndpoint::new(
452                 listen_endpoint
453                     .addr
454                     .unwrap_or(wire::IpAddress::v4(0, 0, 0, 0)),
455                 listen_endpoint.port,
456             );
457             return Some(Endpoint::Ip(Some(result)));
458         }
459     }
460 
peer_endpoint(&self) -> Option<Endpoint>461     fn peer_endpoint(&self) -> Option<Endpoint> {
462         return self.remote_endpoint.clone();
463     }
464 
socket_handle(&self) -> SocketHandle465     fn socket_handle(&self) -> SocketHandle {
466         self.handle.0
467     }
468 }
469 
470 /// @brief 表示 tcp socket
471 ///
472 /// https://man7.org/linux/man-pages/man7/tcp.7.html
473 #[derive(Debug, Clone)]
474 pub struct TcpSocket {
475     handle: Arc<GlobalSocketHandle>,
476     local_endpoint: Option<wire::IpEndpoint>, // save local endpoint for bind()
477     is_listening: bool,
478     metadata: SocketMetadata,
479 }
480 
481 impl TcpSocket {
482     /// 元数据的缓冲区的大小
483     pub const DEFAULT_METADATA_BUF_SIZE: usize = 1024;
484     /// 默认的接收缓冲区的大小 receive
485     pub const DEFAULT_RX_BUF_SIZE: usize = 512 * 1024;
486     /// 默认的发送缓冲区的大小 transmiss
487     pub const DEFAULT_TX_BUF_SIZE: usize = 512 * 1024;
488 
489     /// TcpSocket的特殊事件,用于在事件等待队列上sleep
490     pub const CAN_CONNECT: u64 = 1u64 << 63;
491     pub const CAN_ACCPET: u64 = 1u64 << 62;
492 
493     /// @brief 创建一个tcp的socket
494     ///
495     /// @param options socket的选项
496     ///
497     /// @return 返回创建的tcp的socket
new(options: SocketOptions) -> Self498     pub fn new(options: SocketOptions) -> Self {
499         let rx_buffer = tcp::SocketBuffer::new(vec![0; Self::DEFAULT_RX_BUF_SIZE]);
500         let tx_buffer = tcp::SocketBuffer::new(vec![0; Self::DEFAULT_TX_BUF_SIZE]);
501         let socket = tcp::Socket::new(rx_buffer, tx_buffer);
502 
503         // 把socket添加到socket集合中,并得到socket的句柄
504         let handle: Arc<GlobalSocketHandle> =
505             GlobalSocketHandle::new(SOCKET_SET.lock_irqsave().add(socket));
506 
507         let metadata = SocketMetadata::new(
508             SocketType::TcpSocket,
509             Self::DEFAULT_RX_BUF_SIZE,
510             Self::DEFAULT_TX_BUF_SIZE,
511             Self::DEFAULT_METADATA_BUF_SIZE,
512             options,
513         );
514 
515         return Self {
516             handle,
517             local_endpoint: None,
518             is_listening: false,
519             metadata,
520         };
521     }
do_listen( &mut self, socket: &mut tcp::Socket, local_endpoint: wire::IpEndpoint, ) -> Result<(), SystemError>522     fn do_listen(
523         &mut self,
524         socket: &mut tcp::Socket,
525         local_endpoint: wire::IpEndpoint,
526     ) -> Result<(), SystemError> {
527         let listen_result = if local_endpoint.addr.is_unspecified() {
528             // kdebug!("Tcp Socket Listen on port {}", local_endpoint.port);
529             socket.listen(local_endpoint.port)
530         } else {
531             // kdebug!("Tcp Socket Listen on {local_endpoint}");
532             socket.listen(local_endpoint)
533         };
534         // TODO: 增加端口占用检查
535         return match listen_result {
536             Ok(()) => {
537                 // kdebug!(
538                 //     "Tcp Socket Listen on {local_endpoint}, open?:{}",
539                 //     socket.is_open()
540                 // );
541                 self.is_listening = true;
542 
543                 Ok(())
544             }
545             Err(_) => Err(SystemError::EINVAL),
546         };
547     }
548 }
549 
550 impl Socket for TcpSocket {
as_any_ref(&self) -> &dyn core::any::Any551     fn as_any_ref(&self) -> &dyn core::any::Any {
552         self
553     }
554 
as_any_mut(&mut self) -> &mut dyn core::any::Any555     fn as_any_mut(&mut self) -> &mut dyn core::any::Any {
556         self
557     }
558 
read(&mut self, buf: &mut [u8]) -> (Result<usize, SystemError>, Endpoint)559     fn read(&mut self, buf: &mut [u8]) -> (Result<usize, SystemError>, Endpoint) {
560         if HANDLE_MAP
561             .read_irqsave()
562             .get(&self.socket_handle())
563             .unwrap()
564             .shutdown_type()
565             .contains(ShutdownType::RCV_SHUTDOWN)
566         {
567             return (Err(SystemError::ENOTCONN), Endpoint::Ip(None));
568         }
569         // kdebug!("tcp socket: read, buf len={}", buf.len());
570 
571         loop {
572             poll_ifaces();
573             let mut socket_set_guard = SOCKET_SET.lock_irqsave();
574             let socket = socket_set_guard.get_mut::<tcp::Socket>(self.handle.0);
575 
576             // 如果socket已经关闭,返回错误
577             if !socket.is_active() {
578                 // kdebug!("Tcp Socket Read Error, socket is closed");
579                 return (Err(SystemError::ENOTCONN), Endpoint::Ip(None));
580             }
581 
582             if socket.may_recv() {
583                 let recv_res = socket.recv_slice(buf);
584 
585                 if let Ok(size) = recv_res {
586                     if size > 0 {
587                         let endpoint = if let Some(p) = socket.remote_endpoint() {
588                             p
589                         } else {
590                             return (Err(SystemError::ENOTCONN), Endpoint::Ip(None));
591                         };
592 
593                         drop(socket_set_guard);
594                         poll_ifaces();
595                         return (Ok(size), Endpoint::Ip(Some(endpoint)));
596                     }
597                 } else {
598                     let err = recv_res.unwrap_err();
599                     match err {
600                         tcp::RecvError::InvalidState => {
601                             kwarn!("Tcp Socket Read Error, InvalidState");
602                             return (Err(SystemError::ENOTCONN), Endpoint::Ip(None));
603                         }
604                         tcp::RecvError::Finished => {
605                             // 对端写端已关闭,我们应该关闭读端
606                             HANDLE_MAP
607                                 .write_irqsave()
608                                 .get_mut(&self.socket_handle())
609                                 .unwrap()
610                                 .shutdown_type_writer()
611                                 .insert(ShutdownType::RCV_SHUTDOWN);
612                             return (Err(SystemError::ENOTCONN), Endpoint::Ip(None));
613                         }
614                     }
615                 }
616             } else {
617                 return (Err(SystemError::ENOTCONN), Endpoint::Ip(None));
618             }
619             drop(socket_set_guard);
620             SocketHandleItem::sleep(
621                 self.socket_handle(),
622                 EPollEventType::EPOLLIN.bits() as u64,
623                 HANDLE_MAP.read_irqsave(),
624             );
625         }
626     }
627 
write(&self, buf: &[u8], _to: Option<Endpoint>) -> Result<usize, SystemError>628     fn write(&self, buf: &[u8], _to: Option<Endpoint>) -> Result<usize, SystemError> {
629         if HANDLE_MAP
630             .read_irqsave()
631             .get(&self.socket_handle())
632             .unwrap()
633             .shutdown_type()
634             .contains(ShutdownType::RCV_SHUTDOWN)
635         {
636             return Err(SystemError::ENOTCONN);
637         }
638         let mut socket_set_guard = SOCKET_SET.lock_irqsave();
639         let socket = socket_set_guard.get_mut::<tcp::Socket>(self.handle.0);
640 
641         if socket.is_open() {
642             if socket.can_send() {
643                 match socket.send_slice(buf) {
644                     Ok(size) => {
645                         drop(socket_set_guard);
646                         poll_ifaces();
647                         return Ok(size);
648                     }
649                     Err(e) => {
650                         kerror!("Tcp Socket Write Error {e:?}");
651                         return Err(SystemError::ENOBUFS);
652                     }
653                 }
654             } else {
655                 return Err(SystemError::ENOBUFS);
656             }
657         }
658 
659         return Err(SystemError::ENOTCONN);
660     }
661 
poll(&self) -> EPollEventType662     fn poll(&self) -> EPollEventType {
663         let mut socket_set_guard = SOCKET_SET.lock_irqsave();
664         let socket = socket_set_guard.get_mut::<tcp::Socket>(self.handle.0);
665 
666         return SocketPollMethod::tcp_poll(
667             socket,
668             HANDLE_MAP
669                 .read_irqsave()
670                 .get(&self.socket_handle())
671                 .unwrap()
672                 .shutdown_type(),
673         );
674     }
675 
connect(&mut self, endpoint: Endpoint) -> Result<(), SystemError>676     fn connect(&mut self, endpoint: Endpoint) -> Result<(), SystemError> {
677         let mut sockets = SOCKET_SET.lock_irqsave();
678         let socket = sockets.get_mut::<tcp::Socket>(self.handle.0);
679 
680         if let Endpoint::Ip(Some(ip)) = endpoint {
681             let temp_port = PORT_MANAGER.get_ephemeral_port(self.metadata.socket_type)?;
682             // 检测端口是否被占用
683             PORT_MANAGER.bind_port(self.metadata.socket_type, temp_port, self.handle.clone())?;
684 
685             // kdebug!("temp_port: {}", temp_port);
686             let iface: Arc<dyn NetDriver> = NET_DRIVERS.write_irqsave().get(&0).unwrap().clone();
687             let mut inner_iface = iface.inner_iface().lock();
688             // kdebug!("to connect: {ip:?}");
689 
690             match socket.connect(&mut inner_iface.context(), ip, temp_port) {
691                 Ok(()) => {
692                     // avoid deadlock
693                     drop(inner_iface);
694                     drop(iface);
695                     drop(sockets);
696                     loop {
697                         poll_ifaces();
698                         let mut sockets = SOCKET_SET.lock_irqsave();
699                         let socket = sockets.get_mut::<tcp::Socket>(self.handle.0);
700 
701                         match socket.state() {
702                             tcp::State::Established => {
703                                 return Ok(());
704                             }
705                             tcp::State::SynSent => {
706                                 drop(sockets);
707                                 SocketHandleItem::sleep(
708                                     self.socket_handle(),
709                                     Self::CAN_CONNECT,
710                                     HANDLE_MAP.read_irqsave(),
711                                 );
712                             }
713                             _ => {
714                                 return Err(SystemError::ECONNREFUSED);
715                             }
716                         }
717                     }
718                 }
719                 Err(e) => {
720                     // kerror!("Tcp Socket Connect Error {e:?}");
721                     match e {
722                         tcp::ConnectError::InvalidState => return Err(SystemError::EISCONN),
723                         tcp::ConnectError::Unaddressable => return Err(SystemError::EADDRNOTAVAIL),
724                     }
725                 }
726             }
727         } else {
728             return Err(SystemError::EINVAL);
729         }
730     }
731 
732     /// @brief tcp socket 监听 local_endpoint 端口
733     ///
734     /// @param backlog 未处理的连接队列的最大长度. 由于smoltcp不支持backlog,所以这个参数目前无效
listen(&mut self, _backlog: usize) -> Result<(), SystemError>735     fn listen(&mut self, _backlog: usize) -> Result<(), SystemError> {
736         if self.is_listening {
737             return Ok(());
738         }
739 
740         let local_endpoint = self.local_endpoint.ok_or(SystemError::EINVAL)?;
741         let mut sockets = SOCKET_SET.lock_irqsave();
742         let socket = sockets.get_mut::<tcp::Socket>(self.handle.0);
743 
744         if socket.is_listening() {
745             // kdebug!("Tcp Socket is already listening on {local_endpoint}");
746             return Ok(());
747         }
748         // kdebug!("Tcp Socket  before listen, open={}", socket.is_open());
749         return self.do_listen(socket, local_endpoint);
750     }
751 
bind(&mut self, endpoint: Endpoint) -> Result<(), SystemError>752     fn bind(&mut self, endpoint: Endpoint) -> Result<(), SystemError> {
753         if let Endpoint::Ip(Some(mut ip)) = endpoint {
754             if ip.port == 0 {
755                 ip.port = PORT_MANAGER.get_ephemeral_port(self.metadata.socket_type)?;
756             }
757 
758             // 检测端口是否已被占用
759             PORT_MANAGER.bind_port(self.metadata.socket_type, ip.port, self.handle.clone())?;
760 
761             self.local_endpoint = Some(ip);
762             self.is_listening = false;
763             return Ok(());
764         }
765         return Err(SystemError::EINVAL);
766     }
767 
shutdown(&mut self, shutdown_type: super::ShutdownType) -> Result<(), SystemError>768     fn shutdown(&mut self, shutdown_type: super::ShutdownType) -> Result<(), SystemError> {
769         // TODO:目前只是在表层判断,对端不知晓,后续需使用tcp实现
770         HANDLE_MAP
771             .write_irqsave()
772             .get_mut(&self.socket_handle())
773             .unwrap()
774             .shutdown_type = RwLock::new(shutdown_type);
775         return Ok(());
776     }
777 
accept(&mut self) -> Result<(Box<dyn Socket>, Endpoint), SystemError>778     fn accept(&mut self) -> Result<(Box<dyn Socket>, Endpoint), SystemError> {
779         let endpoint = self.local_endpoint.ok_or(SystemError::EINVAL)?;
780         loop {
781             // kdebug!("tcp accept: poll_ifaces()");
782             poll_ifaces();
783 
784             let mut sockets = SOCKET_SET.lock_irqsave();
785 
786             let socket = sockets.get_mut::<tcp::Socket>(self.handle.0);
787 
788             if socket.is_active() {
789                 // kdebug!("tcp accept: socket.is_active()");
790                 let remote_ep = socket.remote_endpoint().ok_or(SystemError::ENOTCONN)?;
791 
792                 let new_socket = {
793                     // Initialize the TCP socket's buffers.
794                     let rx_buffer = tcp::SocketBuffer::new(vec![0; Self::DEFAULT_RX_BUF_SIZE]);
795                     let tx_buffer = tcp::SocketBuffer::new(vec![0; Self::DEFAULT_TX_BUF_SIZE]);
796                     // The new TCP socket used for sending and receiving data.
797                     let mut tcp_socket = tcp::Socket::new(rx_buffer, tx_buffer);
798                     self.do_listen(&mut tcp_socket, endpoint)
799                         .expect("do_listen failed");
800 
801                     // tcp_socket.listen(endpoint).unwrap();
802 
803                     // 之所以把old_handle存入new_socket, 是因为当前时刻,smoltcp已经把old_handle对应的socket与远程的endpoint关联起来了
804                     // 因此需要再为当前的socket分配一个新的handle
805                     let new_handle = GlobalSocketHandle::new(sockets.add(tcp_socket));
806                     let old_handle = ::core::mem::replace(&mut self.handle, new_handle.clone());
807 
808                     // 更新端口与 handle 的绑定
809                     if let Some(Endpoint::Ip(Some(ip))) = self.endpoint() {
810                         PORT_MANAGER.unbind_port(self.metadata.socket_type, ip.port)?;
811                         PORT_MANAGER.bind_port(
812                             self.metadata.socket_type,
813                             ip.port,
814                             new_handle.clone(),
815                         )?;
816                     }
817 
818                     let metadata = SocketMetadata::new(
819                         SocketType::TcpSocket,
820                         Self::DEFAULT_TX_BUF_SIZE,
821                         Self::DEFAULT_RX_BUF_SIZE,
822                         Self::DEFAULT_METADATA_BUF_SIZE,
823                         self.metadata.options,
824                     );
825 
826                     let new_socket = Box::new(TcpSocket {
827                         handle: old_handle.clone(),
828                         local_endpoint: self.local_endpoint,
829                         is_listening: false,
830                         metadata,
831                     });
832 
833                     // 更新handle表
834                     let mut handle_guard = HANDLE_MAP.write_irqsave();
835                     // 先删除原来的
836                     let item = handle_guard.remove(&old_handle.0).unwrap();
837                     // 按照smoltcp行为,将新的handle绑定到原来的item
838                     handle_guard.insert(new_handle.0, item);
839                     let new_item = SocketHandleItem::from_socket(&new_socket);
840                     // 插入新的item
841                     handle_guard.insert(old_handle.0, new_item);
842 
843                     new_socket
844                 };
845                 // kdebug!("tcp accept: new socket: {:?}", new_socket);
846                 drop(sockets);
847                 poll_ifaces();
848 
849                 return Ok((new_socket, Endpoint::Ip(Some(remote_ep))));
850             }
851             drop(sockets);
852 
853             SocketHandleItem::sleep(
854                 self.socket_handle(),
855                 Self::CAN_ACCPET,
856                 HANDLE_MAP.read_irqsave(),
857             );
858         }
859     }
860 
endpoint(&self) -> Option<Endpoint>861     fn endpoint(&self) -> Option<Endpoint> {
862         let mut result: Option<Endpoint> =
863             self.local_endpoint.clone().map(|x| Endpoint::Ip(Some(x)));
864 
865         if result.is_none() {
866             let sockets = SOCKET_SET.lock_irqsave();
867             let socket = sockets.get::<tcp::Socket>(self.handle.0);
868             if let Some(ep) = socket.local_endpoint() {
869                 result = Some(Endpoint::Ip(Some(ep)));
870             }
871         }
872         return result;
873     }
874 
peer_endpoint(&self) -> Option<Endpoint>875     fn peer_endpoint(&self) -> Option<Endpoint> {
876         let sockets = SOCKET_SET.lock_irqsave();
877         let socket = sockets.get::<tcp::Socket>(self.handle.0);
878         return socket.remote_endpoint().map(|x| Endpoint::Ip(Some(x)));
879     }
880 
metadata(&self) -> Result<SocketMetadata, SystemError>881     fn metadata(&self) -> Result<SocketMetadata, SystemError> {
882         Ok(self.metadata.clone())
883     }
884 
box_clone(&self) -> Box<dyn Socket>885     fn box_clone(&self) -> Box<dyn Socket> {
886         return Box::new(self.clone());
887     }
888 
socket_handle(&self) -> SocketHandle889     fn socket_handle(&self) -> SocketHandle {
890         self.handle.0
891     }
892 }
893 
894 /// # 表示 seqpacket socket
895 #[derive(Debug, Clone)]
896 #[cast_to(Socket)]
897 pub struct SeqpacketSocket {
898     metadata: SocketMetadata,
899     buffer: Arc<SpinLock<Vec<u8>>>,
900     peer_buffer: Option<Arc<SpinLock<Vec<u8>>>>,
901 }
902 
903 impl SeqpacketSocket {
904     /// 默认的元数据缓冲区大小
905     pub const DEFAULT_METADATA_BUF_SIZE: usize = 1024;
906     /// 默认的缓冲区大小
907     pub const DEFAULT_BUF_SIZE: usize = 64 * 1024;
908 
909     /// # 创建一个seqpacket的socket
910     ///
911     /// ## 参数
912     /// - `options`: socket的选项
new(options: SocketOptions) -> Self913     pub fn new(options: SocketOptions) -> Self {
914         let buffer = Vec::with_capacity(Self::DEFAULT_BUF_SIZE);
915 
916         let metadata = SocketMetadata::new(
917             SocketType::SeqpacketSocket,
918             Self::DEFAULT_BUF_SIZE,
919             0,
920             Self::DEFAULT_METADATA_BUF_SIZE,
921             options,
922         );
923 
924         return Self {
925             metadata,
926             buffer: Arc::new(SpinLock::new(buffer)),
927             peer_buffer: None,
928         };
929     }
930 
buffer(&self) -> Arc<SpinLock<Vec<u8>>>931     fn buffer(&self) -> Arc<SpinLock<Vec<u8>>> {
932         self.buffer.clone()
933     }
934 
set_peer_buffer(&mut self, peer_buffer: Arc<SpinLock<Vec<u8>>>)935     fn set_peer_buffer(&mut self, peer_buffer: Arc<SpinLock<Vec<u8>>>) {
936         self.peer_buffer = Some(peer_buffer);
937     }
938 }
939 
940 impl Socket for SeqpacketSocket {
as_any_ref(&self) -> &dyn core::any::Any941     fn as_any_ref(&self) -> &dyn core::any::Any {
942         self
943     }
944 
as_any_mut(&mut self) -> &mut dyn core::any::Any945     fn as_any_mut(&mut self) -> &mut dyn core::any::Any {
946         self
947     }
948 
read(&mut self, buf: &mut [u8]) -> (Result<usize, SystemError>, Endpoint)949     fn read(&mut self, buf: &mut [u8]) -> (Result<usize, SystemError>, Endpoint) {
950         let buffer = self.buffer.lock_irqsave();
951 
952         let len = core::cmp::min(buf.len(), buffer.len());
953         buf[..len].copy_from_slice(&buffer[..len]);
954 
955         (Ok(len), Endpoint::Unused)
956     }
957 
write(&self, buf: &[u8], _to: Option<Endpoint>) -> Result<usize, SystemError>958     fn write(&self, buf: &[u8], _to: Option<Endpoint>) -> Result<usize, SystemError> {
959         if self.peer_buffer.is_none() {
960             kwarn!("SeqpacketSocket is now just for socketpair");
961             return Err(SystemError::ENOSYS);
962         }
963 
964         let binding = self.peer_buffer.clone().unwrap();
965         let mut peer_buffer = binding.lock_irqsave();
966 
967         let len = buf.len();
968         if peer_buffer.capacity() - peer_buffer.len() < len {
969             return Err(SystemError::ENOBUFS);
970         }
971         peer_buffer[..len].copy_from_slice(buf);
972 
973         Ok(len)
974     }
975 
socketpair_ops(&self) -> Option<&'static dyn SocketpairOps>976     fn socketpair_ops(&self) -> Option<&'static dyn SocketpairOps> {
977         Some(&SeqpacketSocketpairOps)
978     }
979 
metadata(&self) -> Result<SocketMetadata, SystemError>980     fn metadata(&self) -> Result<SocketMetadata, SystemError> {
981         Ok(self.metadata.clone())
982     }
983 
box_clone(&self) -> Box<dyn Socket>984     fn box_clone(&self) -> Box<dyn Socket> {
985         Box::new(self.clone())
986     }
987 }
988 
989 struct SeqpacketSocketpairOps;
990 
991 impl SocketpairOps for SeqpacketSocketpairOps {
socketpair(&self, socket0: &mut Box<dyn Socket>, socket1: &mut Box<dyn Socket>)992     fn socketpair(&self, socket0: &mut Box<dyn Socket>, socket1: &mut Box<dyn Socket>) {
993         let pair0 = socket0
994             .as_mut()
995             .as_any_mut()
996             .downcast_mut::<SeqpacketSocket>()
997             .unwrap();
998 
999         let pair1 = socket1
1000             .as_mut()
1001             .as_any_mut()
1002             .downcast_mut::<SeqpacketSocket>()
1003             .unwrap();
1004         pair0.set_peer_buffer(pair1.buffer());
1005         pair1.set_peer_buffer(pair0.buffer());
1006     }
1007 }
1008