easyremote/crates/client-core/src/connection.rs

482 lines
16 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

//! P2P 连接模块
use anyhow::Result;
use bytes::Bytes;
use easyremote_common::protocol::{FrameData, FrameFormat, InputEvent};
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::net::UdpSocket;
use tokio::sync::{mpsc, RwLock};
/// 连接状态
#[derive(Debug, Clone, PartialEq)]
pub enum ConnectionState {
/// 断开连接
Disconnected,
/// 正在连接
Connecting,
/// 已连接
Connected,
/// 连接失败
Failed(String),
}
/// P2P 连接管理器
pub struct P2PConnection {
state: Arc<RwLock<ConnectionState>>,
local_addr: Option<SocketAddr>,
remote_addr: Option<SocketAddr>,
socket: Option<Arc<UdpSocket>>,
frame_tx: Option<mpsc::Sender<FrameData>>,
input_tx: Option<mpsc::Sender<InputEvent>>,
}
impl P2PConnection {
pub fn new() -> Self {
Self {
state: Arc::new(RwLock::new(ConnectionState::Disconnected)),
local_addr: None,
remote_addr: None,
socket: None,
frame_tx: None,
input_tx: None,
}
}
/// 获取连接状态
pub async fn state(&self) -> ConnectionState {
self.state.read().await.clone()
}
/// 初始化本地套接字
pub async fn init_socket(&mut self) -> Result<SocketAddr> {
let socket = UdpSocket::bind("0.0.0.0:0").await?;
let local_addr = socket.local_addr()?;
self.local_addr = Some(local_addr);
self.socket = Some(Arc::new(socket));
Ok(local_addr)
}
/// 尝试连接到远程地址
pub async fn connect(&mut self, remote_addr: SocketAddr) -> Result<()> {
*self.state.write().await = ConnectionState::Connecting;
if let Some(socket) = &self.socket {
socket.connect(remote_addr).await?;
self.remote_addr = Some(remote_addr);
*self.state.write().await = ConnectionState::Connected;
} else {
*self.state.write().await = ConnectionState::Failed("Socket not initialized".into());
anyhow::bail!("Socket not initialized");
}
Ok(())
}
/// 发送帧数据
pub async fn send_frame(&self, frame: &FrameData) -> Result<()> {
if let Some(socket) = &self.socket {
let data = bincode::serialize(frame)?;
// 如果数据太大,需要分片发送
if data.len() > 65000 {
self.send_fragmented(socket, &data).await?;
} else {
socket.send(&data).await?;
}
}
Ok(())
}
/// 发送输入事件
pub async fn send_input(&self, event: &InputEvent) -> Result<()> {
if let Some(socket) = &self.socket {
let data = bincode::serialize(event)?;
socket.send(&data).await?;
}
Ok(())
}
/// 分片发送大数据
async fn send_fragmented(&self, socket: &UdpSocket, data: &[u8]) -> Result<()> {
const MAX_FRAGMENT_SIZE: usize = 60000;
let total_fragments = (data.len() + MAX_FRAGMENT_SIZE - 1) / MAX_FRAGMENT_SIZE;
for (i, chunk) in data.chunks(MAX_FRAGMENT_SIZE).enumerate() {
let fragment = FragmentHeader {
fragment_id: i as u16,
total_fragments: total_fragments as u16,
data: chunk.to_vec(),
};
let fragment_data = bincode::serialize(&fragment)?;
socket.send(&fragment_data).await?;
}
Ok(())
}
/// 开始接收数据
pub async fn start_receiving(
&self,
on_frame: impl Fn(FrameData) + Send + 'static,
on_input: impl Fn(InputEvent) + Send + 'static,
) -> Result<()> {
let socket = self.socket.clone().ok_or_else(|| anyhow::anyhow!("Socket not initialized"))?;
let state = self.state.clone();
tokio::spawn(async move {
let mut buf = vec![0u8; 100000];
let mut fragment_buffer: Vec<Option<Vec<u8>>> = Vec::new();
let mut expected_fragments = 0u16;
loop {
if *state.read().await != ConnectionState::Connected {
break;
}
match socket.recv(&mut buf).await {
Ok(len) => {
let data = &buf[..len];
// 尝试解析为分片
if let Ok(fragment) = bincode::deserialize::<FragmentHeader>(data) {
if fragment.total_fragments > 1 {
// 处理分片
if fragment_buffer.is_empty() {
fragment_buffer = vec![None; fragment.total_fragments as usize];
expected_fragments = fragment.total_fragments;
}
if fragment.fragment_id < expected_fragments {
fragment_buffer[fragment.fragment_id as usize] =
Some(fragment.data);
}
// 检查是否收到所有分片
if fragment_buffer.iter().all(|f| f.is_some()) {
let complete_data: Vec<u8> = fragment_buffer
.iter()
.filter_map(|f| f.clone())
.flatten()
.collect();
if let Ok(frame) = bincode::deserialize::<FrameData>(&complete_data) {
on_frame(frame);
}
fragment_buffer.clear();
}
continue;
}
}
// 尝试解析为帧数据
if let Ok(frame) = bincode::deserialize::<FrameData>(data) {
on_frame(frame);
continue;
}
// 尝试解析为输入事件
if let Ok(input) = bincode::deserialize::<InputEvent>(data) {
on_input(input);
}
}
Err(e) => {
tracing::warn!("Receive error: {}", e);
}
}
}
});
Ok(())
}
/// 断开连接
pub async fn disconnect(&mut self) {
*self.state.write().await = ConnectionState::Disconnected;
self.socket = None;
self.remote_addr = None;
}
}
/// 分片头
#[derive(Debug, serde::Serialize, serde::Deserialize)]
struct FragmentHeader {
fragment_id: u16,
total_fragments: u16,
data: Vec<u8>,
}
/// ICE 服务器配置
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct IceServersConfig {
pub stun_servers: Vec<String>,
pub turn_server: Option<TurnConfig>,
}
/// TURN 服务器配置
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct TurnConfig {
pub url: String,
pub username: String,
pub credential: String,
}
/// NAT穿透辅助
pub struct NatTraversal;
impl NatTraversal {
/// 获取所有本地 IP 地址
pub fn get_local_ips() -> Vec<String> {
let mut ips = Vec::new();
// 尝试使用 socket 连接来获取本地 IP
if let Ok(socket) = std::net::UdpSocket::bind("0.0.0.0:0") {
// 连接到公网地址来获取本机实际使用的 IP
if let Ok(()) = socket.connect("8.8.8.8:80") {
if let Ok(addr) = socket.local_addr() {
let ip = addr.ip().to_string();
if !ip.starts_with("127.") && ip != "0.0.0.0" {
ips.push(ip);
}
}
}
}
// 如果没有获取到,添加 localhost
if ips.is_empty() {
ips.push("127.0.0.1".to_string());
}
ips
}
/// 从服务器获取 ICE 服务器配置
pub async fn fetch_ice_servers(server_url: &str) -> Result<IceServersConfig> {
// 将 ws:// 或 wss:// 转换为 http:// 或 https://
let http_url = server_url
.replace("ws://", "http://")
.replace("wss://", "https://");
let url = format!("{}/api/ice-servers", http_url.trim_end_matches('/'));
let response = reqwest::get(&url).await?;
let config: IceServersConfig = response.json().await?;
Ok(config)
}
/// 获取本地候选地址
pub async fn get_local_candidates() -> Result<Vec<String>> {
let mut candidates = Vec::new();
// 获取所有本地网络接口
if let Ok(socket) = UdpSocket::bind("0.0.0.0:0").await {
if let Ok(addr) = socket.local_addr() {
candidates.push(format!("host {}", addr));
}
}
Ok(candidates)
}
/// 获取完整的 ICE 候选(包括本地和公网地址)
pub async fn get_all_candidates(stun_servers: &[String]) -> Result<Vec<IceCandidate>> {
let mut candidates = Vec::new();
// 添加本地地址候选
let local_candidates = Self::get_local_candidates().await?;
for candidate in local_candidates {
candidates.push(IceCandidate::new(candidate));
}
// 使用 STUN 服务器获取公网地址
for stun_url in stun_servers {
match Self::get_public_addr_from_url(stun_url).await {
Ok(addr) => {
candidates.push(IceCandidate::new(format!("srflx {}", addr)));
tracing::info!("Got public address from STUN {}: {}", stun_url, addr);
break; // 成功获取一个公网地址即可
}
Err(e) => {
tracing::warn!("Failed to get public address from {}: {}", stun_url, e);
}
}
}
Ok(candidates)
}
/// 从 STUN URL 解析并获取公网地址
pub async fn get_public_addr_from_url(stun_url: &str) -> Result<SocketAddr> {
// 解析 STUN URL格式: stun:host:port 或 stun:host
let url = stun_url.trim_start_matches("stun:");
let parts: Vec<&str> = url.split(':').collect();
let (host, port) = match parts.len() {
1 => (parts[0], 3478u16), // 默认 STUN 端口
2 => (parts[0], parts[1].parse().unwrap_or(3478)),
_ => anyhow::bail!("Invalid STUN URL format: {}", stun_url),
};
Self::get_public_addr(host, port).await
}
/// 使用 STUN 服务器获取公网地址
pub async fn get_public_addr(stun_host: &str, stun_port: u16) -> Result<SocketAddr> {
let socket = UdpSocket::bind("0.0.0.0:0").await?;
// 解析 STUN 服务器地址
let stun_addr = tokio::net::lookup_host(format!("{}:{}", stun_host, stun_port))
.await?
.next()
.ok_or_else(|| anyhow::anyhow!("Cannot resolve STUN server: {}", stun_host))?;
// 设置超时
socket.connect(stun_addr).await?;
// 发送 STUN Binding 请求
let binding_request = create_stun_binding_request();
socket.send(&binding_request).await?;
// 接收响应(带超时)
let mut buf = [0u8; 1024];
let recv_result = tokio::time::timeout(
std::time::Duration::from_secs(5),
socket.recv(&mut buf)
).await;
match recv_result {
Ok(Ok(len)) => parse_stun_response(&buf[..len]),
Ok(Err(e)) => anyhow::bail!("STUN receive error: {}", e),
Err(_) => anyhow::bail!("STUN request timeout"),
}
}
}
/// STUN Magic Cookie (RFC 5389)
const STUN_MAGIC_COOKIE: u32 = 0x2112A442;
/// STUN Attribute Types
const ATTR_XOR_MAPPED_ADDRESS: u16 = 0x0020;
fn create_stun_binding_request() -> Vec<u8> {
use rand::Rng;
let mut request = Vec::with_capacity(20);
// Message Type: Binding Request (0x0001)
request.extend_from_slice(&0x0001u16.to_be_bytes());
// Message Length: 0 (no attributes)
request.extend_from_slice(&0x0000u16.to_be_bytes());
// Magic Cookie (RFC 5389)
request.extend_from_slice(&STUN_MAGIC_COOKIE.to_be_bytes());
// Transaction ID (12 bytes random)
let mut rng = rand::thread_rng();
let transaction_id: [u8; 12] = rng.gen();
request.extend_from_slice(&transaction_id);
request
}
fn parse_stun_response(data: &[u8]) -> Result<SocketAddr> {
// 验证最小长度 (20 bytes header)
if data.len() < 20 {
anyhow::bail!("Invalid STUN response: too short");
}
// 验证 Magic Cookie
let magic = u32::from_be_bytes([data[4], data[5], data[6], data[7]]);
if magic != STUN_MAGIC_COOKIE {
anyhow::bail!("Invalid STUN magic cookie");
}
// 解析消息长度
let msg_len = u16::from_be_bytes([data[2], data[3]]) as usize;
if data.len() < 20 + msg_len {
anyhow::bail!("Invalid STUN response: truncated");
}
// 遍历属性查找 XOR-MAPPED-ADDRESS
let mut offset = 20;
while offset + 4 <= 20 + msg_len {
let attr_type = u16::from_be_bytes([data[offset], data[offset + 1]]);
let attr_len = u16::from_be_bytes([data[offset + 2], data[offset + 3]]) as usize;
if attr_type == ATTR_XOR_MAPPED_ADDRESS {
// XOR-MAPPED-ADDRESS 格式:
// 1 byte: reserved
// 1 byte: family (0x01=IPv4, 0x02=IPv6)
// 2 bytes: XOR'd port
// 4 bytes (IPv4) or 16 bytes (IPv6): XOR'd address
if offset + 4 + attr_len > data.len() {
anyhow::bail!("Invalid XOR-MAPPED-ADDRESS attribute");
}
let family = data[offset + 5];
// XOR'd port
let xor_port = u16::from_be_bytes([data[offset + 6], data[offset + 7]]);
let port = xor_port ^ ((STUN_MAGIC_COOKIE >> 16) as u16);
match family {
0x01 => {
// IPv4
let magic_bytes = STUN_MAGIC_COOKIE.to_be_bytes();
let ip = std::net::Ipv4Addr::new(
data[offset + 8] ^ magic_bytes[0],
data[offset + 9] ^ magic_bytes[1],
data[offset + 10] ^ magic_bytes[2],
data[offset + 11] ^ magic_bytes[3],
);
return Ok(SocketAddr::new(std::net::IpAddr::V4(ip), port));
}
0x02 => {
// IPv6 (需要 Transaction ID 进行 XOR)
let magic_bytes = STUN_MAGIC_COOKIE.to_be_bytes();
let transaction_id = &data[8..20];
let mut ip_bytes = [0u8; 16];
// 前 4 bytes 与 magic cookie XOR
for i in 0..4 {
ip_bytes[i] = data[offset + 8 + i] ^ magic_bytes[i];
}
// 后 12 bytes 与 transaction ID XOR
for i in 4..16 {
ip_bytes[i] = data[offset + 8 + i] ^ transaction_id[i - 4];
}
let ip = std::net::Ipv6Addr::from(ip_bytes);
return Ok(SocketAddr::new(std::net::IpAddr::V6(ip), port));
}
_ => {
anyhow::bail!("Unknown address family: {}", family);
}
}
}
// 移动到下一个属性 (4字节对齐)
offset += 4 + attr_len + (4 - (attr_len % 4)) % 4;
}
anyhow::bail!("XOR-MAPPED-ADDRESS not found in STUN response")
}
/// ICE 候选
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct IceCandidate {
pub candidate: String,
pub sdp_mid: Option<String>,
pub sdp_mline_index: Option<u32>,
}
impl IceCandidate {
pub fn new(candidate: String) -> Self {
Self {
candidate,
sdp_mid: None,
sdp_mline_index: Some(0),
}
}
}