Source code for canaille.app.configuration

import os
import re
import smtplib
import socket
import textwrap
from dataclasses import dataclass

from pydantic import BaseModel as PydanticBaseModel
from pydantic import ValidationError
from pydantic import create_model
from pydantic_core import PydanticUndefined
from pydantic_settings import BaseSettings
from pydantic_settings import SettingsConfigDict

try:
    import tomlkit

    HAS_TOMLKIT = True
except ImportError:
    HAS_TOMLKIT = False
ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))

DEFAULT_CONFIG_FILE = "canaille.toml"


class BaseModel(PydanticBaseModel):
    model_config = SettingsConfigDict(
        use_attribute_docstrings=True,
    )


[docs] class RootSettings(BaseSettings): """The top-level namespace contains the configuration settings unrelated to Canaille. The configuration parameters from the following libraries can be used: - :doc:`Flask <flask:config>` - :doc:`Flask-WTF <flask-wtf:config>` - :doc:`Flask-Babel <flask-babel:index>` - :doc:`Authlib <authlib:flask/2/authorization-server>` .. code-block:: toml :caption: config.toml SECRET_KEY = "very-secret" SERVER_NAME = "auth.mydomain.example" PREFERRED_URL_SCHEME = false DEBUG = false [CANAILLE] NAME = "My organization" ... """ model_config = SettingsConfigDict( extra="allow", env_nested_delimiter="__", case_sensitive=True, use_attribute_docstrings=True, ) SECRET_KEY: str | None = None """The Flask :external:py:data:`SECRET_KEY` configuration setting. You MUST set a value before deploying in production. """ SERVER_NAME: str | None = None """The Flask :external:py:data:`SERVER_NAME` configuration setting. This sets domain name on which canaille will be served. """ PREFERRED_URL_SCHEME: str = "https" """The Flask :external:py:data:`PREFERRED_URL_SCHEME` configuration setting. This sets the url scheme by which canaille will be served. """ DEBUG: bool = False """The Flask :external:py:data:`DEBUG` configuration setting. This enables debug options. .. danger:: This is useful for development but should be absolutely avoided in production environments. """
def settings_factory( config=None, env_file=None, env_prefix="", all_options=False, init_with_examples=False, ): """Push the backend specific configuration into CoreSettings. In the purpose to break dependency against backends libraries like python-ldap or sqlalchemy. """ from canaille.core.configuration import CoreSettings config = config or {} default = example_settings(CoreSettings) if init_with_examples else CoreSettings() attributes = {"CANAILLE": (CoreSettings, default)} if ( all_options or "CANAILLE_SQL" in config or any(var.startswith("CANAILLE_SQL__") for var in os.environ) ): from canaille.backends.sql.configuration import SQLSettings default = example_settings(SQLSettings) if init_with_examples else None attributes["CANAILLE_SQL"] = ((SQLSettings | None), default) if ( all_options or "CANAILLE_LDAP" in config or any(var.startswith("CANAILLE_LDAP__") for var in os.environ) ): from canaille.backends.ldap.configuration import LDAPSettings default = example_settings(LDAPSettings) if init_with_examples else None attributes["CANAILLE_LDAP"] = ((LDAPSettings | None), default) if ( all_options or "CANAILLE_OIDC" in config or any(var.startswith("CANAILLE_OIDC__") for var in os.environ) ): from canaille.oidc.configuration import OIDCSettings default = example_settings(OIDCSettings) if init_with_examples else None attributes["CANAILLE_OIDC"] = ((OIDCSettings | None), default) if ( all_options or "CANAILLE_SCIM" in config or any(var.startswith("CANAILLE_SCIM__") for var in os.environ) ): from canaille.scim.configuration import SCIMSettings default = example_settings(SCIMSettings) if init_with_examples else None attributes["CANAILLE_SCIM"] = ((SCIMSettings | None), default) Settings = create_model( "Settings", __base__=RootSettings, **attributes, ) return Settings( **config, _secrets_dir=os.environ.get("SECRETS_DIR"), _env_file=env_file, _env_prefix=env_prefix, ) class ConfigurationException(Exception): pass @dataclass class CheckResult: message: str success: bool | None = None def setup_config(app, config=None, env_file=None, env_prefix=""): from canaille.oidc.installation import install app.config.from_mapping( { # https://flask.palletsprojects.com/en/stable/config/#SESSION_COOKIE_NAME "SESSION_COOKIE_NAME": "canaille", } ) if HAS_TOMLKIT and not config: if "CONFIG" in os.environ: with open(os.environ.get("CONFIG")) as fd: config = tomlkit.load(fd) app.logger.info(f"Loading configuration from {os.environ['CONFIG']}") elif os.path.exists(DEFAULT_CONFIG_FILE): with open(DEFAULT_CONFIG_FILE) as fd: config = tomlkit.load(fd) app.logger.info(f"Loading configuration from {DEFAULT_CONFIG_FILE}") env_file = env_file or os.getenv("ENV_FILE") try: config_obj = settings_factory( config or {}, env_file=env_file, env_prefix=env_prefix ) except ValidationError as exc: # pragma: no cover app.logger.critical(str(exc)) return False config_dict = config_obj.model_dump() app.no_secret_key = config_dict["SECRET_KEY"] is None app.config.from_mapping(config_dict) if app.debug: install(app.config, debug=True) return True def check_network_config(config): """Perform various network connection to services described in the configuration file.""" from canaille.backends import Backend results = [Backend.instance.check_network_config(config)] if smtp_config := config["CANAILLE"]["SMTP"]: results.append(check_smtp_connection(smtp_config)) else: results.append(CheckResult(message="No SMTP server configured")) if smpp_config := config["CANAILLE"]["SMPP"]: results.append(check_smpp_connection(smpp_config)) else: results.append(CheckResult(message="No SMPP server configured")) return results def check_smtp_connection(config) -> str: host = config["HOST"] port = config["PORT"] try: with smtplib.SMTP(host=host, port=port) as smtp: if config["TLS"]: smtp.starttls() if config["LOGIN"]: smtp.login( user=config["LOGIN"], password=config["PASSWORD"], ) except (socket.gaierror, ConnectionRefusedError): return CheckResult( message=f"Could not connect to the SMTP server '{host}' on port '{port}'", success=False, ) except smtplib.SMTPAuthenticationError: return CheckResult( message=f"SMTP authentication failed with user '{config['LOGIN']}'", success=False, ) except smtplib.SMTPNotSupportedError as exc: return CheckResult( message=str(exc), success=False, ) return CheckResult( message="Successful SMTP connection", success=True, ) def check_smpp_connection(config): import smpplib host = config["HOST"] port = config["PORT"] try: with smpplib.client.Client(host, port, allow_unknown_opt_params=True) as client: client.connect() if config["LOGIN"]: client.bind_transmitter( system_id=config["LOGIN"], password=config["PASSWORD"] ) except smpplib.exceptions.ConnectionError: return CheckResult( success=False, message=f"Could not connect to the SMPP server '{host}' on port '{port}'", ) except smpplib.exceptions.UnknownCommandError as exc: # pragma: no cover return CheckResult( message=str(exc), success=False, ) return CheckResult( message="Successful SMPP connection", success=True, ) def sanitize_rst_text(text: str) -> str: """Remove inline RST syntax.""" # Replace :foo:`~bar.Baz` with Baz text = re.sub(r":[\w:-]+:`~[\w\.]+\.(\w+)`", r"\1", text) # Replace :foo:`bar` and :foo:`bar <anything> with bar` text = re.sub(r":[\w:-]+:`([^`<]+)(?: <[^`>]+>)?`", r"\1", text) # Replace `label <URL>`_ with label (URL) text = re.sub(r"`([^`<]+) <([^`>]+)>`_", r"\1 (\2)", text) # Replace ``foo`` with `foo` text = re.sub(r"``([^`]+)``", r"\1", text) # Remove RST directives text = re.sub(r"\.\. [\w-]+::( \w+)?\n\n", "", text) return text def sanitize_comments(text: str, line_length: int) -> str: """Remove RST syntax and wrap the docstring so it displays well as TOML comments.""" def is_code_block(text: str) -> bool: return all(line.startswith(" ") for line in text.splitlines()) def is_list(text: str) -> bool: return all( line.startswith("-") or line.startswith(" ") for line in text.splitlines() ) text = sanitize_rst_text(text) paragraphs = text.split("\n\n") paragraphs = [ textwrap.fill(paragraph, width=line_length) if not is_code_block(paragraph) and not is_list(paragraph) else paragraph for paragraph in paragraphs ] text = "\n\n".join(paragraphs) return text def export_object_to_toml( obj, with_comments: bool = True, with_defaults: bool = True, line_length: int = 80, ): """Create a tomlkit document from an object.""" def is_complex(obj) -> bool: return isinstance(obj, list | dict | BaseModel | BaseSettings) if isinstance(obj, BaseModel | BaseSettings): doc = tomlkit.document() if isinstance(obj, BaseSettings) else tomlkit.table() for field_name, field_info in obj.model_fields.items(): field_value = getattr(obj, field_name) display_value = field_value is not None and ( isinstance(field_value, BaseModel | BaseSettings) or field_value != field_info.default ) display_comments = with_comments and field_info.description display_default_value = ( with_defaults and not is_complex(field_info.default) and not is_complex(field_value) ) if display_comments and (display_default_value or display_value): sanitized = sanitize_comments(field_info.description, line_length) for line in sanitized.splitlines(): doc.add(tomlkit.comment(line)) if display_default_value: parsed = ( tomlkit.item(field_info.default).as_string() if field_info.default is not None and field_info.default is not PydanticUndefined else "" ) doc.add(tomlkit.comment(f"{field_name} = {parsed}".strip())) sub_value = export_object_to_toml(field_value) if display_value: doc.add(field_name, sub_value) doc.add(tomlkit.nl()) return doc elif isinstance(obj, list): max_inline_items = 4 is_multiline = len(obj) > max_inline_items or all( is_complex(item) for item in obj ) doc = tomlkit.array().multiline(is_multiline) for item in obj: sub_value = export_object_to_toml(item) doc.append(sub_value) return doc elif isinstance(obj, dict): inline = all(not is_complex(item) for item in obj.values()) doc = tomlkit.inline_table() if inline else tomlkit.table() for key, value in obj.items(): sub_value = export_object_to_toml(value) doc.add(key, sub_value) return doc else: return obj def export_config(model: BaseSettings, filename: str): doc = export_object_to_toml(model) content = tomlkit.dumps(doc) # Remove end-of-line spaces content = re.sub(r" +\n", "\n", content) # Remove multiple new-lines content = re.sub(r"\n\n+", "\n\n", content) # Remove end-of-file new-line content = re.sub(r"\n+\Z", "\n", content) with open(filename, "w") as fd: fd.write(content) def example_settings(model: type[BaseModel]) -> type[BaseModel]: """Init a pydantic BaseModel with values passed as Field 'examples'.""" data = { field_name: field_info.examples[0] for field_name, field_info in model.model_fields.items() if field_info.examples } return model.model_validate(data)