You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 

73 lines
2.5 KiB

import argparse
import pickle
from pathlib import Path
import numpy as np
import whisper
from sklearn.pipeline import Pipeline
DEMOG_STATS = {
'DEMOG_AGE': {
'mean': 52.971585365853656,
'std': 10.597689763329704},
'DEMOG_SEX': {
'mean': 1.5365853658536586,
'std': 0.5017284187715013},
'HIST_CAG_NBALLEL2': {
'mean': 43.5609756097561,
'std': 2.867861770588018, }}
PREDICTED_SCORES = {'cUHDRS': 'cUHDRS',
'UHDRS_SCORE_MOTEUR': 'TMS',
'UHDRS_TFC_SCORE': 'TFC'}
def extract_logprob_feats(audio_path: Path):
model = whisper.load_model("base")
result = model.transcribe(str(audio_path))
logprobs = [segment["avg_logprob"] for segment in result["segments"]]
mean_logprob = np.mean(logprobs)
std_logprob = np.std(logprobs)
return mean_logprob, std_logprob
def main():
parser = argparse.ArgumentParser(description="Analyze audio transcription perplexity")
parser.add_argument("audio_file", type=str, help="Path to the audio file to analyze")
parser.add_argument("--model", type=str, default="large",
choices=["tiny", "base", "small", "medium", "large"],
help="Size of the whisper model to use")
args = parser.parse_args()
audio_path = Path(args.audio_file)
if not audio_path.exists():
raise FileNotFoundError(f"Audio file not found: {audio_path}")
mean_logprob, std_logprob = extract_logprob_feats(audio_path)
print(f"Mean log probability: {mean_logprob:.4f}")
print(f"Standard deviation: {std_logprob:.4f}")
feats = {
'mean_logprob': mean_logprob,
'std_logprob': std_logprob,
}
# sample a random (fake) value for each demog feature
for demog_name, demog_stats in DEMOG_STATS.items():
feats[demog_name] = np.random.normal(demog_stats['mean'], demog_stats['std'])
feat_names_order = ['mean_logprob', 'std_logprob', 'DEMOG_AGE', 'HIST_CAG_NBALLEL2', 'DEMOG_SEX']
feat_array = np.array([feats[feat_name] for feat_name in feat_names_order])
# for each score, loading the corresponding model and predicting the score
for score_real_name, score_name in PREDICTED_SCORES.items():
with open(f"models/reg_lin_{score_real_name}.pkl", 'rb') as f:
score_model: Pipeline = pickle.load(f)
predicted_score = score_model.predict(feat_array.reshape(1, -1))
print(f"Predicted {score_name}: {predicted_score}")
if __name__ == "__main__":
main()