Skip to content

Commit e11d52e

Browse files
committed
Add htsengine module
1 parent ac47215 commit e11d52e

File tree

4 files changed

+208
-66
lines changed

4 files changed

+208
-66
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
pyopenjtalk/version.py
22
pyopenjtalk/openjtalk.cpp
3+
pyopenjtalk/htsengine.cpp
34
pyopenjtalk/open_jtalk_dic_utf_8-1.10/
45
pyopenjtalk/open_jtalk_dic_utf_8-1.11/
56
docs/generated

pyopenjtalk/htsengine.pyx

+97
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
# coding: utf-8
2+
# cython: boundscheck=True, wraparound=True
3+
# cython: c_string_type=unicode, c_string_encoding=ascii
4+
5+
import numpy as np
6+
7+
cimport numpy as np
8+
np.import_array()
9+
10+
cimport cython
11+
from libc.stdlib cimport malloc, free
12+
13+
from htsengine cimport HTS_Engine
14+
from htsengine cimport (
15+
HTS_Engine_initialize, HTS_Engine_load, HTS_Engine_clear, HTS_Engine_refresh,
16+
HTS_Engine_get_sampling_frequency, HTS_Engine_get_fperiod,
17+
HTS_Engine_synthesize_from_strings,
18+
HTS_Engine_get_generated_speech, HTS_Engine_get_nsamples
19+
)
20+
21+
cdef class HTSEngine(object):
22+
"""HTSEngine
23+
"""
24+
cdef HTS_Engine* engine
25+
26+
def __cinit__(self, voice=b"htsvoice/mei_normal.htsvoice"):
27+
self.engine = new HTS_Engine()
28+
29+
HTS_Engine_initialize(self.engine)
30+
31+
if self.load(voice) != 1:
32+
self.clear()
33+
raise RuntimeError("Failed to initalize HTS_Engine")
34+
35+
def load(self, bytes voice):
36+
cdef char* voices = voice
37+
cdef char ret
38+
ret = HTS_Engine_load(self.engine, &voices, 1)
39+
return ret
40+
41+
def get_sampling_frequency(self):
42+
"""Get sampling frequency
43+
"""
44+
return HTS_Engine_get_sampling_frequency(self.engine)
45+
46+
def get_fperiod(self):
47+
"""Get frame period"""
48+
return HTS_Engine_get_fperiod(self.engine)
49+
50+
def synthesize(self, list labels):
51+
"""Synthesize waveform from list of full-context labels
52+
53+
Args:
54+
labels: full context labels
55+
56+
Returns:
57+
np.ndarray: speech waveform
58+
"""
59+
self.synthesize_from_strings(labels)
60+
x = self.get_generated_speech()
61+
self.refresh()
62+
return x
63+
64+
def synthesize_from_strings(self, list labels):
65+
"""Synthesize from strings"""
66+
cdef size_t num_lines = len(labels)
67+
cdef char **lines = <char**> malloc((num_lines + 1) * sizeof(char*))
68+
for n in range(len(labels)):
69+
lines[n] = <char*>labels[n]
70+
71+
cdef char ret = HTS_Engine_synthesize_from_strings(self.engine, lines, num_lines)
72+
free(lines)
73+
if ret != 1:
74+
raise RuntimeError("Failed to run synthesize_from_strings")
75+
76+
def get_generated_speech(self):
77+
"""Get generated speech"""
78+
cdef size_t nsamples = HTS_Engine_get_nsamples(self.engine)
79+
cdef np.ndarray speech = np.zeros([nsamples], dtype=np.float64)
80+
cdef size_t index
81+
for index in range(nsamples):
82+
speech[index] = HTS_Engine_get_generated_speech(self.engine, index)
83+
return speech
84+
85+
def get_fullcontext_label_format(self):
86+
"""Get full-context label format"""
87+
return (<bytes>HTS_Engine_get_fullcontext_label_format(self.engine)).decode("utf-8")
88+
89+
def refresh(self):
90+
HTS_Engine_refresh(self.engine)
91+
92+
def clear(self):
93+
HTS_Engine_clear(self.engine)
94+
95+
def __dealloc__(self):
96+
self.clear()
97+
del self.engine

pyopenjtalk/htsengine/__init__.pxd

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# distutils: language = c++
2+
3+
4+
cdef extern from "HTS_engine.h":
5+
cdef cppclass _HTS_Engine:
6+
pass
7+
ctypedef _HTS_Engine HTS_Engine
8+
9+
void HTS_Engine_initialize(HTS_Engine * engine)
10+
char HTS_Engine_load(HTS_Engine * engine, char **voices, size_t num_voices)
11+
size_t HTS_Engine_get_sampling_frequency(HTS_Engine * engine)
12+
size_t HTS_Engine_get_fperiod(HTS_Engine * engine)
13+
void HTS_Engine_refresh(HTS_Engine * engine)
14+
void HTS_Engine_clear(HTS_Engine * engine)
15+
const char *HTS_Engine_get_fullcontext_label_format(HTS_Engine * engine)
16+
char HTS_Engine_synthesize_from_strings(HTS_Engine * engine, char **lines, size_t num_lines)
17+
char HTS_Engine_synthesize_from_fn(HTS_Engine * engine, const char *fn)
18+
double HTS_Engine_get_generated_speech(HTS_Engine * engine, size_t index)
19+
size_t HTS_Engine_get_nsamples(HTS_Engine * engine)

