Commit 96854679 by PLN (Algolia)

perf(clap): encode the text tower once, audio-only forward per folder

The ~88 fine descriptors were re-encoded through RoBERTa on every folder's
forward — a fixed cost that made batch size irrelevant. Now cache the normalized
text embeddings + logit scale at load; per-folder forwards run only the audio
tower (get_audio_features) and a matmul against the cached text embeds. ~1.8x
(14s->7.7s/folder; text was ~45% of per-folder cost).

API note (transformers 5.10.2): get_text/audio_features return a model-output
object whose .pooler_output IS the projected 512-d joint embedding — verified
identical to a full ClapModel forward to 1.6e-7. No classification change.
parent 3250dcbb
...@@ -108,14 +108,25 @@ def _clap(): ...@@ -108,14 +108,25 @@ def _clap():
model = ClapModel.from_pretrained(MODEL).eval() model = ClapModel.from_pretrained(MODEL).eval()
proc = ClapProcessor.from_pretrained(MODEL) proc = ClapProcessor.from_pretrained(MODEL)
prompts, fams = active_prompts() prompts, fams = active_prompts()
# text tower inputs computed once; reused for every audio (cheap to re-run) # ENCODE THE TEXT TOWER ONCE. The prompts (≈88 fine descriptors) never change,
# so embedding them per folder re-ran RoBERTa every time — a fixed cost that
# dominated and made batch size irrelevant. We cache the normalized text embeds
# + logit scale; per-folder forwards then run only the audio tower (∝ n_files).
ti = proc(text=prompts, return_tensors="pt", padding=True) ti = proc(text=prompts, return_tensors="pt", padding=True)
with torch.no_grad():
# get_text_features returns a model-output object whose .pooler_output IS
# the projected 512-d joint embedding (verified identical to a full forward
# to 1e-7); normalize it the same way forward() does.
te = model.get_text_features(input_ids=ti["input_ids"],
attention_mask=ti["attention_mask"]).pooler_output
te = te / te.norm(p=2, dim=-1, keepdim=True)
# index → family, for marginalizing per-descriptor probs up to families # index → family, for marginalizing per-descriptor probs up to families
fam_keys = list(dict.fromkeys(fams)) fam_keys = list(dict.fromkeys(fams))
idx = torch.tensor([fam_keys.index(f) for f in fams]) idx = torch.tensor([fam_keys.index(f) for f in fams])
onehot = torch.zeros(len(fams), len(fam_keys)).scatter_(1, idx[:, None], 1.0) onehot = torch.zeros(len(fams), len(fam_keys)).scatter_(1, idx[:, None], 1.0)
_CLAP.update(torch=torch, model=model, proc=proc, ids=ti["input_ids"], _CLAP.update(torch=torch, model=model, proc=proc, text_embeds=te,
mask=ti["attention_mask"], onehot=onehot, fam_keys=fam_keys) logit_scale=model.logit_scale_a.exp().detach(),
onehot=onehot, fam_keys=fam_keys)
return _CLAP return _CLAP
...@@ -141,9 +152,10 @@ def clap_vectors(paths): ...@@ -141,9 +152,10 @@ def clap_vectors(paths):
with torch.no_grad(): with torch.no_grad():
ai = C["proc"](audio=[a for _, a in audios], sampling_rate=SR, ai = C["proc"](audio=[a for _, a in audios], sampling_rate=SR,
return_tensors="pt", padding=True) return_tensors="pt", padding=True)
res = C["model"](input_ids=C["ids"], attention_mask=C["mask"], ae = C["model"].get_audio_features(input_features=ai["input_features"]).pooler_output
input_features=ai["input_features"]) ae = ae / ae.norm(p=2, dim=-1, keepdim=True)
probs = torch.softmax(res.logits_per_audio, dim=-1) # (n_audio, n_text) logits = C["logit_scale"] * ae @ C["text_embeds"].t() # (n_audio, n_text)
probs = torch.softmax(logits, dim=-1)
fam = (probs @ C["onehot"]).tolist() # (n_audio, n_fam) fam = (probs @ C["onehot"]).tolist() # (n_audio, n_fam)
for (i, _), row in zip(audios, fam): for (i, _), row in zip(audios, fam):
out[i] = dict(zip(C["fam_keys"], row)) out[i] = dict(zip(C["fam_keys"], row))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment