Browse Source

forgot to parse replacement diphones. woops.

master
hadware 6 years ago
parent
commit
30d9100d9b
  1. 66
      db_reader.py

66
db_reader.py

@ -1,4 +1,5 @@
import argparse
import struct
from dataclasses import dataclass, field
from io import SEEK_SET
from pathlib import Path
@ -20,10 +21,10 @@ def read_str(io: BinaryIO):
output_str = ""
while True:
char = unpack_from("<c", io.read(1))[0]
if char == "\x00": # null char at the end of string
if char == b"\x00": # null char at the end of string
break
else:
output_str += char
output_str += char.decode('latin1')
return output_str
@ -71,8 +72,15 @@ class MbrolaDatabase:
silence_phone: str = "_" # silence symbol in the database
diphone_table: Dict[Diphone, DiphoneInfo] = field(default_factory=dict)
remapped_diphones: Dict[Diphone, Diphone] = field(default_factory=dict)
pitch_marks: np.ndarray = None
def __getitem__(self, diph: Diphone):
try:
return self.diphone_table[diph]
except KeyError:
return self.diphone_table[self.remapped_diphones[diph]]
def read_header(self, db_file: BinaryIO):
"""Reads the database header"""
self.magic_header = db_file.read(6).decode()
@ -113,6 +121,8 @@ class MbrolaDatabase:
halfseg=half_segment,
pos_pm=position_pm,
nb_frame=nb_frames)
if (left_phone, right_phone) in self.diphone_table:
print((left_phone, right_phone), " already in table")
self.diphone_table[(left_phone, right_phone)] = new_diph
# keep track of the phoneme with the biggest number of frames
@ -120,18 +130,42 @@ class MbrolaDatabase:
self.max_frame = nb_wframe
i += 1
# retrieving all phonemes remaps
while i < self.nb_diphone:
left_phone = read_str(db_file)
right_phone = read_str(db_file)
repl_diph = (left_phone, right_phone)
if repl_diph not in self.diphone_table:
print(f"({repl_diph[0]}, {repl_diph[1]}) were not present in the database!")
continue
left_phone = read_str(db_file)
right_phone = read_str(db_file)
remapped_diph = (left_phone, right_phone)
self.remapped_diphones[remapped_diph] = repl_diph
i += 1
# sanity check
assert len(self.diphone_table) + len(self.remapped_diphones) == self.nb_diphone
def read_pitchmarks(self, db_file: BinaryIO):
round_size = (self.size_mark + 3) // 4
self.pitch_marks = np.array(f"<{round_size}B", db_file.read(round_size))
offsets = unpack(f"<{round_size}B", db_file.read(round_size))
self.pitch_marks = np.array(offsets)
self.raw_offset = db_file.tell()
def read_info(self, db_file: BinaryIO):
db_file.seek(self.raw_offset + self.size_raw, whence=SEEK_SET)
db_file.seek(self.raw_offset + self.size_raw, SEEK_SET)
self.info = []
while True:
try:
self.info.append(read_str(db_file))
except EOFError:
s = read_str(db_file)
if s:
self.info.append(s)
except (EOFError, struct.error):
break
def read_database(self):
@ -154,23 +188,5 @@ if __name__ == "__main__":
raise ValueError("Can't find database file")
mbr_db = MbrolaDatabase(db_path=db_path)
mbr_db.read_database()
with open(db_path, "rb") as db_file:
# should read "mbrola"
mbr_db.magic_header = db_file.read(6).decode()
mbr_db.version = db_file.read(5).decode()
mbr_db.nb_diphone = unpack_from("<h", db_file)[0]
old_size_mark = unpack_from("<H", db_file)[0]
if old_size_mark == 0:
mbr_db.size_mark = unpack_from("<i", db_file)
else:
mbr_db.size_mark = old_size_mark
mbr_db.size_raw = unpack_from("<i", db_file)
mbr_db.freq = unpack_from("<h", db_file)
mbr_db.mbr_period = unpack_from("<B", db_file)
mbr_db.coding = unpack_from("<B", db_file)
# TODO:
# ReadDatabaseIndex(dba) ||
# !ReadDatabasePitchMark(dba) ||
#  !ReadDatabaseInfo(dba)

Loading…
Cancel
Save