//! P2P 连接模块 use anyhow::Result; use bytes::Bytes; use easyremote_common::protocol::{FrameData, FrameFormat, InputEvent}; use std::net::SocketAddr; use std::sync::Arc; use tokio::net::UdpSocket; use tokio::sync::{mpsc, RwLock}; /// 连接状态 #[derive(Debug, Clone, PartialEq)] pub enum ConnectionState { /// 断开连接 Disconnected, /// 正在连接 Connecting, /// 已连接 Connected, /// 连接失败 Failed(String), } /// P2P 连接管理器 pub struct P2PConnection { state: Arc>, local_addr: Option, remote_addr: Option, socket: Option>, frame_tx: Option>, input_tx: Option>, } impl P2PConnection { pub fn new() -> Self { Self { state: Arc::new(RwLock::new(ConnectionState::Disconnected)), local_addr: None, remote_addr: None, socket: None, frame_tx: None, input_tx: None, } } /// 获取连接状态 pub async fn state(&self) -> ConnectionState { self.state.read().await.clone() } /// 初始化本地套接字 pub async fn init_socket(&mut self) -> Result { let socket = UdpSocket::bind("0.0.0.0:0").await?; let local_addr = socket.local_addr()?; self.local_addr = Some(local_addr); self.socket = Some(Arc::new(socket)); Ok(local_addr) } /// 尝试连接到远程地址 pub async fn connect(&mut self, remote_addr: SocketAddr) -> Result<()> { *self.state.write().await = ConnectionState::Connecting; if let Some(socket) = &self.socket { socket.connect(remote_addr).await?; self.remote_addr = Some(remote_addr); *self.state.write().await = ConnectionState::Connected; } else { *self.state.write().await = ConnectionState::Failed("Socket not initialized".into()); anyhow::bail!("Socket not initialized"); } Ok(()) } /// 发送帧数据 pub async fn send_frame(&self, frame: &FrameData) -> Result<()> { if let Some(socket) = &self.socket { let data = bincode::serialize(frame)?; // 如果数据太大,需要分片发送 if data.len() > 65000 { self.send_fragmented(socket, &data).await?; } else { socket.send(&data).await?; } } Ok(()) } /// 发送输入事件 pub async fn send_input(&self, event: &InputEvent) -> Result<()> { if let Some(socket) = &self.socket { let data = bincode::serialize(event)?; socket.send(&data).await?; } Ok(()) } /// 分片发送大数据 async fn send_fragmented(&self, socket: &UdpSocket, data: &[u8]) -> Result<()> { const MAX_FRAGMENT_SIZE: usize = 60000; let total_fragments = (data.len() + MAX_FRAGMENT_SIZE - 1) / MAX_FRAGMENT_SIZE; for (i, chunk) in data.chunks(MAX_FRAGMENT_SIZE).enumerate() { let fragment = FragmentHeader { fragment_id: i as u16, total_fragments: total_fragments as u16, data: chunk.to_vec(), }; let fragment_data = bincode::serialize(&fragment)?; socket.send(&fragment_data).await?; } Ok(()) } /// 开始接收数据 pub async fn start_receiving( &self, on_frame: impl Fn(FrameData) + Send + 'static, on_input: impl Fn(InputEvent) + Send + 'static, ) -> Result<()> { let socket = self.socket.clone().ok_or_else(|| anyhow::anyhow!("Socket not initialized"))?; let state = self.state.clone(); tokio::spawn(async move { let mut buf = vec![0u8; 100000]; let mut fragment_buffer: Vec>> = Vec::new(); let mut expected_fragments = 0u16; loop { if *state.read().await != ConnectionState::Connected { break; } match socket.recv(&mut buf).await { Ok(len) => { let data = &buf[..len]; // 尝试解析为分片 if let Ok(fragment) = bincode::deserialize::(data) { if fragment.total_fragments > 1 { // 处理分片 if fragment_buffer.is_empty() { fragment_buffer = vec![None; fragment.total_fragments as usize]; expected_fragments = fragment.total_fragments; } if fragment.fragment_id < expected_fragments { fragment_buffer[fragment.fragment_id as usize] = Some(fragment.data); } // 检查是否收到所有分片 if fragment_buffer.iter().all(|f| f.is_some()) { let complete_data: Vec = fragment_buffer .iter() .filter_map(|f| f.clone()) .flatten() .collect(); if let Ok(frame) = bincode::deserialize::(&complete_data) { on_frame(frame); } fragment_buffer.clear(); } continue; } } // 尝试解析为帧数据 if let Ok(frame) = bincode::deserialize::(data) { on_frame(frame); continue; } // 尝试解析为输入事件 if let Ok(input) = bincode::deserialize::(data) { on_input(input); } } Err(e) => { tracing::warn!("Receive error: {}", e); } } } }); Ok(()) } /// 断开连接 pub async fn disconnect(&mut self) { *self.state.write().await = ConnectionState::Disconnected; self.socket = None; self.remote_addr = None; } } /// 分片头 #[derive(Debug, serde::Serialize, serde::Deserialize)] struct FragmentHeader { fragment_id: u16, total_fragments: u16, data: Vec, } /// ICE 服务器配置 #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] pub struct IceServersConfig { pub stun_servers: Vec, pub turn_server: Option, } /// TURN 服务器配置 #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] pub struct TurnConfig { pub url: String, pub username: String, pub credential: String, } /// NAT穿透辅助 pub struct NatTraversal; impl NatTraversal { /// 获取所有本地 IP 地址 pub fn get_local_ips() -> Vec { let mut ips = Vec::new(); // 尝试使用 socket 连接来获取本地 IP if let Ok(socket) = std::net::UdpSocket::bind("0.0.0.0:0") { // 连接到公网地址来获取本机实际使用的 IP if let Ok(()) = socket.connect("8.8.8.8:80") { if let Ok(addr) = socket.local_addr() { let ip = addr.ip().to_string(); if !ip.starts_with("127.") && ip != "0.0.0.0" { ips.push(ip); } } } } // 如果没有获取到,添加 localhost if ips.is_empty() { ips.push("127.0.0.1".to_string()); } ips } /// 从服务器获取 ICE 服务器配置 pub async fn fetch_ice_servers(server_url: &str) -> Result { // 将 ws:// 或 wss:// 转换为 http:// 或 https:// let http_url = server_url .replace("ws://", "http://") .replace("wss://", "https://"); let url = format!("{}/api/ice-servers", http_url.trim_end_matches('/')); let response = reqwest::get(&url).await?; let config: IceServersConfig = response.json().await?; Ok(config) } /// 获取本地候选地址 pub async fn get_local_candidates() -> Result> { let mut candidates = Vec::new(); // 获取所有本地网络接口 if let Ok(socket) = UdpSocket::bind("0.0.0.0:0").await { if let Ok(addr) = socket.local_addr() { candidates.push(format!("host {}", addr)); } } Ok(candidates) } /// 获取完整的 ICE 候选(包括本地和公网地址) pub async fn get_all_candidates(stun_servers: &[String]) -> Result> { let mut candidates = Vec::new(); // 添加本地地址候选 let local_candidates = Self::get_local_candidates().await?; for candidate in local_candidates { candidates.push(IceCandidate::new(candidate)); } // 使用 STUN 服务器获取公网地址 for stun_url in stun_servers { match Self::get_public_addr_from_url(stun_url).await { Ok(addr) => { candidates.push(IceCandidate::new(format!("srflx {}", addr))); tracing::info!("Got public address from STUN {}: {}", stun_url, addr); break; // 成功获取一个公网地址即可 } Err(e) => { tracing::warn!("Failed to get public address from {}: {}", stun_url, e); } } } Ok(candidates) } /// 从 STUN URL 解析并获取公网地址 pub async fn get_public_addr_from_url(stun_url: &str) -> Result { // 解析 STUN URL,格式: stun:host:port 或 stun:host let url = stun_url.trim_start_matches("stun:"); let parts: Vec<&str> = url.split(':').collect(); let (host, port) = match parts.len() { 1 => (parts[0], 3478u16), // 默认 STUN 端口 2 => (parts[0], parts[1].parse().unwrap_or(3478)), _ => anyhow::bail!("Invalid STUN URL format: {}", stun_url), }; Self::get_public_addr(host, port).await } /// 使用 STUN 服务器获取公网地址 pub async fn get_public_addr(stun_host: &str, stun_port: u16) -> Result { let socket = UdpSocket::bind("0.0.0.0:0").await?; // 解析 STUN 服务器地址 let stun_addr = tokio::net::lookup_host(format!("{}:{}", stun_host, stun_port)) .await? .next() .ok_or_else(|| anyhow::anyhow!("Cannot resolve STUN server: {}", stun_host))?; // 设置超时 socket.connect(stun_addr).await?; // 发送 STUN Binding 请求 let binding_request = create_stun_binding_request(); socket.send(&binding_request).await?; // 接收响应(带超时) let mut buf = [0u8; 1024]; let recv_result = tokio::time::timeout( std::time::Duration::from_secs(5), socket.recv(&mut buf) ).await; match recv_result { Ok(Ok(len)) => parse_stun_response(&buf[..len]), Ok(Err(e)) => anyhow::bail!("STUN receive error: {}", e), Err(_) => anyhow::bail!("STUN request timeout"), } } } /// STUN Magic Cookie (RFC 5389) const STUN_MAGIC_COOKIE: u32 = 0x2112A442; /// STUN Attribute Types const ATTR_XOR_MAPPED_ADDRESS: u16 = 0x0020; fn create_stun_binding_request() -> Vec { use rand::Rng; let mut request = Vec::with_capacity(20); // Message Type: Binding Request (0x0001) request.extend_from_slice(&0x0001u16.to_be_bytes()); // Message Length: 0 (no attributes) request.extend_from_slice(&0x0000u16.to_be_bytes()); // Magic Cookie (RFC 5389) request.extend_from_slice(&STUN_MAGIC_COOKIE.to_be_bytes()); // Transaction ID (12 bytes random) let mut rng = rand::thread_rng(); let transaction_id: [u8; 12] = rng.gen(); request.extend_from_slice(&transaction_id); request } fn parse_stun_response(data: &[u8]) -> Result { // 验证最小长度 (20 bytes header) if data.len() < 20 { anyhow::bail!("Invalid STUN response: too short"); } // 验证 Magic Cookie let magic = u32::from_be_bytes([data[4], data[5], data[6], data[7]]); if magic != STUN_MAGIC_COOKIE { anyhow::bail!("Invalid STUN magic cookie"); } // 解析消息长度 let msg_len = u16::from_be_bytes([data[2], data[3]]) as usize; if data.len() < 20 + msg_len { anyhow::bail!("Invalid STUN response: truncated"); } // 遍历属性查找 XOR-MAPPED-ADDRESS let mut offset = 20; while offset + 4 <= 20 + msg_len { let attr_type = u16::from_be_bytes([data[offset], data[offset + 1]]); let attr_len = u16::from_be_bytes([data[offset + 2], data[offset + 3]]) as usize; if attr_type == ATTR_XOR_MAPPED_ADDRESS { // XOR-MAPPED-ADDRESS 格式: // 1 byte: reserved // 1 byte: family (0x01=IPv4, 0x02=IPv6) // 2 bytes: XOR'd port // 4 bytes (IPv4) or 16 bytes (IPv6): XOR'd address if offset + 4 + attr_len > data.len() { anyhow::bail!("Invalid XOR-MAPPED-ADDRESS attribute"); } let family = data[offset + 5]; // XOR'd port let xor_port = u16::from_be_bytes([data[offset + 6], data[offset + 7]]); let port = xor_port ^ ((STUN_MAGIC_COOKIE >> 16) as u16); match family { 0x01 => { // IPv4 let magic_bytes = STUN_MAGIC_COOKIE.to_be_bytes(); let ip = std::net::Ipv4Addr::new( data[offset + 8] ^ magic_bytes[0], data[offset + 9] ^ magic_bytes[1], data[offset + 10] ^ magic_bytes[2], data[offset + 11] ^ magic_bytes[3], ); return Ok(SocketAddr::new(std::net::IpAddr::V4(ip), port)); } 0x02 => { // IPv6 (需要 Transaction ID 进行 XOR) let magic_bytes = STUN_MAGIC_COOKIE.to_be_bytes(); let transaction_id = &data[8..20]; let mut ip_bytes = [0u8; 16]; // 前 4 bytes 与 magic cookie XOR for i in 0..4 { ip_bytes[i] = data[offset + 8 + i] ^ magic_bytes[i]; } // 后 12 bytes 与 transaction ID XOR for i in 4..16 { ip_bytes[i] = data[offset + 8 + i] ^ transaction_id[i - 4]; } let ip = std::net::Ipv6Addr::from(ip_bytes); return Ok(SocketAddr::new(std::net::IpAddr::V6(ip), port)); } _ => { anyhow::bail!("Unknown address family: {}", family); } } } // 移动到下一个属性 (4字节对齐) offset += 4 + attr_len + (4 - (attr_len % 4)) % 4; } anyhow::bail!("XOR-MAPPED-ADDRESS not found in STUN response") } /// ICE 候选 #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] pub struct IceCandidate { pub candidate: String, pub sdp_mid: Option, pub sdp_mline_index: Option, } impl IceCandidate { pub fn new(candidate: String) -> Self { Self { candidate, sdp_mid: None, sdp_mline_index: Some(0), } } }