I am a machine learning researcher at Camb.AI. I post about deep learning, electronics, and other things I find interesting.
by Matthew Baas
A major update to my simple-speaker-embeddings repository with vastly improved training and new pretrained model.
TL;DR: A brief summary of my update to the simple-speaker-embeddings repository, how to use the new pytorch models, and how to train and fine-tune the models on your own dataset.
Speaker embedding vectors, also known as $d$-vectors, attempt to represent a particular speaker identity in a fixed dimensional vector. These vectors are associated with the unique speaking style from speech audio from a particular speaker, and not from the writing style of a particular author (which I imagine should more aptly be called author embeddings). Typically, speaker embeddings are computed in some way from a set of recordings of the desired speaker talking. The embeddings should be computed in such a way that the vector is only unique to the speaking style and not the particular words spoken.
In other words, the vector is trained to be unique to the speaker identity of the input utterance – so the returned vector should remain the same regardless of what words are spoken in the input utterance, and depend only on who is speaking in the input utterance. For example, if an input utterance saying “The quick brown fox” spoken by speaker A is fed into the model, the resulting embedding vector should be close (in terms of Euclidean/cosine distance) to the embedding of an utterance saying “I like speaker embeddings” also spoken by speaker A. Conversely, the embedding should be far away (in terms of Euclidean/cosine distance) from the embedding of an utterance saying the same “The quick brown fox” spoken by speaker B. Thus the embedding should be unique to the identity of the speaker of an input utterance, and not the linguistic content of the input utterance.
Having such a vector unique to different speakers is super useful in tasks such as voice conversion, speaker verification, conditioned text-to-speech, and other speech processing tasks. We can also compare embedding vectors to one another to assess how similar two speakers are numerically.
Previously I had released a fairly nicely performing inital model which was a 3 layer GRU trained on 22.05kHz log mel-scale spectrograms using fastai v1. It produced 256-dimensional speaker embeddings and was trained on several datastes.
A year on from the initial simple-speaker-embedding (SSE) release, it became clear that the fastai v1 method for training was not the most reproducible technique, and the mel-spectrogram intermediary was actually an unnecessary middle-man in obtaining a good speaker embedding. So, in search to make speaker embeddings even more simple, the new simple-speaker-embedding v1.0 is now released to address the following:
The rest of this post goes into detail about these improvements and features.
The pretrained models are described here. All pretrained models are available through torchhub, so no specific python package must be installed to use these models. The pretrained model card is given below, where the speaker embeddings $\mathbf{s}$ are of dimension $d$:
torchub model | sample rate | training datasets | input format | $d$ | release |
---|---|---|---|---|---|
gru_embedder |
22.05khz | VCC 2018, VCTK, Librispeech, CommonVoice English | log mel-scale spectrogram | 256 | Nov 2020 |
convgru_embedder |
16kHz | VCTK, Librispeech, voxceleb1, voxceleb2 | raw waveform | 256 | Dec 2021 |
The primary difference between the first gru_embedder
model and the convgru_embedder
model is that the latter has a convolutional encoder before the GRU network.
The convolutional encoder operates on raw waveforms (sequence of floating point numbers between -1.0 and 1.0) and is the same CNN encoder as that used by wav2vec 2.0.
Because the new convgru_embedder
operates on raw waveforms, using it removes the dependency on log mel-scale spectrograms (and all the hyperparameters associated with it).
It also removes the necessity for the old Tacotron 2 code which computed the spectrograms, and removes a step in computing the speaker embeddings, thereby reducing code complexity.
The new model also is trained on 16kHz sampled audio, which is a common sampling rate used in speech processing tasks.
Both models are trained using the GE2E loss function using implementation provided by HarryVolek.
The new convgru_embedder
model is trained on the datasets given in the table above for 2000 epochs with a batch size of 9 speakers with 6 utterances per speaker (i.e. every batch passes 6x9=54 utterances through the model).
Training proceeds with the Adam optimizer using a gradient clipping of 1.0 and FP16 autocasting.
The learning rate follows a piecewise linear schedule. It begins at 1e-8 and is linearly increased to 4e-5 at 15% of the total training updates, whereafter it is linearly decreased to 1e-7 in the final training update.
The utterance length in each batch is uniformly sampled from between 1s and 7.5s each batch, with all contained utterances cropped/padded to fit the same length in a batch.
The full details for usage of both the new 16kHz and older 22.05kHz model are available on the repo readme, I encourage those interested to read further there.
The new training script makes the following trade-off when compared to the old training script:
omegaconf
for config management, fastprogress
for progress bars, and pandas
for csv parsing (all except omegaconf
were already requirements for fastai).The result of this is a much more lightweight environment with fewer overall pip and conda dependencies. Better yet, the new training script has the following features:
nccl
with pytorch’s built in DDP module for optional distributed training by setting the number of gpus to use in the config.All of these features are enabled or disabled by command line arguments, which one can see the full roster of with python train.py --help
. Further details for training options and fine-tuning are available in the repo readme.
The training script obtains its training and validation utterances from a command-line supplied training csv and validation csv path. The csv’s have a format of:
path,speaker
<path to wav>,<speaker id>
<path to wav>,<speaker id>
Where the path
is a file path to a wav/flac/mp3 or other audio file, and speaker
is the speaker name/id to which that waveform path belongs.
For best results, each speaker should have at least 20+ waveforms associated with him/her.
There is also a new dataset splitting script to generate such a train & validation csv for common speech datasets. Usage guidelines for it are given in the repo readme.
A new evaluation script is also available.
eval.py
accepts a model name and optional custom checkpoint path together with a csv to evaluate over (and optional evaluation seed).
The csv format is the same as that for train.py
specified above.
The eval.py
script will:
“This is the rate used to determine the threshold value for a [speaker verification] system when its false acceptance rate (FAR) and false rejection rate (FRR) are equal. FAR = FP/(FP+TN) and FRR = FN/(TP+FN)” where
- FN is the number of false negatives
- FP is the number of false positives
- TN is the number of true negatives
- TP is the number of true positives
Which, translated for this project, means that for a single utterance we do the following:
Then, using these cosine similarities over the full test dataset, look at how often similarities from utterances of the same speaker are higher than similarities between utterances from different speakers. The cosine similarity threshold for deciding when embeddings belong to the same speaker is what is optimized in the definition above. The end result is that the EER is a number between 0 and 1, with lower being better.
Using the new evaluation script, we perform benchmarks on both the old gru_embedder
and new convgru_embedder
models.
Concretely, the evaluation is given on Librispeech test-clean
and test-other
datasets.
These datasets consist of completely unseen speakers to all the models.
Concretely, all speakers included in this evaluation are not seen during training or validation.
Furthermore, the test utterances are not cropped or reduced in length, but rather the model is applied directly to the full length waveform. This means that the test results here are pessimistic. Better numbers will be obtained by cropping utterances to reasonable lengths (4-8s), and taking the mean speaker embedding from multiple utterances or multiple parts of a single long utterance, or both, and then only using this mean embedding vector when comparing to embeddings for other speakers.
In fact, in the original GE2E paper, they compute several embeddings from a single utterance using a sliding window of 1.6s frames with 50% overlap between adjacent frames. They then average the embeddings generated from each frame, and then re-normalize the resulting vector to be of unit length again. I do not do any of these optimizations for the sake of simplicity. The evaluation done here purely operates on full length, unbutchered waveforms.
model | test-clean EER | test-other EER |
---|---|---|
gru_embedder |
0.0797 |
0.0766 |
convgru_embedder |
0.0295 |
0.0181 |
Below are the 2D UMAP plots for both pretrained models. In each scatter plot, each point corresponds to an embedding utterance, while the color of the point corresponds to the speaker identity of the utterance.
From both the objective metrics and the subjective quality of the clustering in the UMAP plots above, it is clear that the new model is substantially better than the old one.
So, if you are doing new speaker embedding tasks and need a simple speaker embedding, I recommend the 16kHz convgru_embedder
model!
What is even better is that with tensorboard, I can share most of the logs generated during training online easily.
Namely, the tensorboard training logs for the convgru_embedder
model are available to view at this link.
Pretty epic indeed!
It is still somewhat limited in that it does not upload projections yet and Google’s privacy policy is not great (they provide the free logs hosting on tensorboard.dev).
But overall, still pretty neat.
So, if you wish to see the exact training and validation loss curves you can expect, head on over :).
Thanks for your interest in my simple-speaker-embedding project. I hope this updated release has (a) made you more confident in the robustness and performance of both old and new models, and (b) convinced you that the new 16kHz model is quite a mean machine when it comes to getting high performing speaker embeddings. And, if you do take an interest, that fine-tuning, evaluating, and otherwise modifying my repo to work for your own models and custom datasets is much easier.
I would like to also extend a thanks to the managers of the Stellenbosch University’s High Performance Computing Cluster for the compute necessary in training the final convgru_embedder
model.
As always, if you spot any errors I have made or would just like to ask a question, please get in contact with me from the About page, or raise an issue on the github repo. Have a good one!
tags: speaker embeddings - d-vectors - speaker verification - pytorch