#pragma once #include #include #include #include #include #include #include #include #include #include #include #include #include namespace Poco { class Logger; } namespace Poco::Net { class IPAddress; } namespace DB { struct User; class Credentials; class ExternalAuthenticators; enum class AuthenticationType : uint8_t; class BackupEntriesCollector; class RestorerFromBackup; /// Result of authentication struct AuthResult { UUID user_id; /// Session settings received from authentication server (if any) SettingsChanges settings{}; AuthenticationData authentication_data {}; }; /// Contains entities, i.e. instances of classes derived from IAccessEntity. /// The implementations of this class MUST be thread-safe. class IAccessStorage : public boost::noncopyable { public: explicit IAccessStorage(const String & storage_name_) : storage_name(storage_name_) {} virtual ~IAccessStorage() = default; /// If the AccessStorage has to do some complicated work when destroying - do it in advance. /// For example, if the AccessStorage contains any threads for background work - ask them to complete and wait for completion. /// By default, does nothing. virtual void shutdown() {} /// Returns the name of this storage. const String & getStorageName() const { return storage_name; } virtual const char * getStorageType() const = 0; /// Returns a JSON with the parameters of the storage. It's up to the storage type to fill the JSON. virtual String getStorageParamsJSON() const { return "{}"; } /// Returns true if this storage is readonly. virtual bool isReadOnly() const { return false; } /// Returns true if this entity is readonly. virtual bool isReadOnly(const UUID &) const { return isReadOnly(); } /// Returns true if this storage is replicated. virtual bool isReplicated() const { return false; } virtual String getReplicationID() const { return ""; } /// Starts periodic reloading and updating of entities in this storage. virtual void startPeriodicReloading() {} /// Stops periodic reloading and updating of entities in this storage. virtual void stopPeriodicReloading() {} enum class ReloadMode { /// Try to reload all access storages (including users.xml, local(disk) access storage, replicated(in zk) access storage. /// This mode is invoked by the SYSTEM RELOAD USERS command. ALL, /// Only reloads users.xml /// This mode is invoked by the SYSTEM RELOAD CONFIG command. USERS_CONFIG_ONLY, }; /// Makes this storage to reload and update access entities right now. virtual void reload(ReloadMode /* reload_mode */) {} /// Returns the identifiers of all the entities of a specified type contained in the storage. std::vector findAll(AccessEntityType type) const; /// Returns the identifiers of all the entities in the storage. template std::vector findAll() const; /// Searches for an entity with specified type and name. Returns std::nullopt if not found. std::optional find(AccessEntityType type, const String & name) const; template std::optional find(const String & name) const { return find(EntityClassT::TYPE, name); } std::vector find(AccessEntityType type, const Strings & names) const; template std::vector find(const Strings & names) const { return find(EntityClassT::TYPE, names); } /// Searches for an entity with specified name and type. Throws an exception if not found. UUID getID(AccessEntityType type, const String & name) const; template UUID getID(const String & name) const { return getID(EntityClassT::TYPE, name); } std::vector getIDs(AccessEntityType type, const Strings & names) const; template std::vector getIDs(const Strings & names) const { return getIDs(EntityClassT::TYPE, names); } /// Returns whether there is an entity with such identifier in the storage. virtual bool exists(const UUID & id) const = 0; bool exists(const std::vector & ids) const; /// Reads an entity. Throws an exception if not found. template std::shared_ptr read(const UUID & id, bool throw_if_not_exists = true) const; template std::shared_ptr read(const String & name, bool throw_if_not_exists = true) const; template std::vector read(const std::vector & ids, bool throw_if_not_exists = true) const; /// Reads an entity. Returns nullptr if not found. template std::shared_ptr tryRead(const UUID & id) const; template std::shared_ptr tryRead(const String & name) const; /// Reads only name of an entity. String readName(const UUID & id) const; std::optional readName(const UUID & id, bool throw_if_not_exists) const; Strings readNames(const std::vector & ids, bool throw_if_not_exists = true) const; std::optional tryReadName(const UUID & id) const; Strings tryReadNames(const std::vector & ids) const; std::pair readNameWithType(const UUID & id) const; std::optional> readNameWithType(const UUID & id, bool throw_if_not_exists) const; std::optional> tryReadNameWithType(const UUID & id) const; /// Reads all entities and returns them with their IDs. template std::vector>> readAllWithIDs() const; std::vector> readAllWithIDs(AccessEntityType type) const; /// Inserts an entity to the storage. Returns ID of a new entry in the storage. /// Throws an exception if the specified name already exists. UUID insert(const AccessEntityPtr & entity); std::optional insert(const AccessEntityPtr & entity, bool replace_if_exists, bool throw_if_exists, UUID * conflicting_id = nullptr); bool insert(const UUID & id, const AccessEntityPtr & entity, bool replace_if_exists, bool throw_if_exists, UUID * conflicting_id = nullptr); std::vector insert(const std::vector & multiple_entities, bool replace_if_exists = false, bool throw_if_exists = true); std::vector insert(const std::vector & multiple_entities, const std::vector & ids, bool replace_if_exists = false, bool throw_if_exists = true); /// Inserts an entity to the storage. Returns ID of a new entry in the storage. std::optional tryInsert(const AccessEntityPtr & entity); std::vector tryInsert(const std::vector & multiple_entities); /// Inserts an entity to the storage. Return ID of a new entry in the storage. /// Replaces an existing entry in the storage if the specified name already exists. UUID insertOrReplace(const AccessEntityPtr & entity); std::vector insertOrReplace(const std::vector & multiple_entities); /// Removes an entity from the storage. Throws an exception if couldn't remove. bool remove(const UUID & id, bool throw_if_not_exists = true); std::vector remove(const std::vector & ids, bool throw_if_not_exists = true); /// Removes an entity from the storage. Returns false if couldn't remove. bool tryRemove(const UUID & id); /// Removes multiple entities from the storage. Returns the list of successfully dropped. std::vector tryRemove(const std::vector & ids); using UpdateFunc = std::function; /// Updates an entity stored in the storage. Throws an exception if couldn't update. bool update(const UUID & id, const UpdateFunc & update_func, bool throw_if_not_exists = true); std::vector update(const std::vector & ids, const UpdateFunc & update_func, bool throw_if_not_exists = true); /// Updates an entity stored in the storage. Returns false if couldn't update. bool tryUpdate(const UUID & id, const UpdateFunc & update_func); /// Updates multiple entities in the storage. Returns the list of successfully updated. std::vector tryUpdate(const std::vector & ids, const UpdateFunc & update_func); /// Finds a user, check the provided credentials and returns the ID of the user if they are valid. /// Throws an exception if no such user or credentials are invalid. AuthResult authenticate( const Credentials & credentials, const Poco::Net::IPAddress & address, const ExternalAuthenticators & external_authenticators, bool allow_no_password, bool allow_plaintext_password) const; std::optional authenticate( const Credentials & credentials, const Poco::Net::IPAddress & address, const ExternalAuthenticators & external_authenticators, bool throw_if_user_not_exists, bool allow_no_password, bool allow_plaintext_password) const; /// Returns true if this storage can be stored to or restored from a backup. virtual bool isBackupAllowed() const { return false; } virtual bool isRestoreAllowed() const { return isBackupAllowed() && !isReadOnly(); } /// Makes a backup of this access storage. virtual void backup(BackupEntriesCollector & backup_entries_collector, const String & data_path_in_backup, AccessEntityType type) const; virtual void restoreFromBackup(RestorerFromBackup & restorer, const String & data_path_in_backup); protected: virtual std::optional findImpl(AccessEntityType type, const String & name) const = 0; virtual std::vector findAllImpl(AccessEntityType type) const = 0; virtual std::vector findAllImpl() const; virtual AccessEntityPtr readImpl(const UUID & id, bool throw_if_not_exists) const = 0; virtual std::optional> readNameWithTypeImpl(const UUID & id, bool throw_if_not_exists) const; virtual bool insertImpl(const UUID & id, const AccessEntityPtr & entity, bool replace_if_exists, bool throw_if_exists, UUID * conflicting_id); virtual bool removeImpl(const UUID & id, bool throw_if_not_exists); virtual bool updateImpl(const UUID & id, const UpdateFunc & update_func, bool throw_if_not_exists); virtual std::optional authenticateImpl( const Credentials & credentials, const Poco::Net::IPAddress & address, const ExternalAuthenticators & external_authenticators, bool throw_if_user_not_exists, bool allow_no_password, bool allow_plaintext_password) const; virtual bool areCredentialsValid( const std::string & user_name, const AuthenticationData & authentication_method, const Credentials & credentials, const ExternalAuthenticators & external_authenticators, SettingsChanges & settings) const; virtual bool isAddressAllowed(const User & user, const Poco::Net::IPAddress & address) const; static UUID generateRandomID(); LoggerPtr getLogger() const; static String formatEntityTypeWithName(AccessEntityType type, const String & name) { return AccessEntityTypeInfo::get(type).formatEntityNameWithType(name); } static void clearConflictsInEntitiesList(std::vector> & entities, LoggerPtr log_); virtual bool acquireReplicatedRestore(RestorerFromBackup &) const { return false; } [[noreturn]] void throwNotFound(const UUID & id) const; [[noreturn]] void throwNotFound(AccessEntityType type, const String & name) const; [[noreturn]] static void throwBadCast(const UUID & id, AccessEntityType type, const String & name, AccessEntityType required_type); [[noreturn]] void throwIDCollisionCannotInsert( const UUID & id, AccessEntityType type, const String & name, AccessEntityType existing_type, const String & existing_name) const; [[noreturn]] void throwNameCollisionCannotInsert(AccessEntityType type, const String & name) const; [[noreturn]] void throwNameCollisionCannotRename(AccessEntityType type, const String & old_name, const String & new_name) const; [[noreturn]] void throwReadonlyCannotInsert(AccessEntityType type, const String & name) const; [[noreturn]] void throwReadonlyCannotUpdate(AccessEntityType type, const String & name) const; [[noreturn]] void throwReadonlyCannotRemove(AccessEntityType type, const String & name) const; [[noreturn]] static void throwAddressNotAllowed(const Poco::Net::IPAddress & address); [[noreturn]] static void throwInvalidCredentials(); [[noreturn]] void throwBackupNotAllowed() const; [[noreturn]] void throwRestoreNotAllowed() const; private: const String storage_name; mutable OnceFlag log_initialized; mutable LoggerPtr log = nullptr; }; template std::vector IAccessStorage::findAll() const { if constexpr (std::is_same_v) return findAllImpl(); else return findAllImpl(EntityClassT::TYPE); } template std::shared_ptr IAccessStorage::read(const UUID & id, bool throw_if_not_exists) const { auto entity = readImpl(id, throw_if_not_exists); if constexpr (std::is_same_v) return entity; else { if (!entity) return nullptr; if (auto ptr = typeid_cast>(entity)) return ptr; throwBadCast(id, entity->getType(), entity->getName(), EntityClassT::TYPE); } } template std::shared_ptr IAccessStorage::read(const String & name, bool throw_if_not_exists) const { if (auto id = find(name)) return read(*id, throw_if_not_exists); if (throw_if_not_exists) throwNotFound(EntityClassT::TYPE, name); else return nullptr; } template std::vector IAccessStorage::read(const std::vector & ids, bool throw_if_not_exists) const { std::vector result; result.reserve(ids.size()); for (const auto & id : ids) result.push_back(read(id, throw_if_not_exists)); return result; } template std::shared_ptr IAccessStorage::tryRead(const UUID & id) const { return read(id, false); } template std::shared_ptr IAccessStorage::tryRead(const String & name) const { return read(name, false); } template std::vector>> IAccessStorage::readAllWithIDs() const { std::vector>> entities; for (const auto & id : findAll()) { if (auto entity = tryRead(id)) entities.emplace_back(id, entity); } return entities; } inline bool parseAccessStorageName(IParser::Pos & pos, Expected & expected, String & storage_name) { return parseIdentifierOrStringLiteral(pos, expected, storage_name); } }