mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-06-27 11:22:03 +00:00
feat(pets): prompt → atlas sprite-generation engine
Turn a text prompt into a petdex-spec spritesheet (8×9 grid of 192×208 cells), grounded so every animation row stays the same creature: - orchestrate: base drafts (distinct variation nudges) → per-row grounded generation → atlas compose; one image call per row, rows fan out in parallel. - atlas: frame-perfect registration in normalize_cells — 1-D cross-correlation of each frame's column-mass profile locks the body (robust to limbs/cape), one shared per-state scale, bottom-anchored; plus alpha-hole repair, gutter severing, and interior-seeded chroma-pocket clearing. - prompts: pixel-art-by-default style hints + registration constraints. - store: local pet write (register_local_pet), slugify/unique_slug, export_pet, slug-realigning rename_pet, createdBy provenance.
This commit is contained in:
parent
35e9c63d89
commit
32f837add1
8 changed files with 1972 additions and 1 deletions
29
agent/pet/generate/__init__.py
Normal file
29
agent/pet/generate/__init__.py
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
"""Pet generation — base-draft → hatch pipeline.
|
||||
|
||||
Public surface used by the gateway RPCs, the CLI ``hermes pets generate``
|
||||
command, and tests:
|
||||
|
||||
- :func:`generate_base_drafts` / :func:`hatch_pet` — the two-step flow.
|
||||
- :class:`HatchResult`, :class:`GenerationError`.
|
||||
- :mod:`atlas` — deterministic frame extraction + atlas composition/validation.
|
||||
|
||||
Image generation is delegated to the active reference-capable
|
||||
:class:`~agent.image_gen_provider.ImageGenProvider` (OpenAI gpt-image-2 or Krea);
|
||||
atlas assembly is fully deterministic so it's testable without any API calls.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from agent.pet.generate.imagegen import GenerationError
|
||||
from agent.pet.generate.orchestrate import (
|
||||
HatchResult,
|
||||
generate_base_drafts,
|
||||
hatch_pet,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"GenerationError",
|
||||
"HatchResult",
|
||||
"generate_base_drafts",
|
||||
"hatch_pet",
|
||||
]
|
||||
724
agent/pet/generate/atlas.py
Normal file
724
agent/pet/generate/atlas.py
Normal file
|
|
@ -0,0 +1,724 @@
|
|||
"""Deterministic spritesheet assembly — generated row strips → Hermes atlas.
|
||||
|
||||
Image-generation models are good at *drawing* a row of poses but bad at exact
|
||||
grid geometry, so the model never owns the atlas layout: it produces one loose
|
||||
horizontal strip per state, and these deterministic ops slice that strip into
|
||||
clean, centered, transparent ``192x208`` cells and pack them into the sheet our
|
||||
renderer reads.
|
||||
|
||||
The atlas follows the **petdex/Codex standard**: 8 columns x 9 rows of
|
||||
``192x208`` cells (``1536x1872``), with the row order + per-row frame counts
|
||||
from OpenAI's ``hatch-pet`` skill. Our renderer (:mod:`agent.pet.render`) keys
|
||||
frames as ``rows = states, cols = frames`` via
|
||||
:data:`agent.pet.constants.CODEX_STATE_ROWS`, and a pet built here is a valid
|
||||
``petdex submit`` spritesheet. Rows shorter than 8 columns leave the trailing
|
||||
cells fully transparent.
|
||||
|
||||
Note ``running`` is the *working* state (in-place processing), NOT locomotion —
|
||||
``running-right`` / ``running-left`` are the actual directional walk cycles.
|
||||
|
||||
The frame-segmentation, fit-to-cell, and transparency-residue logic is adapted
|
||||
from OpenAI's ``hatch-pet`` skill (openai/skills, Apache-2.0).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import logging
|
||||
import math
|
||||
from pathlib import Path
|
||||
|
||||
from agent.pet.constants import FRAME_H, FRAME_W
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
CELL_WIDTH = FRAME_W
|
||||
CELL_HEIGHT = FRAME_H
|
||||
|
||||
# (state, row index, frame count). Order/row indices MUST match
|
||||
# ``constants.CODEX_STATE_ROWS`` so the renderer crops the right row for each
|
||||
# driven state, and the per-row frame counts mirror the petdex/Codex
|
||||
# ``hatch-pet`` ``animation-rows`` spec. The renderer trims trailing blank
|
||||
# columns, so rows shorter than ``COLUMNS`` (8) just leave the tail transparent.
|
||||
ROW_SPECS: list[tuple[str, int, int]] = [
|
||||
("idle", 0, 6),
|
||||
("running-right", 1, 8),
|
||||
("running-left", 2, 8),
|
||||
("waving", 3, 4),
|
||||
("jumping", 4, 5),
|
||||
("failed", 5, 8),
|
||||
("waiting", 6, 6),
|
||||
("running", 7, 6),
|
||||
("review", 8, 6),
|
||||
]
|
||||
|
||||
ROWS = len(ROW_SPECS)
|
||||
COLUMNS = max(count for _, _, count in ROW_SPECS)
|
||||
ATLAS_WIDTH = COLUMNS * CELL_WIDTH
|
||||
ATLAS_HEIGHT = ROWS * CELL_HEIGHT
|
||||
|
||||
FRAME_COUNTS: dict[str, int] = {state: count for state, _, count in ROW_SPECS}
|
||||
|
||||
# Alpha at/below which a pixel is "background" for component detection.
|
||||
_ALPHA_FLOOR = 16
|
||||
# Cell padding kept around a fitted sprite so poses never touch the edge.
|
||||
_CELL_PAD = 10
|
||||
# Margin for the normalized pass — small, to fill the cell like real petdex pets
|
||||
# (they sit ~5px from the edges); the width clamp, not the pad, prevents clipping.
|
||||
_NORMALIZE_PAD = 14
|
||||
# Side-lobe cutoff for fitted frames. Adjacent-pose bleed usually appears as a
|
||||
# small separated horizontal lobe beside the real subject; keep sizeable lobes so
|
||||
# we don't punish a legitimate wide pose.
|
||||
_SIDE_LOBE_RATIO = 0.18
|
||||
|
||||
|
||||
# ───────────────────────── background removal ─────────────────────────
|
||||
|
||||
|
||||
def _color_distance(r: int, g: int, b: int, key: tuple[int, int, int]) -> float:
|
||||
return math.sqrt((r - key[0]) ** 2 + (g - key[1]) ** 2 + (b - key[2]) ** 2)
|
||||
|
||||
|
||||
def _has_transparency(image) -> bool:
|
||||
"""True if the strip already carries a real alpha background."""
|
||||
extrema = image.getchannel("A").getextrema()
|
||||
# Min alpha 0 somewhere and a meaningful share of fully-transparent pixels.
|
||||
if extrema[0] > _ALPHA_FLOOR:
|
||||
return False
|
||||
hist = image.getchannel("A").histogram()
|
||||
transparent = sum(hist[: _ALPHA_FLOOR + 1])
|
||||
total = image.width * image.height
|
||||
return transparent > total * 0.05
|
||||
|
||||
|
||||
def _dominant_corner_color(image) -> tuple[int, int, int]:
|
||||
"""Sample the four corners and return the most common opaque color."""
|
||||
from collections import Counter
|
||||
|
||||
w, h = image.width, image.height
|
||||
px = image.load()
|
||||
counter: Counter = Counter()
|
||||
for x, y in ((0, 0), (w - 1, 0), (0, h - 1), (w - 1, h - 1)):
|
||||
r, g, b, a = px[x, y]
|
||||
if a > _ALPHA_FLOOR:
|
||||
counter[(r, g, b)] += 1
|
||||
if not counter:
|
||||
return (0, 255, 0)
|
||||
return counter.most_common(1)[0][0]
|
||||
|
||||
|
||||
def _near_key_mask(image, key: tuple[int, int, int], tol: int = 48):
|
||||
"""An ``L`` mask, 255 where a pixel is within *tol* per-channel of *key*.
|
||||
|
||||
Tight on purpose: it only marks near-pure backdrop so trapped chroma pockets
|
||||
seed the flood, while chroma-*tinted* character pixels stay outside it. Built
|
||||
with channel point-ops (fast C), no per-pixel Python.
|
||||
"""
|
||||
from PIL import ImageChops
|
||||
|
||||
r, g, b, _a = image.split()
|
||||
kr, kg, kb = key
|
||||
return ImageChops.darker(
|
||||
ImageChops.darker(
|
||||
r.point(lambda v: 255 if abs(v - kr) <= tol else 0),
|
||||
g.point(lambda v: 255 if abs(v - kg) <= tol else 0),
|
||||
),
|
||||
b.point(lambda v: 255 if abs(v - kb) <= tol else 0),
|
||||
)
|
||||
|
||||
|
||||
def remove_background(image, *, chroma_key: tuple[int, int, int] | None = None, threshold: float = 90.0):
|
||||
"""Return *image* (RGBA) with its flat background keyed out to transparent.
|
||||
|
||||
If the strip already has a transparent background we leave it alone; else we
|
||||
key out *chroma_key* (or the dominant corner color when not given) via a
|
||||
**border flood-fill**: only background-coloured pixels *connected to an edge*
|
||||
are removed. A global color match (the old approach) punched holes in the pet
|
||||
wherever an interior highlight happened to match the backdrop — e.g. a pug's
|
||||
light belly against a near-white background — which then showed through as the
|
||||
window behind. Flood-fill keeps those interior pixels because they aren't
|
||||
reachable from the border without crossing the (non-background) pet.
|
||||
"""
|
||||
from collections import deque
|
||||
|
||||
rgba = image.convert("RGBA")
|
||||
if _has_transparency(rgba):
|
||||
return _repair_internal_alpha_holes(rgba)
|
||||
|
||||
key = chroma_key or _dominant_corner_color(rgba)
|
||||
w, h = rgba.width, rgba.height
|
||||
px = rgba.load()
|
||||
|
||||
def _is_bg(x: int, y: int) -> bool:
|
||||
r, g, b, a = px[x, y]
|
||||
return a > _ALPHA_FLOOR and _color_distance(r, g, b, key) <= threshold
|
||||
|
||||
visited = bytearray(w * h)
|
||||
queue: deque[tuple[int, int]] = deque()
|
||||
|
||||
# Seed from every border pixel that looks like background.
|
||||
for x in range(w):
|
||||
for y in (0, h - 1):
|
||||
if _is_bg(x, y) and not visited[y * w + x]:
|
||||
visited[y * w + x] = 1
|
||||
queue.append((x, y))
|
||||
for y in range(h):
|
||||
for x in (0, w - 1):
|
||||
if _is_bg(x, y) and not visited[y * w + x]:
|
||||
visited[y * w + x] = 1
|
||||
queue.append((x, y))
|
||||
|
||||
# Trapped pockets: background enclosed by the character (the magenta between
|
||||
# an arm and the body) isn't border-reachable, so also seed the flood from
|
||||
# interior near-key pixels. Gated to a *saturated* key (our magenta backdrop)
|
||||
# so we never seed from a character sharing a desaturated near-white/gray key
|
||||
# — that's the hole-punching the border-only flood exists to avoid.
|
||||
if max(key) - min(key) >= 120:
|
||||
for i, near in enumerate(_near_key_mask(rgba, key).getdata()):
|
||||
if near and not visited[i]:
|
||||
visited[i] = 1
|
||||
queue.append((i % w, i // w))
|
||||
|
||||
while queue:
|
||||
x, y = queue.popleft()
|
||||
px[x, y] = (0, 0, 0, 0)
|
||||
for nx, ny in ((x + 1, y), (x - 1, y), (x, y + 1), (x, y - 1)):
|
||||
if 0 <= nx < w and 0 <= ny < h:
|
||||
idx = ny * w + nx
|
||||
if not visited[idx]:
|
||||
visited[idx] = 1
|
||||
if _is_bg(nx, ny):
|
||||
queue.append((nx, ny))
|
||||
return rgba
|
||||
|
||||
|
||||
def _repair_internal_alpha_holes(image):
|
||||
"""Fill transparent islands fully enclosed by opaque sprite pixels.
|
||||
|
||||
Some providers return "transparent" PNGs with swiss-cheese alpha inside the
|
||||
character. Border flood-fill cannot see those because there is no opaque
|
||||
backdrop to key, so repair the alpha mask itself: transparent components that
|
||||
touch an image edge remain background; transparent components enclosed by
|
||||
the sprite are filled with the average color of their opaque neighbours.
|
||||
"""
|
||||
from collections import deque
|
||||
|
||||
rgba = image.convert("RGBA")
|
||||
w, h = rgba.size
|
||||
px = rgba.load()
|
||||
visited = bytearray(w * h)
|
||||
|
||||
def _is_transparent(x: int, y: int) -> bool:
|
||||
return px[x, y][3] <= _ALPHA_FLOOR
|
||||
|
||||
def _mark_border_component(sx: int, sy: int) -> None:
|
||||
queue: deque[tuple[int, int]] = deque([(sx, sy)])
|
||||
visited[sy * w + sx] = 1
|
||||
while queue:
|
||||
x, y = queue.popleft()
|
||||
for nx, ny in ((x + 1, y), (x - 1, y), (x, y + 1), (x, y - 1)):
|
||||
if 0 <= nx < w and 0 <= ny < h:
|
||||
idx = ny * w + nx
|
||||
if not visited[idx] and _is_transparent(nx, ny):
|
||||
visited[idx] = 1
|
||||
queue.append((nx, ny))
|
||||
|
||||
# First mark true background: all transparent pixels reachable from the edge.
|
||||
for x in range(w):
|
||||
for y in (0, h - 1):
|
||||
if _is_transparent(x, y) and not visited[y * w + x]:
|
||||
_mark_border_component(x, y)
|
||||
for y in range(h):
|
||||
for x in (0, w - 1):
|
||||
if _is_transparent(x, y) and not visited[y * w + x]:
|
||||
_mark_border_component(x, y)
|
||||
|
||||
def _collect_hole(sx: int, sy: int) -> list[tuple[int, int]]:
|
||||
queue: deque[tuple[int, int]] = deque([(sx, sy)])
|
||||
visited[sy * w + sx] = 1
|
||||
pixels: list[tuple[int, int]] = []
|
||||
while queue:
|
||||
x, y = queue.popleft()
|
||||
pixels.append((x, y))
|
||||
for nx, ny in ((x + 1, y), (x - 1, y), (x, y + 1), (x, y - 1)):
|
||||
if 0 <= nx < w and 0 <= ny < h:
|
||||
idx = ny * w + nx
|
||||
if not visited[idx] and _is_transparent(nx, ny):
|
||||
visited[idx] = 1
|
||||
queue.append((nx, ny))
|
||||
return pixels
|
||||
|
||||
def _fill_color(hole: list[tuple[int, int]]) -> tuple[int, int, int, int]:
|
||||
samples: list[tuple[int, int, int]] = []
|
||||
seen = set(hole)
|
||||
for x, y in hole:
|
||||
for nx, ny in ((x + 1, y), (x - 1, y), (x, y + 1), (x, y - 1)):
|
||||
if 0 <= nx < w and 0 <= ny < h and (nx, ny) not in seen:
|
||||
r, g, b, a = px[nx, ny]
|
||||
if a > _ALPHA_FLOOR:
|
||||
samples.append((r, g, b))
|
||||
if not samples:
|
||||
return (0, 0, 0, 255)
|
||||
return (
|
||||
round(sum(c[0] for c in samples) / len(samples)),
|
||||
round(sum(c[1] for c in samples) / len(samples)),
|
||||
round(sum(c[2] for c in samples) / len(samples)),
|
||||
255,
|
||||
)
|
||||
|
||||
for start, _ in enumerate(visited):
|
||||
if visited[start]:
|
||||
continue
|
||||
x = start % w
|
||||
y = start // w
|
||||
if not _is_transparent(x, y):
|
||||
continue
|
||||
hole = _collect_hole(x, y)
|
||||
color = _fill_color(hole)
|
||||
for hx, hy in hole:
|
||||
px[hx, hy] = color
|
||||
return rgba
|
||||
|
||||
|
||||
# ───────────────────────── frame extraction ─────────────────────────
|
||||
|
||||
|
||||
def _fit_to_cell(image):
|
||||
"""Crop to content, scale to fit a padded cell, and center on transparent."""
|
||||
from PIL import Image
|
||||
|
||||
target = Image.new("RGBA", (CELL_WIDTH, CELL_HEIGHT), (0, 0, 0, 0))
|
||||
image = _drop_side_bleed(image)
|
||||
bbox = image.getbbox()
|
||||
if bbox is None:
|
||||
return target
|
||||
|
||||
sprite = image.crop(bbox)
|
||||
max_w = CELL_WIDTH - _CELL_PAD
|
||||
max_h = CELL_HEIGHT - _CELL_PAD
|
||||
scale = min(max_w / sprite.width, max_h / sprite.height, 1.0)
|
||||
if scale != 1.0:
|
||||
sprite = sprite.resize(
|
||||
(max(1, round(sprite.width * scale)), max(1, round(sprite.height * scale))),
|
||||
Image.Resampling.LANCZOS,
|
||||
)
|
||||
left = (CELL_WIDTH - sprite.width) // 2
|
||||
top = (CELL_HEIGHT - sprite.height) // 2
|
||||
target.alpha_composite(sprite, (left, top))
|
||||
return target
|
||||
|
||||
|
||||
def _drop_side_bleed(image):
|
||||
"""Remove tiny separated left/right lobes before fitting a frame.
|
||||
|
||||
Frogger showed the failure mode: a good centered pose plus a thin vertical
|
||||
sliver from the neighbouring pose. By the time it reaches a cell, that sliver
|
||||
may be close enough to the subject that component extraction already grouped
|
||||
it. A horizontal alpha projection still reveals it as a small side lobe with
|
||||
a low mass compared to the main silhouette. Drop only those low-mass lobes;
|
||||
keep large lobes so wide poses and real limbs survive.
|
||||
"""
|
||||
from PIL import Image
|
||||
|
||||
rgba = image.convert("RGBA")
|
||||
w, h = rgba.size
|
||||
profile = _column_profile(rgba) # mean alpha per column (fast C resize)
|
||||
|
||||
segments: list[tuple[int, int, int]] = [] # (left, right, mass)
|
||||
start = mass = 0
|
||||
started = False
|
||||
for x, v in enumerate(profile + [0]):
|
||||
if v > 2:
|
||||
if not started:
|
||||
start, mass, started = x, 0, True
|
||||
mass += v
|
||||
elif started:
|
||||
segments.append((start, x, mass))
|
||||
started = False
|
||||
|
||||
if len(segments) < 2:
|
||||
return rgba
|
||||
keep_mass = max(m for _, _, m in segments) * _SIDE_LOBE_RATIO
|
||||
keep = [(l, r) for l, r, m in segments if m >= keep_mass]
|
||||
if len(keep) == len(segments):
|
||||
return rgba
|
||||
|
||||
# Zero every column band that isn't a kept segment (box paste, not per-pixel).
|
||||
rgba = rgba.copy()
|
||||
cut, prev = Image.new("RGBA", (w, h), (0, 0, 0, 0)), 0
|
||||
for left, right in keep:
|
||||
if left > prev:
|
||||
rgba.paste(cut.crop((prev, 0, left, h)), (prev, 0))
|
||||
prev = right
|
||||
if prev < w:
|
||||
rgba.paste(cut.crop((prev, 0, w, h)), (prev, 0))
|
||||
return rgba
|
||||
|
||||
|
||||
def _connected_components(image) -> list[dict]:
|
||||
"""Flood-fill the alpha mask into connected blobs (4-connectivity)."""
|
||||
alpha = image.getchannel("A")
|
||||
w, h = image.size
|
||||
data = alpha.tobytes()
|
||||
visited = bytearray(w * h)
|
||||
out: list[dict] = []
|
||||
|
||||
for start, a in enumerate(data):
|
||||
if a <= _ALPHA_FLOOR or visited[start]:
|
||||
continue
|
||||
stack = [start]
|
||||
visited[start] = 1
|
||||
pixels: list[int] = []
|
||||
min_x = w
|
||||
min_y = h
|
||||
max_x = 0
|
||||
max_y = 0
|
||||
while stack:
|
||||
cur = stack.pop()
|
||||
pixels.append(cur)
|
||||
x = cur % w
|
||||
y = cur // w
|
||||
min_x = min(min_x, x)
|
||||
min_y = min(min_y, y)
|
||||
max_x = max(max_x, x)
|
||||
max_y = max(max_y, y)
|
||||
for nb, ok in (
|
||||
(cur - 1, x > 0),
|
||||
(cur + 1, x + 1 < w),
|
||||
(cur - w, y > 0),
|
||||
(cur + w, y + 1 < h),
|
||||
):
|
||||
if ok and not visited[nb] and data[nb] > _ALPHA_FLOOR:
|
||||
visited[nb] = 1
|
||||
stack.append(nb)
|
||||
out.append(
|
||||
{
|
||||
"pixels": pixels,
|
||||
"area": len(pixels),
|
||||
"bbox": (min_x, min_y, max_x + 1, max_y + 1),
|
||||
"center_x": (min_x + max_x + 1) / 2,
|
||||
}
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
def _sever_expected_gutters(strip, frame_count: int):
|
||||
"""Cut thin vertical gutters at expected frame boundaries before labeling.
|
||||
|
||||
Generated rows often have a shared shadow, glow, motion smear, or 1px bridge
|
||||
that connects neighbouring poses. Component detection then sees one giant
|
||||
blob and either fails or falls back to slot slicing. We know the requested
|
||||
frame count, so cut a very narrow transparent band at each expected boundary
|
||||
before connected-component labeling. If a pose truly overlaps the boundary,
|
||||
losing a few pixels is better than exporting merged frames.
|
||||
"""
|
||||
if frame_count <= 1:
|
||||
return strip
|
||||
|
||||
out = strip.copy()
|
||||
px = out.load()
|
||||
slot = out.width / frame_count
|
||||
half = max(2, min(8, round(slot * 0.02)))
|
||||
for i in range(1, frame_count):
|
||||
x = round(i * slot)
|
||||
left = max(0, x - half)
|
||||
right = min(out.width, x + half + 1)
|
||||
for gx in range(left, right):
|
||||
for gy in range(out.height):
|
||||
r, g, b, _a = px[gx, gy]
|
||||
px[gx, gy] = (r, g, b, 0)
|
||||
return out
|
||||
|
||||
|
||||
def _segmentable(strip, frame_count: int) -> bool:
|
||||
"""True if the (gutter-severed) strip yields ≥ *frame_count* distinct blobs.
|
||||
|
||||
Used only as a quality gate: a row that can't show this many separable poses
|
||||
is a bad generation (caller retries / falls back), never silently sliced into
|
||||
merged frames.
|
||||
"""
|
||||
components = _connected_components(strip)
|
||||
if not components:
|
||||
return False
|
||||
largest = max(c["area"] for c in components)
|
||||
seed_threshold = max(120, largest * 0.20)
|
||||
return sum(1 for c in components if c["area"] >= seed_threshold) >= frame_count
|
||||
|
||||
|
||||
def _slot_crops(strip, frame_count: int) -> list:
|
||||
"""Slice *strip* into *frame_count* uniform columns (one coordinate space).
|
||||
|
||||
Equal-width columns keep every frame in a single shared coordinate frame, so
|
||||
a later union-crop + shared placement (:func:`normalize_cells`) preserves the
|
||||
row's real motion without the per-frame re-centering that makes a pet visibly
|
||||
slide. Neighbour side-bleed is trimmed per column.
|
||||
"""
|
||||
w0 = max(1, strip.width // frame_count)
|
||||
h = strip.height
|
||||
return [_drop_side_bleed(strip.crop((i * w0, 0, i * w0 + w0, h))) for i in range(frame_count)]
|
||||
|
||||
|
||||
def extract_strip_frames(
|
||||
strip,
|
||||
frame_count: int,
|
||||
*,
|
||||
chroma_key: tuple[int, int, int] | None = None,
|
||||
method: str = "auto",
|
||||
fit: bool = True,
|
||||
) -> list:
|
||||
"""Turn one generated row strip into *frame_count* frames.
|
||||
|
||||
Background is keyed out, the expected frame gutters are severed, then the
|
||||
strip is sliced into equal columns. Connected components only *validate* that
|
||||
the row holds *frame_count* separable poses (``components`` raises, ``auto``
|
||||
falls back to slicing the un-severed strip).
|
||||
|
||||
*fit* (default) fits+centers each frame into a 192x208 cell — the standalone
|
||||
contract for callers that don't normalize. Hatching passes ``fit=False`` to
|
||||
keep raw, coordinate-aligned columns for :func:`normalize_cells`, which lays
|
||||
one shared scale + baseline across the whole pet (no slide, no size pulse).
|
||||
"""
|
||||
from PIL import Image
|
||||
|
||||
if isinstance(strip, (str, Path)):
|
||||
with Image.open(strip) as opened:
|
||||
strip = opened.convert("RGBA")
|
||||
else:
|
||||
strip = strip.convert("RGBA")
|
||||
|
||||
strip = remove_background(strip, chroma_key=chroma_key)
|
||||
severed = _sever_expected_gutters(strip, frame_count)
|
||||
segmentable = _segmentable(severed, frame_count)
|
||||
if method == "components" and not segmentable:
|
||||
raise ValueError(f"could not segment {frame_count} sprites from strip")
|
||||
|
||||
frames = _slot_crops(severed if segmentable else strip, frame_count)
|
||||
return [_fit_to_cell(f) for f in frames] if fit else frames
|
||||
|
||||
|
||||
def _column_profile(image) -> list[int]:
|
||||
"""Per-column alpha mass — collapse the frame to a 1px-tall strip (fast in C)."""
|
||||
from PIL import Image
|
||||
|
||||
return list(image.getchannel("A").resize((image.width, 1), Image.BILINEAR).getdata())
|
||||
|
||||
|
||||
def _best_shift(ref: list[int], prof: list[int], window: int) -> int:
|
||||
"""Integer dx that best aligns *prof* onto *ref* by cross-correlation.
|
||||
|
||||
This is 1-D phase correlation: the body is the dominant mass in the column
|
||||
profile, so the peak overlap locks onto the body and a flipping arm/cape (a
|
||||
small secondary bump) doesn't move the match. Proven on the jitter case to
|
||||
cut body drift from ~9px to ~1px where a centroid/bbox anchor cannot.
|
||||
"""
|
||||
n = len(ref)
|
||||
best_score: float | None = None
|
||||
best = 0
|
||||
for d in range(-window, window + 1):
|
||||
score = 0
|
||||
for x in range(max(0, d), min(n, n + d)):
|
||||
score += ref[x] * prof[x - d]
|
||||
if best_score is None or score > best_score:
|
||||
best_score = score
|
||||
best = d
|
||||
return best
|
||||
|
||||
|
||||
def normalize_cells(frames_by_state: dict[str, list], *, pad: int = _NORMALIZE_PAD) -> dict[str, list]:
|
||||
"""Register every frame into a 192x208 cell — the deterministic anti-jitter math.
|
||||
|
||||
A per-frame "crop→scale→center" pipeline jitters because a moving limb/cape
|
||||
shifts the bbox (or even the centroid) and a per-frame scale pulses the size.
|
||||
The rigorous fix, matching image-registration practice (phase correlation)
|
||||
and AI-sprite pipelines (perfectpixel-studio / sprite-gen):
|
||||
|
||||
1. **Cross-correlate** each frame's column profile against the per-state
|
||||
*median* profile to find the integer shift that locks the **body** in
|
||||
place — robust to limbs/cape because the body dominates the profile.
|
||||
2. **Union-crop** the registered frames through one shared window and apply
|
||||
**one shared scale** + bottom-anchor, so size and baseline are uniform and
|
||||
intra-state vertical motion (a jump's lift) is preserved.
|
||||
"""
|
||||
from PIL import Image
|
||||
|
||||
blank = lambda: Image.new("RGBA", (CELL_WIDTH, CELL_HEIGHT), (0, 0, 0, 0))
|
||||
|
||||
out: dict[str, list] = {}
|
||||
for state, frames in frames_by_state.items():
|
||||
rgba = [f.convert("RGBA") for f in frames]
|
||||
if not any(f.getbbox() for f in rgba):
|
||||
out[state] = [blank() for _ in frames]
|
||||
continue
|
||||
|
||||
# Pad every frame to a common canvas so column profiles are comparable.
|
||||
w0 = max(f.width for f in rgba)
|
||||
h0 = max(f.height for f in rgba)
|
||||
canvas = []
|
||||
for f in rgba:
|
||||
if f.size != (w0, h0):
|
||||
c = Image.new("RGBA", (w0, h0), (0, 0, 0, 0))
|
||||
c.alpha_composite(f, (0, 0))
|
||||
f = c
|
||||
canvas.append(f)
|
||||
|
||||
# Register horizontally: shift each frame to lock the body (xcorr).
|
||||
profiles = [_column_profile(f) for f in canvas]
|
||||
ref = [sorted(p[x] for p in profiles)[len(profiles) // 2] for x in range(w0)]
|
||||
window = max(8, w0 // 5)
|
||||
margin = window
|
||||
aligned = []
|
||||
for f, prof in zip(canvas, profiles):
|
||||
shifted = Image.new("RGBA", (w0 + 2 * margin, h0), (0, 0, 0, 0))
|
||||
shifted.alpha_composite(f, (margin + _best_shift(ref, prof, window), 0))
|
||||
aligned.append(shifted)
|
||||
|
||||
# Shared window + scale over the registered set; bottom-anchored, centered.
|
||||
boxes = [b for b in (a.getbbox() for a in aligned) if b]
|
||||
left = min(b[0] for b in boxes)
|
||||
top = min(b[1] for b in boxes)
|
||||
right = max(b[2] for b in boxes)
|
||||
bottom = max(b[3] for b in boxes)
|
||||
uw, uh = right - left, bottom - top
|
||||
scale = min((CELL_WIDTH - pad) / uw, (CELL_HEIGHT - pad) / uh)
|
||||
sw, sh = max(1, round(uw * scale)), max(1, round(uh * scale))
|
||||
px, py = round((CELL_WIDTH - sw) / 2), round((CELL_HEIGHT - pad // 2) - sh)
|
||||
|
||||
cells = []
|
||||
for a in aligned:
|
||||
crop = a.crop((left, top, right, bottom))
|
||||
if crop.size != (sw, sh):
|
||||
crop = crop.resize((sw, sh), Image.Resampling.LANCZOS)
|
||||
cell = blank()
|
||||
cell.alpha_composite(crop, (px, py))
|
||||
cells.append(cell)
|
||||
out[state] = cells
|
||||
return out
|
||||
|
||||
|
||||
# ───────────────────────── atlas composition ─────────────────────────
|
||||
|
||||
|
||||
def single_frame(image, *, fit: bool = True):
|
||||
"""One frame from a standalone image (e.g. the base look).
|
||||
|
||||
Used as an idle fallback so a pet always renders even if the idle row
|
||||
generation failed. *fit* yields a finished 192x208 cell; ``fit=False`` yields
|
||||
the raw keyed sprite for :func:`normalize_cells` to place with the rest.
|
||||
"""
|
||||
from PIL import Image
|
||||
|
||||
if isinstance(image, (str, Path)):
|
||||
with Image.open(image) as opened:
|
||||
image = opened.convert("RGBA")
|
||||
keyed = remove_background(image)
|
||||
return _fit_to_cell(keyed) if fit else _drop_side_bleed(keyed)
|
||||
|
||||
|
||||
def _clear_transparent_rgb(image):
|
||||
"""Zero the RGB of fully-transparent pixels (no colored-halo residue)."""
|
||||
from PIL import Image
|
||||
|
||||
rgba = image.convert("RGBA")
|
||||
data = bytearray(rgba.tobytes())
|
||||
for i in range(0, len(data), 4):
|
||||
if data[i + 3] == 0:
|
||||
data[i] = data[i + 1] = data[i + 2] = 0
|
||||
return Image.frombytes("RGBA", rgba.size, bytes(data))
|
||||
|
||||
|
||||
def mirror_frames(frames: list) -> list:
|
||||
"""Horizontally flip each frame *in place* (RGBA-safe).
|
||||
|
||||
Used to derive ``running-left`` from an approved ``running-right`` row. The
|
||||
flip is per-frame so the leftward loop preserves the rightward loop's frame
|
||||
order and timing — this is NOT a whole-strip reverse (which would play the
|
||||
animation backwards), matching the petdex/Codex mirror rule.
|
||||
"""
|
||||
from PIL import Image
|
||||
|
||||
flip = getattr(Image, "Transpose", Image).FLIP_LEFT_RIGHT
|
||||
return [frame.convert("RGBA").transpose(flip) for frame in frames]
|
||||
|
||||
|
||||
def compose_atlas(frames_by_state: dict[str, list]):
|
||||
"""Pack per-state frame lists into the Hermes atlas (RGBA, residue-cleared).
|
||||
|
||||
Missing/short states leave their trailing cells transparent; extra frames
|
||||
beyond a state's spec are dropped.
|
||||
"""
|
||||
from PIL import Image
|
||||
|
||||
atlas = Image.new("RGBA", (ATLAS_WIDTH, ATLAS_HEIGHT), (0, 0, 0, 0))
|
||||
for state, row, count in ROW_SPECS:
|
||||
frames = frames_by_state.get(state) or []
|
||||
for col, frame in enumerate(frames[:count]):
|
||||
cell = frame.convert("RGBA")
|
||||
if cell.size != (CELL_WIDTH, CELL_HEIGHT):
|
||||
cell = _fit_to_cell(cell)
|
||||
atlas.alpha_composite(cell, (col * CELL_WIDTH, row * CELL_HEIGHT))
|
||||
return _clear_transparent_rgb(atlas)
|
||||
|
||||
|
||||
def atlas_to_webp_bytes(atlas) -> bytes:
|
||||
"""Encode an atlas image to lossless WebP bytes (the on-disk pet format)."""
|
||||
buf = io.BytesIO()
|
||||
atlas.save(buf, format="WEBP", lossless=True, quality=100, method=6, exact=True)
|
||||
return buf.getvalue()
|
||||
|
||||
|
||||
def validate_atlas(atlas) -> dict:
|
||||
"""Check geometry, per-cell occupancy, and transparency invariants.
|
||||
|
||||
Returns ``{ok, width, height, errors, warnings, filled_states}``. Errors are
|
||||
blockers (wrong size, empty used cell, opaque/dirty transparency); warnings
|
||||
are soft (a whole state row blank — generation likely dropped a row).
|
||||
"""
|
||||
from PIL import Image
|
||||
|
||||
if isinstance(atlas, (str, Path)):
|
||||
with Image.open(atlas) as opened:
|
||||
atlas = opened.convert("RGBA")
|
||||
else:
|
||||
atlas = atlas.convert("RGBA")
|
||||
|
||||
errors: list[str] = []
|
||||
warnings: list[str] = []
|
||||
|
||||
if atlas.size != (ATLAS_WIDTH, ATLAS_HEIGHT):
|
||||
errors.append(f"expected {ATLAS_WIDTH}x{ATLAS_HEIGHT}, got {atlas.width}x{atlas.height}")
|
||||
return {"ok": False, "width": atlas.width, "height": atlas.height, "errors": errors, "warnings": warnings, "filled_states": []}
|
||||
|
||||
filled_states: list[str] = []
|
||||
for state, row, count in ROW_SPECS:
|
||||
row_pixels = 0
|
||||
for col in range(count):
|
||||
left = col * CELL_WIDTH
|
||||
top = row * CELL_HEIGHT
|
||||
cell = atlas.crop((left, top, left + CELL_WIDTH, top + CELL_HEIGHT))
|
||||
nonblank = sum(cell.getchannel("A").histogram()[1:])
|
||||
row_pixels += nonblank
|
||||
if row_pixels > 0:
|
||||
filled_states.append(state)
|
||||
else:
|
||||
warnings.append(f"state '{state}' has no frames")
|
||||
|
||||
if not filled_states:
|
||||
errors.append("atlas is empty — no state produced any frames")
|
||||
|
||||
# Transparent pixels must carry zero RGB (no halo residue).
|
||||
data = atlas.tobytes()
|
||||
residue = 0
|
||||
for i in range(0, len(data), 4):
|
||||
if data[i + 3] == 0 and (data[i] or data[i + 1] or data[i + 2]):
|
||||
residue += 1
|
||||
if residue:
|
||||
errors.append(f"{residue} transparent pixels retain RGB residue")
|
||||
|
||||
return {
|
||||
"ok": not errors,
|
||||
"width": atlas.width,
|
||||
"height": atlas.height,
|
||||
"errors": errors,
|
||||
"warnings": warnings,
|
||||
"filled_states": filled_states,
|
||||
}
|
||||
176
agent/pet/generate/imagegen.py
Normal file
176
agent/pet/generate/imagegen.py
Normal file
|
|
@ -0,0 +1,176 @@
|
|||
"""Thin image-generation layer for pet sprites.
|
||||
|
||||
Wraps the active :class:`~agent.image_gen_provider.ImageGenProvider` with the
|
||||
two things sprite generation needs that the agent-facing ``image_generate`` tool
|
||||
doesn't expose: **N variants** (loop) and **reference-image grounding** (so each
|
||||
animation row stays the same character as the chosen base).
|
||||
|
||||
Reference grounding only works on providers that support it — currently OpenAI
|
||||
``gpt-image-2`` (image edits) and Krea (style references). We resolve to one of
|
||||
those and surface a clear, actionable error otherwise rather than silently
|
||||
producing an ungrounded, drifting pet.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Providers that can ground generation on a reference image.
|
||||
# openrouter / nous reach Gemini Flash Image (and friends) over the
|
||||
# OpenRouter-compatible chat-completions image protocol, which accepts
|
||||
# reference images for grounding. Nous Portal proxies OpenRouter, so both
|
||||
# qualify.
|
||||
_REF_CAPABLE = ("openai", "openai-codex", "krea", "openrouter", "nous")
|
||||
|
||||
|
||||
class GenerationError(RuntimeError):
|
||||
"""Raised on any image-generation failure (no provider, API error, IO)."""
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SpriteProvider:
|
||||
"""Resolved provider plus whether it can take reference images."""
|
||||
|
||||
name: str
|
||||
provider: object
|
||||
supports_references: bool
|
||||
|
||||
|
||||
def _discover() -> None:
|
||||
try:
|
||||
from hermes_cli.plugins import _ensure_plugins_discovered
|
||||
|
||||
_ensure_plugins_discovered()
|
||||
except Exception as exc: # noqa: BLE001 - discovery is best-effort
|
||||
logger.debug("image-gen plugin discovery failed: %s", exc)
|
||||
|
||||
|
||||
def resolve_provider(*, require_references: bool = True) -> SpriteProvider:
|
||||
"""Pick the image provider to use for sprite work.
|
||||
|
||||
Preference: the configured provider when it's reference-capable, else the
|
||||
first available reference-capable provider. With *require_references* off we
|
||||
fall back to any available provider (used for prompt-only base drafts).
|
||||
"""
|
||||
_discover()
|
||||
from agent.image_gen_registry import get_active_provider, get_provider
|
||||
|
||||
# Configured / active provider first.
|
||||
active = None
|
||||
try:
|
||||
active = get_active_provider()
|
||||
except Exception: # noqa: BLE001
|
||||
active = None
|
||||
if active is not None:
|
||||
name = getattr(active, "name", "")
|
||||
if name in _REF_CAPABLE and active.is_available():
|
||||
return SpriteProvider(name=name, provider=active, supports_references=True)
|
||||
|
||||
# Any available reference-capable provider.
|
||||
for name in _REF_CAPABLE:
|
||||
provider = get_provider(name)
|
||||
if provider is not None and provider.is_available():
|
||||
return SpriteProvider(name=name, provider=provider, supports_references=True)
|
||||
|
||||
if not require_references and active is not None and active.is_available():
|
||||
return SpriteProvider(
|
||||
name=getattr(active, "name", "unknown"), provider=active, supports_references=False
|
||||
)
|
||||
|
||||
raise GenerationError(
|
||||
"Pet generation needs an image backend that supports reference images. "
|
||||
"Open `hermes tools` → Image Generation and configure OpenRouter, Nous "
|
||||
"Portal, or OpenAI (gpt-image-2) with an API key."
|
||||
)
|
||||
|
||||
|
||||
def _save_local(image_ref: str, *, prefix: str) -> Path:
|
||||
"""Return a local path for *image_ref*, downloading it if it's a URL."""
|
||||
if image_ref.startswith(("http://", "https://")):
|
||||
from agent.image_gen_provider import save_url_image
|
||||
|
||||
return Path(save_url_image(image_ref, prefix=prefix))
|
||||
return Path(image_ref)
|
||||
|
||||
|
||||
def _rejected_background(error: str) -> bool:
|
||||
"""True when a provider error is specifically about the ``background`` param.
|
||||
|
||||
Transparent backgrounds are a per-model capability (e.g. some gpt-image tiers
|
||||
reject ``background=transparent`` outright). We detect that one rejection so
|
||||
we can retry without the flag rather than failing the whole pet — our chroma
|
||||
key pass makes the result transparent regardless.
|
||||
"""
|
||||
lowered = (error or "").lower()
|
||||
return "background" in lowered and ("not supported" in lowered or "transparent" in lowered)
|
||||
|
||||
|
||||
def generate(
|
||||
prompt: str,
|
||||
*,
|
||||
n: int = 1,
|
||||
reference_images: list[Path] | None = None,
|
||||
provider: SpriteProvider | None = None,
|
||||
prefix: str = "pet_gen",
|
||||
) -> list[Path]:
|
||||
"""Generate *n* square sprite images and return their local paths.
|
||||
|
||||
*reference_images* grounds the output on a base image (required for rows).
|
||||
We *ask* for a transparent background, but fall back to an opaque generation
|
||||
(cleaned up downstream by the chroma-key pass) on models that reject the
|
||||
flag. Raises :class:`GenerationError` if nothing usable comes back.
|
||||
"""
|
||||
sprite = provider or resolve_provider(require_references=bool(reference_images))
|
||||
if reference_images and not sprite.supports_references:
|
||||
raise GenerationError(
|
||||
f"image backend '{sprite.name}' cannot use reference images; "
|
||||
"configure OpenAI gpt-image-2 or Krea for pet generation"
|
||||
)
|
||||
|
||||
refs = [str(p) for p in (reference_images or [])]
|
||||
|
||||
def _run(extra: dict) -> tuple[Path | None, str]:
|
||||
kwargs: dict = {"aspect_ratio": "square", **extra}
|
||||
if refs:
|
||||
# Providers disagree on the ref kwarg name: our OpenRouter/Nous
|
||||
# backends read ``reference_images``, OpenAI's gpt-image-2 reads
|
||||
# ``reference_image_urls``. Send both; each ignores the other.
|
||||
kwargs["reference_images"] = refs
|
||||
kwargs["reference_image_urls"] = refs
|
||||
try:
|
||||
result = sprite.provider.generate(prompt, **kwargs)
|
||||
except Exception as exc: # noqa: BLE001 - normalize provider crashes
|
||||
logger.debug("provider.generate crashed: %s", exc)
|
||||
return None, str(exc)
|
||||
if not isinstance(result, dict) or not result.get("success"):
|
||||
return None, (result or {}).get("error", "unknown error") if isinstance(result, dict) else "no result"
|
||||
image_ref = result.get("image")
|
||||
if not image_ref:
|
||||
return None, "provider returned no image"
|
||||
try:
|
||||
return _save_local(str(image_ref), prefix=prefix), ""
|
||||
except Exception as exc: # noqa: BLE001
|
||||
return None, f"could not save generated image: {exc}"
|
||||
|
||||
out: list[Path] = []
|
||||
last_error = ""
|
||||
allow_transparent = True
|
||||
for _ in range(max(1, n)):
|
||||
path, err = _run({"background": "transparent"} if allow_transparent else {})
|
||||
# Model doesn't support the transparent flag → drop it for this and every
|
||||
# remaining variant (no point re-probing a capability we just disproved).
|
||||
if path is None and allow_transparent and _rejected_background(err):
|
||||
allow_transparent = False
|
||||
path, err = _run({})
|
||||
if path is not None:
|
||||
out.append(path)
|
||||
else:
|
||||
last_error = err
|
||||
|
||||
if not out:
|
||||
raise GenerationError(last_error or "image generation produced no output")
|
||||
return out
|
||||
292
agent/pet/generate/orchestrate.py
Normal file
292
agent/pet/generate/orchestrate.py
Normal file
|
|
@ -0,0 +1,292 @@
|
|||
"""Pet generation orchestration — the base-draft → hatch flow.
|
||||
|
||||
Two steps, mirroring the UX across every surface:
|
||||
|
||||
1. :func:`generate_base_drafts` — a handful of prompt-only "what should this pet
|
||||
look like" variants. Cheap; the user picks one (or retries for a fresh set).
|
||||
2. :func:`hatch_pet` — takes the chosen base and generates one grounded row
|
||||
strip per Hermes state, slices each into frames, composes the atlas, validates
|
||||
it, and writes the pet into the store.
|
||||
|
||||
Splitting it this way bounds cost (4 cheap base calls per round; the ~6 row
|
||||
calls happen once, on the pet you actually keep) and gives each UI a natural
|
||||
preview/loading point.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
from agent.pet.generate import atlas, imagegen, prompts
|
||||
from agent.pet.generate.imagegen import GenerationError, SpriteProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# (event, detail) — e.g. ("row", "idle"), ("compose", ""), ("save", "<slug>").
|
||||
ProgressFn = Callable[[str, str], None]
|
||||
|
||||
# Image generations are independent network calls, so we fan them out instead of
|
||||
# blocking on each in turn — a hatch is ~8 row calls that would otherwise run
|
||||
# back-to-back and routinely blow past the client's RPC timeout. Capped so we
|
||||
# don't hammer the provider's rate limit (one cold call can still be slow).
|
||||
_MAX_PARALLEL_GENERATIONS = 4
|
||||
_MIN_FILLED_STATES = 6
|
||||
_REQUIRED_STATES = frozenset({"idle", "running-right", "waving"})
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class HatchResult:
|
||||
"""Outcome of a successful :func:`hatch_pet`."""
|
||||
|
||||
slug: str
|
||||
display_name: str
|
||||
spritesheet: Path
|
||||
states: list[str]
|
||||
validation: dict
|
||||
|
||||
|
||||
def _harden_transparency(path: Path) -> Path:
|
||||
"""Key out any solid backdrop the provider painted; save as an RGBA PNG.
|
||||
|
||||
``background=transparent`` is requested on every call, but image models honor
|
||||
it inconsistently — some still paint a flat (often near-white) backdrop. We
|
||||
run the same chroma-key pass the row extractor uses so every base draft the
|
||||
user picks between (and the reference the rows are grounded on) is a clean
|
||||
cutout. Best-effort: a decode failure leaves the original untouched.
|
||||
"""
|
||||
from PIL import Image
|
||||
|
||||
try:
|
||||
with Image.open(path) as opened:
|
||||
keyed = atlas.remove_background(opened.convert("RGBA"))
|
||||
# Zero the RGB of any leftover semi-transparent edge pixels so a keyed
|
||||
# draft has no colored halo when composited on the dark UI.
|
||||
keyed = atlas._clear_transparent_rgb(keyed)
|
||||
out = path.with_suffix(".png")
|
||||
keyed.save(out, format="PNG")
|
||||
return out
|
||||
except Exception as exc: # noqa: BLE001 - cosmetic; fall back to the raw image
|
||||
logger.debug("base draft transparency hardening failed for %s: %s", path, exc)
|
||||
return path
|
||||
|
||||
|
||||
def generate_base_drafts(
|
||||
concept: str,
|
||||
*,
|
||||
n: int = 4,
|
||||
style: str = "auto",
|
||||
provider: SpriteProvider | None = None,
|
||||
on_draft: Callable[[int, Path], None] | None = None,
|
||||
is_cancelled: Callable[[], bool] | None = None,
|
||||
) -> list[Path]:
|
||||
"""Generate *n* candidate base looks for *concept*; returns image paths.
|
||||
|
||||
Each draft is hardened to a transparent cutout (see :func:`_harden_transparency`).
|
||||
Drafts are generated concurrently and *on_draft(index, path)* fires as each
|
||||
one finishes (not at the end) so callers can stream previews to the UI
|
||||
instead of leaving it blank until the whole batch is done.
|
||||
|
||||
*is_cancelled*, when supplied, is polled cooperatively: a draft that hasn't
|
||||
started yet is skipped, and once it trips we stop staging/streaming further
|
||||
drafts and cancel any queued work (already-in-flight provider calls can't be
|
||||
hard-killed, but their results are dropped).
|
||||
"""
|
||||
sprite = provider or imagegen.resolve_provider(require_references=False)
|
||||
cancelled = is_cancelled or (lambda: False)
|
||||
|
||||
# Each draft is its own one-shot generation, run concurrently so the user
|
||||
# waits for one image, not N. A single draft failing must not sink the set.
|
||||
# Each gets a distinct variation nudge so the options aren't near-duplicates.
|
||||
logger.info("pet generate: drafting %d base looks for %r (style=%s)", n, concept, style)
|
||||
|
||||
def _one(index: int) -> tuple[int, Path | None]:
|
||||
if cancelled():
|
||||
return index, None
|
||||
t0 = time.monotonic()
|
||||
variation = prompts.BASE_VARIATIONS[index % len(prompts.BASE_VARIATIONS)]
|
||||
prompt = prompts.build_base_prompt(concept, style=style, variation=variation)
|
||||
try:
|
||||
out = imagegen.generate(prompt, n=1, provider=sprite, prefix="pet_base")
|
||||
except Exception as exc: # noqa: BLE001 - tolerate a single failed draft
|
||||
logger.warning("pet generate: draft %d failed after %.1fs: %s", index, time.monotonic() - t0, exc)
|
||||
return index, None
|
||||
if not out:
|
||||
logger.warning("pet generate: draft %d produced no image", index)
|
||||
return index, None
|
||||
logger.info("pet generate: draft %d ready in %.1fs", index, time.monotonic() - t0)
|
||||
return index, _harden_transparency(out[0])
|
||||
|
||||
workers = max(1, min(n, _MAX_PARALLEL_GENERATIONS))
|
||||
results: dict[int, Path] = {}
|
||||
with ThreadPoolExecutor(max_workers=workers) as pool:
|
||||
futures = [pool.submit(_one, i) for i in range(n)]
|
||||
# as_completed runs in *this* (the caller's) thread, so on_draft — and any
|
||||
# gateway event it emits — inherits the request's bound transport, unlike
|
||||
# the worker threads above.
|
||||
for fut in as_completed(futures):
|
||||
if cancelled():
|
||||
logger.info("pet generate: cancelled — dropping remaining drafts")
|
||||
for pending in futures:
|
||||
pending.cancel()
|
||||
break
|
||||
index, path = fut.result()
|
||||
if path is None:
|
||||
continue
|
||||
results[index] = path
|
||||
if on_draft is not None:
|
||||
try:
|
||||
on_draft(index, path)
|
||||
except Exception as exc: # noqa: BLE001 - progress is best-effort
|
||||
logger.debug("on_draft callback failed: %s", exc)
|
||||
|
||||
drafts = [results[i] for i in sorted(results)]
|
||||
if not drafts and not cancelled():
|
||||
raise GenerationError("image generation produced no usable drafts")
|
||||
return drafts
|
||||
|
||||
|
||||
def hatch_pet(
|
||||
*,
|
||||
base_image: str | Path,
|
||||
slug: str,
|
||||
display_name: str = "",
|
||||
description: str = "",
|
||||
concept: str = "",
|
||||
style: str = "auto",
|
||||
on_progress: ProgressFn | None = None,
|
||||
provider: SpriteProvider | None = None,
|
||||
is_cancelled: Callable[[], bool] | None = None,
|
||||
) -> HatchResult:
|
||||
"""Turn an approved base image into a full, installed Hermes pet.
|
||||
|
||||
Generates a grounded row strip per state, extracts frames, composes +
|
||||
validates the atlas, and registers it. The idle row falls back to the base
|
||||
look so the pet always renders. Raises :class:`GenerationError` on failure.
|
||||
|
||||
*is_cancelled*, when supplied, is polled cooperatively: rows that haven't
|
||||
started are skipped, queued rows are cancelled, and once every row is done we
|
||||
abort (raising :class:`GenerationError`) before composing/saving so a stopped
|
||||
hatch never writes a half-built pet.
|
||||
"""
|
||||
base = Path(base_image)
|
||||
if not base.is_file():
|
||||
raise GenerationError(f"base image not found: {base}")
|
||||
|
||||
sprite = provider or imagegen.resolve_provider(require_references=True)
|
||||
progress = on_progress or (lambda *_: None)
|
||||
cancelled = is_cancelled or (lambda: False)
|
||||
label = concept or display_name or slug
|
||||
|
||||
frames_by_state: dict[str, list] = {}
|
||||
total_rows = len(atlas.ROW_SPECS)
|
||||
logger.info("pet hatch %r: generating %d animation rows", slug, total_rows)
|
||||
|
||||
# Generate every state's row strip concurrently — they're independent
|
||||
# grounded calls, so the hatch waits for the slowest row, not their sum. A
|
||||
# single row failing is tolerated (idle is guaranteed below).
|
||||
def _gen_row(spec: tuple[str, int, int]) -> tuple[str, list | None]:
|
||||
state, _row, count = spec
|
||||
if cancelled():
|
||||
return state, None
|
||||
t0 = time.monotonic()
|
||||
try:
|
||||
strips = imagegen.generate(
|
||||
prompts.build_row_prompt(state, count, label, style=style),
|
||||
n=1,
|
||||
reference_images=[base],
|
||||
provider=sprite,
|
||||
prefix=f"pet_row_{state}",
|
||||
)
|
||||
# One image call per row (the expensive part). ``auto`` validates by
|
||||
# connected components with an equal-slot fallback; raw (fit=False) so
|
||||
# normalize_cells registers the whole pet at once. We deliberately do
|
||||
# NOT re-generate a ragged row — the registration pass salvages it far
|
||||
# cheaper than another image-model round-trip.
|
||||
frames = atlas.extract_strip_frames(strips[0], count, method="auto", fit=False)
|
||||
logger.info("pet hatch %r: row %r ready in %.1fs", slug, state, time.monotonic() - t0)
|
||||
return state, frames
|
||||
except Exception as exc: # noqa: BLE001 - one bad row is tolerated (idle guaranteed)
|
||||
logger.warning("pet hatch %r: row %r failed after %.1fs: %s", slug, state, time.monotonic() - t0, exc)
|
||||
return state, None
|
||||
|
||||
# running-left is derived by mirroring running-right (guaranteed-consistent
|
||||
# and one fewer generation), so we don't generate it directly.
|
||||
generated_specs = [spec for spec in atlas.ROW_SPECS if spec[0] != "running-left"]
|
||||
|
||||
workers = max(1, min(len(generated_specs), _MAX_PARALLEL_GENERATIONS))
|
||||
done = 0
|
||||
with ThreadPoolExecutor(max_workers=workers) as pool:
|
||||
futures = [pool.submit(_gen_row, spec) for spec in generated_specs]
|
||||
# as_completed runs on the caller (request) thread, so progress events
|
||||
# emitted here inherit the request transport — unlike the worker threads.
|
||||
for fut in as_completed(futures):
|
||||
if cancelled():
|
||||
logger.info("pet hatch %r: cancelled — dropping remaining rows", slug)
|
||||
for pending in futures:
|
||||
pending.cancel()
|
||||
break
|
||||
state, frames = fut.result()
|
||||
done += 1
|
||||
progress("row", f"{state}:{done}:{total_rows}")
|
||||
if frames:
|
||||
frames_by_state[state] = frames
|
||||
|
||||
if cancelled():
|
||||
raise GenerationError("hatch cancelled")
|
||||
|
||||
# Derive running-left from the approved running-right row (per-frame mirror,
|
||||
# preserving order/timing). Missing running-right is rejected below; a pet
|
||||
# without its canonical walk cycle is a failed hatch, not a shippable mascot.
|
||||
right = frames_by_state.get("running-right")
|
||||
if right:
|
||||
done += 1
|
||||
progress("row", f"running-left:{done}:{total_rows}")
|
||||
frames_by_state["running-left"] = atlas.mirror_frames(right)
|
||||
logger.info("pet hatch %r: row 'running-left' mirrored from running-right", slug)
|
||||
else:
|
||||
logger.warning("pet hatch %r: no running-right to mirror; left walk left empty", slug)
|
||||
|
||||
# Idle is the resting state the renderer falls back to — guarantee it.
|
||||
if not frames_by_state.get("idle"):
|
||||
progress("row", "idle-fallback")
|
||||
frames_by_state["idle"] = [atlas.single_frame(base, fit=False)]
|
||||
|
||||
progress("compose", "")
|
||||
logger.info("pet hatch %r: composing atlas from %d states", slug, len(frames_by_state))
|
||||
# One shared scale + baseline across every state so the pet never slides or
|
||||
# pulses size between frames; compose just packs the normalized cells.
|
||||
sheet = atlas.compose_atlas(atlas.normalize_cells(frames_by_state))
|
||||
validation = atlas.validate_atlas(sheet)
|
||||
if not validation["ok"]:
|
||||
raise GenerationError("; ".join(validation["errors"]) or "atlas validation failed")
|
||||
filled_states = set(validation["filled_states"])
|
||||
missing_required = sorted(_REQUIRED_STATES - filled_states)
|
||||
if missing_required:
|
||||
raise GenerationError(f"missing required animation row(s): {', '.join(missing_required)}")
|
||||
if len(filled_states) < _MIN_FILLED_STATES:
|
||||
raise GenerationError(
|
||||
f"only {len(filled_states)}/{len(atlas.ROW_SPECS)} animation rows were usable; regenerate"
|
||||
)
|
||||
|
||||
from agent.pet import store
|
||||
|
||||
progress("save", slug)
|
||||
logger.info("pet hatch %r: saving pet", slug)
|
||||
pet = store.register_local_pet(
|
||||
sheet,
|
||||
slug=slug,
|
||||
display_name=display_name or slug,
|
||||
description=description,
|
||||
)
|
||||
return HatchResult(
|
||||
slug=pet.slug,
|
||||
display_name=pet.display_name,
|
||||
spritesheet=pet.spritesheet,
|
||||
states=validation["filled_states"],
|
||||
validation=validation,
|
||||
)
|
||||
140
agent/pet/generate/prompts.py
Normal file
140
agent/pet/generate/prompts.py
Normal file
|
|
@ -0,0 +1,140 @@
|
|||
"""Prompt builders for pet generation.
|
||||
|
||||
Two prompt shapes: a *base* prompt (prompt-only, produces the canonical look the
|
||||
user picks between) and per-*state* *row* prompts (grounded on the chosen base,
|
||||
produce one horizontal strip of N poses). Prompts stay concise and
|
||||
sprite-production oriented; the identity lock and "one transparent row" framing
|
||||
matter more than flowery description.
|
||||
|
||||
We generate the full petdex/Codex nine-state set (see
|
||||
:data:`agent.pet.generate.atlas.ROW_SPECS`) so a hatched pet is a valid
|
||||
``petdex submit`` spritesheet.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
# What each petdex/Codex state should depict (kept short — these go straight into
|
||||
# the row prompt). Phrased to avoid the common sprite-gen failure modes (detached
|
||||
# effects, motion lines, shadows). Critical distinction: ``running`` is the
|
||||
# *working* state (in place), while ``running-right`` / ``running-left`` are the
|
||||
# actual directional walk/run cycles.
|
||||
STATE_ACTIONS: dict[str, str] = {
|
||||
"idle": "a calm idle loop: subtle breathing, a tiny blink or gentle bob, no big gestures",
|
||||
"running-right": (
|
||||
"a sideways walk/run locomotion cycle moving to the RIGHT: the character "
|
||||
"faces and travels right with clear directional steps, a smooth gait loop"
|
||||
),
|
||||
"running-left": (
|
||||
"a sideways walk/run locomotion cycle moving to the LEFT: the character "
|
||||
"faces and travels left with clear directional steps (the mirror of the "
|
||||
"right-facing run)"
|
||||
),
|
||||
"waving": "a friendly greeting: raising a paw/hand/limb to wave, clear up-and-down gesture",
|
||||
"jumping": "a happy celebration jump: anticipation, lift off the ground, peak, and land",
|
||||
"failed": "a sad or deflated reaction: slumped, dejected, small frown — readable but not noisy",
|
||||
"waiting": (
|
||||
"an expectant 'waiting on you' pose: looking up/out as if asking for input "
|
||||
"or approval — distinct from idle and review"
|
||||
),
|
||||
"running": (
|
||||
"focused active work, staying IN PLACE (NOT walking or foot-running): "
|
||||
"leaning in, concentrating, busy 'thinking / processing / typing' energy"
|
||||
),
|
||||
"review": "careful inspection: a focused lean, head tilt, studying something intently",
|
||||
}
|
||||
|
||||
_STYLE_HINTS: dict[str, str] = {
|
||||
# Default to the popular petdex look: crisp 16-bit PIXEL ART, not the smooth
|
||||
# 2D illustration (let alone 3D render) gpt-image reaches for by default.
|
||||
"auto": (
|
||||
" Style: crisp 16-bit PIXEL-ART game sprite — visible square pixels, a small "
|
||||
"limited palette, clean dark outline, flat cel shading, chunky chibi "
|
||||
"proportions, like a classic SNES/JRPG party member or a petdex.dev mascot. "
|
||||
"Absolutely NOT 3D-rendered, NOT a smooth painted or vector illustration, "
|
||||
"NOT photorealistic — no soft gradients, no realistic lighting, no figurine look."
|
||||
),
|
||||
"pixel": " Render in clean 16-bit pixel-art style with visible square pixels and a limited palette.",
|
||||
"plush": " Render as a soft plush toy.",
|
||||
"clay": " Render as a claymation / soft 3D clay figure.",
|
||||
"sticker": " Render as a glossy die-cut sticker.",
|
||||
"flat-vector": " Render in flat vector mascot style.",
|
||||
"3d-toy": " Render as a glossy 3D toy.",
|
||||
"painterly": " Render in a soft painterly style.",
|
||||
}
|
||||
|
||||
_BACKGROUND = (
|
||||
"Center one full-body character on a flat, uniform, high-contrast chroma-key "
|
||||
"background (prefer pure hot magenta #FF00FF unless that color appears on "
|
||||
"the character). The background must completely surround the character: one "
|
||||
"even color with NO gradient, vignette, texture, pattern, scenery, shadow, "
|
||||
"ground line, frame, or border, so it keys out cleanly. The background color "
|
||||
"must not appear anywhere on the character itself. No text, no labels."
|
||||
)
|
||||
|
||||
|
||||
def style_hint(style: str | None) -> str:
|
||||
return _STYLE_HINTS.get((style or "auto").strip().lower(), "")
|
||||
|
||||
|
||||
# Per-draft nudges so the 4 base options are actually distinct — gpt-image returns
|
||||
# near-duplicates for a single prompt. We vary the *look* (palette, build,
|
||||
# expression, accents), NOT the pose, so the chosen base still grounds clean,
|
||||
# consistent animation rows.
|
||||
BASE_VARIATIONS: tuple[str, ...] = (
|
||||
"",
|
||||
"a distinctly different colour palette and markings",
|
||||
"rounder, chunkier chibi proportions and a bigger head",
|
||||
"a different face and expression, with unique accent/accessory details",
|
||||
"a leaner, taller build and an alternate colour scheme",
|
||||
"bolder, more saturated colours and a playful expression",
|
||||
)
|
||||
|
||||
|
||||
def build_base_prompt(concept: str, *, style: str | None = "auto", variation: str = "") -> str:
|
||||
"""The base look: a single, clean, centered full-body mascot.
|
||||
|
||||
*variation* differentiates one draft from the next (see :data:`BASE_VARIATIONS`).
|
||||
"""
|
||||
concept = (concept or "a cute friendly mascot creature").strip()
|
||||
nudge = f" Make this design distinct: {variation}." if variation else ""
|
||||
return (
|
||||
f"A cute, characterful mascot pet: {concept}. "
|
||||
"Compact, whole-body silhouette that reads clearly at small size, "
|
||||
"appealing face, simple consistent palette. "
|
||||
# A neutral, symmetric, at-rest stance makes the cleanest identity anchor
|
||||
"Neutral front-facing standing pose, upright and symmetric, arms/limbs "
|
||||
"relaxed at the sides, feet together on the ground, any cape/accessories "
|
||||
"hanging straight and still."
|
||||
f"{nudge} "
|
||||
f"{_BACKGROUND}{style_hint(style)}"
|
||||
)
|
||||
|
||||
|
||||
def build_row_prompt(state: str, frame_count: int, concept: str, *, style: str | None = "auto") -> str:
|
||||
"""A row strip: *frame_count* poses of the SAME character, left→right.
|
||||
|
||||
The attached base image is the identity source of truth; the prompt locks
|
||||
species, palette, face, and props to it.
|
||||
"""
|
||||
action = STATE_ACTIONS.get(state, "a simple idle pose")
|
||||
concept = (concept or "the mascot").strip()
|
||||
return (
|
||||
f"Using the attached reference image as the exact same character "
|
||||
f"(same species, face, colors, markings, proportions, and props), "
|
||||
f"draw a single horizontal strip of {frame_count} animation frames showing {action}. "
|
||||
f"The {frame_count} poses must be evenly spaced left to right, each fully separated "
|
||||
"by clear empty chroma-key gutters; silhouettes must NEVER touch, overlap, "
|
||||
"share a shadow, share a ground line, share motion trails, or merge into "
|
||||
"one connected shape. "
|
||||
# Registration: a clean sprite sheet keeps the character locked in place
|
||||
# so only the action moves — this is what stops the loop sliding/pulsing.
|
||||
"REGISTRATION (critical): the character is the SAME height and SAME width "
|
||||
"in every frame, drawn at the SAME scale, centered over the SAME point, "
|
||||
"with all feet resting on ONE shared horizontal ground line across the "
|
||||
"whole strip. Keep the body's center, size, and stance fixed frame to "
|
||||
"frame — ONLY the limbs/features the action needs may move. Capes, cloaks, "
|
||||
"bags, and scarves stay in the SAME place and shape every frame (no "
|
||||
"swinging, flowing, or drifting) unless the action itself requires it. No "
|
||||
"pose is cropped at the strip edges. "
|
||||
f"{_BACKGROUND}{style_hint(style)}"
|
||||
)
|
||||
|
|
@ -21,6 +21,7 @@ Read-only and unauthenticated; no credentials involved.
|
|||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
|
@ -28,7 +29,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
MANIFEST_URL = "https://petdex.dev/api/manifest"
|
||||
|
||||
_DEFAULT_TIMEOUT = 20.0
|
||||
_DEFAULT_TIMEOUT = 10.0
|
||||
|
||||
# In-process cache for the (large, slow, identical-per-call) manifest. The list
|
||||
# is a static CDN object that barely changes, yet a single session can ask for
|
||||
|
|
@ -38,6 +39,9 @@ _DEFAULT_TIMEOUT = 20.0
|
|||
_MANIFEST_TTL = 300.0
|
||||
_cache: tuple[float, list[ManifestEntry]] | None = None
|
||||
|
||||
_prefetch_lock = threading.Lock()
|
||||
_prefetching = False
|
||||
|
||||
|
||||
def clear_cache() -> None:
|
||||
"""Drop the cached manifest (forces the next fetch to hit the network)."""
|
||||
|
|
@ -45,6 +49,39 @@ def clear_cache() -> None:
|
|||
_cache = None
|
||||
|
||||
|
||||
def _cache_is_warm() -> bool:
|
||||
return _cache is not None and time.monotonic() - _cache[0] < _MANIFEST_TTL
|
||||
|
||||
|
||||
def prefetch(*, timeout: float = _DEFAULT_TIMEOUT) -> None:
|
||||
"""Warm the manifest cache in a daemon thread — idempotent, never blocks.
|
||||
|
||||
The desktop picker calls this when it loads the (instant) local-only gallery
|
||||
so the full petdex catalog is usually cached by the time it's requested,
|
||||
without ever holding up the user's own pets on a network round-trip.
|
||||
"""
|
||||
global _prefetching
|
||||
|
||||
if _cache_is_warm():
|
||||
return
|
||||
|
||||
with _prefetch_lock:
|
||||
if _prefetching:
|
||||
return
|
||||
_prefetching = True
|
||||
|
||||
def _run() -> None:
|
||||
global _prefetching
|
||||
try:
|
||||
fetch_manifest(timeout=timeout)
|
||||
except Exception as exc: # noqa: BLE001 - best-effort warm
|
||||
logger.debug("petdex manifest prefetch failed: %s", exc)
|
||||
finally:
|
||||
_prefetching = False
|
||||
|
||||
threading.Thread(target=_run, name="petdex-prefetch", daemon=True).start()
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ManifestEntry:
|
||||
"""A single pet's row in the manifest."""
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ from __future__ import annotations
|
|||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
|
|
@ -41,11 +42,16 @@ class InstalledPet:
|
|||
description: str
|
||||
directory: Path
|
||||
spritesheet: Path
|
||||
created_by: str = "" # "generator" for pets hatched locally; "" for petdex installs
|
||||
|
||||
@property
|
||||
def exists(self) -> bool:
|
||||
return self.spritesheet.is_file()
|
||||
|
||||
@property
|
||||
def generated(self) -> bool:
|
||||
return self.created_by == "generator"
|
||||
|
||||
|
||||
def pets_dir() -> Path:
|
||||
"""Return the profile-scoped pets directory (created on demand)."""
|
||||
|
|
@ -113,6 +119,7 @@ def load_pet(slug: str) -> InstalledPet | None:
|
|||
description=str(meta.get("description", "") or ""),
|
||||
directory=directory,
|
||||
spritesheet=_resolve_spritesheet(directory, meta),
|
||||
created_by=str(meta.get("createdBy", "") or ""),
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -197,6 +204,101 @@ def install_pet(slug: str, *, force: bool = False, timeout: float = _DOWNLOAD_TI
|
|||
return pet
|
||||
|
||||
|
||||
def slugify(name: str) -> str:
|
||||
"""Lowercase, hyphenate, and strip a display name into a filesystem slug."""
|
||||
slug = re.sub(r"[^a-z0-9]+", "-", (name or "").strip().lower()).strip("-")
|
||||
return slug or "pet"
|
||||
|
||||
|
||||
def unique_slug(name: str) -> str:
|
||||
"""A :func:`slugify` result that doesn't collide with an existing pet dir."""
|
||||
base = slugify(name)
|
||||
slug = base
|
||||
counter = 2
|
||||
while (pets_dir() / slug).exists():
|
||||
slug = f"{base}-{counter}"
|
||||
counter += 1
|
||||
return slug
|
||||
|
||||
|
||||
def _write_spritesheet(source, dest: Path) -> None:
|
||||
"""Write *source* (PIL image, bytes, or path) as a lossless WebP at *dest*."""
|
||||
if isinstance(source, (bytes, bytearray)):
|
||||
dest.write_bytes(bytes(source))
|
||||
return
|
||||
|
||||
from PIL import Image
|
||||
|
||||
if isinstance(source, (str, Path)):
|
||||
with Image.open(source) as opened:
|
||||
image = opened.convert("RGBA")
|
||||
else:
|
||||
image = source.convert("RGBA")
|
||||
image.save(dest, format="WEBP", lossless=True, quality=100, method=6, exact=True)
|
||||
|
||||
|
||||
def register_local_pet(
|
||||
spritesheet,
|
||||
*,
|
||||
slug: str,
|
||||
display_name: str = "",
|
||||
description: str = "",
|
||||
) -> InstalledPet:
|
||||
"""Write a locally-generated pet into the store and return it.
|
||||
|
||||
*spritesheet* may be a PIL image, raw WebP/PNG bytes, or a path. The pet
|
||||
appears in :func:`installed_pets` immediately, and because :func:`install_pet`
|
||||
returns an already-on-disk pet before consulting the manifest, it can be
|
||||
adopted (``pet.select`` / ``/pet <slug>``) without a manifest entry.
|
||||
"""
|
||||
slug = slugify(slug)
|
||||
directory = pets_dir() / slug
|
||||
directory.mkdir(parents=True, exist_ok=True)
|
||||
sprite_path = directory / "spritesheet.webp"
|
||||
try:
|
||||
_write_spritesheet(spritesheet, sprite_path)
|
||||
except Exception as exc: # noqa: BLE001 - normalize to one error type
|
||||
raise PetStoreError(f"could not write spritesheet for '{slug}': {exc}") from exc
|
||||
|
||||
meta = {
|
||||
"id": slug,
|
||||
"displayName": display_name or slug,
|
||||
"description": description or "",
|
||||
"spritesheetPath": sprite_path.name,
|
||||
"createdBy": "generator",
|
||||
}
|
||||
(directory / "pet.json").write_text(json.dumps(meta, indent=2), encoding="utf-8")
|
||||
|
||||
pet = load_pet(slug)
|
||||
if pet is None or not pet.exists:
|
||||
raise PetStoreError(f"register of generated pet '{slug}' did not produce a spritesheet")
|
||||
return pet
|
||||
|
||||
|
||||
def export_pet(slug: str) -> tuple[str, bytes]:
|
||||
"""Zip an installed pet's folder (pet.json + spritesheet) → (filename, bytes).
|
||||
|
||||
Dotfiles (cached thumbs, backups) are skipped so the archive is a clean,
|
||||
re-importable pet package. Raises :class:`PetStoreError` if not installed.
|
||||
"""
|
||||
import io
|
||||
import zipfile
|
||||
|
||||
root = pets_dir()
|
||||
directory = root / slug.strip()
|
||||
# Guard against traversal: the target must be a direct child of pets_dir.
|
||||
if directory.resolve().parent != root.resolve() or not directory.is_dir():
|
||||
raise PetStoreError(f"pet '{slug}' is not installed")
|
||||
|
||||
name = directory.name
|
||||
buf = io.BytesIO()
|
||||
with zipfile.ZipFile(buf, "w", zipfile.ZIP_DEFLATED) as archive:
|
||||
for path in sorted(directory.iterdir()):
|
||||
if path.is_file() and not path.name.startswith("."):
|
||||
archive.write(path, f"{name}/{path.name}")
|
||||
return f"{name}.zip", buf.getvalue()
|
||||
|
||||
|
||||
_THUMB_FRAME_W = 192
|
||||
_THUMB_FRAME_H = 208
|
||||
_THUMB_W = 96 # rendered ~40px; 2x+ keeps it crisp on HiDPI
|
||||
|
|
@ -301,6 +403,15 @@ def remove_pet(slug: str) -> bool:
|
|||
slug = _safe_slug(slug)
|
||||
if not slug:
|
||||
return False
|
||||
|
||||
# The cached thumbnail lives in pets/.thumbs/<slug>.png — OUTSIDE the pet
|
||||
# dir, so rmtree won't catch it. Drop it too, or a later pet that reuses this
|
||||
# slug renders this one's stale thumbnail.
|
||||
try:
|
||||
(_thumbs_dir() / f"{slug}.png").unlink(missing_ok=True)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
directory = pets_dir() / slug
|
||||
if not directory.is_dir():
|
||||
return False
|
||||
|
|
@ -308,6 +419,55 @@ def remove_pet(slug: str) -> bool:
|
|||
return not directory.exists()
|
||||
|
||||
|
||||
def rename_pet(slug: str, display_name: str) -> str | None:
|
||||
"""Rename a pet's ``displayName`` AND realign its slug/dir to match.
|
||||
|
||||
Generated pets are hatched under a provisional, prompt-derived slug; when
|
||||
the user names the pet on the reveal screen we make that name the real
|
||||
identity so lists/subtitles show what they typed, not the prompt. The dir is
|
||||
renamed to ``slugify(name)`` (and the cached thumbnail moved alongside it)
|
||||
whenever that yields a free, different slug — otherwise the slug is left as
|
||||
is. Returns the resulting slug on success, or ``None`` on failure.
|
||||
"""
|
||||
slug = _safe_slug(slug)
|
||||
display_name = (display_name or "").strip()
|
||||
if not slug or not display_name:
|
||||
return None
|
||||
directory = pets_dir() / slug
|
||||
pet_json = directory / "pet.json"
|
||||
if not pet_json.is_file():
|
||||
return None
|
||||
try:
|
||||
meta = json.loads(pet_json.read_text(encoding="utf-8"))
|
||||
except (OSError, ValueError):
|
||||
meta = {}
|
||||
if not isinstance(meta, dict):
|
||||
meta = {}
|
||||
meta["displayName"] = display_name
|
||||
|
||||
new_slug = slug
|
||||
desired = slugify(display_name)
|
||||
if desired and desired != slug and not (pets_dir() / desired).exists():
|
||||
try:
|
||||
directory.rename(pets_dir() / desired)
|
||||
try:
|
||||
(_thumbs_dir() / f"{slug}.png").rename(_thumbs_dir() / f"{desired}.png")
|
||||
except OSError:
|
||||
pass
|
||||
directory = pets_dir() / desired
|
||||
pet_json = directory / "pet.json"
|
||||
new_slug = desired
|
||||
meta["id"] = new_slug
|
||||
except OSError:
|
||||
new_slug = slug # keep the provisional slug if the move fails
|
||||
|
||||
try:
|
||||
pet_json.write_text(json.dumps(meta, indent=2), encoding="utf-8")
|
||||
except OSError:
|
||||
return None
|
||||
return new_slug
|
||||
|
||||
|
||||
def _download(url: str, dest: Path, *, timeout: float) -> None:
|
||||
import httpx
|
||||
|
||||
|
|
|
|||
413
tests/agent/test_pet_generate.py
Normal file
413
tests/agent/test_pet_generate.py
Normal file
|
|
@ -0,0 +1,413 @@
|
|||
"""Tests for pet generation: deterministic atlas ops, store register, orchestration.
|
||||
|
||||
No network/API calls — image generation is mocked with synthetic strips so the
|
||||
whole pipeline (segmentation → compose → validate → register → adopt) is
|
||||
exercised hermetically.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from agent.pet.generate import atlas
|
||||
|
||||
PIL = pytest.importorskip("PIL")
|
||||
from PIL import Image, ImageDraw # noqa: E402
|
||||
|
||||
|
||||
def _strip(n_blobs: int, *, transparent: bool = True, bg=(0, 255, 0, 255), size=(208, 208)) -> Image.Image:
|
||||
"""A horizontal strip with *n_blobs* clearly-separated colored ellipses."""
|
||||
w = size[0] * n_blobs
|
||||
h = size[1]
|
||||
base = (0, 0, 0, 0) if transparent else bg
|
||||
img = Image.new("RGBA", (w, h), base)
|
||||
draw = ImageDraw.Draw(img)
|
||||
for i in range(n_blobs):
|
||||
cx = i * size[0] + size[0] // 2
|
||||
cy = h // 2
|
||||
r = size[0] // 3
|
||||
color = (40 + i * 30 % 200, 80, 200 - i * 20 % 180, 255)
|
||||
draw.ellipse((cx - r, cy - r, cx + r, cy + r), fill=color)
|
||||
return img
|
||||
|
||||
|
||||
# ───────────────────────── frame extraction ─────────────────────────
|
||||
|
||||
|
||||
def test_extract_strip_frames_transparent_returns_centered_cells():
|
||||
frames = atlas.extract_strip_frames(_strip(6), 6)
|
||||
assert len(frames) == 6
|
||||
for frame in frames:
|
||||
assert frame.size == (atlas.CELL_WIDTH, atlas.CELL_HEIGHT)
|
||||
# Background corners must be transparent.
|
||||
assert frame.getpixel((0, 0))[3] == 0
|
||||
# Something is drawn.
|
||||
assert frame.getchannel("A").getextrema()[1] > 0
|
||||
|
||||
|
||||
def test_extract_strip_frames_keys_out_solid_background():
|
||||
frames = atlas.extract_strip_frames(_strip(4, transparent=False), 4)
|
||||
assert len(frames) == 4
|
||||
# The green backdrop must be gone (corner transparent).
|
||||
assert frames[0].getpixel((0, 0))[3] == 0
|
||||
|
||||
|
||||
def test_remove_background_clears_trapped_chroma_pocket():
|
||||
# Green body enclosing a magenta pocket (the "pink between the arm" case):
|
||||
# the pocket isn't border-reachable, so it must be cleared by interior seeding.
|
||||
img = Image.new("RGBA", (200, 200), (255, 0, 255, 255)) # magenta backdrop
|
||||
draw = ImageDraw.Draw(img)
|
||||
draw.ellipse((40, 40, 160, 160), fill=(40, 200, 60, 255)) # body
|
||||
draw.ellipse((85, 85, 115, 115), fill=(255, 0, 255, 255)) # trapped pocket
|
||||
keyed = atlas.remove_background(img)
|
||||
assert keyed.getpixel((100, 100))[3] == 0 # pocket cleared
|
||||
assert keyed.getpixel((100, 50))[3] > 0 # body still opaque
|
||||
assert keyed.getpixel((2, 2))[3] == 0 # border cleared
|
||||
|
||||
|
||||
def test_extract_strip_frames_repairs_provider_alpha_holes():
|
||||
img = _strip(1)
|
||||
draw = ImageDraw.Draw(img)
|
||||
cx = img.width // 2
|
||||
cy = img.height // 2
|
||||
draw.ellipse((cx - 16, cy - 16, cx + 16, cy + 16), fill=(0, 0, 0, 0))
|
||||
|
||||
frames = atlas.extract_strip_frames(img, 1, method="components")
|
||||
assert frames[0].getpixel((atlas.CELL_WIDTH // 2, atlas.CELL_HEIGHT // 2))[3] > 0
|
||||
|
||||
|
||||
def test_extract_strip_frames_severs_thin_bridges_between_frames():
|
||||
# AI strips often connect poses with a 1px shadow/glow bridge. Strict
|
||||
# component extraction must still find each frame instead of treating the row
|
||||
# as one merged subject.
|
||||
img = _strip(4)
|
||||
draw = ImageDraw.Draw(img)
|
||||
draw.line((20, img.height // 2, img.width - 20, img.height // 2), fill=(255, 255, 255, 255), width=1)
|
||||
|
||||
frames = atlas.extract_strip_frames(img, 4, method="components")
|
||||
assert len(frames) == 4
|
||||
assert all(frame.getchannel("A").getextrema()[1] > 0 for frame in frames)
|
||||
|
||||
|
||||
def test_extract_strip_frames_drops_small_side_lobes_from_adjacent_frames():
|
||||
# Frogger regression: a real pose plus a small separated side lobe from a
|
||||
# neighbouring pose. The side lobe should not survive into the fitted cell.
|
||||
img = Image.new("RGBA", (atlas.CELL_WIDTH, atlas.CELL_HEIGHT), (0, 0, 0, 0))
|
||||
draw = ImageDraw.Draw(img)
|
||||
draw.ellipse((52, 34, 150, 188), fill=(70, 190, 70, 255))
|
||||
draw.rectangle((4, 70, 24, 160), fill=(70, 190, 70, 255))
|
||||
draw.rectangle((168, 82, 186, 150), fill=(70, 190, 70, 255))
|
||||
|
||||
frame = atlas.extract_strip_frames(img, 1, method="components")[0]
|
||||
alpha = frame.getchannel("A")
|
||||
left_edge_mass = sum(1 for x in range(0, 36) for y in range(frame.height) if alpha.getpixel((x, y)) > 16)
|
||||
right_edge_mass = sum(1 for x in range(frame.width - 36, frame.width) for y in range(frame.height) if alpha.getpixel((x, y)) > 16)
|
||||
assert left_edge_mass == 0
|
||||
assert right_edge_mass == 0
|
||||
|
||||
|
||||
def test_extract_strip_frames_slot_fallback_when_unsegmentable():
|
||||
# A single connected smear can't be split into 5 components → slot fallback.
|
||||
img = Image.new("RGBA", (200 * 5, 208), (0, 0, 0, 0))
|
||||
ImageDraw.Draw(img).rectangle((0, 80, 200 * 5 - 1, 120), fill=(200, 50, 50, 255))
|
||||
frames = atlas.extract_strip_frames(img, 5, method="auto")
|
||||
assert len(frames) == 5
|
||||
|
||||
|
||||
def test_extract_components_method_raises_when_too_few():
|
||||
img = Image.new("RGBA", (400, 208), (0, 0, 0, 0))
|
||||
ImageDraw.Draw(img).ellipse((10, 10, 100, 100), fill=(255, 0, 0, 255))
|
||||
with pytest.raises(ValueError):
|
||||
atlas.extract_strip_frames(img, 6, method="components")
|
||||
|
||||
|
||||
# ───────────────────────── atlas compose / validate ─────────────────────────
|
||||
|
||||
|
||||
def _frames_for_all_states() -> dict[str, list]:
|
||||
out: dict[str, list] = {}
|
||||
for state, _row, count in atlas.ROW_SPECS:
|
||||
out[state] = atlas.extract_strip_frames(_strip(count), count)
|
||||
return out
|
||||
|
||||
|
||||
def test_compose_atlas_geometry_and_validation():
|
||||
sheet = atlas.compose_atlas(_frames_for_all_states())
|
||||
assert sheet.size == (atlas.ATLAS_WIDTH, atlas.ATLAS_HEIGHT)
|
||||
result = atlas.validate_atlas(sheet)
|
||||
assert result["ok"], result["errors"]
|
||||
assert set(result["filled_states"]) == {s for s, _, _ in atlas.ROW_SPECS}
|
||||
|
||||
|
||||
def test_compose_atlas_leaves_unused_tail_transparent():
|
||||
# waving has 4 frames; columns 4 and 5 of its row must be transparent.
|
||||
sheet = atlas.compose_atlas(_frames_for_all_states())
|
||||
wave_row = next(r for s, r, _ in atlas.ROW_SPECS if s == "waving")
|
||||
top = wave_row * atlas.CELL_HEIGHT
|
||||
for col in (4, 5):
|
||||
left = col * atlas.CELL_WIDTH
|
||||
cell = sheet.crop((left, top, left + atlas.CELL_WIDTH, top + atlas.CELL_HEIGHT))
|
||||
assert cell.getchannel("A").getextrema()[1] == 0
|
||||
|
||||
|
||||
def test_validate_atlas_rejects_wrong_size():
|
||||
bad = Image.new("RGBA", (100, 100), (0, 0, 0, 0))
|
||||
result = atlas.validate_atlas(bad)
|
||||
assert not result["ok"]
|
||||
assert any("expected" in e for e in result["errors"])
|
||||
|
||||
|
||||
def test_validate_atlas_rejects_rgb_residue():
|
||||
sheet = atlas.compose_atlas(_frames_for_all_states())
|
||||
# Poke a fully-transparent pixel with non-zero RGB.
|
||||
sheet.putpixel((0, 0), (120, 0, 0, 0))
|
||||
result = atlas.validate_atlas(sheet)
|
||||
assert not result["ok"]
|
||||
assert any("residue" in e for e in result["errors"])
|
||||
|
||||
|
||||
def test_validate_atlas_warns_on_empty_state():
|
||||
frames = _frames_for_all_states()
|
||||
frames["jumping"] = []
|
||||
sheet = atlas.compose_atlas(frames)
|
||||
result = atlas.validate_atlas(sheet)
|
||||
assert result["ok"] # one empty row is a warning, not an error
|
||||
assert any("jumping" in w for w in result["warnings"])
|
||||
|
||||
|
||||
def test_single_frame_fits_cell():
|
||||
frame = atlas.single_frame(_strip(1))
|
||||
assert frame.size == (atlas.CELL_WIDTH, atlas.CELL_HEIGHT)
|
||||
assert frame.getchannel("A").getextrema()[1] > 0
|
||||
|
||||
|
||||
# ───────────────────────── store register / adopt ─────────────────────────
|
||||
|
||||
|
||||
def test_slugify_and_unique_slug():
|
||||
from agent.pet import store
|
||||
|
||||
assert store.slugify("My Cool Pet!") == "my-cool-pet"
|
||||
assert store.slugify(" ") == "pet"
|
||||
first = store.unique_slug("Robo")
|
||||
(store.pets_dir() / first).mkdir(parents=True)
|
||||
assert store.unique_slug("Robo") == "robo-2"
|
||||
|
||||
|
||||
def test_register_local_pet_appears_and_is_adoptable():
|
||||
from agent.pet import store
|
||||
|
||||
sheet = atlas.compose_atlas(_frames_for_all_states())
|
||||
pet = store.register_local_pet(sheet, slug="Sparky", display_name="Sparky", description="zappy")
|
||||
assert pet.slug == "sparky"
|
||||
assert pet.exists
|
||||
assert any(p.slug == "sparky" for p in store.installed_pets())
|
||||
|
||||
# install_pet returns the on-disk pet without ever hitting the manifest.
|
||||
adopted = store.install_pet("sparky")
|
||||
assert adopted.slug == "sparky"
|
||||
assert adopted.display_name == "Sparky"
|
||||
|
||||
|
||||
def test_register_local_pet_is_generated_and_exports_zip():
|
||||
import io
|
||||
import zipfile
|
||||
|
||||
from agent.pet import store
|
||||
|
||||
sheet = atlas.compose_atlas(_frames_for_all_states())
|
||||
store.register_local_pet(sheet, slug="zippy", display_name="Zippy")
|
||||
assert store.load_pet("zippy").generated is True # createdBy=generator
|
||||
|
||||
filename, data = store.export_pet("zippy")
|
||||
assert filename == "zippy.zip"
|
||||
names = zipfile.ZipFile(io.BytesIO(data)).namelist()
|
||||
assert "zippy/pet.json" in names
|
||||
assert any(n.startswith("zippy/spritesheet") for n in names)
|
||||
|
||||
|
||||
def test_export_pet_rejects_unknown_and_traversal():
|
||||
from agent.pet import store
|
||||
|
||||
with pytest.raises(store.PetStoreError):
|
||||
store.export_pet("does-not-exist")
|
||||
with pytest.raises(store.PetStoreError):
|
||||
store.export_pet("../secrets")
|
||||
|
||||
|
||||
def test_register_local_pet_accepts_bytes():
|
||||
from agent.pet import store
|
||||
|
||||
sheet = atlas.compose_atlas(_frames_for_all_states())
|
||||
data = atlas.atlas_to_webp_bytes(sheet)
|
||||
pet = store.register_local_pet(data, slug="bytey")
|
||||
assert pet.exists
|
||||
|
||||
|
||||
# ───────────────────────── orchestration (mocked imagegen) ─────────────────────────
|
||||
|
||||
|
||||
def test_generate_base_drafts_returns_n(monkeypatch, tmp_path):
|
||||
from agent.pet.generate import imagegen, orchestrate
|
||||
|
||||
calls = {"n": 0}
|
||||
|
||||
def fake_generate(prompt, *, n=1, reference_images=None, provider=None, prefix="pet"):
|
||||
paths = []
|
||||
for i in range(n):
|
||||
calls["n"] += 1
|
||||
p = tmp_path / f"{prefix}_{calls['n']}.png"
|
||||
_strip(1).save(p)
|
||||
paths.append(p)
|
||||
return paths
|
||||
|
||||
monkeypatch.setattr(imagegen, "resolve_provider", lambda **_: object())
|
||||
monkeypatch.setattr(imagegen, "generate", fake_generate)
|
||||
|
||||
drafts = orchestrate.generate_base_drafts("a fox", n=4)
|
||||
assert len(drafts) == 4
|
||||
|
||||
|
||||
def test_generate_base_drafts_hardens_opaque_background(monkeypatch, tmp_path):
|
||||
"""A provider that ignores background=transparent still yields a cutout."""
|
||||
from agent.pet.generate import imagegen, orchestrate
|
||||
|
||||
def fake_generate(prompt, *, n=1, reference_images=None, provider=None, prefix="pet"):
|
||||
# Solid-green backdrop with a blob — i.e. the provider painted a backdrop.
|
||||
p = tmp_path / f"{prefix}_opaque.png"
|
||||
_strip(1, transparent=False, bg=(0, 255, 0, 255)).save(p)
|
||||
return [p]
|
||||
|
||||
monkeypatch.setattr(imagegen, "resolve_provider", lambda **_: object())
|
||||
monkeypatch.setattr(imagegen, "generate", fake_generate)
|
||||
|
||||
drafts = orchestrate.generate_base_drafts("a fox", n=1)
|
||||
assert len(drafts) == 1
|
||||
|
||||
with Image.open(drafts[0]) as out:
|
||||
rgba = out.convert("RGBA")
|
||||
# The keyed backdrop is now transparent (corner pixel fully see-through).
|
||||
assert rgba.getpixel((0, 0))[3] == 0
|
||||
# The pet blob in the center is still opaque.
|
||||
assert rgba.getpixel((rgba.width // 2, rgba.height // 2))[3] > 0
|
||||
|
||||
|
||||
def test_hatch_pet_end_to_end(monkeypatch, tmp_path):
|
||||
from agent.pet import store
|
||||
from agent.pet.generate import atlas as atlas_mod
|
||||
from agent.pet.generate import imagegen, orchestrate
|
||||
|
||||
base = tmp_path / "base.png"
|
||||
_strip(1).save(base)
|
||||
|
||||
def fake_generate(prompt, *, n=1, reference_images=None, provider=None, prefix="pet"):
|
||||
# Return a synthetic row strip; frame count is inferable from the spec.
|
||||
state = prefix.replace("pet_row_", "")
|
||||
count = atlas_mod.FRAME_COUNTS.get(state, 6)
|
||||
p = tmp_path / f"{prefix}.png"
|
||||
_strip(count).save(p)
|
||||
return [p]
|
||||
|
||||
monkeypatch.setattr(imagegen, "resolve_provider", lambda **_: object())
|
||||
monkeypatch.setattr(imagegen, "generate", fake_generate)
|
||||
|
||||
events: list[tuple[str, str]] = []
|
||||
result = orchestrate.hatch_pet(
|
||||
base_image=base,
|
||||
slug="mocky",
|
||||
display_name="Mocky",
|
||||
description="a test pet",
|
||||
concept="a fox",
|
||||
on_progress=lambda ev, detail: events.append((ev, detail)),
|
||||
)
|
||||
|
||||
assert result.slug == "mocky"
|
||||
assert result.validation["ok"]
|
||||
assert set(result.states) == {s for s, _, _ in atlas_mod.ROW_SPECS}
|
||||
assert ("compose", "") in events
|
||||
# The pet is on disk and adoptable.
|
||||
assert store.load_pet("mocky").exists
|
||||
|
||||
|
||||
def test_hatch_pet_idle_fallback_when_row_fails(monkeypatch, tmp_path):
|
||||
from agent.pet.generate import atlas as atlas_mod
|
||||
from agent.pet.generate import imagegen, orchestrate
|
||||
from agent.pet.generate.imagegen import GenerationError
|
||||
|
||||
base = tmp_path / "base.png"
|
||||
_strip(1).save(base)
|
||||
|
||||
def fake_generate(prompt, *, n=1, reference_images=None, provider=None, prefix="pet"):
|
||||
if prefix == "pet_row_idle":
|
||||
raise GenerationError("boom")
|
||||
state = prefix.replace("pet_row_", "")
|
||||
count = atlas_mod.FRAME_COUNTS.get(state, 6)
|
||||
p = tmp_path / f"{prefix}.png"
|
||||
_strip(count).save(p)
|
||||
return [p]
|
||||
|
||||
monkeypatch.setattr(imagegen, "resolve_provider", lambda **_: object())
|
||||
monkeypatch.setattr(imagegen, "generate", fake_generate)
|
||||
|
||||
result = orchestrate.hatch_pet(base_image=base, slug="fallbacky", concept="a fox")
|
||||
assert "idle" in result.states # filled by the base-image fallback
|
||||
|
||||
|
||||
def test_hatch_pet_rejects_missing_required_animation_rows(monkeypatch, tmp_path):
|
||||
from agent.pet.generate import atlas as atlas_mod
|
||||
from agent.pet.generate import imagegen, orchestrate
|
||||
from agent.pet.generate.imagegen import GenerationError
|
||||
|
||||
base = tmp_path / "base.png"
|
||||
_strip(1).save(base)
|
||||
|
||||
def fake_generate(prompt, *, n=1, reference_images=None, provider=None, prefix="pet"):
|
||||
if prefix == "pet_row_running-right":
|
||||
raise GenerationError("bad row")
|
||||
state = prefix.replace("pet_row_", "")
|
||||
count = atlas_mod.FRAME_COUNTS.get(state, 6)
|
||||
p = tmp_path / f"{prefix}.png"
|
||||
_strip(count).save(p)
|
||||
return [p]
|
||||
|
||||
monkeypatch.setattr(imagegen, "resolve_provider", lambda **_: object())
|
||||
monkeypatch.setattr(imagegen, "generate", fake_generate)
|
||||
|
||||
with pytest.raises(GenerationError, match="running-right"):
|
||||
orchestrate.hatch_pet(base_image=base, slug="broken", concept="a fox")
|
||||
|
||||
|
||||
def test_resolve_provider_errors_without_backend(monkeypatch):
|
||||
from agent.pet.generate import imagegen
|
||||
|
||||
monkeypatch.setattr(imagegen, "_discover", lambda: None)
|
||||
monkeypatch.setattr("agent.image_gen_registry.get_active_provider", lambda: None)
|
||||
monkeypatch.setattr("agent.image_gen_registry.get_provider", lambda name: None)
|
||||
|
||||
with pytest.raises(imagegen.GenerationError):
|
||||
imagegen.resolve_provider(require_references=True)
|
||||
|
||||
|
||||
def test_generate_retries_without_transparent_background(monkeypatch, tmp_path):
|
||||
"""A model that rejects background=transparent still produces images."""
|
||||
from agent.pet.generate import imagegen
|
||||
|
||||
saved = tmp_path / "img.png"
|
||||
_strip(1).save(saved)
|
||||
calls: list[dict] = []
|
||||
|
||||
class FakeProvider:
|
||||
def generate(self, prompt, **kwargs):
|
||||
calls.append(kwargs)
|
||||
if kwargs.get("background") == "transparent":
|
||||
return {"success": False, "error": "Transparent background is not supported for this model."}
|
||||
return {"success": True, "image": str(saved)}
|
||||
|
||||
sprite = imagegen.SpriteProvider(name="openai", provider=FakeProvider(), supports_references=False)
|
||||
|
||||
out = imagegen.generate("a fox", n=2, provider=sprite)
|
||||
assert len(out) == 2
|
||||
# First variant probes transparent (rejected) then retries opaque; the second
|
||||
# variant skips the transparent probe entirely.
|
||||
backgrounds = [c.get("background") for c in calls]
|
||||
assert backgrounds == ["transparent", None, None]
|
||||
Loading…
Add table
Add a link
Reference in a new issue