Commit acf6a7c1 by PLN (Algolia)

perf(sample-classify): batch CLAP per folder + parallel decode + progress

One forward pass per folder (text tower computed once) instead of per-file —
~Nx fewer forwards; ffmpeg decode on a thread pool. validate prints a live
[i/N] running-accuracy line. source field now reflects method:mode.
parent bb70f394
......@@ -24,6 +24,7 @@ import json
import subprocess
import sys
from collections import Counter
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
import numpy as np
......@@ -121,18 +122,32 @@ def _clap():
def clap_vector(path):
"""CLAP family-probability vector {family: prob} for one sample, or None.
logits_per_audio → softmax over descriptors → marginalize (sum) to families."""
a = load_audio(path)
if a is None:
return None
return clap_vectors([path])[0]
def clap_vectors(paths):
"""Batched CLAP: ONE forward pass for all decodable files (the text tower is
computed once, so batching N audios is ~Nx fewer forwards than per-file). Decode
runs on a small thread pool (ffmpeg is I/O-bound). Returns a list aligned to
`paths`, with None where decode failed."""
out = [None] * len(paths)
with ThreadPoolExecutor(max_workers=4) as ex:
decoded = list(ex.map(load_audio, paths))
audios = [(i, a) for i, a in enumerate(decoded) if a is not None]
if not audios:
return out
C = _clap()
torch = C["torch"]
with torch.no_grad():
ai = C["proc"](audio=a, sampling_rate=SR, return_tensors="pt")
out = C["model"](input_ids=C["ids"], attention_mask=C["mask"],
ai = C["proc"](audio=[a for _, a in audios], sampling_rate=SR,
return_tensors="pt", padding=True)
res = C["model"](input_ids=C["ids"], attention_mask=C["mask"],
input_features=ai["input_features"])
probs = torch.softmax(out.logits_per_audio.squeeze(0), dim=-1)
fam_probs = (probs @ C["onehot"]).tolist()
return dict(zip(C["fam_keys"], fam_probs))
probs = torch.softmax(res.logits_per_audio, dim=-1) # (n_audio, n_text)
fam = (probs @ C["onehot"]).tolist() # (n_audio, n_fam)
for (i, _), row in zip(audios, fam):
out[i] = dict(zip(C["fam_keys"], row))
return out
def classify_file(path):
......@@ -157,16 +172,30 @@ def classify_file(path):
def classify_folder(name):
"""Aggregate per-file predictions for a folder → distribution + dominant."""
"""Aggregate per-file predictions for a folder → distribution + dominant.
CLAP is batched across the folder's files (one forward); PANNs stays per-file."""
files = folder_files(name)
if not files:
return None
files = files[:MAX_FILES]
clap_vs = clap_vectors(files) if METHOD in ("clap", "ensemble") else [None] * len(files)
if METHOD in ("panns", "ensemble"):
import sample_panns
dist, confs = Counter(), []
for f in files[:MAX_FILES]:
r = classify_file(f)
if r:
dist[r[0]] += 1
confs.append(r[1])
for i, f in enumerate(files):
vecs = []
if clap_vs[i]:
vecs.append(clap_vs[i])
if METHOD in ("panns", "ensemble"):
pv = sample_panns.family_vector(f)
if pv:
vecs.append(pv)
if not vecs:
continue
avg = {k: sum(v.get(k, 0.0) for v in vecs) / len(vecs) for k in ONT.FAMILIES}
fam = max(avg, key=avg.get)
dist[fam] += 1
confs.append(round(avg[fam], 3))
n = sum(dist.values())
if not n:
return None
......@@ -178,7 +207,7 @@ def classify_folder(name):
"conf": round(dn / n, 3), # fraction of files agreeing
"mean_file_conf": round(float(np.mean(confs)), 3),
"homogeneous": dn / n >= 0.6, # else it's a kit / mixed
"source": "clap:" + MODEL,
"source": f"{METHOD}:{MODE}" + (f":{MODEL}" if METHOD != "panns" else ""),
}
......@@ -202,7 +231,7 @@ def cmd_validate():
print(f"⛵ validate — {len(gt)} name-confident folders with audio present\n")
ok, total, per = 0, 0, Counter()
confmat = Counter()
for s, fam in gt:
for i, (s, fam) in enumerate(gt, 1):
r = classify_folder(s)
if not r:
continue
......@@ -212,8 +241,9 @@ def cmd_validate():
per[fam + (":✓" if hit else ":✗")] += 1
if not hit:
confmat[f"{fam}→{r['dominant']}"] += 1
print(f" {'✓' if hit else '✗'} {s:<22} name={fam:<6} clap={r['dominant']:<6} "
f"conf={r['conf']} ({r['n']} files)")
print(f" [{i:>2}/{len(gt)}] {'✓' if hit else '✗'} {s:<22} name={fam:<6} "
f"got={r['dominant']:<6} conf={r['conf']} ({r['n']} files) "
f"[running {ok}/{total}={ok/total*100:.0f}%]", flush=True)
print(f"\n top-1 accuracy: {ok}/{total} = {ok/total*100:.0f}%" if total else " no data")
if confmat:
print(" confusions:", dict(confmat.most_common()))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment