import warnings
import json
from logging import getLogger
import os
import time
import uuid
from pathlib import Path
from typing import Annotated, Union, Literal
import pandas as pd
from pydantic import (
BaseModel,
Field,
ValidationError,
field_serializer,
model_validator,
Discriminator,
ConfigDict,
)
from pydantic_settings import (
BaseSettings,
PydanticBaseSettingsSource,
SettingsConfigDict,
TomlConfigSettingsSource,
)
import requests
from platformdirs import PlatformDirs, user_config_path
import toml
from ._errors import PywaspError, LicenseError
logger = getLogger(__name__)
# Constants
APP_NAME = "pywasp"
APP_AUTHOR = "DTU Wind Energy"
DEFAULT_PORT = "34523"
DEFAULT_DTU_HOST = "130.226.48.178"
ACCOUNT = "0244d026-758b-442b-9668-811ce125a7e3"
PRODUCT_ID = "f0c12bfb-c1f5-4fc2-bfb1-bd20fbdae727"
MAX_PTS = {
"b32d02b9-3d7e-43a4-80a3-4b8aa13083c6": 25_000, # Tier 1
"87fcb9e4-09df-4673-bb32-e2748c49a183": 125_000, # Tier 2
"819d348f-b49e-4ddd-aeac-c1074de9c971": 625_000, # Tier 3
}
# Needs roaming=True to be consistent with rusts app_dirs2 package.
_CONFIG_DIR = user_config_path(appname=APP_NAME, appauthor=APP_AUTHOR, roaming=True)
_CONFIG_DIR.mkdir(parents=True, exist_ok=True)
_CONFIG_FILE = _CONFIG_DIR / "./pywasp_config.toml"
class DownloadCfg(BaseModel):
download_prompt: bool = True
download_global_nc_files: bool = True
def _check_url_connectivity(test_url):
"""
Attempts to connect to the constructed URL to ensure it's reachable.
Fails fast if the host is unreachable or connection is refused.
Parameters
----------
test_url : str
The URL to test connectivity against.
"""
logger.debug(f"Attempting to connect to: {test_url}")
try:
response = requests.get(test_url, timeout=10, allow_redirects=False)
response.raise_for_status() # Raise an HTTPError for bad responses (4xx or 5xx)
logger.debug(
f"Successfully connected to {test_url} (Status: {response.status_code})"
)
except requests.exceptions.SSLError as e:
if "SSL: CERTIFICATE_VERIFY_FAILED" in str(e):
raise LicenseError(123)
raise PywaspError(
f"Problem with ssl verification, try adding a root_ca_filepath to your Licensing configuration: {e}"
)
except requests.exceptions.ConnectionError as e:
raise ValueError(
f"Failed to connect to license server at {test_url}. "
f"Host might be unreachable or port incorrect. Error: {e}"
)
except requests.exceptions.Timeout as e:
raise ValueError(
f"Connection to license server at {test_url} timed out. "
f"Server might be slow or unresponsive. Error: {e}"
)
except requests.exceptions.RequestException as e:
# Catches other requests-related errors, including HTTPError from raise_for_status()
raise ValueError(
f"An error occurred while checking license server at {test_url}. Error: {e}"
)
except Exception as e:
# Catch any other unexpected errors
raise ValueError(
f"An unexpected error occurred during URL check for {test_url}. Error: {e}"
)
class LicensingLocalCfg(BaseModel):
license_type: Literal["local"] = "local"
host: str
port: int = Field(34523, ge=0, le=65535)
@property
def _url(self) -> str:
return f"http://{self.host}:{self.port}"
@model_validator(mode="after")
def check_url_connectivity(self):
_check_url_connectivity(f"{self._url}")
return self
class LicensingRemoteCfg(BaseModel):
license_type: Literal["keygen_cloud"] = "keygen_cloud"
redirect_host: str | None = None
license_id: uuid.UUID
activation_token: str
root_ca_filepath: Path | None = None
# Use computed_field instead of private attribute
model_config = ConfigDict(arbitrary_types_allowed=True)
@property
def _url(self) -> str:
"""Construct the base URL for the Keygen API. Handles redirection if specified."""
if self.redirect_host is None:
host = "api.keygen.sh"
else:
host = self.redirect_host
return f"https://{host}"
@model_validator(mode="after")
def setup_env_and_check_url_connectivity(self):
"""Set up SSL certificate in environment variables and check URL connectivity."""
if self.root_ca_filepath is not None:
env_var_name = "REQUESTS_CA_BUNDLE"
os.environ[env_var_name] = str(self.root_ca_filepath)
logger.info(
f"Environment variable '{env_var_name}' set to: {self.root_ca_filepath}"
)
else:
logger.info("No 'root_ca_filepath' provided, environment variable not set.")
_check_url_connectivity(f"{self._url}/v1/ping")
return self
@field_serializer("license_id")
def serialize_license_id(self, license_id: uuid.UUID) -> str:
"""Serialize the UUID license ID to a string for API requests."""
return str(license_id)
@field_serializer("root_ca_filepath")
def serialize_root_ca_filepath(self, root_ca_filepath: Path | None) -> str | None:
"""Serialize the root CA file path to a string for API requests."""
if root_ca_filepath is None:
return None
if not root_ca_filepath.exists():
raise ValueError(
f"Root CA file does not exist at specified path: {root_ca_filepath}"
)
return str(root_ca_filepath)
def _request_validation(self) -> dict:
"""Send a request to the license server to validate the license. Returns the response as a dictionary.
Including information about the license status, such as whether it is active, the number of runs used,
the runs limit, and the expiry date.
"""
sleep_s = 0.5
last_exception = None
for attempt in range(5):
uri_validate = f"{self._url}/v1/accounts/{ACCOUNT}/licenses/{self.license_id}/actions/validate"
headers_validate = {
"Content-Type": "application/vnd.api+json",
"Accept": "application/vnd.api+json",
"Authorization": f"Bearer {self.activation_token}",
}
data_validate = json.dumps(
{"meta": {"scope": {"product": f"{PRODUCT_ID}"}}}
)
try:
res = requests.post(
uri_validate,
headers=headers_validate,
data=data_validate,
timeout=30, # Add explicit timeout
)
if res.status_code == 200:
res_json = res.json()
if res_json["meta"]["constant"] == "VALID":
return res_json
else:
raise ValueError(
f"Unable to validate license: {res_json['meta']['constant']}"
)
elif res.status_code == 401:
raise ValueError(
"Activation token is incorrect, please double check and try again."
)
else:
# Store the response for potential final error
last_exception = requests.HTTPError(
f"HTTP {res.status_code}: {res.text}", response=res
)
except requests.exceptions.SSLError as e:
if "SSL: CERTIFICATE_VERIFY_FAILED" in str(e):
raise LicenseError(123)
raise PywaspError(f"Problem with connection: {e}")
except requests.exceptions.RequestException as e:
last_exception = e
# Exponential backoff for retries
time.sleep(sleep_s)
sleep_s *= 2
# Raise the last exception if all retries failed
if last_exception:
raise last_exception
else:
raise PywaspError("License validation failed after 5 attempts")
def get_points_limit(self) -> int:
"""Get the maximum points allowed by the license."""
try:
response_dict = self._request_validation()
policy_id = response_dict["data"]["relationships"]["policy"]["data"]["id"]
return MAX_PTS[policy_id]
except KeyError as e:
raise PywaspError(f"Unknown policy ID in license response: {e}")
def get_license_status(self) -> dict:
"""Get the status of the license.
Returns
-------
dict:
A dictionary containing the license status, including whether it is active,
the number of runs used, the runs limit, and the expiry date.
"""
try:
response_dict = self._request_validation()
return {
"active": True,
"runs_used": int(response_dict["data"]["attributes"]["uses"]),
"runs_limit": int(response_dict["data"]["attributes"]["maxUses"]),
"expiry": pd.to_datetime(
response_dict["data"]["attributes"]["expiry"]
).date(),
}
except Exception as e:
logger.error(f"Error getting license status: {e}")
return {
"active": None,
"runs_used": None,
"runs_limit": None,
"expiry": None,
}
class Settings(BaseSettings):
download: DownloadCfg
licensing: Annotated[
Union[LicensingLocalCfg, LicensingRemoteCfg],
Discriminator("license_type"),
]
model_config = SettingsConfigDict(
toml_file=_CONFIG_FILE,
env_prefix="pywasp_",
extra="ignore",
env_nested_delimiter="__",
env_file=".env",
)
# Add toml settings to end of default 4 sources
@classmethod
def settings_customise_sources(
cls,
settings_cls: type[BaseSettings],
init_settings: PydanticBaseSettingsSource,
env_settings: PydanticBaseSettingsSource,
dotenv_settings: PydanticBaseSettingsSource,
file_secret_settings: PydanticBaseSettingsSource,
) -> tuple[PydanticBaseSettingsSource, ...]:
return (
init_settings,
env_settings,
dotenv_settings,
file_secret_settings,
TomlConfigSettingsSource(settings_cls),
)
def save(self):
with _CONFIG_FILE.open("w") as f:
toml.dump(self.model_dump(exclude_none=True), f)
logger.info(f"Configuration saved to {_CONFIG_FILE}.")
def get_license_status(self) -> dict | None:
"""Get the status of the license.
Returns
-------
dict:
A dictionary containing the license status, including whether it is active,
the number of runs used, the runs limit, and the expiry date.
"""
if isinstance(self.licensing, LicensingLocalCfg):
return None
elif isinstance(self.licensing, LicensingRemoteCfg):
return self.licensing.get_license_status()
else:
raise PywaspError(
f"Unknown license type: {self.licensing.license_type}. Cannot get license status."
)
def get_points_limit(self) -> int | None:
"""Check if the number of points is within the license limit.
Returns
-------
num_points : int
The license limit of number of points allowed in one calculation.
If the license type is 'local', returns None (since there is no limit).
Raises
------
PywaspError:
If the number of points exceeds the license limit and raise_error is True.
"""
if isinstance(self.licensing, LicensingLocalCfg):
return None # Represents no limit for local licenses
elif isinstance(self.licensing, LicensingRemoteCfg):
return self.licensing.get_points_limit()
else:
raise PywaspError(
f"Unknown license type: {self.licensing.license_type}. Cannot check point limit."
)
def check_points_within_limit(self, n_points: int) -> bool:
"""Check if the number of points is within the license limit.
Parameters
----------
n_points : int
The number of points to check.
Returns
-------
bool
True if the number of points is within the limit, False otherwise.
Raises
------
PywaspError:
If the number of points exceeds the license limit and raise_error is True.
"""
if isinstance(self.licensing, LicensingLocalCfg):
return True # Represents no limit for local licenses
elif isinstance(self.licensing, LicensingRemoteCfg):
n_limit = self.licensing.get_points_limit()
return n_points <= n_limit
else:
raise PywaspError(
f"Unknown license type: {self.licensing.license_type}. Cannot check point limit."
)
def _ensure_config_exists(interactive=False):
"""
Check if configuration file exists and optionally create it interactively.
Parameters
----------
interactive : bool, default False
If True and config doesn't exist, prompt user to create it interactively.
If False, just return whether the config exists.
Returns
-------
bool
True if config exists (or was created), False otherwise.
"""
if _CONFIG_FILE.exists():
return True
if interactive:
print(f"PyWAsP configuration file not found at: {_CONFIG_FILE}")
create = input("Would you like to create it now? (Y/n): ").strip().lower()
if create not in ["n", "no", "false"]:
config = create_config_interactively()
return config is not None
return False
[docs]
def create_config_interactively():
"""
Interactively prompt the user to create a pywasp_config.toml file.
This function guides the user through setting up their PyWAsP configuration
by asking for licensing and download preferences. The configuration is saved
to the default config file location.
Returns
-------
Settings
The created configuration settings object.
Raises
------
KeyboardInterrupt
If the user cancels the configuration process.
ValueError
If invalid input is provided or configuration validation fails.
"""
global _config
print(f"Creating PyWAsP configuration file at: {_CONFIG_FILE}")
print("=" * 60)
# Check if config file already exists
if _CONFIG_FILE.exists():
overwrite = (
input(
f"Configuration file already exists at {_CONFIG_FILE}. Overwrite? (y/N): "
)
.strip()
.lower()
)
if overwrite not in ["y", "yes"]:
print("Configuration creation cancelled.")
return None
try:
# Create directory if it doesn't exist
_CONFIG_FILE.parent.mkdir(parents=True, exist_ok=True)
# Download settings
print("\n1. Download Settings")
print("-" * 20)
download_prompt = input("Enable download prompts? (y/N): ").strip().lower()
download_prompt = download_prompt in ["y", "yes", "true"]
download_global = (
input("Download global NC files automatically? (Y/n): ").strip().lower()
)
download_global = download_global not in ["n", "no", "false"]
download_cfg = DownloadCfg(
download_prompt=download_prompt, download_global_nc_files=download_global
)
# Licensing settings
print("\n2. Licensing Settings")
print("-" * 20)
print("Choose license type:")
print("1. Remote license (Keygen cloud)")
print("2. Local license server")
while True:
license_choice = input("Enter choice (1 or 2): ").strip()
if license_choice in ["1", "2"]:
break
print("Please enter 1 or 2.")
if license_choice == "2":
# Local license configuration
print("\nConfiguring local license server...")
while True:
host = input("License server host: ").strip()
if host:
break
print("Please enter a valid host.")
while True:
port_input = input("License server port:").strip()
if port_input:
break
print("Please enter a valid port.")
try:
port = int(port_input)
if not (0 <= port <= 65535):
raise ValueError("Port must be between 0 and 65535")
except ValueError as e:
print(f"Invalid port: {e}")
return None
print(f"Testing connection to {host}:{port}...")
try:
licensing_cfg = LicensingLocalCfg(host=host, port=port)
print("✓ Connection successful!")
except Exception as e:
print(f"✗ Connection failed: {e}")
retry = input("Save configuration anyway? (y/N): ").strip().lower()
if retry not in ["y", "yes"]:
print("Configuration creation cancelled.")
return None
# Create without validation for offline scenarios
licensing_cfg = LicensingLocalCfg.model_construct(host=host, port=port)
else:
# Remote license configuration
print("\nConfiguring remote license (Keygen cloud)...")
license_id_input = input("License ID (UUID): ").strip()
try:
license_id = uuid.UUID(license_id_input)
except ValueError:
print("Invalid UUID format for license ID.")
return None
activation_token = input("Activation token: ").strip()
if not activation_token:
print("Activation token is required.")
return None
redirect_host = input(
"Redirect host through license.windenergy.dtu.dk? (y/N): "
).strip()
if redirect_host.lower() in ["y", "yes", "true"]:
redirect_host = "license.windenergy.dtu.dk"
else:
redirect_host = None
root_ca_path = input(
"Root CA file path (optional, press Enter to skip): "
).strip()
if root_ca_path:
root_ca_filepath = Path(root_ca_path)
if not root_ca_filepath.exists():
print(f"Warning: CA file not found at {root_ca_filepath}")
else:
root_ca_filepath = None
print("Testing license validation...")
try:
licensing_cfg = LicensingRemoteCfg(
license_id=license_id,
activation_token=activation_token,
redirect_host=redirect_host,
root_ca_filepath=root_ca_filepath,
)
print("✓ License configuration validated!")
except Exception as e:
print(f"✗ License validation failed: {e}")
retry = input("Save configuration anyway? (y/N): ").strip().lower()
if retry not in ["y", "yes"]:
print("Configuration creation cancelled.")
return None
# Create without validation for offline scenarios
licensing_cfg = LicensingRemoteCfg.model_construct(
license_id=license_id,
activation_token=activation_token,
redirect_host=redirect_host,
root_ca_filepath=root_ca_filepath,
)
# Create and save settings
config = Settings(download=download_cfg, licensing=licensing_cfg)
config.save()
# Update the global _config variable
_config = config
print(f"\n✓ Configuration saved successfully to: {_CONFIG_FILE}")
print("\nConfiguration summary:")
print(f" Download prompts: {download_cfg.download_prompt}")
print(f" Auto-download global files: {download_cfg.download_global_nc_files}")
print(f" License type: {licensing_cfg.license_type}")
if isinstance(licensing_cfg, LicensingLocalCfg):
print(f" License server: {licensing_cfg.host}:{licensing_cfg.port}")
else:
print(f" License ID: {licensing_cfg.license_id}")
print(f" Redirect host: {licensing_cfg.redirect_host or 'None'}")
except KeyboardInterrupt:
print("\n\nConfiguration creation cancelled by user.")
return None
except Exception as e:
print(f"\nError creating configuration: {e}")
return None
if __name__ == __name__: # Always run on import to load configuration
logger.info("Loading application configuration...")
# Must use roaming=True to be consistent with rusts app_dirs2 package.
app_dir = PlatformDirs(appname=APP_NAME, appauthor=APP_AUTHOR, roaming=True)
Path(app_dir.user_data_path).mkdir(parents=True, exist_ok=True)
try:
_config = Settings.model_validate({})
# Have to save settings so that rust can access them.
_config.save()
logger.info("Application configuration loaded successfully.")
except ValidationError as e:
logger.error("%s" % e)
_config = None
warnings.warn(
"Error loading configuration: "
f"Please ensure your configuration file is at '{_CONFIG_FILE}' and is correct, "
"or set the correct environment variables as described in our documentation. "
"You can also create a new configuration interactively by running `pywasp.user_config.create_config_interactively()` or use the CLI command `pywasp configure`.",
)