From 3d2db27b720cf7eaa94d536e4368ac8a686a54c2 Mon Sep 17 00:00:00 2001 From: swapnull7 Date: Wed, 4 Dec 2019 16:57:35 -0500 Subject: [PATCH 01/14] Add text style transfer (#1) * initial commit * bug fixes and adjusting conv inputs * separate forward function for Discriminator and Generator and disable Gen training for debugging * remove debugger statement * bug fix * detaching stuff before accumulating * refactor and add component as optional parameter * Add optimizer for and backprop against encoder * Add in README --- examples/README.md | 4 + examples/text_style_transfer/README.md | 108 +++++++++ examples/text_style_transfer/config.py | 107 +++++++++ .../text_style_transfer/ctrl_gen_model.py | 214 ++++++++++++++++++ examples/text_style_transfer/main.py | 196 ++++++++++++++++ examples/text_style_transfer/prepare_data.py | 39 ++++ texar/torch/modules/networks/conv_networks.py | 2 +- texar/torch/utils/__init__.py | 1 + texar/torch/utils/variables.py | 68 ++++++ 9 files changed, 738 insertions(+), 1 deletion(-) create mode 100644 examples/text_style_transfer/README.md create mode 100644 examples/text_style_transfer/config.py create mode 100644 examples/text_style_transfer/ctrl_gen_model.py create mode 100644 examples/text_style_transfer/main.py create mode 100644 examples/text_style_transfer/prepare_data.py create mode 100644 texar/torch/utils/variables.py diff --git a/examples/README.md b/examples/README.md index 72fe20c39..be889deba 100644 --- a/examples/README.md +++ b/examples/README.md @@ -22,6 +22,10 @@ More examples are continuously added... * [vae_text](./vae_text): VAE language model +### GANs / Discriminiator-supervision ### + +* [text_style_transfer](./text_style_transfer): Discriminator supervision for controlled text generation + ### Classifier / Sequence Prediction ### * [bert](./bert): Pre-trained BERT model for text representation diff --git a/examples/text_style_transfer/README.md b/examples/text_style_transfer/README.md new file mode 100644 index 000000000..2f9dd0dbb --- /dev/null +++ b/examples/text_style_transfer/README.md @@ -0,0 +1,108 @@ +# Text Style Transfer # + +This example implements a simplified variant of the `ctrl-gen` model from + +[Toward Controlled Generation of Text](https://arxiv.org/pdf/1703.00955.pdf) +*Zhiting Hu, Zichao Yang, Xiaodan Liang, Ruslan Salakhutdinov, Eric Xing; ICML 2017* + +The model roughly has an architecture of `Encoder--Decoder--Classifier`. Compared to the paper, following simplications are made: + + * Replaces the base Variational Autoencoder (VAE) model with an attentional Autoencoder (AE) -- VAE is not necessary in the text style transfer setting since we do not need to interpolate the latent space as in the paper. + * Attribute classifier (i.e., discriminator) is trained with real data only. Samples generated by the decoder are not used. + * Independency constraint is omitted. + +## Usage ## + +### Dataset ### +Download the yelp sentiment dataset with the following cmd: +``` +python prepare_data.py +``` + +### Train the model ### + +Train the model on the above data to do sentiment transfer. +``` +python main.py --config config +``` + +[config.py](./config.py) contains the data and mode configurations. + +* The model will first be pre-trained for a few epochs (specified in `config.py`). During pre-training, the `Encoder-Decoder` part is trained as an autoencoder, while the `Classifier` part is trained with the classification labels. +* Full-training is then performed for another few epochs. During full-training, the `Classifier` part is fixed, and the `Encoder-Decoder` part is trained to fit the classifier, along with continuing to minimize the autoencoding loss. + +(**Note:** When using your own dataset, make sure to set `max_decoding_length_train` and `max_decoding_length_infer` in [config.py](https://github.com/asyml/texar/blob/master/examples/text_style_transfer/config.py#L85-L86).) + +Training log is printed as below: +``` +gamma: 1.0, lambda_g: 0.0 +step: 1, loss_d: 0.6903 accu_d: 0.5625 +step: 1, loss_g_clas: 0.6991 loss_g: 9.1452 accu_g: 0.2812 loss_g_ae: 9.1452 accu_g_gdy: 0.2969 +step: 500, loss_d: 0.0989 accu_d: 0.9688 +step: 500, loss_g_clas: 0.2985 loss_g: 3.9696 accu_g: 0.8891 loss_g_ae: 3.9696 accu_g_gdy: 0.7734 +... +step: 6500, loss_d: 0.0806 accu_d: 0.9703 +step: 6500, loss_g_clas: 5.7137 loss_g: 0.2887 accu_g: 0.0844 loss_g_ae: 0.2887 accu_g_gdy: 0.0625 +epoch: 1, loss_d: 0.0876 accu_d: 0.9719 +epoch: 1, loss_g_clas: 6.7360 loss_g: 0.2195 accu_g: 0.0627 loss_g_ae: 0.2195 accu_g_gdy: 0.0642 +val: accu_g: 0.0445 loss_g_ae: 0.1302 accu_d: 0.9774 bleu: 90.7896 loss_g: 0.1302 loss_d: 0.0666 loss_g_clas: 7.0310 accu_g_gdy: 0.0482 +... + +``` +where: +- `loss_d` and `accu_d` are the classification loss/accuracy of the `Classifier` part. +- `loss_g_clas` is the classification loss of the generated sentences. +- `loss_g_ae` is the autoencoding loss. +- `loss_g` is the joint loss `= loss_g_ae + lambda_g * loss_g_clas`. +- `accu_g` is the classification accuracy of the generated sentences with soft represetations (i.e., Gumbel-softmax). +- `accu_g_gdy` is the classification accuracy of the generated sentences with greedy decoding. +- `bleu` is the BLEU score between the generated and input sentences. + +## Results ## + +Text style transfer has two primary goals: +1. The generated sentence should have desired attribute (e.g., positive/negative sentiment) +2. The generated sentence should keep the content of the original one + +We use automatic metrics to evaluate both: +* For (1), we can use a pre-trained classifier to classify the generated sentences and evaluate the accuracy (the higher the better). In this code we have not implemented a stand-alone classifier for evaluation, which could be very easy though. The `Classifier` part in the model gives a reasonably good estimation (i.e., `accu_g_gdy` in the above) of the accuracy. +* For (2), we evaluate the BLEU score between the generated sentences and the original sentences, i.e., `bleu` in the above (the higher the better) (See [Yang et al., 2018](https://arxiv.org/pdf/1805.11749.pdf) for more details.) + +The implementation here gives the following performance after 10 epochs of pre-training and 2 epochs of full-training: + +| Accuracy (by the `Classifier` part) | BLEU (with the original sentence) | +| -------------------------------------| ----------------------------------| +| 0.92 | 54.0 | + +Also refer to the following papers that used this code and compared to other text style transfer approaches: + +* [Unsupervised Text Style Transfer using Language Models as Discriminators](https://papers.nips.cc/paper/7959-unsupervised-text-style-transfer-using-language-models-as-discriminators.pdf). Zichao Yang, Zhiting Hu, Chris Dyer, Eric Xing, Taylor Berg-Kirkpatrick. NeurIPS 2018 +* [Structured Content Preservation for Unsupervised Text Style Transfer](https://arxiv.org/pdf/1810.06526.pdf). Youzhi Tian, Zhiting Hu, Zhou Yu. 2018 + +### Samples ### +Here are some randomly-picked samples. In each pair, the first sentence is the original sentence and the second is the generated. +``` +go to place for client visits with gorgeous views . +go to place for client visits with lacking views . + +there was lots of people but they still managed to provide great service . +there was lots of people but they still managed to provide careless service . + +this was the best dining experience i have ever had . +this was the worst dining experience i have ever had . + +needless to say , we skipped desert . +gentle to say , we edgy desert . + +the first time i was missing an entire sandwich and a side of fries . +the first time i was beautifully an entire sandwich and a side of fries . + +her boutique has a fabulous selection of designer brands ! +her annoying has a sketchy selection of bland warned ! + +service is pretty good . +service is trashy rude . + +ok nothing new . +exceptional impressed new . +``` diff --git a/examples/text_style_transfer/config.py b/examples/text_style_transfer/config.py new file mode 100644 index 000000000..94e9f9592 --- /dev/null +++ b/examples/text_style_transfer/config.py @@ -0,0 +1,107 @@ +"""Config +""" +# pylint: disable=invalid-name + +import copy + +# Total number of training epochs (including pre-train and full-train) +max_nepochs = 12 +pretrain_nepochs = 10 # Number of pre-train epochs (training as autoencoder) +display = 500 # Display the training results every N training steps. +# Display the dev results every N training steps (set to a +# very large value to disable it). +display_eval = 1e10 + +sample_path = './samples' +checkpoint_path = './checkpoints' +restore = '' # Model snapshot to restore from + +lambda_g = 0.1 # Weight of the classification loss +gamma_decay = 0.5 # Gumbel-softmax temperature anneal rate + +max_seq_length = 16 # Maximum sequence length in dataset w/o BOS token + +train_data = { + 'batch_size': 64, + # 'seed': 123, + 'datasets': [ + { + 'files': './data/yelp/sentiment.train.text', + 'vocab_file': './data/yelp/vocab', + 'data_name': '' + }, + { + 'files': './data/yelp/sentiment.train.labels', + 'data_type': 'int', + 'data_name': 'labels' + } + ], + 'name': 'train' +} + +val_data = copy.deepcopy(train_data) +val_data['datasets'][0]['files'] = './data/yelp/sentiment.dev.text' +val_data['datasets'][1]['files'] = './data/yelp/sentiment.dev.labels' + +test_data = copy.deepcopy(train_data) +test_data['datasets'][0]['files'] = './data/yelp/sentiment.test.text' +test_data['datasets'][1]['files'] = './data/yelp/sentiment.test.labels' + +model = { + 'dim_c': 200, + 'dim_z': 500, + 'embedder': { + 'dim': 100, + }, + 'max_seq_length': max_seq_length, + 'encoder': { + 'rnn_cell': { + 'type': 'GRUCell', + 'kwargs': { + 'num_units': 700 + }, + 'dropout': { + 'input_keep_prob': 0.5 + } + } + }, + 'decoder': { + 'rnn_cell': { + 'type': 'GRUCell', + 'kwargs': { + 'num_units': 700, + }, + 'dropout': { + 'input_keep_prob': 0.5, + 'output_keep_prob': 0.5 + }, + }, + 'attention': { + 'type': 'BahdanauAttention', + 'kwargs': { + 'num_units': 700, + }, + 'attention_layer_size': 700, + }, + 'max_decoding_length_train': 21, + 'max_decoding_length_infer': 20, + }, + 'classifier': { + 'kernel_size': [3, 4, 5], + 'out_channels': 128, + 'data_format': 'channels_last', + 'other_conv_kwargs': [[{'padding': 1}, {'padding': 2}, {'padding': 2}]], + 'dropout_conv': [1], + 'dropout_rate': 0.5, + 'num_dense_layers': 0, + 'num_classes': 1 + }, + 'opt': { + 'optimizer': { + 'type': 'Adam', + 'kwargs': { + 'lr': 5e-4, + }, + }, + }, +} diff --git a/examples/text_style_transfer/ctrl_gen_model.py b/examples/text_style_transfer/ctrl_gen_model.py new file mode 100644 index 000000000..5a87f8bdf --- /dev/null +++ b/examples/text_style_transfer/ctrl_gen_model.py @@ -0,0 +1,214 @@ +# Copyright 2018 The Texar Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Text style transfer +""" + +# pylint: disable=invalid-name, too-many-locals + +import torch +import torch.nn as nn + +import texar.torch as tx +from texar.torch.modules import WordEmbedder, UnidirectionalRNNEncoder, \ + MLPTransformConnector, AttentionRNNDecoder, \ + GumbelSoftmaxEmbeddingHelper, Conv1DClassifier +from texar.torch.utils import get_batch_size, collect_trainable_variables + + +class CtrlGenModel(nn.Module): + """Control + """ + def __init__(self, vocab: tx.data.Vocab, hparams=None): + super().__init__() + self.vocab = vocab + + self._hparams = tx.HParams(hparams, None) + + self.embedder = WordEmbedder(vocab_size=self.vocab.size, + hparams=self._hparams.embedder) + + self.encoder = UnidirectionalRNNEncoder(input_size=self.embedder.dim, + hparams=self._hparams.encoder) + + # Encodes label + self.label_connector = MLPTransformConnector(output_size=self._hparams.dim_c, + linear_layer_dim=1) + + # Teacher-force decoding and the auto-encoding loss for G + self.decoder = AttentionRNNDecoder( + input_size=self.embedder.dim, + encoder_output_size=(self.encoder.cell.hidden_size), + vocab_size=self.vocab.size, + token_embedder=self.embedder, + hparams=self._hparams.decoder) + + self.connector = MLPTransformConnector(output_size=self.decoder.output_size, + linear_layer_dim=(self._hparams.dim_c + + self._hparams.dim_z)) + + self.classifier = Conv1DClassifier(in_channels=self.embedder.dim, + in_features=self._hparams.max_seq_length, + hparams=self._hparams.classifier) + + self.class_embedder = WordEmbedder(vocab_size=self.vocab.size, + hparams=self._hparams.embedder) + + # Creates optimizers + self.g_vars = collect_trainable_variables( + [self.embedder, self.encoder, self.label_connector, + self.connector, self.decoder]) + + self.d_vars = collect_trainable_variables( + [self.class_embedder, self.classifier]) + + def forward_D(self, inputs, f_labels, mode): + + # Classification loss for the classifier + # Get inputs in correct format, [batch_size, channels, seq_length] + class_inputs = self.class_embedder(ids=inputs['text_ids'][:, 1:]) + class_logits, class_preds = self.classifier( + input=class_inputs, + sequence_length=inputs['length'] - 1) + + sig_ce_logits_loss = nn.BCEWithLogitsLoss() + + loss_d = sig_ce_logits_loss(class_logits, f_labels) + accu_d = tx.evals.accuracy(labels=f_labels, + preds=class_preds) + return { + "loss_d": loss_d, + "accu_d": accu_d + } + + def forward_G(self, inputs, f_labels, gamma, lambda_g, mode): + + # text_ids for encoder, with BOS token removed + enc_text_ids = inputs['text_ids'][:, 1:].long() + enc_inputs = self.embedder(enc_text_ids) + enc_outputs, final_state = self.encoder(enc_inputs, + sequence_length=inputs['length'] - 1) + z = final_state[:, self._hparams.dim_c:] + + labels = inputs['labels'].view(-1, 1).float() + + c = self.label_connector(labels) + c_ = self.label_connector(1 - labels) + h = torch.cat([c, z], dim=1) + h_ = torch.cat([c_, z], dim=1) + + g_outputs, _, _ = self.decoder( + memory=enc_outputs, + memory_sequence_length=inputs['length'] - 1, + initial_state=self.connector(h), + inputs=inputs['text_ids'], + embedding=self.embedder, + sequence_length=inputs['length'] - 1 + ) + + loss_g_ae = tx.losses.sequence_sparse_softmax_cross_entropy( + labels=inputs['text_ids'][:, 1:], + logits=g_outputs.logits, + sequence_length=inputs['length'] - 1, + average_across_timesteps=True, + sum_over_timesteps=False + ) + + # Gumbel-softmax decoding, used in training + start_tokens = torch.ones_like(inputs['labels'].long()) * self.vocab.bos_token_id + end_token = self.vocab.eos_token_id + + gumbel_helper = GumbelSoftmaxEmbeddingHelper(start_tokens=start_tokens, + end_token=end_token, + tau=gamma) + + soft_outputs_, _, soft_length_, = self.decoder( + memory=enc_outputs, + memory_sequence_length=inputs['length'] - 1, + helper=gumbel_helper, + initial_state=self.connector(h_)) + + # Greedy decoding, used in eval + outputs_, _, length_ = self.decoder( + memory=enc_outputs, + memory_sequence_length=inputs['length'] - 1, + decoding_strategy='infer_greedy', initial_state=self.connector(h_), + embedding=self.embedder, start_tokens=start_tokens, end_token=end_token) + + # Get inputs in correct format, [batch_size, channels, seq_length] + soft_inputs = self.class_embedder(soft_ids=soft_outputs_.sample_id) + soft_logits, soft_preds = self.classifier( + input=soft_inputs, + sequence_length=soft_length_) + + sig_ce_logits_loss = nn.BCEWithLogitsLoss() + + loss_g_class = sig_ce_logits_loss(soft_logits, (1 - f_labels)) + + # Accuracy on greedy-decoded samples, for training progress monitoring + # greedy_inputs = self.class_embedder(ids=outputs_.sample_id) + _, gdy_preds = self.classifier( + input=self.class_embedder(ids=outputs_.sample_id), + sequence_length=length_) + + accu_g_gdy = tx.evals.accuracy( + labels=1 - f_labels, preds=gdy_preds) + + # Accuracy on soft samples, for training progress monitoring + accu_g = tx.evals.accuracy(labels=1 - f_labels, + preds=soft_preds) + loss_g = loss_g_ae + lambda_g * loss_g_class + ret = { + "loss_g": loss_g, + "loss_g_ae": loss_g_ae, + "loss_g_class": loss_g_class, + "accu_g": accu_g, + "accu_g_gdy": accu_g_gdy, + } + if mode == 'eval': + ret.update({'outputs': outputs_}) + return ret + + def forward(self, inputs, gamma, lambda_g, mode, component=None): + + f_labels = inputs['labels'].float() + if mode == 'train': + if component == 'D': + ret_d = self.forward_D(inputs, f_labels, mode) + return ret_d + + elif component == 'G': + ret_g = self.forward_G(inputs, f_labels, gamma, lambda_g, mode) + return ret_g + + else: + ret_d = self.forward_D(inputs, f_labels, mode) + ret_g = self.forward_G(inputs, f_labels, gamma, lambda_g, mode) + rets = { + "batch_size": get_batch_size(inputs['text_ids']), + "loss_g": ret_g['loss_g'], + "loss_g_ae": ret_g['loss_g_ae'], + "loss_g_clas": ret_g['loss_g_class'], + "loss_d": ret_d['loss_d_class'], + "accu_d": ret_d['accu_d'], + "accu_g": ret_g['accu_g'], + "accu_g_gdy": ret_g['accu_g_gdy'] + } + samples = { + "original": inputs['text_ids'][:, 1:], + "transferred": rets['outputs'].sample_id + } + return rets, samples + + + diff --git a/examples/text_style_transfer/main.py b/examples/text_style_transfer/main.py new file mode 100644 index 000000000..ace295416 --- /dev/null +++ b/examples/text_style_transfer/main.py @@ -0,0 +1,196 @@ +# Copyright 2018 The Texar Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Text style transfer + +This is a simplified implementation of: + +Toward Controlled Generation of Text, ICML2017 +Zhiting Hu, Zichao Yang, Xiaodan Liang, Ruslan Salakhutdinov, Eric Xing + +Download the data with the cmd: + +$ python prepare_data.py + +Train the model with the cmd: + +$ python main.py --config config +""" + +# pylint: disable=invalid-name, too-many-locals, too-many-arguments, no-member + +import os +import importlib +import numpy as np +import torch +import argparse +import texar.torch as tx + +from ctrl_gen_model import CtrlGenModel + +parser = argparse.ArgumentParser() + +parser.add_argument('--config', default='config', help="The config to use.") + +args = parser.parse_args() + +config = importlib.import_module(args.config) + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +def _main(): + # Data + train_data = tx.data.MultiAlignedData(hparams=config.train_data, device=device) + val_data = tx.data.MultiAlignedData(hparams=config.val_data, device=device) + test_data = tx.data.MultiAlignedData(hparams=config.test_data, device=device) + vocab = train_data.vocab(0) + + # Each training batch is used twice: once for updating the generator and + # once for updating the discriminator. Feedable data iterator is used for + # such case. + iterator = tx.data.DataIterator( + {'train_g': train_data, 'train_d': train_data, + 'val': val_data, 'test': test_data}) + + # Model + gamma_ = 1. + lambda_g_ = 0. + + # Model + model = CtrlGenModel(vocab, hparams=config.model) + model.to(device) + + # create optimizers + train_op_d = tx.core.get_train_op( + params=model.d_vars, + hparams=config.model['opt'] + ) + + train_op_g = tx.core.get_train_op( + params=model.g_vars, + hparams=config.model['opt'] + ) + + train_op_g_ae = tx.core.get_train_op( + params=model.g_vars, + hparams=config.model['opt'] + ) + + def _train_epoch(gamma_, lambda_g_, epoch, verbose=True): + model.train() + avg_meters_d = tx.utils.AverageRecorder(size=10) + avg_meters_g = tx.utils.AverageRecorder(size=10) + iterator.switch_to_dataset("train_g") + step = 0 + for batch in iterator: + step += 1 + + vals_d = model(batch, gamma_, lambda_g_, mode="train", component="D") + loss_d = vals_d['loss_d'] + loss_d.backward() + train_op_d() + recorder_d = {key: value.detach().cpu().data for (key, value) in vals_d.items()} + avg_meters_d.add(recorder_d) + + vals_g = model(batch, gamma_, lambda_g_, mode="train", component="G") + loss_g = vals_g['loss_g'] + loss_g_ae = vals_g['loss_g_ae'] + loss_g_ae.backward(retain_graph=True) + loss_g.backward() + train_op_g_ae() + train_op_g() + + recorder_g = {key: value.detach().cpu().data for (key, value) in vals_g.items()} + avg_meters_g.add(recorder_g) + + if verbose and (step == 1 or step % config.display == 0): + print('step: {}, {}'.format(step, avg_meters_d.to_str(4))) + print('step: {}, {}'.format(step, avg_meters_g.to_str(4))) + + if verbose and step % config.display_eval == 0: + iterator.switch_to_dataset("val") + _eval_epoch(gamma_, lambda_g_, epoch) + + print('epoch: {}, {}'.format(epoch, avg_meters_d.to_str(4))) + print('epoch: {}, {}'.format(epoch, avg_meters_g.to_str(4))) + + @torch.no_grad() + def _eval_epoch(gamma_, lambda_g_, epoch, val_or_test='val'): + model.eval() + avg_meters = tx.utils.AverageRecorder() + iterator.switch_to_dataset(val_or_test) + for batch in iterator: + vals, samples = model(batch, gamma_, lambda_g_, mode='eval') + + batch_size = vals.pop('batch_size') + + # Computes BLEU + hyps = tx.utils.map_ids_to_strs(samples['transferred'], vocab) + + refs = tx.utils.map_ids_to_strs(samples['original'], vocab) + refs = np.expand_dims(refs, axis=1) + + bleu = tx.evals.corpus_bleu_moses(refs, hyps) + vals['bleu'] = bleu + + avg_meters.add(vals, weight=batch_size) + + # Writes samples + tx.utils.write_paired_text( + refs.squeeze(), hyps, + os.path.join(config.sample_path, 'val.%d' % epoch), + append=True, mode='v') + + print('{}: {}'.format( + val_or_test, avg_meters.to_str(precision=4))) + + return avg_meters.avg() + + os.makedirs(config.sample_path, exist_ok=True) + os.makedirs(config.checkpoint_path, exist_ok=True) + + # Runs the logics + if config.restore: + print('Restore from: {}'.format(config.restore)) + ckpt = torch.load(args.restore) + model.load_state_dict(ckpt['model']) + # train_op_d.load_state_dict(ckpt['optimizer_d']) + # train_op_g.load_state_dict(ckpt['optimizer_g']) + + for epoch in range(1, config.max_nepochs + 1): + if epoch > config.pretrain_nepochs: + # Anneals the gumbel-softmax temperature + gamma_ = max(0.001, gamma_ * config.gamma_decay) + lambda_g_ = config.lambda_g + print('gamma: {}, lambda_g: {}'.format(gamma_, lambda_g_)) + + # Train + _train_epoch(gamma_, lambda_g_, epoch) + + # Val + _eval_epoch(gamma_, lambda_g_, epoch, 'val') + + states = { + 'model': model.state_dict(), + 'optimizer_d': train_op_d.state_dict(), + 'optimizer_g': train_op_g.state_dict() + } + torch.save(states, os.path.join(config.checkpoint_path, 'ckpt')) + + # Test + _eval_epoch(gamma_, lambda_g_, epoch, 'test') + + +if __name__ == '__main__': + _main() diff --git a/examples/text_style_transfer/prepare_data.py b/examples/text_style_transfer/prepare_data.py new file mode 100644 index 000000000..84a81e2ed --- /dev/null +++ b/examples/text_style_transfer/prepare_data.py @@ -0,0 +1,39 @@ +# Copyright 2018 The Texar Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Downloads data. +""" +import texar.torch as tx + +# pylint: disable=invalid-name + + +def prepare_data(): + """Downloads data. + """ + tx.data.maybe_download( + urls='https://drive.google.com/file/d/' + '1HaUKEYDBEk6GlJGmXwqYteB-4rS9q8Lg/view?usp=sharing', + path='./', + filenames='yelp.zip', + extract=True) + + +def main(): + """Entrypoint. + """ + prepare_data() + + +if __name__ == '__main__': + main() diff --git a/texar/torch/modules/networks/conv_networks.py b/texar/torch/modules/networks/conv_networks.py index fc6b80678..79fe8dd84 100644 --- a/texar/torch/modules/networks/conv_networks.py +++ b/texar/torch/modules/networks/conv_networks.py @@ -421,7 +421,7 @@ def _activation_hparams(name, kwargs=None): other_kwargs[i] = _to_list(other_kwargs[i], "other_kwargs[i]", len(kernel_size[i])) elif (isinstance(other_kwargs[i], (list, tuple)) - and len(other_kwargs[i]) != kernel_size[i]): + and len(other_kwargs[i]) != len(kernel_size[i])): raise ValueError("The length of hparams['other_conv_kwargs'][i]" " must be equal to the length of " "hparams['kernel_size'][i]") diff --git a/texar/torch/utils/__init__.py b/texar/torch/utils/__init__.py index d33081eb5..d67a31aca 100644 --- a/texar/torch/utils/__init__.py +++ b/texar/torch/utils/__init__.py @@ -21,3 +21,4 @@ from texar.torch.utils.shapes import * from texar.torch.utils.utils import * from texar.torch.utils.utils_io import * +from texar.torch.utils.variables import * diff --git a/texar/torch/utils/variables.py b/texar/torch/utils/variables.py new file mode 100644 index 000000000..80072d961 --- /dev/null +++ b/texar/torch/utils/variables.py @@ -0,0 +1,68 @@ +# Copyright 2018 The Texar Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Utility functions related to variables. +""" + +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division + +# pylint: disable=invalid-name + +import tensorflow as tf + +__all__ = [ + "add_variable", + "collect_trainable_variables" +] + + +def add_variable(variable, var_list): + """Adds variable to a given list. + + Args: + variable: A (list of) variable(s). + var_list (list): The list where the :attr:`variable` are added to. + """ + if isinstance(variable, (list, tuple)): + for var in variable: + add_variable(var, var_list) + else: + # Checking uniqueness gives error + # if variable in var_list: + var_list.append(variable) + + +def collect_trainable_variables(modules): + """Collects all trainable variables of modules. + + Trainable variables included in multiple modules occur only once in the + returned list. + + Args: + modules: A (list of) instance of the subclasses of + :class:`~texar.tf.modules.ModuleBase`. + + Returns: + A list of trainable variables in the modules. + """ + if not isinstance(modules, (list, tuple)): + modules = [modules] + + var_list = [] + for mod in modules: + add_variable(mod.trainable_variables, var_list) + + return var_list From e401929f99f95952ef844d496a8865b6647860bc Mon Sep 17 00:00:00 2001 From: swapnull7 Date: Thu, 5 Dec 2019 13:16:05 -0500 Subject: [PATCH 02/14] Add text style transfer with improvements (#2) * initial commit * bug fixes and adjusting conv inputs * separate forward function for Discriminator and Generator and disable Gen training for debugging * remove debugger statement * bug fix * detaching stuff before accumulating * refactor and add component as optional parameter * Add optimizer for and backprop against encoder * Add in README * more fixes to eval mode * create optimizers so that they can be saved * fix typo --- .../text_style_transfer/ctrl_gen_model.py | 45 ++++++++++--------- examples/text_style_transfer/main.py | 19 ++++---- 2 files changed, 36 insertions(+), 28 deletions(-) diff --git a/examples/text_style_transfer/ctrl_gen_model.py b/examples/text_style_transfer/ctrl_gen_model.py index 5a87f8bdf..7bf39ed3c 100644 --- a/examples/text_style_transfer/ctrl_gen_model.py +++ b/examples/text_style_transfer/ctrl_gen_model.py @@ -107,27 +107,32 @@ def forward_G(self, inputs, f_labels, gamma, lambda_g, mode): h = torch.cat([c, z], dim=1) h_ = torch.cat([c_, z], dim=1) - g_outputs, _, _ = self.decoder( - memory=enc_outputs, - memory_sequence_length=inputs['length'] - 1, - initial_state=self.connector(h), - inputs=inputs['text_ids'], - embedding=self.embedder, - sequence_length=inputs['length'] - 1 - ) - - loss_g_ae = tx.losses.sequence_sparse_softmax_cross_entropy( - labels=inputs['text_ids'][:, 1:], - logits=g_outputs.logits, - sequence_length=inputs['length'] - 1, - average_across_timesteps=True, - sum_over_timesteps=False - ) - # Gumbel-softmax decoding, used in training start_tokens = torch.ones_like(inputs['labels'].long()) * self.vocab.bos_token_id end_token = self.vocab.eos_token_id + if mode == 'train': + g_outputs, _, _ = self.decoder( + memory=enc_outputs, + memory_sequence_length=inputs['length'] - 1, + initial_state=self.connector(h), + inputs=inputs['text_ids'], + embedding=self.embedder, + sequence_length=inputs['length'] - 1 + ) + + loss_g_ae = tx.losses.sequence_sparse_softmax_cross_entropy( + labels=inputs['text_ids'][:, 1:], + logits=g_outputs.logits, + sequence_length=inputs['length'] - 1, + average_across_timesteps=True, + sum_over_timesteps=False + ) + + else: + # for eval, there is no loss + loss_g_ae = 0 + gumbel_helper = GumbelSoftmaxEmbeddingHelper(start_tokens=start_tokens, end_token=end_token, tau=gamma) @@ -198,15 +203,15 @@ def forward(self, inputs, gamma, lambda_g, mode, component=None): "batch_size": get_batch_size(inputs['text_ids']), "loss_g": ret_g['loss_g'], "loss_g_ae": ret_g['loss_g_ae'], - "loss_g_clas": ret_g['loss_g_class'], - "loss_d": ret_d['loss_d_class'], + "loss_g_class": ret_g['loss_g_class'], + "loss_d": ret_d['loss_d'], "accu_d": ret_d['accu_d'], "accu_g": ret_g['accu_g'], "accu_g_gdy": ret_g['accu_g_gdy'] } samples = { "original": inputs['text_ids'][:, 1:], - "transferred": rets['outputs'].sample_id + "transferred": ret_g['outputs'].sample_id } return rets, samples diff --git a/examples/text_style_transfer/main.py b/examples/text_style_transfer/main.py index ace295416..19984d87b 100644 --- a/examples/text_style_transfer/main.py +++ b/examples/text_style_transfer/main.py @@ -72,17 +72,17 @@ def _main(): model.to(device) # create optimizers - train_op_d = tx.core.get_train_op( + train_op_d = tx.core.get_optimizer( params=model.d_vars, hparams=config.model['opt'] ) - train_op_g = tx.core.get_train_op( + train_op_g = tx.core.get_optimizer( params=model.g_vars, hparams=config.model['opt'] ) - train_op_g_ae = tx.core.get_train_op( + train_op_g_ae = tx.core.get_optimizer( params=model.g_vars, hparams=config.model['opt'] ) @@ -94,12 +94,15 @@ def _train_epoch(gamma_, lambda_g_, epoch, verbose=True): iterator.switch_to_dataset("train_g") step = 0 for batch in iterator: + train_op_d.zero_grad() + train_op_g_ae.zero_grad() + train_op_g.zero_grad() step += 1 vals_d = model(batch, gamma_, lambda_g_, mode="train", component="D") loss_d = vals_d['loss_d'] loss_d.backward() - train_op_d() + train_op_d.step() recorder_d = {key: value.detach().cpu().data for (key, value) in vals_d.items()} avg_meters_d.add(recorder_d) @@ -108,8 +111,8 @@ def _train_epoch(gamma_, lambda_g_, epoch, verbose=True): loss_g_ae = vals_g['loss_g_ae'] loss_g_ae.backward(retain_graph=True) loss_g.backward() - train_op_g_ae() - train_op_g() + train_op_g_ae.step() + train_op_g.step() recorder_g = {key: value.detach().cpu().data for (key, value) in vals_g.items()} avg_meters_g.add(recorder_g) @@ -136,9 +139,9 @@ def _eval_epoch(gamma_, lambda_g_, epoch, val_or_test='val'): batch_size = vals.pop('batch_size') # Computes BLEU - hyps = tx.utils.map_ids_to_strs(samples['transferred'], vocab) + hyps = tx.data.map_ids_to_strs(samples['transferred'].cpu(), vocab) - refs = tx.utils.map_ids_to_strs(samples['original'], vocab) + refs = tx.data.map_ids_to_strs(samples['original'].cpu(), vocab) refs = np.expand_dims(refs, axis=1) bleu = tx.evals.corpus_bleu_moses(refs, hyps) From 7bb76b759071a6a9f36786a27df4ecbae1772f0c Mon Sep 17 00:00:00 2001 From: Swapnil Singhavi Date: Thu, 5 Dec 2019 13:24:41 -0500 Subject: [PATCH 03/14] restore optimizers --- examples/text_style_transfer/main.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/text_style_transfer/main.py b/examples/text_style_transfer/main.py index ace295416..e40e2a9e2 100644 --- a/examples/text_style_transfer/main.py +++ b/examples/text_style_transfer/main.py @@ -165,8 +165,8 @@ def _eval_epoch(gamma_, lambda_g_, epoch, val_or_test='val'): print('Restore from: {}'.format(config.restore)) ckpt = torch.load(args.restore) model.load_state_dict(ckpt['model']) - # train_op_d.load_state_dict(ckpt['optimizer_d']) - # train_op_g.load_state_dict(ckpt['optimizer_g']) + train_op_d.load_state_dict(ckpt['optimizer_d']) + train_op_g.load_state_dict(ckpt['optimizer_g']) for epoch in range(1, config.max_nepochs + 1): if epoch > config.pretrain_nepochs: From 5999f69f0c305779142abf7858b158f768032c15 Mon Sep 17 00:00:00 2001 From: Swapnil Singhavi Date: Thu, 5 Dec 2019 13:40:41 -0500 Subject: [PATCH 04/14] Update ctrl_gen_model.py --- examples/text_style_transfer/ctrl_gen_model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/text_style_transfer/ctrl_gen_model.py b/examples/text_style_transfer/ctrl_gen_model.py index 7bf39ed3c..ea2f118d9 100644 --- a/examples/text_style_transfer/ctrl_gen_model.py +++ b/examples/text_style_transfer/ctrl_gen_model.py @@ -161,9 +161,9 @@ def forward_G(self, inputs, f_labels, gamma, lambda_g, mode): loss_g_class = sig_ce_logits_loss(soft_logits, (1 - f_labels)) # Accuracy on greedy-decoded samples, for training progress monitoring - # greedy_inputs = self.class_embedder(ids=outputs_.sample_id) + greedy_inputs = self.class_embedder(ids=outputs_.sample_id) _, gdy_preds = self.classifier( - input=self.class_embedder(ids=outputs_.sample_id), + input=greedy_inputs, sequence_length=length_) accu_g_gdy = tx.evals.accuracy( From 9f0ac5dad6f0e65ed2e44d83fa1e4f0605c220c2 Mon Sep 17 00:00:00 2001 From: Swapnil Singhavi Date: Thu, 5 Dec 2019 14:44:36 -0500 Subject: [PATCH 05/14] remove tensorflow import --- texar/torch/utils/variables.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/texar/torch/utils/variables.py b/texar/torch/utils/variables.py index 80072d961..27925db26 100644 --- a/texar/torch/utils/variables.py +++ b/texar/torch/utils/variables.py @@ -15,14 +15,8 @@ Utility functions related to variables. """ -from __future__ import absolute_import -from __future__ import print_function -from __future__ import division - # pylint: disable=invalid-name -import tensorflow as tf - __all__ = [ "add_variable", "collect_trainable_variables" From c62e3b708eb0d93bb84671ae8100560e9108765c Mon Sep 17 00:00:00 2001 From: swapnull7 Date: Thu, 5 Dec 2019 18:49:19 -0500 Subject: [PATCH 06/14] Add text style transfer (#3) --- examples/text_style_transfer/config.py | 3 +- .../text_style_transfer/ctrl_gen_model.py | 54 ++++++++++--------- examples/text_style_transfer/main.py | 24 ++++++--- 3 files changed, 48 insertions(+), 33 deletions(-) diff --git a/examples/text_style_transfer/config.py b/examples/text_style_transfer/config.py index 94e9f9592..2af91831e 100644 --- a/examples/text_style_transfer/config.py +++ b/examples/text_style_transfer/config.py @@ -3,6 +3,7 @@ # pylint: disable=invalid-name import copy +from typing import Dict, Any # Total number of training epochs (including pre-train and full-train) max_nepochs = 12 @@ -21,7 +22,7 @@ max_seq_length = 16 # Maximum sequence length in dataset w/o BOS token -train_data = { +train_data: Dict[str, Any] = { 'batch_size': 64, # 'seed': 123, 'datasets': [ diff --git a/examples/text_style_transfer/ctrl_gen_model.py b/examples/text_style_transfer/ctrl_gen_model.py index ea2f118d9..af74ebd62 100644 --- a/examples/text_style_transfer/ctrl_gen_model.py +++ b/examples/text_style_transfer/ctrl_gen_model.py @@ -38,12 +38,14 @@ def __init__(self, vocab: tx.data.Vocab, hparams=None): self.embedder = WordEmbedder(vocab_size=self.vocab.size, hparams=self._hparams.embedder) - self.encoder = UnidirectionalRNNEncoder(input_size=self.embedder.dim, - hparams=self._hparams.encoder) + self.encoder = UnidirectionalRNNEncoder( + input_size=self.embedder.dim, + hparams=self._hparams.encoder) # type: UnidirectionalRNNEncoder # Encodes label - self.label_connector = MLPTransformConnector(output_size=self._hparams.dim_c, - linear_layer_dim=1) + self.label_connector = MLPTransformConnector( + output_size=self._hparams.dim_c, + linear_layer_dim=1) # Teacher-force decoding and the auto-encoding loss for G self.decoder = AttentionRNNDecoder( @@ -53,13 +55,14 @@ def __init__(self, vocab: tx.data.Vocab, hparams=None): token_embedder=self.embedder, hparams=self._hparams.decoder) - self.connector = MLPTransformConnector(output_size=self.decoder.output_size, - linear_layer_dim=(self._hparams.dim_c + - self._hparams.dim_z)) + self.connector = MLPTransformConnector( + output_size=self.decoder.output_size, + linear_layer_dim=(self._hparams.dim_c + self._hparams.dim_z)) - self.classifier = Conv1DClassifier(in_channels=self.embedder.dim, - in_features=self._hparams.max_seq_length, - hparams=self._hparams.classifier) + self.classifier = Conv1DClassifier( + in_channels=self.embedder.dim, + in_features=self._hparams.max_seq_length, + hparams=self._hparams.classifier) self.class_embedder = WordEmbedder(vocab_size=self.vocab.size, hparams=self._hparams.embedder) @@ -72,7 +75,7 @@ def __init__(self, vocab: tx.data.Vocab, hparams=None): self.d_vars = collect_trainable_variables( [self.class_embedder, self.classifier]) - def forward_D(self, inputs, f_labels, mode): + def forward_D(self, inputs, f_labels): # Classification loss for the classifier # Get inputs in correct format, [batch_size, channels, seq_length] @@ -96,8 +99,9 @@ def forward_G(self, inputs, f_labels, gamma, lambda_g, mode): # text_ids for encoder, with BOS token removed enc_text_ids = inputs['text_ids'][:, 1:].long() enc_inputs = self.embedder(enc_text_ids) - enc_outputs, final_state = self.encoder(enc_inputs, - sequence_length=inputs['length'] - 1) + enc_outputs, final_state = self.encoder( + enc_inputs, + sequence_length=inputs['length'] - 1) z = final_state[:, self._hparams.dim_c:] labels = inputs['labels'].view(-1, 1).float() @@ -108,7 +112,8 @@ def forward_G(self, inputs, f_labels, gamma, lambda_g, mode): h_ = torch.cat([c_, z], dim=1) # Gumbel-softmax decoding, used in training - start_tokens = torch.ones_like(inputs['labels'].long()) * self.vocab.bos_token_id + start_tokens = torch.ones_like(inputs['labels'].long()) * \ + self.vocab.bos_token_id end_token = self.vocab.eos_token_id if mode == 'train': @@ -133,9 +138,10 @@ def forward_G(self, inputs, f_labels, gamma, lambda_g, mode): # for eval, there is no loss loss_g_ae = 0 - gumbel_helper = GumbelSoftmaxEmbeddingHelper(start_tokens=start_tokens, - end_token=end_token, - tau=gamma) + gumbel_helper = GumbelSoftmaxEmbeddingHelper( + start_tokens=start_tokens, + end_token=end_token, + tau=gamma) soft_outputs_, _, soft_length_, = self.decoder( memory=enc_outputs, @@ -147,8 +153,11 @@ def forward_G(self, inputs, f_labels, gamma, lambda_g, mode): outputs_, _, length_ = self.decoder( memory=enc_outputs, memory_sequence_length=inputs['length'] - 1, - decoding_strategy='infer_greedy', initial_state=self.connector(h_), - embedding=self.embedder, start_tokens=start_tokens, end_token=end_token) + decoding_strategy='infer_greedy', + initial_state=self.connector(h_), + embedding=self.embedder, + start_tokens=start_tokens, + end_token=end_token) # Get inputs in correct format, [batch_size, channels, seq_length] soft_inputs = self.class_embedder(soft_ids=soft_outputs_.sample_id) @@ -189,7 +198,7 @@ def forward(self, inputs, gamma, lambda_g, mode, component=None): f_labels = inputs['labels'].float() if mode == 'train': if component == 'D': - ret_d = self.forward_D(inputs, f_labels, mode) + ret_d = self.forward_D(inputs, f_labels) return ret_d elif component == 'G': @@ -197,7 +206,7 @@ def forward(self, inputs, gamma, lambda_g, mode, component=None): return ret_g else: - ret_d = self.forward_D(inputs, f_labels, mode) + ret_d = self.forward_D(inputs, f_labels) ret_g = self.forward_G(inputs, f_labels, gamma, lambda_g, mode) rets = { "batch_size": get_batch_size(inputs['text_ids']), @@ -214,6 +223,3 @@ def forward(self, inputs, gamma, lambda_g, mode, component=None): "transferred": ret_g['outputs'].sample_id } return rets, samples - - - diff --git a/examples/text_style_transfer/main.py b/examples/text_style_transfer/main.py index 1581fade6..6eb2dee3d 100644 --- a/examples/text_style_transfer/main.py +++ b/examples/text_style_transfer/main.py @@ -31,9 +31,10 @@ import os import importlib +import argparse import numpy as np import torch -import argparse + import texar.torch as tx from ctrl_gen_model import CtrlGenModel @@ -51,9 +52,12 @@ def _main(): # Data - train_data = tx.data.MultiAlignedData(hparams=config.train_data, device=device) - val_data = tx.data.MultiAlignedData(hparams=config.val_data, device=device) - test_data = tx.data.MultiAlignedData(hparams=config.test_data, device=device) + train_data = tx.data.MultiAlignedData(hparams=config.train_data, + device=device) + val_data = tx.data.MultiAlignedData(hparams=config.val_data, + device=device) + test_data = tx.data.MultiAlignedData(hparams=config.test_data, + device=device) vocab = train_data.vocab(0) # Each training batch is used twice: once for updating the generator and @@ -99,14 +103,17 @@ def _train_epoch(gamma_, lambda_g_, epoch, verbose=True): train_op_g.zero_grad() step += 1 - vals_d = model(batch, gamma_, lambda_g_, mode="train", component="D") + vals_d = model(batch, gamma_, lambda_g_, mode="train", + component="D") loss_d = vals_d['loss_d'] loss_d.backward() train_op_d.step() - recorder_d = {key: value.detach().cpu().data for (key, value) in vals_d.items()} + recorder_d = {key: value.detach().cpu().data + for (key, value) in vals_d.items()} avg_meters_d.add(recorder_d) - vals_g = model(batch, gamma_, lambda_g_, mode="train", component="G") + vals_g = model(batch, gamma_, lambda_g_, mode="train", + component="G") loss_g = vals_g['loss_g'] loss_g_ae = vals_g['loss_g_ae'] loss_g_ae.backward(retain_graph=True) @@ -114,7 +121,8 @@ def _train_epoch(gamma_, lambda_g_, epoch, verbose=True): train_op_g_ae.step() train_op_g.step() - recorder_g = {key: value.detach().cpu().data for (key, value) in vals_g.items()} + recorder_g = {key: value.detach().cpu().data + for (key, value) in vals_g.items()} avg_meters_g.add(recorder_g) if verbose and (step == 1 or step % config.display == 0): From 6c5b81fb0b99fad14af7aa479e5f937c068fe538 Mon Sep 17 00:00:00 2001 From: swapnull7 Date: Sun, 8 Dec 2019 00:58:36 -0500 Subject: [PATCH 07/14] Add text style transfer (#4) * initial commit * bug fixes and adjusting conv inputs * separate forward function for Discriminator and Generator and disable Gen training for debugging * remove debugger statement * bug fix * detaching stuff before accumulating * refactor and add component as optional parameter * Add optimizer for and backprop against encoder * Add in README * more fixes to eval mode * create optimizers so that they can be saved * fix typo * linting issues * add type annotation for encoder * fix linting * Isolate AE in training * works after changing the learning rate * remove debugger --- examples/text_style_transfer/config.py | 2 +- examples/text_style_transfer/ctrl_gen_model.py | 13 ++++++++----- examples/text_style_transfer/main.py | 17 ++++++++++------- 3 files changed, 19 insertions(+), 13 deletions(-) diff --git a/examples/text_style_transfer/config.py b/examples/text_style_transfer/config.py index 2af91831e..bb9f1dc13 100644 --- a/examples/text_style_transfer/config.py +++ b/examples/text_style_transfer/config.py @@ -101,7 +101,7 @@ 'optimizer': { 'type': 'Adam', 'kwargs': { - 'lr': 5e-4, + 'lr': 3e-4, }, }, }, diff --git a/examples/text_style_transfer/ctrl_gen_model.py b/examples/text_style_transfer/ctrl_gen_model.py index af74ebd62..9d5ba9863 100644 --- a/examples/text_style_transfer/ctrl_gen_model.py +++ b/examples/text_style_transfer/ctrl_gen_model.py @@ -50,7 +50,7 @@ def __init__(self, vocab: tx.data.Vocab, hparams=None): # Teacher-force decoding and the auto-encoding loss for G self.decoder = AttentionRNNDecoder( input_size=self.embedder.dim, - encoder_output_size=(self.encoder.cell.hidden_size), + encoder_output_size=self.encoder.cell.hidden_size, vocab_size=self.vocab.size, token_embedder=self.embedder, hparams=self._hparams.decoder) @@ -69,8 +69,8 @@ def __init__(self, vocab: tx.data.Vocab, hparams=None): # Creates optimizers self.g_vars = collect_trainable_variables( - [self.embedder, self.encoder, self.label_connector, - self.connector, self.decoder]) + [self.decoder, self.connector, self.label_connector, + self.encoder, self.embedder]) self.d_vars = collect_trainable_variables( [self.class_embedder, self.classifier]) @@ -122,7 +122,6 @@ def forward_G(self, inputs, f_labels, gamma, lambda_g, mode): memory_sequence_length=inputs['length'] - 1, initial_state=self.connector(h), inputs=inputs['text_ids'], - embedding=self.embedder, sequence_length=inputs['length'] - 1 ) @@ -133,6 +132,11 @@ def forward_G(self, inputs, f_labels, gamma, lambda_g, mode): average_across_timesteps=True, sum_over_timesteps=False ) + if lambda_g == 0: + ret = { + "loss_g_ae": loss_g_ae, + } + return ret else: # for eval, there is no loss @@ -155,7 +159,6 @@ def forward_G(self, inputs, f_labels, gamma, lambda_g, mode): memory_sequence_length=inputs['length'] - 1, decoding_strategy='infer_greedy', initial_state=self.connector(h_), - embedding=self.embedder, start_tokens=start_tokens, end_token=end_token) diff --git a/examples/text_style_transfer/main.py b/examples/text_style_transfer/main.py index 6eb2dee3d..6ca76ccb7 100644 --- a/examples/text_style_transfer/main.py +++ b/examples/text_style_transfer/main.py @@ -114,18 +114,21 @@ def _train_epoch(gamma_, lambda_g_, epoch, verbose=True): vals_g = model(batch, gamma_, lambda_g_, mode="train", component="G") - loss_g = vals_g['loss_g'] - loss_g_ae = vals_g['loss_g_ae'] - loss_g_ae.backward(retain_graph=True) - loss_g.backward() - train_op_g_ae.step() - train_op_g.step() + + if epoch <= config.pretrain_nepochs: + loss_g_ae = vals_g['loss_g_ae'] + loss_g_ae.backward() + train_op_g_ae.step() + else: + loss_g = vals_g['loss_g'] + loss_g.backward() + train_op_g.step() recorder_g = {key: value.detach().cpu().data for (key, value) in vals_g.items()} avg_meters_g.add(recorder_g) - if verbose and (step == 1 or step % config.display == 0): + if verbose and (step == 1 or step % config.display >= 0): print('step: {}, {}'.format(step, avg_meters_d.to_str(4))) print('step: {}, {}'.format(step, avg_meters_g.to_str(4))) From 9ce07e5fd5eb5fd1a99a633241d07c285d4ea065 Mon Sep 17 00:00:00 2001 From: swapnull7 Date: Mon, 9 Dec 2019 15:32:11 -0500 Subject: [PATCH 08/14] Add text style transfer (#5) * Reviewed changes * linting --- .codecov.yml | 27 --------- codecov.yml | 7 +++ docs/code/utils.rst | 15 +++++ docs/examples.md | 9 +++ examples/README.md | 4 ++ examples/text_style_transfer/README.md | 58 +++++++++---------- .../text_style_transfer/ctrl_gen_model.py | 13 ++--- examples/text_style_transfer/main.py | 11 ++-- examples/text_style_transfer/prepare_data.py | 4 +- texar/torch/utils/variables.py | 22 +++---- 10 files changed, 87 insertions(+), 83 deletions(-) delete mode 100644 .codecov.yml create mode 100644 codecov.yml diff --git a/.codecov.yml b/.codecov.yml deleted file mode 100644 index ef5e6772a..000000000 --- a/.codecov.yml +++ /dev/null @@ -1,27 +0,0 @@ -codecov: - require_ci_to_pass: yes - -coverage: - precision: 2 - round: down - range: "70...100" - - status: - project: - default: - threshold: 1% - patch: off - changes: no - -parsers: - gcov: - branch_detection: - conditional: yes - loop: yes - method: no - macro: no - -comment: - layout: "reach,diff,flags,tree" - behavior: default - require_changes: no diff --git a/codecov.yml b/codecov.yml new file mode 100644 index 000000000..b82018c1c --- /dev/null +++ b/codecov.yml @@ -0,0 +1,7 @@ +coverage: + status: + project: + default: + threshold: 1% + + patch: false diff --git a/docs/code/utils.rst b/docs/code/utils.rst index 5c14af6b0..14f351fe4 100644 --- a/docs/code/utils.rst +++ b/docs/code/utils.rst @@ -12,6 +12,10 @@ Frequent Use .. autoclass:: texar.torch.utils.AverageRecorder :members: +:hidden:`collect_trainable_variables` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. autofunction:: texar.torch.utils.collect_trainable_variables + :hidden:`compat_as_text` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autofunction:: texar.torch.utils.compat_as_text @@ -20,6 +24,17 @@ Frequent Use ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autofunction:: texar.torch.utils.write_paired_text +Variables +========= + +:hidden:`collect_trainable_variables` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. autofunction:: texar.torch.utils.collect_trainable_variables + +:hidden:`add_variable` +~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. autofunction:: texar.torch.utils.add_variable + IO === diff --git a/docs/examples.md b/docs/examples.md index 429d3b02b..3002f81fd 100644 --- a/docs/examples.md +++ b/docs/examples.md @@ -22,6 +22,10 @@ More examples are continuously added... * [bert](https://github.com/asyml/texar-pytorch/tree/master/examples/bert): Pre-trained BERT model for text representation * [xlnet](https://github.com/asyml/texar-pytorch/tree/master/examples/xlnet): Pre-trained XLNet model for text representation +### GANs / Discriminiator-supervision ### + +* [text_style_transfer](https://github.com/asyml/texar-pytorch/tree/master/examples/text_style_transfer): Discriminator supervision for controlled text generation + --- ## Examples by Tasks @@ -35,6 +39,11 @@ More examples are continuously added... * [seq2seq_attn](https://github.com/asyml/texar-pytorch/tree/master/examples/seq2seq_attn): Attentional seq2seq * [transformer](https://github.com/asyml/texar-pytorch/tree/master/examples/transformer): Transformer for machine translation +### Text Style Transfer ### + +* [text_style_transfer](https://github.com/asyml/texar-pytorch/tree/master/examples/text_style_transfer): Discriminator supervision for controlled text generation + + ### Classification ### * [bert](https://github.com/asyml/texar-pytorch/tree/master/examples/bert): Pre-trained BERT model for text representation diff --git a/examples/README.md b/examples/README.md index be889deba..790b232d0 100644 --- a/examples/README.md +++ b/examples/README.md @@ -47,6 +47,10 @@ More examples are continuously added... * [seq2seq_attn](./seq2seq_attn): Attentional seq2seq * [transformer](./transformer): Transformer for machine translation +### Text Style Transfer ### + +* [text_style_transfer](./text_style_transfer): Discriminator supervision for controlled text generation + ### Classification ### * [bert](./bert): Pre-trained BERT model for text representation diff --git a/examples/text_style_transfer/README.md b/examples/text_style_transfer/README.md index 2f9dd0dbb..94c500adb 100644 --- a/examples/text_style_transfer/README.md +++ b/examples/text_style_transfer/README.md @@ -14,7 +14,7 @@ The model roughly has an architecture of `Encoder--Decoder--Classifier`. Compare ## Usage ## ### Dataset ### -Download the yelp sentiment dataset with the following cmd: +Download the yelp sentiment dataset with the following command: ``` python prepare_data.py ``` @@ -36,24 +36,25 @@ python main.py --config config Training log is printed as below: ``` gamma: 1.0, lambda_g: 0.0 -step: 1, loss_d: 0.6903 accu_d: 0.5625 -step: 1, loss_g_clas: 0.6991 loss_g: 9.1452 accu_g: 0.2812 loss_g_ae: 9.1452 accu_g_gdy: 0.2969 -step: 500, loss_d: 0.0989 accu_d: 0.9688 -step: 500, loss_g_clas: 0.2985 loss_g: 3.9696 accu_g: 0.8891 loss_g_ae: 3.9696 accu_g_gdy: 0.7734 +step: 1, loss_d: 0.6934 accu_d: 0.4844 +step: 1, loss_g_ae: 9.1392 +step: 500, loss_d: 0.1488 accu_d: 0.9484 +step: 500, loss_g_ae: 4.2884 +step: 1000, loss_d: 0.1215 accu_d: 0.9625 +step: 1000, loss_g_ae: 2.6201 ... -step: 6500, loss_d: 0.0806 accu_d: 0.9703 -step: 6500, loss_g_clas: 5.7137 loss_g: 0.2887 accu_g: 0.0844 loss_g_ae: 0.2887 accu_g_gdy: 0.0625 -epoch: 1, loss_d: 0.0876 accu_d: 0.9719 -epoch: 1, loss_g_clas: 6.7360 loss_g: 0.2195 accu_g: 0.0627 loss_g_ae: 0.2195 accu_g_gdy: 0.0642 -val: accu_g: 0.0445 loss_g_ae: 0.1302 accu_d: 0.9774 bleu: 90.7896 loss_g: 0.1302 loss_d: 0.0666 loss_g_clas: 7.0310 accu_g_gdy: 0.0482 +epoch: 1, loss_d: 0.0750 accu_d: 0.9688 +epoch: 1, loss_g_ae: 0.8832 +val: loss_g: 0.0000 loss_g_ae: 0.0000 loss_g_class: 3.2949 loss_d: 0.0702 accu_d: 0.9744 accu_g: 0.3022 accu_g_gdy: 0.2732 bleu: 60.8234 +test: loss_g: 0.0000 loss_g_ae: 0.0000 loss_g_class: 3.2359 loss_d: 0.0746 accu_d: 0.9733 accu_g: 0.3076 accu_g_gdy: 0.2791 bleu: 60.1810993 accu_g_gdy: 0.5993 bleu: 63.6671 ... ``` where: - `loss_d` and `accu_d` are the classification loss/accuracy of the `Classifier` part. -- `loss_g_clas` is the classification loss of the generated sentences. +- `loss_g_class` is the classification loss of the generated sentences. - `loss_g_ae` is the autoencoding loss. -- `loss_g` is the joint loss `= loss_g_ae + lambda_g * loss_g_clas`. +- `loss_g` is the joint loss `= loss_g_ae + lambda_g * loss_g_class`. - `accu_g` is the classification accuracy of the generated sentences with soft represetations (i.e., Gumbel-softmax). - `accu_g_gdy` is the classification accuracy of the generated sentences with greedy decoding. - `bleu` is the BLEU score between the generated and input sentences. @@ -72,7 +73,7 @@ The implementation here gives the following performance after 10 epochs of pre-t | Accuracy (by the `Classifier` part) | BLEU (with the original sentence) | | -------------------------------------| ----------------------------------| -| 0.92 | 54.0 | +| 0.96 | 52.0 | Also refer to the following papers that used this code and compared to other text style transfer approaches: @@ -82,27 +83,24 @@ Also refer to the following papers that used this code and compared to other tex ### Samples ### Here are some randomly-picked samples. In each pair, the first sentence is the original sentence and the second is the generated. ``` -go to place for client visits with gorgeous views . -go to place for client visits with lacking views . +love , love love . +poor , poor poor . -there was lots of people but they still managed to provide great service . -there was lots of people but they still managed to provide careless service . +good atmosphere . +disgusted atmosphere . -this was the best dining experience i have ever had . -this was the worst dining experience i have ever had . +the donuts are good sized and very well priced . +the donuts are disgusted sized and very _num_ priced . -needless to say , we skipped desert . -gentle to say , we edgy desert . +it is always clean and the staff is super friendly . +it is nasty overpriced and the staff is super cold . -the first time i was missing an entire sandwich and a side of fries . -the first time i was beautifully an entire sandwich and a side of fries . +super sweet place . +super plain place . -her boutique has a fabulous selection of designer brands ! -her annoying has a sketchy selection of bland warned ! +highly recommended . +horrible horrible . -service is pretty good . -service is trashy rude . - -ok nothing new . -exceptional impressed new . +very good ingredients . +very disgusted ingredients . ``` diff --git a/examples/text_style_transfer/ctrl_gen_model.py b/examples/text_style_transfer/ctrl_gen_model.py index 9d5ba9863..aee2eaebf 100644 --- a/examples/text_style_transfer/ctrl_gen_model.py +++ b/examples/text_style_transfer/ctrl_gen_model.py @@ -1,4 +1,4 @@ -# Copyright 2018 The Texar Authors. All Rights Reserved. +# Copyright 2019 The Texar Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,6 +18,8 @@ import torch import torch.nn as nn +from torch.nn import functional as F + import texar.torch as tx from texar.torch.modules import WordEmbedder, UnidirectionalRNNEncoder, \ @@ -84,9 +86,7 @@ def forward_D(self, inputs, f_labels): input=class_inputs, sequence_length=inputs['length'] - 1) - sig_ce_logits_loss = nn.BCEWithLogitsLoss() - - loss_d = sig_ce_logits_loss(class_logits, f_labels) + loss_d = F.binary_cross_entropy_with_logits(class_logits, f_labels) accu_d = tx.evals.accuracy(labels=f_labels, preds=class_preds) return { @@ -168,9 +168,8 @@ def forward_G(self, inputs, f_labels, gamma, lambda_g, mode): input=soft_inputs, sequence_length=soft_length_) - sig_ce_logits_loss = nn.BCEWithLogitsLoss() - - loss_g_class = sig_ce_logits_loss(soft_logits, (1 - f_labels)) + loss_g_class = F.binary_cross_entropy_with_logits(soft_logits, + (1 - f_labels)) # Accuracy on greedy-decoded samples, for training progress monitoring greedy_inputs = self.class_embedder(ids=outputs_.sample_id) diff --git a/examples/text_style_transfer/main.py b/examples/text_style_transfer/main.py index 6ca76ccb7..6488d9f77 100644 --- a/examples/text_style_transfer/main.py +++ b/examples/text_style_transfer/main.py @@ -1,4 +1,4 @@ -# Copyright 2018 The Texar Authors. All Rights Reserved. +# Copyright 2019 The Texar Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -50,7 +50,7 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu") -def _main(): +def main(): # Data train_data = tx.data.MultiAlignedData(hparams=config.train_data, device=device) @@ -64,7 +64,7 @@ def _main(): # once for updating the discriminator. Feedable data iterator is used for # such case. iterator = tx.data.DataIterator( - {'train_g': train_data, 'train_d': train_data, + {'train': train_data, 'val': val_data, 'test': test_data}) # Model @@ -95,7 +95,7 @@ def _train_epoch(gamma_, lambda_g_, epoch, verbose=True): model.train() avg_meters_d = tx.utils.AverageRecorder(size=10) avg_meters_g = tx.utils.AverageRecorder(size=10) - iterator.switch_to_dataset("train_g") + iterator.switch_to_dataset("train") step = 0 for batch in iterator: train_op_d.zero_grad() @@ -133,7 +133,6 @@ def _train_epoch(gamma_, lambda_g_, epoch, verbose=True): print('step: {}, {}'.format(step, avg_meters_g.to_str(4))) if verbose and step % config.display_eval == 0: - iterator.switch_to_dataset("val") _eval_epoch(gamma_, lambda_g_, epoch) print('epoch: {}, {}'.format(epoch, avg_meters_d.to_str(4))) @@ -207,4 +206,4 @@ def _eval_epoch(gamma_, lambda_g_, epoch, val_or_test='val'): if __name__ == '__main__': - _main() + main() diff --git a/examples/text_style_transfer/prepare_data.py b/examples/text_style_transfer/prepare_data.py index 84a81e2ed..c01da312e 100644 --- a/examples/text_style_transfer/prepare_data.py +++ b/examples/text_style_transfer/prepare_data.py @@ -1,4 +1,4 @@ -# Copyright 2018 The Texar Authors. All Rights Reserved. +# Copyright 2019 The Texar Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,8 +15,6 @@ """ import texar.torch as tx -# pylint: disable=invalid-name - def prepare_data(): """Downloads data. diff --git a/texar/torch/utils/variables.py b/texar/torch/utils/variables.py index 27925db26..94714c4b1 100644 --- a/texar/torch/utils/variables.py +++ b/texar/torch/utils/variables.py @@ -1,4 +1,4 @@ -# Copyright 2018 The Texar Authors. All Rights Reserved. +# Copyright 2019 The Texar Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,7 +15,7 @@ Utility functions related to variables. """ -# pylint: disable=invalid-name +from typing import Any, List, Tuple, Union __all__ = [ "add_variable", @@ -23,8 +23,10 @@ ] -def add_variable(variable, var_list): - """Adds variable to a given list. +def add_variable( + variable: Union[List[Any], Tuple[Any]], + var_list: List[Any]): + r"""Adds variable to a given list. Args: variable: A (list of) variable(s). @@ -34,20 +36,20 @@ def add_variable(variable, var_list): for var in variable: add_variable(var, var_list) else: - # Checking uniqueness gives error - # if variable in var_list: var_list.append(variable) -def collect_trainable_variables(modules): - """Collects all trainable variables of modules. +def collect_trainable_variables( + modules: Union[Any, List[Any]] +): + r"""Collects all trainable variables of modules. Trainable variables included in multiple modules occur only once in the returned list. Args: modules: A (list of) instance of the subclasses of - :class:`~texar.tf.modules.ModuleBase`. + :class:`~texar.torch.modules.ModuleBase`. Returns: A list of trainable variables in the modules. @@ -55,7 +57,7 @@ def collect_trainable_variables(modules): if not isinstance(modules, (list, tuple)): modules = [modules] - var_list = [] + var_list: List[Any] = [] for mod in modules: add_variable(mod.trainable_variables, var_list) From 60550758d85a004fd82e852784eceae030ee50f1 Mon Sep 17 00:00:00 2001 From: swapnull7 Date: Mon, 9 Dec 2019 17:16:19 -0500 Subject: [PATCH 09/14] Add text style transfer (#6) * initial commit * linting --- texar/torch/utils/variables.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/texar/torch/utils/variables.py b/texar/torch/utils/variables.py index 94714c4b1..27bb48379 100644 --- a/texar/torch/utils/variables.py +++ b/texar/torch/utils/variables.py @@ -15,7 +15,10 @@ Utility functions related to variables. """ -from typing import Any, List, Tuple, Union +from typing import List, Tuple, Union +import torch.nn as nn + +from texar.torch.module_base import ModuleBase __all__ = [ "add_variable", @@ -24,8 +27,8 @@ def add_variable( - variable: Union[List[Any], Tuple[Any]], - var_list: List[Any]): + variable: Union[List[nn.Parameter], Tuple[nn.Parameter]], + var_list: List[nn.Parameter]): r"""Adds variable to a given list. Args: @@ -40,7 +43,7 @@ def add_variable( def collect_trainable_variables( - modules: Union[Any, List[Any]] + modules: Union[ModuleBase, List[ModuleBase]] ): r"""Collects all trainable variables of modules. @@ -57,7 +60,7 @@ def collect_trainable_variables( if not isinstance(modules, (list, tuple)): modules = [modules] - var_list: List[Any] = [] + var_list: List[nn.Parameter] = [] for mod in modules: add_variable(mod.trainable_variables, var_list) From 6d7a0cdc15374099b58a788bfb57a4f21775631f Mon Sep 17 00:00:00 2001 From: Swapnil Singhavi Date: Wed, 18 Dec 2019 12:14:13 -0500 Subject: [PATCH 10/14] Fix docs build issue --- docs/code/utils.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/code/utils.rst b/docs/code/utils.rst index 14f351fe4..21aeb0cbf 100644 --- a/docs/code/utils.rst +++ b/docs/code/utils.rst @@ -13,7 +13,7 @@ Frequent Use :members: :hidden:`collect_trainable_variables` -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autofunction:: texar.torch.utils.collect_trainable_variables :hidden:`compat_as_text` From a6a047220240e09ccbe1c0276f44b43a12b65762 Mon Sep 17 00:00:00 2001 From: Swapnil Singhavi Date: Wed, 18 Dec 2019 12:55:37 -0500 Subject: [PATCH 11/14] Fix typo --- docs/examples.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/examples.md b/docs/examples.md index 3002f81fd..b10bcd3db 100644 --- a/docs/examples.md +++ b/docs/examples.md @@ -22,7 +22,7 @@ More examples are continuously added... * [bert](https://github.com/asyml/texar-pytorch/tree/master/examples/bert): Pre-trained BERT model for text representation * [xlnet](https://github.com/asyml/texar-pytorch/tree/master/examples/xlnet): Pre-trained XLNet model for text representation -### GANs / Discriminiator-supervision ### +### GANs / Discriminator-supervision ### * [text_style_transfer](https://github.com/asyml/texar-pytorch/tree/master/examples/text_style_transfer): Discriminator supervision for controlled text generation From 6f63f28d7073a2a598aae5c1b7ec0806d26dbb45 Mon Sep 17 00:00:00 2001 From: Swapnil Singhavi Date: Mon, 23 Dec 2019 17:20:59 -0500 Subject: [PATCH 12/14] Make sure all variables are appended only once --- texar/torch/utils/variables.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/texar/torch/utils/variables.py b/texar/torch/utils/variables.py index 27bb48379..1188bad79 100644 --- a/texar/torch/utils/variables.py +++ b/texar/torch/utils/variables.py @@ -15,7 +15,7 @@ Utility functions related to variables. """ -from typing import List, Tuple, Union +from typing import List, Tuple, Union, Set import torch.nn as nn from texar.torch.module_base import ModuleBase @@ -28,18 +28,19 @@ def add_variable( variable: Union[List[nn.Parameter], Tuple[nn.Parameter]], - var_list: List[nn.Parameter]): + var_list: Set[nn.Parameter]): r"""Adds variable to a given list. Args: variable: A (list of) variable(s). - var_list (list): The list where the :attr:`variable` are added to. + var_list (set): The set where the trainable params are added to. """ if isinstance(variable, (list, tuple)): for var in variable: add_variable(var, var_list) else: - var_list.append(variable) + if variable not in var_list: + var_list.add(variable) def collect_trainable_variables( @@ -60,8 +61,8 @@ def collect_trainable_variables( if not isinstance(modules, (list, tuple)): modules = [modules] - var_list: List[nn.Parameter] = [] + var_list: Set[nn.Parameter] = set() for mod in modules: add_variable(mod.trainable_variables, var_list) - return var_list + return list(var_list) From 89270393129760602872679e1c113e28de89ed9e Mon Sep 17 00:00:00 2001 From: Swapnil Singhavi Date: Tue, 24 Dec 2019 15:20:01 -0500 Subject: [PATCH 13/14] Update main.py --- examples/text_style_transfer/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/text_style_transfer/main.py b/examples/text_style_transfer/main.py index 6488d9f77..16a5eff54 100644 --- a/examples/text_style_transfer/main.py +++ b/examples/text_style_transfer/main.py @@ -128,7 +128,7 @@ def _train_epoch(gamma_, lambda_g_, epoch, verbose=True): for (key, value) in vals_g.items()} avg_meters_g.add(recorder_g) - if verbose and (step == 1 or step % config.display >= 0): + if verbose and (step == 1 or step % config.display == 0): print('step: {}, {}'.format(step, avg_meters_d.to_str(4))) print('step: {}, {}'.format(step, avg_meters_g.to_str(4))) From ca88d144c5e068bab90c4820b6875147107f60be Mon Sep 17 00:00:00 2001 From: Swapnil Singhavi Date: Thu, 26 Dec 2019 13:22:56 -0500 Subject: [PATCH 14/14] fix docstrings --- texar/torch/utils/variables.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/texar/torch/utils/variables.py b/texar/torch/utils/variables.py index 1188bad79..687ed1c40 100644 --- a/texar/torch/utils/variables.py +++ b/texar/torch/utils/variables.py @@ -27,13 +27,13 @@ def add_variable( - variable: Union[List[nn.Parameter], Tuple[nn.Parameter]], + variable: Union[List[nn.Parameter], Tuple[nn.Parameter], nn.Parameter], var_list: Set[nn.Parameter]): r"""Adds variable to a given list. Args: variable: A (list of) variable(s). - var_list (set): The set where the trainable params are added to. + var_list (set): The set where the trainable parameters are added to. """ if isinstance(variable, (list, tuple)): for var in variable: