418 lines
13 KiB
TypeScript
418 lines
13 KiB
TypeScript
import type { Task, Helpers } from "graphile-worker";
|
|
import { PrismaClient } from "../../generated/prisma";
|
|
import { withAccelerate } from "@prisma/extension-accelerate";
|
|
import { getOrDownloadAsset } from "../utils/cache";
|
|
import { env } from "../config/env";
|
|
import { getS3Client, uploadFile } from "../utils/s3";
|
|
import { nanoid } from "nanoid";
|
|
import { existsSync } from "node:fs";
|
|
import spawn from 'nano-spawn';
|
|
import { join, basename, extname } from "node:path";
|
|
import { readFile, writeFile, readdir } from 'node:fs/promises';
|
|
import yaml from 'js-yaml';
|
|
import { string } from "zod";
|
|
|
|
const prisma = new PrismaClient().$extends(withAccelerate());
|
|
|
|
interface Payload {
|
|
vodId: string;
|
|
}
|
|
|
|
interface Detection {
|
|
startFrame: number;
|
|
endFrame: number;
|
|
className: string;
|
|
}
|
|
|
|
interface DataYaml {
|
|
path: string;
|
|
train: string;
|
|
val: string;
|
|
names: Record<string, string>;
|
|
}
|
|
|
|
interface FunscriptAction {
|
|
at: number;
|
|
pos: number;
|
|
}
|
|
|
|
interface Funscript {
|
|
version: string;
|
|
actions: FunscriptAction[];
|
|
}
|
|
|
|
|
|
interface ClassPositionMap {
|
|
[className: string]: number | 'pattern';
|
|
}
|
|
|
|
function assertPayload(payload: any): asserts payload is Payload {
|
|
if (typeof payload !== "object" || !payload) throw new Error("invalid payload-- was not an object.");
|
|
if (typeof payload.vodId !== "string") throw new Error("invalid payload-- was missing vodId");
|
|
}
|
|
|
|
async function loadDataYaml(yamlPath: string): Promise<DataYaml> {
|
|
const yamlContent = await readFile(yamlPath, 'utf8');
|
|
return yaml.load(yamlContent) as DataYaml;
|
|
}
|
|
|
|
|
|
async function preparePython(helpers) {
|
|
const venvPath = join(env.VIBEUI_DIR, 'venv')
|
|
|
|
// Create venv if it doesn't exist
|
|
if (!existsSync(venvPath)) {
|
|
helpers.logger.info("Python venv not found. Creating one...");
|
|
await spawn("python", ["-m", "venv", "venv"], {
|
|
cwd: env.VIBEUI_DIR,
|
|
});
|
|
} else {
|
|
helpers.logger.info("Using existing Python venv.");
|
|
}
|
|
}
|
|
|
|
|
|
async function ffprobe(videoPath: string): Promise<{ fps: number; frames: number }> {
|
|
const { stdout } = await spawn('ffprobe', [
|
|
'-v', 'error',
|
|
'-select_streams', 'v:0',
|
|
'-count_frames',
|
|
'-show_entries', 'stream=nb_read_frames,r_frame_rate',
|
|
'-of', 'default=nokey=1:noprint_wrappers=1',
|
|
videoPath,
|
|
])
|
|
|
|
const [frameRateStr, frameCountStr] = stdout.trim().split('\n')
|
|
const [num, denom] = frameRateStr.trim().split('/').map(Number)
|
|
const fps = num / denom
|
|
const frames = parseInt(frameCountStr.trim(), 10)
|
|
|
|
return { fps, frames }
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
export async function buildFunscript(
|
|
helpers: Helpers,
|
|
predictionOutput: string,
|
|
videoPath: string
|
|
): Promise<string> {
|
|
const labelDir = join(predictionOutput, 'labels');
|
|
const yamlPath = join(predictionOutput, 'data.yaml');
|
|
const outputPath = join(process.env.CACHE_ROOT ?? '/tmp', `${nanoid()}.funscript`);
|
|
helpers.logger.info('Starting Funscript generation');
|
|
|
|
try {
|
|
|
|
const data = await loadDataYaml(join(env.VIBEUI_DIR, 'data.yaml'))
|
|
const classPositionMap = await loadClassPositionMap(data, helpers);
|
|
const { fps, totalFrames } = await loadVideoMetadata(videoPath, helpers);
|
|
const detectionSegments = await processLabelFiles(labelDir, helpers, data);
|
|
const totalDurationMs = Math.floor((totalFrames / fps) * 1000);
|
|
const actions = generateActions(totalDurationMs, fps, detectionSegments, classPositionMap);
|
|
await writeFunscript(outputPath, actions, helpers);
|
|
|
|
return outputPath;
|
|
} catch (error) {
|
|
helpers.logger.error(`Error generating Funscript: ${error instanceof Error ? error.message : 'Unknown error'}`);
|
|
throw error;
|
|
}
|
|
}
|
|
|
|
export async function inference(helpers: Helpers, videoFilePath: string): Promise<string> {
|
|
const modelPath = join(env.VIBEUI_DIR, 'runs/detect/vibeui/weights/best.pt')
|
|
|
|
// Generate a unique name based on video name + UUID
|
|
const videoExt = extname(videoFilePath) // e.g. '.mp4'
|
|
const videoName = basename(videoFilePath, videoExt) // removes the extension
|
|
const uniqueName = `${videoName}-${nanoid()}`
|
|
const customProjectDir = 'runs' // or any custom folder
|
|
const outputPath = join(env.VIBEUI_DIR, customProjectDir, uniqueName)
|
|
|
|
await spawn('./venv/bin/yolo', [
|
|
'predict',
|
|
`model=${modelPath}`,
|
|
`source=${videoFilePath}`,
|
|
'save=False',
|
|
'save_txt=True',
|
|
'save_conf=True',
|
|
`project=${customProjectDir}`,
|
|
`name=${uniqueName}`,
|
|
], {
|
|
cwd: env.VIBEUI_DIR,
|
|
stdio: 'inherit',
|
|
})
|
|
|
|
return outputPath // contains labels/ folder and predictions
|
|
}
|
|
|
|
|
|
async function loadClassPositionMap(data: DataYaml, helpers: Helpers): Promise<ClassPositionMap> {
|
|
try {
|
|
|
|
if (
|
|
!data ||
|
|
typeof data !== 'object' ||
|
|
!('names' in data) ||
|
|
typeof data.names !== 'object' ||
|
|
data.names === null ||
|
|
Object.keys(data.names).length === 0
|
|
) {
|
|
throw new Error('Invalid data.yaml: "names" field is missing, not an object, or empty');
|
|
}
|
|
|
|
const positionMap: ClassPositionMap = {
|
|
ControlledByTipper: 50,
|
|
ControlledByTipperHigh: 80,
|
|
ControlledByTipperLow: 20,
|
|
ControlledByTipperMedium: 50,
|
|
ControlledByTipperUltrahigh: 95,
|
|
Ring1: 30,
|
|
Ring2: 40,
|
|
Ring3: 50,
|
|
Ring4: 60,
|
|
Earthquake: 'pattern',
|
|
Fireworks: 'pattern',
|
|
Pulse: 'pattern',
|
|
Wave: 'pattern',
|
|
Pause: 0,
|
|
RandomTime: 70,
|
|
HighLevel: 80,
|
|
LowLevel: 20,
|
|
MediumLevel: 50,
|
|
UltraHighLevel: 95
|
|
};
|
|
|
|
const names = Object.values(data.names);
|
|
for (const name of names) {
|
|
if (typeof name !== 'string' || name.trim() === '') {
|
|
helpers.logger.info(`Skipping invalid class name: ${name}`);
|
|
continue;
|
|
}
|
|
if (!(name in positionMap)) {
|
|
helpers.logger.info(`No position mapping for class "${name}", defaulting to 0`);
|
|
positionMap[name] = 0;
|
|
}
|
|
}
|
|
|
|
helpers.logger.info(`Loaded class position map: ${JSON.stringify(positionMap)}`);
|
|
return positionMap;
|
|
} catch (error) {
|
|
helpers.logger.error(`Error loading data.yaml: ${error instanceof Error ? error.message : 'Unknown error'}`);
|
|
throw error;
|
|
}
|
|
}
|
|
|
|
function generatePatternPositions(startMs: number, durationMs: number, className: string, fps: number): FunscriptAction[] {
|
|
const actions: FunscriptAction[] = [];
|
|
const frameDurationMs = 1000 / fps;
|
|
const totalFrames = Math.floor(durationMs / frameDurationMs);
|
|
const intervalMs = 100;
|
|
|
|
for (let timeMs = 0; timeMs < durationMs; timeMs += intervalMs) {
|
|
const progress = timeMs / durationMs;
|
|
let pos = 0;
|
|
|
|
switch (className) {
|
|
case 'Pulse':
|
|
pos = Math.round(50 * Math.sin(progress * 2 * Math.PI));
|
|
break;
|
|
case 'Wave':
|
|
pos = Math.round(50 + 50 * Math.sin(progress * 2 * Math.PI));
|
|
break;
|
|
case 'Fireworks':
|
|
pos = Math.random() > 0.5 ? 80 : 0;
|
|
break;
|
|
case 'Earthquake':
|
|
pos = Math.round(90 * Math.sin(progress * 4 * Math.PI) + (Math.random() - 0.5) * 10);
|
|
pos = Math.max(0, Math.min(90, pos));
|
|
break;
|
|
}
|
|
|
|
actions.push({ at: startMs + timeMs, pos });
|
|
}
|
|
|
|
return actions;
|
|
}
|
|
|
|
async function loadVideoMetadata(videoPath: string, helpers: Helpers) {
|
|
const { fps, frames: totalFrames } = await ffprobe(videoPath);
|
|
helpers.logger.info(`Video metadata: fps=${fps}, frames=${totalFrames}`);
|
|
return { fps, totalFrames };
|
|
}
|
|
|
|
async function processLabelFiles(labelDir: string, helpers: Helpers, data: DataYaml): Promise<Detection[]> {
|
|
const labelFiles = (await readdir(labelDir)).filter(file => file.endsWith('.txt'));
|
|
const detections: Map<number, Detection[]> = new Map();
|
|
const names = data.names;
|
|
|
|
for (const file of labelFiles) {
|
|
const match = file.match(/(\d+)\.txt$/);
|
|
if (!match) {
|
|
helpers.logger.info(`Skipping invalid filename: ${file}`);
|
|
continue;
|
|
}
|
|
const frameIndex = parseInt(match[1], 10);
|
|
if (isNaN(frameIndex)) {
|
|
helpers.logger.info(`Skipping invalid frame index from filename: ${file}`);
|
|
continue;
|
|
}
|
|
|
|
const content = await readFile(join(labelDir, file), 'utf8');
|
|
const lines = content.trim().split('\n');
|
|
const frameDetections: Detection[] = [];
|
|
let maxConfidence = 0;
|
|
let selectedClassIndex = -1;
|
|
|
|
for (const line of lines) {
|
|
const parts = line.trim().split(/\s+/);
|
|
if (parts.length < 6) continue;
|
|
|
|
const classIndex = parseInt(parts[0], 10);
|
|
const confidence = parseFloat(parts[5]);
|
|
if (isNaN(classIndex) || isNaN(confidence)) continue;
|
|
|
|
if (confidence >= 0.7 && confidence > maxConfidence) {
|
|
maxConfidence = confidence;
|
|
selectedClassIndex = classIndex;
|
|
}
|
|
}
|
|
|
|
if (maxConfidence > 0) {
|
|
const className = (data.names as Record<string, string>)[selectedClassIndex.toString()];
|
|
if (className) {
|
|
frameDetections.push({ startFrame: frameIndex, endFrame: frameIndex, className });
|
|
}
|
|
}
|
|
|
|
if (frameDetections.length > 0) {
|
|
detections.set(frameIndex, frameDetections);
|
|
}
|
|
}
|
|
|
|
// Merge overlapping detections into continuous segments
|
|
const detectionSegments: Detection[] = [];
|
|
let currentDetection: Detection | null = null;
|
|
|
|
for (const [frameIndex, frameDetections] of detections.entries()) {
|
|
for (const detection of frameDetections) {
|
|
if (!currentDetection || currentDetection.className !== detection.className) {
|
|
if (currentDetection) detectionSegments.push(currentDetection);
|
|
currentDetection = { ...detection, endFrame: frameIndex };
|
|
} else {
|
|
currentDetection.endFrame = frameIndex;
|
|
}
|
|
}
|
|
}
|
|
if (currentDetection) detectionSegments.push(currentDetection);
|
|
|
|
return detectionSegments;
|
|
}
|
|
|
|
function generateActions(totalDurationMs: number, fps: number, detectionSegments: Detection[], classPositionMap: ClassPositionMap): FunscriptAction[] {
|
|
const intervalMs = 100;
|
|
const actions: FunscriptAction[] = [];
|
|
|
|
// Generate static position actions
|
|
for (let timeMs = 0; timeMs <= totalDurationMs; timeMs += intervalMs) {
|
|
const frameIndex = Math.floor((timeMs / 1000) * fps);
|
|
let position = 0;
|
|
|
|
for (const segment of detectionSegments) {
|
|
if (frameIndex >= segment.startFrame && frameIndex <= segment.endFrame) {
|
|
const className = segment.className;
|
|
if (typeof classPositionMap[className] === 'number') {
|
|
position = classPositionMap[className];
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
actions.push({ at: timeMs, pos: position });
|
|
}
|
|
|
|
// Overlay pattern-based actions
|
|
for (const segment of detectionSegments) {
|
|
const className = segment.className;
|
|
if (classPositionMap[className] === 'pattern') {
|
|
const startMs = Math.floor((segment.startFrame / fps) * 1000);
|
|
const durationMs = Math.floor(((segment.endFrame - segment.startFrame + 1) / fps) * 1000);
|
|
const patternActions = generatePatternPositions(startMs, durationMs, className, fps);
|
|
actions.push(...patternActions);
|
|
}
|
|
}
|
|
|
|
// Sort actions by time and remove duplicates
|
|
actions.sort((a, b) => a.at - b.at);
|
|
const uniqueActions: FunscriptAction[] = [];
|
|
let lastTime = -1;
|
|
for (const action of actions) {
|
|
if (action.at !== lastTime) {
|
|
uniqueActions.push(action);
|
|
lastTime = action.at;
|
|
}
|
|
}
|
|
|
|
return uniqueActions;
|
|
}
|
|
|
|
async function writeFunscript(outputPath: string, actions: FunscriptAction[], helpers: Helpers) {
|
|
const funscript: Funscript = { version: '1.0', actions };
|
|
await writeFile(outputPath, JSON.stringify(funscript, null, 2));
|
|
helpers.logger.info(`Funscript generated: ${outputPath} (${actions.length} actions)`);
|
|
}
|
|
|
|
|
|
const createFunscript: Task = async (payload: any, helpers: Helpers) => {
|
|
assertPayload(payload);
|
|
const { vodId } = payload;
|
|
|
|
|
|
|
|
const vod = await prisma.vod.findFirstOrThrow({ where: { id: vodId } });
|
|
await preparePython(helpers)
|
|
|
|
if (vod.funscript) {
|
|
helpers.logger.info(`Doing nothing-- vod ${vodId} already has a funscript.`);
|
|
return;
|
|
}
|
|
|
|
if (!vod.sourceVideo) {
|
|
const msg = `Cannot create funscript: Vod ${vodId} is missing a source video.`;
|
|
helpers.logger.warn(msg);
|
|
throw new Error(msg);
|
|
}
|
|
|
|
|
|
|
|
const s3Client = getS3Client();
|
|
const videoFilePath = await getOrDownloadAsset(s3Client, env.S3_BUCKET, vod.sourceVideo);
|
|
helpers.logger.info(`Downloaded video to ${videoFilePath}`);
|
|
|
|
helpers.logger.info(`Creating funscript for vod ${vodId}...`);
|
|
|
|
const predictionOutput = await inference(helpers, videoFilePath);
|
|
helpers.logger.info(`prediction output ${predictionOutput}`);
|
|
|
|
|
|
const funscriptFilePath = await buildFunscript(helpers, predictionOutput, videoFilePath)
|
|
|
|
|
|
const s3Key = `funscripts/${vodId}.funscript`;
|
|
const s3Url = await uploadFile(s3Client, env.S3_BUCKET, s3Key, funscriptFilePath, "application/json");
|
|
|
|
helpers.logger.info(`Uploaded funscript to S3: ${s3Url}`);
|
|
|
|
await prisma.vod.update({
|
|
where: { id: vodId },
|
|
data: { funscript: s3Key }
|
|
});
|
|
|
|
helpers.logger.info(`Funscript saved to database for vod ${vodId}`);
|
|
};
|
|
|
|
export default createFunscript;
|