482 lines
16 KiB
Rust
482 lines
16 KiB
Rust
//! P2P 连接模块
|
||
|
||
use anyhow::Result;
|
||
use bytes::Bytes;
|
||
use easyremote_common::protocol::{FrameData, FrameFormat, InputEvent};
|
||
use std::net::SocketAddr;
|
||
use std::sync::Arc;
|
||
use tokio::net::UdpSocket;
|
||
use tokio::sync::{mpsc, RwLock};
|
||
|
||
/// 连接状态
|
||
#[derive(Debug, Clone, PartialEq)]
|
||
pub enum ConnectionState {
|
||
/// 断开连接
|
||
Disconnected,
|
||
/// 正在连接
|
||
Connecting,
|
||
/// 已连接
|
||
Connected,
|
||
/// 连接失败
|
||
Failed(String),
|
||
}
|
||
|
||
/// P2P 连接管理器
|
||
pub struct P2PConnection {
|
||
state: Arc<RwLock<ConnectionState>>,
|
||
local_addr: Option<SocketAddr>,
|
||
remote_addr: Option<SocketAddr>,
|
||
socket: Option<Arc<UdpSocket>>,
|
||
frame_tx: Option<mpsc::Sender<FrameData>>,
|
||
input_tx: Option<mpsc::Sender<InputEvent>>,
|
||
}
|
||
|
||
impl P2PConnection {
|
||
pub fn new() -> Self {
|
||
Self {
|
||
state: Arc::new(RwLock::new(ConnectionState::Disconnected)),
|
||
local_addr: None,
|
||
remote_addr: None,
|
||
socket: None,
|
||
frame_tx: None,
|
||
input_tx: None,
|
||
}
|
||
}
|
||
|
||
/// 获取连接状态
|
||
pub async fn state(&self) -> ConnectionState {
|
||
self.state.read().await.clone()
|
||
}
|
||
|
||
/// 初始化本地套接字
|
||
pub async fn init_socket(&mut self) -> Result<SocketAddr> {
|
||
let socket = UdpSocket::bind("0.0.0.0:0").await?;
|
||
let local_addr = socket.local_addr()?;
|
||
self.local_addr = Some(local_addr);
|
||
self.socket = Some(Arc::new(socket));
|
||
Ok(local_addr)
|
||
}
|
||
|
||
/// 尝试连接到远程地址
|
||
pub async fn connect(&mut self, remote_addr: SocketAddr) -> Result<()> {
|
||
*self.state.write().await = ConnectionState::Connecting;
|
||
|
||
if let Some(socket) = &self.socket {
|
||
socket.connect(remote_addr).await?;
|
||
self.remote_addr = Some(remote_addr);
|
||
*self.state.write().await = ConnectionState::Connected;
|
||
} else {
|
||
*self.state.write().await = ConnectionState::Failed("Socket not initialized".into());
|
||
anyhow::bail!("Socket not initialized");
|
||
}
|
||
|
||
Ok(())
|
||
}
|
||
|
||
/// 发送帧数据
|
||
pub async fn send_frame(&self, frame: &FrameData) -> Result<()> {
|
||
if let Some(socket) = &self.socket {
|
||
let data = bincode::serialize(frame)?;
|
||
// 如果数据太大,需要分片发送
|
||
if data.len() > 65000 {
|
||
self.send_fragmented(socket, &data).await?;
|
||
} else {
|
||
socket.send(&data).await?;
|
||
}
|
||
}
|
||
Ok(())
|
||
}
|
||
|
||
/// 发送输入事件
|
||
pub async fn send_input(&self, event: &InputEvent) -> Result<()> {
|
||
if let Some(socket) = &self.socket {
|
||
let data = bincode::serialize(event)?;
|
||
socket.send(&data).await?;
|
||
}
|
||
Ok(())
|
||
}
|
||
|
||
/// 分片发送大数据
|
||
async fn send_fragmented(&self, socket: &UdpSocket, data: &[u8]) -> Result<()> {
|
||
const MAX_FRAGMENT_SIZE: usize = 60000;
|
||
let total_fragments = (data.len() + MAX_FRAGMENT_SIZE - 1) / MAX_FRAGMENT_SIZE;
|
||
|
||
for (i, chunk) in data.chunks(MAX_FRAGMENT_SIZE).enumerate() {
|
||
let fragment = FragmentHeader {
|
||
fragment_id: i as u16,
|
||
total_fragments: total_fragments as u16,
|
||
data: chunk.to_vec(),
|
||
};
|
||
let fragment_data = bincode::serialize(&fragment)?;
|
||
socket.send(&fragment_data).await?;
|
||
}
|
||
|
||
Ok(())
|
||
}
|
||
|
||
/// 开始接收数据
|
||
pub async fn start_receiving(
|
||
&self,
|
||
on_frame: impl Fn(FrameData) + Send + 'static,
|
||
on_input: impl Fn(InputEvent) + Send + 'static,
|
||
) -> Result<()> {
|
||
let socket = self.socket.clone().ok_or_else(|| anyhow::anyhow!("Socket not initialized"))?;
|
||
let state = self.state.clone();
|
||
|
||
tokio::spawn(async move {
|
||
let mut buf = vec![0u8; 100000];
|
||
let mut fragment_buffer: Vec<Option<Vec<u8>>> = Vec::new();
|
||
let mut expected_fragments = 0u16;
|
||
|
||
loop {
|
||
if *state.read().await != ConnectionState::Connected {
|
||
break;
|
||
}
|
||
|
||
match socket.recv(&mut buf).await {
|
||
Ok(len) => {
|
||
let data = &buf[..len];
|
||
|
||
// 尝试解析为分片
|
||
if let Ok(fragment) = bincode::deserialize::<FragmentHeader>(data) {
|
||
if fragment.total_fragments > 1 {
|
||
// 处理分片
|
||
if fragment_buffer.is_empty() {
|
||
fragment_buffer = vec![None; fragment.total_fragments as usize];
|
||
expected_fragments = fragment.total_fragments;
|
||
}
|
||
|
||
if fragment.fragment_id < expected_fragments {
|
||
fragment_buffer[fragment.fragment_id as usize] =
|
||
Some(fragment.data);
|
||
}
|
||
|
||
// 检查是否收到所有分片
|
||
if fragment_buffer.iter().all(|f| f.is_some()) {
|
||
let complete_data: Vec<u8> = fragment_buffer
|
||
.iter()
|
||
.filter_map(|f| f.clone())
|
||
.flatten()
|
||
.collect();
|
||
|
||
if let Ok(frame) = bincode::deserialize::<FrameData>(&complete_data) {
|
||
on_frame(frame);
|
||
}
|
||
|
||
fragment_buffer.clear();
|
||
}
|
||
continue;
|
||
}
|
||
}
|
||
|
||
// 尝试解析为帧数据
|
||
if let Ok(frame) = bincode::deserialize::<FrameData>(data) {
|
||
on_frame(frame);
|
||
continue;
|
||
}
|
||
|
||
// 尝试解析为输入事件
|
||
if let Ok(input) = bincode::deserialize::<InputEvent>(data) {
|
||
on_input(input);
|
||
}
|
||
}
|
||
Err(e) => {
|
||
tracing::warn!("Receive error: {}", e);
|
||
}
|
||
}
|
||
}
|
||
});
|
||
|
||
Ok(())
|
||
}
|
||
|
||
/// 断开连接
|
||
pub async fn disconnect(&mut self) {
|
||
*self.state.write().await = ConnectionState::Disconnected;
|
||
self.socket = None;
|
||
self.remote_addr = None;
|
||
}
|
||
}
|
||
|
||
/// 分片头
|
||
#[derive(Debug, serde::Serialize, serde::Deserialize)]
|
||
struct FragmentHeader {
|
||
fragment_id: u16,
|
||
total_fragments: u16,
|
||
data: Vec<u8>,
|
||
}
|
||
|
||
/// ICE 服务器配置
|
||
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
||
pub struct IceServersConfig {
|
||
pub stun_servers: Vec<String>,
|
||
pub turn_server: Option<TurnConfig>,
|
||
}
|
||
|
||
/// TURN 服务器配置
|
||
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
||
pub struct TurnConfig {
|
||
pub url: String,
|
||
pub username: String,
|
||
pub credential: String,
|
||
}
|
||
|
||
/// NAT穿透辅助
|
||
pub struct NatTraversal;
|
||
|
||
impl NatTraversal {
|
||
/// 获取所有本地 IP 地址
|
||
pub fn get_local_ips() -> Vec<String> {
|
||
let mut ips = Vec::new();
|
||
|
||
// 尝试使用 socket 连接来获取本地 IP
|
||
if let Ok(socket) = std::net::UdpSocket::bind("0.0.0.0:0") {
|
||
// 连接到公网地址来获取本机实际使用的 IP
|
||
if let Ok(()) = socket.connect("8.8.8.8:80") {
|
||
if let Ok(addr) = socket.local_addr() {
|
||
let ip = addr.ip().to_string();
|
||
if !ip.starts_with("127.") && ip != "0.0.0.0" {
|
||
ips.push(ip);
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// 如果没有获取到,添加 localhost
|
||
if ips.is_empty() {
|
||
ips.push("127.0.0.1".to_string());
|
||
}
|
||
|
||
ips
|
||
}
|
||
|
||
/// 从服务器获取 ICE 服务器配置
|
||
pub async fn fetch_ice_servers(server_url: &str) -> Result<IceServersConfig> {
|
||
// 将 ws:// 或 wss:// 转换为 http:// 或 https://
|
||
let http_url = server_url
|
||
.replace("ws://", "http://")
|
||
.replace("wss://", "https://");
|
||
|
||
let url = format!("{}/api/ice-servers", http_url.trim_end_matches('/'));
|
||
|
||
let response = reqwest::get(&url).await?;
|
||
let config: IceServersConfig = response.json().await?;
|
||
|
||
Ok(config)
|
||
}
|
||
|
||
/// 获取本地候选地址
|
||
pub async fn get_local_candidates() -> Result<Vec<String>> {
|
||
let mut candidates = Vec::new();
|
||
|
||
// 获取所有本地网络接口
|
||
if let Ok(socket) = UdpSocket::bind("0.0.0.0:0").await {
|
||
if let Ok(addr) = socket.local_addr() {
|
||
candidates.push(format!("host {}", addr));
|
||
}
|
||
}
|
||
|
||
Ok(candidates)
|
||
}
|
||
|
||
/// 获取完整的 ICE 候选(包括本地和公网地址)
|
||
pub async fn get_all_candidates(stun_servers: &[String]) -> Result<Vec<IceCandidate>> {
|
||
let mut candidates = Vec::new();
|
||
|
||
// 添加本地地址候选
|
||
let local_candidates = Self::get_local_candidates().await?;
|
||
for candidate in local_candidates {
|
||
candidates.push(IceCandidate::new(candidate));
|
||
}
|
||
|
||
// 使用 STUN 服务器获取公网地址
|
||
for stun_url in stun_servers {
|
||
match Self::get_public_addr_from_url(stun_url).await {
|
||
Ok(addr) => {
|
||
candidates.push(IceCandidate::new(format!("srflx {}", addr)));
|
||
tracing::info!("Got public address from STUN {}: {}", stun_url, addr);
|
||
break; // 成功获取一个公网地址即可
|
||
}
|
||
Err(e) => {
|
||
tracing::warn!("Failed to get public address from {}: {}", stun_url, e);
|
||
}
|
||
}
|
||
}
|
||
|
||
Ok(candidates)
|
||
}
|
||
|
||
/// 从 STUN URL 解析并获取公网地址
|
||
pub async fn get_public_addr_from_url(stun_url: &str) -> Result<SocketAddr> {
|
||
// 解析 STUN URL,格式: stun:host:port 或 stun:host
|
||
let url = stun_url.trim_start_matches("stun:");
|
||
let parts: Vec<&str> = url.split(':').collect();
|
||
|
||
let (host, port) = match parts.len() {
|
||
1 => (parts[0], 3478u16), // 默认 STUN 端口
|
||
2 => (parts[0], parts[1].parse().unwrap_or(3478)),
|
||
_ => anyhow::bail!("Invalid STUN URL format: {}", stun_url),
|
||
};
|
||
|
||
Self::get_public_addr(host, port).await
|
||
}
|
||
|
||
/// 使用 STUN 服务器获取公网地址
|
||
pub async fn get_public_addr(stun_host: &str, stun_port: u16) -> Result<SocketAddr> {
|
||
let socket = UdpSocket::bind("0.0.0.0:0").await?;
|
||
|
||
// 解析 STUN 服务器地址
|
||
let stun_addr = tokio::net::lookup_host(format!("{}:{}", stun_host, stun_port))
|
||
.await?
|
||
.next()
|
||
.ok_or_else(|| anyhow::anyhow!("Cannot resolve STUN server: {}", stun_host))?;
|
||
|
||
// 设置超时
|
||
socket.connect(stun_addr).await?;
|
||
|
||
// 发送 STUN Binding 请求
|
||
let binding_request = create_stun_binding_request();
|
||
socket.send(&binding_request).await?;
|
||
|
||
// 接收响应(带超时)
|
||
let mut buf = [0u8; 1024];
|
||
let recv_result = tokio::time::timeout(
|
||
std::time::Duration::from_secs(5),
|
||
socket.recv(&mut buf)
|
||
).await;
|
||
|
||
match recv_result {
|
||
Ok(Ok(len)) => parse_stun_response(&buf[..len]),
|
||
Ok(Err(e)) => anyhow::bail!("STUN receive error: {}", e),
|
||
Err(_) => anyhow::bail!("STUN request timeout"),
|
||
}
|
||
}
|
||
}
|
||
|
||
/// STUN Magic Cookie (RFC 5389)
|
||
const STUN_MAGIC_COOKIE: u32 = 0x2112A442;
|
||
|
||
/// STUN Attribute Types
|
||
const ATTR_XOR_MAPPED_ADDRESS: u16 = 0x0020;
|
||
|
||
fn create_stun_binding_request() -> Vec<u8> {
|
||
use rand::Rng;
|
||
|
||
let mut request = Vec::with_capacity(20);
|
||
|
||
// Message Type: Binding Request (0x0001)
|
||
request.extend_from_slice(&0x0001u16.to_be_bytes());
|
||
// Message Length: 0 (no attributes)
|
||
request.extend_from_slice(&0x0000u16.to_be_bytes());
|
||
// Magic Cookie (RFC 5389)
|
||
request.extend_from_slice(&STUN_MAGIC_COOKIE.to_be_bytes());
|
||
// Transaction ID (12 bytes random)
|
||
let mut rng = rand::thread_rng();
|
||
let transaction_id: [u8; 12] = rng.gen();
|
||
request.extend_from_slice(&transaction_id);
|
||
|
||
request
|
||
}
|
||
|
||
fn parse_stun_response(data: &[u8]) -> Result<SocketAddr> {
|
||
// 验证最小长度 (20 bytes header)
|
||
if data.len() < 20 {
|
||
anyhow::bail!("Invalid STUN response: too short");
|
||
}
|
||
|
||
// 验证 Magic Cookie
|
||
let magic = u32::from_be_bytes([data[4], data[5], data[6], data[7]]);
|
||
if magic != STUN_MAGIC_COOKIE {
|
||
anyhow::bail!("Invalid STUN magic cookie");
|
||
}
|
||
|
||
// 解析消息长度
|
||
let msg_len = u16::from_be_bytes([data[2], data[3]]) as usize;
|
||
if data.len() < 20 + msg_len {
|
||
anyhow::bail!("Invalid STUN response: truncated");
|
||
}
|
||
|
||
// 遍历属性查找 XOR-MAPPED-ADDRESS
|
||
let mut offset = 20;
|
||
while offset + 4 <= 20 + msg_len {
|
||
let attr_type = u16::from_be_bytes([data[offset], data[offset + 1]]);
|
||
let attr_len = u16::from_be_bytes([data[offset + 2], data[offset + 3]]) as usize;
|
||
|
||
if attr_type == ATTR_XOR_MAPPED_ADDRESS {
|
||
// XOR-MAPPED-ADDRESS 格式:
|
||
// 1 byte: reserved
|
||
// 1 byte: family (0x01=IPv4, 0x02=IPv6)
|
||
// 2 bytes: XOR'd port
|
||
// 4 bytes (IPv4) or 16 bytes (IPv6): XOR'd address
|
||
|
||
if offset + 4 + attr_len > data.len() {
|
||
anyhow::bail!("Invalid XOR-MAPPED-ADDRESS attribute");
|
||
}
|
||
|
||
let family = data[offset + 5];
|
||
|
||
// XOR'd port
|
||
let xor_port = u16::from_be_bytes([data[offset + 6], data[offset + 7]]);
|
||
let port = xor_port ^ ((STUN_MAGIC_COOKIE >> 16) as u16);
|
||
|
||
match family {
|
||
0x01 => {
|
||
// IPv4
|
||
let magic_bytes = STUN_MAGIC_COOKIE.to_be_bytes();
|
||
let ip = std::net::Ipv4Addr::new(
|
||
data[offset + 8] ^ magic_bytes[0],
|
||
data[offset + 9] ^ magic_bytes[1],
|
||
data[offset + 10] ^ magic_bytes[2],
|
||
data[offset + 11] ^ magic_bytes[3],
|
||
);
|
||
return Ok(SocketAddr::new(std::net::IpAddr::V4(ip), port));
|
||
}
|
||
0x02 => {
|
||
// IPv6 (需要 Transaction ID 进行 XOR)
|
||
let magic_bytes = STUN_MAGIC_COOKIE.to_be_bytes();
|
||
let transaction_id = &data[8..20];
|
||
let mut ip_bytes = [0u8; 16];
|
||
|
||
// 前 4 bytes 与 magic cookie XOR
|
||
for i in 0..4 {
|
||
ip_bytes[i] = data[offset + 8 + i] ^ magic_bytes[i];
|
||
}
|
||
// 后 12 bytes 与 transaction ID XOR
|
||
for i in 4..16 {
|
||
ip_bytes[i] = data[offset + 8 + i] ^ transaction_id[i - 4];
|
||
}
|
||
|
||
let ip = std::net::Ipv6Addr::from(ip_bytes);
|
||
return Ok(SocketAddr::new(std::net::IpAddr::V6(ip), port));
|
||
}
|
||
_ => {
|
||
anyhow::bail!("Unknown address family: {}", family);
|
||
}
|
||
}
|
||
}
|
||
|
||
// 移动到下一个属性 (4字节对齐)
|
||
offset += 4 + attr_len + (4 - (attr_len % 4)) % 4;
|
||
}
|
||
|
||
anyhow::bail!("XOR-MAPPED-ADDRESS not found in STUN response")
|
||
}
|
||
|
||
/// ICE 候选
|
||
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
||
pub struct IceCandidate {
|
||
pub candidate: String,
|
||
pub sdp_mid: Option<String>,
|
||
pub sdp_mline_index: Option<u32>,
|
||
}
|
||
|
||
impl IceCandidate {
|
||
pub fn new(candidate: String) -> Self {
|
||
Self {
|
||
candidate,
|
||
sdp_mid: None,
|
||
sdp_mline_index: Some(0),
|
||
}
|
||
}
|
||
}
|