// vi:ft=cpp
/* SPDX-License-Identifier: MIT */
/*
 * Copyright (C) 2014-2020, 2023-2024 Kernkonzept GmbH.
 * Author(s): Alexander Warg <alexander.warg@kernkonzept.com>
 *
 */
#pragma once

#include <l4/sys/types.h>
#include <l4/cxx/bitfield>
#include <l4/cxx/minmax>
#include <l4/cxx/utils>

#include <limits.h>
#include <string.h>
#include <stdio.h>

#include "../virtqueue"

/**
 * \ingroup l4virtio_transport
 *
 * L4-VIRTIO Transport C++ API
 */
namespace L4virtio {
namespace Svr {

/**
 * \brief Type of the device status register.
 */
struct Dev_status
{
  unsigned char raw; ///< Raw value of the VIRTIO device status register.
  Dev_status() = default;

  /// Make Status from raw value.
  explicit Dev_status(l4_uint32_t v) : raw(v) {}

  CXX_BITFIELD_MEMBER(0,  0, acked, raw);
  CXX_BITFIELD_MEMBER(1,  1, driver, raw);
  CXX_BITFIELD_MEMBER(2,  2, driver_ok, raw);
  CXX_BITFIELD_MEMBER(3,  3, features_ok, raw);
  CXX_BITFIELD_MEMBER(6,  7, fail_state, raw);
  CXX_BITFIELD_MEMBER(6,  6, device_needs_reset, raw);
  CXX_BITFIELD_MEMBER(7,  7, failed, raw);

  /**
   * Check if the device is in running state.
   *
   * \return true if the device is in running state.
   *
   * The device is in running state when acked(), driver(), features_ok(), and
   * driver_ok() return true, and device_needs_reset() and failed() return
   * false.
   */
  bool running() const
  {
    return (raw == 0xf);
  }
};

/**
 * \brief Type for device feature bitmap
 */
struct Dev_features
{
  l4_uint32_t raw;  ///< The raw value of the features bitmap
  Dev_features() = default;

  /// Make Features from a raw bitmap.
  explicit Dev_features(l4_uint32_t v) : raw(v) {}

  CXX_BITFIELD_MEMBER(28, 28, ring_indirect_desc, raw);
  CXX_BITFIELD_MEMBER(29, 29, ring_event_idx, raw);
};


/**
 * Virtqueue implementation for the device
 *
 * This class represents a single virtqueue, with a local running available
 * index.
 *
 * \note The Virtqueue implementation is not thread-safe.
 */
class Virtqueue : public L4virtio::Virtqueue
{
public:
  /**
   * VIRTIO request, essentially a descriptor from the available ring.
   */
  class Head_desc
  {
    friend class Virtqueue;
  private:
    Virtqueue::Desc const *_d;
    Head_desc(Virtqueue *r, unsigned i) : _d(r->desc(i)) {}

  public:
    /// Make invalid (NULL) request.
    Head_desc() : _d(0) {}

    /// \return True if the request is valid (not NULL).
    bool valid() const { return _d; }

    /// \return True if the request is valid (not NULL).
    explicit operator bool () const
    { return valid(); }

    /// \return Pointer to the head descriptor of the request.
    Desc const *desc() const
    { return _d; }
  };

  struct Request : Head_desc
  {
    Virtqueue *ring = nullptr;
    Request() = default;
  private:
    friend class Virtqueue;
    Request(Virtqueue *r, unsigned i) : Head_desc(r, i), ring(r) {}
  };


  /**
   * Get the next available descriptor from the available ring.
   *
   * \pre The queue must be in working state.
   * \return A Request for the next available descriptor, the Request is invalid
   *         if there are no descriptors in the available ring.
   * \note The return value must be checked even when a previous desc_avail()
   *       returned true.
   *
   */
  Request next_avail()
  {
    if (L4_LIKELY(_current_avail != _avail->idx))
      {
        rmb();
        unsigned head = _current_avail & _idx_mask;
        ++_current_avail;
        return Request(this, _avail->ring[head]);
      }
    return Request();
  }

