optimize stn_buf
This commit is contained in:
parent
5315f24cdf
commit
47adcc28b5
|
@ -2,7 +2,7 @@ use super::*;
|
|||
use crate::misc::{build_socket_listener, socketaddr_to_string};
|
||||
use log::*;
|
||||
use std::{sync::Arc, time::Duration};
|
||||
use stn_buf::Buf;
|
||||
use stn_buf::VecBuf;
|
||||
use tokio::{
|
||||
io::{AsyncReadExt, AsyncWriteExt},
|
||||
net::{TcpListener, TcpStream, UdpSocket},
|
||||
|
@ -95,7 +95,7 @@ impl In {
|
|||
mut client: TcpStream,
|
||||
saddr: String,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let mut buf = Buf::new(TCP_LEN);
|
||||
let mut buf = Vec::with_capacity(TCP_LEN);
|
||||
|
||||
// +----+----------+----------+
|
||||
// |VER | NMETHODS | METHODS |
|
||||
|
@ -110,34 +110,32 @@ impl In {
|
|||
// o X'FF' NO ACCEPTABLE METHODS
|
||||
|
||||
// recv
|
||||
let nread = timeout(self.tcp_timeout, client.read(unsafe { buf.get_unused() })).await??;
|
||||
buf.add_len(nread);
|
||||
let nread = timeout(self.tcp_timeout, client.read(unsafe { buf.remain_mut() })).await??;
|
||||
unsafe { buf.add_len(nread) }
|
||||
// check length
|
||||
if buf.len() < 2 || buf.len() < 2 + buf.get_used()[1] as usize {
|
||||
if buf.len() < 2 || buf.len() < 2 + buf[1] as usize {
|
||||
Err(format!(
|
||||
"{} {} buf.len() < 2 || buf.len() < 2 + buf.get_used()[1] as usize",
|
||||
"{} {} buf.len() < 2 || buf.len() < 2 + buf[1] as usize",
|
||||
self.tag, saddr
|
||||
))?
|
||||
}
|
||||
// check version
|
||||
if buf.get_used()[0] != 5 {
|
||||
if buf[0] != 5 {
|
||||
Err(format!(
|
||||
"{} {} unsupport socks version:{}",
|
||||
self.tag,
|
||||
saddr,
|
||||
buf.get_used()[0]
|
||||
self.tag, saddr, buf[0]
|
||||
))?
|
||||
}
|
||||
// check methods
|
||||
if !&buf.get_used()[2..2 + buf.get_used()[1] as usize].contains(&0) {
|
||||
if !&buf[2..2 + buf[1] as usize].contains(&0) {
|
||||
Err(format!(
|
||||
"{} {} unsupport methods:{:?}",
|
||||
self.tag,
|
||||
saddr,
|
||||
&buf.get_used()[2..2 + buf.get_used()[1] as usize]
|
||||
&buf[2..2 + buf[1] as usize]
|
||||
))?
|
||||
}
|
||||
let header_len = 2 + buf.get_used()[1] as usize;
|
||||
let header_len = 2 + buf[1] as usize;
|
||||
buf.drain(..header_len);
|
||||
|
||||
// +----+--------+
|
||||
|
@ -177,35 +175,28 @@ impl In {
|
|||
// recv
|
||||
while buf.len() < 4
|
||||
|| buf.len()
|
||||
< 4 + match buf.get_used()[3] {
|
||||
< 4 + match buf[3] {
|
||||
ATYP_IPV4 => 4,
|
||||
ATYP_DOMAIN => {
|
||||
if buf.len() < 5 {
|
||||
1
|
||||
} else {
|
||||
1 + buf.get_used()[4] as usize
|
||||
1 + buf[4] as usize
|
||||
}
|
||||
}
|
||||
ATYP_IPV6 => 16,
|
||||
_ => Err(format!(
|
||||
"{} {} unsupport ATYP:{}",
|
||||
self.tag,
|
||||
saddr,
|
||||
buf.get_used()[3]
|
||||
))?,
|
||||
_ => Err(format!("{} {} unsupport ATYP:{}", self.tag, saddr, buf[3]))?,
|
||||
} + 2
|
||||
{
|
||||
let nread =
|
||||
timeout(self.tcp_timeout, client.read(unsafe { buf.get_unused() })).await??;
|
||||
buf.add_len(nread);
|
||||
timeout(self.tcp_timeout, client.read(unsafe { buf.remain_mut() })).await??;
|
||||
unsafe { buf.add_len(nread) }
|
||||
}
|
||||
// check version
|
||||
if buf.get_used()[0] != 5 {
|
||||
if buf[0] != 5 {
|
||||
Err(format!(
|
||||
"{} {} unsupport socks version:{}",
|
||||
self.tag,
|
||||
saddr,
|
||||
buf.get_used()[0]
|
||||
self.tag, saddr, buf[0]
|
||||
))?
|
||||
}
|
||||
|
||||
|
@ -242,19 +233,14 @@ impl In {
|
|||
.await??;
|
||||
|
||||
// read CMD
|
||||
match buf.get_used()[1] {
|
||||
match buf[1] {
|
||||
CMD_CONNECT => {
|
||||
if let Err(e) = self.clone().handle_tcp(client, saddr.clone(), buf).await {
|
||||
warn!("{} {} {}", self.tag, saddr, e);
|
||||
}
|
||||
}
|
||||
CMD_UDP_ASSOCIATE => self.handle_udp(client).await,
|
||||
_ => Err(format!(
|
||||
"{} {} unsupport CMD:{}",
|
||||
self.tag,
|
||||
saddr,
|
||||
buf.get_used()[1]
|
||||
))?,
|
||||
_ => Err(format!("{} {} unsupport CMD:{}", self.tag, saddr, buf[1]))?,
|
||||
}
|
||||
|
||||
Ok(())
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use super::*;
|
||||
use log::*;
|
||||
use std::sync::Arc;
|
||||
use stn_buf::Buf;
|
||||
use stn_buf::VecBuf;
|
||||
use tokio::{
|
||||
io::{AsyncReadExt, AsyncWriteExt},
|
||||
net::TcpStream,
|
||||
|
@ -13,7 +13,7 @@ impl super::In {
|
|||
self: Arc<Self>,
|
||||
client: TcpStream,
|
||||
saddr: String,
|
||||
mut buf: Buf,
|
||||
mut buf: Vec<u8>,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
// +----+-----+-------+------+----------+----------+
|
||||
// |VER | CMD | RSV | ATYP | DST.ADDR | DST.PORT |
|
||||
|
@ -35,7 +35,7 @@ impl super::In {
|
|||
// order
|
||||
|
||||
// get daddr
|
||||
let (daddr, daddr_len) = get_daddr(&buf.get_used()[3..])?;
|
||||
let (daddr, daddr_len) = get_daddr(&buf[3..])?;
|
||||
buf.drain(..4 + daddr_len + 2);
|
||||
|
||||
// connect
|
||||
|
@ -52,19 +52,16 @@ impl super::In {
|
|||
// write server, buf.len() may not 0, so write first
|
||||
if buf.len() != 0 {
|
||||
debug!("{} {} -> {} {}", self.tag, saddr, daddr, buf.len());
|
||||
server_tx
|
||||
.send(buf.get_used().to_vec())
|
||||
.await
|
||||
.or(Err("close"))?;
|
||||
server_tx.send(buf.to_vec()).await.or(Err("close"))?;
|
||||
buf.drain(..);
|
||||
}
|
||||
|
||||
// read client
|
||||
let nread = client_rx.read(unsafe { buf.get_unused() }).await?;
|
||||
let nread = client_rx.read(unsafe { buf.remain_mut() }).await?;
|
||||
if nread == 0 {
|
||||
Err("close")?
|
||||
}
|
||||
buf.add_len(nread);
|
||||
unsafe { buf.add_len(nread) }
|
||||
},
|
||||
{
|
||||
// read server
|
||||
|
|
|
@ -1,89 +0,0 @@
|
|||
use std::{
|
||||
ops::{Bound, RangeBounds},
|
||||
ptr,
|
||||
};
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Buf {
|
||||
buf: Box<[u8]>,
|
||||
len: usize,
|
||||
}
|
||||
|
||||
impl Buf {
|
||||
pub fn new(capacity: usize) -> Self {
|
||||
Self {
|
||||
buf: vec![0u8; capacity].into_boxed_slice(),
|
||||
len: 0,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn len(&self) -> usize {
|
||||
self.len
|
||||
}
|
||||
|
||||
pub fn add_len(&mut self, len: usize) {
|
||||
self.len += len;
|
||||
}
|
||||
|
||||
pub fn get_used(&self) -> &[u8] {
|
||||
&self.buf[..self.len]
|
||||
}
|
||||
|
||||
pub fn get_used_mut(&mut self) -> &mut [u8] {
|
||||
&mut self.buf[..self.len]
|
||||
}
|
||||
|
||||
/// write only
|
||||
pub unsafe fn get_unused(&mut self) -> &mut [u8] {
|
||||
&mut self.buf[self.len..]
|
||||
}
|
||||
|
||||
pub fn drain<R>(&mut self, range: R)
|
||||
where
|
||||
R: RangeBounds<usize>,
|
||||
{
|
||||
let start = match range.start_bound() {
|
||||
Bound::Unbounded => 0,
|
||||
Bound::Included(&n) => n,
|
||||
Bound::Excluded(&n) => n.saturating_add(1),
|
||||
};
|
||||
let end = match range.end_bound() {
|
||||
Bound::Unbounded => self.len,
|
||||
Bound::Included(&n) => n.saturating_add(1),
|
||||
Bound::Excluded(&n) => n,
|
||||
};
|
||||
|
||||
assert!(start <= end, "start({}) <= end({})", start, end);
|
||||
assert!(end <= self.len, "end({}) <= self.len({})", end, self.len);
|
||||
|
||||
unsafe {
|
||||
ptr::copy(
|
||||
self.buf.as_ptr().offset(end as _),
|
||||
self.buf.as_mut_ptr().offset(start as _),
|
||||
self.len - end,
|
||||
);
|
||||
}
|
||||
self.len -= end - start;
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn t1() {
|
||||
use std::io::Write;
|
||||
|
||||
let mut buf = Buf::new(10);
|
||||
|
||||
let nwrite = unsafe { buf.get_unused() }.write(&[1, 1, 1]).unwrap();
|
||||
buf.add_len(nwrite);
|
||||
assert!(buf.get_used() == &[1, 1, 1]);
|
||||
|
||||
buf.drain(1..3);
|
||||
assert!(buf.get_used() == &[1]);
|
||||
|
||||
let nwrite = unsafe { buf.get_unused() }.write(&[1, 1, 1]).unwrap();
|
||||
buf.add_len(nwrite);
|
||||
assert!(buf.get_used() == &[1]);
|
||||
|
||||
buf.drain(..);
|
||||
assert!(buf.get_used() == &[]);
|
||||
}
|
|
@ -1,3 +1,19 @@
|
|||
mod buf;
|
||||
use std::slice;
|
||||
|
||||
pub use buf::*;
|
||||
pub trait VecBuf {
|
||||
unsafe fn remain_mut(&mut self) -> &mut [u8];
|
||||
unsafe fn add_len(&mut self, len: usize);
|
||||
}
|
||||
|
||||
impl VecBuf for Vec<u8> {
|
||||
unsafe fn remain_mut(&mut self) -> &mut [u8] {
|
||||
slice::from_raw_parts_mut(
|
||||
self.as_mut_ptr().add(self.len()) as *mut u8,
|
||||
self.capacity() - self.len(),
|
||||
)
|
||||
}
|
||||
|
||||
unsafe fn add_len(&mut self, len: usize) {
|
||||
self.set_len(self.len() + len)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue