#!/usr/bin/env python3 """CORS proxy for Unsloth Studio API. A lightweight reverse-proxy that adds Cross-Origin Resource Sharing (CORS) headers, allowing browser-based frontend apps to call the Unsloth Studio API running on localhost without being blocked by the Same-Origin Policy. Usage: python main.py # listen :8080, forward to :8888 python main.py --target 8000 # forward to http://127.0.0.1:8000 python main.py --listen 9090 # listen on port 9090 python main.py --target 10.0.0.5 # forward to http://10.0.0.5:8888 python main.py --target 10.0.0.5:8000 --listen 9000 Systemd: sudo cp services/corsproxy.service /etc/systemd/system/ sudo systemctl daemon-reload sudo systemctl enable --now corsproxy.service """ from __future__ import annotations import argparse import json import sys from http.server import HTTPServer, BaseHTTPRequestHandler from urllib.request import Request, urlopen from urllib.error import URLError # --------------------------------------------------------------------------- # Constants # --------------------------------------------------------------------------- CORS_HEADERS: dict[str, str] = { "Access-Control-Allow-Origin": "*", "Access-Control-Allow-Methods": "POST, OPTIONS, GET", "Access-Control-Allow-Headers": "Content-Type, Authorization", "Access-Control-Max-Age": "86400", } # --------------------------------------------------------------------------- # Utilities # --------------------------------------------------------------------------- def parse_target(target: str) -> tuple[str, int]: """Split a ``host:port`` string into ``(host, port)``. If no colon is present the default port ``8888`` is assumed. """ if ":" in target: host, port_str = target.rsplit(":", 1) else: host = target port_str = "8888" try: port: int = int(port_str) except ValueError: print(f"Error: Invalid target port: {port_str!r}") sys.exit(1) return host, port # --------------------------------------------------------------------------- # Request handler # --------------------------------------------------------------------------- class ProxyHandler(BaseHTTPRequestHandler): """Handle individual proxy requests. Class attributes can be mutated at runtime (e.g. by ``main``) to change the upstream target before the server starts listening. """ target_host: str = "127.0.0.1" target_port: int = 8888 # ---- helpers ---------------------------------------------------------- def _target_url(self) -> str: return f"http://{self.target_host}:{self.target_port}{self.path}" def _set_cors(self, status: int = 200) -> None: """Send a 200 OK (or other *status*) with standard CORS headers.""" self.send_response(status) for key, value in CORS_HEADERS.items(): self.send_header(key, value) def _json_response(self, status: int, payload: dict) -> None: """Write a JSON-encoded response body.""" self._set_cors(status) self.send_header("Content-Type", "application/json") self.end_headers() self.wfile.write(json.dumps(payload).encode()) # ---- HTTP methods ----------------------------------------------------- def do_OPTIONS(self) -> None: """Respond to CORS preflight requests with a 204 No Content.""" self._set_cors(204) self.end_headers() def do_POST(self) -> None: """Forward a POST request to the upstream server.""" content_length = int(self.headers.get("Content-Length", 0)) body = self.rfile.read(content_length) target = self._target_url() req = Request(target, data=body, method="POST") # Forward Content-Type and Authorization headers for header_name in ("Content-Type", "Authorization"): value = self.headers.get(header_name) if value: req.add_header(header_name, value) try: resp = urlopen(req, timeout=60) resp_body = resp.read() status = resp.status except URLError as exc: detail = ( f"Upstream server at {self.target_host}:{self.target_port} " "is unreachable." ) if hasattr(exc, "reason"): detail += f" Reason: {exc.reason}" self._json_response(502, {"error": {"message": detail, "type": "proxy_error"}}) return except Exception as exc: self._json_response( 502, {"error": {"message": f"Proxy error: {exc}", "type": "proxy_error"}} ) return self._set_cors(status) content_type = resp.headers.get("Content-Type", "application/json") self.send_header("Content-Type", content_type) self.end_headers() self.wfile.write(resp_body) def do_GET(self) -> None: """Serve a health-check endpoint; everything else returns 405.""" if self.path == "/health": self._json_response(200, { "status": "ok", "proxy_to": f"{self.target_host}:{self.target_port}", }) return self._set_cors(405) self.end_headers() # ---- logging ---------------------------------------------------------- def log_message(self, format: str, *args) -> None: """Print a human-readable log line for every request.""" method = self.command path = self.path target = self._target_url() print(f" [{method}] {path} -> {target}") # Suppress the default server stderr output from BaseHTTPRequestHandler def log_error(self, format: str, *args) -> None: pass # --------------------------------------------------------------------------- # Entry point # --------------------------------------------------------------------------- def main() -> None: """Parse CLI arguments, start the HTTP server, and block.""" parser = argparse.ArgumentParser( description="CORS proxy for Unsloth Studio", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=( "Examples:\n" " python main.py # :8080 -> :8888\n" " python main.py --target 10.0.0.5:8000 --listen 9090\n" ), ) parser.add_argument( "--target", default="127.0.0.1:8888", help="Upstream Unsloth Studio address (default: 127.0.0.1:8888)", ) parser.add_argument( "--listen", default=8080, type=int, help="Port to listen on (default: 8080)", ) args = parser.parse_args() # Resolve upstream host and port host, port = parse_target(args.target) # Configure handler and start the server ProxyHandler.target_host = host ProxyHandler.target_port = port server = HTTPServer(("0.0.0.0", args.listen), ProxyHandler) banner = ( " Unsloth Studio CORS Proxy\n" " ─────────────────────────\n" f" Listening on: http://127.0.0.1:{args.listen}\n" f" Forwarding to: http://{host}:{port}\n" f" Plugin API URL: http://127.0.0.1:{args.listen}\n" f" Health check: http://127.0.0.1:{args.listen}/health" ) print(banner) print() try: server.serve_forever() except KeyboardInterrupt: print("\nShutting down.") server.server_close() if __name__ == "__main__": main()