setup.py

+91-66
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,41 @@
1-
# coding: utf-8
2-
3-
from __future__ import with_statement, print_function, absolute_import
4-
5-
from setuptools import setup, find_packages, Extension
6-
import setuptools.command.develop
7-
import setuptools.command.build_py
8-
from distutils.version import LooseVersion
9-
import subprocess
10-
import numpy as np
111
import os
2+
import subprocess
3+
from distutils.version import LooseVersion
124
from glob import glob
13-
from os.path import join, exists
5+
from os.path import exists, join
146
from subprocess import run
157

16-
version = '0.0.3'
8+
import numpy as np
9+
import setuptools.command.build_py
10+
import setuptools.command.develop
11+
from setuptools import Extension, find_packages, setup
12+
13+
version = "0.0.3"
1714

18-
min_cython_ver = '0.21.0'
15+
min_cython_ver = "0.21.0"
1916
try:
2017
import Cython
18+
2119
ver = Cython.__version__
2220
_CYTHON_INSTALLED = ver >= LooseVersion(min_cython_ver)
2321
except ImportError:
2422
_CYTHON_INSTALLED = False
2523

2624
try:
2725
if not _CYTHON_INSTALLED:
28-
raise ImportError('No supported version of Cython installed.')
29-
from Cython.Distutils import build_ext
26+
raise ImportError("No supported version of Cython installed.")
3027
from Cython.Build import cythonize
28+
from Cython.Distutils import build_ext
29+
3130
cython = True
3231
except ImportError:
3332
cython = False
3433

3534
if cython:
36-
ext = '.pyx'
37-
cmdclass = {'build_ext': build_ext}
35+
ext = ".pyx"
36+
cmdclass = {"build_ext": build_ext}
3837
else:
39-
ext = '.cpp'
38+
ext = ".cpp"
4039
cmdclass = {}
4140
if not os.path.exists(join("pyopenjtalk", "openjtalk" + ext)):
4241
raise RuntimeError("Cython is required to generate C++ code")
@@ -59,99 +58,125 @@
5958
all_src = []
6059
include_dirs = []
6160
for s in [
62-
"jpcommon", "mecab/src", "mecab2njd", "njd", "njd2jpcommon",
63-
"njd_set_accent_phrase", "njd_set_accent_type",
64-
"njd_set_digit", "njd_set_long_vowel", "njd_set_pronunciation",
65-
"njd_set_unvoiced_vowel", "text2mecab",
61+
"jpcommon",
62+
"mecab/src",
63+
"mecab2njd",
64+
"njd",
65+
"njd2jpcommon",
66+
"njd_set_accent_phrase",
67+
"njd_set_accent_type",
68+
"njd_set_digit",
69+
"njd_set_long_vowel",
70+
"njd_set_pronunciation",
71+
"njd_set_unvoiced_vowel",
72+
"text2mecab",
6673
]:
6774
all_src += glob(join(src_top, s, "*.c"))
6875
all_src += glob(join(src_top, s, "*.cpp"))
6976
include_dirs.append(join(os.getcwd(), src_top, s))
7077

71-
# define core cython module
72-
ext_modules = [Extension(
73-
name="pyopenjtalk.openjtalk",
74-
sources=[join("pyopenjtalk", "openjtalk" + ext)] + all_src,
75-
include_dirs=[np.get_include()] + include_dirs,
76-
extra_compile_args=[],
77-
extra_link_args=[],
78-
language="c++",
79-
define_macros=[
80-
("HAVE_CONFIG_H", None),
81-
("DIC_VERSION", 102), ("MECAB_DEFAULT_RC", "\"dummy\""),
82-
("PACKAGE", "\"open_jtalk\""),
83-
("VERSION", "\"1.10\""),
84-
("CHARSET_UTF_8", None),
85-
]
86-
)]
78+
# Extension for OpenJTalk frontend
79+
ext_modules = [
80+
Extension(
81+
name="pyopenjtalk.openjtalk",
82+
sources=[join("pyopenjtalk", "openjtalk" + ext)] + all_src,
83+
include_dirs=[np.get_include()] + include_dirs,
84+
extra_compile_args=[],
85+
extra_link_args=[],
86+
language="c++",
87+
define_macros=[
88+
("HAVE_CONFIG_H", None),
89+
("DIC_VERSION", 102),
90+
("MECAB_DEFAULT_RC", '"dummy"'),
91+
("PACKAGE", '"open_jtalk"'),
92+
("VERSION", '"1.10"'),
93+
("CHARSET_UTF_8", None),
94+
],
95+
)
96+
]
97+
98+
# Extension for HTSEngine backend
99+
htsengine_src_top = join("lib", "hts_engine_API", "src")
100+
all_htsengine_src = glob(join(htsengine_src_top, "lib", "*.c"))
101+
ext_modules += [
102+
Extension(
103+
name="pyopenjtalk.htsengine",
104+
sources=[join("pyopenjtalk", "htsengine" + ext)] + all_htsengine_src,
105+
include_dirs=[np.get_include(), join(htsengine_src_top, "include")],
106+
extra_compile_args=[],
107+
extra_link_args=[],
108+
language="c++",
109+
)
110+
]
87111

