/**********************************************************************
 *
 * Name:     cpl_hash_set.cpp
 * Project:  CPL - Common Portability Library
 * Purpose:  Hash set functions.
 * Author:   Even Rouault, <even dot rouault at mines dash paris dot org>
 *
 **********************************************************************
 * Copyright (c) 2008-2009, Even Rouault <even dot rouault at mines-paris dot org>
 *
 * Permission is hereby granted, free of charge, to any person obtaining a
 * copy of this software and associated documentation files (the "Software"),
 * to deal in the Software without restriction, including without limitation
 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
 * and/or sell copies of the Software, and to permit persons to whom the
 * Software is furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included
 * in all copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
 * DEALINGS IN THE SOFTWARE.
 ****************************************************************************/

#include "cpl_hash_set.h"

#include <cstring>

#include "cpl_conv.h"
#include "cpl_error.h"
#include "cpl_list.h"

CPL_CVSID("$Id: cpl_hash_set.cpp 0846c4df38348216396587449b9cef818856b36c 2018-01-10 16:19:40Z Kurt Schwehr $")

struct _CPLHashSet
{
    CPLHashSetHashFunc    fnHashFunc;
    CPLHashSetEqualFunc   fnEqualFunc;
    CPLHashSetFreeEltFunc fnFreeEltFunc;
    CPLList**             tabList;
    int                   nSize;
    int                   nIndiceAllocatedSize;
    int                   nAllocatedSize;
    CPLList              *psRecyclingList;
    int                   nRecyclingListSize;
    bool                  bRehash;
#ifdef HASH_DEBUG
    int                   nCollisions;
#endif
};

constexpr int anPrimes[] =
{
    53, 97, 193, 389, 769, 1543, 3079, 6151,
    12289, 24593, 49157, 98317, 196613, 393241,
    786433, 1572869, 3145739, 6291469, 12582917,
    25165843, 50331653, 100663319, 201326611,
    402653189, 805306457, 1610612741
};

/************************************************************************/
/*                          CPLHashSetNew()                             */
/************************************************************************/

/**
 * Creates a new hash set
 *
 * The hash function must return a hash value for the elements to insert.
 * If fnHashFunc is NULL, CPLHashSetHashPointer will be used.
 *
 * The equal function must return if two elements are equal.
 * If fnEqualFunc is NULL, CPLHashSetEqualPointer will be used.
 *
 * The free function is used to free elements inserted in the hash set,
 * when the hash set is destroyed, when elements are removed or replaced.
 * If fnFreeEltFunc is NULL, elements inserted into the hash set will not be
 * freed.
 *
 * @param fnHashFunc hash function. May be NULL.
 * @param fnEqualFunc equal function. May be NULL.
 * @param fnFreeEltFunc element free function. May be NULL.
 *
 * @return a new hash set
 */

CPLHashSet* CPLHashSetNew( CPLHashSetHashFunc fnHashFunc,
                           CPLHashSetEqualFunc fnEqualFunc,
                           CPLHashSetFreeEltFunc fnFreeEltFunc )
{
    CPLHashSet* set = static_cast<CPLHashSet *>(CPLMalloc(sizeof(CPLHashSet)));
    set->fnHashFunc = fnHashFunc ? fnHashFunc : CPLHashSetHashPointer;
    set->fnEqualFunc = fnEqualFunc ? fnEqualFunc : CPLHashSetEqualPointer;
    set->fnFreeEltFunc = fnFreeEltFunc;
    set->nSize = 0;
    set->tabList = static_cast<CPLList**>(CPLCalloc(sizeof(CPLList*), 53));
    set->nIndiceAllocatedSize = 0;
    set->nAllocatedSize = 53;
    set->psRecyclingList = nullptr;
    set->nRecyclingListSize = 0;
    set->bRehash = false;
#ifdef HASH_DEBUG
    set->nCollisions = 0;
#endif
    return set;
}

/************************************************************************/
/*                          CPLHashSetSize()                            */
/************************************************************************/

/**
 * Returns the number of elements inserted in the hash set
 *
 * Note: this is not the internal size of the hash set
 *
 * @param set the hash set
 *
 * @return the number of elements in the hash set
 */

int CPLHashSetSize( const CPLHashSet* set )
{
    CPLAssert(set != nullptr);
    return set->nSize;
}

