You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
73 lines
2.5 KiB
73 lines
2.5 KiB
import argparse |
|
import pickle |
|
from pathlib import Path |
|
|
|
import numpy as np |
|
import whisper |
|
from sklearn.pipeline import Pipeline |
|
|
|
DEMOG_STATS = { |
|
'DEMOG_AGE': { |
|
'mean': 52.971585365853656, |
|
'std': 10.597689763329704}, |
|
'DEMOG_SEX': { |
|
'mean': 1.5365853658536586, |
|
'std': 0.5017284187715013}, |
|
'HIST_CAG_NBALLEL2': { |
|
'mean': 43.5609756097561, |
|
'std': 2.867861770588018, }} |
|
|
|
PREDICTED_SCORES = {'cUHDRS': 'cUHDRS', |
|
'UHDRS_SCORE_MOTEUR': 'TMS', |
|
'UHDRS_TFC_SCORE': 'TFC'} |
|
|
|
|
|
def extract_logprob_feats(audio_path: Path): |
|
model = whisper.load_model("base") |
|
result = model.transcribe(str(audio_path)) |
|
|
|
logprobs = [segment["avg_logprob"] for segment in result["segments"]] |
|
mean_logprob = np.mean(logprobs) |
|
std_logprob = np.std(logprobs) |
|
|
|
return mean_logprob, std_logprob |
|
|
|
|
|
def main(): |
|
parser = argparse.ArgumentParser(description="Analyze audio transcription perplexity") |
|
parser.add_argument("audio_file", type=str, help="Path to the audio file to analyze") |
|
parser.add_argument("--model", type=str, default="large", |
|
choices=["tiny", "base", "small", "medium", "large"], |
|
help="Size of the whisper model to use") |
|
args = parser.parse_args() |
|
audio_path = Path(args.audio_file) |
|
|
|
if not audio_path.exists(): |
|
raise FileNotFoundError(f"Audio file not found: {audio_path}") |
|
|
|
mean_logprob, std_logprob = extract_logprob_feats(audio_path) |
|
print(f"Mean log probability: {mean_logprob:.4f}") |
|
print(f"Standard deviation: {std_logprob:.4f}") |
|
|
|
feats = { |
|
'mean_logprob': mean_logprob, |
|
'std_logprob': std_logprob, |
|
} |
|
|
|
# sample a random (fake) value for each demog feature |
|
for demog_name, demog_stats in DEMOG_STATS.items(): |
|
feats[demog_name] = np.random.normal(demog_stats['mean'], demog_stats['std']) |
|
|
|
feat_names_order = ['mean_logprob', 'std_logprob', 'DEMOG_AGE', 'HIST_CAG_NBALLEL2', 'DEMOG_SEX'] |
|
feat_array = np.array([feats[feat_name] for feat_name in feat_names_order]) |
|
|
|
# for each score, loading the corresponding model and predicting the score |
|
for score_real_name, score_name in PREDICTED_SCORES.items(): |
|
with open(f"models/reg_lin_{score_real_name}.pkl", 'rb') as f: |
|
score_model: Pipeline = pickle.load(f) |
|
predicted_score = score_model.predict(feat_array.reshape(1, -1)) |
|
print(f"Predicted {score_name}: {predicted_score}") |
|
|
|
|
|
if __name__ == "__main__": |
|
main()
|
|
|