658 lines
28 KiB
Python
658 lines
28 KiB
Python
from fastapi import Depends, FastAPI, HTTPException, status, Response
|
|
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
|
from dotenv import dotenv_values
|
|
from loguru import logger
|
|
from datetime import datetime, timezone
|
|
from typing import Optional
|
|
from pathlib import Path
|
|
import secrets
|
|
import string
|
|
import psycopg
|
|
import asyncio
|
|
import base64
|
|
import json
|
|
|
|
from ecdsa import Ed25519, SigningKey, VerifyingKey
|
|
|
|
DEBUG = False # Reveal debug tools such as /docs swagger UI
|
|
|
|
ENV = dotenv_values(".env") # .env file
|
|
VERSION = "v1.0.2" # version number
|
|
ALPHABET = string.ascii_uppercase + string.digits #alphabet and numbers for token generation
|
|
try:
|
|
LICENSE_KEY_PARTS = 5 if "NUM_KEY_CHUNKS" not in ENV else int(ENV["NUM_KEY_CHUNKS"]) #number of chunks in the new generated license keys
|
|
except:
|
|
LICENSE_KEY_PARTS = 5
|
|
|
|
try:
|
|
LICENSE_KEY_PART_LENGTH = 5 if "KEY_CHUNK_LENGTH" not in ENV else int(ENV["KEY_CHUNK_LENGTH"]) #number of characters in each chunk
|
|
except:
|
|
LICENSE_KEY_PART_LENGTH = 5
|
|
|
|
try:
|
|
sign_keys_raw = ENV.get("SIGN_KEY")
|
|
if sign_keys_raw is None:
|
|
sign_keys_raw = ENV.get("SIGN_KEYS")
|
|
SIGN_KEYS = False if sign_keys_raw is None else sign_keys_raw.strip().lower() in {"true", "1", "yes", "on"}
|
|
except Exception:
|
|
SIGN_KEYS = False
|
|
|
|
KEYS_DIR = Path("keys")
|
|
PRIVATE_KEY_PATH = KEYS_DIR / "ed25519_private.pem"
|
|
PUBLIC_KEY_PATH = KEYS_DIR / "ed25519_public.pem"
|
|
|
|
SIGNING_PRIVATE_KEY: Optional[SigningKey] = None
|
|
SIGNING_PUBLIC_KEY_B64: Optional[str] = None
|
|
|
|
api_key = ""
|
|
security = HTTPBearer(auto_error=False)
|
|
|
|
if "POSTGRESQL_PASSWORD" in ENV:
|
|
logger.debug("POSTGRESQL_PASSWORD successfully read.")
|
|
connect_statement = "dbname=postgres user=postgres host=postgresql password='{password}'".format(password = ENV['POSTGRESQL_PASSWORD'])
|
|
else:
|
|
logger.error("POSTGRESQL_PASSWORD not found in .env file... cannot continue.")
|
|
exit()
|
|
|
|
if "API_KEY" in ENV:
|
|
logger.debug("API_KEY successfully read.")
|
|
api_key = ENV['API_KEY']
|
|
else:
|
|
logger.error("API_KEY not found in .env file... cannot continue.")
|
|
exit()
|
|
try:
|
|
with psycopg.connect(connect_statement) as conn:
|
|
with conn.cursor() as cur:
|
|
cur.execute("SELECT 1")
|
|
except Exception as error:
|
|
logger.error("Failed to connect to PostgreSQL: {}", error)
|
|
exit()
|
|
|
|
def _load_or_create_signing_keys() -> None:
|
|
global SIGNING_PRIVATE_KEY, SIGNING_PUBLIC_KEY_B64
|
|
|
|
KEYS_DIR.mkdir(parents=True, exist_ok=True)
|
|
|
|
if not PRIVATE_KEY_PATH.exists():
|
|
signing_key = SigningKey.generate(curve=Ed25519)
|
|
verifying_key = signing_key.verifying_key
|
|
PRIVATE_KEY_PATH.write_bytes(signing_key.to_pem(format="pkcs8"))
|
|
PUBLIC_KEY_PATH.write_bytes(verifying_key.to_pem())
|
|
SIGNING_PRIVATE_KEY = signing_key
|
|
SIGNING_PUBLIC_KEY_B64 = base64.b64encode(verifying_key.to_string()).decode("ascii")
|
|
logger.info("Generated new Ed25519 signing key pair at {}.", KEYS_DIR)
|
|
return
|
|
|
|
try:
|
|
SIGNING_PRIVATE_KEY = SigningKey.from_pem(PRIVATE_KEY_PATH.read_bytes())
|
|
except Exception as exc:
|
|
logger.error("Failed to read signing private key: {}", exc)
|
|
raise
|
|
|
|
verifying_key: Optional[VerifyingKey]
|
|
if PUBLIC_KEY_PATH.exists():
|
|
try:
|
|
verifying_key = VerifyingKey.from_pem(PUBLIC_KEY_PATH.read_bytes())
|
|
except Exception as exc:
|
|
logger.warning(
|
|
"Existing public key was invalid; regenerating from private key: {}",
|
|
exc,
|
|
)
|
|
verifying_key = SIGNING_PRIVATE_KEY.verifying_key
|
|
PUBLIC_KEY_PATH.write_bytes(verifying_key.to_pem())
|
|
else:
|
|
verifying_key = SIGNING_PRIVATE_KEY.verifying_key
|
|
PUBLIC_KEY_PATH.write_bytes(verifying_key.to_pem())
|
|
|
|
SIGNING_PUBLIC_KEY_B64 = base64.b64encode(verifying_key.to_string()).decode("ascii")
|
|
logger.debug("Signing key pair loaded from {}.", KEYS_DIR)
|
|
|
|
with psycopg.connect(connect_statement) as conn:
|
|
with conn.cursor() as cur:
|
|
cur.execute(
|
|
"""
|
|
CREATE TABLE IF NOT EXISTS license_keys (
|
|
key TEXT PRIMARY KEY,
|
|
issue_timestamp TIMESTAMPTZ NOT NULL,
|
|
expiration_timestamp TIMESTAMPTZ,
|
|
last_used_timestamp TIMESTAMPTZ,
|
|
info TEXT,
|
|
is_active BOOLEAN NOT NULL DEFAULT TRUE
|
|
)
|
|
"""
|
|
)
|
|
cur.execute(
|
|
"""
|
|
CREATE TABLE IF NOT EXISTS history (
|
|
action TEXT NOT NULL,
|
|
timestamp TIMESTAMPTZ NOT NULL
|
|
)
|
|
"""
|
|
)
|
|
conn.commit()
|
|
|
|
if SIGN_KEYS:
|
|
try:
|
|
_load_or_create_signing_keys()
|
|
except Exception as error:
|
|
logger.error("Failed to initialize signing keys: {}", error)
|
|
exit()
|
|
|
|
if DEBUG:
|
|
app = FastAPI()
|
|
else:
|
|
app = FastAPI(docs_url=None, redoc_url=None, openapi_url=None)
|
|
logger.info("Started FastAPI.")
|
|
|
|
def _generate_license_key() -> str:
|
|
"""
|
|
Generate a cryptographically secure token
|
|
"""
|
|
return "-".join(
|
|
"".join(secrets.choice(ALPHABET) for _ in range(LICENSE_KEY_PART_LENGTH))
|
|
for _ in range(LICENSE_KEY_PARTS)
|
|
)
|
|
|
|
@app.post("/license", status_code=status.HTTP_201_CREATED)
|
|
async def create_license_key(
|
|
is_active: bool = True,
|
|
credentials: HTTPAuthorizationCredentials = Depends(security),
|
|
expiration_date: Optional[datetime] = None,
|
|
info: Optional[str] = None,
|
|
):
|
|
if (
|
|
credentials is None
|
|
or credentials.scheme.lower() != "bearer"
|
|
or credentials.credentials != api_key
|
|
):
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Invalid or missing API key.",
|
|
)
|
|
|
|
license_key = _generate_license_key()
|
|
issued_at = datetime.now(timezone.utc)
|
|
expiration_ts: Optional[datetime] = None
|
|
if expiration_date is not None:
|
|
expiration_ts = expiration_date
|
|
if expiration_ts.tzinfo is None:
|
|
expiration_ts = expiration_ts.replace(tzinfo=timezone.utc)
|
|
else:
|
|
expiration_ts = expiration_ts.astimezone(timezone.utc)
|
|
|
|
def _persist_license():
|
|
with psycopg.connect(connect_statement) as conn:
|
|
with conn.cursor() as cur:
|
|
cur.execute(
|
|
"""
|
|
INSERT INTO license_keys (
|
|
key,
|
|
issue_timestamp,
|
|
expiration_timestamp,
|
|
last_used_timestamp,
|
|
info,
|
|
is_active
|
|
)
|
|
VALUES (%s, %s, %s, %s, %s, %s)
|
|
""",
|
|
(
|
|
license_key,
|
|
issued_at,
|
|
expiration_ts,
|
|
None,
|
|
info,
|
|
is_active,
|
|
),
|
|
)
|
|
cur.execute(
|
|
"""
|
|
INSERT INTO history (action, timestamp)
|
|
VALUES (%s, %s)
|
|
""",
|
|
(
|
|
(
|
|
"create_license_key key={license_key} active={is_active} expiration={expiration} info={info}"
|
|
.format(
|
|
license_key=license_key,
|
|
is_active=is_active,
|
|
expiration=expiration_ts.isoformat()
|
|
if expiration_ts
|
|
else "none",
|
|
info=info if info is not None else "none",
|
|
)
|
|
),
|
|
issued_at,
|
|
),
|
|
)
|
|
conn.commit()
|
|
|
|
try:
|
|
await asyncio.to_thread(_persist_license)
|
|
except Exception as exc:
|
|
logger.exception("Failed to create license key.")
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="Failed to create license key.",
|
|
) from exc
|
|
|
|
return {
|
|
"license_key": license_key,
|
|
"expiration_timestamp": expiration_ts.isoformat() if expiration_ts else None,
|
|
"is_active": is_active,
|
|
"info": info,
|
|
}
|
|
|
|
@app.get("/public-key")
|
|
async def get_public_key():
|
|
if not SIGN_KEYS:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail="License key signing is disabled.",
|
|
)
|
|
|
|
if SIGNING_PUBLIC_KEY_B64 is None:
|
|
try:
|
|
_load_or_create_signing_keys()
|
|
except Exception as error:
|
|
logger.error("Failed to load public key: {}", error)
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="Signing is enabled but keys are unavailable.",
|
|
)
|
|
|
|
return {"public_key": SIGNING_PUBLIC_KEY_B64}
|
|
|
|
@app.get("/is_valid")
|
|
async def is_license_key_valid(license_key: str):
|
|
"""Validate the supplied license key against the database."""
|
|
now = datetime.now(timezone.utc)
|
|
|
|
def _lookup() -> tuple[bool, Optional[datetime]]:
|
|
try:
|
|
with psycopg.connect(connect_statement) as conn:
|
|
with conn.cursor() as cur:
|
|
cur.execute(
|
|
"""
|
|
UPDATE license_keys
|
|
SET last_used_timestamp = %s
|
|
WHERE key = %s
|
|
AND is_active = TRUE
|
|
AND (expiration_timestamp IS NULL OR expiration_timestamp > %s)
|
|
RETURNING expiration_timestamp
|
|
""",
|
|
(now, license_key, now),
|
|
)
|
|
row = cur.fetchone()
|
|
if row is None:
|
|
return False, None
|
|
conn.commit()
|
|
return True, row[0]
|
|
except Exception:
|
|
logger.exception("Failed validating license key.")
|
|
return False, None
|
|
|
|
valid, expiration_ts = await asyncio.to_thread(_lookup)
|
|
|
|
if not valid:
|
|
return {"valid": False}
|
|
|
|
if not SIGN_KEYS:
|
|
return {"valid": True}
|
|
|
|
if SIGNING_PRIVATE_KEY is None:
|
|
try:
|
|
_load_or_create_signing_keys()
|
|
except Exception as error:
|
|
logger.error("Failed to load signing key for validation: {}", error)
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="Signing is enabled but keys are unavailable.",
|
|
)
|
|
|
|
payload = {
|
|
"license_key": license_key,
|
|
"expiration_timestamp": expiration_ts.isoformat() if expiration_ts else None,
|
|
}
|
|
message = json.dumps(payload, sort_keys=True, separators=(",", ":")).encode("utf-8")
|
|
signature = base64.b64encode(SIGNING_PRIVATE_KEY.sign(message)).decode("ascii")
|
|
return {"valid": True, "license": payload, "signature": signature}
|
|
|
|
@app.post("/license/{license_key}/disable", status_code=status.HTTP_200_OK)
|
|
async def disable_license_key(
|
|
license_key: str,
|
|
credentials: HTTPAuthorizationCredentials = Depends(security),
|
|
):
|
|
if (
|
|
credentials is None
|
|
or credentials.scheme.lower() != "bearer"
|
|
or credentials.credentials != api_key
|
|
):
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Invalid or missing API key.",
|
|
)
|
|
|
|
disabled_at = datetime.now(timezone.utc)
|
|
|
|
def _disable() -> bool:
|
|
with psycopg.connect(connect_statement) as conn:
|
|
with conn.cursor() as cur:
|
|
cur.execute(
|
|
"""
|
|
UPDATE license_keys
|
|
SET is_active = FALSE
|
|
WHERE key = %s AND is_active = TRUE
|
|
""",
|
|
(license_key,),
|
|
)
|
|
if cur.rowcount == 0:
|
|
return False
|
|
cur.execute(
|
|
"""
|
|
INSERT INTO history (action, timestamp)
|
|
VALUES (%s, %s)
|
|
""",
|
|
(f"disable_license_key key={license_key}", disabled_at),
|
|
)
|
|
conn.commit()
|
|
return True
|
|
|
|
try:
|
|
disabled = await asyncio.to_thread(_disable)
|
|
except Exception as exc:
|
|
logger.exception("Failed disabling license key.")
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="Failed to disable license key.",
|
|
) from exc
|
|
|
|
if not disabled:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail="License key not found or already disabled.",
|
|
)
|
|
|
|
return {"license_key": license_key, "is_active": False}
|
|
|
|
@app.post("/license/{license_key}/enable", status_code=status.HTTP_200_OK)
|
|
async def enable_license_key(
|
|
license_key: str,
|
|
credentials: HTTPAuthorizationCredentials = Depends(security),
|
|
):
|
|
if (
|
|
credentials is None
|
|
or credentials.scheme.lower() != "bearer"
|
|
or credentials.credentials != api_key
|
|
):
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Invalid or missing API key.",
|
|
)
|
|
|
|
enabled_at = datetime.now(timezone.utc)
|
|
|
|
def _enable() -> bool:
|
|
with psycopg.connect(connect_statement) as conn:
|
|
with conn.cursor() as cur:
|
|
cur.execute(
|
|
"""
|
|
UPDATE license_keys
|
|
SET is_active = TRUE
|
|
WHERE key = %s AND is_active = FALSE
|
|
""",
|
|
(license_key,),
|
|
)
|
|
if cur.rowcount == 0:
|
|
return False
|
|
cur.execute(
|
|
"""
|
|
INSERT INTO history (action, timestamp)
|
|
VALUES (%s, %s)
|
|
""",
|
|
(f"enable_license_key key={license_key}", enabled_at),
|
|
)
|
|
conn.commit()
|
|
return True
|
|
|
|
try:
|
|
enabled = await asyncio.to_thread(_enable)
|
|
except Exception as exc:
|
|
logger.exception("Failed enabling license key.")
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="Failed to enable license key.",
|
|
) from exc
|
|
|
|
if not enabled:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail="License key not found or already enabled.",
|
|
)
|
|
|
|
return {"license_key": license_key, "is_active": True}
|
|
|
|
@app.post("/license/{license_key}/expiration", status_code=status.HTTP_200_OK)
|
|
async def update_license_expiration(
|
|
license_key: str,
|
|
expiration_date: Optional[datetime] = None,
|
|
credentials: HTTPAuthorizationCredentials = Depends(security),
|
|
):
|
|
if (
|
|
credentials is None
|
|
or credentials.scheme.lower() != "bearer"
|
|
or credentials.credentials != api_key
|
|
):
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Invalid or missing API key.",
|
|
)
|
|
|
|
expiration_ts: Optional[datetime] = None
|
|
if expiration_date is not None:
|
|
expiration_ts = expiration_date
|
|
if expiration_ts.tzinfo is None:
|
|
expiration_ts = expiration_ts.replace(tzinfo=timezone.utc)
|
|
else:
|
|
expiration_ts = expiration_ts.astimezone(timezone.utc)
|
|
|
|
updated_at = datetime.now(timezone.utc)
|
|
|
|
def _update():
|
|
with psycopg.connect(connect_statement) as conn:
|
|
with conn.cursor() as cur:
|
|
cur.execute(
|
|
"""
|
|
UPDATE license_keys
|
|
SET expiration_timestamp = %s
|
|
WHERE key = %s
|
|
RETURNING expiration_timestamp
|
|
""",
|
|
(expiration_ts, license_key),
|
|
)
|
|
row = cur.fetchone()
|
|
if row is None:
|
|
return False, None
|
|
new_expiration = row[0]
|
|
cur.execute(
|
|
"""
|
|
INSERT INTO history (action, timestamp)
|
|
VALUES (%s, %s)
|
|
""",
|
|
(
|
|
(
|
|
"update_expiration license_key={key} expiration={expiration}"
|
|
.format(
|
|
key=license_key,
|
|
expiration=new_expiration.isoformat()
|
|
if new_expiration
|
|
else "cleared",
|
|
)
|
|
),
|
|
updated_at,
|
|
),
|
|
)
|
|
conn.commit()
|
|
return True, new_expiration
|
|
|
|
try:
|
|
success, new_expiration = await asyncio.to_thread(_update)
|
|
except Exception as exc:
|
|
logger.exception("Failed updating license expiration.")
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="Failed to update license expiration.",
|
|
) from exc
|
|
|
|
if not success:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail="License key not found.",
|
|
)
|
|
|
|
return {
|
|
"license_key": license_key,
|
|
"expiration_timestamp": new_expiration.isoformat()
|
|
if new_expiration
|
|
else None,
|
|
}
|
|
|
|
@app.get("/license/export", status_code=status.HTTP_200_OK)
|
|
async def export_license_keys(
|
|
credentials: HTTPAuthorizationCredentials = Depends(security),
|
|
):
|
|
if (
|
|
credentials is None
|
|
or credentials.scheme.lower() != "bearer"
|
|
or credentials.credentials != api_key
|
|
):
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Invalid or missing API key.",
|
|
)
|
|
|
|
def _fetch():
|
|
with psycopg.connect(connect_statement) as conn:
|
|
with conn.cursor() as cur:
|
|
cur.execute(
|
|
"""
|
|
SELECT key, issue_timestamp, expiration_timestamp, info, is_active
|
|
FROM license_keys
|
|
ORDER BY issue_timestamp ASC
|
|
"""
|
|
)
|
|
return cur.fetchall()
|
|
|
|
try:
|
|
rows = await asyncio.to_thread(_fetch)
|
|
except Exception as exc:
|
|
logger.exception("Failed exporting license keys.")
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="Failed to export license keys.",
|
|
) from exc
|
|
|
|
lines = ["license_key,issue_timestamp,expiration_timestamp,info,is_active"]
|
|
for license_key, issue_ts, expiration_ts, info_value, is_active in rows:
|
|
lines.append(
|
|
"{key},{issue},{expiration},{info},{active}".format(
|
|
key=license_key,
|
|
issue=issue_ts.isoformat(),
|
|
expiration=expiration_ts.isoformat()
|
|
if expiration_ts
|
|
else "",
|
|
info=(info_value or "").replace("\n", " ").replace("\r", " ").replace(",", " "),
|
|
active="true" if is_active else "false",
|
|
)
|
|
)
|
|
csv_content = "\n".join(lines)
|
|
return Response(
|
|
content=csv_content,
|
|
media_type="text/csv",
|
|
headers={"Content-Disposition": 'attachment; filename="license_export.csv"'},
|
|
)
|
|
|
|
@app.get("/history/export", status_code=status.HTTP_200_OK)
|
|
async def export_history(
|
|
token: Optional[str] = None,
|
|
credentials: HTTPAuthorizationCredentials = Depends(security),
|
|
):
|
|
if (
|
|
credentials is None
|
|
or credentials.scheme.lower() != "bearer"
|
|
or credentials.credentials != api_key
|
|
):
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Invalid or missing API key.",
|
|
)
|
|
|
|
def _fetch():
|
|
with psycopg.connect(connect_statement) as conn:
|
|
with conn.cursor() as cur:
|
|
if token:
|
|
cur.execute(
|
|
"""
|
|
SELECT action, timestamp
|
|
FROM history
|
|
WHERE action LIKE %s
|
|
ORDER BY timestamp ASC
|
|
""",
|
|
(f"%{token}%",),
|
|
)
|
|
else:
|
|
cur.execute(
|
|
"""
|
|
SELECT action, timestamp
|
|
FROM history
|
|
ORDER BY timestamp ASC
|
|
"""
|
|
)
|
|
return cur.fetchall()
|
|
|
|
try:
|
|
rows = await asyncio.to_thread(_fetch)
|
|
except Exception as exc:
|
|
logger.exception("Failed exporting history.")
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="Failed to export history.",
|
|
) from exc
|
|
|
|
lines = ["action,timestamp"]
|
|
for action, ts in rows:
|
|
lines.append(f"{action},{ts.isoformat()}")
|
|
csv_content = "\n".join(lines)
|
|
filename_token = token.replace(" ", "_") if token else "history"
|
|
return Response(
|
|
content=csv_content,
|
|
media_type="text/csv",
|
|
headers={
|
|
"Content-Disposition": f'attachment; filename="{filename_token}_export.csv"'
|
|
},
|
|
)
|
|
|
|
@app.get("/")
|
|
async def get_server_info():
|
|
"""
|
|
Return the server name and version
|
|
"""
|
|
def _check_database():
|
|
with psycopg.connect(connect_statement) as conn:
|
|
with conn.cursor() as cur:
|
|
cur.execute("SELECT 1")
|
|
cur.fetchone()
|
|
|
|
try:
|
|
await asyncio.to_thread(_check_database)
|
|
except Exception as exc:
|
|
logger.exception("Database health check failed.")
|
|
raise HTTPException(
|
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
|
detail="Database unavailable.",
|
|
) from exc
|
|
|
|
info = {
|
|
"version":"Micro License Server {version}".format(version=VERSION)
|
|
}
|
|
return info
|