Coder Social home page Coder Social logo

linzzzzzz / chinese-text-classification-pytorch-tuning Goto Github PK

View Code? Open in Web Editor NEW

This project forked from 649453932/chinese-text-classification-pytorch

27.0 27.0 5.0 59.74 MB

中文文本分类,Bert,TextCNN,TextRNN,FastText,TextRCNN,BiLSTM_Attention,DPCNN,Transformer,基于pytorch,开箱即用。

License: MIT License

Python 100.00%

chinese-text-classification-pytorch-tuning's People

Contributors

649453932 avatar linzzzzzz avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar

chinese-text-classification-pytorch-tuning's Issues

bug

train_eval.py中第95行test_acc不存在,应该改成test_metric

想问下大佬RuntimeError: mat1 and mat2 shapes cannot be multiplied (128x768 and 208x10),该如何解决

将run.py修改如下:
search_space = {
'learning_rate': tune.loguniform(1e-5, 1e-2),
'num_epochs': tune.randint(2, 3),
'dropout': tune.uniform(0, 0.5),
'hidden_size': tune.randint(32, 257),
'num_layers': tune.randint(1,3)
}

if name == 'main':
dataset = 'THUCNews' # 数据集

model_grouping = {
    'bert':2,
    'TextRCNN':1,
    'TextCNN':1,
    'TextRNN':1,
    'FastText':1,
}

# model_group = model_grouping[args.model]
model_group = model_grouping['bert']  ###此处直接选用bert作为训练模型

# 搜狗新闻:embedding_SougouNews.npz, 腾讯:embedding_Tencent.npz, 随机初始化:random
embedding = 'embedding_SougouNews.npz'
model_name = 'bert'
if model_group == 2:
    from utils_bert import build_dataset, build_iterator, get_time_dif

np.random.seed(1)
torch.manual_seed(1)
torch.cuda.manual_seed_all(1)
torch.backends.cudnn.deterministic = True  # 保证每次结果一样


def experiment(tune_config):
    x = import_module('models.' + model_name)
    if model_group == 1:
        config = x.Config(dataset, embedding)
    elif model_group == 2:
        config = x.Config(dataset)

    if tune_config:
        for param in tune_config:
            setattr(config, param, tune_config[param])
    

    start_time = time.time()
    print("Loading data...")
    if model_group == 1:
        vocab, train_data, dev_data, test_data = build_dataset(config, args.word)
    elif model_group == 2:
        train_data, dev_data, test_data = build_dataset(config)
    train_iter = build_iterator(train_data, config)
    dev_iter = build_iterator(dev_data, config)
    test_iter = build_iterator(test_data, config)
    time_dif = get_time_dif(start_time)
    print("Time usage:", time_dif)

    # train
    if model_group == 1:
        config.n_vocab = len(vocab)
        model = x.Model(config).to(config.device)
        if model_name != 'Transformer':
            init_network(model)
        print(model.parameters)
    elif model_group == 2:
        model = x.Model(config).to(config.device)


    if tune_config:
        res = train(config, model, train_iter, dev_iter, test_iter, model_group=model_group, tune_param=True)
        tune.report(metric=res)
    else:
        train(config, model, train_iter, dev_iter, test_iter, model_group=model_group, tune_param=False)


print('tune param: ', True)

# if tune parameters
if True:
    scheduler = ASHAScheduler(metric='metric', mode="max") if args.tune_asha else None
    
    analysis = tune.run(experiment, num_samples=50, config=search_space, resources_per_trial={'gpu':int(True)},
        scheduler=scheduler,
        verbose=3)
    analysis.results_df.to_csv('tune_results_'+args.tune_file+'.csv')
# if not tune parameters
else:
    experiment(tune_config=None)

然后运行的过程中出现了如下错误:
ray.exceptions.RayTaskError(RuntimeError): ray::ImplicitFunc.train() (pid=22632, ip=127.0.0.1, repr=experiment)
File "python\ray_raylet.pyx", line 877, in ray._raylet.execute_task
File "python\ray_raylet.pyx", line 881, in ray._raylet.execute_task
File "python\ray_raylet.pyx", line 821, in ray._raylet.execute_task.function_executor
File "G:\anaconda\envs\bertCTP-F\lib\site-packages\ray_private\function_manager.py", line 670, in actor_method_executor
return method(__ray_actor, *args, **kwargs)
File "G:\anaconda\envs\bertCTP-F\lib\site-packages\ray\util\tracing\tracing_helper.py", line 460, in _resume_span
return method(self, *_args, **_kwargs)
File "G:\anaconda\envs\bertCTP-F\lib\site-packages\ray\tune\trainable\trainable.py", line 384, in train
raise skipped from exception_cause(skipped)
File "G:\anaconda\envs\bertCTP-F\lib\site-packages\ray\tune\trainable\function_trainable.py", line 339, in entrypoint
self._status_reporter.get_checkpoint(),
File "G:\anaconda\envs\bertCTP-F\lib\site-packages\ray\util\tracing\tracing_helper.py", line 460, in _resume_span
return method(self, *_args, **_kwargs)
File "G:\anaconda\envs\bertCTP-F\lib\site-packages\ray\tune\trainable\function_trainable.py", line 653, in _trainable_func
output = fn()
File "E:/w_learning/bertCTP-F/run.py", line 112, in experiment
res = train(config, model, train_iter, dev_iter, test_iter, model_group=model_group, tune_param=True)
File "E:\w_learning\bertCTP-F\train_eval.py", line 60, in train
outputs = model(trains)
File "G:\anaconda\envs\bertCTP-F\lib\site-packages\torch\nn\modules\module.py", line 1110, in _call_impl
return forward_call(*input, **kwargs)
File "E:\w_learning\bertCTP-F\models\bert.py", line 47, in forward
out = self.fc(pooled)
File "G:\anaconda\envs\bertCTP-F\lib\site-packages\torch\nn\modules\module.py", line 1110, in _call_impl
return forward_call(*input, **kwargs)
File "G:\anaconda\envs\bertCTP-F\lib\site-packages\torch\nn\modules\linear.py", line 103, in forward
return F.linear(input, self.weight, self.bias)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (128x768 and 208x10)
想问下大佬这是什么地方出现了问题呀

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.