Source code for pywasp._cfg

import warnings
import json
from logging import getLogger
import os
import time
import uuid
from pathlib import Path
from typing import Annotated, Union, Literal
import pandas as pd

from pydantic import (
    BaseModel,
    Field,
    ValidationError,
    field_serializer,
    model_validator,
    Discriminator,
    ConfigDict,
)
from pydantic_settings import (
    BaseSettings,
    PydanticBaseSettingsSource,
    SettingsConfigDict,
    TomlConfigSettingsSource,
)

import requests
from platformdirs import PlatformDirs, user_config_path
import toml

from ._errors import PywaspError, LicenseError

logger = getLogger(__name__)

# Constants
APP_NAME = "pywasp"
APP_AUTHOR = "DTU Wind Energy"
DEFAULT_PORT = "34523"
DEFAULT_DTU_HOST = "130.226.48.178"
ACCOUNT = "0244d026-758b-442b-9668-811ce125a7e3"
PRODUCT_ID = "f0c12bfb-c1f5-4fc2-bfb1-bd20fbdae727"

MAX_PTS = {
    "b32d02b9-3d7e-43a4-80a3-4b8aa13083c6": 25_000,  # Tier 1
    "87fcb9e4-09df-4673-bb32-e2748c49a183": 125_000,  # Tier 2
    "819d348f-b49e-4ddd-aeac-c1074de9c971": 625_000,  # Tier 3
}

# Needs roaming=True to be consistent with rusts app_dirs2 package.
_CONFIG_DIR = user_config_path(appname=APP_NAME, appauthor=APP_AUTHOR, roaming=True)
_CONFIG_DIR.mkdir(parents=True, exist_ok=True)


_CONFIG_FILE = _CONFIG_DIR / "./pywasp_config.toml"


class DownloadCfg(BaseModel):
    download_prompt: bool = True
    download_global_nc_files: bool = True


def _check_url_connectivity(test_url):
    """
    Attempts to connect to the constructed URL to ensure it's reachable.
    Fails fast if the host is unreachable or connection is refused.

    Parameters
    ----------
    test_url : str
        The URL to test connectivity against.

    """
    logger.debug(f"Attempting to connect to: {test_url}")
    try:
        response = requests.get(test_url, timeout=10, allow_redirects=False)
        response.raise_for_status()  # Raise an HTTPError for bad responses (4xx or 5xx)
        logger.debug(
            f"Successfully connected to {test_url} (Status: {response.status_code})"
        )
    except requests.exceptions.SSLError as e:
        if "SSL: CERTIFICATE_VERIFY_FAILED" in str(e):
            raise LicenseError(123)
        raise PywaspError(
            f"Problem with ssl verification, try adding a root_ca_filepath to your Licensing configuration: {e}"
        )
    except requests.exceptions.ConnectionError as e:
        raise ValueError(
            f"Failed to connect to license server at {test_url}. "
            f"Host might be unreachable or port incorrect. Error: {e}"
        )
    except requests.exceptions.Timeout as e:
        raise ValueError(
            f"Connection to license server at {test_url} timed out. "
            f"Server might be slow or unresponsive. Error: {e}"
        )
    except requests.exceptions.RequestException as e:
        # Catches other requests-related errors, including HTTPError from raise_for_status()
        raise ValueError(
            f"An error occurred while checking license server at {test_url}. Error: {e}"
        )
    except Exception as e:
        # Catch any other unexpected errors
        raise ValueError(
            f"An unexpected error occurred during URL check for {test_url}. Error: {e}"
        )


class LicensingLocalCfg(BaseModel):
    license_type: Literal["local"] = "local"
    host: str
    port: int = Field(34523, ge=0, le=65535)

    @property
    def _url(self) -> str:
        return f"http://{self.host}:{self.port}"

    @model_validator(mode="after")
    def check_url_connectivity(self):
        _check_url_connectivity(f"{self._url}")
        return self


