Add main.py
This commit is contained in:
144
main.py
Normal file
144
main.py
Normal file
@@ -0,0 +1,144 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
CORS proxy for Unsloth Studio API.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python cors_proxy.py # proxies to http://127.0.0.1:8888 on port 8080
|
||||||
|
python cors_proxy.py --target 8000 # proxies to http://127.0.0.1:8000
|
||||||
|
python cors_proxy.py --listen 9090 # listens on port 9090
|
||||||
|
python cors_proxy.py --target 10.0.0.5:8000 --listen 9000
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
from http.server import HTTPServer, BaseHTTPRequestHandler
|
||||||
|
from urllib.request import Request, urlopen, URLError
|
||||||
|
|
||||||
|
CORS_HEADERS = {
|
||||||
|
"Access-Control-Allow-Origin": "*",
|
||||||
|
"Access-Control-Allow-Methods": "POST, OPTIONS, GET",
|
||||||
|
"Access-Control-Allow-Headers": "Content-Type, Authorization",
|
||||||
|
"Access-Control-Max-Age": "86400",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class ProxyHandler(BaseHTTPRequestHandler):
|
||||||
|
target_host = "127.0.0.1"
|
||||||
|
target_port = 8888
|
||||||
|
|
||||||
|
def _target_url(self):
|
||||||
|
return f"http://{self.target_host}:{self.target_port}{self.path}"
|
||||||
|
|
||||||
|
def _set_cors(self, status=200):
|
||||||
|
self.send_response(status)
|
||||||
|
for k, v in CORS_HEADERS.items():
|
||||||
|
self.send_header(k, v)
|
||||||
|
|
||||||
|
def do_OPTIONS(self):
|
||||||
|
self._set_cors(204)
|
||||||
|
self.end_headers()
|
||||||
|
|
||||||
|
def do_POST(self):
|
||||||
|
content_length = int(self.headers.get("Content-Length", 0))
|
||||||
|
body = self.rfile.read(content_length)
|
||||||
|
|
||||||
|
target = self._target_url()
|
||||||
|
req = Request(target, data=body, method="POST")
|
||||||
|
|
||||||
|
for h in ("Content-Type", "Authorization"):
|
||||||
|
val = self.headers.get(h)
|
||||||
|
if val:
|
||||||
|
req.add_header(h, val)
|
||||||
|
|
||||||
|
try:
|
||||||
|
resp = urlopen(req, timeout=60)
|
||||||
|
resp_body = resp.read()
|
||||||
|
status = resp.status
|
||||||
|
except URLError as e:
|
||||||
|
detail = f"Upstream server at {self.target_host}:{self.target_port} is unreachable."
|
||||||
|
if hasattr(e, "reason"):
|
||||||
|
detail += f" Reason: {e.reason}"
|
||||||
|
self._set_cors(502)
|
||||||
|
self.send_header("Content-Type", "application/json")
|
||||||
|
self.end_headers()
|
||||||
|
self.wfile.write(json.dumps({
|
||||||
|
"error": {"message": detail, "type": "proxy_error"}
|
||||||
|
}).encode())
|
||||||
|
return
|
||||||
|
except Exception as e:
|
||||||
|
self._set_cors(502)
|
||||||
|
self.send_header("Content-Type", "application/json")
|
||||||
|
self.end_headers()
|
||||||
|
self.wfile.write(json.dumps({
|
||||||
|
"error": {"message": f"Proxy error: {e}", "type": "proxy_error"}
|
||||||
|
}).encode())
|
||||||
|
return
|
||||||
|
|
||||||
|
self._set_cors(status)
|
||||||
|
ctype = resp.headers.get("Content-Type", "application/json")
|
||||||
|
self.send_header("Content-Type", ctype)
|
||||||
|
self.end_headers()
|
||||||
|
self.wfile.write(resp_body)
|
||||||
|
|
||||||
|
def log_message(self, fmt, *args):
|
||||||
|
method = self.command
|
||||||
|
path = self.path
|
||||||
|
print(f" [{method}] {path} -> {self._target_url()}")
|
||||||
|
|
||||||
|
def do_GET(self):
|
||||||
|
if self.path == "/health":
|
||||||
|
self._set_cors(200)
|
||||||
|
self.send_header("Content-Type", "application/json")
|
||||||
|
self.end_headers()
|
||||||
|
self.wfile.write(json.dumps({
|
||||||
|
"status": "ok",
|
||||||
|
"proxy_to": f"{self.target_host}:{self.target_port}",
|
||||||
|
}).encode())
|
||||||
|
return
|
||||||
|
self._set_cors(405)
|
||||||
|
self.end_headers()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="CORS proxy for Unsloth Studio")
|
||||||
|
parser.add_argument(
|
||||||
|
"--target", default="127.0.0.1:8888",
|
||||||
|
help="Unsloth Studio address (default: 127.0.0.1:8888)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--listen", default=8080, type=int,
|
||||||
|
help="Port to listen on (default: 8080)",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
host, port_str = args.target, "8888"
|
||||||
|
if ":" in args.target:
|
||||||
|
host, port_str = args.target.rsplit(":", 1)
|
||||||
|
try:
|
||||||
|
port = int(port_str)
|
||||||
|
except ValueError:
|
||||||
|
print(f"Invalid target port: {port_str}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
ProxyHandler.target_host = host
|
||||||
|
ProxyHandler.target_port = port
|
||||||
|
|
||||||
|
server = HTTPServer(("0.0.0.0", args.listen), ProxyHandler)
|
||||||
|
|
||||||
|
print(f" Unsloth Studio CORS Proxy")
|
||||||
|
print(f" ─────────────────────────")
|
||||||
|
print(f" Listening on: http://127.0.0.1:{args.listen}")
|
||||||
|
print(f" Forwarding to: http://{host}:{port}")
|
||||||
|
print(f" Plugin API URL: http://127.0.0.1:{args.listen}")
|
||||||
|
print(f" Health check: http://127.0.0.1:{args.listen}/health")
|
||||||
|
print()
|
||||||
|
|
||||||
|
try:
|
||||||
|
server.serve_forever()
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\nShutting down.")
|
||||||
|
server.server_close()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Reference in New Issue
Block a user