  /**
   * Return unfinished descriptors to the available ring, i.e. reset the local
   * next index of the available ring to the given descriptor.
   *
   * \param  d  descriptor of the request that is to be marked as finished.
   *
   * \pre queue must be in working state.
   *
   * \pre `d` must be a valid request from this queue, obtained via
   *      next_avail(), that has not yet been finished, and in addition, no
   *      descriptors following it have been finished.
   */
  void rewind_avail(Head_desc const &d)
  {
    unsigned head_idx = d._d - _desc;
    // Calculate the distance between _current_avail and head_idx, taking into
    // account that _current_avail might have wrapped around with respect to
    // _idx_mask in the meantime.
    _current_avail -= (_current_avail - head_idx) & _idx_mask;
  }

  /**
   * Test for available descriptors.
   *
   * \return true if there are descriptors available, false if not.
   * \pre The queue must be in working state.
   */
  bool desc_avail() const
  {
    return _current_avail != _avail->idx;
  }

  /**
   * Put the given descriptor into the used ring.
   *
   * \param r    Request that shall be marked as finished.
   * \param len  The total number of bytes written.
   *
   * \pre queue must be in working state.
   *
   * \pre `r` must be a valid request from this queue.
   */
  void consumed(Head_desc const &r, l4_uint32_t len = 0)
  {
    l4_uint16_t i = _used->idx & _idx_mask;
    _used->ring[i] = Used_elem(r._d - _desc, len);
    wmb();
    ++_used->idx;
  }

  /**
   * Put multiple descriptors into the used ring.
   *
   * A range of descriptors, specified by `begin` and `end` iterators is
   * added. Each iterator points to a struct that has a `first` member that
   * is a `Head_desc` and a `second` member that is the corresponding number
   * of bytes written.
   *
   * \tparam ITER   The type of the iterator (inferred).
   * \param  begin  Iterator pointing to first new descriptor.
   * \param  end    Iterator pointing to one past last entry.
   *
   * \pre queue must be in working state.
   */
  template<typename ITER>
  void consumed(ITER const &begin, ITER const &end)
  {
    l4_uint16_t added = 0;
    l4_uint16_t idx = _used->idx;

    for (auto elem = begin ; elem != end; ++elem, ++added)
      _used->ring[(idx + added) & _idx_mask]
        = Used_elem(elem->first._d - _desc, elem->second);

    wmb();
    _used->idx += added;
  }

  /**
   * Add a descriptor to the used ring, and notify an observer.
   *
   * \tparam QUEUE_OBSERVER  The type of the observer (inferred).
   * \param  d               descriptor of the request that is to be marked as
   *                         finished.
   * \param  o               Pointer to the observer that is notified.
   * \param  len             Number of bytes written for this request.
   *
   * \pre queue must be in working state.
   *
   * \pre `d` must be a valid request from this queue.
   */
  template<typename QUEUE_OBSERVER>
  void finish(Head_desc &d, QUEUE_OBSERVER *o, l4_uint32_t len = 0)
  {
    consumed(d, len);
    o->notify_queue(this);
    d._d = 0;
  }

  /**
   * Add a range of descriptors to the used ring, and notify an observer once.
   *
   * The iterators are passed to consumed<ITER>(ITER const &, ITER const &),
   * and the requirements detailed there apply.
   *
   * \tparam ITER            type of the iterator (inferred)
   * \tparam QUEUE_OBSERVER  the type of the observer (inferred).
   * \param  begin           iterator pointing to first element.
   * \param  end             iterator pointing to one past last element.
   * \param  o               pointer to the observer that is notified.
   *
   * \pre queue must be in working state.
   */
  template<typename ITER, typename QUEUE_OBSERVER>
  void finish(ITER const &begin, ITER const &end, QUEUE_OBSERVER *o)
  {
    consumed(begin, end);
    o->notify_queue(this);
  }

  /**
   * Set the 'no notify' flag for this queue.
   *
   * This function may be called on a disabled queue.
   */
  void disable_notify()
  {
    if (L4_LIKELY(ready()))
      _used->flags.no_notify() = 1;
  }

  /**
   * Clear the 'no notify' flag for this queue.
   *
   * This function may be called on a disabled queue.
   */
  void enable_notify()
  {
    if (L4_LIKELY(ready()))
      _used->flags.no_notify() = 0;
  }

  /**
   * Get a descriptor from the descriptor list.
   *
   * \param idx  The index of the descriptor.
   *
   * \pre `idx` < `num`
   * \pre queue must be in working state
   */
  Desc const *desc(unsigned idx) const
  { return _desc + idx; }

};

/**
 * \brief Abstract data buffer.
 */
struct Data_buffer
{
  char *pos;         ///< Current buffer position
  l4_uint32_t left;  ///< Bytes left in buffer

