#!/usr/bin/env python3
"""
GPT Image 1 / 1.5 完整 Demo

根据 OpenAI 官方文档，覆盖全部参数：
  size, quality, n, background, output_format, output_compression,
  moderation, response_format(b64_json), 图片编辑, mask 编辑

用法：
    NEWAPI_BASE_URL=https://your-api-gateway.example.com/v1 \
    NEWAPI_API_KEY=sk-xxx \
    python3 demo_gpt_image.py

可选参数：
    --model gpt-image-1.5-vvip   指定模型（默认 gpt-image-1.5-vvip）
    --skip-edit                   跳过编辑测试
"""

import base64
import json
import os
import sys
import urllib.error
import urllib.request
from pathlib import Path

BASE_URL = os.environ.get("NEWAPI_BASE_URL", "")
API_KEY = os.environ.get("NEWAPI_API_KEY", "")
MODEL = "gpt-image-1.5-vvip"
OUTPUT_DIR = Path("gpt_image_output")


def api_post(path, data):
    req = urllib.request.Request(
        f"{BASE_URL}{path}",
        data=json.dumps(data).encode("utf-8"),
        headers={
            "Authorization": f"Bearer {API_KEY}",
            "Content-Type": "application/json",
        },
        method="POST",
    )
    with urllib.request.urlopen(req, timeout=300) as resp:
        return json.loads(resp.read().decode("utf-8"))


def save(data_item, filename):
    filepath = OUTPUT_DIR / filename
    if data_item.get("b64_json"):
        img = base64.b64decode(data_item["b64_json"])
        filepath.write_bytes(img)
        return len(img)
    elif data_item.get("url"):
        with urllib.request.urlopen(data_item["url"], timeout=60) as r:
            img = r.read()
        filepath.write_bytes(img)
        return len(img)
    return 0


def print_usage(resp):
    usage = resp.get("usage", {})
    if usage:
        inp = usage.get("input_tokens", usage.get("prompt_tokens", "?"))
        out = usage.get("output_tokens", usage.get("completion_tokens", "?"))
        print(f"    tokens: input={inp} output={out}")


# ============================================================
#  1. 基础文生图
# ============================================================
def test_basic():
    print("\n=== 1. 基础文生图 ===")
    resp = api_post("/images/generations", {
        "model": MODEL,
        "prompt": "A cute orange cat wearing sunglasses at the beach, photorealistic",
    })
    print_usage(resp)
    sz = save(resp["data"][0], "1_basic.png")
    print(f"  ✅ {sz} bytes → 1_basic.png")
    return resp["data"][0].get("url", "")


# ============================================================
#  2. 不同 quality (low / medium / high)
# ============================================================
def test_quality():
    print("\n=== 2. Quality 参数 ===")
    for q in ["low", "medium", "high"]:
        print(f"  quality={q}...")
        resp = api_post("/images/generations", {
            "model": MODEL,
            "prompt": "a red rose on white background",
            "size": "1024x1024",
            "quality": q,
        })
        print_usage(resp)
        sz = save(resp["data"][0], f"2_quality_{q}.png")
        print(f"  ✅ {sz} bytes → 2_quality_{q}.png")


# ============================================================
#  3. 不同尺寸
# ============================================================
def test_sizes():
    print("\n=== 3. 不同尺寸 (1K / 2K / 4K) ===")
    sizes = [
        ("1024x1024", "1K square"),
        ("1536x1024", "1K landscape"),
        ("1024x1536", "1K portrait"),
        ("2048x2048", "2K square"),
        ("2048x1152", "2K landscape"),
        ("3840x2160", "4K landscape"),
    ]
    for size, desc in sizes:
        print(f"  {desc} ({size})...")
        resp = api_post("/images/generations", {
            "model": MODEL,
            "prompt": "a mountain landscape at sunset",
            "size": size,
        })
        print_usage(resp)
        save(resp["data"][0], f"3_size_{size}.png")
        print(f"  ✅ → 3_size_{size}.png")


# ============================================================
#  4. 多张生成 (n=3)
# ============================================================
def test_batch():
    print("\n=== 4. 多张生成 n=3 ===")
    resp = api_post("/images/generations", {
        "model": MODEL,
        "prompt": "a colorful butterfly",
        "size": "1024x1024",
        "quality": "medium",
        "n": 3,
    })
    print_usage(resp)
    for i, item in enumerate(resp["data"]):
        sz = save(item, f"4_batch_{i+1}.png")
        print(f"  [{i+1}/{len(resp['data'])}] ✅ {sz} bytes")


