use std::time::Duration;

use crate::peer_connection::with_timeout;
use anyhow::Context;
use buffers::ByteBuf;
use peer_binary_protocol::{
    Handshake, MessageBorrowed, MessageDeserializeError, PIECE_MESSAGE_DEFAULT_LEN,
};
use tokio::io::AsyncReadExt;

pub struct ReadBuf {
    buf: Vec<u8>,
    // How many bytes into the buffer we have read from the connection.
    // New reads should go past this.
    filled: usize,
    // How many bytes have we successfully deserialized.
    processed: usize,
}

impl ReadBuf {
    pub fn new() -> Self {
        Self {
            buf: vec![0; PIECE_MESSAGE_DEFAULT_LEN * 2],
            filled: 0,
            processed: 0,
        }
    }

    fn prepare_for_read(&mut self, need_additional_bytes: usize) {
        // Ensure the buffer starts from the to-be-deserialized message.
        if self.processed > 0 {
            if self.filled > self.processed {
                self.buf.copy_within(self.processed..self.filled, 0);
            }
            self.filled -= self.processed;
            self.processed = 0;
        }

        // Ensure we have enough capacity to deserialize the message.
        if self.buf.len() < self.filled + need_additional_bytes {
            self.buf.reserve(need_additional_bytes);
            self.buf.resize(self.buf.capacity(), 0);
        }
    }

    // Read the BT handshake.
    // This MUST be run as the first operation on the buffer.
    pub async fn read_handshake(
        &mut self,
        mut conn: impl AsyncReadExt + Unpin,
        timeout: Duration,
    ) -> anyhow::Result<Handshake<ByteBuf<'_>>> {
        self.filled = with_timeout(timeout, conn.read(&mut self.buf))
            .await
            .context("error reading handshake")?;
        if self.filled == 0 {
            anyhow::bail!("peer disconnected while reading handshake");
        }
        let (h, size) = Handshake::deserialize(&self.buf[..self.filled]).map_err(|e| {
            anyhow::anyhow!(
                "error deserializing handshake: {:?} hadshake data {:?}",
                e,
                &self.buf[..self.filled.min(19)]
            )
        })?;
        self.processed = size;
        Ok(h)
    }

    // Read a message into the buffer, try to deserialize it and call the callback on it.
    pub async fn read_message(
        &mut self,
        mut conn: impl AsyncReadExt + Unpin,
        timeout: Duration,
    ) -> anyhow::Result<MessageBorrowed<'_>> {
        loop {
            let need_additional_bytes =
                match MessageBorrowed::deserialize(&self.buf[self.processed..self.filled]) {
                    Err(MessageDeserializeError::NotEnoughData(d, _)) => d,
                    Ok((msg, size)) => {
                        self.processed += size;

                        // Rust's borrow checker can't do this early return so resort to unsafe.
                        // This erases the lifetime so that it's happy.
                        let msg: MessageBorrowed<'_> =
                            unsafe { std::mem::transmute(msg as MessageBorrowed<'_>) };
                        return Ok(msg);
                    }
                    Err(e) => return Err(e.into()),
                };
            self.prepare_for_read(need_additional_bytes);
            debug_assert!(!self.buf[self.filled..].is_empty());
            let size = with_timeout(timeout, conn.read(&mut self.buf[self.filled..]))
                .await
                .context("error reading from peer")?;
            if size == 0 {
                anyhow::bail!("disconnected while reading, read so far: {}", self.filled)
            }
            self.filled += size;
        }
    }
}
