diff --git a/main.py b/main.py new file mode 100644 index 0000000..67ebdcc --- /dev/null +++ b/main.py @@ -0,0 +1,144 @@ +#!/usr/bin/env python3 +""" +CORS proxy for Unsloth Studio API. + +Usage: + python cors_proxy.py # proxies to http://127.0.0.1:8888 on port 8080 + python cors_proxy.py --target 8000 # proxies to http://127.0.0.1:8000 + python cors_proxy.py --listen 9090 # listens on port 9090 + python cors_proxy.py --target 10.0.0.5:8000 --listen 9000 +""" + +import argparse +import json +from http.server import HTTPServer, BaseHTTPRequestHandler +from urllib.request import Request, urlopen, URLError + +CORS_HEADERS = { + "Access-Control-Allow-Origin": "*", + "Access-Control-Allow-Methods": "POST, OPTIONS, GET", + "Access-Control-Allow-Headers": "Content-Type, Authorization", + "Access-Control-Max-Age": "86400", +} + + +class ProxyHandler(BaseHTTPRequestHandler): + target_host = "127.0.0.1" + target_port = 8888 + + def _target_url(self): + return f"http://{self.target_host}:{self.target_port}{self.path}" + + def _set_cors(self, status=200): + self.send_response(status) + for k, v in CORS_HEADERS.items(): + self.send_header(k, v) + + def do_OPTIONS(self): + self._set_cors(204) + self.end_headers() + + def do_POST(self): + content_length = int(self.headers.get("Content-Length", 0)) + body = self.rfile.read(content_length) + + target = self._target_url() + req = Request(target, data=body, method="POST") + + for h in ("Content-Type", "Authorization"): + val = self.headers.get(h) + if val: + req.add_header(h, val) + + try: + resp = urlopen(req, timeout=60) + resp_body = resp.read() + status = resp.status + except URLError as e: + detail = f"Upstream server at {self.target_host}:{self.target_port} is unreachable." + if hasattr(e, "reason"): + detail += f" Reason: {e.reason}" + self._set_cors(502) + self.send_header("Content-Type", "application/json") + self.end_headers() + self.wfile.write(json.dumps({ + "error": {"message": detail, "type": "proxy_error"} + }).encode()) + return + except Exception as e: + self._set_cors(502) + self.send_header("Content-Type", "application/json") + self.end_headers() + self.wfile.write(json.dumps({ + "error": {"message": f"Proxy error: {e}", "type": "proxy_error"} + }).encode()) + return + + self._set_cors(status) + ctype = resp.headers.get("Content-Type", "application/json") + self.send_header("Content-Type", ctype) + self.end_headers() + self.wfile.write(resp_body) + + def log_message(self, fmt, *args): + method = self.command + path = self.path + print(f" [{method}] {path} -> {self._target_url()}") + + def do_GET(self): + if self.path == "/health": + self._set_cors(200) + self.send_header("Content-Type", "application/json") + self.end_headers() + self.wfile.write(json.dumps({ + "status": "ok", + "proxy_to": f"{self.target_host}:{self.target_port}", + }).encode()) + return + self._set_cors(405) + self.end_headers() + + +def main(): + parser = argparse.ArgumentParser(description="CORS proxy for Unsloth Studio") + parser.add_argument( + "--target", default="127.0.0.1:8888", + help="Unsloth Studio address (default: 127.0.0.1:8888)", + ) + parser.add_argument( + "--listen", default=8080, type=int, + help="Port to listen on (default: 8080)", + ) + args = parser.parse_args() + + host, port_str = args.target, "8888" + if ":" in args.target: + host, port_str = args.target.rsplit(":", 1) + try: + port = int(port_str) + except ValueError: + print(f"Invalid target port: {port_str}") + sys.exit(1) + + ProxyHandler.target_host = host + ProxyHandler.target_port = port + + server = HTTPServer(("0.0.0.0", args.listen), ProxyHandler) + + print(f" Unsloth Studio CORS Proxy") + print(f" ─────────────────────────") + print(f" Listening on: http://127.0.0.1:{args.listen}") + print(f" Forwarding to: http://{host}:{port}") + print(f" Plugin API URL: http://127.0.0.1:{args.listen}") + print(f" Health check: http://127.0.0.1:{args.listen}/health") + print() + + try: + server.serve_forever() + except KeyboardInterrupt: + print("\nShutting down.") + server.server_close() + + +if __name__ == "__main__": + main()