1 use alloc::{boxed::Box, sync::Arc, vec::Vec}; 2 use smoltcp::{ 3 socket::{raw, tcp, udp}, 4 wire, 5 }; 6 use system_error::SystemError; 7 8 use crate::{ 9 driver::net::NetDevice, 10 kerror, kwarn, 11 libs::rwlock::RwLock, 12 net::{ 13 event_poll::EPollEventType, net_core::poll_ifaces, Endpoint, Protocol, ShutdownType, 14 NET_DEVICES, 15 }, 16 }; 17 18 use super::{ 19 handle::GlobalSocketHandle, Socket, SocketHandleItem, SocketMetadata, SocketOptions, 20 SocketPollMethod, SocketType, HANDLE_MAP, PORT_MANAGER, SOCKET_SET, 21 }; 22 23 /// @brief 表示原始的socket。原始套接字绕过传输层协议(如 TCP 或 UDP)并提供对网络层协议(如 IP)的直接访问。 24 /// 25 /// ref: https://man7.org/linux/man-pages/man7/raw.7.html 26 #[derive(Debug, Clone)] 27 pub struct RawSocket { 28 handle: GlobalSocketHandle, 29 /// 用户发送的数据包是否包含了IP头. 30 /// 如果是true,用户发送的数据包,必须包含IP头。(即用户要自行设置IP头+数据) 31 /// 如果是false,用户发送的数据包,不包含IP头。(即用户只要设置数据) 32 header_included: bool, 33 /// socket的metadata 34 metadata: SocketMetadata, 35 } 36 37 impl RawSocket { 38 /// 元数据的缓冲区的大小 39 pub const DEFAULT_METADATA_BUF_SIZE: usize = 1024; 40 /// 默认的接收缓冲区的大小 receive 41 pub const DEFAULT_RX_BUF_SIZE: usize = 64 * 1024; 42 /// 默认的发送缓冲区的大小 transmiss 43 pub const DEFAULT_TX_BUF_SIZE: usize = 64 * 1024; 44 45 /// @brief 创建一个原始的socket 46 /// 47 /// @param protocol 协议号 48 /// @param options socket的选项 49 /// 50 /// @return 返回创建的原始的socket 51 pub fn new(protocol: Protocol, options: SocketOptions) -> Self { 52 let rx_buffer = raw::PacketBuffer::new( 53 vec![raw::PacketMetadata::EMPTY; Self::DEFAULT_METADATA_BUF_SIZE], 54 vec![0; Self::DEFAULT_RX_BUF_SIZE], 55 ); 56 let tx_buffer = raw::PacketBuffer::new( 57 vec![raw::PacketMetadata::EMPTY; Self::DEFAULT_METADATA_BUF_SIZE], 58 vec![0; Self::DEFAULT_TX_BUF_SIZE], 59 ); 60 let protocol: u8 = protocol.into(); 61 let socket = raw::Socket::new( 62 wire::IpVersion::Ipv4, 63 wire::IpProtocol::from(protocol), 64 rx_buffer, 65 tx_buffer, 66 ); 67 68 // 把socket添加到socket集合中,并得到socket的句柄 69 let handle = GlobalSocketHandle::new_smoltcp_handle(SOCKET_SET.lock_irqsave().add(socket)); 70 71 let metadata = SocketMetadata::new( 72 SocketType::Raw, 73 Self::DEFAULT_RX_BUF_SIZE, 74 Self::DEFAULT_TX_BUF_SIZE, 75 Self::DEFAULT_METADATA_BUF_SIZE, 76 options, 77 ); 78 79 return Self { 80 handle, 81 header_included: false, 82 metadata, 83 }; 84 } 85 } 86 87 impl Socket for RawSocket { 88 fn close(&mut self) { 89 let mut socket_set_guard = SOCKET_SET.lock_irqsave(); 90 if let smoltcp::socket::Socket::Udp(mut sock) = 91 socket_set_guard.remove(self.handle.smoltcp_handle().unwrap()) 92 { 93 sock.close(); 94 } 95 drop(socket_set_guard); 96 poll_ifaces(); 97 } 98 99 fn read(&self, buf: &mut [u8]) -> (Result<usize, SystemError>, Endpoint) { 100 poll_ifaces(); 101 loop { 102 // 如何优化这里? 103 let mut socket_set_guard = SOCKET_SET.lock_irqsave(); 104 let socket = 105 socket_set_guard.get_mut::<raw::Socket>(self.handle.smoltcp_handle().unwrap()); 106 107 match socket.recv_slice(buf) { 108 Ok(len) => { 109 let packet = wire::Ipv4Packet::new_unchecked(buf); 110 return ( 111 Ok(len), 112 Endpoint::Ip(Some(wire::IpEndpoint { 113 addr: wire::IpAddress::Ipv4(packet.src_addr()), 114 port: 0, 115 })), 116 ); 117 } 118 Err(_) => { 119 if !self.metadata.options.contains(SocketOptions::BLOCK) { 120 // 如果是非阻塞的socket,就返回错误 121 return (Err(SystemError::EAGAIN_OR_EWOULDBLOCK), Endpoint::Ip(None)); 122 } 123 } 124 } 125 drop(socket_set_guard); 126 SocketHandleItem::sleep( 127 self.socket_handle(), 128 EPollEventType::EPOLLIN.bits() as u64, 129 HANDLE_MAP.read_irqsave(), 130 ); 131 } 132 } 133 134 fn write(&self, buf: &[u8], to: Option<Endpoint>) -> Result<usize, SystemError> { 135 // 如果用户发送的数据包,包含IP头,则直接发送 136 if self.header_included { 137 let mut socket_set_guard = SOCKET_SET.lock_irqsave(); 138 let socket = 139 socket_set_guard.get_mut::<raw::Socket>(self.handle.smoltcp_handle().unwrap()); 140 match socket.send_slice(buf) { 141 Ok(_) => { 142 return Ok(buf.len()); 143 } 144 Err(raw::SendError::BufferFull) => { 145 return Err(SystemError::ENOBUFS); 146 } 147 } 148 } else { 149 // 如果用户发送的数据包,不包含IP头,则需要自己构造IP头 150 151 if let Some(Endpoint::Ip(Some(endpoint))) = to { 152 let mut socket_set_guard = SOCKET_SET.lock_irqsave(); 153 let socket: &mut raw::Socket = 154 socket_set_guard.get_mut::<raw::Socket>(self.handle.smoltcp_handle().unwrap()); 155 156 // 暴力解决方案:只考虑0号网卡。 TODO:考虑多网卡的情况!!! 157 let iface = NET_DEVICES.read_irqsave().get(&0).unwrap().clone(); 158 159 // 构造IP头 160 let ipv4_src_addr: Option<wire::Ipv4Address> = 161 iface.inner_iface().lock().ipv4_addr(); 162 if ipv4_src_addr.is_none() { 163 return Err(SystemError::ENETUNREACH); 164 } 165 let ipv4_src_addr = ipv4_src_addr.unwrap(); 166 167 if let wire::IpAddress::Ipv4(ipv4_dst) = endpoint.addr { 168 let len = buf.len(); 169 170 // 创建20字节的IPv4头部 171 let mut buffer: Vec<u8> = vec![0u8; len + 20]; 172 let mut packet: wire::Ipv4Packet<&mut Vec<u8>> = 173 wire::Ipv4Packet::new_unchecked(&mut buffer); 174 175 // 封装ipv4 header 176 packet.set_version(4); 177 packet.set_header_len(20); 178 packet.set_total_len((20 + len) as u16); 179 packet.set_src_addr(ipv4_src_addr); 180 packet.set_dst_addr(ipv4_dst); 181 182 // 设置ipv4 header的protocol字段 183 packet.set_next_header(socket.ip_protocol()); 184 185 // 获取IP数据包的负载字段 186 let payload: &mut [u8] = packet.payload_mut(); 187 payload.copy_from_slice(buf); 188 189 // 填充checksum字段 190 packet.fill_checksum(); 191 192 // 发送数据包 193 socket.send_slice(&buffer).unwrap(); 194 195 iface.poll(&mut socket_set_guard).ok(); 196 197 drop(socket_set_guard); 198 return Ok(len); 199 } else { 200 kwarn!("Unsupport Ip protocol type!"); 201 return Err(SystemError::EINVAL); 202 } 203 } else { 204 // 如果没有指定目的地址,则返回错误 205 return Err(SystemError::ENOTCONN); 206 } 207 } 208 } 209 210 fn connect(&mut self, _endpoint: Endpoint) -> Result<(), SystemError> { 211 Ok(()) 212 } 213 214 fn metadata(&self) -> SocketMetadata { 215 self.metadata.clone() 216 } 217 218 fn box_clone(&self) -> Box<dyn Socket> { 219 Box::new(self.clone()) 220 } 221 222 fn socket_handle(&self) -> GlobalSocketHandle { 223 self.handle 224 } 225 226 fn as_any_ref(&self) -> &dyn core::any::Any { 227 self 228 } 229 230 fn as_any_mut(&mut self) -> &mut dyn core::any::Any { 231 self 232 } 233 } 234 235 /// @brief 表示udp socket 236 /// 237 /// https://man7.org/linux/man-pages/man7/udp.7.html 238 #[derive(Debug, Clone)] 239 pub struct UdpSocket { 240 pub handle: GlobalSocketHandle, 241 remote_endpoint: Option<Endpoint>, // 记录远程endpoint提供给connect(), 应该使用IP地址。 242 metadata: SocketMetadata, 243 } 244 245 impl UdpSocket { 246 /// 元数据的缓冲区的大小 247 pub const DEFAULT_METADATA_BUF_SIZE: usize = 1024; 248 /// 默认的接收缓冲区的大小 receive 249 pub const DEFAULT_RX_BUF_SIZE: usize = 64 * 1024; 250 /// 默认的发送缓冲区的大小 transmiss 251 pub const DEFAULT_TX_BUF_SIZE: usize = 64 * 1024; 252 253 /// @brief 创建一个udp的socket 254 /// 255 /// @param options socket的选项 256 /// 257 /// @return 返回创建的udp的socket 258 pub fn new(options: SocketOptions) -> Self { 259 let rx_buffer = udp::PacketBuffer::new( 260 vec![udp::PacketMetadata::EMPTY; Self::DEFAULT_METADATA_BUF_SIZE], 261 vec![0; Self::DEFAULT_RX_BUF_SIZE], 262 ); 263 let tx_buffer = udp::PacketBuffer::new( 264 vec![udp::PacketMetadata::EMPTY; Self::DEFAULT_METADATA_BUF_SIZE], 265 vec![0; Self::DEFAULT_TX_BUF_SIZE], 266 ); 267 let socket = udp::Socket::new(rx_buffer, tx_buffer); 268 269 // 把socket添加到socket集合中,并得到socket的句柄 270 let handle: GlobalSocketHandle = 271 GlobalSocketHandle::new_smoltcp_handle(SOCKET_SET.lock_irqsave().add(socket)); 272 273 let metadata = SocketMetadata::new( 274 SocketType::Udp, 275 Self::DEFAULT_RX_BUF_SIZE, 276 Self::DEFAULT_TX_BUF_SIZE, 277 Self::DEFAULT_METADATA_BUF_SIZE, 278 options, 279 ); 280 281 return Self { 282 handle, 283 remote_endpoint: None, 284 metadata, 285 }; 286 } 287 288 fn do_bind(&self, socket: &mut udp::Socket, endpoint: Endpoint) -> Result<(), SystemError> { 289 if let Endpoint::Ip(Some(mut ip)) = endpoint { 290 // 端口为0则分配随机端口 291 if ip.port == 0 { 292 ip.port = PORT_MANAGER.get_ephemeral_port(self.metadata.socket_type)?; 293 } 294 // 检测端口是否已被占用 295 PORT_MANAGER.bind_port(self.metadata.socket_type, ip.port)?; 296 297 let bind_res = if ip.addr.is_unspecified() { 298 socket.bind(ip.port) 299 } else { 300 socket.bind(ip) 301 }; 302 303 match bind_res { 304 Ok(()) => return Ok(()), 305 Err(_) => return Err(SystemError::EINVAL), 306 } 307 } else { 308 return Err(SystemError::EINVAL); 309 } 310 } 311 } 312 313 impl Socket for UdpSocket { 314 fn close(&mut self) { 315 let mut socket_set_guard = SOCKET_SET.lock_irqsave(); 316 if let smoltcp::socket::Socket::Udp(mut sock) = 317 socket_set_guard.remove(self.handle.smoltcp_handle().unwrap()) 318 { 319 sock.close(); 320 } 321 drop(socket_set_guard); 322 poll_ifaces(); 323 } 324 325 /// @brief 在read函数执行之前,请先bind到本地的指定端口 326 fn read(&self, buf: &mut [u8]) -> (Result<usize, SystemError>, Endpoint) { 327 loop { 328 // kdebug!("Wait22 to Read"); 329 poll_ifaces(); 330 let mut socket_set_guard = SOCKET_SET.lock_irqsave(); 331 let socket = 332 socket_set_guard.get_mut::<udp::Socket>(self.handle.smoltcp_handle().unwrap()); 333 334 // kdebug!("Wait to Read"); 335 336 if socket.can_recv() { 337 if let Ok((size, metadata)) = socket.recv_slice(buf) { 338 drop(socket_set_guard); 339 poll_ifaces(); 340 return (Ok(size), Endpoint::Ip(Some(metadata.endpoint))); 341 } 342 } else { 343 // 如果socket没有连接,则忙等 344 // return (Err(SystemError::ENOTCONN), Endpoint::Ip(None)); 345 } 346 drop(socket_set_guard); 347 SocketHandleItem::sleep( 348 self.socket_handle(), 349 EPollEventType::EPOLLIN.bits() as u64, 350 HANDLE_MAP.read_irqsave(), 351 ); 352 } 353 } 354 355 fn write(&self, buf: &[u8], to: Option<Endpoint>) -> Result<usize, SystemError> { 356 // kdebug!("udp to send: {:?}, len={}", to, buf.len()); 357 let remote_endpoint: &wire::IpEndpoint = { 358 if let Some(Endpoint::Ip(Some(ref endpoint))) = to { 359 endpoint 360 } else if let Some(Endpoint::Ip(Some(ref endpoint))) = self.remote_endpoint { 361 endpoint 362 } else { 363 return Err(SystemError::ENOTCONN); 364 } 365 }; 366 // kdebug!("udp write: remote = {:?}", remote_endpoint); 367 368 let mut socket_set_guard = SOCKET_SET.lock_irqsave(); 369 let socket = socket_set_guard.get_mut::<udp::Socket>(self.handle.smoltcp_handle().unwrap()); 370 // kdebug!("is open()={}", socket.is_open()); 371 // kdebug!("socket endpoint={:?}", socket.endpoint()); 372 if socket.can_send() { 373 // kdebug!("udp write: can send"); 374 match socket.send_slice(buf, *remote_endpoint) { 375 Ok(()) => { 376 // kdebug!("udp write: send ok"); 377 drop(socket_set_guard); 378 poll_ifaces(); 379 return Ok(buf.len()); 380 } 381 Err(_) => { 382 // kdebug!("udp write: send err"); 383 return Err(SystemError::ENOBUFS); 384 } 385 } 386 } else { 387 // kdebug!("udp write: can not send"); 388 return Err(SystemError::ENOBUFS); 389 }; 390 } 391 392 fn bind(&mut self, endpoint: Endpoint) -> Result<(), SystemError> { 393 let mut sockets = SOCKET_SET.lock_irqsave(); 394 let socket = sockets.get_mut::<udp::Socket>(self.handle.smoltcp_handle().unwrap()); 395 // kdebug!("UDP Bind to {:?}", endpoint); 396 return self.do_bind(socket, endpoint); 397 } 398 399 fn poll(&self) -> EPollEventType { 400 let sockets = SOCKET_SET.lock_irqsave(); 401 let socket = sockets.get::<udp::Socket>(self.handle.smoltcp_handle().unwrap()); 402 403 return SocketPollMethod::udp_poll( 404 socket, 405 HANDLE_MAP 406 .read_irqsave() 407 .get(&self.socket_handle()) 408 .unwrap() 409 .shutdown_type(), 410 ); 411 } 412 413 fn connect(&mut self, endpoint: Endpoint) -> Result<(), SystemError> { 414 if let Endpoint::Ip(_) = endpoint { 415 self.remote_endpoint = Some(endpoint); 416 Ok(()) 417 } else { 418 Err(SystemError::EINVAL) 419 } 420 } 421 422 fn ioctl( 423 &self, 424 _cmd: usize, 425 _arg0: usize, 426 _arg1: usize, 427 _arg2: usize, 428 ) -> Result<usize, SystemError> { 429 todo!() 430 } 431 432 fn metadata(&self) -> SocketMetadata { 433 self.metadata.clone() 434 } 435 436 fn box_clone(&self) -> Box<dyn Socket> { 437 return Box::new(self.clone()); 438 } 439 440 fn endpoint(&self) -> Option<Endpoint> { 441 let sockets = SOCKET_SET.lock_irqsave(); 442 let socket = sockets.get::<udp::Socket>(self.handle.smoltcp_handle().unwrap()); 443 let listen_endpoint = socket.endpoint(); 444 445 if listen_endpoint.port == 0 { 446 return None; 447 } else { 448 // 如果listen_endpoint的address是None,意味着“监听所有的地址”。 449 // 这里假设所有的地址都是ipv4 450 // TODO: 支持ipv6 451 let result = wire::IpEndpoint::new( 452 listen_endpoint 453 .addr 454 .unwrap_or(wire::IpAddress::v4(0, 0, 0, 0)), 455 listen_endpoint.port, 456 ); 457 return Some(Endpoint::Ip(Some(result))); 458 } 459 } 460 461 fn peer_endpoint(&self) -> Option<Endpoint> { 462 return self.remote_endpoint.clone(); 463 } 464 465 fn socket_handle(&self) -> GlobalSocketHandle { 466 self.handle 467 } 468 469 fn as_any_ref(&self) -> &dyn core::any::Any { 470 self 471 } 472 473 fn as_any_mut(&mut self) -> &mut dyn core::any::Any { 474 self 475 } 476 } 477 478 /// @brief 表示 tcp socket 479 /// 480 /// https://man7.org/linux/man-pages/man7/tcp.7.html 481 #[derive(Debug, Clone)] 482 pub struct TcpSocket { 483 handles: Vec<GlobalSocketHandle>, 484 local_endpoint: Option<wire::IpEndpoint>, // save local endpoint for bind() 485 is_listening: bool, 486 metadata: SocketMetadata, 487 } 488 489 impl TcpSocket { 490 /// 元数据的缓冲区的大小 491 pub const DEFAULT_METADATA_BUF_SIZE: usize = 1024; 492 /// 默认的接收缓冲区的大小 receive 493 pub const DEFAULT_RX_BUF_SIZE: usize = 512 * 1024; 494 /// 默认的发送缓冲区的大小 transmiss 495 pub const DEFAULT_TX_BUF_SIZE: usize = 512 * 1024; 496 497 /// TcpSocket的特殊事件,用于在事件等待队列上sleep 498 pub const CAN_CONNECT: u64 = 1u64 << 63; 499 pub const CAN_ACCPET: u64 = 1u64 << 62; 500 501 /// @brief 创建一个tcp的socket 502 /// 503 /// @param options socket的选项 504 /// 505 /// @return 返回创建的tcp的socket 506 pub fn new(options: SocketOptions) -> Self { 507 // 创建handles数组并把socket添加到socket集合中,并得到socket的句柄 508 let handles: Vec<GlobalSocketHandle> = vec![GlobalSocketHandle::new_smoltcp_handle( 509 SOCKET_SET.lock_irqsave().add(Self::create_new_socket()), 510 )]; 511 512 let metadata = SocketMetadata::new( 513 SocketType::Tcp, 514 Self::DEFAULT_RX_BUF_SIZE, 515 Self::DEFAULT_TX_BUF_SIZE, 516 Self::DEFAULT_METADATA_BUF_SIZE, 517 options, 518 ); 519 // kdebug!("when there's a new tcp socket,its'len: {}",handles.len()); 520 521 return Self { 522 handles, 523 local_endpoint: None, 524 is_listening: false, 525 metadata, 526 }; 527 } 528 529 fn do_listen( 530 &mut self, 531 socket: &mut tcp::Socket, 532 local_endpoint: wire::IpEndpoint, 533 ) -> Result<(), SystemError> { 534 let listen_result = if local_endpoint.addr.is_unspecified() { 535 // kdebug!("Tcp Socket Listen on port {}", local_endpoint.port); 536 socket.listen(local_endpoint.port) 537 } else { 538 // kdebug!("Tcp Socket Listen on {local_endpoint}"); 539 socket.listen(local_endpoint) 540 }; 541 return match listen_result { 542 Ok(()) => { 543 // kdebug!( 544 // "Tcp Socket Listen on {local_endpoint}, open?:{}", 545 // socket.is_open() 546 // ); 547 self.is_listening = true; 548 549 Ok(()) 550 } 551 Err(_) => Err(SystemError::EINVAL), 552 }; 553 } 554 555 /// # create_new_socket - 创建新的TCP套接字 556 /// 557 /// 该函数用于创建一个新的TCP套接字,并返回该套接字的引用。 558 fn create_new_socket() -> tcp::Socket<'static> { 559 // 初始化tcp的buffer 560 let rx_buffer = tcp::SocketBuffer::new(vec![0; Self::DEFAULT_RX_BUF_SIZE]); 561 let tx_buffer = tcp::SocketBuffer::new(vec![0; Self::DEFAULT_TX_BUF_SIZE]); 562 tcp::Socket::new(rx_buffer, tx_buffer) 563 } 564 } 565 566 impl Socket for TcpSocket { 567 fn close(&mut self) { 568 for handle in self.handles.iter() { 569 { 570 let mut socket_set_guard = SOCKET_SET.lock_irqsave(); 571 let smoltcp_handle = handle.smoltcp_handle().unwrap(); 572 socket_set_guard 573 .get_mut::<smoltcp::socket::tcp::Socket>(smoltcp_handle) 574 .close(); 575 drop(socket_set_guard); 576 } 577 poll_ifaces(); 578 SOCKET_SET 579 .lock_irqsave() 580 .remove(handle.smoltcp_handle().unwrap()); 581 // kdebug!("[Socket] [TCP] Close: {:?}", handle); 582 } 583 } 584 585 fn read(&self, buf: &mut [u8]) -> (Result<usize, SystemError>, Endpoint) { 586 if HANDLE_MAP 587 .read_irqsave() 588 .get(&self.socket_handle()) 589 .unwrap() 590 .shutdown_type() 591 .contains(ShutdownType::RCV_SHUTDOWN) 592 { 593 return (Err(SystemError::ENOTCONN), Endpoint::Ip(None)); 594 } 595 // kdebug!("tcp socket: read, buf len={}", buf.len()); 596 // kdebug!("tcp socket:read, socket'len={}",self.handle.len()); 597 loop { 598 poll_ifaces(); 599 let mut socket_set_guard = SOCKET_SET.lock_irqsave(); 600 601 let socket = socket_set_guard 602 .get_mut::<tcp::Socket>(self.handles.get(0).unwrap().smoltcp_handle().unwrap()); 603 604 // 如果socket已经关闭,返回错误 605 if !socket.is_active() { 606 // kdebug!("Tcp Socket Read Error, socket is closed"); 607 return (Err(SystemError::ENOTCONN), Endpoint::Ip(None)); 608 } 609 610 if socket.may_recv() { 611 match socket.recv_slice(buf) { 612 Ok(size) => { 613 if size > 0 { 614 let endpoint = if let Some(p) = socket.remote_endpoint() { 615 p 616 } else { 617 return (Err(SystemError::ENOTCONN), Endpoint::Ip(None)); 618 }; 619 620 drop(socket_set_guard); 621 poll_ifaces(); 622 return (Ok(size), Endpoint::Ip(Some(endpoint))); 623 } 624 } 625 Err(tcp::RecvError::InvalidState) => { 626 kwarn!("Tcp Socket Read Error, InvalidState"); 627 return (Err(SystemError::ENOTCONN), Endpoint::Ip(None)); 628 } 629 Err(tcp::RecvError::Finished) => { 630 // 对端写端已关闭,我们应该关闭读端 631 HANDLE_MAP 632 .write_irqsave() 633 .get_mut(&self.socket_handle()) 634 .unwrap() 635 .shutdown_type_writer() 636 .insert(ShutdownType::RCV_SHUTDOWN); 637 return (Err(SystemError::ENOTCONN), Endpoint::Ip(None)); 638 } 639 } 640 } else { 641 return (Err(SystemError::ENOTCONN), Endpoint::Ip(None)); 642 } 643 drop(socket_set_guard); 644 SocketHandleItem::sleep( 645 self.socket_handle(), 646 (EPollEventType::EPOLLIN.bits() | EPollEventType::EPOLLHUP.bits()) as u64, 647 HANDLE_MAP.read_irqsave(), 648 ); 649 } 650 } 651 652 fn write(&self, buf: &[u8], _to: Option<Endpoint>) -> Result<usize, SystemError> { 653 if HANDLE_MAP 654 .read_irqsave() 655 .get(&self.socket_handle()) 656 .unwrap() 657 .shutdown_type() 658 .contains(ShutdownType::RCV_SHUTDOWN) 659 { 660 return Err(SystemError::ENOTCONN); 661 } 662 // kdebug!("tcp socket:write, socket'len={}",self.handle.len()); 663 664 let mut socket_set_guard = SOCKET_SET.lock_irqsave(); 665 666 let socket = socket_set_guard 667 .get_mut::<tcp::Socket>(self.handles.get(0).unwrap().smoltcp_handle().unwrap()); 668 669 if socket.is_open() { 670 if socket.can_send() { 671 match socket.send_slice(buf) { 672 Ok(size) => { 673 drop(socket_set_guard); 674 poll_ifaces(); 675 return Ok(size); 676 } 677 Err(e) => { 678 kerror!("Tcp Socket Write Error {e:?}"); 679 return Err(SystemError::ENOBUFS); 680 } 681 } 682 } else { 683 return Err(SystemError::ENOBUFS); 684 } 685 } 686 687 return Err(SystemError::ENOTCONN); 688 } 689 690 fn poll(&self) -> EPollEventType { 691 let mut socket_set_guard = SOCKET_SET.lock_irqsave(); 692 // kdebug!("tcp socket:poll, socket'len={}",self.handle.len()); 693 694 let socket = socket_set_guard 695 .get_mut::<tcp::Socket>(self.handles.get(0).unwrap().smoltcp_handle().unwrap()); 696 return SocketPollMethod::tcp_poll( 697 socket, 698 HANDLE_MAP 699 .read_irqsave() 700 .get(&self.socket_handle()) 701 .unwrap() 702 .shutdown_type(), 703 ); 704 } 705 706 fn connect(&mut self, endpoint: Endpoint) -> Result<(), SystemError> { 707 let mut sockets = SOCKET_SET.lock_irqsave(); 708 // kdebug!("tcp socket:connect, socket'len={}",self.handle.len()); 709 710 let socket = 711 sockets.get_mut::<tcp::Socket>(self.handles.get(0).unwrap().smoltcp_handle().unwrap()); 712 713 if let Endpoint::Ip(Some(ip)) = endpoint { 714 let temp_port = PORT_MANAGER.get_ephemeral_port(self.metadata.socket_type)?; 715 // 检测端口是否被占用 716 PORT_MANAGER.bind_port(self.metadata.socket_type, temp_port)?; 717 718 // kdebug!("temp_port: {}", temp_port); 719 let iface: Arc<dyn NetDevice> = NET_DEVICES.write_irqsave().get(&0).unwrap().clone(); 720 let mut inner_iface = iface.inner_iface().lock(); 721 // kdebug!("to connect: {ip:?}"); 722 723 match socket.connect(inner_iface.context(), ip, temp_port) { 724 Ok(()) => { 725 // avoid deadlock 726 drop(inner_iface); 727 drop(iface); 728 drop(sockets); 729 loop { 730 poll_ifaces(); 731 let mut sockets = SOCKET_SET.lock_irqsave(); 732 let socket = sockets.get_mut::<tcp::Socket>( 733 self.handles.get(0).unwrap().smoltcp_handle().unwrap(), 734 ); 735 736 match socket.state() { 737 tcp::State::Established => { 738 return Ok(()); 739 } 740 tcp::State::SynSent => { 741 drop(sockets); 742 SocketHandleItem::sleep( 743 self.socket_handle(), 744 Self::CAN_CONNECT, 745 HANDLE_MAP.read_irqsave(), 746 ); 747 } 748 _ => { 749 return Err(SystemError::ECONNREFUSED); 750 } 751 } 752 } 753 } 754 Err(e) => { 755 // kerror!("Tcp Socket Connect Error {e:?}"); 756 match e { 757 tcp::ConnectError::InvalidState => return Err(SystemError::EISCONN), 758 tcp::ConnectError::Unaddressable => return Err(SystemError::EADDRNOTAVAIL), 759 } 760 } 761 } 762 } else { 763 return Err(SystemError::EINVAL); 764 } 765 } 766 767 /// @brief tcp socket 监听 local_endpoint 端口 768 /// 769 /// @param backlog 未处理的连接队列的最大长度 770 fn listen(&mut self, backlog: usize) -> Result<(), SystemError> { 771 if self.is_listening { 772 return Ok(()); 773 } 774 775 let local_endpoint = self.local_endpoint.ok_or(SystemError::EINVAL)?; 776 let mut sockets = SOCKET_SET.lock_irqsave(); 777 // 获取handle的数量 778 let handlen = self.handles.len(); 779 let backlog = handlen.max(backlog); 780 781 // 添加剩余需要构建的socket 782 // kdebug!("tcp socket:before listen, socket'len={}", self.handle_list.len()); 783 let mut handle_guard = HANDLE_MAP.write_irqsave(); 784 let wait_queue = Arc::clone(&handle_guard.get(&self.socket_handle()).unwrap().wait_queue); 785 786 self.handles.extend((handlen..backlog).map(|_| { 787 let socket = Self::create_new_socket(); 788 let handle = GlobalSocketHandle::new_smoltcp_handle(sockets.add(socket)); 789 let handle_item = SocketHandleItem::new(Some(wait_queue.clone())); 790 handle_guard.insert(handle, handle_item); 791 handle 792 })); 793 // kdebug!("tcp socket:listen, socket'len={}",self.handle.len()); 794 // kdebug!("tcp socket:listen, backlog={backlog}"); 795 796 // 监听所有的socket 797 for i in 0..backlog { 798 let handle = self.handles.get(i).unwrap(); 799 800 let socket = sockets.get_mut::<tcp::Socket>(handle.smoltcp_handle().unwrap()); 801 802 if !socket.is_listening() { 803 // kdebug!("Tcp Socket is already listening on {local_endpoint}"); 804 self.do_listen(socket, local_endpoint)?; 805 } 806 // kdebug!("Tcp Socket before listen, open={}", socket.is_open()); 807 } 808 return Ok(()); 809 } 810 811 fn bind(&mut self, endpoint: Endpoint) -> Result<(), SystemError> { 812 if let Endpoint::Ip(Some(mut ip)) = endpoint { 813 if ip.port == 0 { 814 ip.port = PORT_MANAGER.get_ephemeral_port(self.metadata.socket_type)?; 815 } 816 817 // 检测端口是否已被占用 818 PORT_MANAGER.bind_port(self.metadata.socket_type, ip.port)?; 819 // kdebug!("tcp socket:bind, socket'len={}",self.handle.len()); 820 821 self.local_endpoint = Some(ip); 822 self.is_listening = false; 823 return Ok(()); 824 } 825 return Err(SystemError::EINVAL); 826 } 827 828 fn shutdown(&mut self, shutdown_type: super::ShutdownType) -> Result<(), SystemError> { 829 // TODO:目前只是在表层判断,对端不知晓,后续需使用tcp实现 830 HANDLE_MAP 831 .write_irqsave() 832 .get_mut(&self.socket_handle()) 833 .unwrap() 834 .shutdown_type = RwLock::new(shutdown_type); 835 return Ok(()); 836 } 837 838 fn accept(&mut self) -> Result<(Box<dyn Socket>, Endpoint), SystemError> { 839 if !self.is_listening { 840 return Err(SystemError::EINVAL); 841 } 842 let endpoint = self.local_endpoint.ok_or(SystemError::EINVAL)?; 843 loop { 844 // kdebug!("tcp accept: poll_ifaces()"); 845 poll_ifaces(); 846 // kdebug!("tcp socket:accept, socket'len={}", self.handle_list.len()); 847 848 let mut sockset = SOCKET_SET.lock_irqsave(); 849 // Get the corresponding activated handler 850 let global_handle_index = self.handles.iter().position(|handle| { 851 let con_smol_sock = sockset.get::<tcp::Socket>(handle.smoltcp_handle().unwrap()); 852 con_smol_sock.is_active() 853 }); 854 855 if let Some(handle_index) = global_handle_index { 856 let con_smol_sock = sockset 857 .get::<tcp::Socket>(self.handles[handle_index].smoltcp_handle().unwrap()); 858 859 // kdebug!("[Socket] [TCP] Accept: {:?}", handle); 860 // handle is connected socket's handle 861 let remote_ep = con_smol_sock 862 .remote_endpoint() 863 .ok_or(SystemError::ENOTCONN)?; 864 865 let mut tcp_socket = Self::create_new_socket(); 866 self.do_listen(&mut tcp_socket, endpoint)?; 867 868 let new_handle = GlobalSocketHandle::new_smoltcp_handle(sockset.add(tcp_socket)); 869 870 // let handle in TcpSock be the new empty handle, and return the old connected handle 871 let old_handle = core::mem::replace(&mut self.handles[handle_index], new_handle); 872 873 let metadata = SocketMetadata::new( 874 SocketType::Tcp, 875 Self::DEFAULT_TX_BUF_SIZE, 876 Self::DEFAULT_RX_BUF_SIZE, 877 Self::DEFAULT_METADATA_BUF_SIZE, 878 self.metadata.options, 879 ); 880 881 let sock_ret = Box::new(TcpSocket { 882 handles: vec![old_handle], 883 local_endpoint: self.local_endpoint, 884 is_listening: false, 885 metadata, 886 }); 887 888 { 889 let mut handle_guard = HANDLE_MAP.write_irqsave(); 890 // 先删除原来的 891 let item = handle_guard.remove(&old_handle).unwrap(); 892 893 // 按照smoltcp行为,将新的handle绑定到原来的item 894 let new_item = SocketHandleItem::new(None); 895 handle_guard.insert(old_handle, new_item); 896 // 插入新的item 897 handle_guard.insert(new_handle, item); 898 drop(handle_guard); 899 } 900 return Ok((sock_ret, Endpoint::Ip(Some(remote_ep)))); 901 } 902 903 drop(sockset); 904 905 // kdebug!("[TCP] [Accept] sleeping socket with handle: {:?}", self.handles.get(0).unwrap().smoltcp_handle().unwrap()); 906 SocketHandleItem::sleep( 907 self.socket_handle(), // NOTICE 908 Self::CAN_ACCPET, 909 HANDLE_MAP.read_irqsave(), 910 ); 911 // kdebug!("tcp socket:after sleep, handle_guard'len={}",HANDLE_MAP.write_irqsave().len()); 912 } 913 } 914 915 fn endpoint(&self) -> Option<Endpoint> { 916 let mut result: Option<Endpoint> = self.local_endpoint.map(|x| Endpoint::Ip(Some(x))); 917 918 if result.is_none() { 919 let sockets = SOCKET_SET.lock_irqsave(); 920 // kdebug!("tcp socket:endpoint, socket'len={}",self.handle.len()); 921 922 let socket = 923 sockets.get::<tcp::Socket>(self.handles.get(0).unwrap().smoltcp_handle().unwrap()); 924 if let Some(ep) = socket.local_endpoint() { 925 result = Some(Endpoint::Ip(Some(ep))); 926 } 927 } 928 return result; 929 } 930 931 fn peer_endpoint(&self) -> Option<Endpoint> { 932 let sockets = SOCKET_SET.lock_irqsave(); 933 // kdebug!("tcp socket:peer_endpoint, socket'len={}",self.handle.len()); 934 935 let socket = 936 sockets.get::<tcp::Socket>(self.handles.get(0).unwrap().smoltcp_handle().unwrap()); 937 return socket.remote_endpoint().map(|x| Endpoint::Ip(Some(x))); 938 } 939 940 fn metadata(&self) -> SocketMetadata { 941 self.metadata.clone() 942 } 943 944 fn box_clone(&self) -> Box<dyn Socket> { 945 Box::new(self.clone()) 946 } 947 948 fn socket_handle(&self) -> GlobalSocketHandle { 949 // kdebug!("tcp socket:socket_handle, socket'len={}",self.handle.len()); 950 951 *self.handles.get(0).unwrap() 952 } 953 954 fn as_any_ref(&self) -> &dyn core::any::Any { 955 self 956 } 957 958 fn as_any_mut(&mut self) -> &mut dyn core::any::Any { 959 self 960 } 961 } 962