// -*- Mode: C++ -*-
// vim:ft=cpp
/**
 * \file
 * \brief  Region handling
 */
/*
 * (c) 2008-2009 Adam Lackorzynski <adam@os.inf.tu-dresden.de>,
 *               Alexander Warg <warg@os.inf.tu-dresden.de>,
 *               Björn Döbel <doebel@os.inf.tu-dresden.de>
 *     economic rights: Technische Universität Dresden (Germany)
 *
 * License: see LICENSE.spdx (in this directory or the directories above)
 */

#pragma once

#include <l4/cxx/avl_map>
#include <l4/sys/types.h>
#include <l4/re/rm>


namespace L4Re { namespace Util {
class Region
{
private:
  l4_addr_t _start, _end;
#ifdef CONFIG_L4RE_REGION_INFO
  char _dbg_name[40]; // Not a 0-terminating string
  unsigned char _dbg_name_len = 0;
  static_assert(sizeof(_dbg_name) < 256);
  Rm::Offset _dbg_backing_offset = 0;
#endif

public:
  Region() noexcept : _start(~0UL), _end(~0UL) {}
  Region(l4_addr_t addr) noexcept : _start(addr), _end(addr) {}
  Region(l4_addr_t start, l4_addr_t end) noexcept
    : _start(start), _end(end) {}
  Region(l4_addr_t start, l4_addr_t end,
         char const *name, unsigned name_len,
         Rm::Offset backing_offset) noexcept
    : _start(start), _end(end)
  {
#ifdef CONFIG_L4RE_REGION_INFO
    _dbg_name_len = name_len > sizeof(_dbg_name)
                    ? sizeof(_dbg_name) : name_len;
    for (unsigned i = 0; i < _dbg_name_len; ++i)
      _dbg_name[i] = name[i];

    _dbg_backing_offset = backing_offset;
#else
    (void)name;
    (void)name_len;
    (void)backing_offset;
#endif
 }
  l4_addr_t start() const noexcept { return _start; }
  l4_addr_t end() const noexcept { return _end; }
  unsigned long size() const noexcept { return end() - start() + 1; }
  bool invalid() const noexcept { return _start == ~0UL && _end == ~0UL; }
  bool operator < (Region const &o) const noexcept
  { return end() < o.start(); }
  bool contains(Region const &o) const noexcept
  { return o.start() >= start() && o.end() <= end(); }
  bool operator == (Region const &o) const noexcept
  { return o.start() == start() && o.end() == end(); }
  ~Region() noexcept {}

#ifdef CONFIG_L4RE_REGION_INFO
  char const *name() const { return _dbg_name; }
  unsigned char name_len() const { return _dbg_name_len; }
  Rm::Offset backing_offset() const { return _dbg_backing_offset; }
#else
  char const *name() const { return "N/A"; }
  unsigned char name_len() const { return 3; }
  Rm::Offset backing_offset() const { return 0; }
#endif
};

template< typename DS, typename OPS >
class Region_handler
{
private:
  L4Re::Rm::Offset _offs;
  DS _mem;
  l4_cap_idx_t _client_cap = L4_INVALID_CAP;
  L4Re::Rm::Region_flags _flags;

public:
  typedef DS Dataspace;
  typedef OPS Ops;
  typedef typename OPS::Map_result Map_result;

  Region_handler() noexcept : _offs(0), _mem(), _flags() {}
  Region_handler(Dataspace const &mem, l4_cap_idx_t client_cap,
      L4Re::Rm::Offset offset = 0,
      L4Re::Rm::Region_flags flags = L4Re::Rm::Region_flags(0)) noexcept
    : _offs(offset), _mem(mem), _client_cap(client_cap), _flags(flags)
  {}

  Dataspace const &memory() const noexcept
  {
    return _mem;
  }

  l4_cap_idx_t client_cap_idx() const noexcept
  {
    return _client_cap;
  }

  L4Re::Rm::Offset offset() const noexcept
  {
    return _offs;
  }

  constexpr bool is_ro() const noexcept
  {
    return !(_flags & L4Re::Rm::F::W);
  }

  L4Re::Rm::Region_flags caching() const noexcept
  {
    return _flags & L4Re::Rm::F::Caching_mask;
  }

  L4Re::Rm::Region_flags flags() const noexcept
  {
    return _flags;
  }

  Region_handler operator + (l4_int64_t offset) const noexcept
  {
    Region_handler n = *this; n._offs += offset; return n;
  }

  void free(l4_addr_t start, unsigned long size) const noexcept
  {
    Ops::free(this, start, size);
  }

  int map(l4_addr_t addr, Region const &r, bool writable,
          Map_result *result) const
  {
    return Ops::map(this, addr, r, writable, result);
  }

  int map_info(l4_addr_t *start_addr, l4_addr_t *end_addr) const
  {
    return Ops::map_info(this, start_addr, end_addr);
  }

};


template< typename Hdlr, template<typename T> class Alloc >
class Region_map
{
protected:
  typedef cxx::Avl_map< Region, Hdlr, cxx::Lt_functor, Alloc > Tree;
  Tree _rm; ///< Region Map
  Tree _am; ///< Area Map

private:
  l4_addr_t _start;
  l4_addr_t _end;

protected:
  void set_limits(l4_addr_t start, l4_addr_t end) noexcept
  {
    _start = start;
    _end = end;
  }

public:
  typedef typename Tree::Item_type  Item;
  typedef typename Tree::Node       Node;
  typedef typename Tree::Key_type   Key_type;
  typedef Hdlr Region_handler;

  typedef typename Tree::Iterator Iterator;
  typedef typename Tree::Const_iterator Const_iterator;
  typedef typename Tree::Rev_iterator Rev_iterator;
  typedef typename Tree::Const_rev_iterator Const_rev_iterator;

  Iterator begin() noexcept { return _rm.begin(); }
  Const_iterator begin() const noexcept { return _rm.begin(); }
  Iterator end() noexcept { return _rm.end(); }
  Const_iterator end() const noexcept { return _rm.end(); }

  Iterator area_begin() noexcept { return _am.begin(); }
  Const_iterator area_begin() const noexcept { return _am.begin(); }
  Iterator area_end() noexcept { return _am.end(); }
  Const_iterator area_end() const noexcept { return _am.end(); }
  Node area_find(Key_type const &c) const noexcept { return _am.find_node(c); }

  l4_addr_t min_addr() const noexcept { return _start; }
  l4_addr_t max_addr() const noexcept { return _end; }


  Region_map(l4_addr_t start, l4_addr_t end) noexcept : _start(start), _end(end) {}

  Node find(Key_type const &key) const noexcept
  {
    Node n = _rm.find_node(key);
    if (!n)
      return Node();

    // 'find' should find any region overlapping with the searched one, the
    // caller should check for further requirements
    if (0)
      if (!n->first.contains(key))
        return Node();

    return n;
  }

  Node lower_bound(Key_type const &key) const noexcept
  {
    Node n = _rm.lower_bound_node(key);
    return n;
  }

  Node lower_bound_area(Key_type const &key) const noexcept
  {
    Node n = _am.lower_bound_node(key);
    return n;
  }

  l4_addr_t attach_area(l4_addr_t addr, unsigned long size,
                        L4Re::Rm::Flags flags = L4Re::Rm::Flags(0),
                        unsigned char align = L4_PAGESHIFT) noexcept
  {
    if (size < 2)
      return L4_INVALID_ADDR;


    Region c;

    if (!(flags & L4Re::Rm::F::Search_addr))
      {
	c = Region(addr, addr + size - 1);
	Node r = _am.find_node(c);
	if (r)
	  return L4_INVALID_ADDR;
      }

    while (flags &  L4Re::Rm::F::Search_addr)
      {
	if (addr < min_addr() || (addr + size - 1) > max_addr())
	  addr = min_addr();
	addr = find_free(addr, max_addr(), size, align, flags);
	if (addr == L4_INVALID_ADDR)
	  return L4_INVALID_ADDR;

	c = Region(addr, addr + size - 1);
	Node r = _am.find_node(c);
	if (!r)
	  break;

	if (r->first.end() >= max_addr())
	  return L4_INVALID_ADDR;

	addr = r->first.end() + 1;
      }

    if (_am.insert(c, Hdlr(typename Hdlr::Dataspace(), 0, 0, flags.region_flags())).second == 0)
      return addr;

    return L4_INVALID_ADDR;
  }

