// vi:ft=cpp
/* SPDX-License-Identifier: MIT */
/*
 * Copyright (C) 2024 Kernkonzept GmbH.
 * Author(s): Christian Pötzsch <christian.poetzsch@kernkonzept.com>
 *
 * License: see LICENSE.spdx (in this directory or the directories above)
 */
#pragma once

#include <l4/cxx/bitmap>
#include <l4/l4virtio/l4virtio>
#include <l4/l4virtio/server/l4virtio>
#include <l4/l4virtio/server/virtio>
#include <l4/re/util/object_registry>

#include <map>
#include <vector>

namespace L4virtio { namespace Svr { namespace Scmi {

/// SCMI version
enum
{
  Version = 0x20000
};

/// SCMI error codes
enum
{
  Success = 0,
  Not_supported = -1,
  Invalid_parameters = -2,
  Denied = -3,
  Not_found = -4,
  Out_of_range = -5,
  Busy = -6,
  Comms_error = -7,
  Generic_error = -8,
  Hardware_error = -9,
  Protocol_error = -10
};

/// SCMI header
struct Scmi_hdr_t
{
  l4_uint32_t hdr_raw = 0;
  CXX_BITFIELD_MEMBER(18, 27, token, hdr_raw);
  CXX_BITFIELD_MEMBER(10, 17, protocol_id, hdr_raw);
  CXX_BITFIELD_MEMBER( 8,  9, message_type, hdr_raw);
  CXX_BITFIELD_MEMBER( 0,  7, message_id, hdr_raw);
};

/// SCMI protocol ids
enum
{
  Base_protocol = 0x10,
  Power_domain_management_protocol = 0x11,
  System_power_management_protocol = 0x12,
  Performance_domain_management_protocol = 0x13,
  Clock_management_protocol = 0x14,
  Sensor_management_protocol = 0x15,
  Reset_domain_management_protocol = 0x16,
  Voltage_domain_management_protocol = 0x17
};

/// SCMI common protocol message ids
enum
{
  Protocol_version = 0x0,
  Protocol_attributes = 0x1,
  Protocol_message_attributes = 0x2,
};

/// SCMI base protocol message ids
enum
{
  Base_discover_vendor = 0x3,
  Base_discover_sub_vendor = 0x4,
  Base_discover_implementation_version = 0x5,
  Base_discover_list_protocols = 0x6,
  Base_discover_agent = 0x7,
  Base_notify_errors = 0x8,
  Base_set_device_permissions = 0x9,
  Base_set_protocol_permissions = 0xa,
  Base_reset_agent_configuration = 0xb
};

/// SCMI base protocol attributes
struct Base_attr_t
{
  l4_uint32_t attr_raw = 0;
  CXX_BITFIELD_MEMBER(8, 15, nagents, attr_raw);
  CXX_BITFIELD_MEMBER(0,  7, nprots, attr_raw);
};

/// SCMI performance domain management protocol message ids
enum
{
  Performance_domain_attributes = 0x3,
  Performance_describe_levels = 0x4,
  Performance_limits_set = 0x5,
  Performance_limits_get = 0x6,
  Performance_level_set = 0x7,
  Performance_level_get = 0x8,
  Performance_notify_limits = 0x9,
  Performance_notify_level = 0xa,
  Performance_describe_fastchannel = 0xb,
};

/// SCMI performance protocol attributes
struct Performance_attr_t
{
  l4_uint32_t attr_raw = 0;
  CXX_BITFIELD_MEMBER(16, 16, power, attr_raw);
  CXX_BITFIELD_MEMBER( 0, 15, domains, attr_raw);

  l4_uint32_t stat_addr_low = 0;
  l4_uint32_t stat_addr_high = 0;
  l4_uint32_t stat_len = 0;
};

/// SCMI performance domain protocol attributes
struct Performance_domain_attr_t
{
  l4_uint32_t attr_raw = 0;
  CXX_BITFIELD_MEMBER(31, 31, set_limits, attr_raw);
  CXX_BITFIELD_MEMBER(30, 30, set_perf_level, attr_raw);
  CXX_BITFIELD_MEMBER(29, 29, perf_limits_change_notify, attr_raw);
  CXX_BITFIELD_MEMBER(28, 28, perf_level_change_notify, attr_raw);
  CXX_BITFIELD_MEMBER(27, 27, fast_channel, attr_raw);

