1 #![allow(dead_code)]
2 use alloc::{boxed::Box, sync::Arc, vec::Vec};
3 use smoltcp::{
4     iface::{SocketHandle, SocketSet},
5     socket::{raw, tcp, udp},
6     wire,
7 };
8 
9 use crate::{
10     arch::rand::rand,
11     driver::net::NetDriver,
12     filesystem::vfs::{FileType, IndexNode, Metadata, PollStatus},
13     kerror, kwarn,
14     libs::{
15         spinlock::{SpinLock, SpinLockGuard},
16         wait_queue::WaitQueue,
17     },
18     syscall::SystemError,
19 };
20 
21 use super::{net_core::poll_ifaces, Endpoint, Protocol, Socket, NET_DRIVERS};
22 
23 lazy_static! {
24     /// 所有socket的集合
25     /// TODO: 优化这里,自己实现SocketSet!!!现在这样的话,不管全局有多少个网卡,每个时间点都只会有1个进程能够访问socket
26     pub static ref SOCKET_SET: SpinLock<SocketSet<'static >> = SpinLock::new(SocketSet::new(vec![]));
27     pub static ref SOCKET_WAITQUEUE: WaitQueue = WaitQueue::INIT;
28 }
29 
30 /* For setsockopt(2) */
31 // See: linux-5.19.10/include/uapi/asm-generic/socket.h#9
32 pub const SOL_SOCKET: u8 = 1;
33 
34 /// @brief socket的句柄管理组件。
35 /// 它在smoltcp的SocketHandle上封装了一层,增加更多的功能。
36 /// 比如,在socket被关闭时,自动释放socket的资源,通知系统的其他组件。
37 #[derive(Debug)]
38 pub struct GlobalSocketHandle(SocketHandle);
39 
40 impl GlobalSocketHandle {
new(handle: SocketHandle) -> Self41     pub fn new(handle: SocketHandle) -> Self {
42         Self(handle)
43     }
44 }
45 
46 impl Clone for GlobalSocketHandle {
clone(&self) -> Self47     fn clone(&self) -> Self {
48         Self(self.0)
49     }
50 }
51 
52 impl Drop for GlobalSocketHandle {
drop(&mut self)53     fn drop(&mut self) {
54         let mut socket_set_guard = SOCKET_SET.lock();
55         socket_set_guard.remove(self.0); // 删除的时候,会发送一条FINISH的信息?
56         drop(socket_set_guard);
57         poll_ifaces();
58     }
59 }
60 
61 /// @brief socket的类型
62 #[derive(Debug)]
63 pub enum SocketType {
64     /// 原始的socket
65     RawSocket,
66     /// 用于Tcp通信的 Socket
67     TcpSocket,
68     /// 用于Udp通信的 Socket
69     UdpSocket,
70 }
71 
72 bitflags! {
73     /// @brief socket的选项
74     #[derive(Default)]
75     pub struct SocketOptions: u32 {
76         /// 是否阻塞
77         const BLOCK = 1 << 0;
78         /// 是否允许广播
79         const BROADCAST = 1 << 1;
80         /// 是否允许多播
81         const MULTICAST = 1 << 2;
82         /// 是否允许重用地址
83         const REUSEADDR = 1 << 3;
84         /// 是否允许重用端口
85         const REUSEPORT = 1 << 4;
86     }
87 }
88 
89 #[derive(Debug)]
90 /// @brief 在trait Socket的metadata函数中返回该结构体供外部使用
91 pub struct SocketMetadata {
92     /// socket的类型
93     pub socket_type: SocketType,
94     /// 发送缓冲区的大小
95     pub send_buf_size: usize,
96     /// 接收缓冲区的大小
97     pub recv_buf_size: usize,
98     /// 元数据的缓冲区的大小
99     pub metadata_buf_size: usize,
100     /// socket的选项
101     pub options: SocketOptions,
102 }
103 
104 /// @brief 表示原始的socket。原始套接字绕过传输层协议(如 TCP 或 UDP)并提供对网络层协议(如 IP)的直接访问。
105 ///
106 /// ref: https://man7.org/linux/man-pages/man7/raw.7.html
107 #[derive(Debug, Clone)]
108 pub struct RawSocket {
109     handle: GlobalSocketHandle,
110     /// 用户发送的数据包是否包含了IP头.
111     /// 如果是true,用户发送的数据包,必须包含IP头。(即用户要自行设置IP头+数据)
112     /// 如果是false,用户发送的数据包,不包含IP头。(即用户只要设置数据)
113     header_included: bool,
114     /// socket的选项
115     options: SocketOptions,
116 }
117 
118 impl RawSocket {
119     /// 元数据的缓冲区的大小
120     pub const DEFAULT_METADATA_BUF_SIZE: usize = 1024;
121     /// 默认的发送缓冲区的大小 transmiss
122     pub const DEFAULT_RX_BUF_SIZE: usize = 64 * 1024;
123     /// 默认的接收缓冲区的大小 receive
124     pub const DEFAULT_TX_BUF_SIZE: usize = 64 * 1024;
125 
126     /// @brief 创建一个原始的socket
127     ///
128     /// @param protocol 协议号
129     /// @param options socket的选项
130     ///
131     /// @return 返回创建的原始的socket
new(protocol: Protocol, options: SocketOptions) -> Self132     pub fn new(protocol: Protocol, options: SocketOptions) -> Self {
133         let tx_buffer = raw::PacketBuffer::new(
134             vec![raw::PacketMetadata::EMPTY; Self::DEFAULT_METADATA_BUF_SIZE],
135             vec![0; Self::DEFAULT_TX_BUF_SIZE],
136         );
137         let rx_buffer = raw::PacketBuffer::new(
138             vec![raw::PacketMetadata::EMPTY; Self::DEFAULT_METADATA_BUF_SIZE],
139             vec![0; Self::DEFAULT_RX_BUF_SIZE],
140         );
141         let protocol: u8 = protocol.into();
142         let socket = raw::Socket::new(
143             smoltcp::wire::IpVersion::Ipv4,
144             wire::IpProtocol::from(protocol),
145             tx_buffer,
146             rx_buffer,
147         );
148 
149         // 把socket添加到socket集合中,并得到socket的句柄
150         let handle: GlobalSocketHandle = GlobalSocketHandle::new(SOCKET_SET.lock().add(socket));
151 
152         return Self {
153             handle,
154             header_included: false,
155             options,
156         };
157     }
158 }
159 
160 impl Socket for RawSocket {
read(&self, buf: &mut [u8]) -> (Result<usize, SystemError>, Endpoint)161     fn read(&self, buf: &mut [u8]) -> (Result<usize, SystemError>, Endpoint) {
162         poll_ifaces();
163         loop {
164             // 如何优化这里?
165             let mut socket_set_guard = SOCKET_SET.lock();
166             let socket = socket_set_guard.get_mut::<raw::Socket>(self.handle.0);
167 
168             match socket.recv_slice(buf) {
169                 Ok(len) => {
170                     let packet = wire::Ipv4Packet::new_unchecked(buf);
171                     return (
172                         Ok(len),
173                         Endpoint::Ip(Some(smoltcp::wire::IpEndpoint {
174                             addr: wire::IpAddress::Ipv4(packet.src_addr()),
175                             port: 0,
176                         })),
177                     );
178                 }
179                 Err(smoltcp::socket::raw::RecvError::Exhausted) => {
180                     if !self.options.contains(SocketOptions::BLOCK) {
181                         // 如果是非阻塞的socket,就返回错误
182                         return (Err(SystemError::EAGAIN_OR_EWOULDBLOCK), Endpoint::Ip(None));
183                     }
184                 }
185             }
186             drop(socket);
187             drop(socket_set_guard);
188             SOCKET_WAITQUEUE.sleep();
189         }
190     }
191 
write(&self, buf: &[u8], to: Option<super::Endpoint>) -> Result<usize, SystemError>192     fn write(&self, buf: &[u8], to: Option<super::Endpoint>) -> Result<usize, SystemError> {
193         // 如果用户发送的数据包,包含IP头,则直接发送
194         if self.header_included {
195             let mut socket_set_guard = SOCKET_SET.lock();
196             let socket = socket_set_guard.get_mut::<raw::Socket>(self.handle.0);
197             match socket.send_slice(buf) {
198                 Ok(_len) => {
199                     return Ok(buf.len());
200                 }
201                 Err(smoltcp::socket::raw::SendError::BufferFull) => {
202                     return Err(SystemError::ENOBUFS);
203                 }
204             }
205         } else {
206             // 如果用户发送的数据包,不包含IP头,则需要自己构造IP头
207 
208             if let Some(Endpoint::Ip(Some(endpoint))) = to {
209                 let mut socket_set_guard = SOCKET_SET.lock();
210                 let socket: &mut raw::Socket =
211                     socket_set_guard.get_mut::<raw::Socket>(self.handle.0);
212 
213                 // 暴力解决方案:只考虑0号网卡。 TODO:考虑多网卡的情况!!!
214                 let iface = NET_DRIVERS.read().get(&0).unwrap().clone();
215 
216                 // 构造IP头
217                 let ipv4_src_addr: Option<smoltcp::wire::Ipv4Address> =
218                     iface.inner_iface().lock().ipv4_addr();
219                 if ipv4_src_addr.is_none() {
220                     return Err(SystemError::ENETUNREACH);
221                 }
222                 let ipv4_src_addr = ipv4_src_addr.unwrap();
223 
224                 if let wire::IpAddress::Ipv4(ipv4_dst) = endpoint.addr {
225                     let len = buf.len();
226 
227                     // 创建20字节的IPv4头部
228                     let mut buffer: Vec<u8> = vec![0u8; len + 20];
229                     let mut packet: wire::Ipv4Packet<&mut Vec<u8>> =
230                         wire::Ipv4Packet::new_unchecked(&mut buffer);
231 
232                     // 封装ipv4 header
233                     packet.set_version(4);
234                     packet.set_header_len(20);
235                     packet.set_total_len((20 + len) as u16);
236                     packet.set_src_addr(ipv4_src_addr);
237                     packet.set_dst_addr(ipv4_dst);
238 
239                     // 设置ipv4 header的protocol字段
240                     packet.set_next_header(socket.ip_protocol().into());
241 
242                     // 获取IP数据包的负载字段
243                     let payload: &mut [u8] = packet.payload_mut();
244                     payload.copy_from_slice(buf);
245 
246                     // 填充checksum字段
247                     packet.fill_checksum();
248 
249                     // 发送数据包
250                     socket.send_slice(&buffer).unwrap();
251 
252                     drop(socket);
253 
254                     iface.poll(&mut socket_set_guard).ok();
255 
256                     drop(socket_set_guard);
257                     return Ok(len);
258                 } else {
259                     kwarn!("Unsupport Ip protocol type!");
260                     return Err(SystemError::EINVAL);
261                 }
262             } else {
263                 // 如果没有指定目的地址,则返回错误
264                 return Err(SystemError::ENOTCONN);
265             }
266         }
267     }
268 
connect(&mut self, _endpoint: super::Endpoint) -> Result<(), SystemError>269     fn connect(&mut self, _endpoint: super::Endpoint) -> Result<(), SystemError> {
270         return Ok(());
271     }
272 
metadata(&self) -> Result<SocketMetadata, SystemError>273     fn metadata(&self) -> Result<SocketMetadata, SystemError> {
274         todo!()
275     }
276 
box_clone(&self) -> alloc::boxed::Box<dyn Socket>277     fn box_clone(&self) -> alloc::boxed::Box<dyn Socket> {
278         return Box::new(self.clone());
279     }
280 }
281 
282 /// @brief 表示udp socket
283 ///
284 /// https://man7.org/linux/man-pages/man7/udp.7.html
285 #[derive(Debug, Clone)]
286 pub struct UdpSocket {
287     pub handle: GlobalSocketHandle,
288     remote_endpoint: Option<Endpoint>, // 记录远程endpoint提供给connect(), 应该使用IP地址。
289     options: SocketOptions,
290 }
291 
292 impl UdpSocket {
293     /// 元数据的缓冲区的大小
294     pub const DEFAULT_METADATA_BUF_SIZE: usize = 1024;
295     /// 默认的发送缓冲区的大小 transmiss
296     pub const DEFAULT_RX_BUF_SIZE: usize = 64 * 1024;
297     /// 默认的接收缓冲区的大小 receive
298     pub const DEFAULT_TX_BUF_SIZE: usize = 64 * 1024;
299 
300     /// @brief 创建一个原始的socket
301     ///
302     /// @param protocol 协议号
303     /// @param options socket的选项
304     ///
305     /// @return 返回创建的原始的socket
new(options: SocketOptions) -> Self306     pub fn new(options: SocketOptions) -> Self {
307         let tx_buffer = udp::PacketBuffer::new(
308             vec![udp::PacketMetadata::EMPTY; Self::DEFAULT_METADATA_BUF_SIZE],
309             vec![0; Self::DEFAULT_TX_BUF_SIZE],
310         );
311         let rx_buffer = udp::PacketBuffer::new(
312             vec![udp::PacketMetadata::EMPTY; Self::DEFAULT_METADATA_BUF_SIZE],
313             vec![0; Self::DEFAULT_RX_BUF_SIZE],
314         );
315         let socket = udp::Socket::new(tx_buffer, rx_buffer);
316 
317         // 把socket添加到socket集合中,并得到socket的句柄
318         let handle: GlobalSocketHandle = GlobalSocketHandle::new(SOCKET_SET.lock().add(socket));
319 
320         return Self {
321             handle,
322             remote_endpoint: None,
323             options,
324         };
325     }
326 
do_bind(&self, socket: &mut udp::Socket, endpoint: Endpoint) -> Result<(), SystemError>327     fn do_bind(&self, socket: &mut udp::Socket, endpoint: Endpoint) -> Result<(), SystemError> {
328         if let Endpoint::Ip(Some(ip)) = endpoint {
329             let bind_res = if ip.addr.is_unspecified() {
330                 socket.bind(ip.port)
331             } else {
332                 socket.bind(ip)
333             };
334 
335             match bind_res {
336                 Ok(()) => return Ok(()),
337                 Err(_) => return Err(SystemError::EINVAL),
338             }
339         } else {
340             return Err(SystemError::EINVAL);
341         };
342     }
343 }
344 
345 impl Socket for UdpSocket {
346     /// @brief 在read函数执行之前,请先bind到本地的指定端口
read(&self, buf: &mut [u8]) -> (Result<usize, SystemError>, Endpoint)347     fn read(&self, buf: &mut [u8]) -> (Result<usize, SystemError>, Endpoint) {
348         loop {
349             // kdebug!("Wait22 to Read");
350             poll_ifaces();
351             let mut socket_set_guard = SOCKET_SET.lock();
352             let socket = socket_set_guard.get_mut::<udp::Socket>(self.handle.0);
353 
354             // kdebug!("Wait to Read");
355 
356             if socket.can_recv() {
357                 if let Ok((size, remote_endpoint)) = socket.recv_slice(buf) {
358                     drop(socket);
359                     drop(socket_set_guard);
360                     poll_ifaces();
361                     return (Ok(size), Endpoint::Ip(Some(remote_endpoint)));
362                 }
363             } else {
364                 // 如果socket没有连接,则忙等
365                 // return (Err(SystemError::ENOTCONN), Endpoint::Ip(None));
366             }
367             drop(socket);
368             drop(socket_set_guard);
369             SOCKET_WAITQUEUE.sleep();
370         }
371     }
372 
write(&self, buf: &[u8], to: Option<super::Endpoint>) -> Result<usize, SystemError>373     fn write(&self, buf: &[u8], to: Option<super::Endpoint>) -> Result<usize, SystemError> {
374         // kdebug!("udp to send: {:?}, len={}", to, buf.len());
375         let remote_endpoint: &wire::IpEndpoint = {
376             if let Some(Endpoint::Ip(Some(ref endpoint))) = to {
377                 endpoint
378             } else if let Some(Endpoint::Ip(Some(ref endpoint))) = self.remote_endpoint {
379                 endpoint
380             } else {
381                 return Err(SystemError::ENOTCONN);
382             }
383         };
384         // kdebug!("udp write: remote = {:?}", remote_endpoint);
385 
386         let mut socket_set_guard = SOCKET_SET.lock();
387         let socket = socket_set_guard.get_mut::<udp::Socket>(self.handle.0);
388         // kdebug!("is open()={}", socket.is_open());
389         // kdebug!("socket endpoint={:?}", socket.endpoint());
390         if socket.endpoint().port == 0 {
391             let temp_port = get_ephemeral_port();
392 
393             let local_ep = match remote_endpoint.addr {
394                 // 远程remote endpoint使用什么协议,发送的时候使用的协议是一样的吧
395                 // 否则就用 self.endpoint().addr.unwrap()
396                 wire::IpAddress::Ipv4(_) => Endpoint::Ip(Some(wire::IpEndpoint::new(
397                     smoltcp::wire::IpAddress::Ipv4(wire::Ipv4Address::UNSPECIFIED),
398                     temp_port,
399                 ))),
400                 wire::IpAddress::Ipv6(_) => Endpoint::Ip(Some(wire::IpEndpoint::new(
401                     smoltcp::wire::IpAddress::Ipv6(wire::Ipv6Address::UNSPECIFIED),
402                     temp_port,
403                 ))),
404             };
405             // kdebug!("udp write: local_ep = {:?}", local_ep);
406             self.do_bind(socket, local_ep)?;
407         }
408         // kdebug!("is open()={}", socket.is_open());
409         if socket.can_send() {
410             // kdebug!("udp write: can send");
411             match socket.send_slice(&buf, *remote_endpoint) {
412                 Ok(()) => {
413                     // kdebug!("udp write: send ok");
414                     drop(socket);
415                     drop(socket_set_guard);
416                     poll_ifaces();
417                     return Ok(buf.len());
418                 }
419                 Err(_) => {
420                     // kdebug!("udp write: send err");
421                     return Err(SystemError::ENOBUFS);
422                 }
423             }
424         } else {
425             // kdebug!("udp write: can not send");
426             return Err(SystemError::ENOBUFS);
427         };
428     }
429 
bind(&mut self, endpoint: Endpoint) -> Result<(), SystemError>430     fn bind(&mut self, endpoint: Endpoint) -> Result<(), SystemError> {
431         let mut sockets = SOCKET_SET.lock();
432         let socket = sockets.get_mut::<udp::Socket>(self.handle.0);
433         // kdebug!("UDP Bind to {:?}", endpoint);
434         return self.do_bind(socket, endpoint);
435     }
436 
poll(&self) -> (bool, bool, bool)437     fn poll(&self) -> (bool, bool, bool) {
438         let sockets = SOCKET_SET.lock();
439         let socket = sockets.get::<udp::Socket>(self.handle.0);
440 
441         return (socket.can_send(), socket.can_recv(), false);
442     }
443 
444     /// @brief
connect(&mut self, endpoint: super::Endpoint) -> Result<(), SystemError>445     fn connect(&mut self, endpoint: super::Endpoint) -> Result<(), SystemError> {
446         if let Endpoint::Ip(_) = endpoint {
447             self.remote_endpoint = Some(endpoint);
448             return Ok(());
449         } else {
450             return Err(SystemError::EINVAL);
451         };
452     }
453 
ioctl( &self, _cmd: usize, _arg0: usize, _arg1: usize, _arg2: usize, ) -> Result<usize, SystemError>454     fn ioctl(
455         &self,
456         _cmd: usize,
457         _arg0: usize,
458         _arg1: usize,
459         _arg2: usize,
460     ) -> Result<usize, SystemError> {
461         todo!()
462     }
metadata(&self) -> Result<SocketMetadata, SystemError>463     fn metadata(&self) -> Result<SocketMetadata, SystemError> {
464         todo!()
465     }
466 
box_clone(&self) -> alloc::boxed::Box<dyn Socket>467     fn box_clone(&self) -> alloc::boxed::Box<dyn Socket> {
468         return Box::new(self.clone());
469     }
470 
endpoint(&self) -> Option<Endpoint>471     fn endpoint(&self) -> Option<Endpoint> {
472         let sockets = SOCKET_SET.lock();
473         let socket = sockets.get::<udp::Socket>(self.handle.0);
474         let listen_endpoint = socket.endpoint();
475 
476         if listen_endpoint.port == 0 {
477             return None;
478         } else {
479             // 如果listen_endpoint的address是None,意味着“监听所有的地址”。
480             // 这里假设所有的地址都是ipv4
481             // TODO: 支持ipv6
482             let result = wire::IpEndpoint::new(
483                 listen_endpoint
484                     .addr
485                     .unwrap_or(wire::IpAddress::v4(0, 0, 0, 0)),
486                 listen_endpoint.port,
487             );
488             return Some(Endpoint::Ip(Some(result)));
489         }
490     }
491 
peer_endpoint(&self) -> Option<Endpoint>492     fn peer_endpoint(&self) -> Option<Endpoint> {
493         return self.remote_endpoint.clone();
494     }
495 }
496 
497 /// @brief 表示 tcp socket
498 ///
499 /// https://man7.org/linux/man-pages/man7/tcp.7.html
500 #[derive(Debug, Clone)]
501 pub struct TcpSocket {
502     handle: GlobalSocketHandle,
503     local_endpoint: Option<wire::IpEndpoint>, // save local endpoint for bind()
504     is_listening: bool,
505     options: SocketOptions,
506 }
507 
508 impl TcpSocket {
509     /// 元数据的缓冲区的大小
510     pub const DEFAULT_METADATA_BUF_SIZE: usize = 1024;
511     /// 默认的发送缓冲区的大小 transmiss
512     pub const DEFAULT_RX_BUF_SIZE: usize = 512 * 1024;
513     /// 默认的接收缓冲区的大小 receive
514     pub const DEFAULT_TX_BUF_SIZE: usize = 512 * 1024;
515 
516     /// @brief 创建一个原始的socket
517     ///
518     /// @param protocol 协议号
519     /// @param options socket的选项
520     ///
521     /// @return 返回创建的原始的socket
new(options: SocketOptions) -> Self522     pub fn new(options: SocketOptions) -> Self {
523         let tx_buffer = tcp::SocketBuffer::new(vec![0; Self::DEFAULT_TX_BUF_SIZE]);
524         let rx_buffer = tcp::SocketBuffer::new(vec![0; Self::DEFAULT_RX_BUF_SIZE]);
525         let socket = tcp::Socket::new(tx_buffer, rx_buffer);
526 
527         // 把socket添加到socket集合中,并得到socket的句柄
528         let handle: GlobalSocketHandle = GlobalSocketHandle::new(SOCKET_SET.lock().add(socket));
529 
530         return Self {
531             handle,
532             local_endpoint: None,
533             is_listening: false,
534             options,
535         };
536     }
do_listen( &mut self, socket: &mut smoltcp::socket::tcp::Socket, local_endpoint: smoltcp::wire::IpEndpoint, ) -> Result<(), SystemError>537     fn do_listen(
538         &mut self,
539         socket: &mut smoltcp::socket::tcp::Socket,
540         local_endpoint: smoltcp::wire::IpEndpoint,
541     ) -> Result<(), SystemError> {
542         let listen_result = if local_endpoint.addr.is_unspecified() {
543             // kdebug!("Tcp Socket Listen on port {}", local_endpoint.port);
544             socket.listen(local_endpoint.port)
545         } else {
546             // kdebug!("Tcp Socket Listen on {local_endpoint}");
547             socket.listen(local_endpoint)
548         };
549         // todo: 增加端口占用检查
550         return match listen_result {
551             Ok(()) => {
552                 // kdebug!(
553                 //     "Tcp Socket Listen on {local_endpoint}, open?:{}",
554                 //     socket.is_open()
555                 // );
556                 self.is_listening = true;
557 
558                 Ok(())
559             }
560             Err(_) => Err(SystemError::EINVAL),
561         };
562     }
563 }
564 
565 impl Socket for TcpSocket {
read(&self, buf: &mut [u8]) -> (Result<usize, SystemError>, Endpoint)566     fn read(&self, buf: &mut [u8]) -> (Result<usize, SystemError>, Endpoint) {
567         // kdebug!("tcp socket: read, buf len={}", buf.len());
568 
569         loop {
570             poll_ifaces();
571             let mut socket_set_guard = SOCKET_SET.lock();
572             let socket = socket_set_guard.get_mut::<tcp::Socket>(self.handle.0);
573 
574             // 如果socket已经关闭,返回错误
575             if !socket.is_active() {
576                 // kdebug!("Tcp Socket Read Error, socket is closed");
577                 return (Err(SystemError::ENOTCONN), Endpoint::Ip(None));
578             }
579 
580             if socket.may_recv() {
581                 let recv_res = socket.recv_slice(buf);
582 
583                 if let Ok(size) = recv_res {
584                     if size > 0 {
585                         let endpoint = if let Some(p) = socket.remote_endpoint() {
586                             p
587                         } else {
588                             return (Err(SystemError::ENOTCONN), Endpoint::Ip(None));
589                         };
590 
591                         drop(socket);
592                         drop(socket_set_guard);
593                         poll_ifaces();
594                         return (Ok(size), Endpoint::Ip(Some(endpoint)));
595                     }
596                 } else {
597                     let err = recv_res.unwrap_err();
598                     match err {
599                         tcp::RecvError::InvalidState => {
600                             kwarn!("Tcp Socket Read Error, InvalidState");
601                             return (Err(SystemError::ENOTCONN), Endpoint::Ip(None));
602                         }
603                         tcp::RecvError::Finished => {
604                             return (Err(SystemError::ENOTCONN), Endpoint::Ip(None));
605                         }
606                     }
607                 }
608             } else {
609                 return (Err(SystemError::ENOTCONN), Endpoint::Ip(None));
610             }
611             drop(socket);
612             drop(socket_set_guard);
613             SOCKET_WAITQUEUE.sleep();
614         }
615     }
616 
write(&self, buf: &[u8], _to: Option<super::Endpoint>) -> Result<usize, SystemError>617     fn write(&self, buf: &[u8], _to: Option<super::Endpoint>) -> Result<usize, SystemError> {
618         let mut socket_set_guard = SOCKET_SET.lock();
619         let socket = socket_set_guard.get_mut::<tcp::Socket>(self.handle.0);
620 
621         if socket.is_open() {
622             if socket.can_send() {
623                 match socket.send_slice(buf) {
624                     Ok(size) => {
625                         drop(socket);
626                         drop(socket_set_guard);
627                         poll_ifaces();
628                         return Ok(size);
629                     }
630                     Err(e) => {
631                         kerror!("Tcp Socket Write Error {e:?}");
632                         return Err(SystemError::ENOBUFS);
633                     }
634                 }
635             } else {
636                 return Err(SystemError::ENOBUFS);
637             }
638         }
639 
640         return Err(SystemError::ENOTCONN);
641     }
642 
poll(&self) -> (bool, bool, bool)643     fn poll(&self) -> (bool, bool, bool) {
644         let mut socket_set_guard = SOCKET_SET.lock();
645         let socket = socket_set_guard.get_mut::<tcp::Socket>(self.handle.0);
646 
647         let mut input = false;
648         let mut output = false;
649         let mut error = false;
650         if self.is_listening && socket.is_active() {
651             input = true;
652         } else if !socket.is_open() {
653             error = true;
654         } else {
655             if socket.may_recv() {
656                 input = true;
657             }
658             if socket.can_send() {
659                 output = true;
660             }
661         }
662 
663         return (input, output, error);
664     }
665 
connect(&mut self, endpoint: Endpoint) -> Result<(), SystemError>666     fn connect(&mut self, endpoint: Endpoint) -> Result<(), SystemError> {
667         let mut sockets = SOCKET_SET.lock();
668         let socket = sockets.get_mut::<tcp::Socket>(self.handle.0);
669 
670         if let Endpoint::Ip(Some(ip)) = endpoint {
671             let temp_port = get_ephemeral_port();
672             // kdebug!("temp_port: {}", temp_port);
673             let iface: Arc<dyn NetDriver> = NET_DRIVERS.write().get(&0).unwrap().clone();
674             let mut inner_iface = iface.inner_iface().lock();
675             // kdebug!("to connect: {ip:?}");
676 
677             match socket.connect(&mut inner_iface.context(), ip, temp_port) {
678                 Ok(()) => {
679                     // avoid deadlock
680                     drop(inner_iface);
681                     drop(iface);
682                     drop(socket);
683                     drop(sockets);
684                     loop {
685                         poll_ifaces();
686                         let mut sockets = SOCKET_SET.lock();
687                         let socket = sockets.get_mut::<tcp::Socket>(self.handle.0);
688 
689                         match socket.state() {
690                             tcp::State::Established => {
691                                 return Ok(());
692                             }
693                             tcp::State::SynSent => {
694                                 drop(socket);
695                                 drop(sockets);
696                                 SOCKET_WAITQUEUE.sleep();
697                             }
698                             _ => {
699                                 return Err(SystemError::ECONNREFUSED);
700                             }
701                         }
702                     }
703                 }
704                 Err(e) => {
705                     // kerror!("Tcp Socket Connect Error {e:?}");
706                     match e {
707                         tcp::ConnectError::InvalidState => return Err(SystemError::EISCONN),
708                         tcp::ConnectError::Unaddressable => return Err(SystemError::EADDRNOTAVAIL),
709                     }
710                 }
711             }
712         } else {
713             return Err(SystemError::EINVAL);
714         }
715     }
716 
717     /// @brief tcp socket 监听 local_endpoint 端口
718     ///
719     /// @param backlog 未处理的连接队列的最大长度. 由于smoltcp不支持backlog,所以这个参数目前无效
listen(&mut self, _backlog: usize) -> Result<(), SystemError>720     fn listen(&mut self, _backlog: usize) -> Result<(), SystemError> {
721         if self.is_listening {
722             return Ok(());
723         }
724 
725         let local_endpoint = self.local_endpoint.ok_or(SystemError::EINVAL)?;
726         let mut sockets = SOCKET_SET.lock();
727         let socket = sockets.get_mut::<tcp::Socket>(self.handle.0);
728 
729         if socket.is_listening() {
730             // kdebug!("Tcp Socket is already listening on {local_endpoint}");
731             return Ok(());
732         }
733         // kdebug!("Tcp Socket  before listen, open={}", socket.is_open());
734         return self.do_listen(socket, local_endpoint);
735     }
736 
bind(&mut self, endpoint: Endpoint) -> Result<(), SystemError>737     fn bind(&mut self, endpoint: Endpoint) -> Result<(), SystemError> {
738         if let Endpoint::Ip(Some(mut ip)) = endpoint {
739             if ip.port == 0 {
740                 ip.port = get_ephemeral_port();
741             }
742 
743             self.local_endpoint = Some(ip);
744             self.is_listening = false;
745             return Ok(());
746         }
747         return Err(SystemError::EINVAL);
748     }
749 
shutdown(&self, _type: super::ShutdownType) -> Result<(), SystemError>750     fn shutdown(&self, _type: super::ShutdownType) -> Result<(), SystemError> {
751         let mut sockets = SOCKET_SET.lock();
752         let socket = sockets.get_mut::<tcp::Socket>(self.handle.0);
753         socket.close();
754         return Ok(());
755     }
756 
accept(&mut self) -> Result<(Box<dyn Socket>, Endpoint), SystemError>757     fn accept(&mut self) -> Result<(Box<dyn Socket>, Endpoint), SystemError> {
758         let endpoint = self.local_endpoint.ok_or(SystemError::EINVAL)?;
759         loop {
760             // kdebug!("tcp accept: poll_ifaces()");
761             poll_ifaces();
762 
763             let mut sockets = SOCKET_SET.lock();
764 
765             let socket = sockets.get_mut::<tcp::Socket>(self.handle.0);
766 
767             if socket.is_active() {
768                 // kdebug!("tcp accept: socket.is_active()");
769                 let remote_ep = socket.remote_endpoint().ok_or(SystemError::ENOTCONN)?;
770                 drop(socket);
771 
772                 let new_socket = {
773                     // Initialize the TCP socket's buffers.
774                     let rx_buffer = tcp::SocketBuffer::new(vec![0; Self::DEFAULT_RX_BUF_SIZE]);
775                     let tx_buffer = tcp::SocketBuffer::new(vec![0; Self::DEFAULT_TX_BUF_SIZE]);
776                     // The new TCP socket used for sending and receiving data.
777                     let mut tcp_socket = tcp::Socket::new(rx_buffer, tx_buffer);
778                     self.do_listen(&mut tcp_socket, endpoint)
779                         .expect("do_listen failed");
780 
781                     // tcp_socket.listen(endpoint).unwrap();
782 
783                     // 之所以把old_handle存入new_socket, 是因为当前时刻,smoltcp已经把old_handle对应的socket与远程的endpoint关联起来了
784                     // 因此需要再为当前的socket分配一个新的handle
785                     let new_handle = GlobalSocketHandle::new(sockets.add(tcp_socket));
786                     let old_handle = ::core::mem::replace(&mut self.handle, new_handle);
787 
788                     Box::new(TcpSocket {
789                         handle: old_handle,
790                         local_endpoint: self.local_endpoint,
791                         is_listening: false,
792                         options: self.options,
793                     })
794                 };
795                 // kdebug!("tcp accept: new socket: {:?}", new_socket);
796                 drop(sockets);
797                 poll_ifaces();
798 
799                 return Ok((new_socket, Endpoint::Ip(Some(remote_ep))));
800             }
801             drop(socket);
802             drop(sockets);
803             SOCKET_WAITQUEUE.sleep();
804         }
805     }
806 
endpoint(&self) -> Option<Endpoint>807     fn endpoint(&self) -> Option<Endpoint> {
808         let mut result: Option<Endpoint> =
809             self.local_endpoint.clone().map(|x| Endpoint::Ip(Some(x)));
810 
811         if result.is_none() {
812             let sockets = SOCKET_SET.lock();
813             let socket = sockets.get::<tcp::Socket>(self.handle.0);
814             if let Some(ep) = socket.local_endpoint() {
815                 result = Some(Endpoint::Ip(Some(ep)));
816             }
817         }
818         return result;
819     }
820 
peer_endpoint(&self) -> Option<Endpoint>821     fn peer_endpoint(&self) -> Option<Endpoint> {
822         let sockets = SOCKET_SET.lock();
823         let socket = sockets.get::<tcp::Socket>(self.handle.0);
824         return socket.remote_endpoint().map(|x| Endpoint::Ip(Some(x)));
825     }
826 
metadata(&self) -> Result<SocketMetadata, SystemError>827     fn metadata(&self) -> Result<SocketMetadata, SystemError> {
828         todo!()
829     }
830 
box_clone(&self) -> alloc::boxed::Box<dyn Socket>831     fn box_clone(&self) -> alloc::boxed::Box<dyn Socket> {
832         return Box::new(self.clone());
833     }
834 }
835 
836 /// @breif 自动分配一个未被使用的PORT
837 ///
838 /// TODO: 增加ListenTable, 用于检查端口是否被占用
get_ephemeral_port() -> u16839 pub fn get_ephemeral_port() -> u16 {
840     // TODO selects non-conflict high port
841 
842     static mut EPHEMERAL_PORT: u16 = 0;
843     unsafe {
844         if EPHEMERAL_PORT == 0 {
845             EPHEMERAL_PORT = (49152 + rand() % (65536 - 49152)) as u16;
846         }
847         if EPHEMERAL_PORT == 65535 {
848             EPHEMERAL_PORT = 49152;
849         } else {
850             EPHEMERAL_PORT = EPHEMERAL_PORT + 1;
851         }
852         EPHEMERAL_PORT
853     }
854 }
855 
856 /// @brief 地址族的枚举
857 ///
858 /// 参考:https://opengrok.ringotek.cn/xref/linux-5.19.10/include/linux/socket.h#180
859 #[derive(Debug, Clone, Copy, PartialEq, Eq, FromPrimitive, ToPrimitive)]
860 pub enum AddressFamily {
861     /// AF_UNSPEC 表示地址族未指定
862     Unspecified = 0,
863     /// AF_UNIX 表示Unix域的socket (与AF_LOCAL相同)
864     Unix = 1,
865     ///  AF_INET 表示IPv4的socket
866     INet = 2,
867     /// AF_AX25 表示AMPR AX.25的socket
868     AX25 = 3,
869     /// AF_IPX 表示IPX的socket
870     IPX = 4,
871     /// AF_APPLETALK 表示Appletalk的socket
872     Appletalk = 5,
873     /// AF_NETROM 表示AMPR NET/ROM的socket
874     Netrom = 6,
875     /// AF_BRIDGE 表示多协议桥接的socket
876     Bridge = 7,
877     /// AF_ATMPVC 表示ATM PVCs的socket
878     Atmpvc = 8,
879     /// AF_X25 表示X.25的socket
880     X25 = 9,
881     /// AF_INET6 表示IPv6的socket
882     INet6 = 10,
883     /// AF_ROSE 表示AMPR ROSE的socket
884     Rose = 11,
885     /// AF_DECnet Reserved for DECnet project
886     Decnet = 12,
887     /// AF_NETBEUI Reserved for 802.2LLC project
888     Netbeui = 13,
889     /// AF_SECURITY 表示Security callback的伪AF
890     Security = 14,
891     /// AF_KEY 表示Key management API
892     Key = 15,
893     /// AF_NETLINK 表示Netlink的socket
894     Netlink = 16,
895     /// AF_PACKET 表示Low level packet interface
896     Packet = 17,
897     /// AF_ASH 表示Ash
898     Ash = 18,
899     /// AF_ECONET 表示Acorn Econet
900     Econet = 19,
901     /// AF_ATMSVC 表示ATM SVCs
902     Atmsvc = 20,
903     /// AF_RDS 表示Reliable Datagram Sockets
904     Rds = 21,
905     /// AF_SNA 表示Linux SNA Project
906     Sna = 22,
907     /// AF_IRDA 表示IRDA sockets
908     Irda = 23,
909     /// AF_PPPOX 表示PPPoX sockets
910     Pppox = 24,
911     /// AF_WANPIPE 表示WANPIPE API sockets
912     WanPipe = 25,
913     /// AF_LLC 表示Linux LLC
914     Llc = 26,
915     /// AF_IB 表示Native InfiniBand address
916     /// 介绍:https://access.redhat.com/documentation/en-us/red_hat_enterprise_linux/9/html-single/configuring_infiniband_and_rdma_networks/index#understanding-infiniband-and-rdma_configuring-infiniband-and-rdma-networks
917     Ib = 27,
918     /// AF_MPLS 表示MPLS
919     Mpls = 28,
920     /// AF_CAN 表示Controller Area Network
921     Can = 29,
922     /// AF_TIPC 表示TIPC sockets
923     Tipc = 30,
924     /// AF_BLUETOOTH 表示Bluetooth sockets
925     Bluetooth = 31,
926     /// AF_IUCV 表示IUCV sockets
927     Iucv = 32,
928     /// AF_RXRPC 表示RxRPC sockets
929     Rxrpc = 33,
930     /// AF_ISDN 表示mISDN sockets
931     Isdn = 34,
932     /// AF_PHONET 表示Phonet sockets
933     Phonet = 35,
934     /// AF_IEEE802154 表示IEEE 802.15.4 sockets
935     Ieee802154 = 36,
936     /// AF_CAIF 表示CAIF sockets
937     Caif = 37,
938     /// AF_ALG 表示Algorithm sockets
939     Alg = 38,
940     /// AF_NFC 表示NFC sockets
941     Nfc = 39,
942     /// AF_VSOCK 表示vSockets
943     Vsock = 40,
944     /// AF_KCM 表示Kernel Connection Multiplexor
945     Kcm = 41,
946     /// AF_QIPCRTR 表示Qualcomm IPC Router
947     Qipcrtr = 42,
948     /// AF_SMC 表示SMC-R sockets.
949     /// reserve number for PF_SMC protocol family that reuses AF_INET address family
950     Smc = 43,
951     /// AF_XDP 表示XDP sockets
952     Xdp = 44,
953     /// AF_MCTP 表示Management Component Transport Protocol
954     Mctp = 45,
955     /// AF_MAX 表示最大的地址族
956     Max = 46,
957 }
958 
959 impl TryFrom<u16> for AddressFamily {
960     type Error = SystemError;
try_from(x: u16) -> Result<Self, Self::Error>961     fn try_from(x: u16) -> Result<Self, Self::Error> {
962         use num_traits::FromPrimitive;
963         return <Self as FromPrimitive>::from_u16(x).ok_or_else(|| SystemError::EINVAL);
964     }
965 }
966 
967 /// @brief posix套接字类型的枚举(这些值与linux内核中的值一致)
968 #[derive(Debug, Clone, Copy, PartialEq, Eq, FromPrimitive, ToPrimitive)]
969 pub enum PosixSocketType {
970     Stream = 1,
971     Datagram = 2,
972     Raw = 3,
973     Rdm = 4,
974     SeqPacket = 5,
975     Dccp = 6,
976     Packet = 10,
977 }
978 
979 impl TryFrom<u8> for PosixSocketType {
980     type Error = SystemError;
try_from(x: u8) -> Result<Self, Self::Error>981     fn try_from(x: u8) -> Result<Self, Self::Error> {
982         use num_traits::FromPrimitive;
983         return <Self as FromPrimitive>::from_u8(x).ok_or_else(|| SystemError::EINVAL);
984     }
985 }
986 
987 /// @brief Socket在文件系统中的inode封装
988 #[derive(Debug)]
989 pub struct SocketInode(SpinLock<Box<dyn Socket>>);
990 
991 impl SocketInode {
new(socket: Box<dyn Socket>) -> Arc<Self>992     pub fn new(socket: Box<dyn Socket>) -> Arc<Self> {
993         return Arc::new(Self(SpinLock::new(socket)));
994     }
995 
996     #[inline]
inner(&self) -> SpinLockGuard<Box<dyn Socket>>997     pub fn inner(&self) -> SpinLockGuard<Box<dyn Socket>> {
998         return self.0.lock();
999     }
1000 }
1001 
1002 impl IndexNode for SocketInode {
open( &self, _data: &mut crate::filesystem::vfs::FilePrivateData, _mode: &crate::filesystem::vfs::file::FileMode, ) -> Result<(), SystemError>1003     fn open(
1004         &self,
1005         _data: &mut crate::filesystem::vfs::FilePrivateData,
1006         _mode: &crate::filesystem::vfs::file::FileMode,
1007     ) -> Result<(), SystemError> {
1008         return Ok(());
1009     }
1010 
close( &self, _data: &mut crate::filesystem::vfs::FilePrivateData, ) -> Result<(), SystemError>1011     fn close(
1012         &self,
1013         _data: &mut crate::filesystem::vfs::FilePrivateData,
1014     ) -> Result<(), SystemError> {
1015         return Ok(());
1016     }
1017 
read_at( &self, _offset: usize, len: usize, buf: &mut [u8], _data: &mut crate::filesystem::vfs::FilePrivateData, ) -> Result<usize, SystemError>1018     fn read_at(
1019         &self,
1020         _offset: usize,
1021         len: usize,
1022         buf: &mut [u8],
1023         _data: &mut crate::filesystem::vfs::FilePrivateData,
1024     ) -> Result<usize, SystemError> {
1025         return self.0.lock().read(&mut buf[0..len]).0;
1026     }
1027 
write_at( &self, _offset: usize, len: usize, buf: &[u8], _data: &mut crate::filesystem::vfs::FilePrivateData, ) -> Result<usize, SystemError>1028     fn write_at(
1029         &self,
1030         _offset: usize,
1031         len: usize,
1032         buf: &[u8],
1033         _data: &mut crate::filesystem::vfs::FilePrivateData,
1034     ) -> Result<usize, SystemError> {
1035         return self.0.lock().write(&buf[0..len], None);
1036     }
1037 
poll(&self) -> Result<crate::filesystem::vfs::PollStatus, SystemError>1038     fn poll(&self) -> Result<crate::filesystem::vfs::PollStatus, SystemError> {
1039         let (read, write, error) = self.0.lock().poll();
1040         let mut result = PollStatus::empty();
1041         if read {
1042             result.insert(PollStatus::READ);
1043         }
1044         if write {
1045             result.insert(PollStatus::WRITE);
1046         }
1047         if error {
1048             result.insert(PollStatus::ERROR);
1049         }
1050         return Ok(result);
1051     }
1052 
fs(&self) -> alloc::sync::Arc<dyn crate::filesystem::vfs::FileSystem>1053     fn fs(&self) -> alloc::sync::Arc<dyn crate::filesystem::vfs::FileSystem> {
1054         todo!()
1055     }
1056 
as_any_ref(&self) -> &dyn core::any::Any1057     fn as_any_ref(&self) -> &dyn core::any::Any {
1058         self
1059     }
1060 
list(&self) -> Result<Vec<alloc::string::String>, SystemError>1061     fn list(&self) -> Result<Vec<alloc::string::String>, SystemError> {
1062         return Err(SystemError::ENOTDIR);
1063     }
1064 
metadata(&self) -> Result<crate::filesystem::vfs::Metadata, SystemError>1065     fn metadata(&self) -> Result<crate::filesystem::vfs::Metadata, SystemError> {
1066         let meta = Metadata {
1067             mode: 0o777,
1068             file_type: FileType::Socket,
1069             ..Default::default()
1070         };
1071 
1072         return Ok(meta);
1073     }
1074 }
1075