Skip to content

Commit ed3d66a

Browse files
seanmonstarcarllerche
authored andcommitted
add TcpStream::peek
1 parent 2345956 commit ed3d66a

File tree

6 files changed

+116
-23
lines changed

6 files changed

+116
-23
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Unreleased
22

3+
* Add `TcpStream::peek` function (#773)
34
* Raise minimum Rust version to 1.18.0
45
* `Poll`: retry select() when interrupted by a signal
56

src/net/tcp.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,16 @@ impl TcpStream {
343343
self.sys.take_error()
344344
}
345345

346+
/// Receives data on the socket from the remote address to which it is
347+
/// connected, without removing that data from the queue. On success,
348+
/// returns the number of bytes peeked.
349+
///
350+
/// Successive calls return the same data. This is accomplished by passing
351+
/// `MSG_PEEK` as a flag to the underlying recv system call.
352+
pub fn peek(&self, buf: &mut [u8]) -> io::Result<usize> {
353+
self.sys.peek(buf)
354+
}
355+
346356
/// Read in a list of buffers all at once.
347357
///
348358
/// This operation will attempt to read bytes from this socket and place

src/sys/fuchsia/net.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,10 @@ impl TcpStream {
128128
self.io.take_error()
129129
}
130130

131+
pub fn peek(&self, buf: &mut [u8]) -> io::Result<usize> {
132+
self.io.peek(buf)
133+
}
134+
131135
pub fn readv(&self, bufs: &mut [&mut IoVec]) -> io::Result<usize> {
132136
unsafe {
133137
let slice = iovec::as_os_slice_mut(bufs);

src/sys/unix/tcp.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,10 @@ impl TcpStream {
126126
self.inner.take_error()
127127
}
128128

129+
pub fn peek(&self, buf: &mut [u8]) -> io::Result<usize> {
130+
self.inner.peek(buf)
131+
}
132+
129133
pub fn readv(&self, bufs: &mut [&mut IoVec]) -> io::Result<usize> {
130134
unsafe {
131135
let slice = iovec::as_os_slice_mut(bufs);

src/sys/windows/tcp.rs

Lines changed: 42 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -230,29 +230,7 @@ impl TcpStream {
230230
self.imp.inner()
231231
}
232232

233-
fn post_register(&self, interest: Ready, me: &mut StreamInner) {
234-
if interest.is_readable() {
235-
self.imp.schedule_read(me);
236-
}
237-
238-
// At least with epoll, if a socket is registered with an interest in
239-
// writing and it's immediately writable then a writable event is
240-
// generated immediately, so do so here.
241-
if interest.is_writable() {
242-
if let State::Empty = me.write {
243-
self.imp.add_readiness(me, Ready::writable());
244-
}
245-
}
246-
}
247-
248-
pub fn read(&self, buf: &mut [u8]) -> io::Result<usize> {
249-
match IoVec::from_bytes_mut(buf) {
250-
Some(vec) => self.readv(&mut [vec]),
251-
None => Ok(0),
252-
}
253-
}
254-
255-
pub fn readv(&self, bufs: &mut [&mut IoVec]) -> io::Result<usize> {
233+
fn before_read(&self) -> io::Result<MutexGuard<StreamInner>> {
256234
let mut me = self.inner();
257235

258236
match me.read {
@@ -280,6 +258,47 @@ impl TcpStream {
280258
State::Ready(()) => {}
281259
}
282260

261+
Ok(me)
262+
}
263+
264+
fn post_register(&self, interest: Ready, me: &mut StreamInner) {
265+
if interest.is_readable() {
266+
self.imp.schedule_read(me);
267+
}
268+
269+
// At least with epoll, if a socket is registered with an interest in
270+
// writing and it's immediately writable then a writable event is
271+
// generated immediately, so do so here.
272+
if interest.is_writable() {
273+
if let State::Empty = me.write {
274+
self.imp.add_readiness(me, Ready::writable());
275+
}
276+
}
277+
}
278+
279+
pub fn read(&self, buf: &mut [u8]) -> io::Result<usize> {
280+
match IoVec::from_bytes_mut(buf) {
281+
Some(vec) => self.readv(&mut [vec]),
282+
None => Ok(0),
283+
}
284+
}
285+
286+
pub fn peek(&self, buf: &mut [u8]) -> io::Result<usize> {
287+
let mut me = self.before_read()?;
288+
289+
match (&self.imp.inner.socket).peek(buf) {
290+
Ok(n) => Ok(n),
291+
Err(e) => {
292+
me.read = State::Empty;
293+
self.imp.schedule_read(&mut me);
294+
Err(e)
295+
}
296+
}
297+
}
298+
299+
pub fn readv(&self, bufs: &mut [&mut IoVec]) -> io::Result<usize> {
300+
let mut me = self.before_read()?;
301+
283302
// TODO: Does WSARecv work on a nonblocking sockets? We ideally want to
284303
// call that instead of looping over all the buffers and calling
285304
// `recv` on each buffer. I'm not sure though if an overlapped

test/test_tcp.rs

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,61 @@ fn read() {
154154
t.join().unwrap();
155155
}
156156

157+
#[test]
158+
fn peek() {
159+
const N: usize = 16 * 1024 * 1024;
160+
struct H { amt: usize, socket: TcpStream, shutdown: bool }
161+
162+
let l = net::TcpListener::bind("127.0.0.1:0").unwrap();
163+
let addr = l.local_addr().unwrap();
164+
165+
let t = thread::spawn(move || {
166+
let mut s = l.accept().unwrap().0;
167+
let b = [0; 1024];
168+
let mut amt = 0;
169+
while amt < N {
170+
amt += s.write(&b).unwrap();
171+
}
172+
});
173+
174+
let poll = Poll::new().unwrap();
175+
let s = TcpStream::connect(&addr).unwrap();
176+
177+
poll.register(&s, Token(1), Ready::readable(), PollOpt::edge()).unwrap();
178+
179+
let mut events = Events::with_capacity(128);
180+
181+
let mut h = H { amt: 0, socket: s, shutdown: false };
182+
while !h.shutdown {
183+
poll.poll(&mut events, None).unwrap();
184+
185+
for event in &events {
186+
assert_eq!(event.token(), Token(1));
187+
let mut b = [0; 1024];
188+
match h.socket.peek(&mut b) {
189+
Ok(_) => (),
190+
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
191+
continue
192+
},
193+
Err(e) => panic!("unexpected error: {:?}", e),
194+
}
195+
196+
loop {
197+
if let Some(amt) = h.socket.try_read(&mut b).unwrap() {
198+
h.amt += amt;
199+
} else {
200+
break
201+
}
202+
if h.amt >= N {
203+
h.shutdown = true;
204+
break
205+
}
206+
}
207+
}
208+
}
209+
t.join().unwrap();
210+
}
211+
157212
#[test]
158213
fn read_bufs() {
159214
const N: usize = 16 * 1024 * 1024;

0 commit comments

Comments
 (0)