  l4_uint32_t rate_limit_raw = 0;
  CXX_BITFIELD_MEMBER( 0, 19, rate_limit, rate_limit_raw);

  l4_uint32_t sustained_freq = 0;
  l4_uint32_t sustained_perf_level = 0;
  char name[16] = { 0 };
};

/// SCMI performance describe levels numbers
struct Performance_describe_levels_n_t
{
  l4_uint32_t num_levels_raw = 0;
  CXX_BITFIELD_MEMBER(16, 31, nremain_perf_levels, num_levels_raw);
  CXX_BITFIELD_MEMBER( 0, 11, nperf_levels, num_levels_raw);
};

/// SCMI performance describe level
struct Performance_describe_level_t
{
  l4_uint32_t perf_level = 0;
  l4_uint32_t power_cost = 0;
  l4_uint16_t trans_latency = 0;
  l4_uint16_t res0 = 0;
};

template<typename OBSERV>
struct Queue_worker : Request_processor
{
  Queue_worker(OBSERV *o, Virtqueue *queue)
  : o(o), q(queue)
  {}

  bool init_queue()
  {
    auto r = q->next_avail();

    if (L4_UNLIKELY(!r))
      return false;

    head = start(o->mem_info(), r, &req);

    return true;
  }

  bool next()
  { return Request_processor::next(o->mem_info(), &req); }

  void finish(l4_uint32_t total)
  { q->finish(head, o, total); }

  template<typename T>
  l4_ssize_t read(Data_buffer *buf, T *data, l4_size_t s = sizeof(T))
  {
    buf->pos = reinterpret_cast<char *>(data);
    buf->left = s;
    l4_size_t chunk = 0;
    for (;;)
      {
        chunk += req.copy_to(buf);
        if (req.done())
          next();
        if (!buf->left)
          break;
      }
    if (chunk != s)
      return -1;
    return chunk;
  }

  template<typename T>
  l4_ssize_t write(Data_buffer *buf, T *data, l4_size_t s = sizeof(T))
  {
    buf->pos = reinterpret_cast<char *>(data);
    buf->left = s;
    l4_size_t chunk = 0;
    for (;;)
      {
        chunk += buf->copy_to(&req);
        if (req.done())
          next();
        if (!buf->left)
          break;
      }
    if (chunk != s)
      return -1;
    return chunk;
  }

  l4_ssize_t handle_request()
  {
    try
      {
        if (!head && L4_UNLIKELY(!init_queue()))
          return 0;

        for (;;)
          {
            l4_ssize_t total = 0;
            l4_ssize_t res = 0;
            Scmi_hdr_t hdr;
            Data_buffer buf = Data_buffer(&hdr);
            if ((res = read(&buf, &hdr)) < 0)
              return res;

            // Search/execute handler for given protocol
            auto proto = o->proto(hdr.protocol_id());
            if (proto)
              {
                if ((res = proto->handle_request(hdr, buf, this)) < 0)
                  return res;
                total += res;
              }
            else
              {
                if ((res = write(&buf, &hdr)) < 0)
                  return res;
                total += res;

                l4_int32_t status = Not_supported;
                if ((res = write(&buf, &status)) < 0)
                  return res;
                total += res;
              }

            finish(total);

            head = Virtqueue::Head_desc();
            if (L4_UNLIKELY(!init_queue()))
              return 0;
          }
      }
    catch (L4virtio::Svr::Bad_descriptor const &e)
      {
        return e.error;
      }
    return 0;
  }

private:
  struct Buffer : Data_buffer
  {
    Buffer() = default;
    Buffer(L4virtio::Svr::Driver_mem_region const *r,
           Virtqueue::Desc const &d, Request_processor const *)
    {
      pos = static_cast<char *>(r->local(d.addr));
      left = d.len;
    }
  };

  /// Current head
  Virtqueue::Head_desc head;
  Buffer req;

