// SPDX-License-Identifier: MIT pragma solidity >=0.8.4 <0.9.0; import {IDiamondCut} from "../interfaces/IDiamondCut.sol"; library LibDiamond { bytes32 internal constant DIAMOND_STORAGE_POSITION = keccak256("diamond.standard.diamond.storage"); struct FacetAddressAndPosition { address facetAddress; uint96 functionSelectorPosition; // position in facetFunctionSelectors.functionSelectors array } struct FacetFunctionSelectors { bytes4[] functionSelectors; uint256 facetAddressPosition; // position of facetAddress in facetAddresses array } struct DiamondStorage { // maps function selector to the facet address and // the position of the selector in the facetFunctionSelectors.selectors array mapping(bytes4 => FacetAddressAndPosition) selectorToFacetAndPosition; // maps facet addresses to function selectors mapping(address => FacetFunctionSelectors) facetFunctionSelectors; // facet addresses address[] facetAddresses; // Used to query if a contract implements an interface. // Used to implement ERC-165. mapping(bytes4 => bool) supportedInterfaces; // owner of the contract address contractOwner; } function diamondStorage() internal pure returns (DiamondStorage storage ds) { bytes32 position = DIAMOND_STORAGE_POSITION; // solhint-disable-next-line no-inline-assembly assembly { ds.slot := position } } event OwnershipTransferred( address indexed previousOwner, address indexed newOwner ); function setContractOwner(address _newOwner) internal { DiamondStorage storage ds = diamondStorage(); address previousOwner = ds.contractOwner; ds.contractOwner = _newOwner; emit OwnershipTransferred(previousOwner, _newOwner); } function contractOwner() internal view returns (address contractOwner_) { contractOwner_ = diamondStorage().contractOwner; } function enforceIsContractOwner() internal view { require( msg.sender == diamondStorage().contractOwner, "LibDiamond: Must be contract owner" ); } event DiamondCut( IDiamondCut.FacetCut[] _diamondCut, address _init, bytes _calldata ); // Internal function version of diamondCut function diamondCut( IDiamondCut.FacetCut[] memory _diamondCut, address _init, bytes memory _calldata ) internal { for ( uint256 facetIndex; facetIndex < _diamondCut.length; facetIndex++ ) { IDiamondCut.FacetCutAction action = _diamondCut[facetIndex].action; if (action == IDiamondCut.FacetCutAction.Add) { addFunctions( _diamondCut[facetIndex].facetAddress, _diamondCut[facetIndex].functionSelectors ); } else if (action == IDiamondCut.FacetCutAction.Replace) { replaceFunctions( _diamondCut[facetIndex].facetAddress, _diamondCut[facetIndex].functionSelectors ); } else if (action == IDiamondCut.FacetCutAction.Remove) { removeFunctions( _diamondCut[facetIndex].facetAddress, _diamondCut[facetIndex].functionSelectors ); } else { revert("LibDiamondCut: Incorrect FacetCutAction"); } } emit DiamondCut(_diamondCut, _init, _calldata); initializeDiamondCut(_init, _calldata); } function addFunctions( address _facetAddress, bytes4[] memory _functionSelectors ) internal { require( _functionSelectors.length > 0, "LibDiamondCut: No selectors in facet to cut" ); DiamondStorage storage ds = diamondStorage(); require( _facetAddress != address(0), "LibDiamondCut: Add facet can't be address(0)" ); uint96 selectorPosition = uint96( ds.facetFunctionSelectors[_facetAddress].functionSelectors.length ); // add new facet address if it does not exist if (selectorPosition == 0) { addFacet(ds, _facetAddress); } for ( uint256 selectorIndex; selectorIndex < _functionSelectors.length; selectorIndex++ ) { bytes4 selector = _functionSelectors[selectorIndex]; address oldFacetAddress = ds .selectorToFacetAndPosition[selector] .facetAddress; require( oldFacetAddress == address(0), "LibDiamondCut: Can't add function that already exists" ); addFunction(ds, selector, selectorPosition, _facetAddress); selectorPosition++; } } function replaceFunctions( address _facetAddress, bytes4[] memory _functionSelectors ) internal { require( _functionSelectors.length > 0, "LibDiamondCut: No selectors in facet to cut" ); DiamondStorage storage ds = diamondStorage(); require( _facetAddress != address(0), "LibDiamondCut: Add facet can't be address(0)" ); uint96 selectorPosition = uint96( ds.facetFunctionSelectors[_facetAddress].functionSelectors.length ); // add new facet address if it does not exist if (selectorPosition == 0) { addFacet(ds, _facetAddress); } for ( uint256 selectorIndex; selectorIndex < _functionSelectors.length; selectorIndex++ ) { bytes4 selector = _functionSelectors[selectorIndex]; address oldFacetAddress = ds .selectorToFacetAndPosition[selector] .facetAddress; require( oldFacetAddress != _facetAddress, "LibDiamondCut: Can't replace function with same function" ); removeFunction(ds, oldFacetAddress, selector); addFunction(ds, selector, selectorPosition, _facetAddress); selectorPosition++; } } function removeFunctions( address _facetAddress, bytes4[] memory _functionSelectors ) internal { require( _functionSelectors.length > 0, "LibDiamondCut: No selectors in facet to cut" ); DiamondStorage storage ds = diamondStorage(); // if function does not exist then do nothing and return require( _facetAddress == address(0), "LibDiamondCut: Remove facet address must be address(0)" ); for ( uint256 selectorIndex; selectorIndex < _functionSelectors.length; selectorIndex++ ) { bytes4 selector = _functionSelectors[selectorIndex]; address oldFacetAddress = ds .selectorToFacetAndPosition[selector] .facetAddress; removeFunction(ds, oldFacetAddress, selector); } } function addFacet(DiamondStorage storage ds, address _facetAddress) internal { enforceHasContractCode( _facetAddress, "LibDiamondCut: New facet has no code" ); ds.facetFunctionSelectors[_facetAddress].facetAddressPosition = ds .facetAddresses .length; ds.facetAddresses.push(_facetAddress); } function addFunction( DiamondStorage storage ds, bytes4 _selector, uint96 _selectorPosition, address _facetAddress ) internal { ds .selectorToFacetAndPosition[_selector] .functionSelectorPosition = _selectorPosition; ds.facetFunctionSelectors[_facetAddress].functionSelectors.push( _selector ); ds.selectorToFacetAndPosition[_selector].facetAddress = _facetAddress; } function removeFunction( DiamondStorage storage ds, address _facetAddress, bytes4 _selector ) internal { require( _facetAddress != address(0), "LibDiamondCut: Can't remove function that doesn't exist" ); // an immutable function is a function defined directly in a diamond require( _facetAddress != address(this), "LibDiamondCut: Can't remove immutable function" ); // replace selector with last selector, then delete last selector uint256 selectorPosition = ds .selectorToFacetAndPosition[_selector] .functionSelectorPosition; uint256 lastSelectorPosition = ds .facetFunctionSelectors[_facetAddress] .functionSelectors .length - 1; // if not the same then replace _selector with lastSelector if (selectorPosition != lastSelectorPosition) { bytes4 lastSelector = ds .facetFunctionSelectors[_facetAddress] .functionSelectors[lastSelectorPosition]; ds.facetFunctionSelectors[_facetAddress].functionSelectors[ selectorPosition ] = lastSelector; ds .selectorToFacetAndPosition[lastSelector] .functionSelectorPosition = uint96(selectorPosition); } // delete the last selector ds.facetFunctionSelectors[_facetAddress].functionSelectors.pop(); delete ds.selectorToFacetAndPosition[_selector]; // if no more selectors for facet address then delete the facet address if (lastSelectorPosition == 0) { // replace facet address with last facet address and delete last facet address uint256 lastFacetAddressPosition = ds.facetAddresses.length - 1; uint256 facetAddressPosition = ds .facetFunctionSelectors[_facetAddress] .facetAddressPosition; if (facetAddressPosition != lastFacetAddressPosition) { address lastFacetAddress = ds.facetAddresses[ lastFacetAddressPosition ]; ds.facetAddresses[facetAddressPosition] = lastFacetAddress; ds .facetFunctionSelectors[lastFacetAddress] .facetAddressPosition = facetAddressPosition; } ds.facetAddresses.pop(); delete ds .facetFunctionSelectors[_facetAddress] .facetAddressPosition; } } function initializeDiamondCut(address _init, bytes memory _calldata) internal { if (_init == address(0)) { require( _calldata.length == 0, "LibDiamondCut: _init is address(0) but_calldata is not empty" ); } else { require( _calldata.length > 0, "LibDiamondCut: _calldata is empty but _init is not address(0)" ); if (_init != address(this)) { enforceHasContractCode( _init, "LibDiamondCut: _init address has no code" ); } // solhint-disable-next-line avoid-low-level-calls (bool success, bytes memory error) = _init.delegatecall(_calldata); if (!success) { if (error.length > 0) { // bubble up the error revert(string(error)); } else { revert("LibDiamondCut: _init function reverted"); } } } } function enforceHasContractCode( address _contract, string memory _errorMessage ) internal view { uint256 contractSize; // solhint-disable-next-line no-inline-assembly assembly { contractSize := extcodesize(_contract) } require(contractSize > 0, _errorMessage); } }