fp/services/our/src/tasks/createFunscript.ts
CJ_Clippy c386e48dcf
Some checks failed
ci / build (push) Failing after 1s
ci / Tests & Checks (push) Failing after 1s
add buttplug vjs plugin
2025-07-13 01:04:45 -08:00

418 lines
13 KiB
TypeScript

import type { Task, Helpers } from "graphile-worker";
import { PrismaClient } from "../../generated/prisma";
import { withAccelerate } from "@prisma/extension-accelerate";
import { getOrDownloadAsset } from "../utils/cache";
import { env } from "../config/env";
import { getS3Client, uploadFile } from "../utils/s3";
import { nanoid } from "nanoid";
import { existsSync } from "node:fs";
import spawn from 'nano-spawn';
import { join, basename, extname } from "node:path";
import { readFile, writeFile, readdir } from 'node:fs/promises';
import yaml from 'js-yaml';
import { string } from "zod";
const prisma = new PrismaClient().$extends(withAccelerate());
interface Payload {
vodId: string;
}
interface Detection {
startFrame: number;
endFrame: number;
className: string;
}
interface DataYaml {
path: string;
train: string;
val: string;
names: Record<string, string>;
}
interface FunscriptAction {
at: number;
pos: number;
}
interface Funscript {
version: string;
actions: FunscriptAction[];
}
interface ClassPositionMap {
[className: string]: number | 'pattern';
}
function assertPayload(payload: any): asserts payload is Payload {
if (typeof payload !== "object" || !payload) throw new Error("invalid payload-- was not an object.");
if (typeof payload.vodId !== "string") throw new Error("invalid payload-- was missing vodId");
}
async function loadDataYaml(yamlPath: string): Promise<DataYaml> {
const yamlContent = await readFile(yamlPath, 'utf8');
return yaml.load(yamlContent) as DataYaml;
}
async function preparePython(helpers) {
const venvPath = join(env.VIBEUI_DIR, 'venv')
// Create venv if it doesn't exist
if (!existsSync(venvPath)) {
helpers.logger.info("Python venv not found. Creating one...");
await spawn("python", ["-m", "venv", "venv"], {
cwd: env.VIBEUI_DIR,
});
} else {
helpers.logger.info("Using existing Python venv.");
}
}
async function ffprobe(videoPath: string): Promise<{ fps: number; frames: number }> {
const { stdout } = await spawn('ffprobe', [
'-v', 'error',
'-select_streams', 'v:0',
'-count_frames',
'-show_entries', 'stream=nb_read_frames,r_frame_rate',
'-of', 'default=nokey=1:noprint_wrappers=1',
videoPath,
])
const [frameRateStr, frameCountStr] = stdout.trim().split('\n')
const [num, denom] = frameRateStr.trim().split('/').map(Number)
const fps = num / denom
const frames = parseInt(frameCountStr.trim(), 10)
return { fps, frames }
}
export async function buildFunscript(
helpers: Helpers,
predictionOutput: string,
videoPath: string
): Promise<string> {
const labelDir = join(predictionOutput, 'labels');
const yamlPath = join(predictionOutput, 'data.yaml');
const outputPath = join(process.env.CACHE_ROOT ?? '/tmp', `${nanoid()}.funscript`);
helpers.logger.info('Starting Funscript generation');
try {
const data = await loadDataYaml(join(env.VIBEUI_DIR, 'data.yaml'))
const classPositionMap = await loadClassPositionMap(data, helpers);
const { fps, totalFrames } = await loadVideoMetadata(videoPath, helpers);
const detectionSegments = await processLabelFiles(labelDir, helpers, data);
const totalDurationMs = Math.floor((totalFrames / fps) * 1000);
const actions = generateActions(totalDurationMs, fps, detectionSegments, classPositionMap);
await writeFunscript(outputPath, actions, helpers);
return outputPath;
} catch (error) {
helpers.logger.error(`Error generating Funscript: ${error instanceof Error ? error.message : 'Unknown error'}`);
throw error;
}
}
export async function inference(helpers: Helpers, videoFilePath: string): Promise<string> {
const modelPath = join(env.VIBEUI_DIR, 'runs/detect/vibeui/weights/best.pt')
// Generate a unique name based on video name + UUID
const videoExt = extname(videoFilePath) // e.g. '.mp4'
const videoName = basename(videoFilePath, videoExt) // removes the extension
const uniqueName = `${videoName}-${nanoid()}`
const customProjectDir = 'runs' // or any custom folder
const outputPath = join(env.VIBEUI_DIR, customProjectDir, uniqueName)
await spawn('./venv/bin/yolo', [
'predict',
`model=${modelPath}`,
`source=${videoFilePath}`,
'save=False',
'save_txt=True',
'save_conf=True',
`project=${customProjectDir}`,
`name=${uniqueName}`,
], {
cwd: env.VIBEUI_DIR,
stdio: 'inherit',
})
return outputPath // contains labels/ folder and predictions
}
async function loadClassPositionMap(data: DataYaml, helpers: Helpers): Promise<ClassPositionMap> {
try {
if (
!data ||
typeof data !== 'object' ||
!('names' in data) ||
typeof data.names !== 'object' ||
data.names === null ||
Object.keys(data.names).length === 0
) {
throw new Error('Invalid data.yaml: "names" field is missing, not an object, or empty');
}
const positionMap: ClassPositionMap = {
ControlledByTipper: 50,
ControlledByTipperHigh: 80,
ControlledByTipperLow: 20,
ControlledByTipperMedium: 50,
ControlledByTipperUltrahigh: 95,
Ring1: 30,
Ring2: 40,
Ring3: 50,
Ring4: 60,
Earthquake: 'pattern',
Fireworks: 'pattern',
Pulse: 'pattern',
Wave: 'pattern',
Pause: 0,
RandomTime: 70,
HighLevel: 80,
LowLevel: 20,
MediumLevel: 50,
UltraHighLevel: 95
};
const names = Object.values(data.names);
for (const name of names) {
if (typeof name !== 'string' || name.trim() === '') {
helpers.logger.info(`Skipping invalid class name: ${name}`);
continue;
}
if (!(name in positionMap)) {
helpers.logger.info(`No position mapping for class "${name}", defaulting to 0`);
positionMap[name] = 0;
}
}
helpers.logger.info(`Loaded class position map: ${JSON.stringify(positionMap)}`);
return positionMap;
} catch (error) {
helpers.logger.error(`Error loading data.yaml: ${error instanceof Error ? error.message : 'Unknown error'}`);
throw error;
}
}
function generatePatternPositions(startMs: number, durationMs: number, className: string, fps: number): FunscriptAction[] {
const actions: FunscriptAction[] = [];
const frameDurationMs = 1000 / fps;
const totalFrames = Math.floor(durationMs / frameDurationMs);
const intervalMs = 100;
for (let timeMs = 0; timeMs < durationMs; timeMs += intervalMs) {
const progress = timeMs / durationMs;
let pos = 0;
switch (className) {
case 'Pulse':
pos = Math.round(50 * Math.sin(progress * 2 * Math.PI));
break;
case 'Wave':
pos = Math.round(50 + 50 * Math.sin(progress * 2 * Math.PI));
break;
case 'Fireworks':
pos = Math.random() > 0.5 ? 80 : 0;
break;
case 'Earthquake':
pos = Math.round(90 * Math.sin(progress * 4 * Math.PI) + (Math.random() - 0.5) * 10);
pos = Math.max(0, Math.min(90, pos));
break;
}
actions.push({ at: startMs + timeMs, pos });
}
return actions;
}
async function loadVideoMetadata(videoPath: string, helpers: Helpers) {
const { fps, frames: totalFrames } = await ffprobe(videoPath);
helpers.logger.info(`Video metadata: fps=${fps}, frames=${totalFrames}`);
return { fps, totalFrames };
}
async function processLabelFiles(labelDir: string, helpers: Helpers, data: DataYaml): Promise<Detection[]> {
const labelFiles = (await readdir(labelDir)).filter(file => file.endsWith('.txt'));
const detections: Map<number, Detection[]> = new Map();
const names = data.names;
for (const file of labelFiles) {
const match = file.match(/(\d+)\.txt$/);
if (!match) {
helpers.logger.info(`Skipping invalid filename: ${file}`);
continue;
}
const frameIndex = parseInt(match[1], 10);
if (isNaN(frameIndex)) {
helpers.logger.info(`Skipping invalid frame index from filename: ${file}`);
continue;
}
const content = await readFile(join(labelDir, file), 'utf8');
const lines = content.trim().split('\n');
const frameDetections: Detection[] = [];
let maxConfidence = 0;
let selectedClassIndex = -1;
for (const line of lines) {
const parts = line.trim().split(/\s+/);
if (parts.length < 6) continue;
const classIndex = parseInt(parts[0], 10);
const confidence = parseFloat(parts[5]);
if (isNaN(classIndex) || isNaN(confidence)) continue;
if (confidence >= 0.7 && confidence > maxConfidence) {
maxConfidence = confidence;
selectedClassIndex = classIndex;
}
}
if (maxConfidence > 0) {
const className = (data.names as Record<string, string>)[selectedClassIndex.toString()];
if (className) {
frameDetections.push({ startFrame: frameIndex, endFrame: frameIndex, className });
}
}
if (frameDetections.length > 0) {
detections.set(frameIndex, frameDetections);
}
}
// Merge overlapping detections into continuous segments
const detectionSegments: Detection[] = [];
let currentDetection: Detection | null = null;
for (const [frameIndex, frameDetections] of detections.entries()) {
for (const detection of frameDetections) {
if (!currentDetection || currentDetection.className !== detection.className) {
if (currentDetection) detectionSegments.push(currentDetection);
currentDetection = { ...detection, endFrame: frameIndex };
} else {
currentDetection.endFrame = frameIndex;
}
}
}
if (currentDetection) detectionSegments.push(currentDetection);
return detectionSegments;
}
function generateActions(totalDurationMs: number, fps: number, detectionSegments: Detection[], classPositionMap: ClassPositionMap): FunscriptAction[] {
const intervalMs = 100;
const actions: FunscriptAction[] = [];
// Generate static position actions
for (let timeMs = 0; timeMs <= totalDurationMs; timeMs += intervalMs) {
const frameIndex = Math.floor((timeMs / 1000) * fps);
let position = 0;
for (const segment of detectionSegments) {
if (frameIndex >= segment.startFrame && frameIndex <= segment.endFrame) {
const className = segment.className;
if (typeof classPositionMap[className] === 'number') {
position = classPositionMap[className];
break;
}
}
}
actions.push({ at: timeMs, pos: position });
}
// Overlay pattern-based actions
for (const segment of detectionSegments) {
const className = segment.className;
if (classPositionMap[className] === 'pattern') {
const startMs = Math.floor((segment.startFrame / fps) * 1000);
const durationMs = Math.floor(((segment.endFrame - segment.startFrame + 1) / fps) * 1000);
const patternActions = generatePatternPositions(startMs, durationMs, className, fps);
actions.push(...patternActions);
}
}
// Sort actions by time and remove duplicates
actions.sort((a, b) => a.at - b.at);
const uniqueActions: FunscriptAction[] = [];
let lastTime = -1;
for (const action of actions) {
if (action.at !== lastTime) {
uniqueActions.push(action);
lastTime = action.at;
}
}
return uniqueActions;
}
async function writeFunscript(outputPath: string, actions: FunscriptAction[], helpers: Helpers) {
const funscript: Funscript = { version: '1.0', actions };
await writeFile(outputPath, JSON.stringify(funscript, null, 2));
helpers.logger.info(`Funscript generated: ${outputPath} (${actions.length} actions)`);
}
const createFunscript: Task = async (payload: any, helpers: Helpers) => {
assertPayload(payload);
const { vodId } = payload;
const vod = await prisma.vod.findFirstOrThrow({ where: { id: vodId } });
await preparePython(helpers)
if (vod.funscript) {
helpers.logger.info(`Doing nothing-- vod ${vodId} already has a funscript.`);
return;
}
if (!vod.sourceVideo) {
const msg = `Cannot create funscript: Vod ${vodId} is missing a source video.`;
helpers.logger.warn(msg);
throw new Error(msg);
}
const s3Client = getS3Client();
const videoFilePath = await getOrDownloadAsset(s3Client, env.S3_BUCKET, vod.sourceVideo);
helpers.logger.info(`Downloaded video to ${videoFilePath}`);
helpers.logger.info(`Creating funscript for vod ${vodId}...`);
const predictionOutput = await inference(helpers, videoFilePath);
helpers.logger.info(`prediction output ${predictionOutput}`);
const funscriptFilePath = await buildFunscript(helpers, predictionOutput, videoFilePath)
const s3Key = `funscripts/${vodId}.funscript`;
const s3Url = await uploadFile(s3Client, env.S3_BUCKET, s3Key, funscriptFilePath, "application/json");
helpers.logger.info(`Uploaded funscript to S3: ${s3Url}`);
await prisma.vod.update({
where: { id: vodId },
data: { funscript: s3Key }
});
helpers.logger.info(`Funscript saved to database for vod ${vodId}`);
};
export default createFunscript;