  Data_buffer() = default;

  /**
   * \brief Create buffer for object `p`.
   *
   * \tparam T  Type of object (implicit)
   * \param  p  Pointer to object.
   *
   * The buffer shall point to the start of the object `p` and the size left
   * is sizeof(T).
   */
  template<typename T>
  explicit Data_buffer(T *p)
  : pos(reinterpret_cast<char *>(p)), left(sizeof(T))
  {}

  /**
   * Set buffer for object `p`.
   *
   * \tparam T  Type of object (implicit)
   * \param  p  Pointer to object.
   *
   * The buffer shall point to the start of the object `p` and the size left
   * is sizeof(T).
   */
  template<typename T>
  void set(T *p)
  {
    pos = reinterpret_cast<char *>(p);
    left = sizeof(T);
  }

  /**
   * Copy contents from this buffer to the destination buffer.
   *
   * \param dst  Destination buffer.
   * \param max  (optional) Maximum number of bytes to copy.
   * \return the number of bytes copied.
   *
   * This function copies at most `max` bytes from this to `dst`. If
   * `max` is omitted, copies the maximum number of bytes available
   * that fit `dst`.
   */
  l4_uint32_t copy_to(Data_buffer *dst, l4_uint32_t max = UINT_MAX)
  {
    unsigned long bytes = cxx::min(cxx::min(left, dst->left), max);
    memcpy(dst->pos, pos, bytes);
    left -= bytes;
    pos += bytes;
    dst->left -= bytes;
    dst->pos += bytes;
    return bytes;
  }

  /**
   * Skip given number of bytes in this buffer.
   *
   * \param bytes  Number of bytes that shall be skipped.
   * \return The number of bytes skipped.
   *
   * Try to skip the given number of bytes in this buffer, if there are less
   * bytes left in the buffer that given then at most left bytes are skipped
   * and the amount is returned.
   */
  l4_uint32_t skip(l4_uint32_t bytes)
  {
    unsigned long b = cxx::min(left, bytes);
    left -= b;
    pos += b;
    return b;
  }

  /**
   * Check if there are no more bytes left in the buffer.
   *
   * \return true if there are no more bytes left in the buffer.
   */
  bool done() const
  { return left == 0; }
};

class Request_processor;

/**
 * Exception used by Queue to indicate descriptor errors.
 */
struct Bad_descriptor
{
  /// The error code
  enum Error
  {
    Bad_address, ///< Address cannot be translated
    Bad_rights,  ///< Missing access rights on memory
    Bad_flags,   ///< Invalid combination of descriptor flags
    Bad_next,    ///< Invalid next index
    Bad_size     ///< Invalid size of memory block
  };

  /// The processor that triggered the exception
  Request_processor const *proc;

  // The error code
  Error error;

  /**
   * Make a bad descriptor exception.
   *
   * \param proc  The request processor causing the exception
   * \param e     The error code.
   */
  Bad_descriptor(Request_processor const *proc, Error e)
  : proc(proc), error(e)
  {}

  /**
   * Get a human readable description of the error code.
   *
   * \return  Message describing the error.
   */
  char const *message() const
  {
    static char const *const err[] =
    {
      [Bad_address] = "Descriptor address cannot be translated",
      [Bad_rights]  = "Insufficient memory access rights",
      [Bad_flags]   = "Invalid descriptor flags",
      [Bad_next]    = "The descriptor's `next` index is invalid",
      [Bad_size]    = "Invalid size of the memory block"
    };

    if (error >= (sizeof(err) / sizeof(err[0])) || !err[error])
      return "Unknown error";

    return err[error];
  }
};


/**
 * Encapsulate the state for processing a VIRTIO request.
 *
 * A VIRTIO request is a possibly chained list of descriptors retrieved from
 * the available ring of a virtqueue, using Virtqueue::next_avail().
 *
 * The descriptor processing depends on helper (DESC_MAN) for interpreting the
 * descriptors in the context of the device implementation.
 *
 * DESC_MAN has to provide the functionality to safely dereference a
 * descriptor from a descriptor list.
 *
 * The following methods must be provided by DESC_MAN:
 *   * \code DESC_MAN::load_desc(Virtqueue::Desc const &desc,
 *                      Request_processor const *proc,
 *                      Virtqueue::Desc const **table) \endcode
 *     This function is used to dereference `desc` as an indirect descriptor
 *     table, and must return a pointer to an indirect descriptor table.
 *   * \code DESC_MAN::load_desc(Virtqueue::Desc const &desc,
 *                      Request_processor const *proc, ...) \endcode
 *     This function is used to dereference a descriptor as a normal data
 *     buffer, and '...' are the arguments that are passed to start() and next().
 */
class Request_processor
{
private:
  /// pointer to descriptor table (may point to an indirect table)
  Virtqueue::Desc const *_table;

  /// currently processed descriptor
  Virtqueue::Desc _current;

  /// number of entries in the current descriptor table (_table)
  l4_uint16_t _num;

public:
  /**
   * Start processing a new request.
   *
   * \tparam DESC_MAN   Type of descriptor manager (implicit).
   * \param  dm         Descriptor manager that is used to translate VIRTIO
   *                    descriptor addresses.
   * \param  ring       VIRTIO ring of the request.
   * \param  request    VIRTIO request from Virtqueue::next_avail()
   * \param  args       Extra arguments passed to dm->load_desc()
   *
   * \pre The given request must be valid.
   *
   * \throws Bad_descriptor  The descriptor has an invalid size or load_desc()
   *                         has thrown an exception by itself.
   */
  template<typename DESC_MAN, typename ...ARGS>
  void start(DESC_MAN *dm, Virtqueue *ring, Virtqueue::Head_desc const &request, ARGS... args)
  {
    _current = cxx::access_once(request.desc());

    if (_current.flags.indirect())
      {
        dm->load_desc(_current, this, &_table);
        _num = _current.len / sizeof(Virtqueue::Desc);
        if (L4_UNLIKELY(!_num))
          throw Bad_descriptor(this, Bad_descriptor::Bad_size);

        _current = cxx::access_once(_table);
      }
    else
      {
        _table = ring->desc(0);
        _num = ring->num();
      }

    dm->load_desc(_current, this, cxx::forward<ARGS>(args)...);
  }

  /**
   * Start processing a new request.
   *
   * \tparam DESC_MAN  Type of descriptor manager (implicit).
   * \param dm         Descriptor manager that is used to translate VIRTIO
   *                   descriptor addresses.
   * \param request    VIRTIO request from Virtqueue::next_avail()
   * \param args       Extra arguments passed to dm->load_desc()
   * \pre The given request must be valid.
   */
  template<typename DESC_MAN, typename ...ARGS>
  Virtqueue::Request const &start(DESC_MAN *dm, Virtqueue::Request const &request, ARGS... args)
  {
    start(dm, request.ring, request, cxx::forward<ARGS>(args)...);
    return request;
  }

  /**
   * Get the flags of the currently processed descriptor.
   *
   * \return The flags of the currently processed descriptor.
   */
  Virtqueue::Desc::Flags current_flags() const
  { return _current.flags; }

  /**
   * Are there more chained descriptors?
   *
   * \return true if there are more chained descriptors in the current request.
   */
  bool has_more() const
  { return _current.flags.next(); }

  /**
   * Switch to the next descriptor in a descriptor chain.
   *
   * \tparam DESC_MAN   Type of descriptor manager (implicit).
   * \param  dm         Descriptor manager that is used to translate VIRTIO
   *                    descriptor addresses.
   * \param  args       Extra arguments passed to dm->load_desc()
   *
   * \retval true   A next descriptor is available.
   * \retval false  No descriptor available.
   *
   * \throws Bad_descriptor  The `next` index of this descriptor is invalid.
   */
  template<typename DESC_MAN, typename ...ARGS>
  bool next(DESC_MAN *dm, ARGS... args)
  {
    if (!_current.flags.next())
      return false;

    if (L4_UNLIKELY(_current.next >= _num))
      throw Bad_descriptor(this, Bad_descriptor::Bad_next);

    _current = cxx::access_once(_table + _current.next);

    if (0) // we ignore this for performance reasons
      if (L4_UNLIKELY(_current.flags.indirect()))
        throw Bad_descriptor(this, Bad_descriptor::Bad_flags);

    // must throw an exception in case of a bad descriptor
    dm->load_desc(_current, this, cxx::forward<ARGS>(args)...);
    return true;
  }
};

}
}
