Python wrapper for Xiph.org rnnoise ( https://gitlab.xiph.org/xiph/rnnoise )
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_pyrnnoise.py 3.6KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. #!/usr/bin/env python3
  2. import unittest
  3. import wave
  4. import rnnoise
  5. INPUT="samples/bonjour.wav"
  6. class TestRNNoiseModule(unittest.TestCase):
  7. def test_frame_info(self):
  8. """ Check rnnoise module's methods """
  9. fsize = rnnoise.frame_size()
  10. bps = rnnoise.bytes_per_sample()
  11. self.assertTrue(isinstance(fsize, int))
  12. self.assertTrue(isinstance(bps, int))
  13. self.assertTrue(fsize > 0)
  14. self.assertTrue(bps > 0)
  15. class TestRNNoise(unittest.TestCase):
  16. def test_process_frame(self):
  17. rnn = rnnoise.RNNoise()
  18. with wave.open(INPUT, 'rb') as fp:
  19. frame = fp.readframes(rnn.frame_size)
  20. res = rnn.process_frame(frame)
  21. self.assertTrue(isinstance(res, bytes))
  22. self.assertEqual(len(res), rnn.frame_size*rnnoise.bytes_per_sample())
  23. def test_process_frame_invalid(self):
  24. rnn = rnnoise.RNNoise()
  25. bad_frames = (
  26. bytes(rnn.frame_size),
  27. bytes((rnn.frame_size-1)*rnnoise.bytes_per_sample()),
  28. bytes((rnn.frame_size+1)*rnnoise.bytes_per_sample()),
  29. bytes((rnn.frame_size*rnnoise.bytes_per_sample())-1),
  30. bytes((rnn.frame_size*rnnoise.bytes_per_sample())+1),
  31. )
  32. for bad_frame in bad_frames:
  33. msg = "Testing with a bad frame"
  34. with self.subTest(msg, frame_len=len(bad_frame)) as st:
  35. with self.assertRaises(ValueError):
  36. rnn.process_frame(bad_frame)
  37. class TestRNNoiseIterator(unittest.TestCase):
  38. """ Testing RNNoise.iter_on() method """
  39. def setUp(self):
  40. rnn = rnnoise.RNNoise()
  41. self.orig = bytes()
  42. with wave.open(INPUT, 'rb') as fp:
  43. needed = rnn.frame_size * rnnoise.bytes_per_sample()
  44. while True:
  45. frame = fp.readframes(rnn.frame_size)
  46. if len(frame) == 0:
  47. break
  48. if len(frame) < needed:
  49. frame += bytes(needed-len(frame))
  50. newframe = rnn.process_frame(frame)
  51. newframe = newframe[:len(frame)]
  52. self.orig += newframe
  53. @staticmethod
  54. def in_iter(fsize):
  55. """ fsize is a callable returning expected frame size """
  56. with wave.open(INPUT, 'rb') as fp:
  57. while True:
  58. frame = fp.readframes(fsize())
  59. if len(frame):
  60. yield frame
  61. else:
  62. raise StopIteration()
  63. def test_iter_big_samples(self):
  64. """ Testing with iterator on 4096 samples """
  65. rnn = rnnoise.RNNoise()
  66. result = bytes()
  67. for frame in rnn.iter_on(self.in_iter(lambda: 4096)):
  68. result += frame
  69. self.assertEqual(result, self.orig)
  70. def test_byte_sample(self):
  71. """ Testing with iterator on 1 byte """
  72. rnn = rnnoise.RNNoise()
  73. def _cust_iter():
  74. for sample in self.in_iter(lambda: 1):
  75. yield sample[:1]
  76. yield sample[1:]
  77. result = bytes()
  78. for frame in rnn.iter_on(_cust_iter()):
  79. result += frame
  80. self.assertEqual(result, self.orig)
  81. def test_invalid(self):
  82. """ Testing with invalid iterator on odd number of bytes """
  83. rnn = rnnoise.RNNoise()
  84. def _cust_iter():
  85. data = self.in_iter(lambda: 4096)
  86. yield next(data)[:-1]
  87. for sample in data:
  88. yield sample
  89. with self.assertRaises(ValueError):
  90. result = bytes()
  91. for frame in rnn.iter_on(_cust_iter()):
  92. result += frame
  93. if __name__ == '__main__':
  94. unittest.main()