#include "dowa.h"

// --- Arena --- //
Dowa_Arena *Dowa_Arena_Create(size_t capacity)
{
  Dowa_Arena *p_arena = malloc(sizeof(Dowa_Arena));
  if (p_arena == NULL)
  {
    perror("malloc");
    return NULL;
  }
  p_arena->buffer = malloc(capacity);
  if (p_arena->buffer == NULL)
  {
    perror("malloc");
    Dowa_Free(p_arena);
    return NULL;
  }
  p_arena->offset = 0;
  p_arena->capacity = capacity;
  return p_arena;
}

void *Dowa_Arena_Allocate(Dowa_Arena *p_arena, size_t size)
{
  return Dowa_Arena_Allocate_Aligned(p_arena, size, sizeof(void*) * 2);
}

void Dowa_Arena_Free(Dowa_Arena *p_arena)
{
  if (!p_arena)
    return;

  if (p_arena->buffer)
    Dowa_Free(p_arena->buffer);
  Dowa_Free(p_arena);
}

void *Dowa_Arena_Copy(Dowa_Arena *p_arena, const void *src, size_t size)
{
  if (p_arena == NULL || src == NULL || size == 0)
    return NULL;

  void *dest = Dowa_Arena_Allocate(p_arena, size);
  if (!dest)
    return NULL;

  memcpy(dest, src, size);
  return dest;
}

void Dowa_Arena_Reset(Dowa_Arena *p_arena)
{
  if (!p_arena) return;
  p_arena->offset = 0;
}

size_t Dowa_Arena_Get_Used(Dowa_Arena *p_arena)
{
  if (!p_arena) return 0;
  return p_arena->offset;
}

size_t Dowa_Arena_Get_Remaining(Dowa_Arena *p_arena)
{
  if (!p_arena) return 0;
  return p_arena->capacity - p_arena->offset;
}

void *Dowa_Arena_Allocate_Aligned(Dowa_Arena *p_arena, size_t size, size_t alignment)
{
  if (!p_arena || !p_arena->buffer || size == 0 || alignment == 0)
    return NULL;

  size_t current_address = (size_t)(p_arena->buffer + p_arena->offset);
  size_t aligned_address = (current_address + alignment - 1) & ~(alignment - 1);
  size_t padding = aligned_address - current_address;

  if (p_arena->offset + padding + size > p_arena->capacity)
    return NULL;

  p_arena->offset += padding;
  void *p_result = p_arena->buffer + p_arena->offset;
  p_arena->offset += size;
  return p_result;
}

// --- NEW stb_ds-style Array Implementation --- //

void* dowa__array_grow(void* p_array, size_t element_size, size_t minimum_capacity, Dowa_Arena* p_arena)
{
  Dowa_Array_Header* p_header;
  size_t new_capacity;
  size_t current_capacity;

  if (p_array)
  {
    p_header = dowa__header(p_array);
    current_capacity = p_header->capacity;

    if (p_header->allocator_type == DOWA_ALLOCATOR_ARENA && p_header->p_arena != p_arena)
    {
      fprintf(stderr, "Error: Cannot mix arena allocators\n");
      return p_array;
    }
  }
  else
  {
    current_capacity = 0;
  }

  // Calculate needed capacity: if minimum_capacity is 0, we need room for at least one more element
  size_t needed_capacity = minimum_capacity;
  if (p_array && needed_capacity == 0)
    needed_capacity = p_header->length + 1;

  if (current_capacity >= needed_capacity && p_array != NULL)
    return p_array;

  new_capacity = current_capacity * 2;
  if (new_capacity < 4)
    new_capacity = 4;
  if (new_capacity < needed_capacity)
    new_capacity = needed_capacity;

  size_t total_size = sizeof(Dowa_Array_Header) + (element_size * new_capacity);

  if (p_arena)
  {
    // Array header and data must be properly aligned
    size_t alignment = element_size >= 16 ? 16 : (element_size >= 8 ? 8 : element_size);
    Dowa_Array_Header* p_new_header = (Dowa_Array_Header*)Dowa_Arena_Allocate_Aligned(p_arena, total_size, alignment);
    if (!p_new_header)
      return p_array;

    void* p_new_array = (char*)p_new_header + sizeof(Dowa_Array_Header);

    if (p_array)
    {
      memcpy(p_new_array, p_array, element_size * p_header->length);
      p_new_header->length = p_header->length;
    }
    else
    {
      p_new_header->length = 0;
    }

    p_new_header->capacity = new_capacity;
    p_new_header->allocator_type = DOWA_ALLOCATOR_ARENA;
    p_new_header->p_arena = p_arena;
    p_new_header->p_hash = p_array ? p_header->p_hash : NULL;

    return p_new_array;
  }
  else
  {
    Dowa_Array_Header* p_new_header;

    if (p_array)
    {
      p_header = dowa__header(p_array);
      p_new_header = (Dowa_Array_Header*)realloc(p_header, total_size);
    }
    else
    {
      p_new_header = (Dowa_Array_Header*)malloc(total_size);
      if (p_new_header)
        p_new_header->length = 0;
    }

    if (!p_new_header)
      return p_array;

    p_new_header->capacity = new_capacity;
    p_new_header->allocator_type = DOWA_ALLOCATOR_MALLOC;
    p_new_header->p_arena = NULL;
    p_new_header->p_hash = NULL;

    return (char*)p_new_header + sizeof(Dowa_Array_Header);
  }
}