# ============================================================
#  5. 透明背景
# ============================================================
def test_transparent():
    print("\n=== 5. 透明背景 ===")
    resp = api_post("/images/generations", {
        "model": MODEL,
        "prompt": "a sneaker product photo, isolated, studio lighting",
        "size": "1024x1024",
        "quality": "high",
        "background": "transparent",
        "output_format": "png",
    })
    print_usage(resp)
    sz = save(resp["data"][0], "5_transparent.png")
    print(f"  ✅ {sz} bytes → 5_transparent.png")


# ============================================================
#  6. 输出格式 (png / jpeg / webp) + 压缩
# ============================================================
def test_formats():
    print("\n=== 6. 输出格式 + 压缩 ===")
    for fmt in ["png", "jpeg", "webp"]:
        params = {
            "model": MODEL,
            "prompt": "a green apple on a table",
            "size": "1024x1024",
            "output_format": fmt,
        }
        if fmt in ("jpeg", "webp"):
            params["output_compression"] = 50
        resp = api_post("/images/generations", params)
        sz = save(resp["data"][0], f"6_format.{fmt}")
        print(f"  ✅ {fmt} → {sz} bytes")


# ============================================================
#  7. Base64 输出
# ============================================================
def test_b64():
    print("\n=== 7. Base64 输出 ===")
    resp = api_post("/images/generations", {
        "model": MODEL,
        "prompt": "a simple blue circle",
        "size": "1024x1024",
        "response_format": "b64_json",
    })
    print_usage(resp)
    sz = save(resp["data"][0], "7_b64.png")
    print(f"  ✅ {sz} bytes → 7_b64.png (from base64)")


# ============================================================
#  8. Moderation=low
# ============================================================
def test_moderation():
    print("\n=== 8. Moderation=low ===")
    resp = api_post("/images/generations", {
        "model": MODEL,
        "prompt": "a warrior in armor, dramatic lighting",
        "size": "1024x1024",
        "moderation": "low",
    })
    print_usage(resp)
    sz = save(resp["data"][0], "8_moderation_low.png")
    print(f"  ✅ {sz} bytes → 8_moderation_low.png")


# ============================================================
#  9. 全参数组合
# ============================================================
def test_combo():
    print("\n=== 9. 全参数组合 (high + transparent + webp + compression + low moderation) ===")
    resp = api_post("/images/generations", {
        "model": MODEL,
        "prompt": "a glass perfume bottle, isolated on transparent background, professional product photography",
        "size": "1536x1024",
        "quality": "high",
        "background": "transparent",
        "output_format": "webp",
        "output_compression": 30,
        "moderation": "low",
        "n": 1,
    })
    print_usage(resp)
    sz = save(resp["data"][0], "9_combo.webp")
    print(f"  ✅ {sz} bytes → 9_combo.webp")


# ============================================================
#  10. 图片编辑
# ============================================================
def test_edit(source_url):
    print("\n=== 10. 图片编辑 ===")
    with urllib.request.urlopen(source_url, timeout=60) as r:
        source_bytes = r.read()
    print(f"  源图: {len(source_bytes)} bytes")

    boundary = "----FormBoundary7MA4YWxkTrZu0gW"
    body = b""
    for name, value in [("model", MODEL), ("prompt", "Add a tiny red hat on the cat"), ("n", "1")]:
        body += f"--{boundary}\r\nContent-Disposition: form-data; name=\"{name}\"\r\n\r\n{value}\r\n".encode()

    body += f"--{boundary}\r\nContent-Disposition: form-data; name=\"image\"; filename=\"source.png\"\r\nContent-Type: image/png\r\n\r\n".encode()
    body += source_bytes + b"\r\n"
    body += f"--{boundary}--\r\n".encode()

    req = urllib.request.Request(
        f"{BASE_URL}/images/edits",
        data=body,
        headers={
            "Authorization": f"Bearer {API_KEY}",
            "Content-Type": f"multipart/form-data; boundary={boundary}",
        },
        method="POST",
    )
    with urllib.request.urlopen(req, timeout=300) as resp:
        result = json.loads(resp.read().decode("utf-8"))

    sz = save(result["data"][0], "10_edit.png")
    print(f"  ✅ {sz} bytes → 10_edit.png")