  bool detach_area(l4_addr_t addr) noexcept
  {
    if (_am.remove(addr))
      return false;

    return true;
  }

  void *attach(void *addr, unsigned long size, Hdlr const &hdlr,
               L4Re::Rm::Flags attach_flags = L4Re::Rm::Flags(0),
               unsigned char align = L4_PAGESHIFT,
               char const *name = nullptr, unsigned name_len = 0,
               L4Re::Rm::Offset backing_offset = 0) noexcept
  {
    if (size < 2)
      return L4_INVALID_PTR;

    l4_addr_t beg, end;
    int err = hdlr.map_info(&beg, &end);
    if (err > 0)
      {
        // Mapping address determined by underlying dataspace. Make sure we
        // prevent any additional alignment. We already know the place!
        beg += hdlr.offset();
        end = beg + size - 1U;
        align = L4_PAGESHIFT;

        // In case of exact mappings, the supplied address must match because
        // we cannot remap.
        if (!(attach_flags & L4Re::Rm::F::Search_addr)
            && reinterpret_cast<l4_addr_t>(addr) != beg)
          return L4_INVALID_PTR;

        // When searching for a suitable address, the start must cover the
        // dataspace beginning to "find" the right spot.
        if ((attach_flags & L4Re::Rm::F::Search_addr)
            && reinterpret_cast<l4_addr_t>(addr) > beg)
          return L4_INVALID_PTR;
      }
    else if (err == 0)
      {
        beg = reinterpret_cast<l4_addr_t>(addr);
        end = max_addr();
      }
    else if (err < 0)
      return L4_INVALID_PTR;

    if (attach_flags & L4Re::Rm::F::In_area)
      {
        Node r = _am.find_node(Region(beg, beg + size - 1));
        if (!r || (r->second.flags() & L4Re::Rm::F::Reserved))
          return L4_INVALID_PTR;

        end = r->first.end();
      }

    if (attach_flags & L4Re::Rm::F::Search_addr)
      {
        beg = find_free(beg, end, size, align, attach_flags);
        if (beg == L4_INVALID_ADDR)
          return L4_INVALID_PTR;
      }

    if (!(attach_flags & (L4Re::Rm::F::Search_addr | L4Re::Rm::F::In_area))
        && _am.find_node(Region(beg, beg + size - 1)))
      return L4_INVALID_PTR;

    if (beg < min_addr() || beg + size - 1 > end)
      return L4_INVALID_PTR;

    if (_rm.insert(Region(beg, beg + size - 1,
                          name, name_len, backing_offset), hdlr).second
        == 0)
      return reinterpret_cast<void*>(beg);

    return L4_INVALID_PTR;
  }