void dowa__array_free(void* p_array)
{
  if (!p_array)
    return;

  Dowa_Array_Header* p_header = dowa__header(p_array);

  if (p_header->allocator_type == DOWA_ALLOCATOR_MALLOC)
    free(p_header);
}

// --- NEW stb_ds-style HashMap Implementation --- //

#define DOWA_HASH_EMPTY 0xFFFFFFFF
#define DOWA_HASH_TOMBSTONE 0xFFFFFFFE

uint32 dowa__hash_bytes(void* p_key, size_t key_size)
{
  uint32 hash = HASH_KEY_NUMBER;
  uint8* p_bytes = (uint8*)p_key;

  for (size_t i = 0; i < key_size; i++)
    hash = ((hash << 5) + hash) + p_bytes[i];

  if (hash == DOWA_HASH_EMPTY || hash == DOWA_HASH_TOMBSTONE)
    hash = HASH_KEY_NUMBER;

  return hash;
}

static Dowa_Hash_Index* dowa__hashmap_get_index(void* p_map)
{
  if (!p_map)
    return NULL;

  Dowa_Array_Header* p_header = dowa__header(p_map);
  return (Dowa_Hash_Index*)p_header->p_hash;
}

static void* dowa__hashmap_find_slot(void* p_map, size_t element_size, void* p_key, size_t key_size, uint32 hash, boolean* p_found)
{
  Dowa_Hash_Index* p_index = dowa__hashmap_get_index(p_map);
  if (!p_index || !p_index->p_buckets)
  {
    *p_found = FALSE;
    return NULL;
  }

  size_t bucket_mask = p_index->bucket_count - 1;
  size_t bucket_index = hash & bucket_mask;
  size_t probe_step = 0;

  while (1)
  {
    Dowa_Hash_Bucket* p_bucket = &p_index->p_buckets[bucket_index];

    for (size_t i = 0; i < DOWA_HASH_BUCKET_SIZE; i++)
    {
      if (p_bucket->hash[i] == DOWA_HASH_EMPTY)
      {
        *p_found = FALSE;
        return NULL;
      }

      if (p_bucket->hash[i] == hash && p_bucket->index[i] != DOWA_HASH_TOMBSTONE)
      {
        uint32 array_index = p_bucket->index[i];
        char* p_element = (char*)p_map + (array_index * element_size);

        char** p_stored_key_ptr = (char**)p_element;
        char* p_stored_key = *p_stored_key_ptr;
        char* p_search_key = (char*)p_key;

        if (memcmp(p_stored_key, p_search_key, key_size) == 0)
        {
          *p_found = TRUE;
          return p_element;
        }
      }
    }

    probe_step++;
    bucket_index = (hash + (probe_step * probe_step)) & bucket_mask;

    if (probe_step > p_index->bucket_count)
      break;
  }

  *p_found = FALSE;
  return NULL;
}

void* dowa__hashmap_get_ptr(void* p_map, size_t element_size, void* p_key, size_t key_size)
{
  if (!p_map)
    return NULL;

  uint32 hash = dowa__hash_bytes(p_key, key_size);
  boolean found = FALSE;
  void* p_result = dowa__hashmap_find_slot(p_map, element_size, p_key, key_size, hash, &found);

  return found ? p_result : NULL;
}

