228 lines
7.5 KiB
Python
228 lines
7.5 KiB
Python
#!/usr/bin/env python3
|
|
"""CORS proxy for Unsloth Studio API.
|
|
|
|
A lightweight reverse-proxy that adds Cross-Origin Resource Sharing (CORS)
|
|
headers, allowing browser-based frontend apps to call the Unsloth Studio API
|
|
running on localhost without being blocked by the Same-Origin Policy.
|
|
|
|
Usage:
|
|
python main.py # listen :8080, forward to :8888
|
|
python main.py --target 8000 # forward to http://127.0.0.1:8000
|
|
python main.py --listen 9090 # listen on port 9090
|
|
python main.py --target 10.0.0.5 # forward to http://10.0.0.5:8888
|
|
python main.py --target 10.0.0.5:8000 --listen 9000
|
|
|
|
Systemd:
|
|
sudo cp services/corsproxy.service /etc/systemd/system/
|
|
sudo systemctl daemon-reload
|
|
sudo systemctl enable --now corsproxy.service
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import json
|
|
import sys
|
|
from http.server import HTTPServer, BaseHTTPRequestHandler
|
|
from urllib.request import Request, urlopen
|
|
from urllib.error import URLError
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Constants
|
|
# ---------------------------------------------------------------------------
|
|
|
|
CORS_HEADERS: dict[str, str] = {
|
|
"Access-Control-Allow-Origin": "*",
|
|
"Access-Control-Allow-Methods": "POST, OPTIONS, GET",
|
|
"Access-Control-Allow-Headers": "Content-Type, Authorization",
|
|
"Access-Control-Max-Age": "86400",
|
|
}
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Utilities
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def parse_target(target: str) -> tuple[str, int]:
|
|
"""Split a ``host:port`` string into ``(host, port)``.
|
|
|
|
If no colon is present the default port ``8888`` is assumed.
|
|
"""
|
|
if ":" in target:
|
|
host, port_str = target.rsplit(":", 1)
|
|
else:
|
|
host = target
|
|
port_str = "8888"
|
|
|
|
try:
|
|
port: int = int(port_str)
|
|
except ValueError:
|
|
print(f"Error: Invalid target port: {port_str!r}")
|
|
sys.exit(1)
|
|
|
|
return host, port
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Request handler
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class ProxyHandler(BaseHTTPRequestHandler):
|
|
"""Handle individual proxy requests.
|
|
|
|
Class attributes can be mutated at runtime (e.g. by ``main``) to change
|
|
the upstream target before the server starts listening.
|
|
"""
|
|
|
|
target_host: str = "127.0.0.1"
|
|
target_port: int = 8888
|
|
|
|
# ---- helpers ----------------------------------------------------------
|
|
|
|
def _target_url(self) -> str:
|
|
return f"http://{self.target_host}:{self.target_port}{self.path}"
|
|
|
|
def _set_cors(self, status: int = 200) -> None:
|
|
"""Send a 200 OK (or other *status*) with standard CORS headers."""
|
|
self.send_response(status)
|
|
for key, value in CORS_HEADERS.items():
|
|
self.send_header(key, value)
|
|
|
|
def _json_response(self, status: int, payload: dict) -> None:
|
|
"""Write a JSON-encoded response body."""
|
|
self._set_cors(status)
|
|
self.send_header("Content-Type", "application/json")
|
|
self.end_headers()
|
|
self.wfile.write(json.dumps(payload).encode())
|
|
|
|
# ---- HTTP methods -----------------------------------------------------
|
|
|
|
def do_OPTIONS(self) -> None:
|
|
"""Respond to CORS preflight requests with a 204 No Content."""
|
|
self._set_cors(204)
|
|
self.end_headers()
|
|
|
|
def do_POST(self) -> None:
|
|
"""Forward a POST request to the upstream server."""
|
|
content_length = int(self.headers.get("Content-Length", 0))
|
|
body = self.rfile.read(content_length)
|
|
|
|
target = self._target_url()
|
|
req = Request(target, data=body, method="POST")
|
|
|
|
# Forward Content-Type and Authorization headers
|
|
for header_name in ("Content-Type", "Authorization"):
|
|
value = self.headers.get(header_name)
|
|
if value:
|
|
req.add_header(header_name, value)
|
|
|
|
try:
|
|
resp = urlopen(req, timeout=60)
|
|
resp_body = resp.read()
|
|
status = resp.status
|
|
except URLError as exc:
|
|
detail = (
|
|
f"Upstream server at {self.target_host}:{self.target_port} "
|
|
"is unreachable."
|
|
)
|
|
if hasattr(exc, "reason"):
|
|
detail += f" Reason: {exc.reason}"
|
|
self._json_response(502, {"error": {"message": detail, "type": "proxy_error"}})
|
|
return
|
|
except Exception as exc:
|
|
self._json_response(
|
|
502, {"error": {"message": f"Proxy error: {exc}", "type": "proxy_error"}}
|
|
)
|
|
return
|
|
|
|
self._set_cors(status)
|
|
content_type = resp.headers.get("Content-Type", "application/json")
|
|
self.send_header("Content-Type", content_type)
|
|
self.end_headers()
|
|
self.wfile.write(resp_body)
|
|
|
|
def do_GET(self) -> None:
|
|
"""Serve a health-check endpoint; everything else returns 405."""
|
|
if self.path == "/health":
|
|
self._json_response(200, {
|
|
"status": "ok",
|
|
"proxy_to": f"{self.target_host}:{self.target_port}",
|
|
})
|
|
return
|
|
self._set_cors(405)
|
|
self.end_headers()
|
|
|
|
# ---- logging ----------------------------------------------------------
|
|
|
|
def log_message(self, format: str, *args) -> None:
|
|
"""Print a human-readable log line for every request."""
|
|
method = self.command
|
|
path = self.path
|
|
target = self._target_url()
|
|
print(f" [{method}] {path} -> {target}")
|
|
|
|
# Suppress the default server stderr output from BaseHTTPRequestHandler
|
|
def log_error(self, format: str, *args) -> None:
|
|
pass
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Entry point
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def main() -> None:
|
|
"""Parse CLI arguments, start the HTTP server, and block."""
|
|
parser = argparse.ArgumentParser(
|
|
description="CORS proxy for Unsloth Studio",
|
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
epilog=(
|
|
"Examples:\n"
|
|
" python main.py # :8080 -> :8888\n"
|
|
" python main.py --target 10.0.0.5:8000 --listen 9090\n"
|
|
),
|
|
)
|
|
parser.add_argument(
|
|
"--target",
|
|
default="127.0.0.1:8888",
|
|
help="Upstream Unsloth Studio address (default: 127.0.0.1:8888)",
|
|
)
|
|
parser.add_argument(
|
|
"--listen",
|
|
default=8080,
|
|
type=int,
|
|
help="Port to listen on (default: 8080)",
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
# Resolve upstream host and port
|
|
host, port = parse_target(args.target)
|
|
|
|
# Configure handler and start the server
|
|
ProxyHandler.target_host = host
|
|
ProxyHandler.target_port = port
|
|
|
|
server = HTTPServer(("0.0.0.0", args.listen), ProxyHandler)
|
|
|
|
banner = (
|
|
" Unsloth Studio CORS Proxy\n"
|
|
" ─────────────────────────\n"
|
|
f" Listening on: http://127.0.0.1:{args.listen}\n"
|
|
f" Forwarding to: http://{host}:{port}\n"
|
|
f" Plugin API URL: http://127.0.0.1:{args.listen}\n"
|
|
f" Health check: http://127.0.0.1:{args.listen}/health"
|
|
)
|
|
print(banner)
|
|
print()
|
|
|
|
try:
|
|
server.serve_forever()
|
|
except KeyboardInterrupt:
|
|
print("\nShutting down.")
|
|
server.server_close()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|