class LicensingRemoteCfg(BaseModel):
    license_type: Literal["keygen_cloud"] = "keygen_cloud"
    redirect_host: str | None = None
    license_id: uuid.UUID
    activation_token: str
    root_ca_filepath: Path | None = None

    # Use computed_field instead of private attribute
    model_config = ConfigDict(arbitrary_types_allowed=True)

    @property
    def _url(self) -> str:
        """Construct the base URL for the Keygen API. Handles redirection if specified."""
        if self.redirect_host is None:
            host = "api.keygen.sh"
        else:
            host = self.redirect_host
        return f"https://{host}"

    @model_validator(mode="after")
    def setup_env_and_check_url_connectivity(self):
        """Set up SSL certificate in environment variables and check URL connectivity."""
        if self.root_ca_filepath is not None:
            env_var_name = "REQUESTS_CA_BUNDLE"
            os.environ[env_var_name] = str(self.root_ca_filepath)
            logger.info(
                f"Environment variable '{env_var_name}' set to: {self.root_ca_filepath}"
            )
        else:
            logger.info("No 'root_ca_filepath' provided, environment variable not set.")

        _check_url_connectivity(f"{self._url}/v1/ping")
        return self

    @field_serializer("license_id")
    def serialize_license_id(self, license_id: uuid.UUID) -> str:
        """Serialize the UUID license ID to a string for API requests."""
        return str(license_id)

    @field_serializer("root_ca_filepath")
    def serialize_root_ca_filepath(self, root_ca_filepath: Path | None) -> str | None:
        """Serialize the root CA file path to a string for API requests."""
        if root_ca_filepath is None:
            return None
        if not root_ca_filepath.exists():
            raise ValueError(
                f"Root CA file does not exist at specified path: {root_ca_filepath}"
            )
        return str(root_ca_filepath)

    def _request_validation(self) -> dict:
        """Send a request to the license server to validate the license. Returns the response as a dictionary.
        Including information about the license status, such as whether it is active, the number of runs used,
        the runs limit, and the expiry date.
        """
        sleep_s = 0.5
        last_exception = None

        for attempt in range(5):
            uri_validate = f"{self._url}/v1/accounts/{ACCOUNT}/licenses/{self.license_id}/actions/validate"
            headers_validate = {
                "Content-Type": "application/vnd.api+json",
                "Accept": "application/vnd.api+json",
                "Authorization": f"Bearer {self.activation_token}",
            }
            data_validate = json.dumps(
                {"meta": {"scope": {"product": f"{PRODUCT_ID}"}}}
            )

            try:
                res = requests.post(
                    uri_validate,
                    headers=headers_validate,
                    data=data_validate,
                    timeout=30,  # Add explicit timeout
                )

                if res.status_code == 200:
                    res_json = res.json()
                    if res_json["meta"]["constant"] == "VALID":
                        return res_json
                    else:
                        raise ValueError(
                            f"Unable to validate license: {res_json['meta']['constant']}"
                        )
                elif res.status_code == 401:
                    raise ValueError(
                        "Activation token is incorrect, please double check and try again."
                    )
                else:
                    # Store the response for potential final error
                    last_exception = requests.HTTPError(
                        f"HTTP {res.status_code}: {res.text}", response=res
                    )

            except requests.exceptions.SSLError as e:
                if "SSL: CERTIFICATE_VERIFY_FAILED" in str(e):
                    raise LicenseError(123)
                raise PywaspError(f"Problem with connection: {e}")
            except requests.exceptions.RequestException as e:
                last_exception = e

            # Exponential backoff for retries
            time.sleep(sleep_s)
            sleep_s *= 2

        # Raise the last exception if all retries failed
        if last_exception:
            raise last_exception
        else:
            raise PywaspError("License validation failed after 5 attempts")

    def get_points_limit(self) -> int:
        """Get the maximum points allowed by the license."""
        try:
            response_dict = self._request_validation()
            policy_id = response_dict["data"]["relationships"]["policy"]["data"]["id"]
            return MAX_PTS[policy_id]
        except KeyError as e:
            raise PywaspError(f"Unknown policy ID in license response: {e}")

    def get_license_status(self) -> dict:
        """Get the status of the license.

        Returns
        -------
        dict:
            A dictionary containing the license status, including whether it is active,
            the number of runs used, the runs limit, and the expiry date.
        """
        try:
            response_dict = self._request_validation()
            return {
                "active": True,
                "runs_used": int(response_dict["data"]["attributes"]["uses"]),
                "runs_limit": int(response_dict["data"]["attributes"]["maxUses"]),
                "expiry": pd.to_datetime(
                    response_dict["data"]["attributes"]["expiry"]
                ).date(),
            }
        except Exception as e:
            logger.error(f"Error getting license status: {e}")
            return {
                "active": None,
                "runs_used": None,
                "runs_limit": None,
                "expiry": None,
            }


