1 use alloc::{boxed::Box, sync::Arc, vec::Vec}; 2 use smoltcp::{ 3 socket::{raw, tcp, udp, AnySocket}, 4 wire, 5 }; 6 use system_error::SystemError; 7 8 use crate::{ 9 arch::rand::rand, 10 driver::net::NetDevice, 11 kerror, kwarn, 12 libs::rwlock::RwLock, 13 net::{ 14 event_poll::EPollEventType, net_core::poll_ifaces, Endpoint, Protocol, ShutdownType, 15 NET_DEVICES, 16 }, 17 }; 18 19 use super::{ 20 handle::GlobalSocketHandle, Socket, SocketHandleItem, SocketMetadata, SocketOptions, 21 SocketPollMethod, SocketType, HANDLE_MAP, PORT_MANAGER, SOCKET_SET, 22 }; 23 24 /// @brief 表示原始的socket。原始套接字绕过传输层协议(如 TCP 或 UDP)并提供对网络层协议(如 IP)的直接访问。 25 /// 26 /// ref: https://man7.org/linux/man-pages/man7/raw.7.html 27 #[derive(Debug, Clone)] 28 pub struct RawSocket { 29 handle: GlobalSocketHandle, 30 /// 用户发送的数据包是否包含了IP头. 31 /// 如果是true,用户发送的数据包,必须包含IP头。(即用户要自行设置IP头+数据) 32 /// 如果是false,用户发送的数据包,不包含IP头。(即用户只要设置数据) 33 header_included: bool, 34 /// socket的metadata 35 metadata: SocketMetadata, 36 } 37 38 impl RawSocket { 39 /// 元数据的缓冲区的大小 40 pub const DEFAULT_METADATA_BUF_SIZE: usize = 1024; 41 /// 默认的接收缓冲区的大小 receive 42 pub const DEFAULT_RX_BUF_SIZE: usize = 64 * 1024; 43 /// 默认的发送缓冲区的大小 transmiss 44 pub const DEFAULT_TX_BUF_SIZE: usize = 64 * 1024; 45 46 /// @brief 创建一个原始的socket 47 /// 48 /// @param protocol 协议号 49 /// @param options socket的选项 50 /// 51 /// @return 返回创建的原始的socket 52 pub fn new(protocol: Protocol, options: SocketOptions) -> Self { 53 let rx_buffer = raw::PacketBuffer::new( 54 vec![raw::PacketMetadata::EMPTY; Self::DEFAULT_METADATA_BUF_SIZE], 55 vec![0; Self::DEFAULT_RX_BUF_SIZE], 56 ); 57 let tx_buffer = raw::PacketBuffer::new( 58 vec![raw::PacketMetadata::EMPTY; Self::DEFAULT_METADATA_BUF_SIZE], 59 vec![0; Self::DEFAULT_TX_BUF_SIZE], 60 ); 61 let protocol: u8 = protocol.into(); 62 let socket = raw::Socket::new( 63 wire::IpVersion::Ipv4, 64 wire::IpProtocol::from(protocol), 65 rx_buffer, 66 tx_buffer, 67 ); 68 69 // 把socket添加到socket集合中,并得到socket的句柄 70 let handle = GlobalSocketHandle::new_smoltcp_handle(SOCKET_SET.lock_irqsave().add(socket)); 71 72 let metadata = SocketMetadata::new( 73 SocketType::Raw, 74 Self::DEFAULT_RX_BUF_SIZE, 75 Self::DEFAULT_TX_BUF_SIZE, 76 Self::DEFAULT_METADATA_BUF_SIZE, 77 options, 78 ); 79 80 return Self { 81 handle, 82 header_included: false, 83 metadata, 84 }; 85 } 86 } 87 88 impl Socket for RawSocket { 89 fn close(&mut self) { 90 let mut socket_set_guard = SOCKET_SET.lock_irqsave(); 91 socket_set_guard.remove(self.handle.smoltcp_handle().unwrap()); // 删除的时候,会发送一条FINISH的信息? 92 drop(socket_set_guard); 93 poll_ifaces(); 94 } 95 96 fn read(&self, buf: &mut [u8]) -> (Result<usize, SystemError>, Endpoint) { 97 poll_ifaces(); 98 loop { 99 // 如何优化这里? 100 let mut socket_set_guard = SOCKET_SET.lock_irqsave(); 101 let socket = 102 socket_set_guard.get_mut::<raw::Socket>(self.handle.smoltcp_handle().unwrap()); 103 104 match socket.recv_slice(buf) { 105 Ok(len) => { 106 let packet = wire::Ipv4Packet::new_unchecked(buf); 107 return ( 108 Ok(len), 109 Endpoint::Ip(Some(wire::IpEndpoint { 110 addr: wire::IpAddress::Ipv4(packet.src_addr()), 111 port: 0, 112 })), 113 ); 114 } 115 Err(_) => { 116 if !self.metadata.options.contains(SocketOptions::BLOCK) { 117 // 如果是非阻塞的socket,就返回错误 118 return (Err(SystemError::EAGAIN_OR_EWOULDBLOCK), Endpoint::Ip(None)); 119 } 120 } 121 } 122 drop(socket_set_guard); 123 SocketHandleItem::sleep( 124 self.socket_handle(), 125 EPollEventType::EPOLLIN.bits() as u64, 126 HANDLE_MAP.read_irqsave(), 127 ); 128 } 129 } 130 131 fn write(&self, buf: &[u8], to: Option<Endpoint>) -> Result<usize, SystemError> { 132 // 如果用户发送的数据包,包含IP头,则直接发送 133 if self.header_included { 134 let mut socket_set_guard = SOCKET_SET.lock_irqsave(); 135 let socket = 136 socket_set_guard.get_mut::<raw::Socket>(self.handle.smoltcp_handle().unwrap()); 137 match socket.send_slice(buf) { 138 Ok(_) => { 139 return Ok(buf.len()); 140 } 141 Err(raw::SendError::BufferFull) => { 142 return Err(SystemError::ENOBUFS); 143 } 144 } 145 } else { 146 // 如果用户发送的数据包,不包含IP头,则需要自己构造IP头 147 148 if let Some(Endpoint::Ip(Some(endpoint))) = to { 149 let mut socket_set_guard = SOCKET_SET.lock_irqsave(); 150 let socket: &mut raw::Socket = 151 socket_set_guard.get_mut::<raw::Socket>(self.handle.smoltcp_handle().unwrap()); 152 153 // 暴力解决方案:只考虑0号网卡。 TODO:考虑多网卡的情况!!! 154 let iface = NET_DEVICES.read_irqsave().get(&0).unwrap().clone(); 155 156 // 构造IP头 157 let ipv4_src_addr: Option<wire::Ipv4Address> = 158 iface.inner_iface().lock().ipv4_addr(); 159 if ipv4_src_addr.is_none() { 160 return Err(SystemError::ENETUNREACH); 161 } 162 let ipv4_src_addr = ipv4_src_addr.unwrap(); 163 164 if let wire::IpAddress::Ipv4(ipv4_dst) = endpoint.addr { 165 let len = buf.len(); 166 167 // 创建20字节的IPv4头部 168 let mut buffer: Vec<u8> = vec![0u8; len + 20]; 169 let mut packet: wire::Ipv4Packet<&mut Vec<u8>> = 170 wire::Ipv4Packet::new_unchecked(&mut buffer); 171 172 // 封装ipv4 header 173 packet.set_version(4); 174 packet.set_header_len(20); 175 packet.set_total_len((20 + len) as u16); 176 packet.set_src_addr(ipv4_src_addr); 177 packet.set_dst_addr(ipv4_dst); 178 179 // 设置ipv4 header的protocol字段 180 packet.set_next_header(socket.ip_protocol()); 181 182 // 获取IP数据包的负载字段 183 let payload: &mut [u8] = packet.payload_mut(); 184 payload.copy_from_slice(buf); 185 186 // 填充checksum字段 187 packet.fill_checksum(); 188 189 // 发送数据包 190 socket.send_slice(&buffer).unwrap(); 191 192 iface.poll(&mut socket_set_guard).ok(); 193 194 drop(socket_set_guard); 195 return Ok(len); 196 } else { 197 kwarn!("Unsupport Ip protocol type!"); 198 return Err(SystemError::EINVAL); 199 } 200 } else { 201 // 如果没有指定目的地址,则返回错误 202 return Err(SystemError::ENOTCONN); 203 } 204 } 205 } 206 207 fn connect(&mut self, _endpoint: Endpoint) -> Result<(), SystemError> { 208 Ok(()) 209 } 210 211 fn metadata(&self) -> SocketMetadata { 212 self.metadata.clone() 213 } 214 215 fn box_clone(&self) -> Box<dyn Socket> { 216 Box::new(self.clone()) 217 } 218 219 fn socket_handle(&self) -> GlobalSocketHandle { 220 self.handle 221 } 222 223 fn as_any_ref(&self) -> &dyn core::any::Any { 224 self 225 } 226 227 fn as_any_mut(&mut self) -> &mut dyn core::any::Any { 228 self 229 } 230 } 231 232 /// @brief 表示udp socket 233 /// 234 /// https://man7.org/linux/man-pages/man7/udp.7.html 235 #[derive(Debug, Clone)] 236 pub struct UdpSocket { 237 pub handle: GlobalSocketHandle, 238 remote_endpoint: Option<Endpoint>, // 记录远程endpoint提供给connect(), 应该使用IP地址。 239 metadata: SocketMetadata, 240 } 241 242 impl UdpSocket { 243 /// 元数据的缓冲区的大小 244 pub const DEFAULT_METADATA_BUF_SIZE: usize = 1024; 245 /// 默认的接收缓冲区的大小 receive 246 pub const DEFAULT_RX_BUF_SIZE: usize = 64 * 1024; 247 /// 默认的发送缓冲区的大小 transmiss 248 pub const DEFAULT_TX_BUF_SIZE: usize = 64 * 1024; 249 250 /// @brief 创建一个udp的socket 251 /// 252 /// @param options socket的选项 253 /// 254 /// @return 返回创建的udp的socket 255 pub fn new(options: SocketOptions) -> Self { 256 let rx_buffer = udp::PacketBuffer::new( 257 vec![udp::PacketMetadata::EMPTY; Self::DEFAULT_METADATA_BUF_SIZE], 258 vec![0; Self::DEFAULT_RX_BUF_SIZE], 259 ); 260 let tx_buffer = udp::PacketBuffer::new( 261 vec![udp::PacketMetadata::EMPTY; Self::DEFAULT_METADATA_BUF_SIZE], 262 vec![0; Self::DEFAULT_TX_BUF_SIZE], 263 ); 264 let socket = udp::Socket::new(rx_buffer, tx_buffer); 265 266 // 把socket添加到socket集合中,并得到socket的句柄 267 let handle: GlobalSocketHandle = 268 GlobalSocketHandle::new_smoltcp_handle(SOCKET_SET.lock_irqsave().add(socket)); 269 270 let metadata = SocketMetadata::new( 271 SocketType::Udp, 272 Self::DEFAULT_RX_BUF_SIZE, 273 Self::DEFAULT_TX_BUF_SIZE, 274 Self::DEFAULT_METADATA_BUF_SIZE, 275 options, 276 ); 277 278 return Self { 279 handle, 280 remote_endpoint: None, 281 metadata, 282 }; 283 } 284 285 fn do_bind(&self, socket: &mut udp::Socket, endpoint: Endpoint) -> Result<(), SystemError> { 286 if let Endpoint::Ip(Some(mut ip)) = endpoint { 287 // 端口为0则分配随机端口 288 if ip.port == 0 { 289 ip.port = PORT_MANAGER.get_ephemeral_port(self.metadata.socket_type)?; 290 } 291 // 检测端口是否已被占用 292 PORT_MANAGER.bind_port(self.metadata.socket_type, ip.port, self.clone())?; 293 294 let bind_res = if ip.addr.is_unspecified() { 295 socket.bind(ip.port) 296 } else { 297 socket.bind(ip) 298 }; 299 300 match bind_res { 301 Ok(()) => return Ok(()), 302 Err(_) => return Err(SystemError::EINVAL), 303 } 304 } else { 305 return Err(SystemError::EINVAL); 306 } 307 } 308 } 309 310 impl Socket for UdpSocket { 311 fn close(&mut self) { 312 let mut socket_set_guard = SOCKET_SET.lock_irqsave(); 313 socket_set_guard.remove(self.handle.smoltcp_handle().unwrap()); // 删除的时候,会发送一条FINISH的信息? 314 drop(socket_set_guard); 315 poll_ifaces(); 316 } 317 318 /// @brief 在read函数执行之前,请先bind到本地的指定端口 319 fn read(&self, buf: &mut [u8]) -> (Result<usize, SystemError>, Endpoint) { 320 loop { 321 // kdebug!("Wait22 to Read"); 322 poll_ifaces(); 323 let mut socket_set_guard = SOCKET_SET.lock_irqsave(); 324 let socket = 325 socket_set_guard.get_mut::<udp::Socket>(self.handle.smoltcp_handle().unwrap()); 326 327 // kdebug!("Wait to Read"); 328 329 if socket.can_recv() { 330 if let Ok((size, metadata)) = socket.recv_slice(buf) { 331 drop(socket_set_guard); 332 poll_ifaces(); 333 return (Ok(size), Endpoint::Ip(Some(metadata.endpoint))); 334 } 335 } else { 336 // 如果socket没有连接,则忙等 337 // return (Err(SystemError::ENOTCONN), Endpoint::Ip(None)); 338 } 339 drop(socket_set_guard); 340 SocketHandleItem::sleep( 341 self.socket_handle(), 342 EPollEventType::EPOLLIN.bits() as u64, 343 HANDLE_MAP.read_irqsave(), 344 ); 345 } 346 } 347 348 fn write(&self, buf: &[u8], to: Option<Endpoint>) -> Result<usize, SystemError> { 349 // kdebug!("udp to send: {:?}, len={}", to, buf.len()); 350 let remote_endpoint: &wire::IpEndpoint = { 351 if let Some(Endpoint::Ip(Some(ref endpoint))) = to { 352 endpoint 353 } else if let Some(Endpoint::Ip(Some(ref endpoint))) = self.remote_endpoint { 354 endpoint 355 } else { 356 return Err(SystemError::ENOTCONN); 357 } 358 }; 359 // kdebug!("udp write: remote = {:?}", remote_endpoint); 360 361 let mut socket_set_guard = SOCKET_SET.lock_irqsave(); 362 let socket = socket_set_guard.get_mut::<udp::Socket>(self.handle.smoltcp_handle().unwrap()); 363 // kdebug!("is open()={}", socket.is_open()); 364 // kdebug!("socket endpoint={:?}", socket.endpoint()); 365 if socket.can_send() { 366 // kdebug!("udp write: can send"); 367 match socket.send_slice(buf, *remote_endpoint) { 368 Ok(()) => { 369 // kdebug!("udp write: send ok"); 370 drop(socket_set_guard); 371 poll_ifaces(); 372 return Ok(buf.len()); 373 } 374 Err(_) => { 375 // kdebug!("udp write: send err"); 376 return Err(SystemError::ENOBUFS); 377 } 378 } 379 } else { 380 // kdebug!("udp write: can not send"); 381 return Err(SystemError::ENOBUFS); 382 }; 383 } 384 385 fn bind(&mut self, endpoint: Endpoint) -> Result<(), SystemError> { 386 let mut sockets = SOCKET_SET.lock_irqsave(); 387 let socket = sockets.get_mut::<udp::Socket>(self.handle.smoltcp_handle().unwrap()); 388 // kdebug!("UDP Bind to {:?}", endpoint); 389 return self.do_bind(socket, endpoint); 390 } 391 392 fn poll(&self) -> EPollEventType { 393 let sockets = SOCKET_SET.lock_irqsave(); 394 let socket = sockets.get::<udp::Socket>(self.handle.smoltcp_handle().unwrap()); 395 396 return SocketPollMethod::udp_poll( 397 socket, 398 HANDLE_MAP 399 .read_irqsave() 400 .get(&self.socket_handle()) 401 .unwrap() 402 .shutdown_type(), 403 ); 404 } 405 406 fn connect(&mut self, endpoint: Endpoint) -> Result<(), SystemError> { 407 if let Endpoint::Ip(_) = endpoint { 408 self.remote_endpoint = Some(endpoint); 409 Ok(()) 410 } else { 411 Err(SystemError::EINVAL) 412 } 413 } 414 415 fn ioctl( 416 &self, 417 _cmd: usize, 418 _arg0: usize, 419 _arg1: usize, 420 _arg2: usize, 421 ) -> Result<usize, SystemError> { 422 todo!() 423 } 424 425 fn metadata(&self) -> SocketMetadata { 426 self.metadata.clone() 427 } 428 429 fn box_clone(&self) -> Box<dyn Socket> { 430 return Box::new(self.clone()); 431 } 432 433 fn endpoint(&self) -> Option<Endpoint> { 434 let sockets = SOCKET_SET.lock_irqsave(); 435 let socket = sockets.get::<udp::Socket>(self.handle.smoltcp_handle().unwrap()); 436 let listen_endpoint = socket.endpoint(); 437 438 if listen_endpoint.port == 0 { 439 return None; 440 } else { 441 // 如果listen_endpoint的address是None,意味着“监听所有的地址”。 442 // 这里假设所有的地址都是ipv4 443 // TODO: 支持ipv6 444 let result = wire::IpEndpoint::new( 445 listen_endpoint 446 .addr 447 .unwrap_or(wire::IpAddress::v4(0, 0, 0, 0)), 448 listen_endpoint.port, 449 ); 450 return Some(Endpoint::Ip(Some(result))); 451 } 452 } 453 454 fn peer_endpoint(&self) -> Option<Endpoint> { 455 return self.remote_endpoint.clone(); 456 } 457 458 fn socket_handle(&self) -> GlobalSocketHandle { 459 self.handle 460 } 461 462 fn as_any_ref(&self) -> &dyn core::any::Any { 463 self 464 } 465 466 fn as_any_mut(&mut self) -> &mut dyn core::any::Any { 467 self 468 } 469 } 470 471 /// @brief 表示 tcp socket 472 /// 473 /// https://man7.org/linux/man-pages/man7/tcp.7.html 474 #[derive(Debug, Clone)] 475 pub struct TcpSocket { 476 handles: Vec<GlobalSocketHandle>, 477 local_endpoint: Option<wire::IpEndpoint>, // save local endpoint for bind() 478 is_listening: bool, 479 metadata: SocketMetadata, 480 } 481 482 impl TcpSocket { 483 /// 元数据的缓冲区的大小 484 pub const DEFAULT_METADATA_BUF_SIZE: usize = 1024; 485 /// 默认的接收缓冲区的大小 receive 486 pub const DEFAULT_RX_BUF_SIZE: usize = 512 * 1024; 487 /// 默认的发送缓冲区的大小 transmiss 488 pub const DEFAULT_TX_BUF_SIZE: usize = 512 * 1024; 489 490 /// TcpSocket的特殊事件,用于在事件等待队列上sleep 491 pub const CAN_CONNECT: u64 = 1u64 << 63; 492 pub const CAN_ACCPET: u64 = 1u64 << 62; 493 494 /// @brief 创建一个tcp的socket 495 /// 496 /// @param options socket的选项 497 /// 498 /// @return 返回创建的tcp的socket 499 pub fn new(options: SocketOptions) -> Self { 500 // 创建handles数组并把socket添加到socket集合中,并得到socket的句柄 501 let handles: Vec<GlobalSocketHandle> = vec![GlobalSocketHandle::new_smoltcp_handle( 502 SOCKET_SET.lock_irqsave().add(Self::create_new_socket()), 503 )]; 504 505 let metadata = SocketMetadata::new( 506 SocketType::Tcp, 507 Self::DEFAULT_RX_BUF_SIZE, 508 Self::DEFAULT_TX_BUF_SIZE, 509 Self::DEFAULT_METADATA_BUF_SIZE, 510 options, 511 ); 512 // kdebug!("when there's a new tcp socket,its'len: {}",handles.len()); 513 514 return Self { 515 handles, 516 local_endpoint: None, 517 is_listening: false, 518 metadata, 519 }; 520 } 521 522 fn do_listen( 523 &mut self, 524 socket: &mut tcp::Socket, 525 local_endpoint: wire::IpEndpoint, 526 ) -> Result<(), SystemError> { 527 let listen_result = if local_endpoint.addr.is_unspecified() { 528 // kdebug!("Tcp Socket Listen on port {}", local_endpoint.port); 529 socket.listen(local_endpoint.port) 530 } else { 531 // kdebug!("Tcp Socket Listen on {local_endpoint}"); 532 socket.listen(local_endpoint) 533 }; 534 return match listen_result { 535 Ok(()) => { 536 // kdebug!( 537 // "Tcp Socket Listen on {local_endpoint}, open?:{}", 538 // socket.is_open() 539 // ); 540 self.is_listening = true; 541 542 Ok(()) 543 } 544 Err(_) => Err(SystemError::EINVAL), 545 }; 546 } 547 548 /// # create_new_socket - 创建新的TCP套接字 549 /// 550 /// 该函数用于创建一个新的TCP套接字,并返回该套接字的引用。 551 fn create_new_socket() -> tcp::Socket<'static> { 552 // 初始化tcp的buffer 553 let rx_buffer = tcp::SocketBuffer::new(vec![0; Self::DEFAULT_RX_BUF_SIZE]); 554 let tx_buffer = tcp::SocketBuffer::new(vec![0; Self::DEFAULT_TX_BUF_SIZE]); 555 tcp::Socket::new(rx_buffer, tx_buffer) 556 } 557 } 558 559 impl Socket for TcpSocket { 560 fn close(&mut self) { 561 for handle in self.handles.iter() { 562 let mut socket_set_guard = SOCKET_SET.lock_irqsave(); 563 socket_set_guard.remove(handle.smoltcp_handle().unwrap()); // 删除的时候,会发送一条FINISH的信息? 564 drop(socket_set_guard); 565 } 566 poll_ifaces(); 567 } 568 569 fn read(&self, buf: &mut [u8]) -> (Result<usize, SystemError>, Endpoint) { 570 if HANDLE_MAP 571 .read_irqsave() 572 .get(&self.socket_handle()) 573 .unwrap() 574 .shutdown_type() 575 .contains(ShutdownType::RCV_SHUTDOWN) 576 { 577 return (Err(SystemError::ENOTCONN), Endpoint::Ip(None)); 578 } 579 // kdebug!("tcp socket: read, buf len={}", buf.len()); 580 // kdebug!("tcp socket:read, socket'len={}",self.handle.len()); 581 loop { 582 poll_ifaces(); 583 let mut socket_set_guard = SOCKET_SET.lock_irqsave(); 584 585 let socket = socket_set_guard 586 .get_mut::<tcp::Socket>(self.handles.get(0).unwrap().smoltcp_handle().unwrap()); 587 588 // 如果socket已经关闭,返回错误 589 if !socket.is_active() { 590 // kdebug!("Tcp Socket Read Error, socket is closed"); 591 return (Err(SystemError::ENOTCONN), Endpoint::Ip(None)); 592 } 593 594 if socket.may_recv() { 595 match socket.recv_slice(buf) { 596 Ok(size) => { 597 if size > 0 { 598 let endpoint = if let Some(p) = socket.remote_endpoint() { 599 p 600 } else { 601 return (Err(SystemError::ENOTCONN), Endpoint::Ip(None)); 602 }; 603 604 drop(socket_set_guard); 605 poll_ifaces(); 606 return (Ok(size), Endpoint::Ip(Some(endpoint))); 607 } 608 } 609 Err(tcp::RecvError::InvalidState) => { 610 kwarn!("Tcp Socket Read Error, InvalidState"); 611 return (Err(SystemError::ENOTCONN), Endpoint::Ip(None)); 612 } 613 Err(tcp::RecvError::Finished) => { 614 // 对端写端已关闭,我们应该关闭读端 615 HANDLE_MAP 616 .write_irqsave() 617 .get_mut(&self.socket_handle()) 618 .unwrap() 619 .shutdown_type_writer() 620 .insert(ShutdownType::RCV_SHUTDOWN); 621 return (Err(SystemError::ENOTCONN), Endpoint::Ip(None)); 622 } 623 } 624 } else { 625 return (Err(SystemError::ENOTCONN), Endpoint::Ip(None)); 626 } 627 drop(socket_set_guard); 628 SocketHandleItem::sleep( 629 self.socket_handle(), 630 EPollEventType::EPOLLIN.bits() as u64, 631 HANDLE_MAP.read_irqsave(), 632 ); 633 } 634 } 635 636 fn write(&self, buf: &[u8], _to: Option<Endpoint>) -> Result<usize, SystemError> { 637 if HANDLE_MAP 638 .read_irqsave() 639 .get(&self.socket_handle()) 640 .unwrap() 641 .shutdown_type() 642 .contains(ShutdownType::RCV_SHUTDOWN) 643 { 644 return Err(SystemError::ENOTCONN); 645 } 646 // kdebug!("tcp socket:write, socket'len={}",self.handle.len()); 647 648 let mut socket_set_guard = SOCKET_SET.lock_irqsave(); 649 650 let socket = socket_set_guard 651 .get_mut::<tcp::Socket>(self.handles.get(0).unwrap().smoltcp_handle().unwrap()); 652 653 if socket.is_open() { 654 if socket.can_send() { 655 match socket.send_slice(buf) { 656 Ok(size) => { 657 drop(socket_set_guard); 658 poll_ifaces(); 659 return Ok(size); 660 } 661 Err(e) => { 662 kerror!("Tcp Socket Write Error {e:?}"); 663 return Err(SystemError::ENOBUFS); 664 } 665 } 666 } else { 667 return Err(SystemError::ENOBUFS); 668 } 669 } 670 671 return Err(SystemError::ENOTCONN); 672 } 673 674 fn poll(&self) -> EPollEventType { 675 let mut socket_set_guard = SOCKET_SET.lock_irqsave(); 676 // kdebug!("tcp socket:poll, socket'len={}",self.handle.len()); 677 678 let socket = socket_set_guard 679 .get_mut::<tcp::Socket>(self.handles.get(0).unwrap().smoltcp_handle().unwrap()); 680 return SocketPollMethod::tcp_poll( 681 socket, 682 HANDLE_MAP 683 .read_irqsave() 684 .get(&self.socket_handle()) 685 .unwrap() 686 .shutdown_type(), 687 ); 688 } 689 690 fn connect(&mut self, endpoint: Endpoint) -> Result<(), SystemError> { 691 let mut sockets = SOCKET_SET.lock_irqsave(); 692 // kdebug!("tcp socket:connect, socket'len={}",self.handle.len()); 693 694 let socket = 695 sockets.get_mut::<tcp::Socket>(self.handles.get(0).unwrap().smoltcp_handle().unwrap()); 696 697 if let Endpoint::Ip(Some(ip)) = endpoint { 698 let temp_port = PORT_MANAGER.get_ephemeral_port(self.metadata.socket_type)?; 699 // 检测端口是否被占用 700 PORT_MANAGER.bind_port(self.metadata.socket_type, temp_port, self.clone())?; 701 702 // kdebug!("temp_port: {}", temp_port); 703 let iface: Arc<dyn NetDevice> = NET_DEVICES.write_irqsave().get(&0).unwrap().clone(); 704 let mut inner_iface = iface.inner_iface().lock(); 705 // kdebug!("to connect: {ip:?}"); 706 707 match socket.connect(inner_iface.context(), ip, temp_port) { 708 Ok(()) => { 709 // avoid deadlock 710 drop(inner_iface); 711 drop(iface); 712 drop(sockets); 713 loop { 714 poll_ifaces(); 715 let mut sockets = SOCKET_SET.lock_irqsave(); 716 let socket = sockets.get_mut::<tcp::Socket>( 717 self.handles.get(0).unwrap().smoltcp_handle().unwrap(), 718 ); 719 720 match socket.state() { 721 tcp::State::Established => { 722 return Ok(()); 723 } 724 tcp::State::SynSent => { 725 drop(sockets); 726 SocketHandleItem::sleep( 727 self.socket_handle(), 728 Self::CAN_CONNECT, 729 HANDLE_MAP.read_irqsave(), 730 ); 731 } 732 _ => { 733 return Err(SystemError::ECONNREFUSED); 734 } 735 } 736 } 737 } 738 Err(e) => { 739 // kerror!("Tcp Socket Connect Error {e:?}"); 740 match e { 741 tcp::ConnectError::InvalidState => return Err(SystemError::EISCONN), 742 tcp::ConnectError::Unaddressable => return Err(SystemError::EADDRNOTAVAIL), 743 } 744 } 745 } 746 } else { 747 return Err(SystemError::EINVAL); 748 } 749 } 750 751 /// @brief tcp socket 监听 local_endpoint 端口 752 /// 753 /// @param backlog 未处理的连接队列的最大长度. 由于smoltcp不支持backlog,所以这个参数目前无效 754 fn listen(&mut self, backlog: usize) -> Result<(), SystemError> { 755 if self.is_listening { 756 return Ok(()); 757 } 758 759 let local_endpoint = self.local_endpoint.ok_or(SystemError::EINVAL)?; 760 let mut sockets = SOCKET_SET.lock_irqsave(); 761 // 获取handle的数量 762 let handlen = self.handles.len(); 763 let backlog = handlen.max(backlog); 764 765 // 添加剩余需要构建的socket 766 // kdebug!("tcp socket:before listen, socket'len={}",self.handle.len()); 767 let mut handle_guard = HANDLE_MAP.write_irqsave(); 768 self.handles.extend((handlen..backlog).map(|_| { 769 let socket = Self::create_new_socket(); 770 let handle = GlobalSocketHandle::new_smoltcp_handle(sockets.add(socket)); 771 let handle_item = SocketHandleItem::new(); 772 handle_guard.insert(handle, handle_item); 773 handle 774 })); 775 // kdebug!("tcp socket:listen, socket'len={}",self.handle.len()); 776 // kdebug!("tcp socket:listen, backlog={backlog}"); 777 778 // 监听所有的socket 779 for i in 0..backlog { 780 let handle = self.handles.get(i).unwrap(); 781 782 let socket = sockets.get_mut::<tcp::Socket>(handle.smoltcp_handle().unwrap()); 783 784 if !socket.is_listening() { 785 // kdebug!("Tcp Socket is already listening on {local_endpoint}"); 786 self.do_listen(socket, local_endpoint)?; 787 } 788 // kdebug!("Tcp Socket before listen, open={}", socket.is_open()); 789 } 790 return Ok(()); 791 } 792 793 fn bind(&mut self, endpoint: Endpoint) -> Result<(), SystemError> { 794 if let Endpoint::Ip(Some(mut ip)) = endpoint { 795 if ip.port == 0 { 796 ip.port = PORT_MANAGER.get_ephemeral_port(self.metadata.socket_type)?; 797 } 798 799 // 检测端口是否已被占用 800 PORT_MANAGER.bind_port(self.metadata.socket_type, ip.port, self.clone())?; 801 // kdebug!("tcp socket:bind, socket'len={}",self.handle.len()); 802 803 self.local_endpoint = Some(ip); 804 self.is_listening = false; 805 return Ok(()); 806 } 807 return Err(SystemError::EINVAL); 808 } 809 810 fn shutdown(&mut self, shutdown_type: super::ShutdownType) -> Result<(), SystemError> { 811 // TODO:目前只是在表层判断,对端不知晓,后续需使用tcp实现 812 HANDLE_MAP 813 .write_irqsave() 814 .get_mut(&self.socket_handle()) 815 .unwrap() 816 .shutdown_type = RwLock::new(shutdown_type); 817 return Ok(()); 818 } 819 820 fn accept(&mut self) -> Result<(Box<dyn Socket>, Endpoint), SystemError> { 821 let endpoint = self.local_endpoint.ok_or(SystemError::EINVAL)?; 822 loop { 823 // kdebug!("tcp accept: poll_ifaces()"); 824 poll_ifaces(); 825 // kdebug!("tcp socket:accept, socket'len={}",self.handle.len()); 826 827 let mut sockets = SOCKET_SET.lock_irqsave(); 828 829 // 随机获取访问的socket的handle 830 let index: usize = rand() % self.handles.len(); 831 let handle = self.handles.get(index).unwrap(); 832 833 let socket = sockets 834 .iter_mut() 835 .find(|y| { 836 tcp::Socket::downcast(y.1) 837 .map(|y| y.is_active()) 838 .unwrap_or(false) 839 }) 840 .map(|y| tcp::Socket::downcast_mut(y.1).unwrap()); 841 if let Some(socket) = socket { 842 if socket.is_active() { 843 // kdebug!("tcp accept: socket.is_active()"); 844 let remote_ep = socket.remote_endpoint().ok_or(SystemError::ENOTCONN)?; 845 846 let new_socket = { 847 // The new TCP socket used for sending and receiving data. 848 let mut tcp_socket = Self::create_new_socket(); 849 self.do_listen(&mut tcp_socket, endpoint) 850 .expect("do_listen failed"); 851 852 // tcp_socket.listen(endpoint).unwrap(); 853 854 // 之所以把old_handle存入new_socket, 是因为当前时刻,smoltcp已经把old_handle对应的socket与远程的endpoint关联起来了 855 // 因此需要再为当前的socket分配一个新的handle 856 let new_handle = 857 GlobalSocketHandle::new_smoltcp_handle(sockets.add(tcp_socket)); 858 let old_handle = ::core::mem::replace( 859 &mut *self.handles.get_mut(index).unwrap(), 860 new_handle, 861 ); 862 863 let metadata = SocketMetadata::new( 864 SocketType::Tcp, 865 Self::DEFAULT_TX_BUF_SIZE, 866 Self::DEFAULT_RX_BUF_SIZE, 867 Self::DEFAULT_METADATA_BUF_SIZE, 868 self.metadata.options, 869 ); 870 871 let new_socket = Box::new(TcpSocket { 872 handles: vec![old_handle], 873 local_endpoint: self.local_endpoint, 874 is_listening: false, 875 metadata, 876 }); 877 // kdebug!("tcp socket:after accept, socket'len={}",new_socket.handle.len()); 878 879 // 更新端口与 socket 的绑定 880 if let Some(Endpoint::Ip(Some(ip))) = self.endpoint() { 881 PORT_MANAGER.unbind_port(self.metadata.socket_type, ip.port)?; 882 PORT_MANAGER.bind_port( 883 self.metadata.socket_type, 884 ip.port, 885 *new_socket.clone(), 886 )?; 887 } 888 889 // 更新handle表 890 let mut handle_guard = HANDLE_MAP.write_irqsave(); 891 // 先删除原来的 892 893 let item = handle_guard.remove(&old_handle).unwrap(); 894 895 // 按照smoltcp行为,将新的handle绑定到原来的item 896 handle_guard.insert(new_handle, item); 897 let new_item = SocketHandleItem::new(); 898 899 // 插入新的item 900 handle_guard.insert(old_handle, new_item); 901 902 new_socket 903 }; 904 // kdebug!("tcp accept: new socket: {:?}", new_socket); 905 drop(sockets); 906 poll_ifaces(); 907 908 return Ok((new_socket, Endpoint::Ip(Some(remote_ep)))); 909 } 910 } 911 // kdebug!("tcp socket:before sleep, handle_guard'len={}",HANDLE_MAP.write_irqsave().len()); 912 913 drop(sockets); 914 SocketHandleItem::sleep(*handle, Self::CAN_ACCPET, HANDLE_MAP.read_irqsave()); 915 // kdebug!("tcp socket:after sleep, handle_guard'len={}",HANDLE_MAP.write_irqsave().len()); 916 } 917 } 918 919 fn endpoint(&self) -> Option<Endpoint> { 920 let mut result: Option<Endpoint> = self.local_endpoint.map(|x| Endpoint::Ip(Some(x))); 921 922 if result.is_none() { 923 let sockets = SOCKET_SET.lock_irqsave(); 924 // kdebug!("tcp socket:endpoint, socket'len={}",self.handle.len()); 925 926 let socket = 927 sockets.get::<tcp::Socket>(self.handles.get(0).unwrap().smoltcp_handle().unwrap()); 928 if let Some(ep) = socket.local_endpoint() { 929 result = Some(Endpoint::Ip(Some(ep))); 930 } 931 } 932 return result; 933 } 934 935 fn peer_endpoint(&self) -> Option<Endpoint> { 936 let sockets = SOCKET_SET.lock_irqsave(); 937 // kdebug!("tcp socket:peer_endpoint, socket'len={}",self.handle.len()); 938 939 let socket = 940 sockets.get::<tcp::Socket>(self.handles.get(0).unwrap().smoltcp_handle().unwrap()); 941 return socket.remote_endpoint().map(|x| Endpoint::Ip(Some(x))); 942 } 943 944 fn metadata(&self) -> SocketMetadata { 945 self.metadata.clone() 946 } 947 948 fn box_clone(&self) -> Box<dyn Socket> { 949 Box::new(self.clone()) 950 } 951 952 fn socket_handle(&self) -> GlobalSocketHandle { 953 // kdebug!("tcp socket:socket_handle, socket'len={}",self.handle.len()); 954 955 *self.handles.get(0).unwrap() 956 } 957 958 fn as_any_ref(&self) -> &dyn core::any::Any { 959 self 960 } 961 962 fn as_any_mut(&mut self) -> &mut dyn core::any::Any { 963 self 964 } 965 } 966