/************************************************************************/
/*                       CPLHashSetGetNewListElt()                      */
/************************************************************************/

static CPLList* CPLHashSetGetNewListElt( CPLHashSet* set )
{
    if( set->psRecyclingList )
    {
        CPLList* psRet = set->psRecyclingList;
        psRet->pData = nullptr;
        set->nRecyclingListSize--;
        set->psRecyclingList = psRet->psNext;
        return psRet;
    }

    return static_cast<CPLList *>(CPLMalloc(sizeof(CPLList)));
}

/************************************************************************/
/*                       CPLHashSetReturnListElt()                      */
/************************************************************************/

static void CPLHashSetReturnListElt( CPLHashSet* set, CPLList* psList )
{
    if( set->nRecyclingListSize < 128 )
    {
        psList->psNext = set->psRecyclingList;
        set->psRecyclingList = psList;
        set->nRecyclingListSize++;
    }
    else
    {
        CPLFree(psList);
    }
}

/************************************************************************/
/*                   CPLHashSetClearInternal()                          */
/************************************************************************/

static void CPLHashSetClearInternal( CPLHashSet* set, bool bFinalize )
{
    CPLAssert(set != nullptr);
    for( int i = 0; i < set->nAllocatedSize; i++ )
    {
        CPLList* cur = set->tabList[i];
        while( cur )
        {
            if( set->fnFreeEltFunc )
                set->fnFreeEltFunc(cur->pData);
            CPLList* psNext = cur->psNext;
            if( bFinalize )
                CPLFree(cur);
            else
                CPLHashSetReturnListElt(set, cur);
            cur = psNext;
        }
        set->tabList[i] = nullptr;
    }
    set->bRehash = false;
}

/************************************************************************/
/*                        CPLHashSetDestroy()                           */
/************************************************************************/

/**
 * Destroys an allocated hash set.
 *
 * This function also frees the elements if a free function was
 * provided at the creation of the hash set.
 *
 * @param set the hash set
 */

void CPLHashSetDestroy( CPLHashSet* set )
{
    CPLHashSetClearInternal(set, true);
    CPLFree(set->tabList);
    CPLListDestroy(set->psRecyclingList);
    CPLFree(set);
}

/************************************************************************/
/*                        CPLHashSetClear()                             */
/************************************************************************/

/**
 * Clear all elements from a hash set.
 *
 * This function also frees the elements if a free function was
 * provided at the creation of the hash set.
 *
 * @param set the hash set
 * @since GDAL 2.1
 */

void CPLHashSetClear( CPLHashSet* set )
{
    CPLHashSetClearInternal(set, false);
    set->tabList = static_cast<CPLList**>(
        CPLRealloc(set->tabList, sizeof(CPLList*) * 53));
    set->nIndiceAllocatedSize = 0;
    set->nAllocatedSize = 53;
#ifdef HASH_DEBUG
    set->nCollisions = 0;
#endif
    set->nSize = 0;
}

/************************************************************************/
/*                       CPLHashSetForeach()                            */
/************************************************************************/

/**
 * Walk through the hash set and runs the provided function on all the
 * elements
 *
 * This function is provided the user_data argument of CPLHashSetForeach.
 * It must return TRUE to go on the walk through the hash set, or FALSE to
 * make it stop.
 *
 * Note : the structure of the hash set must *NOT* be modified during the
 * walk.
 *
 * @param set the hash set.
 * @param fnIterFunc the function called on each element.
 * @param user_data the user data provided to the function.
 */

void CPLHashSetForeach( CPLHashSet* set,
                        CPLHashSetIterEltFunc fnIterFunc,
                        void* user_data )
{
    CPLAssert(set != nullptr);
    if( !fnIterFunc ) return;

    for( int i = 0; i < set->nAllocatedSize; i++ )
    {
        CPLList* cur = set->tabList[i];
        while( cur )
        {
            if( !fnIterFunc(cur->pData, user_data) )
                return;

            cur = cur->psNext;
        }
    }
}

/************************************************************************/
/*                        CPLHashSetRehash()                            */
/************************************************************************/

