Coder Social home page Coder Social logo

matgnn's Introduction

MatGNN: Graph Neural Network Library for Materials Science

MatGNN Logo

A python library written in pytorch lightning to train GNN for materials science research.

Installation

To create a virtual environment for MatGNN using conda or mamba, run the following command:

mamba env create -n matgnn -f environment.yml

Then you need to add the following to your run script:

import sys
sys.path.append("/path/to/matgnn")

Usage

First, you need to define some global variables that will be used by both the data loader and the trainer:

PREC = "64"
BATCH_SIZE = 64
DEVICE = "gpu"
ASE_DB_LOC = "path/to/database"

Here, PREC can be "64", "32", or "16" to specify the floating point precision. Batch_SIZE defines the batch size for both training and validation data loader. DEVICE defines the device type which can be eiher "cpu" or "gpu". ASE_DB_LOC indicates the location of the ASE database.

After defining the global variables, we can now define the data parameters:

For AtomGraph module:

atomgraph_params = AtomGraphParameters(
    feature_type="atomic_number",
    self_loop=True,
    graph_radius = 5
    max_neighbors = 5
    edge_feature = True
    edge_resolution = 50
    add_node_degree = True
    )

Here, feature_type can be "atomic_number" or "atomic_symbol" to specify the feature type. self_loop indicates whether to use self-loop or not. graph_radius can be used to limit the radius of the graph constructor to limit the number of neighbor. max_neighbors indicates the maximum number of neighbors to include that lie within the graph radius. edge_feature indicates whether to add edge feature or not. edge_resolution indicates the fineness of the edge feature constructor. add_node_degree indicates whether to add node degree to node feature.

After defining the feature constructor parameters, we can now incorporate that into the dataset parameters.

ds_params = DatasetParameters(
    feature_type="AtomGraph",
    ase_db_loc=ASE_DB_LOC,
    target="hof",
    dtype=PREC,
    extra_parameters=atomgraph_params)

In the dataset parameters, the feature type can be "AtomGraph", "SOAP", "CM", or "SM" to specify the feature constructor type. target can be used to use a specific column of the database as the target.

Now, we can piece them altogether to define the data module parameters which will be used to create the training and validation dataloader.

dm_params = DataModuleParameters(
    in_memory=True,
    dataset_params=ds_params,
    batch_size=BATCH_SIZE,
)

Data can either reside in memory or in file system (loaded on the fly). Now, we can initiate the data module to get some data parameters that will be needed to define the model parameters.

dm = MaterialsGraphDataModule(dm_params)
dm.setup()
n_features = dm.dataset.num_features
n_edge_features = dm.dataset.num_edge_features

Now, we are ready to define the model parameters:

model_params = GraphConvolutionParameters(
    n_features=n_features,
    n_edge_features=n_edge_features,
    batch_size=BATCH_SIZE,
    pre_hidden_size=130,
    post_hidden_size=120,
    gcn_hidden_size=150,
    n_gcn=4,
    n_pre_gcn_layers = 2,
    n_post_gcn_layers = 2,
    gcn_type="schnet",
    pool="max",
    dtype=PREC,
    device=DEVICE,
)

Here, gcn_type can be "gcn", "cgcnn", or "schnet". To use message passing NN, you can similarly define the MPNNParameters.

We can then pass the model parameters to the matgnn constructor:

mg_params = MatGNNParameters(
    model_params=model_params,
    optimizer="adam"
    )

mg = MatGNN(params=mg_params)

Finally, we can intialize the pytorch lightning trainer to train the model:

trainer = pl.Trainer(accelerator=DEVICE,
                     max_epochs=200,
                     precision=PREC)

trainer.fit(
    model=mg,
    train_dataloaders=dm.train_dataloader(),
    val_dataloaders=dm.val_dataloader()
    )

Running this script will print the training statistics on the console.

TODO

- [ ] MEGNET
- [ ] Test metrics
- [ ] logger
- [ ] saving and loading
- [ ] continue training from checkpoint
- [ ] Hyperparameter
- [ ] Weight initialization

Author

matgnn's People

Contributors

mamunm avatar

Watchers

 avatar Kostas Georgiou avatar

Forkers

satanug

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.