# ============================================================
#  11. Mask 蒙版编辑
# ============================================================
def test_mask_edit(source_url):
    print("\n=== 11. Mask 蒙版编辑 ===")
    import struct
    import zlib

    with urllib.request.urlopen(source_url, timeout=60) as r:
        source_bytes = r.read()

    # 生成 256x256 mask PNG（中心 128x128 区域为编辑区）
    w, h = 256, 256
    rows = []
    for y in range(h):
        row = b"\x00"
        for x in range(w):
            if 64 <= x < 192 and 64 <= y < 192:
                row += b"\xff\xff\xff\xff"  # 白色不透明 = 编辑区
            else:
                row += b"\x00\x00\x00\x00"  # 透明 = 保留区
        rows.append(row)
    raw = b"".join(rows)

    def png_chunk(ctype, data):
        c = ctype + data
        return struct.pack(">I", len(data)) + c + struct.pack(">I", zlib.crc32(c) & 0xFFFFFFFF)

    ihdr = struct.pack(">IIBBBBB", w, h, 8, 6, 0, 0, 0)
    mask_bytes = b"\x89PNG\r\n\x1a\n"
    mask_bytes += png_chunk(b"IHDR", ihdr)
    mask_bytes += png_chunk(b"IDAT", zlib.compress(raw))
    mask_bytes += png_chunk(b"IEND", b"")
    print(f"  mask: {w}x{h} PNG ({len(mask_bytes)} bytes)")

    boundary = "----FormBoundary7MA4YWxkTrZu0gW"
    body = b""
    for name, value in [("model", MODEL), ("prompt", "Place a golden crown in the center"), ("n", "1")]:
        body += f"--{boundary}\r\nContent-Disposition: form-data; name=\"{name}\"\r\n\r\n{value}\r\n".encode()

    body += f"--{boundary}\r\nContent-Disposition: form-data; name=\"image\"; filename=\"source.png\"\r\nContent-Type: image/png\r\n\r\n".encode()
    body += source_bytes + b"\r\n"
    body += f"--{boundary}\r\nContent-Disposition: form-data; name=\"mask\"; filename=\"mask.png\"\r\nContent-Type: image/png\r\n\r\n".encode()
    body += mask_bytes + b"\r\n"
    body += f"--{boundary}--\r\n".encode()

    req = urllib.request.Request(
        f"{BASE_URL}/images/edits",
        data=body,
        headers={
            "Authorization": f"Bearer {API_KEY}",
            "Content-Type": f"multipart/form-data; boundary={boundary}",
        },
        method="POST",
    )
    with urllib.request.urlopen(req, timeout=300) as resp:
        result = json.loads(resp.read().decode("utf-8"))

    sz = save(result["data"][0], "11_mask_edit.png")
    print(f"  ✅ {sz} bytes → 11_mask_edit.png")


# ============================================================
#  主流程
# ============================================================
def main():
    global MODEL

    if not API_KEY or not BASE_URL:
        print("请设置 NEWAPI_BASE_URL 和 NEWAPI_API_KEY", file=sys.stderr)
        print(
            "示例: NEWAPI_BASE_URL='https://your-api-gateway.example.com/v1' "
            "NEWAPI_API_KEY='sk-xxx' python3 demo_gpt_image.py",
            file=sys.stderr,
        )
        sys.exit(1)

    args = sys.argv[1:]
    skip_edit = "--skip-edit" in args
    if "--model" in args:
        idx = args.index("--model")
        if idx + 1 < len(args):
            MODEL = args[idx + 1]

    OUTPUT_DIR.mkdir(exist_ok=True)
    print(f"模型: {MODEL}")
    print(f"Base URL: {BASE_URL}")
    print(f"输出目录: {OUTPUT_DIR}")

    results = {}
    source_url = ""

    tests = [
        ("1_basic", test_basic),
        ("2_quality", test_quality),
        ("3_sizes", test_sizes),
        ("4_batch", test_batch),
        ("5_transparent", test_transparent),
        ("6_formats", test_formats),
        ("7_b64", test_b64),
        ("8_moderation", test_moderation),
        ("9_combo", test_combo),
    ]

    for name, fn in tests:
        try:
            ret = fn()
            if name == "1_basic" and ret:
                source_url = ret
            results[name] = "✅"
        except Exception as e:
            results[name] = f"❌ {e}"

    if not skip_edit and source_url:
        for name, fn in [("10_edit", test_edit), ("11_mask_edit", test_mask_edit)]:
            try:
                fn(source_url)
                results[name] = "✅"
            except Exception as e:
                results[name] = f"❌ {e}"
    elif not skip_edit:
        results["10_edit"] = "⏭️ 跳过（无源图）"
        results["11_mask_edit"] = "⏭️ 跳过（无源图）"

    # 汇总
    print(f"\n{'='*50}")
    print(f"  {MODEL} 测试结果")
    print(f"{'='*50}")
    passed = sum(1 for v in results.values() if v == "✅")
    for k, v in sorted(results.items()):
        print(f"  {k}: {v}")
    print(f"\n  通过: {passed}/{len(results)}")


if __name__ == "__main__":
    main()
