1 use alloc::{boxed::Box, sync::Arc, vec::Vec}; 2 use log::{error, warn}; 3 use smoltcp::{ 4 socket::{raw, tcp, udp}, 5 wire, 6 }; 7 use system_error::SystemError; 8 9 use crate::{ 10 driver::net::NetDevice, 11 libs::rwlock::RwLock, 12 net::{ 13 event_poll::EPollEventType, net_core::poll_ifaces, Endpoint, Protocol, ShutdownType, 14 NET_DEVICES, 15 }, 16 }; 17 18 use super::{ 19 handle::GlobalSocketHandle, PosixSocketHandleItem, Socket, SocketHandleItem, SocketMetadata, 20 SocketOptions, SocketPollMethod, SocketType, HANDLE_MAP, PORT_MANAGER, SOCKET_SET, 21 }; 22 23 /// @brief 表示原始的socket。原始套接字绕过传输层协议(如 TCP 或 UDP)并提供对网络层协议(如 IP)的直接访问。 24 /// 25 /// ref: https://man7.org/linux/man-pages/man7/raw.7.html 26 #[derive(Debug, Clone)] 27 pub struct RawSocket { 28 handle: GlobalSocketHandle, 29 /// 用户发送的数据包是否包含了IP头. 30 /// 如果是true,用户发送的数据包,必须包含IP头。(即用户要自行设置IP头+数据) 31 /// 如果是false,用户发送的数据包,不包含IP头。(即用户只要设置数据) 32 header_included: bool, 33 /// socket的metadata 34 metadata: SocketMetadata, 35 posix_item: Arc<PosixSocketHandleItem>, 36 } 37 38 impl RawSocket { 39 /// 元数据的缓冲区的大小 40 pub const DEFAULT_METADATA_BUF_SIZE: usize = 1024; 41 /// 默认的接收缓冲区的大小 receive 42 pub const DEFAULT_RX_BUF_SIZE: usize = 64 * 1024; 43 /// 默认的发送缓冲区的大小 transmiss 44 pub const DEFAULT_TX_BUF_SIZE: usize = 64 * 1024; 45 46 /// @brief 创建一个原始的socket 47 /// 48 /// @param protocol 协议号 49 /// @param options socket的选项 50 /// 51 /// @return 返回创建的原始的socket 52 pub fn new(protocol: Protocol, options: SocketOptions) -> Self { 53 let rx_buffer = raw::PacketBuffer::new( 54 vec![raw::PacketMetadata::EMPTY; Self::DEFAULT_METADATA_BUF_SIZE], 55 vec![0; Self::DEFAULT_RX_BUF_SIZE], 56 ); 57 let tx_buffer = raw::PacketBuffer::new( 58 vec![raw::PacketMetadata::EMPTY; Self::DEFAULT_METADATA_BUF_SIZE], 59 vec![0; Self::DEFAULT_TX_BUF_SIZE], 60 ); 61 let protocol: u8 = protocol.into(); 62 let socket = raw::Socket::new( 63 wire::IpVersion::Ipv4, 64 wire::IpProtocol::from(protocol), 65 rx_buffer, 66 tx_buffer, 67 ); 68 69 // 把socket添加到socket集合中,并得到socket的句柄 70 let handle = GlobalSocketHandle::new_smoltcp_handle(SOCKET_SET.lock_irqsave().add(socket)); 71 72 let metadata = SocketMetadata::new( 73 SocketType::Raw, 74 Self::DEFAULT_RX_BUF_SIZE, 75 Self::DEFAULT_TX_BUF_SIZE, 76 Self::DEFAULT_METADATA_BUF_SIZE, 77 options, 78 ); 79 80 let posix_item = Arc::new(PosixSocketHandleItem::new(None)); 81 82 return Self { 83 handle, 84 header_included: false, 85 metadata, 86 posix_item, 87 }; 88 } 89 } 90 91 impl Socket for RawSocket { 92 fn posix_item(&self) -> Arc<PosixSocketHandleItem> { 93 self.posix_item.clone() 94 } 95 96 fn close(&mut self) { 97 let mut socket_set_guard = SOCKET_SET.lock_irqsave(); 98 if let smoltcp::socket::Socket::Udp(mut sock) = 99 socket_set_guard.remove(self.handle.smoltcp_handle().unwrap()) 100 { 101 sock.close(); 102 } 103 drop(socket_set_guard); 104 poll_ifaces(); 105 } 106 107 fn read(&self, buf: &mut [u8]) -> (Result<usize, SystemError>, Endpoint) { 108 poll_ifaces(); 109 loop { 110 // 如何优化这里? 111 let mut socket_set_guard = SOCKET_SET.lock_irqsave(); 112 let socket = 113 socket_set_guard.get_mut::<raw::Socket>(self.handle.smoltcp_handle().unwrap()); 114 115 match socket.recv_slice(buf) { 116 Ok(len) => { 117 let packet = wire::Ipv4Packet::new_unchecked(buf); 118 return ( 119 Ok(len), 120 Endpoint::Ip(Some(wire::IpEndpoint { 121 addr: wire::IpAddress::Ipv4(packet.src_addr()), 122 port: 0, 123 })), 124 ); 125 } 126 Err(_) => { 127 if !self.metadata.options.contains(SocketOptions::BLOCK) { 128 // 如果是非阻塞的socket,就返回错误 129 return (Err(SystemError::EAGAIN_OR_EWOULDBLOCK), Endpoint::Ip(None)); 130 } 131 } 132 } 133 drop(socket_set_guard); 134 self.posix_item.sleep(EPollEventType::EPOLLIN.bits() as u64); 135 } 136 } 137 138 fn write(&self, buf: &[u8], to: Option<Endpoint>) -> Result<usize, SystemError> { 139 // 如果用户发送的数据包,包含IP头,则直接发送 140 if self.header_included { 141 let mut socket_set_guard = SOCKET_SET.lock_irqsave(); 142 let socket = 143 socket_set_guard.get_mut::<raw::Socket>(self.handle.smoltcp_handle().unwrap()); 144 match socket.send_slice(buf) { 145 Ok(_) => { 146 return Ok(buf.len()); 147 } 148 Err(raw::SendError::BufferFull) => { 149 return Err(SystemError::ENOBUFS); 150 } 151 } 152 } else { 153 // 如果用户发送的数据包,不包含IP头,则需要自己构造IP头 154 155 if let Some(Endpoint::Ip(Some(endpoint))) = to { 156 let mut socket_set_guard = SOCKET_SET.lock_irqsave(); 157 let socket: &mut raw::Socket = 158 socket_set_guard.get_mut::<raw::Socket>(self.handle.smoltcp_handle().unwrap()); 159 160 // 暴力解决方案:只考虑0号网卡。 TODO:考虑多网卡的情况!!! 161 let iface = NET_DEVICES.read_irqsave().get(&0).unwrap().clone(); 162 163 // 构造IP头 164 let ipv4_src_addr: Option<wire::Ipv4Address> = 165 iface.inner_iface().lock().ipv4_addr(); 166 if ipv4_src_addr.is_none() { 167 return Err(SystemError::ENETUNREACH); 168 } 169 let ipv4_src_addr = ipv4_src_addr.unwrap(); 170 171 if let wire::IpAddress::Ipv4(ipv4_dst) = endpoint.addr { 172 let len = buf.len(); 173 174 // 创建20字节的IPv4头部 175 let mut buffer: Vec<u8> = vec![0u8; len + 20]; 176 let mut packet: wire::Ipv4Packet<&mut Vec<u8>> = 177 wire::Ipv4Packet::new_unchecked(&mut buffer); 178 179 // 封装ipv4 header 180 packet.set_version(4); 181 packet.set_header_len(20); 182 packet.set_total_len((20 + len) as u16); 183 packet.set_src_addr(ipv4_src_addr); 184 packet.set_dst_addr(ipv4_dst); 185 186 // 设置ipv4 header的protocol字段 187 packet.set_next_header(socket.ip_protocol()); 188 189 // 获取IP数据包的负载字段 190 let payload: &mut [u8] = packet.payload_mut(); 191 payload.copy_from_slice(buf); 192 193 // 填充checksum字段 194 packet.fill_checksum(); 195 196 // 发送数据包 197 socket.send_slice(&buffer).unwrap(); 198 199 iface.poll(&mut socket_set_guard).ok(); 200 201 drop(socket_set_guard); 202 return Ok(len); 203 } else { 204 warn!("Unsupport Ip protocol type!"); 205 return Err(SystemError::EINVAL); 206 } 207 } else { 208 // 如果没有指定目的地址,则返回错误 209 return Err(SystemError::ENOTCONN); 210 } 211 } 212 } 213 214 fn connect(&mut self, _endpoint: Endpoint) -> Result<(), SystemError> { 215 Ok(()) 216 } 217 218 fn metadata(&self) -> SocketMetadata { 219 self.metadata.clone() 220 } 221 222 fn box_clone(&self) -> Box<dyn Socket> { 223 Box::new(self.clone()) 224 } 225 226 fn socket_handle(&self) -> GlobalSocketHandle { 227 self.handle 228 } 229 230 fn as_any_ref(&self) -> &dyn core::any::Any { 231 self 232 } 233 234 fn as_any_mut(&mut self) -> &mut dyn core::any::Any { 235 self 236 } 237 } 238 239 /// @brief 表示udp socket 240 /// 241 /// https://man7.org/linux/man-pages/man7/udp.7.html 242 #[derive(Debug, Clone)] 243 pub struct UdpSocket { 244 pub handle: GlobalSocketHandle, 245 remote_endpoint: Option<Endpoint>, // 记录远程endpoint提供给connect(), 应该使用IP地址。 246 metadata: SocketMetadata, 247 posix_item: Arc<PosixSocketHandleItem>, 248 } 249 250 impl UdpSocket { 251 /// 元数据的缓冲区的大小 252 pub const DEFAULT_METADATA_BUF_SIZE: usize = 1024; 253 /// 默认的接收缓冲区的大小 receive 254 pub const DEFAULT_RX_BUF_SIZE: usize = 64 * 1024; 255 /// 默认的发送缓冲区的大小 transmiss 256 pub const DEFAULT_TX_BUF_SIZE: usize = 64 * 1024; 257 258 /// @brief 创建一个udp的socket 259 /// 260 /// @param options socket的选项 261 /// 262 /// @return 返回创建的udp的socket 263 pub fn new(options: SocketOptions) -> Self { 264 let rx_buffer = udp::PacketBuffer::new( 265 vec![udp::PacketMetadata::EMPTY; Self::DEFAULT_METADATA_BUF_SIZE], 266 vec![0; Self::DEFAULT_RX_BUF_SIZE], 267 ); 268 let tx_buffer = udp::PacketBuffer::new( 269 vec![udp::PacketMetadata::EMPTY; Self::DEFAULT_METADATA_BUF_SIZE], 270 vec![0; Self::DEFAULT_TX_BUF_SIZE], 271 ); 272 let socket = udp::Socket::new(rx_buffer, tx_buffer); 273 274 // 把socket添加到socket集合中,并得到socket的句柄 275 let handle: GlobalSocketHandle = 276 GlobalSocketHandle::new_smoltcp_handle(SOCKET_SET.lock_irqsave().add(socket)); 277 278 let metadata = SocketMetadata::new( 279 SocketType::Udp, 280 Self::DEFAULT_RX_BUF_SIZE, 281 Self::DEFAULT_TX_BUF_SIZE, 282 Self::DEFAULT_METADATA_BUF_SIZE, 283 options, 284 ); 285 286 let posix_item = Arc::new(PosixSocketHandleItem::new(None)); 287 288 return Self { 289 handle, 290 remote_endpoint: None, 291 metadata, 292 posix_item, 293 }; 294 } 295 296 fn do_bind(&self, socket: &mut udp::Socket, endpoint: Endpoint) -> Result<(), SystemError> { 297 if let Endpoint::Ip(Some(mut ip)) = endpoint { 298 // 端口为0则分配随机端口 299 if ip.port == 0 { 300 ip.port = PORT_MANAGER.get_ephemeral_port(self.metadata.socket_type)?; 301 } 302 // 检测端口是否已被占用 303 PORT_MANAGER.bind_port(self.metadata.socket_type, ip.port)?; 304 305 let bind_res = if ip.addr.is_unspecified() { 306 socket.bind(ip.port) 307 } else { 308 socket.bind(ip) 309 }; 310 311 match bind_res { 312 Ok(()) => return Ok(()), 313 Err(_) => return Err(SystemError::EINVAL), 314 } 315 } else { 316 return Err(SystemError::EINVAL); 317 } 318 } 319 } 320 321 impl Socket for UdpSocket { 322 fn posix_item(&self) -> Arc<PosixSocketHandleItem> { 323 self.posix_item.clone() 324 } 325 326 fn close(&mut self) { 327 let mut socket_set_guard = SOCKET_SET.lock_irqsave(); 328 if let smoltcp::socket::Socket::Udp(mut sock) = 329 socket_set_guard.remove(self.handle.smoltcp_handle().unwrap()) 330 { 331 sock.close(); 332 } 333 drop(socket_set_guard); 334 poll_ifaces(); 335 } 336 337 /// @brief 在read函数执行之前,请先bind到本地的指定端口 338 fn read(&self, buf: &mut [u8]) -> (Result<usize, SystemError>, Endpoint) { 339 loop { 340 // debug!("Wait22 to Read"); 341 poll_ifaces(); 342 let mut socket_set_guard = SOCKET_SET.lock_irqsave(); 343 let socket = 344 socket_set_guard.get_mut::<udp::Socket>(self.handle.smoltcp_handle().unwrap()); 345 346 // debug!("Wait to Read"); 347 348 if socket.can_recv() { 349 if let Ok((size, metadata)) = socket.recv_slice(buf) { 350 drop(socket_set_guard); 351 poll_ifaces(); 352 return (Ok(size), Endpoint::Ip(Some(metadata.endpoint))); 353 } 354 } else { 355 // 如果socket没有连接,则忙等 356 // return (Err(SystemError::ENOTCONN), Endpoint::Ip(None)); 357 } 358 drop(socket_set_guard); 359 self.posix_item.sleep(EPollEventType::EPOLLIN.bits() as u64); 360 } 361 } 362 363 fn write(&self, buf: &[u8], to: Option<Endpoint>) -> Result<usize, SystemError> { 364 // debug!("udp to send: {:?}, len={}", to, buf.len()); 365 let remote_endpoint: &wire::IpEndpoint = { 366 if let Some(Endpoint::Ip(Some(ref endpoint))) = to { 367 endpoint 368 } else if let Some(Endpoint::Ip(Some(ref endpoint))) = self.remote_endpoint { 369 endpoint 370 } else { 371 return Err(SystemError::ENOTCONN); 372 } 373 }; 374 // debug!("udp write: remote = {:?}", remote_endpoint); 375 376 let mut socket_set_guard = SOCKET_SET.lock_irqsave(); 377 let socket = socket_set_guard.get_mut::<udp::Socket>(self.handle.smoltcp_handle().unwrap()); 378 // debug!("is open()={}", socket.is_open()); 379 // debug!("socket endpoint={:?}", socket.endpoint()); 380 if socket.can_send() { 381 // debug!("udp write: can send"); 382 match socket.send_slice(buf, *remote_endpoint) { 383 Ok(()) => { 384 // debug!("udp write: send ok"); 385 drop(socket_set_guard); 386 poll_ifaces(); 387 return Ok(buf.len()); 388 } 389 Err(_) => { 390 // debug!("udp write: send err"); 391 return Err(SystemError::ENOBUFS); 392 } 393 } 394 } else { 395 // debug!("udp write: can not send"); 396 return Err(SystemError::ENOBUFS); 397 }; 398 } 399 400 fn bind(&mut self, endpoint: Endpoint) -> Result<(), SystemError> { 401 let mut sockets = SOCKET_SET.lock_irqsave(); 402 let socket = sockets.get_mut::<udp::Socket>(self.handle.smoltcp_handle().unwrap()); 403 // debug!("UDP Bind to {:?}", endpoint); 404 return self.do_bind(socket, endpoint); 405 } 406 407 fn poll(&self) -> EPollEventType { 408 let sockets = SOCKET_SET.lock_irqsave(); 409 let socket = sockets.get::<udp::Socket>(self.handle.smoltcp_handle().unwrap()); 410 411 return SocketPollMethod::udp_poll( 412 socket, 413 HANDLE_MAP 414 .read_irqsave() 415 .get(&self.socket_handle()) 416 .unwrap() 417 .shutdown_type(), 418 ); 419 } 420 421 fn connect(&mut self, endpoint: Endpoint) -> Result<(), SystemError> { 422 if let Endpoint::Ip(_) = endpoint { 423 self.remote_endpoint = Some(endpoint); 424 Ok(()) 425 } else { 426 Err(SystemError::EINVAL) 427 } 428 } 429 430 fn ioctl( 431 &self, 432 _cmd: usize, 433 _arg0: usize, 434 _arg1: usize, 435 _arg2: usize, 436 ) -> Result<usize, SystemError> { 437 todo!() 438 } 439 440 fn metadata(&self) -> SocketMetadata { 441 self.metadata.clone() 442 } 443 444 fn box_clone(&self) -> Box<dyn Socket> { 445 return Box::new(self.clone()); 446 } 447 448 fn endpoint(&self) -> Option<Endpoint> { 449 let sockets = SOCKET_SET.lock_irqsave(); 450 let socket = sockets.get::<udp::Socket>(self.handle.smoltcp_handle().unwrap()); 451 let listen_endpoint = socket.endpoint(); 452 453 if listen_endpoint.port == 0 { 454 return None; 455 } else { 456 // 如果listen_endpoint的address是None,意味着“监听所有的地址”。 457 // 这里假设所有的地址都是ipv4 458 // TODO: 支持ipv6 459 let result = wire::IpEndpoint::new( 460 listen_endpoint 461 .addr 462 .unwrap_or(wire::IpAddress::v4(0, 0, 0, 0)), 463 listen_endpoint.port, 464 ); 465 return Some(Endpoint::Ip(Some(result))); 466 } 467 } 468 469 fn peer_endpoint(&self) -> Option<Endpoint> { 470 return self.remote_endpoint.clone(); 471 } 472 473 fn socket_handle(&self) -> GlobalSocketHandle { 474 self.handle 475 } 476 477 fn as_any_ref(&self) -> &dyn core::any::Any { 478 self 479 } 480 481 fn as_any_mut(&mut self) -> &mut dyn core::any::Any { 482 self 483 } 484 } 485 486 /// @brief 表示 tcp socket 487 /// 488 /// https://man7.org/linux/man-pages/man7/tcp.7.html 489 #[derive(Debug, Clone)] 490 pub struct TcpSocket { 491 handles: Vec<GlobalSocketHandle>, 492 local_endpoint: Option<wire::IpEndpoint>, // save local endpoint for bind() 493 is_listening: bool, 494 metadata: SocketMetadata, 495 posix_item: Arc<PosixSocketHandleItem>, 496 } 497 498 impl TcpSocket { 499 /// 元数据的缓冲区的大小 500 pub const DEFAULT_METADATA_BUF_SIZE: usize = 1024; 501 /// 默认的接收缓冲区的大小 receive 502 pub const DEFAULT_RX_BUF_SIZE: usize = 512 * 1024; 503 /// 默认的发送缓冲区的大小 transmiss 504 pub const DEFAULT_TX_BUF_SIZE: usize = 512 * 1024; 505 506 /// TcpSocket的特殊事件,用于在事件等待队列上sleep 507 pub const CAN_CONNECT: u64 = 1u64 << 63; 508 pub const CAN_ACCPET: u64 = 1u64 << 62; 509 510 /// @brief 创建一个tcp的socket 511 /// 512 /// @param options socket的选项 513 /// 514 /// @return 返回创建的tcp的socket 515 pub fn new(options: SocketOptions) -> Self { 516 // 创建handles数组并把socket添加到socket集合中,并得到socket的句柄 517 let handles: Vec<GlobalSocketHandle> = vec![GlobalSocketHandle::new_smoltcp_handle( 518 SOCKET_SET.lock_irqsave().add(Self::create_new_socket()), 519 )]; 520 521 let metadata = SocketMetadata::new( 522 SocketType::Tcp, 523 Self::DEFAULT_RX_BUF_SIZE, 524 Self::DEFAULT_TX_BUF_SIZE, 525 Self::DEFAULT_METADATA_BUF_SIZE, 526 options, 527 ); 528 let posix_item = Arc::new(PosixSocketHandleItem::new(None)); 529 // debug!("when there's a new tcp socket,its'len: {}",handles.len()); 530 531 return Self { 532 handles, 533 local_endpoint: None, 534 is_listening: false, 535 metadata, 536 posix_item, 537 }; 538 } 539 540 fn do_listen( 541 &mut self, 542 socket: &mut tcp::Socket, 543 local_endpoint: wire::IpEndpoint, 544 ) -> Result<(), SystemError> { 545 let listen_result = if local_endpoint.addr.is_unspecified() { 546 socket.listen(local_endpoint.port) 547 } else { 548 socket.listen(local_endpoint) 549 }; 550 return match listen_result { 551 Ok(()) => { 552 // debug!( 553 // "Tcp Socket Listen on {local_endpoint}, open?:{}", 554 // socket.is_open() 555 // ); 556 self.is_listening = true; 557 558 Ok(()) 559 } 560 Err(_) => Err(SystemError::EINVAL), 561 }; 562 } 563 564 /// # create_new_socket - 创建新的TCP套接字 565 /// 566 /// 该函数用于创建一个新的TCP套接字,并返回该套接字的引用。 567 fn create_new_socket() -> tcp::Socket<'static> { 568 // 初始化tcp的buffer 569 let rx_buffer = tcp::SocketBuffer::new(vec![0; Self::DEFAULT_RX_BUF_SIZE]); 570 let tx_buffer = tcp::SocketBuffer::new(vec![0; Self::DEFAULT_TX_BUF_SIZE]); 571 tcp::Socket::new(rx_buffer, tx_buffer) 572 } 573 574 /// listening状态的posix socket是需要特殊处理的 575 fn tcp_poll_listening(&self) -> EPollEventType { 576 let socketset_guard = SOCKET_SET.lock_irqsave(); 577 578 let can_accept = self.handles.iter().any(|h| { 579 if let Some(sh) = h.smoltcp_handle() { 580 let socket = socketset_guard.get::<tcp::Socket>(sh); 581 socket.is_active() 582 } else { 583 false 584 } 585 }); 586 587 if can_accept { 588 return EPollEventType::EPOLL_LISTEN_CAN_ACCEPT; 589 } else { 590 return EPollEventType::empty(); 591 } 592 } 593 } 594 595 impl Socket for TcpSocket { 596 fn posix_item(&self) -> Arc<PosixSocketHandleItem> { 597 self.posix_item.clone() 598 } 599 600 fn close(&mut self) { 601 for handle in self.handles.iter() { 602 { 603 let mut socket_set_guard = SOCKET_SET.lock_irqsave(); 604 let smoltcp_handle = handle.smoltcp_handle().unwrap(); 605 socket_set_guard 606 .get_mut::<smoltcp::socket::tcp::Socket>(smoltcp_handle) 607 .close(); 608 drop(socket_set_guard); 609 } 610 poll_ifaces(); 611 SOCKET_SET 612 .lock_irqsave() 613 .remove(handle.smoltcp_handle().unwrap()); 614 // debug!("[Socket] [TCP] Close: {:?}", handle); 615 } 616 } 617 618 fn read(&self, buf: &mut [u8]) -> (Result<usize, SystemError>, Endpoint) { 619 if HANDLE_MAP 620 .read_irqsave() 621 .get(&self.socket_handle()) 622 .unwrap() 623 .shutdown_type() 624 .contains(ShutdownType::RCV_SHUTDOWN) 625 { 626 return (Err(SystemError::ENOTCONN), Endpoint::Ip(None)); 627 } 628 // debug!("tcp socket: read, buf len={}", buf.len()); 629 // debug!("tcp socket:read, socket'len={}",self.handle.len()); 630 loop { 631 poll_ifaces(); 632 let mut socket_set_guard = SOCKET_SET.lock_irqsave(); 633 634 let socket = socket_set_guard 635 .get_mut::<tcp::Socket>(self.handles.get(0).unwrap().smoltcp_handle().unwrap()); 636 637 // 如果socket已经关闭,返回错误 638 if !socket.is_active() { 639 // debug!("Tcp Socket Read Error, socket is closed"); 640 return (Err(SystemError::ENOTCONN), Endpoint::Ip(None)); 641 } 642 643 if socket.may_recv() { 644 match socket.recv_slice(buf) { 645 Ok(size) => { 646 if size > 0 { 647 let endpoint = if let Some(p) = socket.remote_endpoint() { 648 p 649 } else { 650 return (Err(SystemError::ENOTCONN), Endpoint::Ip(None)); 651 }; 652 653 drop(socket_set_guard); 654 poll_ifaces(); 655 return (Ok(size), Endpoint::Ip(Some(endpoint))); 656 } 657 } 658 Err(tcp::RecvError::InvalidState) => { 659 warn!("Tcp Socket Read Error, InvalidState"); 660 return (Err(SystemError::ENOTCONN), Endpoint::Ip(None)); 661 } 662 Err(tcp::RecvError::Finished) => { 663 // 对端写端已关闭,我们应该关闭读端 664 HANDLE_MAP 665 .write_irqsave() 666 .get_mut(&self.socket_handle()) 667 .unwrap() 668 .shutdown_type_writer() 669 .insert(ShutdownType::RCV_SHUTDOWN); 670 return (Err(SystemError::ENOTCONN), Endpoint::Ip(None)); 671 } 672 } 673 } else { 674 return (Err(SystemError::ENOTCONN), Endpoint::Ip(None)); 675 } 676 drop(socket_set_guard); 677 self.posix_item 678 .sleep((EPollEventType::EPOLLIN | EPollEventType::EPOLLHUP).bits() as u64); 679 } 680 } 681 682 fn write(&self, buf: &[u8], _to: Option<Endpoint>) -> Result<usize, SystemError> { 683 if HANDLE_MAP 684 .read_irqsave() 685 .get(&self.socket_handle()) 686 .unwrap() 687 .shutdown_type() 688 .contains(ShutdownType::RCV_SHUTDOWN) 689 { 690 return Err(SystemError::ENOTCONN); 691 } 692 // debug!("tcp socket:write, socket'len={}",self.handle.len()); 693 694 let mut socket_set_guard = SOCKET_SET.lock_irqsave(); 695 696 let socket = socket_set_guard 697 .get_mut::<tcp::Socket>(self.handles.get(0).unwrap().smoltcp_handle().unwrap()); 698 699 if socket.is_open() { 700 if socket.can_send() { 701 match socket.send_slice(buf) { 702 Ok(size) => { 703 drop(socket_set_guard); 704 poll_ifaces(); 705 return Ok(size); 706 } 707 Err(e) => { 708 error!("Tcp Socket Write Error {e:?}"); 709 return Err(SystemError::ENOBUFS); 710 } 711 } 712 } else { 713 return Err(SystemError::ENOBUFS); 714 } 715 } 716 717 return Err(SystemError::ENOTCONN); 718 } 719 720 fn poll(&self) -> EPollEventType { 721 // 处理listen的快速路径 722 if self.is_listening { 723 return self.tcp_poll_listening(); 724 } 725 // 由于上面处理了listening状态,所以这里只处理非listening状态,这种情况下只有一个handle 726 727 assert!(self.handles.len() == 1); 728 729 let mut socket_set_guard = SOCKET_SET.lock_irqsave(); 730 // debug!("tcp socket:poll, socket'len={}",self.handle.len()); 731 732 let socket = socket_set_guard 733 .get_mut::<tcp::Socket>(self.handles.get(0).unwrap().smoltcp_handle().unwrap()); 734 let handle_map_guard = HANDLE_MAP.read_irqsave(); 735 let handle_item = handle_map_guard.get(&self.socket_handle()).unwrap(); 736 let shutdown_type = handle_item.shutdown_type(); 737 let is_posix_listen = handle_item.is_posix_listen; 738 drop(handle_map_guard); 739 740 return SocketPollMethod::tcp_poll(socket, shutdown_type, is_posix_listen); 741 } 742 743 fn connect(&mut self, endpoint: Endpoint) -> Result<(), SystemError> { 744 let mut sockets = SOCKET_SET.lock_irqsave(); 745 // debug!("tcp socket:connect, socket'len={}", self.handles.len()); 746 747 let socket = 748 sockets.get_mut::<tcp::Socket>(self.handles.get(0).unwrap().smoltcp_handle().unwrap()); 749 750 if let Endpoint::Ip(Some(ip)) = endpoint { 751 let temp_port = PORT_MANAGER.get_ephemeral_port(self.metadata.socket_type)?; 752 // 检测端口是否被占用 753 PORT_MANAGER.bind_port(self.metadata.socket_type, temp_port)?; 754 755 // debug!("temp_port: {}", temp_port); 756 let iface: Arc<dyn NetDevice> = NET_DEVICES.write_irqsave().get(&0).unwrap().clone(); 757 let mut inner_iface = iface.inner_iface().lock(); 758 // debug!("to connect: {ip:?}"); 759 760 match socket.connect(inner_iface.context(), ip, temp_port) { 761 Ok(()) => { 762 // avoid deadlock 763 drop(inner_iface); 764 drop(iface); 765 drop(sockets); 766 loop { 767 poll_ifaces(); 768 let mut sockets = SOCKET_SET.lock_irqsave(); 769 let socket = sockets.get_mut::<tcp::Socket>( 770 self.handles.get(0).unwrap().smoltcp_handle().unwrap(), 771 ); 772 773 match socket.state() { 774 tcp::State::Established => { 775 return Ok(()); 776 } 777 tcp::State::SynSent => { 778 drop(sockets); 779 self.posix_item.sleep(Self::CAN_CONNECT); 780 } 781 _ => { 782 return Err(SystemError::ECONNREFUSED); 783 } 784 } 785 } 786 } 787 Err(e) => { 788 // error!("Tcp Socket Connect Error {e:?}"); 789 match e { 790 tcp::ConnectError::InvalidState => return Err(SystemError::EISCONN), 791 tcp::ConnectError::Unaddressable => return Err(SystemError::EADDRNOTAVAIL), 792 } 793 } 794 } 795 } else { 796 return Err(SystemError::EINVAL); 797 } 798 } 799 800 /// @brief tcp socket 监听 local_endpoint 端口 801 /// 802 /// @param backlog 未处理的连接队列的最大长度 803 fn listen(&mut self, backlog: usize) -> Result<(), SystemError> { 804 if self.is_listening { 805 return Ok(()); 806 } 807 808 // debug!( 809 // "tcp socket:listen, socket'len={}, backlog = {backlog}", 810 // self.handles.len() 811 // ); 812 813 let local_endpoint = self.local_endpoint.ok_or(SystemError::EINVAL)?; 814 let mut sockets = SOCKET_SET.lock_irqsave(); 815 // 获取handle的数量 816 let handlen = self.handles.len(); 817 let backlog = handlen.max(backlog); 818 819 // 添加剩余需要构建的socket 820 // debug!("tcp socket:before listen, socket'len={}", self.handle_list.len()); 821 let mut handle_guard = HANDLE_MAP.write_irqsave(); 822 let socket_handle_item_0 = handle_guard.get_mut(&self.socket_handle()).unwrap(); 823 socket_handle_item_0.is_posix_listen = true; 824 825 self.handles.extend((handlen..backlog).map(|_| { 826 let socket = Self::create_new_socket(); 827 let handle = GlobalSocketHandle::new_smoltcp_handle(sockets.add(socket)); 828 let mut handle_item = SocketHandleItem::new(Arc::downgrade(&self.posix_item)); 829 handle_item.is_posix_listen = true; 830 handle_guard.insert(handle, handle_item); 831 handle 832 })); 833 834 // debug!("tcp socket:listen, socket'len={}", self.handles.len()); 835 // debug!("tcp socket:listen, backlog={backlog}"); 836 837 // 监听所有的socket 838 for i in 0..backlog { 839 let handle = self.handles.get(i).unwrap(); 840 841 let socket = sockets.get_mut::<tcp::Socket>(handle.smoltcp_handle().unwrap()); 842 843 if !socket.is_listening() { 844 // debug!("Tcp Socket is already listening on {local_endpoint}"); 845 self.do_listen(socket, local_endpoint)?; 846 } 847 // debug!("Tcp Socket before listen, open={}", socket.is_open()); 848 } 849 850 return Ok(()); 851 } 852 853 fn bind(&mut self, endpoint: Endpoint) -> Result<(), SystemError> { 854 if let Endpoint::Ip(Some(mut ip)) = endpoint { 855 if ip.port == 0 { 856 ip.port = PORT_MANAGER.get_ephemeral_port(self.metadata.socket_type)?; 857 } 858 859 // 检测端口是否已被占用 860 PORT_MANAGER.bind_port(self.metadata.socket_type, ip.port)?; 861 // debug!("tcp socket:bind, socket'len={}",self.handle.len()); 862 863 self.local_endpoint = Some(ip); 864 self.is_listening = false; 865 866 return Ok(()); 867 } 868 return Err(SystemError::EINVAL); 869 } 870 871 fn shutdown(&mut self, shutdown_type: super::ShutdownType) -> Result<(), SystemError> { 872 // TODO:目前只是在表层判断,对端不知晓,后续需使用tcp实现 873 HANDLE_MAP 874 .write_irqsave() 875 .get_mut(&self.socket_handle()) 876 .unwrap() 877 .shutdown_type = RwLock::new(shutdown_type); 878 return Ok(()); 879 } 880 881 fn accept(&mut self) -> Result<(Box<dyn Socket>, Endpoint), SystemError> { 882 if !self.is_listening { 883 return Err(SystemError::EINVAL); 884 } 885 let endpoint = self.local_endpoint.ok_or(SystemError::EINVAL)?; 886 loop { 887 // debug!("tcp accept: poll_ifaces()"); 888 poll_ifaces(); 889 // debug!("tcp socket:accept, socket'len={}", self.handle_list.len()); 890 891 let mut sockset = SOCKET_SET.lock_irqsave(); 892 // Get the corresponding activated handler 893 let global_handle_index = self.handles.iter().position(|handle| { 894 let con_smol_sock = sockset.get::<tcp::Socket>(handle.smoltcp_handle().unwrap()); 895 con_smol_sock.is_active() 896 }); 897 898 if let Some(handle_index) = global_handle_index { 899 let con_smol_sock = sockset 900 .get::<tcp::Socket>(self.handles[handle_index].smoltcp_handle().unwrap()); 901 902 // debug!("[Socket] [TCP] Accept: {:?}", handle); 903 // handle is connected socket's handle 904 let remote_ep = con_smol_sock 905 .remote_endpoint() 906 .ok_or(SystemError::ENOTCONN)?; 907 908 let tcp_socket = Self::create_new_socket(); 909 910 let new_handle = GlobalSocketHandle::new_smoltcp_handle(sockset.add(tcp_socket)); 911 912 // let handle in TcpSock be the new empty handle, and return the old connected handle 913 let old_handle = core::mem::replace(&mut self.handles[handle_index], new_handle); 914 915 let metadata = SocketMetadata::new( 916 SocketType::Tcp, 917 Self::DEFAULT_TX_BUF_SIZE, 918 Self::DEFAULT_RX_BUF_SIZE, 919 Self::DEFAULT_METADATA_BUF_SIZE, 920 self.metadata.options, 921 ); 922 923 let sock_ret = Box::new(TcpSocket { 924 handles: vec![old_handle], 925 local_endpoint: self.local_endpoint, 926 is_listening: false, 927 metadata, 928 posix_item: Arc::new(PosixSocketHandleItem::new(None)), 929 }); 930 931 { 932 let mut handle_guard = HANDLE_MAP.write_irqsave(); 933 // 先删除原来的 934 let item = handle_guard.remove(&old_handle).unwrap(); 935 item.reset_shutdown_type(); 936 assert!(item.is_posix_listen); 937 938 // 按照smoltcp行为,将新的handle绑定到原来的item 939 let new_item = SocketHandleItem::new(Arc::downgrade(&sock_ret.posix_item)); 940 handle_guard.insert(old_handle, new_item); 941 // 插入新的item 942 handle_guard.insert(new_handle, item); 943 944 let socket = sockset.get_mut::<tcp::Socket>( 945 self.handles[handle_index].smoltcp_handle().unwrap(), 946 ); 947 948 if !socket.is_listening() { 949 self.do_listen(socket, endpoint)?; 950 } 951 952 drop(handle_guard); 953 } 954 955 return Ok((sock_ret, Endpoint::Ip(Some(remote_ep)))); 956 } 957 958 drop(sockset); 959 960 // debug!("[TCP] [Accept] sleeping socket with handle: {:?}", self.handles.get(0).unwrap().smoltcp_handle().unwrap()); 961 self.posix_item.sleep(Self::CAN_ACCPET); 962 // debug!("tcp socket:after sleep, handle_guard'len={}",HANDLE_MAP.write_irqsave().len()); 963 } 964 } 965 966 fn endpoint(&self) -> Option<Endpoint> { 967 let mut result: Option<Endpoint> = self.local_endpoint.map(|x| Endpoint::Ip(Some(x))); 968 969 if result.is_none() { 970 let sockets = SOCKET_SET.lock_irqsave(); 971 // debug!("tcp socket:endpoint, socket'len={}",self.handle.len()); 972 973 let socket = 974 sockets.get::<tcp::Socket>(self.handles.get(0).unwrap().smoltcp_handle().unwrap()); 975 if let Some(ep) = socket.local_endpoint() { 976 result = Some(Endpoint::Ip(Some(ep))); 977 } 978 } 979 return result; 980 } 981 982 fn peer_endpoint(&self) -> Option<Endpoint> { 983 let sockets = SOCKET_SET.lock_irqsave(); 984 // debug!("tcp socket:peer_endpoint, socket'len={}",self.handle.len()); 985 986 let socket = 987 sockets.get::<tcp::Socket>(self.handles.get(0).unwrap().smoltcp_handle().unwrap()); 988 return socket.remote_endpoint().map(|x| Endpoint::Ip(Some(x))); 989 } 990 991 fn metadata(&self) -> SocketMetadata { 992 self.metadata.clone() 993 } 994 995 fn box_clone(&self) -> Box<dyn Socket> { 996 Box::new(self.clone()) 997 } 998 999 fn socket_handle(&self) -> GlobalSocketHandle { 1000 // debug!("tcp socket:socket_handle, socket'len={}",self.handle.len()); 1001 1002 *self.handles.get(0).unwrap() 1003 } 1004 1005 fn as_any_ref(&self) -> &dyn core::any::Any { 1006 self 1007 } 1008 1009 fn as_any_mut(&mut self) -> &mut dyn core::any::Any { 1010 self 1011 } 1012 } 1013