static void CPLHashSetRehash( CPLHashSet* set )
{
    int nNewAllocatedSize = anPrimes[set->nIndiceAllocatedSize];
    CPLList** newTabList = static_cast<CPLList **>(
        CPLCalloc(sizeof(CPLList*), nNewAllocatedSize));
#ifdef HASH_DEBUG
    CPLDebug("CPLHASH", "hashSet=%p, nSize=%d, nCollisions=%d, "
             "fCollisionRate=%.02f",
             set, set->nSize, set->nCollisions,
             set->nCollisions * 100.0 / set->nSize);
    set->nCollisions = 0;
#endif
    for( int i = 0; i < set->nAllocatedSize; i++ )
    {
        CPLList* cur = set->tabList[i];
        while( cur )
        {
            const unsigned long nNewHashVal =
                set->fnHashFunc(cur->pData) % nNewAllocatedSize;
#ifdef HASH_DEBUG
            if( newTabList[nNewHashVal] )
                set->nCollisions++;
#endif
            CPLList* psNext = cur->psNext;
            cur->psNext = newTabList[nNewHashVal];
            newTabList[nNewHashVal] = cur;
            cur = psNext;
        }
    }
    CPLFree(set->tabList);
    set->tabList = newTabList;
    set->nAllocatedSize = nNewAllocatedSize;
    set->bRehash = false;
}

/************************************************************************/
/*                        CPLHashSetFindPtr()                           */
/************************************************************************/

static void** CPLHashSetFindPtr( CPLHashSet* set, const void* elt )
{
    const unsigned long nHashVal = set->fnHashFunc(elt) % set->nAllocatedSize;
    CPLList* cur = set->tabList[nHashVal];
    while( cur )
    {
        if( set->fnEqualFunc(cur->pData, elt) )
            return &cur->pData;
        cur = cur->psNext;
    }
    return nullptr;
}

/************************************************************************/
/*                         CPLHashSetInsert()                           */
/************************************************************************/

/**
 * Inserts an element into a hash set.
 *
 * If the element was already inserted in the hash set, the previous
 * element is replaced by the new element. If a free function was provided,
 * it is used to free the previously inserted element
 *
 * @param set the hash set
 * @param elt the new element to insert in the hash set
 *
 * @return TRUE if the element was not already in the hash set
 */

int CPLHashSetInsert( CPLHashSet* set, void* elt )
{
    CPLAssert(set != nullptr);
    void** pElt = CPLHashSetFindPtr(set, elt);
    if( pElt )
    {
        if( set->fnFreeEltFunc )
            set->fnFreeEltFunc(*pElt);

        *pElt = elt;
        return FALSE;
    }

    if( set->nSize >= 2 * set->nAllocatedSize / 3 ||
        (set->bRehash && set->nIndiceAllocatedSize > 0 &&
         set->nSize <= set->nAllocatedSize / 2) )
    {
        set->nIndiceAllocatedSize++;
        CPLHashSetRehash(set);
    }

    const unsigned long nHashVal = set->fnHashFunc(elt) % set->nAllocatedSize;
#ifdef HASH_DEBUG
    if( set->tabList[nHashVal] )
        set->nCollisions++;
#endif

    CPLList* new_elt = CPLHashSetGetNewListElt(set);
    new_elt->pData = elt;
    new_elt->psNext = set->tabList[nHashVal];
    set->tabList[nHashVal] = new_elt;
    set->nSize++;

    return TRUE;
}

/************************************************************************/
/*                        CPLHashSetLookup()                            */
/************************************************************************/

/**
 * Returns the element found in the hash set corresponding to the element to look up
 * The element must not be modified.
 *
 * @param set the hash set
 * @param elt the element to look up in the hash set
 *
 * @return the element found in the hash set or NULL
 */

void* CPLHashSetLookup( CPLHashSet* set, const void* elt )
{
    CPLAssert(set != nullptr);
    void** pElt = CPLHashSetFindPtr(set, elt);
    if( pElt )
        return *pElt;

    return nullptr;
}

/************************************************************************/
/*                     CPLHashSetRemoveInternal()                       */
/************************************************************************/

