import type { Task, Helpers } from "graphile-worker"; import { PrismaClient } from "../../generated/prisma"; import { withAccelerate } from "@prisma/extension-accelerate"; import { getOrDownloadAsset } from "../utils/cache"; import { env } from "../config/env"; import { getS3Client, uploadFile } from "../utils/s3"; import { nanoid } from "nanoid"; import { existsSync, rmSync } from "node:fs"; import { join, basename, extname } from "node:path"; import { readFile, writeFile, readdir } from 'node:fs/promises'; import yaml from 'js-yaml'; import { getNanoSpawn } from "../utils/nanoSpawn"; import which from "which"; import * as ort from 'onnxruntime-node'; interface Payload { vodId: string; } interface Detection { startFrame: number; endFrame: number; className: string; } interface DataYaml { path: string; train: string; val: string; names: Record; } interface FunscriptAction { at: number; pos: number; } interface Funscript { version: string; actions: FunscriptAction[]; } interface ClassPositionMap { [className: string]: number | 'pattern'; } const prisma = new PrismaClient().$extends(withAccelerate()); function assertPayload(payload: any): asserts payload is Payload { if (typeof payload !== "object" || !payload) throw new Error("invalid payload-- was not an object."); if (typeof payload.vodId !== "string") throw new Error("invalid payload-- was missing vodId"); } async function loadDataYaml(yamlPath: string): Promise { const yamlContent = await readFile(yamlPath, 'utf8'); return yaml.load(yamlContent) as DataYaml; } async function preparePython(helpers) { const spawn = await getNanoSpawn(); const venvPath = join(env.VIBEUI_DIR, "venv"); // Determine Python executable let pythonCmd; try { pythonCmd = which.sync("python3"); } catch { helpers.logger.error("Python is not installed or not in PATH."); throw new Error("Python not found in PATH."); } // If venv doesn't exist, create it if (!existsSync(venvPath)) { helpers.logger.info("Python venv not found. Creating one..."); try { await spawn(pythonCmd, ["-m", "venv", "venv"], { cwd: env.VIBEUI_DIR, }); helpers.logger.info("Python venv successfully created."); } catch (err) { helpers.logger.error("Failed to create Python venv:", err); // Clean up partially created venv if needed try { if (existsSync(venvPath)) { rmSync(venvPath, { recursive: true, force: true }); helpers.logger.warn("Removed broken venv directory."); } } catch (cleanupErr) { helpers.logger.error("Error while cleaning up broken venv:", cleanupErr); } throw new Error("Python venv creation failed. Check if python3 and python3-venv are installed."); } } else { helpers.logger.info("Using existing Python venv."); } } async function ffprobe(videoPath: string): Promise<{ fps: number; frames: number }> { const spawn = await getNanoSpawn() const { stdout } = await spawn('ffprobe', [ '-v', 'error', '-select_streams', 'v:0', '-count_frames', '-show_entries', 'stream=nb_read_frames,r_frame_rate', '-of', 'default=nokey=1:noprint_wrappers=1', videoPath, ]) const [frameRateStr, frameCountStr] = stdout.trim().split('\n') const [num, denom] = frameRateStr.trim().split('/').map(Number) const fps = num / denom const frames = parseInt(frameCountStr.trim(), 10) return { fps, frames } } export async function buildFunscript( helpers: Helpers, predictionOutput: string, videoPath: string ): Promise { const labelDir = join(predictionOutput, 'labels'); const yamlPath = join(predictionOutput, 'data.yaml'); const outputPath = join(process.env.CACHE_ROOT ?? '/tmp', `${nanoid()}.funscript`); helpers.logger.info('Starting Funscript generation'); try { const data = await loadDataYaml(join(env.VIBEUI_DIR, 'data.yaml')) const classPositionMap = await loadClassPositionMap(data, helpers); const { fps, totalFrames } = await loadVideoMetadata(videoPath, helpers); const detectionSegments = await processLabelFiles(labelDir, helpers, data); const totalDurationMs = Math.floor((totalFrames / fps) * 1000); const actions = generateActions(totalDurationMs, fps, detectionSegments, classPositionMap); await writeFunscript(outputPath, actions, helpers); return outputPath; } catch (error) { helpers.logger.error(`Error generating Funscript: ${error instanceof Error ? error.message : 'Unknown error'}`); throw error; } } export async function inference(helpers: Helpers, videoFilePath: string): Promise { const spawn = await getNanoSpawn() const modelPath = join(env.VIBEUI_DIR, 'runs/detect/vibeui/weights/best.pt') // Generate a unique name based on video name + UUID const videoExt = extname(videoFilePath) // e.g. '.mp4' const videoName = basename(videoFilePath, videoExt) // removes the extension const uniqueName = `${videoName}-${nanoid()}` const customProjectDir = 'runs' // or any custom folder const outputPath = join(env.VIBEUI_DIR, customProjectDir, uniqueName) await spawn('./venv/bin/yolo', [ 'predict', `model=${modelPath}`, `source=${videoFilePath}`, 'save=False', 'save_txt=True', 'save_conf=True', `project=${customProjectDir}`, `name=${uniqueName}`, ], { cwd: env.VIBEUI_DIR, stdio: 'inherit', }) return outputPath // contains labels/ folder and predictions } async function loadClassPositionMap(data: DataYaml, helpers: Helpers): Promise { try { if ( !data || typeof data !== 'object' || !('names' in data) || typeof data.names !== 'object' || data.names === null || Object.keys(data.names).length === 0 ) { throw new Error('Invalid data.yaml: "names" field is missing, not an object, or empty'); } const positionMap: ClassPositionMap = { ControlledByTipper: 50, ControlledByTipperHigh: 80, ControlledByTipperLow: 20, ControlledByTipperMedium: 50, ControlledByTipperUltrahigh: 95, Ring1: 30, Ring2: 40, Ring3: 50, Ring4: 60, Earthquake: 'pattern', Fireworks: 'pattern', Pulse: 'pattern', Wave: 'pattern', Pause: 0, RandomTime: 70, HighLevel: 80, LowLevel: 20, MediumLevel: 50, UltraHighLevel: 95 }; const names = Object.values(data.names); for (const name of names) { if (typeof name !== 'string' || name.trim() === '') { helpers.logger.info(`Skipping invalid class name: ${name}`); continue; } if (!(name in positionMap)) { helpers.logger.info(`No position mapping for class "${name}", defaulting to 0`); positionMap[name] = 0; } } helpers.logger.info(`Loaded class position map: ${JSON.stringify(positionMap)}`); return positionMap; } catch (error) { helpers.logger.error(`Error loading data.yaml: ${error instanceof Error ? error.message : 'Unknown error'}`); throw error; } } function generatePatternPositions(startMs: number, durationMs: number, className: string, fps: number): FunscriptAction[] { const actions: FunscriptAction[] = []; const frameDurationMs = 1000 / fps; const totalFrames = Math.floor(durationMs / frameDurationMs); const intervalMs = 100; for (let timeMs = 0; timeMs < durationMs; timeMs += intervalMs) { const progress = timeMs / durationMs; let pos = 0; switch (className) { case 'Pulse': pos = Math.round(50 * Math.sin(progress * 2 * Math.PI)); break; case 'Wave': pos = Math.round(50 + 50 * Math.sin(progress * 2 * Math.PI)); break; case 'Fireworks': pos = Math.random() > 0.5 ? 80 : 0; break; case 'Earthquake': pos = Math.round(90 * Math.sin(progress * 4 * Math.PI) + (Math.random() - 0.5) * 10); pos = Math.max(0, Math.min(90, pos)); break; } actions.push({ at: startMs + timeMs, pos }); } return actions; } async function loadVideoMetadata(videoPath: string, helpers: Helpers) { const { fps, frames: totalFrames } = await ffprobe(videoPath); helpers.logger.info(`Video metadata: fps=${fps}, frames=${totalFrames}`); return { fps, totalFrames }; } async function processLabelFiles(labelDir: string, helpers: Helpers, data: DataYaml): Promise { const labelFiles = (await readdir(labelDir)).filter(file => file.endsWith('.txt')); const detections: Map = new Map(); const names = data.names; for (const file of labelFiles) { const match = file.match(/(\d+)\.txt$/); if (!match) { helpers.logger.info(`Skipping invalid filename: ${file}`); continue; } const frameIndex = parseInt(match[1], 10); if (isNaN(frameIndex)) { helpers.logger.info(`Skipping invalid frame index from filename: ${file}`); continue; } const content = await readFile(join(labelDir, file), 'utf8'); const lines = content.trim().split('\n'); const frameDetections: Detection[] = []; let maxConfidence = 0; let selectedClassIndex = -1; for (const line of lines) { const parts = line.trim().split(/\s+/); if (parts.length < 6) continue; const classIndex = parseInt(parts[0], 10); const confidence = parseFloat(parts[5]); if (isNaN(classIndex) || isNaN(confidence)) continue; if (confidence >= 0.7 && confidence > maxConfidence) { maxConfidence = confidence; selectedClassIndex = classIndex; } } if (maxConfidence > 0) { const className = (data.names as Record)[selectedClassIndex.toString()]; if (className) { frameDetections.push({ startFrame: frameIndex, endFrame: frameIndex, className }); } } if (frameDetections.length > 0) { detections.set(frameIndex, frameDetections); } } // Merge overlapping detections into continuous segments const detectionSegments: Detection[] = []; let currentDetection: Detection | null = null; for (const [frameIndex, frameDetections] of detections.entries()) { for (const detection of frameDetections) { if (!currentDetection || currentDetection.className !== detection.className) { if (currentDetection) detectionSegments.push(currentDetection); currentDetection = { ...detection, endFrame: frameIndex }; } else { currentDetection.endFrame = frameIndex; } } } if (currentDetection) detectionSegments.push(currentDetection); return detectionSegments; } function generateActions(totalDurationMs: number, fps: number, detectionSegments: Detection[], classPositionMap: ClassPositionMap): FunscriptAction[] { const intervalMs = 100; const actions: FunscriptAction[] = []; // Generate static position actions for (let timeMs = 0; timeMs <= totalDurationMs; timeMs += intervalMs) { const frameIndex = Math.floor((timeMs / 1000) * fps); let position = 0; for (const segment of detectionSegments) { if (frameIndex >= segment.startFrame && frameIndex <= segment.endFrame) { const className = segment.className; if (typeof classPositionMap[className] === 'number') { position = classPositionMap[className]; break; } } } actions.push({ at: timeMs, pos: position }); } // Overlay pattern-based actions for (const segment of detectionSegments) { const className = segment.className; if (classPositionMap[className] === 'pattern') { const startMs = Math.floor((segment.startFrame / fps) * 1000); const durationMs = Math.floor(((segment.endFrame - segment.startFrame + 1) / fps) * 1000); const patternActions = generatePatternPositions(startMs, durationMs, className, fps); actions.push(...patternActions); } } // Sort actions by time and remove duplicates actions.sort((a, b) => a.at - b.at); const uniqueActions: FunscriptAction[] = []; let lastTime = -1; for (const action of actions) { if (action.at !== lastTime) { uniqueActions.push(action); lastTime = action.at; } } return uniqueActions; } async function writeFunscript(outputPath: string, actions: FunscriptAction[], helpers: Helpers) { const funscript: Funscript = { version: '1.0', actions }; await writeFile(outputPath, JSON.stringify(funscript, null, 2)); helpers.logger.info(`Funscript generated: ${outputPath} (${actions.length} actions)`); } const createFunscript: Task = async (payload: any, helpers: Helpers) => { assertPayload(payload); const { vodId } = payload; const vod = await prisma.vod.findFirstOrThrow({ where: { id: vodId } }); await preparePython(helpers) if (vod.funscript) { helpers.logger.info(`Doing nothing-- vod ${vodId} already has a funscript.`); return; } if (!vod.sourceVideo) { const msg = `Cannot create funscript: Vod ${vodId} is missing a source video.`; helpers.logger.warn(msg); throw new Error(msg); } const s3Client = getS3Client(); const videoFilePath = await getOrDownloadAsset(s3Client, env.S3_BUCKET, vod.sourceVideo); helpers.logger.info(`Downloaded video to ${videoFilePath}`); helpers.logger.info(`Creating funscript for vod ${vodId}...`); const predictionOutput = await inference(helpers, videoFilePath); helpers.logger.info(`prediction output ${predictionOutput}`); const funscriptFilePath = await buildFunscript(helpers, predictionOutput, videoFilePath) const s3Key = `funscripts/${vodId}.funscript`; const s3Url = await uploadFile(s3Client, env.S3_BUCKET, s3Key, funscriptFilePath, "application/json"); helpers.logger.info(`Uploaded funscript to S3: ${s3Url}`); await prisma.vod.update({ where: { id: vodId }, data: { funscript: s3Key } }); helpers.logger.info(`Funscript saved to database for vod ${vodId}`); }; export default createFunscript;