Coder Social home page Coder Social logo

jieyuz2 / wrench Goto Github PK

View Code? Open in Web Editor NEW
214.0 214.0 31.0 1.86 MB

[NeurIPS 2021] WRENCH: Weak supeRvision bENCHmark

Home Page: https://arxiv.org/abs/2109.11377

License: Apache License 2.0

Python 100.00%
benchmark-framework data-centric-ai data-programming dataset deep-learning machine-learning nlp robust-learning sequence-labeling weak-supervision weakly-supervised-learning

wrench's People

Contributors

andst avatar jeffreywpli avatar jieyuz2 avatar polaris-73 avatar rpryzant avatar wurenzhi avatar yinxiangshi avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar

wrench's Issues

Snuba, Glara and Tallor Label function generators

In your paper you mention that Label Functions are critical to the overall performance of a WS system. Have you implemented this automated label function generators to compare these approaches to other approaches of creating Labeling functions automatically?

Using Multiple GPUs

Hi,

Is it possible to use multiple GPUs for the experiments, or will it be in future releases? It would be a nice feature if it is not possible right now.

Best regards.

Is there a limitation of using dataset for different algs?

Firstly, thank you for building this awesome benchmark. While I try the example with different datasets (e.g., I try astra with youtube dataset), I got some errors like this,

    loss = cross_entropy_with_probs(predict_l, batch['labels'].to(device))
KeyError: 'labels'

Can this be fixed?

Question on train/val/test split when evaluating label model.

It seems the label model is also fitted on a training set and then evaluated on a test set in the original paper. However, when using weak supervision to generate labeled data, we care more about the quality of the generated labels than the generalization ability of a label model. For example, a label model provides perfect labels on the training set (which it was fitted on with an unsupervised learning process), and the label model provides random labels on a test set (on which it was not fitted). This is a perfect label model for the purpose of generating labeled data but will be the worst label model in the benchmark. My questions is:
For the purpose of generating labeled data (which is then used to train an end model), is it really necessary to do train/val/test split to evaluate the label model? Can we just fit the unsupervised label model on the whole dataset and then evaluate on the whole dataset?

I appreciate any explanations.

Tensors on different devices

Hi, I am trying to run run_meta_weight_net.py but I am getting the following RuntimeError:

Traceback (most recent call last):
  File "run_meta_weight_net.py", line 67, in <module>
    device=device
  File "/gpfs/space/home/wrench/wrench/metalearning/meta_weight_net.py", line 166, in fit
    meta_loss = cross_entropy_with_probs(outputs, meta_target)
  File "/gpfs/space/home/wrench/wrench/utils.py", line 146, in cross_entropy_with_probs
    return F.cross_entropy(input, target.long(), weight=weight, reduction=reduction)
  File "/gpfs/space/home/.conda/envs/wrench2/lib/python3.6/site-packages/torch/nn/functional.py", line 2824, in cross_entropy
    return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking arugment for argument target in method wrapper_nll_loss_forward)

I have tried to reinstall my PyTorch but have had little success in getting it to work. Is this an environment error on my behalf or is it an actual bug in the code (one tensor being on the GPU and the other on the CPU when calculating cross-entropy loss)?

Thank you!

Recommended parameters to use for each algorithms and datasets.

I've tried several combinations of different algorithms and datasets, but I found it's hard to get results similar to the paper.
I suspect this is due to inappropriate parameter settings, so, I think it will be great if this repo can provide some recommended parameters. (Especially for the newly added algorithms, it's hard to judge if it get the right results)

datasets.zip in google drive is not updated

Hi,

First of all thank you for the awesome benchmark.

In the google drive datasets, it seems the datasets.zip is not updated. It dates from prior to the last update of the datasets / classification folders, and is missing the train.json in the commercial dataset for instance.

I think it would be advisable to delete this archive as its possible to download the above directly from google drive, or update it.

Thank you

New Release

Hi! Love the repo, super useful so far and really easy interface to use. Thanks for putting it together!

I was wondering if there were plans to cut another release any time soon? We use the v1.0 tag for making sure the version is consistent across multiple builds. Noticed a few bug fixes and QOL improvements since the last release, and those would be nice to have marked at a new tag.

Questions on the use of ground-truth labels for validation

Thanks for putting up the benchmark! This is really great work!
It seems that both the label model and the end model use the ground-truth labels for validation.
For example, the base label model uses the ground-truth labels of the validation set to calculate the class balance weights:

y = np.array(dataset_valid.labels)

I have a few questions regarding this:
(1) A valid baseline for the label models would be a classifier trained on the validation set with the weak labels of LFs as features and the ground-truth labels as the target. Given that the validation set for most datasets is actually not small, I feel the trained model might be a pretty strong baseline compared to other unsupervised label models.
(2) Similar to how we combine the weak labels on the training set to get aggregated labels, we could also get aggregated labels for the validation set. Then, the aggregated labels instead of the ground-truth labels of the validation set could be used for validation purposes for the end model. Wouldn't this be a more realistic setting? Especially considering that the proposal of weak supervision is to replace human labeling with programmatical labeling.

I appreciate any explanations. Thanks!

about COSINE endmodel

Hi @JieyuZ2 and @yinxiangshi , I am trying to run the COSINE endmodel but I have some troubles in reproducing the results in COSINE paper. Although I tried to use the suggested hyperparameters I still get marginal benefit from wrench, and I'm not sure where is wrong. Can you share the scripts you used when evaluating COSINE? Thanks.

Unable to install locally or on Google Colab

Thanks for the fantastic package and great research work.

I have tried both pip install ws-benchmark==1.1.2rc0 and pip install ws-benchmark on both my machine and Google Colab (fresh environment) but the installation fails each time.

A snippet from the error during installation:

Collecting scikit-learn<0.25.0,>=0.20.2 (from ws-benchmark==1.1.2rc0)
  Downloading scikit-learn-0.24.2.tar.gz (7.5 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 7.5/7.5 MB 4.0 MB/s eta 0:00:00
  Installing build dependencies ... done
  Getting requirements to build wheel ... done
  error: subprocess-exited-with-error
  
  × Preparing metadata (pyproject.toml) did not run successfully.
  │ exit code: 1
  ╰─> See above for output.
  
  note: This error originates from a subprocess, and is likely not a problem with pip.
  Preparing metadata (pyproject.toml) ... error
error: metadata-generation-failed

× Encountered error while generating package metadata.
╰─> See above for output.

note: This is an issue with the package mentioned above, not pip.
hint: See above for details.

You can find the Colab notebook here.

Clarifying dataset download links

Great work on the benchmark!

Under the "Available Datasets" section on the main README, you provide 2 links for downloading the WRENCH datasets:

One point of confusion is that expanded datasets found on the Google drive link are different than the direct download zip file. For example, classification/youtube/train.json on Google drive has 1686 instances while the zip file contains 1586 for the same file, matching the statistics reported on the README. Can you make the correct file download unambiguous in the documentation?

Handling datasets using Hugging Face datasets and Hub

Hi,

Love this initiative, congrats!

Would it be possible to integrate the datasets into the huggingface Hub? Besides from the technical effort, would there be any copyright, licensing issues? If not I wouldn't mind to help out with this

ModuleNotFoundError: No module named 'tokenizations'

Hi, I faced some problems when trying to install the library. I tried to use pip install ws-benchmark==1.1.2rc0 as suggested in the document, the installation was successful but when I run the code I faced the error ModuleNotFoundError: No module named 'tokenizations'. Then I tried to clone the repository and create the environment using conda env create -f environment.yml, but the installation failed due to the following error FileNotFoundError: [Errno 2] No such file or directory: '/home/naiqing/miniconda3/envs/wrench/lib/python3.6/site-packages/huggingface_hub-0.0.16-py3.8.egg'. Do you have ideas on what might cause the problem and how can I fix it?

No module named 'wrench.classification.self_training'

Hi, I am trying to run run_denoise.py but I am getting the following error:

Traceback (most recent call last):
  File "run_denoise.py", line 5, in <module>
    from wrench.classification import Denoise
  File "/gpfs/space/home/wrench/wrench/classification/__init__.py", line 4, in <module>
    from .self_training import LDSelfTrain, DDSelfTrain
ModuleNotFoundError: No module named 'wrench.classification.self_training'

Could you please add LDSelfTrain and DDSelfTrain classes?

Python Package Installation Fails

Installing ws-benchmark python package fails due to dependency conflict (see stack trace below).

Tested on system:

  • OS: ubuntu
  • Python: 3.8.13
  • Clean VE

Command to replicate:

  • pip install ws-benchmark

Stack Trace:

ERROR: Cannot install ws-benchmark and ws-benchmark==1.1.1 because these package versions have conflicting dependencies.

The conflict is caused by:
    ws-benchmark 1.1.1 depends on networkx==2.7
    snorkel 0.9.7 depends on networkx<2.4 and >=2.2

To fix this you could try to:
1. loosen the range of package versions you've specified
2. remove package versions to allow pip attempt to solve the dependency conflict

ERROR: ResolutionImpossible: for help visit https://pip.pypa.io/en/latest/topics/dependency-resolution/#dealing-with-dependency-conflicts

Running scripts

Hi, I am trying to run some models on the IMDB dataset.

MLP:

import logging
import torch
import numpy as np
from wrench.dataset import load_dataset
from wrench.labelmodel import Snorkel
from wrench.logging import LoggingHandler
from wrench.search import grid_search
from wrench.endmodel import EndClassifierModel

#### Just some code to print debug information to stdout
logging.basicConfig(format='%(asctime)s - %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S',
                    level=logging.INFO,
                    handlers=[LoggingHandler()])

logger = logging.getLogger(__name__)

device = torch.device('cuda')

if __name__ == '__main__':
    #### Load dataset
    dataset_path = '../datasets/'
    data = "imdb"
    bert_model_name = "bert-base-cased"
    train_data, valid_data, test_data = load_dataset(
        dataset_path,
        data,
        extract_feature=True,
        extract_fn='bert',  # extract bert embedding
        model_name=bert_model_name,
        cache_name='bert',
        dataset_type="TextDataset"
    )

    #### Run label model: Snorkel
    label_model = Snorkel(
        lr=0.005,
        l2=0,
        n_epochs=200,
        seed=123
    )
    label_model.fit(
        dataset_train=train_data,
        dataset_valid=valid_data
    )

    acc = label_model.test(test_data, 'acc')
    logger.info(f'label model test acc: {acc}')

    #### Filter out uncovered training data
    aggregated_hard_labels = label_model.predict(train_data)
    aggregated_soft_labels = label_model.predict_proba(train_data)

    #### Search Space
    search_space = {
        'optimizer_lr': np.logspace(-5, -1, num=5, base=10),
        'optimizer_weight_decay': np.logspace(-5, -1, num=5, base=10),
    }

    #### Initialize the model: MLP
    model = EndClassifierModel(
        batch_size=8,
        real_batch_size=8,
        test_batch_size=8,
        backbone='MLP',
        optimizer='Adam'
    )

    #### Search best hyper-parameters using validation set in parallel
    n_trials = 20
    n_repeats = 1
    searched_paras = grid_search(
        model,
        dataset_train=train_data,
        y_train=aggregated_soft_labels,
        dataset_valid=valid_data,
        metric='acc',
        direction='auto',
        search_space=search_space,
        n_repeats=n_repeats,
        n_trials=n_trials,
        parallel=True,
        device=device,
    )


    #### Run end model: MLP
    model = EndClassifierModel(
        batch_size=8,
        real_batch_size=8,
        test_batch_size=8,
        backbone='MLP',
        optimizer='Adam',
        **searched_paras
    )
    model.fit(
        dataset_train=train_data,
        y_train=aggregated_soft_labels,
        dataset_valid=valid_data,
        metric='acc',
        device=device
    )

    logger.info(model.predict(test_data).tolist())

    acc = model.test(test_data, 'acc')
    logger.info(f'end model (MLP) test acc: {acc}')

for which I am getting the following output:

100%|██████████| 20000/20000 [00:00<00:00, 902651.16it/s]
100%|██████████| 2500/2500 [00:00<00:00, 852639.45it/s]
100%|██████████| 2500/2500 [00:00<00:00, 829503.99it/s]
Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
100%|██████████| 20000/20000 [1:42:45<00:00,  3.24it/s]  
Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
100%|██████████| 2500/2500 [13:24<00:00,  3.11it/s]
Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
100%|██████████| 2500/2500 [13:50<00:00,  3.01it/s]
[I 2021-10-23 22:24:36,807] A new study created in memory with name: no-name-9e4ad09c-ea4a-4ee8-80c2-7633429e4038
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
        - Avoid using `tokenizers` before the fork if possible
        - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
2021-10-23 20:14:19 - loading data from ../datasets/imdb/train.json
2021-10-23 20:14:19 - loading data from ../datasets/imdb/valid.json
2021-10-23 20:14:19 - loading data from ../datasets/imdb/test.json
2021-10-23 21:57:10 - saving features into ../datasets/imdb/train_bert.pkl
2021-10-23 22:10:40 - saving features into ../datasets/imdb/valid_bert.pkl
2021-10-23 22:24:36 - saving features into ../datasets/imdb/test_bert.pkl
2021-10-23 22:24:36 - label model test acc: 0.716
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
        - Avoid using `tokenizers` before the fork if possible
        - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
100%|██████████| 1/1 [00:37<00:00, 37.48s/it]
[I 2021-10-23 22:25:14,563] Trial 0 finished with value: 0.5012 and parameters: {'optimizer_lr': 0.001, 'optimizer_weight_decay': 0.0001}. Best is trial 0 with value: 0.5012.
100%|██████████| 1/1 [00:23<00:00, 23.70s/it]
[I 2021-10-23 22:25:38,448] Trial 1 finished with value: 0.496 and parameters: {'optimizer_lr': 1e-05, 'optimizer_weight_decay': 0.1}. Best is trial 0 with value: 0.5012.
100%|██████████| 1/1 [00:14<00:00, 14.53s/it]
[I 2021-10-23 22:25:53,171] Trial 2 finished with value: 0.5004 and parameters: {'optimizer_lr': 0.1, 'optimizer_weight_decay': 0.001}. Best is trial 0 with value: 0.5012.
100%|██████████| 1/1 [00:43<00:00, 43.73s/it]
[I 2021-10-23 22:26:37,071] Trial 3 finished with value: 0.5088 and parameters: {'optimizer_lr': 0.001, 'optimizer_weight_decay': 0.001}. Best is trial 3 with value: 0.5088.
100%|██████████| 1/1 [00:18<00:00, 18.85s/it]
[I 2021-10-23 22:26:56,161] Trial 4 finished with value: 0.488 and parameters: {'optimizer_lr': 0.001, 'optimizer_weight_decay': 0.1}. Best is trial 3 with value: 0.5088.
100%|██████████| 1/1 [00:38<00:00, 38.81s/it]
[I 2021-10-23 22:27:35,214] Trial 5 finished with value: 0.4948 and parameters: {'optimizer_lr': 0.0001, 'optimizer_weight_decay': 0.1}. Best is trial 3 with value: 0.5088.
100%|██████████| 1/1 [00:38<00:00, 38.15s/it]
[I 2021-10-23 22:28:13,614] Trial 6 finished with value: 0.5024 and parameters: {'optimizer_lr': 0.01, 'optimizer_weight_decay': 0.01}. Best is trial 3 with value: 0.5088.
100%|██████████| 1/1 [00:15<00:00, 15.47s/it]
[I 2021-10-23 22:28:29,335] Trial 7 finished with value: 0.4996 and parameters: {'optimizer_lr': 0.1, 'optimizer_weight_decay': 1e-05}. Best is trial 3 with value: 0.5088.
100%|██████████| 1/1 [00:22<00:00, 22.49s/it]
[I 2021-10-23 22:28:52,093] Trial 8 finished with value: 0.5008 and parameters: {'optimizer_lr': 0.0001, 'optimizer_weight_decay': 1e-05}. Best is trial 3 with value: 0.5088.
100%|██████████| 1/1 [00:40<00:00, 40.25s/it]
[I 2021-10-23 22:29:32,594] Trial 9 finished with value: 0.5008 and parameters: {'optimizer_lr': 0.0001, 'optimizer_weight_decay': 0.0001}. Best is trial 3 with value: 0.5088.
100%|██████████| 1/1 [00:39<00:00, 39.06s/it]
[I 2021-10-23 22:30:11,902] Trial 10 finished with value: 0.5116 and parameters: {'optimizer_lr': 1e-05, 'optimizer_weight_decay': 1e-05}. Best is trial 10 with value: 0.5116.
100%|██████████| 1/1 [00:43<00:00, 43.46s/it]
[I 2021-10-23 22:30:55,531] Trial 11 finished with value: 0.4912 and parameters: {'optimizer_lr': 0.001, 'optimizer_weight_decay': 1e-05}. Best is trial 10 with value: 0.5116.
100%|██████████| 1/1 [00:23<00:00, 23.41s/it]
[I 2021-10-23 22:31:19,095] Trial 12 finished with value: 0.4956 and parameters: {'optimizer_lr': 0.001, 'optimizer_weight_decay': 0.01}. Best is trial 10 with value: 0.5116.
100%|██████████| 1/1 [00:22<00:00, 22.12s/it]
[I 2021-10-23 22:31:41,374] Trial 13 finished with value: 0.492 and parameters: {'optimizer_lr': 1e-05, 'optimizer_weight_decay': 0.01}. Best is trial 10 with value: 0.5116.
100%|██████████| 1/1 [00:15<00:00, 15.78s/it]
[I 2021-10-23 22:31:57,283] Trial 14 finished with value: 0.5044 and parameters: {'optimizer_lr': 0.1, 'optimizer_weight_decay': 0.0001}. Best is trial 10 with value: 0.5116.
100%|██████████| 1/1 [00:37<00:00, 37.28s/it]
[I 2021-10-23 22:32:34,728] Trial 15 finished with value: 0.488 and parameters: {'optimizer_lr': 1e-05, 'optimizer_weight_decay': 0.001}. Best is trial 10 with value: 0.5116.
100%|██████████| 1/1 [00:16<00:00, 16.04s/it]
[I 2021-10-23 22:32:50,934] Trial 16 finished with value: 0.4924 and parameters: {'optimizer_lr': 0.0001, 'optimizer_weight_decay': 0.001}. Best is trial 10 with value: 0.5116.
100%|██████████| 1/1 [00:19<00:00, 19.65s/it]
[I 2021-10-23 22:33:10,753] Trial 17 finished with value: 0.5156 and parameters: {'optimizer_lr': 0.1, 'optimizer_weight_decay': 0.1}. Best is trial 17 with value: 0.5156.
100%|██████████| 1/1 [00:15<00:00, 15.41s/it]
[I 2021-10-23 22:33:26,345] Trial 18 finished with value: 0.5068 and parameters: {'optimizer_lr': 0.01, 'optimizer_weight_decay': 0.001}. Best is trial 17 with value: 0.5156.
100%|██████████| 1/1 [00:16<00:00, 16.75s/it]
[I 2021-10-23 22:33:43,222] Trial 19 finished with value: 0.498 and parameters: {'optimizer_lr': 0.0001, 'optimizer_weight_decay': 0.01}. Best is trial 17 with value: 0.5156.
[TRAIN]:  15%|█████▌                               | 1499/10000 [00:21<02:04, 68.19steps/s, loss=4.02, val_acc=0.5, best_val_acc=0.508, best_step=500]
2021-10-23 22:33:43 - [END: BEST VAL / PARAMS] Best value: 0.5156, Best paras: {'optimizer_lr': 0.1, 'optimizer_weight_decay': 0.1}
2021-10-23 22:33:43 - 
==========[hyper parameters]==========
{
    "batch_size": 8,
    "real_batch_size": 8,
    "test_batch_size": 8,
    "n_steps": 10000,
    "grad_norm": -1,
    "use_lr_scheduler": false,
    "binary_mode": false
}
==========[optimizer config]==========
{
    "name": "Adam",
    "paras": {
        "lr": 0.1,
        "weight_decay": 0.1
    }
}
==========[backbone config]==========
{
    "name": "MLP",
    "paras": {
        "hidden_size": 100,
        "dropout": 0.0
    }
}

2021-10-23 22:34:09 - [INFO] early stop @ step 1500!
2021-10-23 22:34:09 - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
2021-10-23 22:34:09 - end model (MLP) test acc: 0.5004

COSINE:

import logging
import torch
from wrench.dataset import load_dataset
from wrench.logging import LoggingHandler
from wrench.labelmodel import Snorkel
from wrench.endmodel import Cosine

#### Just some code to print debug information to stdout
logging.basicConfig(format='%(asctime)s - %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S',
                    level=logging.INFO,
                    handlers=[LoggingHandler()])

logger = logging.getLogger(__name__)

device = torch.device('cuda')

if __name__ == '__main__':
    #### Load dataset
    dataset_path = '../datasets/'
    data = "imdb"
    bert_model_name = "bert-base-cased"
    train_data, valid_data, test_data = load_dataset(
        dataset_path,
        data,
        extract_feature=True,
        extract_fn='bert',  # extract bert embedding
        model_name=bert_model_name,
        cache_name='bert',
        dataset_type="TextDataset"
    )

    #### Run label model: Snorkel
    label_model = Snorkel(
        lr=0.005,
        l2=0,
        n_epochs=200,
        seed=123
    )
    label_model.fit(
        dataset_train=train_data,
        dataset_valid=valid_data
    )

    acc = label_model.test(test_data, 'acc')
    logger.info(f'label model test acc: {acc}')

    #### Filter out uncovered training data
    aggregated_hard_labels = label_model.predict(train_data)
    aggregated_soft_labels = label_model.predict_proba(train_data)


    # COSINE
    model = Cosine(
        teacher_update=100,
        margin=1.0,
        thresh=0.6,
        lr=1e-5,
        mu=1.0,
        lamda=0.05,
        backbone='BERT',
        backbone_model_name=bert_model_name,
        batch_size=8,
        real_batch_size=8,
        test_batch_size=8,
    )

    model.fit(dataset_train=train_data,
              dataset_valid=valid_data,
              y_train=aggregated_hard_labels,
              evaluation_step=10,
              metric='acc',
              patience=50,
              device=device)

    acc = model.test(test_data, 'acc')

    logger.info(model.predict(test_data))

    logger.info(f'end model (COSINE) test acc: {acc}')

for which I am getting the following output:

100%|██████████| 20000/20000 [00:00<00:00, 899119.81it/s]
100%|██████████| 2500/2500 [00:00<00:00, 423667.07it/s]
100%|██████████| 2500/2500 [00:00<00:00, 802645.44it/s]
Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
100%|██████████| 20000/20000 [1:47:44<00:00,  3.09it/s]  
Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
100%|██████████| 2500/2500 [14:22<00:00,  2.90it/s]
Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
100%|██████████| 2500/2500 [13:33<00:00,  3.07it/s] 
Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
[TRAIN] COSINE pretrain stage:   5%|▊               | 509/10000 [21:19<6:37:40,  2.51s/steps, loss=0.605, val_acc=0.5, best_val_acc=0.5, best_step=10]
[TRAIN] COSINE distillation stage:   0%|                                                                                 | 0/10000 [03:05<?, ?steps/s]
2021-10-23 20:14:13 - loading data from ../datasets/imdb/train.json
2021-10-23 20:14:13 - loading data from ../datasets/imdb/valid.json
2021-10-23 20:14:14 - loading data from ../datasets/imdb/test.json
2021-10-23 22:02:05 - saving features into ../datasets/imdb/train_bert.pkl
2021-10-23 22:16:34 - saving features into ../datasets/imdb/valid_bert.pkl
2021-10-23 22:30:14 - saving features into ../datasets/imdb/test_bert.pkl
2021-10-23 22:30:14 - label model test acc: 0.716
2021-10-23 22:30:17 - 
==========[hyper parameters]==========
{
    "teacher_update": 100,
    "margin": 1.0,
    "mu": 1.0,
    "thresh": 0.6,
    "lamda": 0.05,
    "batch_size": 8,
    "real_batch_size": 8,
    "test_batch_size": 8,
    "n_steps": 10000,
    "grad_norm": -1,
    "use_lr_scheduler": false,
    "binary_mode": false
}
==========[optimizer config]==========
{
    "name": "Adam",
    "paras": {
        "lr": 0.001,
        "weight_decay": 0.0
    }
}
==========[backbone config]==========
{
    "name": "BERT",
    "paras": {
        "model_name": "bert-base-cased",
        "max_tokens": 512,
        "fine_tune_layers": -1
    }
}
==========[label model_config config]==========
{
    "name": "MajorityVoting",
    "paras": {}
}

2021-10-23 22:51:52 - [INFO] early stop @ step 510!
2021-10-23 22:55:20 - early stop because all the data are filtered!
2021-10-23 22:56:06 - [1 1 1 ... 1 1 1]
2021-10-23 22:56:06 - end model (COSINE) test acc: 0.5

As can be seen for both models, label model test acc: 0.716 but end model (MLP) test acc: 0.5004 and end model (COSINE) test acc: 0.5.

Am I doing something completely wrong? Could you please tell me if I am running the code correctly or is there some issue with hyperparameters?

I would greatly appreciate if you could give me some advice. I would be very glad if you could include an example running script of the COSINE model as well.

Thanks for the benchmark, I really appreciate it!

Adding custom datasets.

How can I go about adding my own datasets to the wrench framework so that I can perform experiments?

Thanks!

Releasing labeling functions?

Hi, thanks for putting this benchmark together! Would it be possible to also release the labeling functions for the datasets (rather than just the vote matrices)?

COSINE for token classification?

Hi,

I would like to know whether the code for cosine weak-supervision technique is already capable of performing token classification? Or else what changes should I need to do to build a weakly-supervised training pipeline using some weakly-labeled and unlabeled datasets?

Numba 0.43 doesn't work with newer Python versions

The numba package 0.43, specified here, doesn't work with Python 3.9. Upgrading the package to the latest version (0.54) resolves the issue.
Traceback:

/home/murphy/anaconda3/envs/taskbase/lib/python3.9/site-packages/llvmlite/llvmpy/__init__.py:3: UserWarning: The module `llvmlite.llvmpy` is deprecated and will be removed in the future.
  warnings.warn(
/home/murphy/anaconda3/envs/taskbase/lib/python3.9/site-packages/llvmlite/llvmpy/core.py:8: UserWarning: The module `llvmlite.llvmpy.core` is deprecated and will be removed in the future. Equivalent functionality is provided by `llvmlite.ir`.
  warnings.warn(
Traceback (most recent call last):
  File "<frozen importlib._bootstrap>", line 1007, in _find_and_load
  File "<frozen importlib._bootstrap>", line 986, in _find_and_load_unlocked
  File "<frozen importlib._bootstrap>", line 680, in _load_unlocked
  File "<frozen importlib._bootstrap_external>", line 850, in exec_module
  File "<frozen importlib._bootstrap>", line 228, in _call_with_frames_removed
  File "/home/murphy/anaconda3/envs/taskbase/lib/python3.9/site-packages/wrench/labelmodel/__init__.py", line 1, in <module>
    from .dawid_skene import DawidSkene
  File "/home/murphy/anaconda3/envs/taskbase/lib/python3.9/site-packages/wrench/labelmodel/dawid_skene.py", line 6, in <module>
    from numba import njit, prange
  File "/home/murphy/anaconda3/envs/taskbase/lib/python3.9/site-packages/numba/__init__.py", line 25, in <module>
    from .decorators import autojit, cfunc, generated_jit, jit, njit, stencil
  File "/home/murphy/anaconda3/envs/taskbase/lib/python3.9/site-packages/numba/decorators.py", line 12, in <module>
    from .targets import registry
  File "/home/murphy/anaconda3/envs/taskbase/lib/python3.9/site-packages/numba/targets/registry.py", line 5, in <module>
    from . import cpu
  File "/home/murphy/anaconda3/envs/taskbase/lib/python3.9/site-packages/numba/targets/cpu.py", line 9, in <module>
    from numba import _dynfunc, config
ImportError: /home/murphy/anaconda3/envs/taskbase/lib/python3.9/site-packages/numba/_dynfunc.cpython-39-x86_64-linux-gnu.so: undefined symbol: _PyObject_GC_UNTRACK

Wrong Data Statistics of the SemEval Dataset?

Dear authors,

thanks a lot for the nice repo!

I noticed that both the WRENCH paper and this repo report the following data statistics on the SemEval dataset:
number of samples in train/validation/test: 1749/200/692

However, I found the data in the link you provided (Google Drive) have 1749, 178, and 600 samples for training, validation, and test. Maybe I misunderstood something?

Looking forward to hearing from you.

Multi-class with abstains

Thanks for adding your work on GitHub.

Question: I have a multiclass dataset that has 4 classes, and multiple weak labels. The weak labels mark '-1' for abstains. Does the weak supervision frameworks (such as FlyingSquid and others) used within the wrench library works for multiclass data along with abstains?
If yes, then could you explain how do you aggregate weak labels for Multiclass dataset using FlyingSquid.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.