Commit acf6a7c1 by PLN (Algolia)

perf(sample-classify): batch CLAP per folder + parallel decode + progress

One forward pass per folder (text tower computed once) instead of per-file —
~Nx fewer forwards; ffmpeg decode on a thread pool. validate prints a live
[i/N] running-accuracy line. source field now reflects method:mode.
parent bb70f394
...@@ -24,6 +24,7 @@ import json ...@@ -24,6 +24,7 @@ import json
import subprocess import subprocess
import sys import sys
from collections import Counter from collections import Counter
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path from pathlib import Path
import numpy as np import numpy as np
...@@ -121,18 +122,32 @@ def _clap(): ...@@ -121,18 +122,32 @@ def _clap():
def clap_vector(path): def clap_vector(path):
"""CLAP family-probability vector {family: prob} for one sample, or None. """CLAP family-probability vector {family: prob} for one sample, or None.
logits_per_audio → softmax over descriptors → marginalize (sum) to families.""" logits_per_audio → softmax over descriptors → marginalize (sum) to families."""
a = load_audio(path) return clap_vectors([path])[0]
if a is None:
return None
def clap_vectors(paths):
"""Batched CLAP: ONE forward pass for all decodable files (the text tower is
computed once, so batching N audios is ~Nx fewer forwards than per-file). Decode
runs on a small thread pool (ffmpeg is I/O-bound). Returns a list aligned to
`paths`, with None where decode failed."""
out = [None] * len(paths)
with ThreadPoolExecutor(max_workers=4) as ex:
decoded = list(ex.map(load_audio, paths))
audios = [(i, a) for i, a in enumerate(decoded) if a is not None]
if not audios:
return out
C = _clap() C = _clap()
torch = C["torch"] torch = C["torch"]
with torch.no_grad(): with torch.no_grad():
ai = C["proc"](audio=a, sampling_rate=SR, return_tensors="pt") ai = C["proc"](audio=[a for _, a in audios], sampling_rate=SR,
out = C["model"](input_ids=C["ids"], attention_mask=C["mask"], return_tensors="pt", padding=True)
res = C["model"](input_ids=C["ids"], attention_mask=C["mask"],
input_features=ai["input_features"]) input_features=ai["input_features"])
probs = torch.softmax(out.logits_per_audio.squeeze(0), dim=-1) probs = torch.softmax(res.logits_per_audio, dim=-1) # (n_audio, n_text)
fam_probs = (probs @ C["onehot"]).tolist() fam = (probs @ C["onehot"]).tolist() # (n_audio, n_fam)
return dict(zip(C["fam_keys"], fam_probs)) for (i, _), row in zip(audios, fam):
out[i] = dict(zip(C["fam_keys"], row))
return out
def classify_file(path): def classify_file(path):
...@@ -157,16 +172,30 @@ def classify_file(path): ...@@ -157,16 +172,30 @@ def classify_file(path):
def classify_folder(name): def classify_folder(name):
"""Aggregate per-file predictions for a folder → distribution + dominant.""" """Aggregate per-file predictions for a folder → distribution + dominant.
CLAP is batched across the folder's files (one forward); PANNs stays per-file."""
files = folder_files(name) files = folder_files(name)
if not files: if not files:
return None return None
files = files[:MAX_FILES]
clap_vs = clap_vectors(files) if METHOD in ("clap", "ensemble") else [None] * len(files)
if METHOD in ("panns", "ensemble"):
import sample_panns
dist, confs = Counter(), [] dist, confs = Counter(), []
for f in files[:MAX_FILES]: for i, f in enumerate(files):
r = classify_file(f) vecs = []
if r: if clap_vs[i]:
dist[r[0]] += 1 vecs.append(clap_vs[i])
confs.append(r[1]) if METHOD in ("panns", "ensemble"):
pv = sample_panns.family_vector(f)
if pv:
vecs.append(pv)
if not vecs:
continue
avg = {k: sum(v.get(k, 0.0) for v in vecs) / len(vecs) for k in ONT.FAMILIES}
fam = max(avg, key=avg.get)
dist[fam] += 1
confs.append(round(avg[fam], 3))
n = sum(dist.values()) n = sum(dist.values())
if not n: if not n:
return None return None
...@@ -178,7 +207,7 @@ def classify_folder(name): ...@@ -178,7 +207,7 @@ def classify_folder(name):
"conf": round(dn / n, 3), # fraction of files agreeing "conf": round(dn / n, 3), # fraction of files agreeing
"mean_file_conf": round(float(np.mean(confs)), 3), "mean_file_conf": round(float(np.mean(confs)), 3),
"homogeneous": dn / n >= 0.6, # else it's a kit / mixed "homogeneous": dn / n >= 0.6, # else it's a kit / mixed
"source": "clap:" + MODEL, "source": f"{METHOD}:{MODE}" + (f":{MODEL}" if METHOD != "panns" else ""),
} }
...@@ -202,7 +231,7 @@ def cmd_validate(): ...@@ -202,7 +231,7 @@ def cmd_validate():
print(f"⛵ validate — {len(gt)} name-confident folders with audio present\n") print(f"⛵ validate — {len(gt)} name-confident folders with audio present\n")
ok, total, per = 0, 0, Counter() ok, total, per = 0, 0, Counter()
confmat = Counter() confmat = Counter()
for s, fam in gt: for i, (s, fam) in enumerate(gt, 1):
r = classify_folder(s) r = classify_folder(s)
if not r: if not r:
continue continue
...@@ -212,8 +241,9 @@ def cmd_validate(): ...@@ -212,8 +241,9 @@ def cmd_validate():
per[fam + (":✓" if hit else ":✗")] += 1 per[fam + (":✓" if hit else ":✗")] += 1
if not hit: if not hit:
confmat[f"{fam}→{r['dominant']}"] += 1 confmat[f"{fam}→{r['dominant']}"] += 1
print(f" {'✓' if hit else '✗'} {s:<22} name={fam:<6} clap={r['dominant']:<6} " print(f" [{i:>2}/{len(gt)}] {'✓' if hit else '✗'} {s:<22} name={fam:<6} "
f"conf={r['conf']} ({r['n']} files)") f"got={r['dominant']:<6} conf={r['conf']} ({r['n']} files) "
f"[running {ok}/{total}={ok/total*100:.0f}%]", flush=True)
print(f"\n top-1 accuracy: {ok}/{total} = {ok/total*100:.0f}%" if total else " no data") print(f"\n top-1 accuracy: {ok}/{total} = {ok/total*100:.0f}%" if total else " no data")
if confmat: if confmat:
print(" confusions:", dict(confmat.most_common())) print(" confusions:", dict(confmat.most_common()))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment