optimize stn_buf

This commit is contained in:
fh0 2021-07-21 11:34:45 +08:00
parent 5315f24cdf
commit 47adcc28b5
4 changed files with 44 additions and 134 deletions

View File

@ -2,7 +2,7 @@ use super::*;
use crate::misc::{build_socket_listener, socketaddr_to_string};
use log::*;
use std::{sync::Arc, time::Duration};
use stn_buf::Buf;
use stn_buf::VecBuf;
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
net::{TcpListener, TcpStream, UdpSocket},
@ -95,7 +95,7 @@ impl In {
mut client: TcpStream,
saddr: String,
) -> Result<(), Box<dyn std::error::Error>> {
let mut buf = Buf::new(TCP_LEN);
let mut buf = Vec::with_capacity(TCP_LEN);
// +----+----------+----------+
// |VER | NMETHODS | METHODS |
@ -110,34 +110,32 @@ impl In {
// o X'FF' NO ACCEPTABLE METHODS
// recv
let nread = timeout(self.tcp_timeout, client.read(unsafe { buf.get_unused() })).await??;
buf.add_len(nread);
let nread = timeout(self.tcp_timeout, client.read(unsafe { buf.remain_mut() })).await??;
unsafe { buf.add_len(nread) }
// check length
if buf.len() < 2 || buf.len() < 2 + buf.get_used()[1] as usize {
if buf.len() < 2 || buf.len() < 2 + buf[1] as usize {
Err(format!(
"{} {} buf.len() < 2 || buf.len() < 2 + buf.get_used()[1] as usize",
"{} {} buf.len() < 2 || buf.len() < 2 + buf[1] as usize",
self.tag, saddr
))?
}
// check version
if buf.get_used()[0] != 5 {
if buf[0] != 5 {
Err(format!(
"{} {} unsupport socks version:{}",
self.tag,
saddr,
buf.get_used()[0]
self.tag, saddr, buf[0]
))?
}
// check methods
if !&buf.get_used()[2..2 + buf.get_used()[1] as usize].contains(&0) {
if !&buf[2..2 + buf[1] as usize].contains(&0) {
Err(format!(
"{} {} unsupport methods:{:?}",
self.tag,
saddr,
&buf.get_used()[2..2 + buf.get_used()[1] as usize]
&buf[2..2 + buf[1] as usize]
))?
}
let header_len = 2 + buf.get_used()[1] as usize;
let header_len = 2 + buf[1] as usize;
buf.drain(..header_len);
// +----+--------+
@ -177,35 +175,28 @@ impl In {
// recv
while buf.len() < 4
|| buf.len()
< 4 + match buf.get_used()[3] {
< 4 + match buf[3] {
ATYP_IPV4 => 4,
ATYP_DOMAIN => {
if buf.len() < 5 {
1
} else {
1 + buf.get_used()[4] as usize
1 + buf[4] as usize
}
}
ATYP_IPV6 => 16,
_ => Err(format!(
"{} {} unsupport ATYP:{}",
self.tag,
saddr,
buf.get_used()[3]
))?,
_ => Err(format!("{} {} unsupport ATYP:{}", self.tag, saddr, buf[3]))?,
} + 2
{
let nread =
timeout(self.tcp_timeout, client.read(unsafe { buf.get_unused() })).await??;
buf.add_len(nread);
timeout(self.tcp_timeout, client.read(unsafe { buf.remain_mut() })).await??;
unsafe { buf.add_len(nread) }
}
// check version
if buf.get_used()[0] != 5 {
if buf[0] != 5 {
Err(format!(
"{} {} unsupport socks version:{}",
self.tag,
saddr,
buf.get_used()[0]
self.tag, saddr, buf[0]
))?
}
@ -242,19 +233,14 @@ impl In {
.await??;
// read CMD
match buf.get_used()[1] {
match buf[1] {
CMD_CONNECT => {
if let Err(e) = self.clone().handle_tcp(client, saddr.clone(), buf).await {
warn!("{} {} {}", self.tag, saddr, e);
}
}
CMD_UDP_ASSOCIATE => self.handle_udp(client).await,
_ => Err(format!(
"{} {} unsupport CMD:{}",
self.tag,
saddr,
buf.get_used()[1]
))?,
_ => Err(format!("{} {} unsupport CMD:{}", self.tag, saddr, buf[1]))?,
}
Ok(())

View File

@ -1,7 +1,7 @@
use super::*;
use log::*;
use std::sync::Arc;
use stn_buf::Buf;
use stn_buf::VecBuf;
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
net::TcpStream,
@ -13,7 +13,7 @@ impl super::In {
self: Arc<Self>,
client: TcpStream,
saddr: String,
mut buf: Buf,
mut buf: Vec<u8>,
) -> Result<(), Box<dyn std::error::Error>> {
// +----+-----+-------+------+----------+----------+
// |VER | CMD | RSV | ATYP | DST.ADDR | DST.PORT |
@ -35,7 +35,7 @@ impl super::In {
// order
// get daddr
let (daddr, daddr_len) = get_daddr(&buf.get_used()[3..])?;
let (daddr, daddr_len) = get_daddr(&buf[3..])?;
buf.drain(..4 + daddr_len + 2);
// connect
@ -52,19 +52,16 @@ impl super::In {
// write server, buf.len() may not 0, so write first
if buf.len() != 0 {
debug!("{} {} -> {} {}", self.tag, saddr, daddr, buf.len());
server_tx
.send(buf.get_used().to_vec())
.await
.or(Err("close"))?;
server_tx.send(buf.to_vec()).await.or(Err("close"))?;
buf.drain(..);
}
// read client
let nread = client_rx.read(unsafe { buf.get_unused() }).await?;
let nread = client_rx.read(unsafe { buf.remain_mut() }).await?;
if nread == 0 {
Err("close")?
}
buf.add_len(nread);
unsafe { buf.add_len(nread) }
},
{
// read server

View File

@ -1,89 +0,0 @@
use std::{
ops::{Bound, RangeBounds},
ptr,
};
#[derive(Clone, Debug)]
pub struct Buf {
buf: Box<[u8]>,
len: usize,
}
impl Buf {
pub fn new(capacity: usize) -> Self {
Self {
buf: vec![0u8; capacity].into_boxed_slice(),
len: 0,
}
}
pub fn len(&self) -> usize {
self.len
}
pub fn add_len(&mut self, len: usize) {
self.len += len;
}
pub fn get_used(&self) -> &[u8] {
&self.buf[..self.len]
}
pub fn get_used_mut(&mut self) -> &mut [u8] {
&mut self.buf[..self.len]
}
/// write only
pub unsafe fn get_unused(&mut self) -> &mut [u8] {
&mut self.buf[self.len..]
}
pub fn drain<R>(&mut self, range: R)
where
R: RangeBounds<usize>,
{
let start = match range.start_bound() {
Bound::Unbounded => 0,
Bound::Included(&n) => n,
Bound::Excluded(&n) => n.saturating_add(1),
};
let end = match range.end_bound() {
Bound::Unbounded => self.len,
Bound::Included(&n) => n.saturating_add(1),
Bound::Excluded(&n) => n,
};
assert!(start <= end, "start({}) <= end({})", start, end);
assert!(end <= self.len, "end({}) <= self.len({})", end, self.len);
unsafe {
ptr::copy(
self.buf.as_ptr().offset(end as _),
self.buf.as_mut_ptr().offset(start as _),
self.len - end,
);
}
self.len -= end - start;
}
}
#[test]
fn t1() {
use std::io::Write;
let mut buf = Buf::new(10);
let nwrite = unsafe { buf.get_unused() }.write(&[1, 1, 1]).unwrap();
buf.add_len(nwrite);
assert!(buf.get_used() == &[1, 1, 1]);
buf.drain(1..3);
assert!(buf.get_used() == &[1]);
let nwrite = unsafe { buf.get_unused() }.write(&[1, 1, 1]).unwrap();
buf.add_len(nwrite);
assert!(buf.get_used() == &[1]);
buf.drain(..);
assert!(buf.get_used() == &[]);
}

View File

@ -1,3 +1,19 @@
mod buf;
use std::slice;
pub use buf::*;
pub trait VecBuf {
unsafe fn remain_mut(&mut self) -> &mut [u8];
unsafe fn add_len(&mut self, len: usize);
}
impl VecBuf for Vec<u8> {
unsafe fn remain_mut(&mut self) -> &mut [u8] {
slice::from_raw_parts_mut(
self.as_mut_ptr().add(self.len()) as *mut u8,
self.capacity() - self.len(),
)
}
unsafe fn add_len(&mut self, len: usize) {
self.set_len(self.len() + len)
}
}