void* dowa__hashmap_get(void* p_map, size_t element_size, void* p_key, size_t key_size)
{
  void* p_kv = dowa__hashmap_get_ptr(p_map, element_size, p_key, key_size);
  if (!p_kv)
    return NULL;

  return (char*)p_kv + key_size;
}

boolean dowa__hashmap_has_key(void* p_map, size_t element_size, void* p_key, size_t key_size)
{
  return dowa__hashmap_get_ptr(p_map, element_size, p_key, key_size) != NULL;
}

static void* dowa__hashmap_rehash(void* p_map, size_t element_size, size_t new_bucket_count, Dowa_Arena* p_arena)
{
  Dowa_Hash_Index* p_old_index = dowa__hashmap_get_index(p_map);
  if (!p_old_index)
    return p_map;

  size_t bucket_alloc_size = sizeof(Dowa_Hash_Bucket) * new_bucket_count;
  Dowa_Hash_Bucket* p_new_buckets;

  if (p_arena)
  {
    p_new_buckets = (Dowa_Hash_Bucket*)Dowa_Arena_Allocate_Aligned(p_arena, bucket_alloc_size, DOWA_HASH_CACHE_LINE_SIZE);
  }
  else
  {
    if (posix_memalign((void**)&p_new_buckets, DOWA_HASH_CACHE_LINE_SIZE, bucket_alloc_size) != 0)
      p_new_buckets = NULL;
  }

  if (!p_new_buckets)
    return p_map;

  for (size_t i = 0; i < new_bucket_count; i++)
  {
    for (size_t j = 0; j < DOWA_HASH_BUCKET_SIZE; j++)
    {
      p_new_buckets[i].hash[j] = DOWA_HASH_EMPTY;
      p_new_buckets[i].index[j] = DOWA_HASH_EMPTY;
    }
  }

  Dowa_Hash_Index* p_new_index;
  if (p_arena)
    p_new_index = (Dowa_Hash_Index*)Dowa_Arena_Allocate_Aligned(p_arena, sizeof(Dowa_Hash_Index), 16);
  else
    p_new_index = (Dowa_Hash_Index*)malloc(sizeof(Dowa_Hash_Index));

  if (!p_new_index)
  {
    if (!p_arena)
      free(p_new_buckets);
    return p_map;
  }

  p_new_index->bucket_count = new_bucket_count;
  p_new_index->item_count = 0;
  p_new_index->tombstone_count = 0;
  p_new_index->allocator_type = p_arena ? DOWA_ALLOCATOR_ARENA : DOWA_ALLOCATOR_MALLOC;
  p_new_index->p_arena = p_arena;
  p_new_index->p_buckets = p_new_buckets;

  size_t map_length = Dowa_Array_Length(p_map);
  for (size_t i = 0; i < map_length; i++)
  {
    char* p_element = (char*)p_map + (i * element_size);
    char** p_key_ptr = (char**)p_element;
    char* p_key_data = *p_key_ptr;
    uint32 hash = dowa__hash_bytes(p_key_data, strlen(p_key_data) + 1);

    size_t bucket_mask = new_bucket_count - 1;
    size_t bucket_index = hash & bucket_mask;
    size_t probe_step = 0;
    boolean inserted = FALSE;

    while (!inserted && probe_step <= new_bucket_count)
    {
      Dowa_Hash_Bucket* p_bucket = &p_new_buckets[bucket_index];

      for (size_t j = 0; j < DOWA_HASH_BUCKET_SIZE; j++)
      {
        if (p_bucket->hash[j] == DOWA_HASH_EMPTY)
        {
          p_bucket->hash[j] = hash;
          p_bucket->index[j] = (uint32)i;
          p_new_index->item_count++;
          inserted = TRUE;
          break;
        }
      }

      if (!inserted)
      {
        probe_step++;
        bucket_index = (hash + (probe_step * probe_step)) & bucket_mask;
      }
    }
  }

  if (p_old_index->allocator_type == DOWA_ALLOCATOR_MALLOC)
  {
    if (p_old_index->p_buckets)
      free(p_old_index->p_buckets);
    free(p_old_index);
  }

  dowa__header(p_map)->p_hash = p_new_index;
  return p_map;
}