  int detach(void *addr, unsigned long sz, unsigned flags,
             Region *reg, Hdlr *hdlr) noexcept
  {
    l4_addr_t a = reinterpret_cast<l4_addr_t>(addr);
    Region dr(a, a + sz - 1);
    Region res(~0UL, 0);

    Node r = find(dr);
    if (!r)
      return -L4_ENOENT;

    Region g = r->first;
    Hdlr const &h = r->second;

    if (flags & L4Re::Rm::Detach_overlap || dr.contains(g))
      {
        // successful removal of the AVL tree item also frees the node
        Hdlr h_copy = h;

        if (_rm.remove(g))
          return -L4_ENOENT;

        if (!(flags & L4Re::Rm::Detach_keep) && (h_copy.flags() & L4Re::Rm::F::Detach_free))
          h_copy.free(0, g.size());

        if (hdlr)
          *hdlr = h_copy;
        if (reg)
          *reg = g;

        if (find(dr))
          return Rm::Detached_ds | Rm::Detach_again;
        else
          return Rm::Detached_ds;
      }
    else if (dr.start() <= g.start())
      {
        // move the start of a region

        if (!(flags & L4Re::Rm::Detach_keep) && (h.flags() & L4Re::Rm::F::Detach_free))
          h.free(0, dr.end() + 1 - g.start());

        unsigned long sz = dr.end() + 1 - g.start();
        Item &cn = const_cast<Item &>(*r);
        cn.first = Region(dr.end() + 1, g.end(), g.name(), g.name_len(),
                          g.backing_offset() + sz);
        cn.second = cn.second + sz;
        if (hdlr)
          *hdlr = Hdlr();
        if (reg)
          *reg = Region(g.start(), dr.end());
        if (find(dr))
          return Rm::Kept_ds | Rm::Detach_again;
        else
          return Rm::Kept_ds;
      }
    else if (dr.end() >= g.end())
      {
        // move the end of a region

        if (!(flags & L4Re::Rm::Detach_keep)
            && (h.flags() & L4Re::Rm::F::Detach_free))
          h.free(dr.start() - g.start(), g.end() + 1 - dr.start());

        Item &cn = const_cast<Item &>(*r);
        cn.first = Region(g.start(), dr.start() - 1, g.name(), g.name_len(),
                          g.backing_offset());
        if (hdlr)
          *hdlr = Hdlr();
        if (reg)
          *reg = Region(dr.start(), g.end());

        if (find(dr))
          return Rm::Kept_ds | Rm::Detach_again;
        else
          return Rm::Kept_ds;
      }
    else if (g.contains(dr))
      {
        // split a single region that contains the new region

        if (!(flags & L4Re::Rm::Detach_keep) && (h.flags() & L4Re::Rm::F::Detach_free))
          h.free(dr.start() - g.start(), dr.size());

        // first move the end off the existing region before the new one
        Item &cn = const_cast<Item &>(*r);
        cn.first = Region(g.start(), dr.start()-1, g.name(), g.name_len(),
                          g.backing_offset());

        int err;

        // insert a second region for the remaining tail of
        // the old existing region
        auto tail_sz = dr.end() + 1 - g.start();
        err = _rm.insert(Region(dr.end() + 1, g.end(), g.name(), g.name_len(),
                                g.backing_offset() + tail_sz),
                         h + tail_sz).second;

        if (err)
          return err;

        if (hdlr)
          *hdlr = h;
        if (reg)
          *reg = dr;
        return Rm::Split_ds;
      }
    return -L4_ENOENT;
  }

  l4_addr_t find_free(l4_addr_t start, l4_addr_t end, l4_addr_t size,
                      unsigned char align, L4Re::Rm::Flags attach_flags) const noexcept;

};


template< typename Hdlr, template<typename T> class Alloc >
l4_addr_t
Region_map<Hdlr, Alloc>::find_free(l4_addr_t start, l4_addr_t end,
    unsigned long size, unsigned char align, L4Re::Rm::Flags attach_flags) const noexcept
{
  l4_addr_t addr = start;

  if (addr == ~0UL || addr < min_addr() || addr >= end)
    addr = min_addr();

  addr = l4_round_size(addr, align);
  Node r;

  for(;;)
    {
      if (addr > 0 && addr - 1 > end - size)
	return L4_INVALID_ADDR;

      Region c(addr, addr + size - 1);
      r = _rm.find_node(c);

      if (!r)
	{
	  if (!(attach_flags & L4Re::Rm::F::In_area) && (r = _am.find_node(c)))
	    {
	      if (r->first.end() > end - size)
		return L4_INVALID_ADDR;

	      addr = l4_round_size(r->first.end() + 1, align);
	      continue;
	    }
	  break;
	}
      else if (r->first.end() > end - size)
	return L4_INVALID_ADDR;

      addr = l4_round_size(r->first.end() + 1, align);
    }

  if (!r)
    return addr;

  return L4_INVALID_ADDR;
}

}}
