123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111 |
- #!/usr/bin/env python3
-
- import unittest
- import wave
-
- import rnnoise
-
- INPUT="samples/bonjour.wav"
-
- class TestRNNoiseModule(unittest.TestCase):
-
- def test_frame_info(self):
- """ Check rnnoise module's methods """
- fsize = rnnoise.frame_size()
- bps = rnnoise.bytes_per_sample()
- self.assertTrue(isinstance(fsize, int))
- self.assertTrue(isinstance(bps, int))
- self.assertTrue(fsize > 0)
- self.assertTrue(bps > 0)
-
- class TestRNNoise(unittest.TestCase):
-
- def test_process_frame(self):
- rnn = rnnoise.RNNoise()
- with wave.open(INPUT, 'rb') as fp:
- frame = fp.readframes(rnn.frame_size)
- res = rnn.process_frame(frame)
- self.assertTrue(isinstance(res, bytes))
- self.assertEqual(len(res), rnn.frame_size*rnnoise.bytes_per_sample())
-
- def test_process_frame_invalid(self):
- rnn = rnnoise.RNNoise()
- bad_frames = (
- bytes(rnn.frame_size),
- bytes((rnn.frame_size-1)*rnnoise.bytes_per_sample()),
- bytes((rnn.frame_size+1)*rnnoise.bytes_per_sample()),
- bytes((rnn.frame_size*rnnoise.bytes_per_sample())-1),
- bytes((rnn.frame_size*rnnoise.bytes_per_sample())+1),
- )
- for bad_frame in bad_frames:
- msg = "Testing with a bad frame"
- with self.subTest(msg, frame_len=len(bad_frame)) as st:
- with self.assertRaises(ValueError):
- rnn.process_frame(bad_frame)
-
- class TestRNNoiseIterator(unittest.TestCase):
- """ Testing RNNoise.iter_on() method """
-
- def setUp(self):
- rnn = rnnoise.RNNoise()
- self.orig = bytes()
- with wave.open(INPUT, 'rb') as fp:
- needed = rnn.frame_size * rnnoise.bytes_per_sample()
- while True:
- frame = fp.readframes(rnn.frame_size)
- if len(frame) == 0:
- break
- if len(frame) < needed:
- frame += bytes(needed-len(frame))
- newframe = rnn.process_frame(frame)
- newframe = newframe[:len(frame)]
- self.orig += newframe
-
- @staticmethod
- def in_iter(fsize):
- """ fsize is a callable returning expected frame size """
- with wave.open(INPUT, 'rb') as fp:
- while True:
- frame = fp.readframes(fsize())
- if len(frame):
- yield frame
- else:
- raise StopIteration()
-
- def test_iter_big_samples(self):
- """ Testing with iterator on 4096 samples """
- rnn = rnnoise.RNNoise()
- result = bytes()
- for frame in rnn.iter_on(self.in_iter(lambda: 4096)):
- result += frame
- self.assertEqual(result, self.orig)
-
- def test_byte_sample(self):
- """ Testing with iterator on 1 byte """
- rnn = rnnoise.RNNoise()
- def _cust_iter():
- for sample in self.in_iter(lambda: 1):
- yield sample[:1]
- yield sample[1:]
-
- result = bytes()
- for frame in rnn.iter_on(_cust_iter()):
- result += frame
- self.assertEqual(result, self.orig)
-
- def test_invalid(self):
- """ Testing with invalid iterator on odd number of bytes """
- rnn = rnnoise.RNNoise()
- def _cust_iter():
- data = self.in_iter(lambda: 4096)
- yield next(data)[:-1]
- for sample in data:
- yield sample
-
- with self.assertRaises(ValueError):
- result = bytes()
- for frame in rnn.iter_on(_cust_iter()):
- result += frame
-
- if __name__ == '__main__':
- unittest.main()
|