88112
# Adapted from https://github.com/pytorch/pytorch
89113
cwd = os.path.dirname(os.path.abspath(__file__))
90-
if os.getenv('PYOPENJTALK_BUILD_VERSION'):
91-
version = os.getenv('PYOPENJTALK_BUILD_VERSION')
114+
if os.getenv("PYOPENJTALK_BUILD_VERSION"):
115+
version = os.getenv("PYOPENJTALK_BUILD_VERSION")
92116
else:
93117
try:
94-
sha = subprocess.check_output(
95-
['git', 'rev-parse', 'HEAD'], cwd=cwd).decode('ascii').strip()
96-
version += '+' + sha[:7]
118+
sha = (
119+
subprocess.check_output(["git", "rev-parse", "HEAD"], cwd=cwd)
120+
.decode("ascii")
121+
.strip()
122+
)
123+
version += "+" + sha[:7]
97124
except subprocess.CalledProcessError:
98125
pass
99126
except IOError: # FileNotFoundError for python 3
100127
pass
101128

102129

103130
class build_py(setuptools.command.build_py.build_py):
104-
105131
def run(self):
106132
self.create_version_file()
107133
setuptools.command.build_py.build_py.run(self)
108134

109135
@staticmethod
110136
def create_version_file():
111137
global version, cwd
112-
print('-- Building version ' + version)
113-
version_path = os.path.join(cwd, 'pyopenjtalk', 'version.py')
114-
with open(version_path, 'w') as f:
138+
print("-- Building version " + version)
139+
version_path = os.path.join(cwd, "pyopenjtalk", "version.py")
140+
with open(version_path, "w") as f:
115141
f.write("__version__ = '{}'\n".format(version))
116142

117143

118144
class develop(setuptools.command.develop.develop):
119-
120145
def run(self):
121146
build_py.create_version_file()
122147
setuptools.command.develop.develop.run(self)
123148

124149

125-
cmdclass['build_py'] = build_py
126-
cmdclass['develop'] = develop
150+
cmdclass["build_py"] = build_py
151+
cmdclass["develop"] = develop
127152

128153

129-
with open('README.md', 'r') as fd:
154+
with open("README.md", "r") as fd:
130155
long_description = fd.read()
131156

132157
setup(
133-
name='pyopenjtalk',
158+
name="pyopenjtalk",
134159
version=version,
135-
description='A python wrapper for OpenJTalk',
160+
description="A python wrapper for OpenJTalk",
136161
long_description=long_description,
137-
long_description_content_type='text/markdown',
138-
author='Ryuichi Yamamoto',
139-
author_email='zryuichi@gmail.com',
140-
url='https://github.com/r9y9/pyopenjtalk',
141-
license='MIT',
162+
long_description_content_type="text/markdown",
163+
author="Ryuichi Yamamoto",
164+
author_email="zryuichi@gmail.com",
165+
url="https://github.com/r9y9/pyopenjtalk",
166+
license="MIT",
142167
packages=find_packages(),
143-
package_data={'': ['htsvoice/*']},
168+
package_data={"": ["htsvoice/*"]},
144169
ext_modules=ext_modules,
145170
cmdclass=cmdclass,
146171
install_requires=[
147-
'numpy >= 1.8.0',
148-
'cython >= ' + min_cython_ver,
149-
'six',
172+
"numpy >= 1.8.0",
173+
"cython >= " + min_cython_ver,
174+
"six",
150175
],
151-
tests_require=['nose', 'coverage'],
176+
tests_require=["nose", "coverage"],
152177
extras_require={
153-
'docs': ['sphinx_rtd_theme'],
154-
'test': ['nose', 'scipy'],
178+
"docs": ["sphinx_rtd_theme"],
179+
"test": ["nose", "scipy"],
155180
},
156181
classifiers=[
157182
"Operating System :: POSIX",
@@ -168,5 +193,5 @@ def run(self):
168193
"Intended Audience :: Science/Research",
169194
"Intended Audience :: Developers",
170195
],
171-
keywords=["OpenJTalk", "Research"]
196+
keywords=["OpenJTalk", "Research"],
172197
)

0 commit comments

Comments
 (0)