  /// Pointer to the device the end point belongs to
  OBSERV *o;
  Virtqueue *q;
};

/**
 * Base class for all protocols.
 *
 * Defines an interface for processing the virtio buffers for the implemented
 * protocol.
 */
template<typename OBSERV>
struct Proto
{
  virtual l4_ssize_t handle_request(Scmi_hdr_t &hdr,
                                    Data_buffer &buf,
                                    Queue_worker<OBSERV> *qw) = 0;
};

/**
 * A server implementation of the virtio-scmi protocol.
 *
 * Use this class as a base to implement your own specific SCMI device.
 *
 * SCMI defines multiple protocols which can be optionally handled. This server
 * implementation is flexible enough to handle any combination of them. The
 * user of this server has to deviate from the provided Proto classes (for the
 * protocols he want to handle) and needs to implement the required callbacks.
 *
 * Right now, support for the base and the performance protocol is provided.
 *
 * The base protocol is mandatory.
 *
 * If you want to use this from a Uvmm Linux guest, the device tree needs to
 * look something like this:
 *
 * \code
 *  firmware {
 *    scmi {
 *      compatible = "arm,scmi-virtio";
 *
 *      #address-cells = <1>;
 *      #size-cells = <0>;
 *
 *      // ... supported protocols ...
 *    };
 *  };
 * \endcode
 */
class Scmi_dev : public L4virtio::Svr::Device
{
  struct Features : L4virtio::Svr::Dev_config::Features
  {
    Features() = default;
    Features(l4_uint32_t raw) : L4virtio::Svr::Dev_config::Features(raw) {}
  };

  struct Host_irq : public L4::Irqep_t<Host_irq>
  {
    Scmi_dev *c;
    explicit Host_irq(Scmi_dev *c) : c(c) {}
    void handle_irq() { c->kick(); }
  };

  enum
  {
    Queue_size = 0x10
  };

public:
  Scmi_dev(L4Re::Util::Object_registry *registry)
  : L4virtio::Svr::Device(&_dev_config),
   _dev_config(L4VIRTIO_VENDOR_KK, L4VIRTIO_ID_SCMI, 1),
   _host_irq(this),
   _request_worker(this, &_q[0])
  {
    init_mem_info(2);

    L4Re::chkcap(registry->register_irq_obj(&_host_irq),
                 "Register irq object");

    Features hf(0);
    hf.ring_indirect_desc() = true;

    _dev_config.host_features(0) = hf.raw;

    _dev_config.set_host_feature(L4VIRTIO_FEATURE_VERSION_1);
    _dev_config.reset_hdr();

    reset();
  }

  /// Add an actual protocol implementation with the given id to the server.
  void add_proto(l4_uint32_t id, Proto<Scmi_dev> *proto)
  { _protos.insert({id, proto}); }

  Proto<Scmi_dev> *proto(l4_uint32_t id) const
  {
    if (_protos.find(id) != _protos.end())
      return _protos.at(id);
    return nullptr;
  }

  void notify_queue(L4virtio::Virtqueue *queue)
  {
    if (queue->no_notify_guest())
      return;

    _dev_config.add_irq_status(L4VIRTIO_IRQ_STATUS_VRING);
    _kick_guest_irq->trigger();
  }

private:
  L4::Cap<L4::Irq> device_notify_irq() const override
  { return L4::cap_cast<L4::Irq>(_host_irq.obj_cap()); }

  void register_single_driver_irq() override
  {
    _kick_guest_irq = L4Re::Util::Unique_cap<L4::Irq>(
      L4Re::chkcap(server_iface()->template rcv_cap<L4::Irq>(0)));

    L4Re::chksys(server_iface()->realloc_rcv_cap(0));
  }

  void kick()
  {
    if (_request_worker.handle_request() < 0)
      device_error();
  }

  void reset() override
  {
    for (Virtqueue &q : _q)
      q.disable();

    for (l4_uint32_t i = 0; i < _dev_config.num_queues(); i++)
      reset_queue_config(i, Queue_size);
  }

  bool check_queues() override
  {
    return true;
  }

  int reconfig_queue(unsigned idx) override
  {
    if (idx >= sizeof(_q) / sizeof(_q[0]))
      return -L4_ERANGE;

    return setup_queue(_q + idx, idx, Queue_size);
  }

