Coder Social home page Coder Social logo

Comments (3)

malawada avatar malawada commented on August 11, 2024

Hi, you can use the 5_fold_271_carla_sequence_classification_example_model.pt in the folder. It is the same model just named differently. The old version should still work as well, but the new one has updated model weights that may perform better.

from roadscene2vec.

xinxinlv avatar xinxinlv commented on August 11, 2024

hi, thank you for your reply.
I modified the “use_case_2.py” and directly loaded the new pre-trained model “5_fold_271_carla_sequence_classification_example_model.pt” provided by you to predict “use_case_data/lanechange” (without training).
`def risk_assess():
scenegraph_extraction_config = configuration(r"use_case_2_scenegraph_extraction_config.yaml",from_function = True) #create scenegraph extraction config object
extracted_scenegraphs = extract_seq(scenegraph_extraction_config) #extracted scenegraphs for each frame for the given sequence into a ScenegraphDataset
training_config = configuration(r"use_case_2_learning_config.yaml",from_function = True) #create training config object
trainer = Scenegraph_Trainer(training_config) #create trainer object using config
#trainer.split_dataset() #split ScenegraphDataset specified in learning config into training, testing data
# trainer.build_model() #build model specified in learning config
# trainer.learn()

trainer.load_model() 


model_input = format_use_case_model_input(extracted_scenegraphs, trainer) #turn extracted original sequence's extracted ScenegraphDataset into model input
output, _ = trainer.model.forward(*model_input) #output risk assessment for the original sequence 
return output   `

The correct label is risky for “use_case_data/lanechange/22_lanchange”.
But, the direct prediction result is “safe”. This is incorrect. I don't know what went wrong

Model loaded from file. /home/liuxx/anaconda3/envs/av/lib/python3.9/site-packages/torch_geometric/deprecation.py:12: UserWarning: 'data.DataLoader' is deprecated, use 'loader.DataLoader' instead warnings.warn(out) tensor([-0.3435, -1.2354], device='cuda:0', grad_fn=<LogSoftmaxBackward0>)

from roadscene2vec.

malawada avatar malawada commented on August 11, 2024

It is a machine learning model so it will not always predict the correct result :). Our experiments show it predicts the correct result about 90% of the time.

from roadscene2vec.

Related Issues (10)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.