progress on license server
This commit is contained in:
500
src/main.py
Normal file
500
src/main.py
Normal file
@@ -0,0 +1,500 @@
|
||||
from fastapi import Depends, FastAPI, HTTPException, status, Response
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
from dotenv import dotenv_values
|
||||
from loguru import logger
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
import secrets
|
||||
import string
|
||||
import psycopg
|
||||
import asyncio
|
||||
|
||||
ENV = dotenv_values(".env") # .env file
|
||||
VERSION = "v0.0.1" # version number
|
||||
ALPHABET = string.ascii_uppercase + string.digits #alphabet and numbers for token generation
|
||||
LICENSE_KEY_PARTS = 5 if "NUM_KEY_CHUNKS" not in ENV else ENV["NUM_KEY_CHUNKS"] #number of chunks in the new generated license keys
|
||||
LICENSE_KEY_PART_LENGTH = 5 if "KEY_CHUNK_LENGTH" not in ENV else ENV["KEY_CHUNK_LENGTH"] #number of characters in each chunk
|
||||
|
||||
api_key = ""
|
||||
security = HTTPBearer(auto_error=False)
|
||||
|
||||
if "POSTGRESQL_PASSWORD" in ENV:
|
||||
logger.debug("POSTGRESQL_PASSWORD successfully read.")
|
||||
connect_statement = "dbname=postgres user=postgres host=postgresql password='{password}'".format(password = ENV['POSTGRESQL_PASSWORD'])
|
||||
else:
|
||||
logger.error("POSTGRESQL_PASSWORD not found in .env file... cannot continue.")
|
||||
exit()
|
||||
|
||||
if "API_KEY" in ENV:
|
||||
logger.debug("API_KEY successfully read.")
|
||||
api_key = ENV['API_KEY']
|
||||
else:
|
||||
logger.error("API_KEY not found in .env file... cannot continue.")
|
||||
exit()
|
||||
|
||||
with psycopg.connect(connect_statement) as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS license_keys (
|
||||
key TEXT PRIMARY KEY,
|
||||
issue_timestamp TIMESTAMPTZ NOT NULL,
|
||||
expiration_timestamp TIMESTAMPTZ,
|
||||
last_used_timestamp TIMESTAMPTZ,
|
||||
is_active BOOLEAN NOT NULL DEFAULT TRUE
|
||||
)
|
||||
"""
|
||||
)
|
||||
cur.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS history (
|
||||
action TEXT NOT NULL,
|
||||
timestamp TIMESTAMPTZ NOT NULL
|
||||
)
|
||||
"""
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
#app = FastAPI(docs_url=None, redoc_url=None, openapi_url=None)
|
||||
app = FastAPI()
|
||||
logger.info("Started FastAPI.")
|
||||
|
||||
|
||||
|
||||
def _generate_license_key() -> str:
|
||||
"""
|
||||
Generate a cryptographically secure token
|
||||
"""
|
||||
return "-".join(
|
||||
"".join(secrets.choice(ALPHABET) for _ in range(LICENSE_KEY_PART_LENGTH))
|
||||
for _ in range(LICENSE_KEY_PARTS)
|
||||
)
|
||||
|
||||
|
||||
@app.post("/license", status_code=status.HTTP_201_CREATED)
|
||||
async def create_license_key(
|
||||
is_active: bool = True,
|
||||
credentials: HTTPAuthorizationCredentials = Depends(security),
|
||||
expiration_date: Optional[datetime] = None,
|
||||
):
|
||||
if (
|
||||
credentials is None
|
||||
or credentials.scheme.lower() != "bearer"
|
||||
or credentials.credentials != api_key
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid or missing API key.",
|
||||
)
|
||||
|
||||
license_key = _generate_license_key()
|
||||
issued_at = datetime.now(timezone.utc)
|
||||
expiration_ts: Optional[datetime] = None
|
||||
if expiration_date is not None:
|
||||
expiration_ts = expiration_date
|
||||
if expiration_ts.tzinfo is None:
|
||||
expiration_ts = expiration_ts.replace(tzinfo=timezone.utc)
|
||||
else:
|
||||
expiration_ts = expiration_ts.astimezone(timezone.utc)
|
||||
|
||||
def _persist_license():
|
||||
with psycopg.connect(connect_statement) as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
INSERT INTO license_keys (
|
||||
key,
|
||||
issue_timestamp,
|
||||
expiration_timestamp,
|
||||
last_used_timestamp,
|
||||
is_active
|
||||
)
|
||||
VALUES (%s, %s, %s, %s, %s)
|
||||
""",
|
||||
(
|
||||
license_key,
|
||||
issued_at,
|
||||
expiration_ts,
|
||||
None,
|
||||
is_active,
|
||||
),
|
||||
)
|
||||
cur.execute(
|
||||
"""
|
||||
INSERT INTO history (action, timestamp)
|
||||
VALUES (%s, %s)
|
||||
""",
|
||||
(
|
||||
f"create_license_key key={license_key} active={is_active} expiration={expiration_ts.isoformat() if expiration_ts else 'none'}",
|
||||
issued_at,
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
try:
|
||||
await asyncio.to_thread(_persist_license)
|
||||
except Exception as exc:
|
||||
logger.exception("Failed to create license key.")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to create license key.",
|
||||
) from exc
|
||||
|
||||
return {
|
||||
"license_key": license_key,
|
||||
"expiration_timestamp": expiration_ts.isoformat() if expiration_ts else None,
|
||||
"is_active": is_active,
|
||||
}
|
||||
|
||||
@app.get("/is_valid")
|
||||
async def is_license_key_valid(license_key: str) -> bool:
|
||||
"""Validate the supplied license key against the database."""
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
def _lookup() -> bool:
|
||||
try:
|
||||
with psycopg.connect(connect_statement) as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
UPDATE license_keys
|
||||
SET last_used_timestamp = %s
|
||||
WHERE key = %s
|
||||
AND is_active = TRUE
|
||||
AND (expiration_timestamp IS NULL OR expiration_timestamp > %s)
|
||||
RETURNING 1
|
||||
""",
|
||||
(now, license_key, now),
|
||||
)
|
||||
if cur.fetchone() is None:
|
||||
return False
|
||||
conn.commit()
|
||||
return True
|
||||
except Exception:
|
||||
logger.exception("Failed validating license key.")
|
||||
return False
|
||||
|
||||
return await asyncio.to_thread(_lookup)
|
||||
|
||||
@app.post("/license/{license_key}/disable", status_code=status.HTTP_200_OK)
|
||||
async def disable_license_key(
|
||||
license_key: str,
|
||||
credentials: HTTPAuthorizationCredentials = Depends(security),
|
||||
):
|
||||
if (
|
||||
credentials is None
|
||||
or credentials.scheme.lower() != "bearer"
|
||||
or credentials.credentials != api_key
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid or missing API key.",
|
||||
)
|
||||
|
||||
disabled_at = datetime.now(timezone.utc)
|
||||
|
||||
def _disable() -> bool:
|
||||
with psycopg.connect(connect_statement) as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
UPDATE license_keys
|
||||
SET is_active = FALSE
|
||||
WHERE key = %s AND is_active = TRUE
|
||||
""",
|
||||
(license_key,),
|
||||
)
|
||||
if cur.rowcount == 0:
|
||||
return False
|
||||
cur.execute(
|
||||
"""
|
||||
INSERT INTO history (action, timestamp)
|
||||
VALUES (%s, %s)
|
||||
""",
|
||||
(f"disable_license_key key={license_key}", disabled_at),
|
||||
)
|
||||
conn.commit()
|
||||
return True
|
||||
|
||||
try:
|
||||
disabled = await asyncio.to_thread(_disable)
|
||||
except Exception as exc:
|
||||
logger.exception("Failed disabling license key.")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to disable license key.",
|
||||
) from exc
|
||||
|
||||
if not disabled:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="License key not found or already disabled.",
|
||||
)
|
||||
|
||||
return {"license_key": license_key, "is_active": False}
|
||||
|
||||
@app.post("/license/{license_key}/enable", status_code=status.HTTP_200_OK)
|
||||
async def enable_license_key(
|
||||
license_key: str,
|
||||
credentials: HTTPAuthorizationCredentials = Depends(security),
|
||||
):
|
||||
if (
|
||||
credentials is None
|
||||
or credentials.scheme.lower() != "bearer"
|
||||
or credentials.credentials != api_key
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid or missing API key.",
|
||||
)
|
||||
|
||||
enabled_at = datetime.now(timezone.utc)
|
||||
|
||||
def _enable() -> bool:
|
||||
with psycopg.connect(connect_statement) as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
UPDATE license_keys
|
||||
SET is_active = TRUE
|
||||
WHERE key = %s AND is_active = FALSE
|
||||
""",
|
||||
(license_key,),
|
||||
)
|
||||
if cur.rowcount == 0:
|
||||
return False
|
||||
cur.execute(
|
||||
"""
|
||||
INSERT INTO history (action, timestamp)
|
||||
VALUES (%s, %s)
|
||||
""",
|
||||
(f"enable_license_key key={license_key}", enabled_at),
|
||||
)
|
||||
conn.commit()
|
||||
return True
|
||||
|
||||
try:
|
||||
enabled = await asyncio.to_thread(_enable)
|
||||
except Exception as exc:
|
||||
logger.exception("Failed enabling license key.")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to enable license key.",
|
||||
) from exc
|
||||
|
||||
if not enabled:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="License key not found or already enabled.",
|
||||
)
|
||||
|
||||
return {"license_key": license_key, "is_active": True}
|
||||
|
||||
@app.post("/license/{license_key}/expiration", status_code=status.HTTP_200_OK)
|
||||
async def update_license_expiration(
|
||||
license_key: str,
|
||||
expiration_date: Optional[datetime] = None,
|
||||
credentials: HTTPAuthorizationCredentials = Depends(security),
|
||||
):
|
||||
if (
|
||||
credentials is None
|
||||
or credentials.scheme.lower() != "bearer"
|
||||
or credentials.credentials != api_key
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid or missing API key.",
|
||||
)
|
||||
|
||||
expiration_ts: Optional[datetime] = None
|
||||
if expiration_date is not None:
|
||||
expiration_ts = expiration_date
|
||||
if expiration_ts.tzinfo is None:
|
||||
expiration_ts = expiration_ts.replace(tzinfo=timezone.utc)
|
||||
else:
|
||||
expiration_ts = expiration_ts.astimezone(timezone.utc)
|
||||
|
||||
updated_at = datetime.now(timezone.utc)
|
||||
|
||||
def _update():
|
||||
with psycopg.connect(connect_statement) as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
UPDATE license_keys
|
||||
SET expiration_timestamp = %s
|
||||
WHERE key = %s
|
||||
RETURNING expiration_timestamp
|
||||
""",
|
||||
(expiration_ts, license_key),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
if row is None:
|
||||
return False, None
|
||||
new_expiration = row[0]
|
||||
cur.execute(
|
||||
"""
|
||||
INSERT INTO history (action, timestamp)
|
||||
VALUES (%s, %s)
|
||||
""",
|
||||
(
|
||||
(
|
||||
"update_expiration license_key={key} expiration={expiration}"
|
||||
.format(
|
||||
key=license_key,
|
||||
expiration=new_expiration.isoformat()
|
||||
if new_expiration
|
||||
else "cleared",
|
||||
)
|
||||
),
|
||||
updated_at,
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
return True, new_expiration
|
||||
|
||||
try:
|
||||
success, new_expiration = await asyncio.to_thread(_update)
|
||||
except Exception as exc:
|
||||
logger.exception("Failed updating license expiration.")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to update license expiration.",
|
||||
) from exc
|
||||
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="License key not found.",
|
||||
)
|
||||
|
||||
return {
|
||||
"license_key": license_key,
|
||||
"expiration_timestamp": new_expiration.isoformat()
|
||||
if new_expiration
|
||||
else None,
|
||||
}
|
||||
|
||||
@app.get("/license/export", status_code=status.HTTP_200_OK)
|
||||
async def export_license_keys(
|
||||
credentials: HTTPAuthorizationCredentials = Depends(security),
|
||||
):
|
||||
if (
|
||||
credentials is None
|
||||
or credentials.scheme.lower() != "bearer"
|
||||
or credentials.credentials != api_key
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid or missing API key.",
|
||||
)
|
||||
|
||||
def _fetch():
|
||||
with psycopg.connect(connect_statement) as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT key, issue_timestamp, expiration_timestamp, is_active
|
||||
FROM license_keys
|
||||
ORDER BY issue_timestamp ASC
|
||||
"""
|
||||
)
|
||||
return cur.fetchall()
|
||||
|
||||
try:
|
||||
rows = await asyncio.to_thread(_fetch)
|
||||
except Exception as exc:
|
||||
logger.exception("Failed exporting license keys.")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to export license keys.",
|
||||
) from exc
|
||||
|
||||
lines = ["license_key,issue_timestamp,expiration_timestamp,is_active"]
|
||||
for license_key, issue_ts, expiration_ts, is_active in rows:
|
||||
lines.append(
|
||||
"{key},{issue},{expiration},{active}".format(
|
||||
key=license_key,
|
||||
issue=issue_ts.isoformat(),
|
||||
expiration=expiration_ts.isoformat()
|
||||
if expiration_ts
|
||||
else "",
|
||||
active="true" if is_active else "false",
|
||||
)
|
||||
)
|
||||
csv_content = "\n".join(lines)
|
||||
return Response(
|
||||
content=csv_content,
|
||||
media_type="text/csv",
|
||||
headers={"Content-Disposition": 'attachment; filename="license_export.csv"'},
|
||||
)
|
||||
|
||||
@app.get("/history/export", status_code=status.HTTP_200_OK)
|
||||
async def export_history(
|
||||
token: Optional[str] = None,
|
||||
credentials: HTTPAuthorizationCredentials = Depends(security),
|
||||
):
|
||||
if (
|
||||
credentials is None
|
||||
or credentials.scheme.lower() != "bearer"
|
||||
or credentials.credentials != api_key
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid or missing API key.",
|
||||
)
|
||||
|
||||
def _fetch():
|
||||
with psycopg.connect(connect_statement) as conn:
|
||||
with conn.cursor() as cur:
|
||||
if token:
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT action, timestamp
|
||||
FROM history
|
||||
WHERE action LIKE %s
|
||||
ORDER BY timestamp ASC
|
||||
""",
|
||||
(f"%{token}%",),
|
||||
)
|
||||
else:
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT action, timestamp
|
||||
FROM history
|
||||
ORDER BY timestamp ASC
|
||||
"""
|
||||
)
|
||||
return cur.fetchall()
|
||||
|
||||
try:
|
||||
rows = await asyncio.to_thread(_fetch)
|
||||
except Exception as exc:
|
||||
logger.exception("Failed exporting history.")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to export history.",
|
||||
) from exc
|
||||
|
||||
lines = ["action,timestamp"]
|
||||
for action, ts in rows:
|
||||
lines.append(f"{action},{ts.isoformat()}")
|
||||
csv_content = "\n".join(lines)
|
||||
filename_token = token.replace(" ", "_") if token else "history"
|
||||
return Response(
|
||||
content=csv_content,
|
||||
media_type="text/csv",
|
||||
headers={
|
||||
"Content-Disposition": f'attachment; filename="{filename_token}_export.csv"'
|
||||
},
|
||||
)
|
||||
|
||||
@app.get("/")
|
||||
async def get_server_info():
|
||||
"""
|
||||
Return the server name and version
|
||||
"""
|
||||
info = {
|
||||
"version":"Micro License Server {version}".format(version=VERSION)
|
||||
}
|
||||
return info
|
||||
Reference in New Issue
Block a user