class Settings(BaseSettings):
    download: DownloadCfg
    licensing: Annotated[
        Union[LicensingLocalCfg, LicensingRemoteCfg],
        Discriminator("license_type"),
    ]

    model_config = SettingsConfigDict(
        toml_file=_CONFIG_FILE,
        env_prefix="pywasp_",
        extra="ignore",
        env_nested_delimiter="__",
        env_file=".env",
    )

    # Add toml settings to end of default 4 sources
    @classmethod
    def settings_customise_sources(
        cls,
        settings_cls: type[BaseSettings],
        init_settings: PydanticBaseSettingsSource,
        env_settings: PydanticBaseSettingsSource,
        dotenv_settings: PydanticBaseSettingsSource,
        file_secret_settings: PydanticBaseSettingsSource,
    ) -> tuple[PydanticBaseSettingsSource, ...]:
        return (
            init_settings,
            env_settings,
            dotenv_settings,
            file_secret_settings,
            TomlConfigSettingsSource(settings_cls),
        )

    def save(self):
        with _CONFIG_FILE.open("w") as f:
            toml.dump(self.model_dump(exclude_none=True), f)
        logger.info(f"Configuration saved to {_CONFIG_FILE}.")

    def get_license_status(self) -> dict | None:
        """Get the status of the license.

        Returns
        -------
        dict:
            A dictionary containing the license status, including whether it is active,
            the number of runs used, the runs limit, and the expiry date.
        """
        if isinstance(self.licensing, LicensingLocalCfg):
            return None
        elif isinstance(self.licensing, LicensingRemoteCfg):
            return self.licensing.get_license_status()
        else:
            raise PywaspError(
                f"Unknown license type: {self.licensing.license_type}. Cannot get license status."
            )

    def get_points_limit(self) -> int | None:
        """Check if the number of points is within the license limit.

        Returns
        -------
        num_points : int
            The license limit of number of points allowed in one calculation.
            If the license type is 'local', returns None (since there is no limit).

        Raises
        ------
        PywaspError:
            If the number of points exceeds the license limit and raise_error is True.
        """
        if isinstance(self.licensing, LicensingLocalCfg):
            return None  # Represents no limit for local licenses
        elif isinstance(self.licensing, LicensingRemoteCfg):
            return self.licensing.get_points_limit()
        else:
            raise PywaspError(
                f"Unknown license type: {self.licensing.license_type}. Cannot check point limit."
            )

    def check_points_within_limit(self, n_points: int) -> bool:
        """Check if the number of points is within the license limit.

        Parameters
        ----------
        n_points : int
            The number of points to check.

        Returns
        -------
        bool
            True if the number of points is within the limit, False otherwise.

        Raises
        ------
        PywaspError:
            If the number of points exceeds the license limit and raise_error is True.
        """
        if isinstance(self.licensing, LicensingLocalCfg):
            return True  # Represents no limit for local licenses
        elif isinstance(self.licensing, LicensingRemoteCfg):
            n_limit = self.licensing.get_points_limit()
            return n_points <= n_limit
        else:
            raise PywaspError(
                f"Unknown license type: {self.licensing.license_type}. Cannot check point limit."
            )


def _ensure_config_exists(interactive=False):
    """
    Check if configuration file exists and optionally create it interactively.

    Parameters
    ----------
    interactive : bool, default False
        If True and config doesn't exist, prompt user to create it interactively.
        If False, just return whether the config exists.

    Returns
    -------
    bool
        True if config exists (or was created), False otherwise.
    """
    if _CONFIG_FILE.exists():
        return True

    if interactive:
        print(f"PyWAsP configuration file not found at: {_CONFIG_FILE}")
        create = input("Would you like to create it now? (Y/n): ").strip().lower()
        if create not in ["n", "no", "false"]:
            config = create_config_interactively()
            return config is not None

    return False


