|
class SBERT: |
|
def __init__(self, config): |
|
self.loss = 0 |
|
self.metrics = [] |
|
self.inputs = [] |
|
self.config = config |
|
self.build() |
|
|
|
def build(self): |
|
|
|
self.saver_dict = {} |
|
self.build_body() |
|
|
|
if self.config.use_par_head: |
|
self.build_nli_head() |
|
|
|
if self.config.use_toxic_head: |
|
self.build_toxic_head() |
|
|
|
if self.config.use_ner_head: |
|
self.build_tag_head() |
|
|
|
self.compile_model() |
|
|
|
def compile_model(self): |
|
|
|
log_this("Compiling") |
|
|
|
self.train_model = tf.keras.models.Model(inputs=self.inputs, outputs=[self.loss]) |
|
|
|
opt = tf.keras.optimizers.Adam(learning_rate=self.config.lr) |
|
self.train_model.compile( |
|
optimizer=opt, |
|
loss=average_loss, |
|
metrics=self.metrics) |
|
log_this("The model is built") |
|
|
|
def build_body(self): |
|
self.nlu_encoder = BertLayer( |
|
self.config.module_path, self.config.ctx_len, |
|
n_tune_layers=self.config.n_tune, do_preprocessing=True, |
|
pooling='mean', tune_embeddings=self.config.tune_embs, |
|
trainable=self.config.train_bert) |
|
|
|
def build_tag_head(self): |
|
|
|
log_this("Building tagger head") |
|
|
|
tag_input = layers.Input(shape=(1, ), dtype=tf.string) |
|
tag_label = layers.Input(shape=(self.config.ctx_len, self.config.n_tags,), dtype=tf.float32) |
|
|
|
self.nlu_encoder.as_dict = True |
|
inp_tok_encoded = self.nlu_encoder(tag_input)['token_output'] |
|
self.nlu_encoder.as_dict = False |
|
|
|
tag_mlp = self.build_mlp( |
|
2, self.config.dim, self.config.dim, self.config.n_tags, |
|
name="ner", dropout_rate=self.config.head_dropout_rate) |
|
tag_pred = tf.keras.layers.TimeDistributed(tag_mlp)(inp_tok_encoded) |
|
tag_loss = tf.keras.losses.categorical_crossentropy(tag_label, tag_pred) |
|
|
|
self.tag_model = tf.keras.models.Model(inputs=[tag_input], outputs=[tag_pred], name=f'tagger_model') |
|
self.inputs += [tag_input, tag_label] |
|
self.loss += self.config.tagger_loss_weight * tag_loss |
|
|
|
def build_nli_head(self): |
|
|
|
log_this("Building paraphraser head") |
|
|
|
anc_input = layers.Input(shape=(1,), dtype=tf.string) |
|
pos_input = layers.Input(shape=(1,), dtype=tf.string) |
|
neg_input = layers.Input(shape=(1,), dtype=tf.string) |
|
|
|
anc_encoded = self.nlu_encoder(anc_input) |
|
pos_encoded = self.nlu_encoder(pos_input) |
|
|
|
if self.config.train_bert: |
|
neg_encoded = self.nlu_encoder(neg_input) |
|
par_loss = tf.keras.layers.Lambda(softmax_loss)([anc_encoded, pos_encoded, neg_encoded]) |
|
self.loss += self.config.paraphrase_loss_weight * par_loss |
|
|
|
self.nli_encoder_model = tf.keras.models.Model(inputs=[pos_input], outputs=[pos_encoded]) |
|
|
|
sim = tf.keras.layers.Lambda(cosine_similarity, name='similarity')([anc_encoded, pos_encoded]) |
|
self.sim_model = tf.keras.models.Model(inputs=[anc_input, pos_input], outputs=[sim]) |
|
self.inputs += [anc_input, pos_input, neg_input] |
|
|
|
|
|
def build_toxic_head(self): |
|
|
|
log_this("Building toxic head") |
|
|
|
sent_input = layers.Input(shape=(1, ), dtype=tf.string) |
|
sent_label = layers.Input(shape=(self.config.n_toxic_tags, ), dtype=tf.float32) |
|
|
|
sents_encoded = self.nlu_encoder(sent_input) |
|
|
|
tox_mlp = self.build_mlp( |
|
2, self.config.dim, self.config.dim, self.config.n_toxic_tags, |
|
name="toxic", dropout_rate=self.config.head_dropout_rate) |
|
pred = tox_mlp(sents_encoded) |
|
|
|
tox_loss = tf.keras.losses.categorical_crossentropy(sent_label, pred) |
|
tox_loss = tf.reshape(tox_loss, (-1, 1)) |
|
|
|
self.tox_model = tf.keras.models.Model(inputs=[sent_input], outputs=[pred], name=f'toxic_model') |
|
self.inputs += [sent_input, sent_label] |
|
self.loss += self.config.toxic_loss_weight * tox_loss |