  void trigger_driver_config_irq() override
  {
    _dev_config.add_irq_status(L4VIRTIO_IRQ_STATUS_CONFIG);
    _kick_guest_irq->trigger();
  }

  L4virtio::Svr::Dev_config_t<L4virtio::Svr::No_custom_data> _dev_config;
  Host_irq _host_irq;
  L4Re::Util::Unique_cap<L4::Irq> _kick_guest_irq;
  Virtqueue _q[1];
  Queue_worker<Scmi_dev> _request_worker;
  std::map<l4_uint32_t, Proto<Scmi_dev> *> _protos;
};

/**
 * Base class for the SCMI base protocol.
 *
 * Use this class as a base to implement the base protocol.
 */
class Base_proto : public Proto<Scmi_dev>
{
  /// Return the base protocol attributes, like the number of supported
  /// protocols.
  virtual l4_int32_t fill_attr(Base_attr_t *attr) const = 0;

  /// Return a list of supported protocols.
  virtual std::vector<l4_uint32_t> prots() const = 0;

  l4_ssize_t handle_request(Scmi_hdr_t &hdr, Data_buffer &buf,
                            Queue_worker<Scmi_dev> *qw) override
  {
    l4_ssize_t total = 0;
    l4_ssize_t res = 0;
    switch (hdr.message_id())
      {
      case Protocol_version:
        {
          if ((res = qw->write(&buf, &hdr)) < 0)
            return res;
          total += res;

          struct
          {
            l4_int32_t status = Success;
            l4_uint32_t version = Version;
          } version;
          if ((res = qw->write(&buf, &version)) < 0)
            return res;
          total += res;
          break;
        }
      case Protocol_attributes:
        {
          if ((res = qw->write(&buf, &hdr)) < 0)
            return res;
          total += res;

          Base_attr_t ba;
          l4_int32_t status = fill_attr(&ba);
          if ((res = qw->write(&buf, &status)) < 0)
            return res;
          total += res;
          if (status == Success)
            {
              if ((res = qw->write(&buf, &ba)) < 0)
                return res;
              total += res;
            }
          break;
        }
      case Protocol_message_attributes:
        {
          l4_uint32_t msg_id = 0;
          if ((res = qw->read(&buf, &msg_id)) < 0)
            return res;

          if ((res = qw->write(&buf, &hdr)) < 0)
            return res;
          total += res;

          struct
          {
            l4_int32_t status = Not_found;
            l4_uint32_t attr = 0;
          } attr;
          if (msg_id >= Protocol_version &&
              msg_id <= Base_discover_list_protocols)
            attr.status = Success;

          if ((res = qw->write(&buf, &attr)) < 0)
            return res;
          total += res;
          break;
        }
      case Base_discover_vendor:
        {
          if ((res = qw->write(&buf, &hdr)) < 0)
            return res;
          total += res;

          struct
          {
            l4_int32_t status = Success;
            l4_uint8_t vendor_identifier[16] = { "L4Re" };
          } vendor;
          if ((res = qw->write(&buf, &vendor)) < 0)
            return res;
          total += res;
          break;
        }
      case Base_discover_sub_vendor:
        {
          if ((res = qw->write(&buf, &hdr)) < 0)
            return res;
          total += res;

          struct
          {
            l4_int32_t status = Success;
            l4_uint8_t vendor_identifier[16] = { "Scmi" };
          } vendor;
          if ((res = qw->write(&buf, &vendor)) < 0)
            return res;
          total += res;
          break;
        }
      case Base_discover_implementation_version:
        {
          if ((res = qw->write(&buf, &hdr)) < 0)
            return res;
          total += res;

          struct
          {
            l4_int32_t status = Success;
            l4_uint32_t version = 1;
          } version;
          if ((res = qw->write(&buf, &version)) < 0)
            return res;
          total += res;
          break;
        }
      case Base_discover_list_protocols:
        {
          l4_uint32_t skip = 0;
          if ((res = qw->read(&buf, &skip)) < 0)
            return res;

          if ((res = qw->write(&buf, &hdr)) < 0)
            return res;
          total += res;

          auto p = prots();
          struct
          {
            l4_int32_t status = Success;
            l4_uint32_t num;
          } proto;
          proto.num = p.size();
          if ((res = qw->write(&buf, &proto)) < 0)
            return res;
          total += res;

          // Array of uint32 where 4 protos are packed into one uint32. So
          // round up to 4 bytes and fill the array byte by byte.
          l4_uint8_t parr[(p.size() + 3) / 4 * 4] = { 0 };
          for (l4_size_t i = 0; i < p.size(); i++)
            parr[i] = p.at(i);

          if ((res = qw->write(&buf, parr, sizeof(parr))) < 0)
            return res;
          total += res;
          break;
        }
      default:
        {
          if ((res = qw->write(&buf, &hdr)) < 0)
            return res;
          total += res;

          l4_int32_t status = Not_supported;
          if ((res = qw->write(&buf, &status)) < 0)
            return res;
          total += res;
          break;
        }
      }

    return total;
  }
};

/**
 * Base class for the SCMI performance protocol.
 *
 * Use this class as a base to implement the performance protocol.
 *
 * If you want to use this from a Uvmm Linux guest, the device tree needs to
 * look something like this:
 *
 * \code
 *  firmware {
 *    scmi {
 *      compatible = "arm,scmi-virtio";
 *
 *      #address-cells = <1>;
 *      #size-cells = <0>;
 *
 *      cpufreq: protocol@13 {
 *        reg = <0x13>;
 *        #clock-cells = <1>;
 *      };
 *    };
 *  };
 *  ....
 *
 *  cpu@0 {
 *    device_type = "cpu";
 *    reg = <0x0>;
 *    clocks = <&cpufreq 0>; // domain_id
 *  };
 * \endcode
 */
class Perf_proto : public Proto<Scmi_dev>
{
  /// Return the performance protocol attributes, like the number of supported
  /// domains.
  virtual l4_int32_t fill_attr(Performance_attr_t *attr) const = 0;

  /// Return the performance protocol domain attributes for a given domain,
  /// like if setting the performance level or limits is allowed.
  virtual l4_int32_t fill_domain_attr(l4_uint32_t domain_id,
                                      Performance_domain_attr_t *attr) const = 0;

  /// Return the amount of supported performance levels for a given domain.
  virtual l4_int32_t fill_describe_levels_n(l4_uint32_t domain_id,
                                            l4_uint32_t level_idx,
                                            Performance_describe_levels_n_t *attr) const = 0;

  /// Return a list of supported performance levels with their attributes for a
  /// given domain.
  virtual l4_int32_t fill_describe_levels(l4_uint32_t domain_id,
                                          l4_uint32_t level_idx,
                                          l4_uint32_t num,
                                          Performance_describe_level_t *attr) const = 0;

  /// Set the performance level for a given domain.
  virtual l4_int32_t level_set(l4_uint32_t domain_id,
                               l4_uint32_t perf_level) = 0;

  /// Get the performance level for a given domain.
  virtual l4_int32_t level_get(l4_uint32_t domain_id,
                               l4_uint32_t *perf_level) const = 0;

  l4_ssize_t handle_request(Scmi_hdr_t &hdr, Data_buffer &buf,
                            Queue_worker<Scmi_dev> *qw) override
  {
    l4_ssize_t total = 0;
    l4_ssize_t res = 0;
    switch (hdr.message_id())
      {
      case Protocol_version:
        {
          if ((res = qw->write(&buf, &hdr)) < 0)
            return res;
          total += res;

          struct
          {
            l4_int32_t status = Success;
            l4_uint32_t version = Version;
          } version;
          if ((res = qw->write(&buf, &version)) < 0)
            return res;
          total += res;
          break;
        }
      case Protocol_attributes:
        {
          if ((res = qw->write(&buf, &hdr)) < 0)
            return res;
          total += res;

          Performance_attr_t pa;
          l4_int32_t status = fill_attr(&pa);
          if ((res = qw->write(&buf, &status)) < 0)
            return res;
          total += res;
          if (status == Success)
            {
              if ((res = qw->write(&buf, &pa)) < 0)
                return res;
              total += res;
            }
          break;
        }
      case Protocol_message_attributes:
        {
          l4_uint32_t msg_id = 0;
          if ((res = qw->read(&buf, &msg_id)) < 0)
            return res;

          if ((res = qw->write(&buf, &hdr)) < 0)
            return res;
          total += res;

          struct
          {
            l4_int32_t status = Not_found;

            l4_uint32_t attr_raw = 0;
            CXX_BITFIELD_MEMBER(0, 0, fast_channel, attr_raw); // ignored
          } attr;
          if ((msg_id >= Protocol_version &&
               msg_id <= Performance_describe_levels) ||
              (msg_id >= Performance_level_set &&
               msg_id <= Performance_level_get))
            attr.status = Success;

          if ((res = qw->write(&buf, &attr)) < 0)
            return res;
          total += res;
          break;
        }
      case Performance_domain_attributes:
        {
          l4_uint32_t domain_id = 0;
          if ((res = qw->read(&buf, &domain_id)) < 0)
            return res;

          if ((res = qw->write(&buf, &hdr)) < 0)
            return res;
          total += res;

          Performance_domain_attr_t attr;
          l4_int32_t status = fill_domain_attr(domain_id, &attr);
          if ((res = qw->write(&buf, &status)) < 0)
            return res;
          total += res;
          if (status == Success)
            {
              if ((res = qw->write(&buf, &attr)) < 0)
                return res;
              total += res;
            }
          break;
        }
      case Performance_describe_levels:
        {
          struct
          {
            l4_uint32_t domain_id = 0;
            l4_uint32_t level_idx = 0;
          } param;
          if ((res = qw->read(&buf, &param)) < 0)
            return res;

          if ((res = qw->write(&buf, &hdr)) < 0)
            return res;
          total += res;

          // First figure out how many levels we support
          Performance_describe_levels_n_t attr;
          l4_int32_t status = fill_describe_levels_n(param.domain_id, param.level_idx,
                                                     &attr);
          if (status != Success)
            {
              // On error bail out early
              if ((res = qw->write(&buf, &status)) < 0)
                return res;
              total += res;
              break;
            }

          // Now fetch the actual levels
          Performance_describe_level_t attr1[attr.nperf_levels().get()];
          status = fill_describe_levels(param.domain_id, param.level_idx,
                                        attr.nperf_levels(), attr1);
          if ((res = qw->write(&buf, &status)) < 0)
            return res;
          total += res;
          if (status == Success)
            {
              // Write both answers to the client
              if ((res = qw->write(&buf, &attr)) < 0)
                return res;
              total += res;
              if ((res = qw->write(&buf, attr1, sizeof(attr1))) < 0)
                return res;
              total += res;
            }
          break;
        }
      case Performance_level_set:
        {
          struct
          {
            l4_uint32_t domain_id;
            l4_uint32_t perf_level;
          } param;
          if ((res = qw->read(&buf, &param)) < 0)
            return res;

          if ((res = qw->write(&buf, &hdr)) < 0)
            return res;
          total += res;

          l4_int32_t status = level_set(param.domain_id, param.perf_level);
          if ((res = qw->write(&buf, &status)) < 0)
            return res;
          total += res;
          break;
        }
      case Performance_level_get:
        {
          l4_uint32_t domain_id = 0;
          if ((res = qw->read(&buf, &domain_id)) < 0)
            return res;

          if ((res = qw->write(&buf, &hdr)) < 0)
            return res;
          total += res;

          l4_uint32_t perf_level;
          l4_int32_t status = level_get(domain_id, &perf_level);
          if ((res = qw->write(&buf, &status)) < 0)
            return res;
          total += res;
          if (status == Success)
            {
              if ((res = qw->write(&buf, &perf_level)) < 0)
                return res;
              total += res;
            }
          break;
        }
      default:
        {
          if ((res = qw->write(&buf, &hdr)) < 0)
            return res;
          total += res;

          l4_int32_t status = Not_supported;
          if ((res = qw->write(&buf, &status)) < 0)
            return res;
          total += res;
          break;
        }
      }
    return total;
  }
};

} /* Scmi */ } /* Svr */ } /* L4virtio */
