Coder Social home page Coder Social logo

sdr's Introduction

Self-Supervised Document Similarity Ranking (SDR) via Contextualized Language Models and Hierarchical Inference

This repo is the implementation for SDR.

 

Tested environment

  • Python 3.7
  • PyTorch 1.7
  • CUDA 11.0

Lower CUDA and PyTorch versions should work as well.

 

Contents

License, Security, support and code of conduct specifications are under the Instructions directory.  

Installation

Run

bash instructions/installation.sh 

 

Datasets

The published datasets are:

  • Video games
    • 21,935 articles
    • Expert annotated test set. 90 articles with 12 ground-truth recommendations.
    • Examples:
      • Grand Theft Auto - Mafia
      • Burnout Paradise - Forza Horizon 3
  • Wines
    • 1635 articles
    • Crafted by a human sommelier, 92 articles with ~10 ground-truth recommendations.
    • Examples:
      • Pinot Meunier - Chardonnay
      • Dom Pérignon - Moët & Chandon

For more details and direct download see Wines and Video Games.

 

Training

The training process downloads the datasets automatically.

python sdr_main.py --dataset_name video_games

The code is based on PyTorch-Lightning, all PL hyperparameters are supported. (limit_train/val/test_batches, check_val_every_n_epoch etc.)

Tensorboard support

All metrics are being logged automatically and stored in

SDR/output/document_similarity/SDR/arch_SDR/dataset_name_<dataset>/<time_of_run>

Run tesnroboard --logdir=<path> to see the the logs.

 

Inference

The hierarchical inference described in the paper is implemented as a stand-alone service and can be used with any backbone algorithm (models/reco/hierarchical_reco.py).

 

python sdr_main.py --dataset_name <name> --resume_from_checkpoint <checkpoint> --test_only

Results

Citing & Authors

If you find this repository or the annotated datasets helpful, feel free to cite our publication -

SDR: Self-Supervised Document-to-Document Similarity Ranking viaContextualized Language Models and Hierarchical Inference

 @misc{ginzburg2021selfsupervised,
     title={Self-Supervised Document Similarity Ranking via Contextualized Language Models and Hierarchical Inference}, 
     author={Dvir Ginzburg and Itzik Malkiel and Oren Barkan and Avi Caciularu and Noam Koenigstein},
     year={2021},
     eprint={2106.01186},
     archivePrefix={arXiv},
     primaryClass={cs.CL}
}

Contact: Dvir Ginzburg, Itzik Malkiel.

sdr's People

Contributors

dvirginz avatar microsoft-github-policy-service[bot] avatar microsoftopensource avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

sdr's Issues

error when I run the command or training on the video games dataset

when I run

python sdr_main.py --dataset_name video_games

I get the following error -

Traceback (most recent call last):
  File "sdr_main.py", line 80, in <module>
    main()
  File "sdr_main.py", line 28, in main
    main_train(model_class_pointer, hyperparams,parser)
  File "sdr_main.py", line 72, in main_train
    trainer.fit(model)
  File "/home/j20200059/miniconda3/envs/SDR/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 510, in fit
    results = self.accelerator_backend.train()
  File "/home/j20200059/miniconda3/envs/SDR/lib/python3.7/site-packages/pytorch_lightning/accelerators/accelerator.py", line 57, in train
    return self.train_or_test()
  File "/home/j20200059/miniconda3/envs/SDR/lib/python3.7/site-packages/pytorch_lightning/accelerators/accelerator.py", line 74, in train_or_test
    results = self.trainer.train()
  File "/home/j20200059/miniconda3/envs/SDR/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 532, in train
    self.run_sanity_check(self.get_model())
  File "/home/j20200059/miniconda3/envs/SDR/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 721, in run_sanity_check
    self.reset_val_dataloader(ref_model)
  File "/home/j20200059/miniconda3/envs/SDR/lib/python3.7/site-packages/pytorch_lightning/trainer/data_loading.py", line 287, in reset_val_dataloader
    self.num_val_batches, self.val_dataloaders = self._reset_eval_dataloader(model, 'val')
  File "/home/j20200059/miniconda3/envs/SDR/lib/python3.7/site-packages/pytorch_lightning/trainer/data_loading.py", line 207, in _reset_eval_dataloader
    dataloaders = self.request_dataloader(getattr(model, loader_name))
  File "/home/j20200059/miniconda3/envs/SDR/lib/python3.7/site-packages/pytorch_lightning/trainer/data_loading.py", line 310, in request_dataloader
    dataloader = dataloader_fx()
  File "/scratch/j20200059/SDR/models/doc_similarity_pl_template.py", line 182, in val_dataloader
    return self.dataloader(mode="val")
  File "/scratch/j20200059/SDR/models/SDR/SDR.py", line 171, in dataloader
    batch_size=self.hparams.val_batch_size,
  File "/scratch/j20200059/SDR/models/SDR/SDR_utils.py", line 16, in __init__
    super(MPerClassSamplerDeter, self).__init__(labels, m, batch_size, int(length_before_new_iter))
  File "/home/j20200059/miniconda3/envs/SDR/lib/python3.7/site-packages/pytorch_metric_learning/samplers/m_per_class_sampler.py", line 32, in __init__
    ), "m * (number of unique labels) must be >= batch_size"
