You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

192 lines
6.6 KiB

import argparse
import struct
from dataclasses import dataclass, field
from io import SEEK_SET
from pathlib import Path
from struct import unpack_from, unpack
from typing import List, Tuple, BinaryIO, Dict
import numpy as np
MBROLA_VOICES_FOLDER = Path("/usr/share/mbrola/")
argparser = argparse.ArgumentParser()
argparser.add_argument("mbrola_db", type=Path,
help="Mrbola db name or direct path")
PhonemeCode = int
Diphone = Tuple[str,str]
def read_str(io: BinaryIO):
"""Read a string from a binary IO"""
output_str = ""
while True:
char = unpack_from("<c", io.read(1))[0]
if char == b"\x00": # null char at the end of string
break
else:
output_str += char.decode('latin1')
return output_str
@dataclass
class DiphoneInfo:
left: str
right: str
pos_wave: int # position in SPEECH_FILE
halfseg: int # position of center of diphone
pos_pm: int # index in PITCHMARK_FILE
nb_frame: int # Number of pitch markers
@property
def left_code(self):
pass
@property
def right_code(self):
pass
@dataclass
class FrameType:
pass
@dataclass
class MbrolaDatabase:
# TODO : init default values as specified in the database_init
db_path: Path # name of database
coding: int = 1 # database format. Might be useless
freq: int = 0 # Sampling frequency of the database
mbr_period: int = 0 # Period of the MBR analysis
nb_diphone: int = 0 # Number of diphones in the database
size_mark: int = 0 # Size of the pitchmark part
size_raw: int = 0 # Size of the wave part
raw_offset: int = 0 # Offset for raw samples in database
max_frame: int = 0 # Maximum number of frames encountered for a diphone in the dba
max_samples: int = 0 # Size of the diphone buffer= 0 means let me manage it myself
magic_header: str = "MBROLA" # Magic header of the database
version: str = "2.06" # version of the database
info: List[str] = None
silence_phone: str = "_" # silence symbol in the database
diphone_table: Dict[Diphone, DiphoneInfo] = field(default_factory=dict)
remapped_diphones: Dict[Diphone, Diphone] = field(default_factory=dict)
pitch_marks: np.ndarray = None
def __getitem__(self, diph: Diphone):
try:
return self.diphone_table[diph]
except KeyError:
return self.diphone_table[self.remapped_diphones[diph]]
def read_header(self, db_file: BinaryIO):
"""Reads the database header"""
self.magic_header = db_file.read(6).decode()
self.version = db_file.read(5).decode()
self.nb_diphone = unpack_from("<h", db_file.read(2))[0]
old_size_mark = unpack_from("<H", db_file.read(2))[0]
if old_size_mark == 0:
self.size_mark = unpack_from("<i", db_file.read(4))[0]
else:
self.size_mark = old_size_mark
self.size_raw = unpack_from("<i", db_file.read(4))[0]
self.freq = unpack_from("<h", db_file.read(2))[0]
self.mbr_period = unpack_from("<B", db_file.read(1))[0]
self.coding = unpack_from("<B", db_file.read(1))[0]
def read_index(self, db_file: BinaryIO):
"""Reads the index table of diphones"""""
i = 0
pm_index = 0 #  cumulative position in pitch mark vector
wav_index = 0 #   cumulative position in the waveform database
while pm_index != self.size_mark and i < self.nb_diphone:
left_phone = read_str(db_file)
right_phone = read_str(db_file)
half_segment = unpack_from("<h", db_file.read(2))[0]
nb_frames = unpack_from("<B", db_file.read(1))[0]
nb_wframe = unpack_from("<B", db_file.read(1))[0]
position_pm = pm_index
pm_index += nb_frames
if pm_index == self.size_mark:
self.silence_phone = left_phone
new_wav_index = wav_index
wav_index = nb_wframe * self.mbr_period
new_diph = DiphoneInfo(left=left_phone, right=right_phone,
pos_wave=new_wav_index,
halfseg=half_segment,
pos_pm=position_pm,
nb_frame=nb_frames)
if (left_phone, right_phone) in self.diphone_table:
print((left_phone, right_phone), " already in table")
self.diphone_table[(left_phone, right_phone)] = new_diph
# keep track of the phoneme with the biggest number of frames
if self.max_frame < nb_wframe:
self.max_frame = nb_wframe
i += 1
# retrieving all phonemes remaps
while i < self.nb_diphone:
left_phone = read_str(db_file)
right_phone = read_str(db_file)
repl_diph = (left_phone, right_phone)
if repl_diph not in self.diphone_table:
print(f"({repl_diph[0]}, {repl_diph[1]}) were not present in the database!")
continue
left_phone = read_str(db_file)
right_phone = read_str(db_file)
remapped_diph = (left_phone, right_phone)
self.remapped_diphones[remapped_diph] = repl_diph
i += 1
# sanity check
assert len(self.diphone_table) + len(self.remapped_diphones) == self.nb_diphone
def read_pitchmarks(self, db_file: BinaryIO):
round_size = (self.size_mark + 3) // 4
offsets = unpack(f"<{round_size}B", db_file.read(round_size))
self.pitch_marks = np.array(offsets)
self.raw_offset = db_file.tell()
def read_info(self, db_file: BinaryIO):
db_file.seek(self.raw_offset + self.size_raw, SEEK_SET)
self.info = []
while True:
try:
s = read_str(db_file)
if s:
self.info.append(s)
except (EOFError, struct.error):
break
def read_database(self):
with open(self.db_path, "rb") as db_file:
self.read_header(db_file)
self.read_index(db_file)
self.read_pitchmarks(db_file)
self.read_info(db_file)
if __name__ == "__main__":
args = argparser.parse_args()
if (MBROLA_VOICES_FOLDER / args.mbrola_db / args.mbrola_db).is_file():
db_path = MBROLA_VOICES_FOLDER / args.mbrola_db / args.mbrola_db
elif (MBROLA_VOICES_FOLDER / args.mbrola_db).is_file():
db_path = MBROLA_VOICES_FOLDER / args.mbrola_db
elif args.mbrola_db.is_file():
db_path = args.mbrola_db
else:
raise ValueError("Can't find database file")
mbr_db = MbrolaDatabase(db_path=db_path)
mbr_db.read_database()