static
bool CPLHashSetRemoveInternal( CPLHashSet* set, const void* elt,
                               bool bDeferRehash )
{
    CPLAssert(set != nullptr);
    if( set->nIndiceAllocatedSize > 0 && set->nSize <= set->nAllocatedSize / 2 )
    {
        set->nIndiceAllocatedSize--;
        if( bDeferRehash )
            set->bRehash = true;
        else
            CPLHashSetRehash(set);
    }

    int nHashVal = static_cast<int>(set->fnHashFunc(elt) % set->nAllocatedSize);
    CPLList* cur = set->tabList[nHashVal];
    CPLList* prev = nullptr;
    while( cur )
    {
        if( set->fnEqualFunc(cur->pData, elt) )
        {
            if( prev )
                prev->psNext = cur->psNext;
            else
                set->tabList[nHashVal] = cur->psNext;

            if( set->fnFreeEltFunc )
                set->fnFreeEltFunc(cur->pData);

            CPLHashSetReturnListElt(set, cur);
#ifdef HASH_DEBUG
            if( set->tabList[nHashVal] )
                set->nCollisions--;
#endif
            set->nSize--;
            return true;
        }
        prev = cur;
        cur = cur->psNext;
    }
    return false;
}

/************************************************************************/
/*                         CPLHashSetRemove()                           */
/************************************************************************/

/**
 * Removes an element from a hash set
 *
 * @param set the hash set
 * @param elt the new element to remove from the hash set
 *
 * @return TRUE if the element was in the hash set
 */

int CPLHashSetRemove( CPLHashSet* set, const void* elt )
{
    return CPLHashSetRemoveInternal(set, elt, false);
}

/************************************************************************/
/*                     CPLHashSetRemoveDeferRehash()                    */
/************************************************************************/

/**
 * Removes an element from a hash set.
 *
 * This will defer potential rehashing of the set to later calls to
 * CPLHashSetInsert() or CPLHashSetRemove().
 *
 * @param set the hash set
 * @param elt the new element to remove from the hash set
 *
 * @return TRUE if the element was in the hash set
 * @since GDAL 2.1
 */

int CPLHashSetRemoveDeferRehash( CPLHashSet* set, const void* elt )
{
    return CPLHashSetRemoveInternal(set, elt, true);
}

/************************************************************************/
/*                    CPLHashSetHashPointer()                           */
/************************************************************************/

/**
 * Hash function for an arbitrary pointer
 *
 * @param elt the arbitrary pointer to hash
 *
 * @return the hash value of the pointer
 */

unsigned long CPLHashSetHashPointer( const void* elt )
{
    return static_cast<unsigned long>(
        reinterpret_cast<GUIntptr_t>(const_cast<void *>(elt)));
}

/************************************************************************/
/*                   CPLHashSetEqualPointer()                           */
/************************************************************************/

/**
 * Equality function for arbitrary pointers
 *
 * @param elt1 the first arbitrary pointer to compare
 * @param elt2 the second arbitrary pointer to compare
 *
 * @return TRUE if the pointers are equal
 */

int CPLHashSetEqualPointer( const void* elt1, const void* elt2 )
{
    return elt1 == elt2;
}

/************************************************************************/
/*                        CPLHashSetHashStr()                           */
/************************************************************************/

/**
 * Hash function for a zero-terminated string
 *
 * @param elt the string to hash. May be NULL.
 *
 * @return the hash value of the string
 */

CPL_NOSANITIZE_UNSIGNED_INT_OVERFLOW
unsigned long CPLHashSetHashStr( const void *elt )
{
    if( elt == nullptr )
        return 0;

    const unsigned char* pszStr = static_cast<const unsigned char *>(elt);
    unsigned long hash = 0;

    int c = 0;
    while( (c = *pszStr++) != '\0' )
        hash = c + (hash << 6) + (hash << 16) - hash;

    return hash;
}

/************************************************************************/
/*                     CPLHashSetEqualStr()                             */
/************************************************************************/

/**
 * Equality function for strings
 *
 * @param elt1 the first string to compare. May be NULL.
 * @param elt2 the second string to compare. May be NULL.
 *
 * @return TRUE if the strings are equal
 */

int CPLHashSetEqualStr( const void* elt1, const void* elt2 )
{
    const char* pszStr1 = static_cast<const char *>(elt1);
    const char* pszStr2 = static_cast<const char *>(elt2);

    if( pszStr1 == nullptr && pszStr2 != nullptr )
        return FALSE;

    if( pszStr1 != nullptr && pszStr2 == nullptr )
        return FALSE;

    if( pszStr1 == nullptr && pszStr2 == nullptr )
        return TRUE;

    return strcmp(pszStr1, pszStr2) == 0;
}