AssertionError: m * (number of unique labels) must be >= batch_size

TypeError: __init__() got an unexpected keyword argument 'filepath'

I executed the training command

%cd /home/ec2-user/SageMaker/SDR
!python sdr_main.py --dataset_name video_games

The error was

/home/ec2-user/SageMaker/SDR
[nltk_data] Downloading package punkt to /home/ec2-user/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
WARNING:PML:The pytorch-metric-learning testing module requires faiss. You can install the GPU version with the command 'conda install faiss-gpu -c pytorch'
                        or the CPU version with 'conda install faiss-cpu -c pytorch'. Learn more at https://github.com/facebookresearch/faiss/blob/master/INSTALL.md
Global seed set to 42
INFO:lightning:Global seed set to 42
/home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages/cryptography/hazmat/backends/openssl/x509.py:18: CryptographyDeprecationWarning: This version of cryptography contains a temporary pyOpenSSL fallback path. Upgrade pyOpenSSL now.
  utils.DeprecatedIn35,
Some weights of SimilarityModeling were not initialized from the model checkpoint at roberta-large and are newly initialized: ['lm_head.decoder.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.

Log directory:
/home/ec2-user/SageMaker/SDR/output/document_similarity/arch_SDR/dataset_name_video_games/test_only_False/22_11_2021-20_22_55

Traceback (most recent call last):
  File "sdr_main.py", line 80, in <module>
    main()
  File "sdr_main.py", line 28, in main
    main_train(model_class_pointer, hyperparams,parser)
  File "sdr_main.py", line 55, in main_train
    verbose=True,
TypeError: __init__() got an unexpected keyword argument 'filepath'

Using padded tokens when creating averaged sentence embeddings

When calculating the similarity loss between two sentences, it looks like we are using the averaged word embeddings per sentence. Within models.SDR.similarity_modeling.SimilarityModeling we have the following:

...
non_masked_outputs = self.roberta(
    non_masked_input_ids,
    attention_mask=attention_mask,
    token_type_ids=token_type_ids,
    position_ids=position_ids,
    head_mask=head_mask,
    inputs_embeds=inputs_embeds,
    output_hidden_states=output_hidden_states,
    return_dict=return_dict,
)
non_masked_seq_out = non_masked_outputs[0]

meaned_sentences = non_masked_seq_out.mean(1)
miner_output = list(self.miner_func(meaned_sentences, sample_labels))

sim_loss = self.similarity_loss_func(meaned_sentences, sample_labels, miner_output)
...

It appears using the embeddings for the padded tokens since we aren't taking into account any sentence lengths. Was this done by design perhaps?

Training with custom dataset?

@dvirginz
What part of the code should I refer to if I were to train the model on my custom dataset? Also, is it necessary to perform the MLM training along with the contrastive loss? (would using the contrastive loss alone degrade performance by a lot?)

RuntimeError: CUDA out of memory.

Train command

%cd /home/ec2-user/SageMaker/SDR
!python sdr_main.py --dataset_name wines

Stacktrace:

Traceback (most recent call last):
  File "sdr_main.py", line 80, in <module>
    main()
  File "sdr_main.py", line 28, in main
    main_train(model_class_pointer, hyperparams,parser)
  File "sdr_main.py", line 72, in main_train
    trainer.fit(model)
  File "/home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages/pytorch_lightning/trainer/trainer.py", line 510, in fit
    results = self.accelerator_backend.train()
  File "/home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages/pytorch_lightning/accelerators/accelerator.py", line 57, in train
    return self.train_or_test()
  File "/home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages/pytorch_lightning/accelerators/accelerator.py", line 74, in train_or_test
    results = self.trainer.train()
  File "/home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages/pytorch_lightning/trainer/trainer.py", line 561, in train
    self.train_loop.run_training_epoch()
  File "/home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages/pytorch_lightning/trainer/training_loop.py", line 550, in run_training_epoch
    batch_output = self.run_training_batch(batch, batch_idx, dataloader_idx)
  File "/home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages/pytorch_lightning/trainer/training_loop.py", line 692, in run_training_batch
    self.trainer.hiddens)
  File "/home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages/pytorch_lightning/trainer/training_loop.py", line 806, in training_step_and_backward
    result = self.training_step(split_batch, batch_idx, opt_idx, hiddens)
  File "/home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages/pytorch_lightning/trainer/training_loop.py", line 319, in training_step
    training_step_output = self.trainer.accelerator_backend.training_step(args)
  File "/home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages/pytorch_lightning/accelerators/dp_accelerator.py", line 117, in training_step
    return self._step(args)
  File "/home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages/pytorch_lightning/accelerators/dp_accelerator.py", line 113, in _step
    output = self.trainer.model(*args)
  File "/home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages/pytorch_lightning/overrides/data_parallel.py", line 93, in forward
    return self.module.training_step(*inputs[0], **kwargs[0])
  File "/home/ec2-user/SageMaker/SDR/models/doc_similarity_pl_template.py", line 49, in training_step
    batch = self(batch)
  File "/home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/ec2-user/SageMaker/SDR/models/SDR/SDR.py", line 78, in forward
    eval(f"self.forward_{self.hparams.mode}")(batch)
  File "/home/ec2-user/SageMaker/SDR/models/SDR/SDR.py", line 48, in forward_train
    run_mlm=True,
  File "/home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/ec2-user/SageMaker/SDR/models/SDR/similarity_modeling.py", line 129, in forward
    return_dict=return_dict,
  File "/home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages/transformers/modeling_bert.py", line 835, in forward
    return_dict=return_dict,
  File "/home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages/transformers/modeling_bert.py", line 490, in forward
    output_attentions,
  File "/home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages/transformers/modeling_bert.py", line 433, in forward
    self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
  File "/home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages/transformers/modeling_utils.py", line 1597, in apply_chunking_to_forward
    return forward_fn(*input_tensors)
  File "/home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages/transformers/modeling_bert.py", line 439, in feed_forward_chunk
    intermediate_output = self.intermediate(attention_output)
  File "/home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages/transformers/modeling_bert.py", line 367, in forward
    hidden_states = self.intermediate_act_fn(hidden_states)
  File "/home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages/torch/nn/functional.py", line 1556, in gelu
    return torch._C._nn.gelu(input)
RuntimeError: CUDA out of memory. Tried to allocate 32.00 MiB (GPU 0; 14.76 GiB total capacity; 11.17 GiB already allocated; 14.75 MiB free; 11.40 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

how to use proposed dataset

Hi team, thank you for the great work, I want to use your proposed dataset (wines) for my study but I found no ground truth, all in that file wines.txt is only the title and sections. I want to know how ground truth is arranged in this file.
Thank you team! hope you reply soon!

SBERT v performance?

Hi!

Looking at the results, it is written that an experiment on SBERT v has been conducted, and I am curious where I can see the performance of SBERT v.

Thanks.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.