[docs] def create_config_interactively(): """ Interactively prompt the user to create a pywasp_config.toml file. This function guides the user through setting up their PyWAsP configuration by asking for licensing and download preferences. The configuration is saved to the default config file location. Returns ------- Settings The created configuration settings object. Raises ------ KeyboardInterrupt If the user cancels the configuration process. ValueError If invalid input is provided or configuration validation fails. """ global _config print(f"Creating PyWAsP configuration file at: {_CONFIG_FILE}") print("=" * 60) # Check if config file already exists if _CONFIG_FILE.exists(): overwrite = ( input( f"Configuration file already exists at {_CONFIG_FILE}. Overwrite? (y/N): " ) .strip() .lower() ) if overwrite not in ["y", "yes"]: print("Configuration creation cancelled.") return None try: # Create directory if it doesn't exist _CONFIG_FILE.parent.mkdir(parents=True, exist_ok=True) # Download settings print("\n1. Download Settings") print("-" * 20) download_prompt = input("Enable download prompts? (y/N): ").strip().lower() download_prompt = download_prompt in ["y", "yes", "true"] download_global = ( input("Download global NC files automatically? (Y/n): ").strip().lower() ) download_global = download_global not in ["n", "no", "false"] download_cfg = DownloadCfg( download_prompt=download_prompt, download_global_nc_files=download_global ) # Licensing settings print("\n2. Licensing Settings") print("-" * 20) print("Choose license type:") print("1. Remote license (Keygen cloud)") print("2. Local license server") while True: license_choice = input("Enter choice (1 or 2): ").strip() if license_choice in ["1", "2"]: break print("Please enter 1 or 2.") if license_choice == "2": # Local license configuration print("\nConfiguring local license server...") while True: host = input("License server host: ").strip() if host: break print("Please enter a valid host.") while True: port_input = input("License server port:").strip() if port_input: break print("Please enter a valid port.") try: port = int(port_input) if not (0 <= port <= 65535): raise ValueError("Port must be between 0 and 65535") except ValueError as e: print(f"Invalid port: {e}") return None print(f"Testing connection to {host}:{port}...") try: licensing_cfg = LicensingLocalCfg(host=host, port=port) print("✓ Connection successful!") except Exception as e: print(f"✗ Connection failed: {e}") retry = input("Save configuration anyway? (y/N): ").strip().lower() if retry not in ["y", "yes"]: print("Configuration creation cancelled.") return None # Create without validation for offline scenarios licensing_cfg = LicensingLocalCfg.model_construct(host=host, port=port) else: # Remote license configuration print("\nConfiguring remote license (Keygen cloud)...") license_id_input = input("License ID (UUID): ").strip() try: license_id = uuid.UUID(license_id_input) except ValueError: print("Invalid UUID format for license ID.") return None activation_token = input("Activation token: ").strip() if not activation_token: print("Activation token is required.") return None redirect_host = input( "Redirect host through license.windenergy.dtu.dk? (y/N): " ).strip() if redirect_host.lower() in ["y", "yes", "true"]: redirect_host = "license.windenergy.dtu.dk" else: redirect_host = None root_ca_path = input( "Root CA file path (optional, press Enter to skip): " ).strip() if root_ca_path: root_ca_filepath = Path(root_ca_path) if not root_ca_filepath.exists(): print(f"Warning: CA file not found at {root_ca_filepath}") else: root_ca_filepath = None print("Testing license validation...") try: licensing_cfg = LicensingRemoteCfg( license_id=license_id, activation_token=activation_token, redirect_host=redirect_host, root_ca_filepath=root_ca_filepath, ) print("✓ License configuration validated!") except Exception as e: print(f"✗ License validation failed: {e}") retry = input("Save configuration anyway? (y/N): ").strip().lower() if retry not in ["y", "yes"]: print("Configuration creation cancelled.") return None # Create without validation for offline scenarios licensing_cfg = LicensingRemoteCfg.model_construct( license_id=license_id, activation_token=activation_token, redirect_host=redirect_host, root_ca_filepath=root_ca_filepath, ) # Create and save settings config = Settings(download=download_cfg, licensing=licensing_cfg) config.save() # Update the global _config variable _config = config print(f"\n✓ Configuration saved successfully to: {_CONFIG_FILE}") print("\nConfiguration summary:") print(f" Download prompts: {download_cfg.download_prompt}") print(f" Auto-download global files: {download_cfg.download_global_nc_files}") print(f" License type: {licensing_cfg.license_type}") if isinstance(licensing_cfg, LicensingLocalCfg): print(f" License server: {licensing_cfg.host}:{licensing_cfg.port}") else: print(f" License ID: {licensing_cfg.license_id}") print(f" Redirect host: {licensing_cfg.redirect_host or 'None'}") except KeyboardInterrupt: print("\n\nConfiguration creation cancelled by user.") return None except Exception as e: print(f"\nError creating configuration: {e}") return None
if __name__ == __name__: # Always run on import to load configuration logger.info("Loading application configuration...") # Must use roaming=True to be consistent with rusts app_dirs2 package. app_dir = PlatformDirs(appname=APP_NAME, appauthor=APP_AUTHOR, roaming=True) Path(app_dir.user_data_path).mkdir(parents=True, exist_ok=True) try: _config = Settings.model_validate({}) # Have to save settings so that rust can access them. _config.save() logger.info("Application configuration loaded successfully.") except ValidationError as e: logger.error("%s" % e) _config = None warnings.warn( "Error loading configuration: " f"Please ensure your configuration file is at '{_CONFIG_FILE}' and is correct, " "or set the correct environment variables as described in our documentation. " "You can also create a new configuration interactively by running `pywasp.user_config.create_config_interactively()` or use the CLI command `pywasp configure`.", )