void* dowa__hashmap_push(void* p_map, size_t element_size, void* p_key, size_t key_size, Dowa_Arena* p_arena)
{
  uint32 hash = dowa__hash_bytes(p_key, key_size);

  if (p_map)
  {
    boolean found = FALSE;
    dowa__hashmap_find_slot(p_map, element_size, p_key, key_size, hash, &found);
    if (found)
      return p_map;
  }

  p_map = dowa__array_grow(p_map, element_size, 0, p_arena);

  Dowa_Hash_Index* p_index = dowa__hashmap_get_index(p_map);
  if (!p_index)
  {
    size_t initial_bucket_count = 4;
    size_t bucket_alloc_size = sizeof(Dowa_Hash_Bucket) * initial_bucket_count;
    Dowa_Hash_Bucket* p_buckets;

    if (p_arena)
    {
      p_buckets = (Dowa_Hash_Bucket*)Dowa_Arena_Allocate_Aligned(p_arena, bucket_alloc_size, DOWA_HASH_CACHE_LINE_SIZE);
    }
    else
    {
      if (posix_memalign((void**)&p_buckets, DOWA_HASH_CACHE_LINE_SIZE, bucket_alloc_size) != 0)
        p_buckets = NULL;
    }

    if (!p_buckets)
      return p_map;

    for (size_t i = 0; i < initial_bucket_count; i++)
    {
      for (size_t j = 0; j < DOWA_HASH_BUCKET_SIZE; j++)
      {
        p_buckets[i].hash[j] = DOWA_HASH_EMPTY;
        p_buckets[i].index[j] = DOWA_HASH_EMPTY;
      }
    }

    if (p_arena)
      p_index = (Dowa_Hash_Index*)Dowa_Arena_Allocate_Aligned(p_arena, sizeof(Dowa_Hash_Index), 16);
    else
      p_index = (Dowa_Hash_Index*)malloc(sizeof(Dowa_Hash_Index));

    if (!p_index)
    {
      if (!p_arena)
        free(p_buckets);
      return p_map;
    }

    p_index->bucket_count = initial_bucket_count;
    p_index->item_count = 0;
    p_index->tombstone_count = 0;
    p_index->allocator_type = p_arena ? DOWA_ALLOCATOR_ARENA : DOWA_ALLOCATOR_MALLOC;
    p_index->p_arena = p_arena;
    p_index->p_buckets = p_buckets;

    dowa__header(p_map)->p_hash = p_index;
  }

  // Rehash when load factor exceeds ~50% (bucket_count * BUCKET_SIZE * 0.5)
  // This ensures we rehash before probe chains get too long
  if (p_index->item_count + p_index->tombstone_count >= p_index->bucket_count * 4)
    p_map = dowa__hashmap_rehash(p_map, element_size, p_index->bucket_count * 2, p_arena);

  p_index = dowa__hashmap_get_index(p_map);

  Dowa_Array_Header* p_header = dowa__header(p_map);
  uint32 new_index = (uint32)p_header->length;
  char* p_new_element = (char*)p_map + (new_index * element_size);

  char* p_key_copy;
  if (p_arena)
    p_key_copy = (char*)Dowa_Arena_Allocate(p_arena, key_size);
  else
    p_key_copy = (char*)malloc(key_size);

  if (!p_key_copy)
    return p_map;

  memcpy(p_key_copy, p_key, key_size);

  // Store the pointer value (use memcpy to avoid alignment issues)
  memcpy(p_new_element, &p_key_copy, sizeof(p_key_copy));
  p_header->length++;

  // Try to insert into hash table - if it fails, rehash and retry
  boolean inserted = FALSE;
  for (int attempt = 0; attempt < 2 && !inserted; attempt++)
  {
    if (attempt == 1)
    {
      // First attempt failed - rehash to make more room
      p_header->length--;  // Undo the array length increment temporarily
      p_map = dowa__hashmap_rehash(p_map, element_size, p_index->bucket_count * 2, p_arena);
      p_index = dowa__hashmap_get_index(p_map);
      p_header = dowa__header(p_map);

      // Restore array state
      new_index = (uint32)p_header->length;
      p_new_element = (char*)p_map + (new_index * element_size);
      memcpy(p_new_element, &p_key_copy, sizeof(p_key_copy));
      p_header->length++;
    }

    size_t bucket_mask = p_index->bucket_count - 1;
    size_t bucket_index = hash & bucket_mask;
    size_t probe_step = 0;

    while (!inserted && probe_step <= p_index->bucket_count)
    {
      Dowa_Hash_Bucket* p_bucket = &p_index->p_buckets[bucket_index];

      for (size_t i = 0; i < DOWA_HASH_BUCKET_SIZE; i++)
      {
        if (p_bucket->hash[i] == DOWA_HASH_EMPTY || p_bucket->hash[i] == DOWA_HASH_TOMBSTONE)
        {
          if (p_bucket->hash[i] == DOWA_HASH_TOMBSTONE)
            p_index->tombstone_count--;

          p_bucket->hash[i] = hash;
          p_bucket->index[i] = new_index;
          p_index->item_count++;
          inserted = TRUE;
          break;
        }
      }

      if (!inserted)
      {
        probe_step++;
        bucket_index = (hash + (probe_step * probe_step)) & bucket_mask;
      }
    }
  }

  return p_map;
}

void dowa__hashmap_delete(void* p_map, size_t element_size, void* p_key, size_t key_size)
{
  if (!p_map)
    return;

  Dowa_Hash_Index* p_index = dowa__hashmap_get_index(p_map);
  if (!p_index || !p_index->p_buckets)
    return;

  uint32 hash = dowa__hash_bytes(p_key, key_size);
  size_t bucket_mask = p_index->bucket_count - 1;
  size_t bucket_index = hash & bucket_mask;
  size_t probe_step = 0;

  while (probe_step <= p_index->bucket_count)
  {
    Dowa_Hash_Bucket* p_bucket = &p_index->p_buckets[bucket_index];

    for (size_t i = 0; i < DOWA_HASH_BUCKET_SIZE; i++)
    {
      if (p_bucket->hash[i] == DOWA_HASH_EMPTY)
        return;

      if (p_bucket->hash[i] == hash && p_bucket->index[i] != DOWA_HASH_TOMBSTONE)
      {
        uint32 array_index = p_bucket->index[i];
        char* p_element = (char*)p_map + (array_index * element_size);

        char** p_stored_key_ptr = (char**)p_element;
        char* p_stored_key = *p_stored_key_ptr;
        char* p_search_key = (char*)p_key;

        if (memcmp(p_stored_key, p_search_key, key_size) == 0)
        {
          if (dowa__header(p_map)->allocator_type == DOWA_ALLOCATOR_MALLOC)
            free(p_stored_key);

          p_bucket->hash[i] = DOWA_HASH_TOMBSTONE;
          p_bucket->index[i] = DOWA_HASH_TOMBSTONE;
          p_index->item_count--;
          p_index->tombstone_count++;
          return;
        }
      }
    }

    probe_step++;
    bucket_index = (hash + (probe_step * probe_step)) & bucket_mask;
  }
}

void dowa__hashmap_clear(void* p_map, size_t element_size)
{
  if (!p_map)
    return;

  Dowa_Array_Header* p_header = dowa__header(p_map);
  p_header->length = 0;

  Dowa_Hash_Index* p_index = dowa__hashmap_get_index(p_map);
  if (p_index && p_index->p_buckets)
  {
    for (size_t i = 0; i < p_index->bucket_count; i++)
    {
      for (size_t j = 0; j < DOWA_HASH_BUCKET_SIZE; j++)
      {
        p_index->p_buckets[i].hash[j] = DOWA_HASH_EMPTY;
        p_index->p_buckets[i].index[j] = DOWA_HASH_EMPTY;
      }
    }
    p_index->item_count = 0;
    p_index->tombstone_count = 0;
  }
}

void dowa__hashmap_free(void* p_map)
{
  if (!p_map)
    return;

  Dowa_Hash_Index* p_index = dowa__hashmap_get_index(p_map);
  if (p_index && p_index->allocator_type == DOWA_ALLOCATOR_MALLOC)
  {
    if (p_index->p_buckets)
      free(p_index->p_buckets);
    free(p_index);
  }

  dowa__array_free(p_map);
}

size_t dowa__hashmap_count(void* p_map)
{
  if (!p_map)
    return 0;

  Dowa_Hash_Index* p_index = dowa__hashmap_get_index(p_map);
  return p_index ? p_index->item_count : 0;
}
