1 use alloc::{boxed::Box, sync::Arc, vec::Vec}; 2 use smoltcp::{ 3 iface::SocketHandle, 4 socket::{raw, tcp, udp}, 5 wire, 6 }; 7 use system_error::SystemError; 8 9 use crate::{ 10 driver::net::NetDriver, 11 kerror, kwarn, 12 libs::{rwlock::RwLock, spinlock::SpinLock}, 13 net::{ 14 event_poll::EPollEventType, net_core::poll_ifaces, Endpoint, Protocol, ShutdownType, 15 NET_DRIVERS, 16 }, 17 }; 18 19 use super::{ 20 GlobalSocketHandle, Socket, SocketHandleItem, SocketMetadata, SocketOptions, SocketPollMethod, 21 SocketType, SocketpairOps, HANDLE_MAP, PORT_MANAGER, SOCKET_SET, 22 }; 23 24 /// @brief 表示原始的socket。原始套接字绕过传输层协议(如 TCP 或 UDP)并提供对网络层协议(如 IP)的直接访问。 25 /// 26 /// ref: https://man7.org/linux/man-pages/man7/raw.7.html 27 #[derive(Debug, Clone)] 28 pub struct RawSocket { 29 handle: Arc<GlobalSocketHandle>, 30 /// 用户发送的数据包是否包含了IP头. 31 /// 如果是true,用户发送的数据包,必须包含IP头。(即用户要自行设置IP头+数据) 32 /// 如果是false,用户发送的数据包,不包含IP头。(即用户只要设置数据) 33 header_included: bool, 34 /// socket的metadata 35 metadata: SocketMetadata, 36 } 37 38 impl RawSocket { 39 /// 元数据的缓冲区的大小 40 pub const DEFAULT_METADATA_BUF_SIZE: usize = 1024; 41 /// 默认的接收缓冲区的大小 receive 42 pub const DEFAULT_RX_BUF_SIZE: usize = 64 * 1024; 43 /// 默认的发送缓冲区的大小 transmiss 44 pub const DEFAULT_TX_BUF_SIZE: usize = 64 * 1024; 45 46 /// @brief 创建一个原始的socket 47 /// 48 /// @param protocol 协议号 49 /// @param options socket的选项 50 /// 51 /// @return 返回创建的原始的socket new(protocol: Protocol, options: SocketOptions) -> Self52 pub fn new(protocol: Protocol, options: SocketOptions) -> Self { 53 let rx_buffer = raw::PacketBuffer::new( 54 vec![raw::PacketMetadata::EMPTY; Self::DEFAULT_METADATA_BUF_SIZE], 55 vec![0; Self::DEFAULT_RX_BUF_SIZE], 56 ); 57 let tx_buffer = raw::PacketBuffer::new( 58 vec![raw::PacketMetadata::EMPTY; Self::DEFAULT_METADATA_BUF_SIZE], 59 vec![0; Self::DEFAULT_TX_BUF_SIZE], 60 ); 61 let protocol: u8 = protocol.into(); 62 let socket = raw::Socket::new( 63 wire::IpVersion::Ipv4, 64 wire::IpProtocol::from(protocol), 65 rx_buffer, 66 tx_buffer, 67 ); 68 69 // 把socket添加到socket集合中,并得到socket的句柄 70 let handle: Arc<GlobalSocketHandle> = 71 GlobalSocketHandle::new(SOCKET_SET.lock_irqsave().add(socket)); 72 73 let metadata = SocketMetadata::new( 74 SocketType::RawSocket, 75 Self::DEFAULT_RX_BUF_SIZE, 76 Self::DEFAULT_TX_BUF_SIZE, 77 Self::DEFAULT_METADATA_BUF_SIZE, 78 options, 79 ); 80 81 return Self { 82 handle, 83 header_included: false, 84 metadata, 85 }; 86 } 87 } 88 89 impl Socket for RawSocket { as_any_ref(&self) -> &dyn core::any::Any90 fn as_any_ref(&self) -> &dyn core::any::Any { 91 self 92 } 93 as_any_mut(&mut self) -> &mut dyn core::any::Any94 fn as_any_mut(&mut self) -> &mut dyn core::any::Any { 95 self 96 } 97 read(&mut self, buf: &mut [u8]) -> (Result<usize, SystemError>, Endpoint)98 fn read(&mut self, buf: &mut [u8]) -> (Result<usize, SystemError>, Endpoint) { 99 poll_ifaces(); 100 loop { 101 // 如何优化这里? 102 let mut socket_set_guard = SOCKET_SET.lock_irqsave(); 103 let socket = socket_set_guard.get_mut::<raw::Socket>(self.handle.0); 104 105 match socket.recv_slice(buf) { 106 Ok(len) => { 107 let packet = wire::Ipv4Packet::new_unchecked(buf); 108 return ( 109 Ok(len), 110 Endpoint::Ip(Some(wire::IpEndpoint { 111 addr: wire::IpAddress::Ipv4(packet.src_addr()), 112 port: 0, 113 })), 114 ); 115 } 116 Err(raw::RecvError::Exhausted) => { 117 if !self.metadata.options.contains(SocketOptions::BLOCK) { 118 // 如果是非阻塞的socket,就返回错误 119 return (Err(SystemError::EAGAIN_OR_EWOULDBLOCK), Endpoint::Ip(None)); 120 } 121 } 122 } 123 drop(socket_set_guard); 124 SocketHandleItem::sleep( 125 self.socket_handle(), 126 EPollEventType::EPOLLIN.bits() as u64, 127 HANDLE_MAP.read_irqsave(), 128 ); 129 } 130 } 131 write(&self, buf: &[u8], to: Option<Endpoint>) -> Result<usize, SystemError>132 fn write(&self, buf: &[u8], to: Option<Endpoint>) -> Result<usize, SystemError> { 133 // 如果用户发送的数据包,包含IP头,则直接发送 134 if self.header_included { 135 let mut socket_set_guard = SOCKET_SET.lock_irqsave(); 136 let socket = socket_set_guard.get_mut::<raw::Socket>(self.handle.0); 137 match socket.send_slice(buf) { 138 Ok(_) => { 139 return Ok(buf.len()); 140 } 141 Err(raw::SendError::BufferFull) => { 142 return Err(SystemError::ENOBUFS); 143 } 144 } 145 } else { 146 // 如果用户发送的数据包,不包含IP头,则需要自己构造IP头 147 148 if let Some(Endpoint::Ip(Some(endpoint))) = to { 149 let mut socket_set_guard = SOCKET_SET.lock_irqsave(); 150 let socket: &mut raw::Socket = 151 socket_set_guard.get_mut::<raw::Socket>(self.handle.0); 152 153 // 暴力解决方案:只考虑0号网卡。 TODO:考虑多网卡的情况!!! 154 let iface = NET_DRIVERS.read_irqsave().get(&0).unwrap().clone(); 155 156 // 构造IP头 157 let ipv4_src_addr: Option<wire::Ipv4Address> = 158 iface.inner_iface().lock().ipv4_addr(); 159 if ipv4_src_addr.is_none() { 160 return Err(SystemError::ENETUNREACH); 161 } 162 let ipv4_src_addr = ipv4_src_addr.unwrap(); 163 164 if let wire::IpAddress::Ipv4(ipv4_dst) = endpoint.addr { 165 let len = buf.len(); 166 167 // 创建20字节的IPv4头部 168 let mut buffer: Vec<u8> = vec![0u8; len + 20]; 169 let mut packet: wire::Ipv4Packet<&mut Vec<u8>> = 170 wire::Ipv4Packet::new_unchecked(&mut buffer); 171 172 // 封装ipv4 header 173 packet.set_version(4); 174 packet.set_header_len(20); 175 packet.set_total_len((20 + len) as u16); 176 packet.set_src_addr(ipv4_src_addr); 177 packet.set_dst_addr(ipv4_dst); 178 179 // 设置ipv4 header的protocol字段 180 packet.set_next_header(socket.ip_protocol().into()); 181 182 // 获取IP数据包的负载字段 183 let payload: &mut [u8] = packet.payload_mut(); 184 payload.copy_from_slice(buf); 185 186 // 填充checksum字段 187 packet.fill_checksum(); 188 189 // 发送数据包 190 socket.send_slice(&buffer).unwrap(); 191 192 iface.poll(&mut socket_set_guard).ok(); 193 194 drop(socket_set_guard); 195 return Ok(len); 196 } else { 197 kwarn!("Unsupport Ip protocol type!"); 198 return Err(SystemError::EINVAL); 199 } 200 } else { 201 // 如果没有指定目的地址,则返回错误 202 return Err(SystemError::ENOTCONN); 203 } 204 } 205 } 206 connect(&mut self, _endpoint: Endpoint) -> Result<(), SystemError>207 fn connect(&mut self, _endpoint: Endpoint) -> Result<(), SystemError> { 208 Ok(()) 209 } 210 metadata(&self) -> Result<SocketMetadata, SystemError>211 fn metadata(&self) -> Result<SocketMetadata, SystemError> { 212 Ok(self.metadata.clone()) 213 } 214 box_clone(&self) -> Box<dyn Socket>215 fn box_clone(&self) -> Box<dyn Socket> { 216 return Box::new(self.clone()); 217 } 218 socket_handle(&self) -> SocketHandle219 fn socket_handle(&self) -> SocketHandle { 220 self.handle.0 221 } 222 } 223 224 /// @brief 表示udp socket 225 /// 226 /// https://man7.org/linux/man-pages/man7/udp.7.html 227 #[derive(Debug, Clone)] 228 pub struct UdpSocket { 229 pub handle: Arc<GlobalSocketHandle>, 230 remote_endpoint: Option<Endpoint>, // 记录远程endpoint提供给connect(), 应该使用IP地址。 231 metadata: SocketMetadata, 232 } 233 234 impl UdpSocket { 235 /// 元数据的缓冲区的大小 236 pub const DEFAULT_METADATA_BUF_SIZE: usize = 1024; 237 /// 默认的接收缓冲区的大小 receive 238 pub const DEFAULT_RX_BUF_SIZE: usize = 64 * 1024; 239 /// 默认的发送缓冲区的大小 transmiss 240 pub const DEFAULT_TX_BUF_SIZE: usize = 64 * 1024; 241 242 /// @brief 创建一个udp的socket 243 /// 244 /// @param options socket的选项 245 /// 246 /// @return 返回创建的udp的socket new(options: SocketOptions) -> Self247 pub fn new(options: SocketOptions) -> Self { 248 let rx_buffer = udp::PacketBuffer::new( 249 vec![udp::PacketMetadata::EMPTY; Self::DEFAULT_METADATA_BUF_SIZE], 250 vec![0; Self::DEFAULT_RX_BUF_SIZE], 251 ); 252 let tx_buffer = udp::PacketBuffer::new( 253 vec![udp::PacketMetadata::EMPTY; Self::DEFAULT_METADATA_BUF_SIZE], 254 vec![0; Self::DEFAULT_TX_BUF_SIZE], 255 ); 256 let socket = udp::Socket::new(rx_buffer, tx_buffer); 257 258 // 把socket添加到socket集合中,并得到socket的句柄 259 let handle: Arc<GlobalSocketHandle> = 260 GlobalSocketHandle::new(SOCKET_SET.lock_irqsave().add(socket)); 261 262 let metadata = SocketMetadata::new( 263 SocketType::UdpSocket, 264 Self::DEFAULT_RX_BUF_SIZE, 265 Self::DEFAULT_TX_BUF_SIZE, 266 Self::DEFAULT_METADATA_BUF_SIZE, 267 options, 268 ); 269 270 return Self { 271 handle, 272 remote_endpoint: None, 273 metadata, 274 }; 275 } 276 do_bind(&self, socket: &mut udp::Socket, endpoint: Endpoint) -> Result<(), SystemError>277 fn do_bind(&self, socket: &mut udp::Socket, endpoint: Endpoint) -> Result<(), SystemError> { 278 if let Endpoint::Ip(Some(ip)) = endpoint { 279 // 检测端口是否已被占用 280 PORT_MANAGER.bind_port(self.metadata.socket_type, ip.port, self.handle.clone())?; 281 282 let bind_res = if ip.addr.is_unspecified() { 283 socket.bind(ip.port) 284 } else { 285 socket.bind(ip) 286 }; 287 288 match bind_res { 289 Ok(()) => return Ok(()), 290 Err(_) => return Err(SystemError::EINVAL), 291 } 292 } else { 293 return Err(SystemError::EINVAL); 294 } 295 } 296 } 297 298 impl Socket for UdpSocket { as_any_ref(&self) -> &dyn core::any::Any299 fn as_any_ref(&self) -> &dyn core::any::Any { 300 self 301 } 302 as_any_mut(&mut self) -> &mut dyn core::any::Any303 fn as_any_mut(&mut self) -> &mut dyn core::any::Any { 304 self 305 } 306 307 /// @brief 在read函数执行之前,请先bind到本地的指定端口 read(&mut self, buf: &mut [u8]) -> (Result<usize, SystemError>, Endpoint)308 fn read(&mut self, buf: &mut [u8]) -> (Result<usize, SystemError>, Endpoint) { 309 loop { 310 // kdebug!("Wait22 to Read"); 311 poll_ifaces(); 312 let mut socket_set_guard = SOCKET_SET.lock_irqsave(); 313 let socket = socket_set_guard.get_mut::<udp::Socket>(self.handle.0); 314 315 // kdebug!("Wait to Read"); 316 317 if socket.can_recv() { 318 if let Ok((size, remote_endpoint)) = socket.recv_slice(buf) { 319 drop(socket_set_guard); 320 poll_ifaces(); 321 return (Ok(size), Endpoint::Ip(Some(remote_endpoint))); 322 } 323 } else { 324 // 如果socket没有连接,则忙等 325 // return (Err(SystemError::ENOTCONN), Endpoint::Ip(None)); 326 } 327 drop(socket_set_guard); 328 SocketHandleItem::sleep( 329 self.socket_handle(), 330 EPollEventType::EPOLLIN.bits() as u64, 331 HANDLE_MAP.read_irqsave(), 332 ); 333 } 334 } 335 write(&self, buf: &[u8], to: Option<Endpoint>) -> Result<usize, SystemError>336 fn write(&self, buf: &[u8], to: Option<Endpoint>) -> Result<usize, SystemError> { 337 // kdebug!("udp to send: {:?}, len={}", to, buf.len()); 338 let remote_endpoint: &wire::IpEndpoint = { 339 if let Some(Endpoint::Ip(Some(ref endpoint))) = to { 340 endpoint 341 } else if let Some(Endpoint::Ip(Some(ref endpoint))) = self.remote_endpoint { 342 endpoint 343 } else { 344 return Err(SystemError::ENOTCONN); 345 } 346 }; 347 // kdebug!("udp write: remote = {:?}", remote_endpoint); 348 349 let mut socket_set_guard = SOCKET_SET.lock_irqsave(); 350 let socket = socket_set_guard.get_mut::<udp::Socket>(self.handle.0); 351 // kdebug!("is open()={}", socket.is_open()); 352 // kdebug!("socket endpoint={:?}", socket.endpoint()); 353 if socket.endpoint().port == 0 { 354 let temp_port = PORT_MANAGER.get_ephemeral_port(self.metadata.socket_type)?; 355 356 let local_ep = match remote_endpoint.addr { 357 // 远程remote endpoint使用什么协议,发送的时候使用的协议是一样的吧 358 // 否则就用 self.endpoint().addr.unwrap() 359 wire::IpAddress::Ipv4(_) => Endpoint::Ip(Some(wire::IpEndpoint::new( 360 wire::IpAddress::Ipv4(wire::Ipv4Address::UNSPECIFIED), 361 temp_port, 362 ))), 363 wire::IpAddress::Ipv6(_) => Endpoint::Ip(Some(wire::IpEndpoint::new( 364 wire::IpAddress::Ipv6(wire::Ipv6Address::UNSPECIFIED), 365 temp_port, 366 ))), 367 }; 368 // kdebug!("udp write: local_ep = {:?}", local_ep); 369 self.do_bind(socket, local_ep)?; 370 } 371 // kdebug!("is open()={}", socket.is_open()); 372 if socket.can_send() { 373 // kdebug!("udp write: can send"); 374 match socket.send_slice(&buf, *remote_endpoint) { 375 Ok(()) => { 376 // kdebug!("udp write: send ok"); 377 drop(socket_set_guard); 378 poll_ifaces(); 379 return Ok(buf.len()); 380 } 381 Err(_) => { 382 // kdebug!("udp write: send err"); 383 return Err(SystemError::ENOBUFS); 384 } 385 } 386 } else { 387 // kdebug!("udp write: can not send"); 388 return Err(SystemError::ENOBUFS); 389 }; 390 } 391 bind(&mut self, endpoint: Endpoint) -> Result<(), SystemError>392 fn bind(&mut self, endpoint: Endpoint) -> Result<(), SystemError> { 393 let mut sockets = SOCKET_SET.lock_irqsave(); 394 let socket = sockets.get_mut::<udp::Socket>(self.handle.0); 395 // kdebug!("UDP Bind to {:?}", endpoint); 396 return self.do_bind(socket, endpoint); 397 } 398 poll(&self) -> EPollEventType399 fn poll(&self) -> EPollEventType { 400 let sockets = SOCKET_SET.lock_irqsave(); 401 let socket = sockets.get::<udp::Socket>(self.handle.0); 402 403 return SocketPollMethod::udp_poll( 404 socket, 405 HANDLE_MAP 406 .read_irqsave() 407 .get(&self.socket_handle()) 408 .unwrap() 409 .shutdown_type(), 410 ); 411 } 412 connect(&mut self, endpoint: Endpoint) -> Result<(), SystemError>413 fn connect(&mut self, endpoint: Endpoint) -> Result<(), SystemError> { 414 if let Endpoint::Ip(_) = endpoint { 415 self.remote_endpoint = Some(endpoint); 416 return Ok(()); 417 } else { 418 return Err(SystemError::EINVAL); 419 }; 420 } 421 ioctl( &self, _cmd: usize, _arg0: usize, _arg1: usize, _arg2: usize, ) -> Result<usize, SystemError>422 fn ioctl( 423 &self, 424 _cmd: usize, 425 _arg0: usize, 426 _arg1: usize, 427 _arg2: usize, 428 ) -> Result<usize, SystemError> { 429 todo!() 430 } 431 metadata(&self) -> Result<SocketMetadata, SystemError>432 fn metadata(&self) -> Result<SocketMetadata, SystemError> { 433 Ok(self.metadata.clone()) 434 } 435 box_clone(&self) -> Box<dyn Socket>436 fn box_clone(&self) -> Box<dyn Socket> { 437 return Box::new(self.clone()); 438 } 439 endpoint(&self) -> Option<Endpoint>440 fn endpoint(&self) -> Option<Endpoint> { 441 let sockets = SOCKET_SET.lock_irqsave(); 442 let socket = sockets.get::<udp::Socket>(self.handle.0); 443 let listen_endpoint = socket.endpoint(); 444 445 if listen_endpoint.port == 0 { 446 return None; 447 } else { 448 // 如果listen_endpoint的address是None,意味着“监听所有的地址”。 449 // 这里假设所有的地址都是ipv4 450 // TODO: 支持ipv6 451 let result = wire::IpEndpoint::new( 452 listen_endpoint 453 .addr 454 .unwrap_or(wire::IpAddress::v4(0, 0, 0, 0)), 455 listen_endpoint.port, 456 ); 457 return Some(Endpoint::Ip(Some(result))); 458 } 459 } 460 peer_endpoint(&self) -> Option<Endpoint>461 fn peer_endpoint(&self) -> Option<Endpoint> { 462 return self.remote_endpoint.clone(); 463 } 464 socket_handle(&self) -> SocketHandle465 fn socket_handle(&self) -> SocketHandle { 466 self.handle.0 467 } 468 } 469 470 /// @brief 表示 tcp socket 471 /// 472 /// https://man7.org/linux/man-pages/man7/tcp.7.html 473 #[derive(Debug, Clone)] 474 pub struct TcpSocket { 475 handle: Arc<GlobalSocketHandle>, 476 local_endpoint: Option<wire::IpEndpoint>, // save local endpoint for bind() 477 is_listening: bool, 478 metadata: SocketMetadata, 479 } 480 481 impl TcpSocket { 482 /// 元数据的缓冲区的大小 483 pub const DEFAULT_METADATA_BUF_SIZE: usize = 1024; 484 /// 默认的接收缓冲区的大小 receive 485 pub const DEFAULT_RX_BUF_SIZE: usize = 512 * 1024; 486 /// 默认的发送缓冲区的大小 transmiss 487 pub const DEFAULT_TX_BUF_SIZE: usize = 512 * 1024; 488 489 /// TcpSocket的特殊事件,用于在事件等待队列上sleep 490 pub const CAN_CONNECT: u64 = 1u64 << 63; 491 pub const CAN_ACCPET: u64 = 1u64 << 62; 492 493 /// @brief 创建一个tcp的socket 494 /// 495 /// @param options socket的选项 496 /// 497 /// @return 返回创建的tcp的socket new(options: SocketOptions) -> Self498 pub fn new(options: SocketOptions) -> Self { 499 let rx_buffer = tcp::SocketBuffer::new(vec![0; Self::DEFAULT_RX_BUF_SIZE]); 500 let tx_buffer = tcp::SocketBuffer::new(vec![0; Self::DEFAULT_TX_BUF_SIZE]); 501 let socket = tcp::Socket::new(rx_buffer, tx_buffer); 502 503 // 把socket添加到socket集合中,并得到socket的句柄 504 let handle: Arc<GlobalSocketHandle> = 505 GlobalSocketHandle::new(SOCKET_SET.lock_irqsave().add(socket)); 506 507 let metadata = SocketMetadata::new( 508 SocketType::TcpSocket, 509 Self::DEFAULT_RX_BUF_SIZE, 510 Self::DEFAULT_TX_BUF_SIZE, 511 Self::DEFAULT_METADATA_BUF_SIZE, 512 options, 513 ); 514 515 return Self { 516 handle, 517 local_endpoint: None, 518 is_listening: false, 519 metadata, 520 }; 521 } do_listen( &mut self, socket: &mut tcp::Socket, local_endpoint: wire::IpEndpoint, ) -> Result<(), SystemError>522 fn do_listen( 523 &mut self, 524 socket: &mut tcp::Socket, 525 local_endpoint: wire::IpEndpoint, 526 ) -> Result<(), SystemError> { 527 let listen_result = if local_endpoint.addr.is_unspecified() { 528 // kdebug!("Tcp Socket Listen on port {}", local_endpoint.port); 529 socket.listen(local_endpoint.port) 530 } else { 531 // kdebug!("Tcp Socket Listen on {local_endpoint}"); 532 socket.listen(local_endpoint) 533 }; 534 // TODO: 增加端口占用检查 535 return match listen_result { 536 Ok(()) => { 537 // kdebug!( 538 // "Tcp Socket Listen on {local_endpoint}, open?:{}", 539 // socket.is_open() 540 // ); 541 self.is_listening = true; 542 543 Ok(()) 544 } 545 Err(_) => Err(SystemError::EINVAL), 546 }; 547 } 548 } 549 550 impl Socket for TcpSocket { as_any_ref(&self) -> &dyn core::any::Any551 fn as_any_ref(&self) -> &dyn core::any::Any { 552 self 553 } 554 as_any_mut(&mut self) -> &mut dyn core::any::Any555 fn as_any_mut(&mut self) -> &mut dyn core::any::Any { 556 self 557 } 558 read(&mut self, buf: &mut [u8]) -> (Result<usize, SystemError>, Endpoint)559 fn read(&mut self, buf: &mut [u8]) -> (Result<usize, SystemError>, Endpoint) { 560 if HANDLE_MAP 561 .read_irqsave() 562 .get(&self.socket_handle()) 563 .unwrap() 564 .shutdown_type() 565 .contains(ShutdownType::RCV_SHUTDOWN) 566 { 567 return (Err(SystemError::ENOTCONN), Endpoint::Ip(None)); 568 } 569 // kdebug!("tcp socket: read, buf len={}", buf.len()); 570 571 loop { 572 poll_ifaces(); 573 let mut socket_set_guard = SOCKET_SET.lock_irqsave(); 574 let socket = socket_set_guard.get_mut::<tcp::Socket>(self.handle.0); 575 576 // 如果socket已经关闭,返回错误 577 if !socket.is_active() { 578 // kdebug!("Tcp Socket Read Error, socket is closed"); 579 return (Err(SystemError::ENOTCONN), Endpoint::Ip(None)); 580 } 581 582 if socket.may_recv() { 583 let recv_res = socket.recv_slice(buf); 584 585 if let Ok(size) = recv_res { 586 if size > 0 { 587 let endpoint = if let Some(p) = socket.remote_endpoint() { 588 p 589 } else { 590 return (Err(SystemError::ENOTCONN), Endpoint::Ip(None)); 591 }; 592 593 drop(socket_set_guard); 594 poll_ifaces(); 595 return (Ok(size), Endpoint::Ip(Some(endpoint))); 596 } 597 } else { 598 let err = recv_res.unwrap_err(); 599 match err { 600 tcp::RecvError::InvalidState => { 601 kwarn!("Tcp Socket Read Error, InvalidState"); 602 return (Err(SystemError::ENOTCONN), Endpoint::Ip(None)); 603 } 604 tcp::RecvError::Finished => { 605 // 对端写端已关闭,我们应该关闭读端 606 HANDLE_MAP 607 .write_irqsave() 608 .get_mut(&self.socket_handle()) 609 .unwrap() 610 .shutdown_type_writer() 611 .insert(ShutdownType::RCV_SHUTDOWN); 612 return (Err(SystemError::ENOTCONN), Endpoint::Ip(None)); 613 } 614 } 615 } 616 } else { 617 return (Err(SystemError::ENOTCONN), Endpoint::Ip(None)); 618 } 619 drop(socket_set_guard); 620 SocketHandleItem::sleep( 621 self.socket_handle(), 622 EPollEventType::EPOLLIN.bits() as u64, 623 HANDLE_MAP.read_irqsave(), 624 ); 625 } 626 } 627 write(&self, buf: &[u8], _to: Option<Endpoint>) -> Result<usize, SystemError>628 fn write(&self, buf: &[u8], _to: Option<Endpoint>) -> Result<usize, SystemError> { 629 if HANDLE_MAP 630 .read_irqsave() 631 .get(&self.socket_handle()) 632 .unwrap() 633 .shutdown_type() 634 .contains(ShutdownType::RCV_SHUTDOWN) 635 { 636 return Err(SystemError::ENOTCONN); 637 } 638 let mut socket_set_guard = SOCKET_SET.lock_irqsave(); 639 let socket = socket_set_guard.get_mut::<tcp::Socket>(self.handle.0); 640 641 if socket.is_open() { 642 if socket.can_send() { 643 match socket.send_slice(buf) { 644 Ok(size) => { 645 drop(socket_set_guard); 646 poll_ifaces(); 647 return Ok(size); 648 } 649 Err(e) => { 650 kerror!("Tcp Socket Write Error {e:?}"); 651 return Err(SystemError::ENOBUFS); 652 } 653 } 654 } else { 655 return Err(SystemError::ENOBUFS); 656 } 657 } 658 659 return Err(SystemError::ENOTCONN); 660 } 661 poll(&self) -> EPollEventType662 fn poll(&self) -> EPollEventType { 663 let mut socket_set_guard = SOCKET_SET.lock_irqsave(); 664 let socket = socket_set_guard.get_mut::<tcp::Socket>(self.handle.0); 665 666 return SocketPollMethod::tcp_poll( 667 socket, 668 HANDLE_MAP 669 .read_irqsave() 670 .get(&self.socket_handle()) 671 .unwrap() 672 .shutdown_type(), 673 ); 674 } 675 connect(&mut self, endpoint: Endpoint) -> Result<(), SystemError>676 fn connect(&mut self, endpoint: Endpoint) -> Result<(), SystemError> { 677 let mut sockets = SOCKET_SET.lock_irqsave(); 678 let socket = sockets.get_mut::<tcp::Socket>(self.handle.0); 679 680 if let Endpoint::Ip(Some(ip)) = endpoint { 681 let temp_port = PORT_MANAGER.get_ephemeral_port(self.metadata.socket_type)?; 682 // 检测端口是否被占用 683 PORT_MANAGER.bind_port(self.metadata.socket_type, temp_port, self.handle.clone())?; 684 685 // kdebug!("temp_port: {}", temp_port); 686 let iface: Arc<dyn NetDriver> = NET_DRIVERS.write_irqsave().get(&0).unwrap().clone(); 687 let mut inner_iface = iface.inner_iface().lock(); 688 // kdebug!("to connect: {ip:?}"); 689 690 match socket.connect(&mut inner_iface.context(), ip, temp_port) { 691 Ok(()) => { 692 // avoid deadlock 693 drop(inner_iface); 694 drop(iface); 695 drop(sockets); 696 loop { 697 poll_ifaces(); 698 let mut sockets = SOCKET_SET.lock_irqsave(); 699 let socket = sockets.get_mut::<tcp::Socket>(self.handle.0); 700 701 match socket.state() { 702 tcp::State::Established => { 703 return Ok(()); 704 } 705 tcp::State::SynSent => { 706 drop(sockets); 707 SocketHandleItem::sleep( 708 self.socket_handle(), 709 Self::CAN_CONNECT, 710 HANDLE_MAP.read_irqsave(), 711 ); 712 } 713 _ => { 714 return Err(SystemError::ECONNREFUSED); 715 } 716 } 717 } 718 } 719 Err(e) => { 720 // kerror!("Tcp Socket Connect Error {e:?}"); 721 match e { 722 tcp::ConnectError::InvalidState => return Err(SystemError::EISCONN), 723 tcp::ConnectError::Unaddressable => return Err(SystemError::EADDRNOTAVAIL), 724 } 725 } 726 } 727 } else { 728 return Err(SystemError::EINVAL); 729 } 730 } 731 732 /// @brief tcp socket 监听 local_endpoint 端口 733 /// 734 /// @param backlog 未处理的连接队列的最大长度. 由于smoltcp不支持backlog,所以这个参数目前无效 listen(&mut self, _backlog: usize) -> Result<(), SystemError>735 fn listen(&mut self, _backlog: usize) -> Result<(), SystemError> { 736 if self.is_listening { 737 return Ok(()); 738 } 739 740 let local_endpoint = self.local_endpoint.ok_or(SystemError::EINVAL)?; 741 let mut sockets = SOCKET_SET.lock_irqsave(); 742 let socket = sockets.get_mut::<tcp::Socket>(self.handle.0); 743 744 if socket.is_listening() { 745 // kdebug!("Tcp Socket is already listening on {local_endpoint}"); 746 return Ok(()); 747 } 748 // kdebug!("Tcp Socket before listen, open={}", socket.is_open()); 749 return self.do_listen(socket, local_endpoint); 750 } 751 bind(&mut self, endpoint: Endpoint) -> Result<(), SystemError>752 fn bind(&mut self, endpoint: Endpoint) -> Result<(), SystemError> { 753 if let Endpoint::Ip(Some(mut ip)) = endpoint { 754 if ip.port == 0 { 755 ip.port = PORT_MANAGER.get_ephemeral_port(self.metadata.socket_type)?; 756 } 757 758 // 检测端口是否已被占用 759 PORT_MANAGER.bind_port(self.metadata.socket_type, ip.port, self.handle.clone())?; 760 761 self.local_endpoint = Some(ip); 762 self.is_listening = false; 763 return Ok(()); 764 } 765 return Err(SystemError::EINVAL); 766 } 767 shutdown(&mut self, shutdown_type: super::ShutdownType) -> Result<(), SystemError>768 fn shutdown(&mut self, shutdown_type: super::ShutdownType) -> Result<(), SystemError> { 769 // TODO:目前只是在表层判断,对端不知晓,后续需使用tcp实现 770 HANDLE_MAP 771 .write_irqsave() 772 .get_mut(&self.socket_handle()) 773 .unwrap() 774 .shutdown_type = RwLock::new(shutdown_type); 775 return Ok(()); 776 } 777 accept(&mut self) -> Result<(Box<dyn Socket>, Endpoint), SystemError>778 fn accept(&mut self) -> Result<(Box<dyn Socket>, Endpoint), SystemError> { 779 let endpoint = self.local_endpoint.ok_or(SystemError::EINVAL)?; 780 loop { 781 // kdebug!("tcp accept: poll_ifaces()"); 782 poll_ifaces(); 783 784 let mut sockets = SOCKET_SET.lock_irqsave(); 785 786 let socket = sockets.get_mut::<tcp::Socket>(self.handle.0); 787 788 if socket.is_active() { 789 // kdebug!("tcp accept: socket.is_active()"); 790 let remote_ep = socket.remote_endpoint().ok_or(SystemError::ENOTCONN)?; 791 792 let new_socket = { 793 // Initialize the TCP socket's buffers. 794 let rx_buffer = tcp::SocketBuffer::new(vec![0; Self::DEFAULT_RX_BUF_SIZE]); 795 let tx_buffer = tcp::SocketBuffer::new(vec![0; Self::DEFAULT_TX_BUF_SIZE]); 796 // The new TCP socket used for sending and receiving data. 797 let mut tcp_socket = tcp::Socket::new(rx_buffer, tx_buffer); 798 self.do_listen(&mut tcp_socket, endpoint) 799 .expect("do_listen failed"); 800 801 // tcp_socket.listen(endpoint).unwrap(); 802 803 // 之所以把old_handle存入new_socket, 是因为当前时刻,smoltcp已经把old_handle对应的socket与远程的endpoint关联起来了 804 // 因此需要再为当前的socket分配一个新的handle 805 let new_handle = GlobalSocketHandle::new(sockets.add(tcp_socket)); 806 let old_handle = ::core::mem::replace(&mut self.handle, new_handle.clone()); 807 808 // 更新端口与 handle 的绑定 809 if let Some(Endpoint::Ip(Some(ip))) = self.endpoint() { 810 PORT_MANAGER.unbind_port(self.metadata.socket_type, ip.port)?; 811 PORT_MANAGER.bind_port( 812 self.metadata.socket_type, 813 ip.port, 814 new_handle.clone(), 815 )?; 816 } 817 818 let metadata = SocketMetadata::new( 819 SocketType::TcpSocket, 820 Self::DEFAULT_TX_BUF_SIZE, 821 Self::DEFAULT_RX_BUF_SIZE, 822 Self::DEFAULT_METADATA_BUF_SIZE, 823 self.metadata.options, 824 ); 825 826 let new_socket = Box::new(TcpSocket { 827 handle: old_handle.clone(), 828 local_endpoint: self.local_endpoint, 829 is_listening: false, 830 metadata, 831 }); 832 833 // 更新handle表 834 let mut handle_guard = HANDLE_MAP.write_irqsave(); 835 // 先删除原来的 836 let item = handle_guard.remove(&old_handle.0).unwrap(); 837 // 按照smoltcp行为,将新的handle绑定到原来的item 838 handle_guard.insert(new_handle.0, item); 839 let new_item = SocketHandleItem::from_socket(&new_socket); 840 // 插入新的item 841 handle_guard.insert(old_handle.0, new_item); 842 843 new_socket 844 }; 845 // kdebug!("tcp accept: new socket: {:?}", new_socket); 846 drop(sockets); 847 poll_ifaces(); 848 849 return Ok((new_socket, Endpoint::Ip(Some(remote_ep)))); 850 } 851 drop(sockets); 852 853 SocketHandleItem::sleep( 854 self.socket_handle(), 855 Self::CAN_ACCPET, 856 HANDLE_MAP.read_irqsave(), 857 ); 858 } 859 } 860 endpoint(&self) -> Option<Endpoint>861 fn endpoint(&self) -> Option<Endpoint> { 862 let mut result: Option<Endpoint> = 863 self.local_endpoint.clone().map(|x| Endpoint::Ip(Some(x))); 864 865 if result.is_none() { 866 let sockets = SOCKET_SET.lock_irqsave(); 867 let socket = sockets.get::<tcp::Socket>(self.handle.0); 868 if let Some(ep) = socket.local_endpoint() { 869 result = Some(Endpoint::Ip(Some(ep))); 870 } 871 } 872 return result; 873 } 874 peer_endpoint(&self) -> Option<Endpoint>875 fn peer_endpoint(&self) -> Option<Endpoint> { 876 let sockets = SOCKET_SET.lock_irqsave(); 877 let socket = sockets.get::<tcp::Socket>(self.handle.0); 878 return socket.remote_endpoint().map(|x| Endpoint::Ip(Some(x))); 879 } 880 metadata(&self) -> Result<SocketMetadata, SystemError>881 fn metadata(&self) -> Result<SocketMetadata, SystemError> { 882 Ok(self.metadata.clone()) 883 } 884 box_clone(&self) -> Box<dyn Socket>885 fn box_clone(&self) -> Box<dyn Socket> { 886 return Box::new(self.clone()); 887 } 888 socket_handle(&self) -> SocketHandle889 fn socket_handle(&self) -> SocketHandle { 890 self.handle.0 891 } 892 } 893 894 /// # 表示 seqpacket socket 895 #[derive(Debug, Clone)] 896 #[cast_to(Socket)] 897 pub struct SeqpacketSocket { 898 metadata: SocketMetadata, 899 buffer: Arc<SpinLock<Vec<u8>>>, 900 peer_buffer: Option<Arc<SpinLock<Vec<u8>>>>, 901 } 902 903 impl SeqpacketSocket { 904 /// 默认的元数据缓冲区大小 905 pub const DEFAULT_METADATA_BUF_SIZE: usize = 1024; 906 /// 默认的缓冲区大小 907 pub const DEFAULT_BUF_SIZE: usize = 64 * 1024; 908 909 /// # 创建一个seqpacket的socket 910 /// 911 /// ## 参数 912 /// - `options`: socket的选项 new(options: SocketOptions) -> Self913 pub fn new(options: SocketOptions) -> Self { 914 let buffer = Vec::with_capacity(Self::DEFAULT_BUF_SIZE); 915 916 let metadata = SocketMetadata::new( 917 SocketType::SeqpacketSocket, 918 Self::DEFAULT_BUF_SIZE, 919 0, 920 Self::DEFAULT_METADATA_BUF_SIZE, 921 options, 922 ); 923 924 return Self { 925 metadata, 926 buffer: Arc::new(SpinLock::new(buffer)), 927 peer_buffer: None, 928 }; 929 } 930 buffer(&self) -> Arc<SpinLock<Vec<u8>>>931 fn buffer(&self) -> Arc<SpinLock<Vec<u8>>> { 932 self.buffer.clone() 933 } 934 set_peer_buffer(&mut self, peer_buffer: Arc<SpinLock<Vec<u8>>>)935 fn set_peer_buffer(&mut self, peer_buffer: Arc<SpinLock<Vec<u8>>>) { 936 self.peer_buffer = Some(peer_buffer); 937 } 938 } 939 940 impl Socket for SeqpacketSocket { as_any_ref(&self) -> &dyn core::any::Any941 fn as_any_ref(&self) -> &dyn core::any::Any { 942 self 943 } 944 as_any_mut(&mut self) -> &mut dyn core::any::Any945 fn as_any_mut(&mut self) -> &mut dyn core::any::Any { 946 self 947 } 948 read(&mut self, buf: &mut [u8]) -> (Result<usize, SystemError>, Endpoint)949 fn read(&mut self, buf: &mut [u8]) -> (Result<usize, SystemError>, Endpoint) { 950 let buffer = self.buffer.lock_irqsave(); 951 952 let len = core::cmp::min(buf.len(), buffer.len()); 953 buf[..len].copy_from_slice(&buffer[..len]); 954 955 (Ok(len), Endpoint::Unused) 956 } 957 write(&self, buf: &[u8], _to: Option<Endpoint>) -> Result<usize, SystemError>958 fn write(&self, buf: &[u8], _to: Option<Endpoint>) -> Result<usize, SystemError> { 959 if self.peer_buffer.is_none() { 960 kwarn!("SeqpacketSocket is now just for socketpair"); 961 return Err(SystemError::ENOSYS); 962 } 963 964 let binding = self.peer_buffer.clone().unwrap(); 965 let mut peer_buffer = binding.lock_irqsave(); 966 967 let len = buf.len(); 968 if peer_buffer.capacity() - peer_buffer.len() < len { 969 return Err(SystemError::ENOBUFS); 970 } 971 peer_buffer[..len].copy_from_slice(buf); 972 973 Ok(len) 974 } 975 socketpair_ops(&self) -> Option<&'static dyn SocketpairOps>976 fn socketpair_ops(&self) -> Option<&'static dyn SocketpairOps> { 977 Some(&SeqpacketSocketpairOps) 978 } 979 metadata(&self) -> Result<SocketMetadata, SystemError>980 fn metadata(&self) -> Result<SocketMetadata, SystemError> { 981 Ok(self.metadata.clone()) 982 } 983 box_clone(&self) -> Box<dyn Socket>984 fn box_clone(&self) -> Box<dyn Socket> { 985 Box::new(self.clone()) 986 } 987 } 988 989 struct SeqpacketSocketpairOps; 990 991 impl SocketpairOps for SeqpacketSocketpairOps { socketpair(&self, socket0: &mut Box<dyn Socket>, socket1: &mut Box<dyn Socket>)992 fn socketpair(&self, socket0: &mut Box<dyn Socket>, socket1: &mut Box<dyn Socket>) { 993 let pair0 = socket0 994 .as_mut() 995 .as_any_mut() 996 .downcast_mut::<SeqpacketSocket>() 997 .unwrap(); 998 999 let pair1 = socket1 1000 .as_mut() 1001 .as_any_mut() 1002 .downcast_mut::<SeqpacketSocket>() 1003 .unwrap(); 1004 pair0.set_peer_buffer(pair1.buffer()); 1005 pair1.set_peer_buffer(pair0.buffer()); 1006 } 1007 } 1008