-
Notifications
You must be signed in to change notification settings - Fork 215
Expand file tree
/
Copy pathmodel_bl.py
More file actions
21 lines (16 loc) · 737 Bytes
/
model_bl.py
File metadata and controls
21 lines (16 loc) · 737 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import torch
import torch.nn as nn
class D_VECTOR(nn.Module):
"""d vector speaker embedding."""
def __init__(self, num_layers=3, dim_input=40, dim_cell=256, dim_emb=64):
super(D_VECTOR, self).__init__()
self.lstm = nn.LSTM(input_size=dim_input, hidden_size=dim_cell,
num_layers=num_layers, batch_first=True)
self.embedding = nn.Linear(dim_cell, dim_emb)
def forward(self, x):
self.lstm.flatten_parameters()
lstm_out, _ = self.lstm(x)
embeds = self.embedding(lstm_out[:,-1,:])
norm = embeds.norm(p=2, dim=-1, keepdim=True)
embeds_normalized = embeds.div(norm)
return embeds_normalized