Coder Social home page Coder Social logo

lvrcek / gnnome-assembly Goto Github PK

View Code? Open in Web Editor NEW
69.0 3.0 11.0 90.49 MB

Learning to untangle genome assembly with graph neural networks.

License: MIT License

Python 98.53% Shell 1.47%
denovo-assembly genome-assembly graph-algorithms graph-neural-networks

gnnome-assembly's Introduction

Hi there, I'm Lovro ๐Ÿ‘‹

Lovro's Github Lovro's Twitter Lovro's LinkedIn Lovro's email

  • ๐Ÿ“š Interested in Deep Learning ๐Ÿง  and Computational Genomics ๐Ÿงฌ
  • ๐Ÿ”ญ Currently working on applying Graph Neural Networks to de novo Genome Assembly
  • ๐Ÿ‘จโ€๐Ÿ’ป My go-to languages and frameworks: Python, PyTorch, DGL
  • ๐Ÿ“ซ How to reach me: Send an email to vrcek dot lovro at gmail dot com (or by clicking any of the buttons above)
  • โšก Hobbies: Rock Climbing ๐Ÿง—โ€โ™‚๏ธ, Swimming ๐ŸŠโ€โ™‚๏ธ, Reading ๐Ÿ“–

gnnome-assembly's People

Contributors

lvrcek avatar xbresson avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

gnnome-assembly's Issues

Read simulation and Raven issues

I have been unable to get this past Raven. The seqrequester program gives me the following error every time that I run pipeline.py:

"ERROR: Don't know how long to make the reads. Set -distribution or -length"

If I simulate reads using another program, then put them in simulate/chr1/raw, the pipeline will run a bit further but following raven crashes with a non-specific error:

"Step 0: generating graphs for reads in 0.fasta
Path to the reads: /ocean/projects/bio220033p/crj39/Assemblies/GNNome-assembly/data/simulated/chr1/raw/0.fasta
Starting raven at: /ocean/projects/bio220033p/crj39/Assemblies/GNNome-assembly/vendor/raven/build/bin/raven
Parameters: --identity 0.99 -k29 -w9 -t32 -p0
Assembly output: assembly.fasta

[raven::] loaded 731 sequences 0.042683s
[raven::Graph::Construct] minimized 0 - 731 / 731 0.272052s
[raven::Graph::Construct] mapped sequences 0.120643s
[raven::Graph::Construct] annotated piles 0.001718s
[raven::Graph::Construct] filtered overlaps 0.232881s
[raven::Graph::Construct] removed contained sequences 0.000014s
[raven::Graph::Construct] removed chimeric sequences 0.001704s
[raven::Graph::Construct] reached checkpoint 0.004957s
[raven::Graph::Construct] minimized 0 - 463 / 463 0.326245s
[raven::Graph::Construct] mapped valid sequences 0.336586s
[raven::Graph::Construct] updated overlaps 0.000010s
[raven::Graph::Construct] removed false overlaps 0.010328s
[raven::Graph::Construct] stored 926 nodes 0.011106s
[raven::Graph::Construct] stored 0 edges 0.000000s
[raven::Graph::Construct] reached checkpoint 0.014007s
[raven::Graph::Construct] 1.334490s
[raven::Graph::Assemble] removed transitive edges 0.000028s
[raven::Graph::Assemble] reached checkpoint 0.013977s
[raven::Graph::Assemble] removed tips and bubbles 0.000023s
[raven::Graph::Assemble] reached checkpoint 0.013770s
[raven::Graph::Assemble] removed long edges 0.002454s
[raven::Graph::Assemble] reached checkpoint 0.014359s
[raven::Graph::Assemble] 0.046142s
[raven::] 1.525321s

Raven generated the graph! Processing...
Traceback (most recent call last):
File "./pipeline.py", line 347, in
generate_graphs(data_path, all_chr)
File "./pipeline.py", line 148, in generate_graphs
graph_dataset.AssemblyGraphDataset(chr_sim_path, nb_pos_enc=None, specs=specs, generate=True)
File "/ocean/projects/bio220033p/crj39/Assemblies/GNNome-assembly/graph_dataset.py", line 66, in init
super().init(name='assembly_graphs', raw_dir=raw_dir, save_dir=save_dir)
File "/jet/home/crj39/.local/lib/python3.8/site-packages/dgl/data/dgl_dataset.py", line 113, in init
self._load()
File "/jet/home/crj39/.local/lib/python3.8/site-packages/dgl/data/dgl_dataset.py", line 204, in _load
self.process()
File "/ocean/projects/bio220033p/crj39/Assemblies/GNNome-assembly/graph_dataset.py", line 126, in process
graph, pred, succ, reads, edges, labels = graph_parser.from_csv(os.path.join(self.tmp_dir, f'{idx}_graph_1.csv'), reads_path)
File "/ocean/projects/bio220033p/crj39/Assemblies/GNNome-assembly/graph_parser.py", line 297, in from_csv
graph_dgl = dgl.from_networkx(graph_nx,
File "/jet/home/crj39/.local/lib/python3.8/site-packages/dgl/convert.py", line 1260, in from_networkx
g.ndata[attr] = F.copy_to(_batcher(attr_dict[attr]), g.device)
File "/jet/home/crj39/.local/lib/python3.8/site-packages/dgl/convert.py", line 1249, in _batcher
if F.is_tensor(lst[0]):
IndexError: list index out of range"

This seems to happen both when using my own data and when running the pipeline as prepared with example data.

Here is my command:
python ./pipeline.py --data ./data --out mito_test1

Hifiasm_error correction_question

Hi,author:
I would like to ask you how the reads after Hifiasm error correction are obtained from the bin file?
What software is used?
The original sentence in the article:
We first error-correct reads with hifiasm [16], thus reducing amount of mismatches, insertions, and
deletions in the reads.

I am looking forward to your reply. Thanks.

Data formatting question for new analysis

Thank you for this tool and the manuscript, very interesting work. I am currently trying to configure the pipeline.py script for running on my own plant genome (n=7). Can the pipeline.py script be edited simply by removing steps "-1" and "0" and changing the structure of my own data to that specified in the manuscript (the Code section)?

Generating real graphs skipped

I am carrying out my own analysis and am having some issues. The simulated graphs are generated just fine (i.e. I needed 5 graphs for chr1 and they all worked, generated .gfa and .dgl files), but following this my edited pipeline.py script goes straight to the train/valid/test split it seems, skipping over the "Generate the real_graphs" step. Following this, the split cannot access the .dgl file from the real reads because they consequently weren't assembled with Raven yet. Below is my log:

Raven generated the graph! Processing...
Parsed Raven output! Saving files...
Processing of graph 4 generated from 4.fasta done!

SETUP::split
SETUP::split:: Copying 3 graphs of chr1 into /ocean/projects/bio220033p/crj39/Software/GNNome-assembly/data/experiments/train_Phyfe1
Copying /ocean/projects/bio220033p/crj39/Software/GNNome-assembly/data/simulated/chr1/processed/0.dgl into /ocean/projects/bio220033p/crj39/Software/GNNome-assembly/data/experiments/train_Phyfe1/processed/0.dgl
Copying /ocean/projects/bio220033p/crj39/Software/GNNome-assembly/data/simulated/chr1/processed/1.dgl into /ocean/projects/bio220033p/crj39/Software/GNNome-assembly/data/experiments/train_Phyfe1/processed/1.dgl
Copying /ocean/projects/bio220033p/crj39/Software/GNNome-assembly/data/simulated/chr1/processed/2.dgl into /ocean/projects/bio220033p/crj39/Software/GNNome-assembly/data/experiments/train_Phyfe1/processed/2.dgl
SETUP::split:: Copying 2 graphs of chr1 into /ocean/projects/bio220033p/crj39/Software/GNNome-assembly/data/experiments/valid_Phyfe1
Copying /ocean/projects/bio220033p/crj39/Software/GNNome-assembly/data/simulated/chr1/processed/3.dgl into /ocean/projects/bio220033p/crj39/Software/GNNome-assembly/data/experiments/valid_Phyfe1/processed/0.dgl
Copying /ocean/projects/bio220033p/crj39/Software/GNNome-assembly/data/simulated/chr1/processed/4.dgl into /ocean/projects/bio220033p/crj39/Software/GNNome-assembly/data/experiments/valid_Phyfe1/processed/1.dgl
SETUP::split:: Copying 1 graphs of chr8_r into /ocean/projects/bio220033p/crj39/Software/GNNome-assembly/data/experiments/test_Phyfe1
Copying /ocean/projects/bio220033p/crj39/Software/GNNome-assembly/data/real/chr8/processed/0.dgl into /ocean/projects/bio220033p/crj39/Software/GNNome-assembly/data/experiments/test_Phyfe1/processed/0.dgl
cp: cannot stat '/ocean/projects/bio220033p/crj39/Software/GNNome-assembly/data/real/chr8/processed/0.dgl': No such file or directory
cp: cannot stat '/ocean/projects/bio220033p/crj39/Software/GNNome-assembly/data/real/chr8/info/0_succ.pkl': No such file or directory
cp: cannot stat '/ocean/projects/bio220033p/crj39/Software/GNNome-assembly/data/real/chr8/info/0_pred.pkl': No such file or directory
cp: cannot stat '/ocean/projects/bio220033p/crj39/Software/GNNome-assembly/data/real/chr8/info/0_edges.pkl': No such file or directory
cp: cannot stat '/ocean/projects/bio220033p/crj39/Software/GNNome-assembly/data/real/chr8/info/0_reads.pkl': No such file or directory
SETUP::train
DGL graph idx=0 info:
Graph(num_nodes=83704, num_edges=828890,
ndata_schemes={'read_idx': Scheme(shape=(), dtype=torch.int64), 'read_length': Scheme(shape=(), dtype=torch.int64), 'read_trim_end': Scheme(shape=(), dtype=torch.int64), 'read_strand': Scheme(shape=(), dtype=torch.int64), 'read_start': Scheme(shape=(), dtype=torch.int64), 'read_end': Scheme(shape=(), dtype=torch.int64), 'read_trim_start': Scheme(shape=(), dtype=torch.int64), 'x': Scheme(shape=(1,), dtype=torch.float32), 'in_deg': Scheme(shape=(), dtype=torch.float32), 'out_deg': Scheme(shape=(), dtype=torch.float32), 'pe': Scheme(shape=(16,), dtype=torch.float32)}
edata_schemes={'y': Scheme(shape=(), dtype=torch.float32), 'overlap_length': Scheme(shape=(), dtype=torch.int64), 'prefix_length': Scheme(shape=(), dtype=torch.int64), 'overlap_similarity': Scheme(shape=(), dtype=torch.float32), 'e': Scheme(shape=(2,), dtype=torch.float32)})
DGL graph idx=1 info:
Graph(num_nodes=83886, num_edges=834224,
ndata_schemes={'read_idx': Scheme(shape=(), dtype=torch.int64), 'read_length': Scheme(shape=(), dtype=torch.int64), 'read_trim_end': Scheme(shape=(), dtype=torch.int64), 'read_strand': Scheme(shape=(), dtype=torch.int64), 'read_start': Scheme(shape=(), dtype=torch.int64), 'read_end': Scheme(shape=(), dtype=torch.int64), 'read_trim_start': Scheme(shape=(), dtype=torch.int64), 'x': Scheme(shape=(1,), dtype=torch.float32), 'in_deg': Scheme(shape=(), dtype=torch.float32), 'out_deg': Scheme(shape=(), dtype=torch.float32), 'pe': Scheme(shape=(16,), dtype=torch.float32)}
edata_schemes={'y': Scheme(shape=(), dtype=torch.float32), 'overlap_length': Scheme(shape=(), dtype=torch.int64), 'prefix_length': Scheme(shape=(), dtype=torch.int64), 'overlap_similarity': Scheme(shape=(), dtype=torch.float32), 'e': Scheme(shape=(2,), dtype=torch.float32)})
DGL graph idx=2 info:
Graph(num_nodes=84162, num_edges=831336,
ndata_schemes={'read_idx': Scheme(shape=(), dtype=torch.int64), 'read_length': Scheme(shape=(), dtype=torch.int64), 'read_trim_end': Scheme(shape=(), dtype=torch.int64), 'read_strand': Scheme(shape=(), dtype=torch.int64), 'read_start': Scheme(shape=(), dtype=torch.int64), 'read_end': Scheme(shape=(), dtype=torch.int64), 'read_trim_start': Scheme(shape=(), dtype=torch.int64), 'x': Scheme(shape=(1,), dtype=torch.float32), 'in_deg': Scheme(shape=(), dtype=torch.float32), 'out_deg': Scheme(shape=(), dtype=torch.float32), 'pe': Scheme(shape=(16,), dtype=torch.float32)}
edata_schemes={'y': Scheme(shape=(), dtype=torch.float32), 'overlap_length': Scheme(shape=(), dtype=torch.int64), 'prefix_length': Scheme(shape=(), dtype=torch.int64), 'overlap_similarity': Scheme(shape=(), dtype=torch.float32), 'e': Scheme(shape=(2,), dtype=torch.float32)})
DGL graph idx=0 info:
Graph(num_nodes=83884, num_edges=830630,
ndata_schemes={'read_idx': Scheme(shape=(), dtype=torch.int64), 'read_length': Scheme(shape=(), dtype=torch.int64), 'read_trim_end': Scheme(shape=(), dtype=torch.int64), 'read_strand': Scheme(shape=(), dtype=torch.int64), 'read_start': Scheme(shape=(), dtype=torch.int64), 'read_end': Scheme(shape=(), dtype=torch.int64), 'read_trim_start': Scheme(shape=(), dtype=torch.int64), 'x': Scheme(shape=(1,), dtype=torch.float32), 'in_deg': Scheme(shape=(), dtype=torch.float32), 'out_deg': Scheme(shape=(), dtype=torch.float32), 'pe': Scheme(shape=(16,), dtype=torch.float32)}
edata_schemes={'y': Scheme(shape=(), dtype=torch.float32), 'overlap_length': Scheme(shape=(), dtype=torch.int64), 'prefix_length': Scheme(shape=(), dtype=torch.int64), 'overlap_similarity': Scheme(shape=(), dtype=torch.float32), 'e': Scheme(shape=(2,), dtype=torch.float32)})
DGL graph idx=1 info:
Graph(num_nodes=83910, num_edges=832352,
ndata_schemes={'read_idx': Scheme(shape=(), dtype=torch.int64), 'read_length': Scheme(shape=(), dtype=torch.int64), 'read_trim_end': Scheme(shape=(), dtype=torch.int64), 'read_strand': Scheme(shape=(), dtype=torch.int64), 'read_start': Scheme(shape=(), dtype=torch.int64), 'read_end': Scheme(shape=(), dtype=torch.int64), 'read_trim_start': Scheme(shape=(), dtype=torch.int64), 'x': Scheme(shape=(1,), dtype=torch.float32), 'in_deg': Scheme(shape=(), dtype=torch.float32), 'out_deg': Scheme(shape=(), dtype=torch.float32), 'pe': Scheme(shape=(16,), dtype=torch.float32)}
edata_schemes={'y': Scheme(shape=(), dtype=torch.float32), 'overlap_length': Scheme(shape=(), dtype=torch.int64), 'prefix_length': Scheme(shape=(), dtype=torch.int64), 'overlap_similarity': Scheme(shape=(), dtype=torch.float32), 'e': Scheme(shape=(2,), dtype=torch.float32)})
Traceback (most recent call last):
File "./pipeline.py", line 390, in
train_model(train_path, valid_path, out, overfit)
File "./pipeline.py", line 322, in train_model
train.train(train_path, valid_path, out, overfit)
File "/ocean/projects/bio220033p/crj39/Software/GNNome-assembly/train.py", line 198, in train
model.to(device)
File "/jet/home/crj39/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 989, in to
return self._apply(convert)
File "/jet/home/crj39/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 641, in _apply
module._apply(fn)
File "/jet/home/crj39/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 664, in _apply
param_applied = fn(param)
File "/jet/home/crj39/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 987, in convert
return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None, non_blocking)
RuntimeError: CUDA error: invalid device ordinal
CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.

Below is my config.py:

################################################################################

# Edit these three dictionaries to specify graphs to train/validation/test
# Assemblies will be constructed only for the graphs in the test_dict

# To train/validate/test on multiple chromosomes, put the as separate
# entries in the dictionaries
# E.g., to train on 1 chr19 graph and 2 chr20 graphs:
# _train_dict = {'chr19': 1, 'chr20': 2}

# To test on real chromosome put "_r" suffix. Don't put value higher than 1,
# since there is only 1 real HiFi dataset for each chromosomes
# E.g., to test on real chr21:
# _test_dict = {'chr21_r': 1}

_train_dict = {'chr1': 3}
_valid_dict = {'chr1': 2}
_test_dict = {'chr8_r': 1}

################################################################################

def get_config():
return {
'train_dict': _train_dict,
'valid_dict': _valid_dict,
'test_dict' : _test_dict
}

And finally my custom pipeline.py, which I feel like may be the issue:

import argparse
import gzip
import os
import pickle
import subprocess
from datetime import datetime

from tqdm import tqdm
import requests
from Bio import SeqIO

import graph_dataset
import train
import inference
import evaluate
import config

chr_lens = {
'chr1' : 41585911,
'chr2' : 45593668,
'chr3' : 35595883,
'chr4' : 43868759,
'chr5' : 42432271,
'chr6' : 34633677,
'chr7' : 39460530,
'chr8' : 60396661,
'chr9' : 50334075,
'chr10' : 34936972,
'chr11' : 34623285,
'chr12' : 30656727,
'chr13' : 62553722,
}

def change_description(file_path):
new_fasta = []
for record in SeqIO.parse(file_path, file_path[-5:]): # 'fasta' for FASTA file, 'fastq' for FASTQ file
des = record.description.split(",")
id = des[0][5:]
if des[1] == "forward":
strand = '+'
else:
strand = '-'
position = des[2][9:].split("-")
start = position[0]
end = position[1]
record.id = id
record.description = f'strand={strand}, start={start}, end={end}'
new_fasta.append(record)
SeqIO.write(new_fasta, file_path, "fasta")

#def create_chr_dirs(pth):
# for i in range(1, 24):
# if i == 23:
# i = 'X'
# subprocess.run(f'mkdir chr{i}', shell=True, cwd=pth)
# subprocess.run(f'mkdir raw processed info raven_output graphia', shell=True, cwd=os.path.join(pth, f'chr{i}'))

def merge_dicts(d1, d2, d3={}):
keys = {*d1, *d2, *d3}
merged = {key: d1.get(key, 0) + d2.get(key, 0) + d3.get(key, 0) for key in keys}
return merged

# -1. Set up the data file structure
def file_structure_setup(data_path, ref_path):
print(f'SETUP::filesystem:: Create directories for storing data')
if not os.path.isdir(data_path):
os.makedirs(data_path)

# if 'CHM13' not in os.listdir(ref_path):
# os.mkdir(os.path.join(ref_path, 'CHM13'))
if 'chromosomes' not in os.listdir(ref_path):
os.mkdir(os.path.join(ref_path, 'chromosomes'))

if 'simulated' not in os.listdir(data_path):
    os.mkdir(os.path.join(data_path, 'simulated'))
    create_chr_dirs(os.path.join(data_path, 'simulated'))
if 'real' not in os.listdir(data_path):
    subprocess.run(f'bash download_dataset.sh {data_path}', shell=True)
    \# os.mkdir(os.path.join(data_path, 'real'))
    \# create_chr_dirs(os.path.join(data_path, 'real'))
if 'experiments' not in os.listdir(data_path):
    os.mkdir(os.path.join(data_path, 'experiments'))

# 0. Download the CHM13 if necessary
#def download_reference(ref_path):
# chm_path = os.path.join(ref_path, 'CHM13')
chr_path = os.path.join(ref_path, 'chromosomes')
# chm13_url = 'https://s3-us-west-2.amazonaws.com/human-pangenomics/T2T/CHM13/assemblies/chm13.draft_v1.1.fasta.gz'
# chm13_path = os.path.join(chm_path, 'chm13.draft_v1.1.fasta.gz')

# if len(os.listdir(chm_path)) == 0:
# # Download the CHM13 reference
# # Code for tqdm from: https://stackoverflow.com/questions/37573483/progress-bar-while-download-file-over-http-with-requests
# print(f'SETUP::download:: CHM13 not found! Downloading...')
# response = requests.get(chm13_url, stream=True)
# total_size_in_bytes= int(response.headers.get('content-length', 0))
# block_size = 1024 #1 Kibibyte
# progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)

# with open(chm13_path, 'wb') as file:
# for data in response.iter_content(block_size):
# progress_bar.update(len(data))
# file.write(data)
# progress_bar.close()
# if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
# print("ERROR, something went wrong")

# if len(os.listdir(chr_path)) == 0:
# # Parse the CHM13 into individual chromosomes
# print(f'SETUP::download:: Split CHM13 per chromosome')
# with gzip.open(chm13_path, 'rt') as f:
# for record in SeqIO.parse(f, 'fasta'):
# SeqIO.write(record, os.path.join(chr_path, f'{record.id}.fasta'), 'fasta')

# 1. Simulate the sequences
def simulate_reads(data_path, ref_path, chr_dict):
# Dict saying how much of simulated datasets for each chromosome do we need
# E.g., {'chr1': 4, 'chr6': 2, 'chrX': 4}

# print(f'SETUP::simulate')
# if 'vendor' not in os.listdir():
# os.mkdir('vendor')
# if 'seqrequester' not in os.listdir('vendor'):
# print(f'SETUP::simulate:: Download seqrequester')
# subprocess.run(f'git clone https://github.com/marbl/seqrequester.git', shell=True, cwd='vendor')
# subprocess.run(f'make', shell=True, cwd='vendor/seqrequester/src')

data_path = os.path.abspath(data_path)
chr_path = os.path.join(ref_path, 'chromosomes')
len_path = os.path.join(ref_path, 'lengths')
sim_path = os.path.join(data_path, 'simulated')
for chrN, n_need in chr_dict.items():
    if '_r' in chrN:
        continue
    chr_raw_path = os.path.join(sim_path, f'{chrN}/raw')
    n_have = len(os.listdir(chr_raw_path))
    if n_need <= n_have:
        continue
    else:
        n_diff = n_need - n_have

# print(f'SETUP::simulate:: Simulate {n_diff} datasets for {chrN}')
# # Simulate reads for chrN n_diff times
chr_seq_path = os.path.join(chr_path, f'{chrN}.fasta')
chr_dist_path = os.path.join(len_path, f'{chrN}.txt')
chr_len = chr_lens[chrN]
for i in range(n_diff):
idx = n_have + i
chr_save_path = os.path.join(chr_raw_path, f'{idx}.fasta')
# print(f'\nStep {i}: Simulating reads {chr_save_path}')
# subprocess.run(f'./vendor/seqrequester/build/bin/seqrequester simulate -genome {chr_seq_path} '
# f'-genomesize {chr_len} -coverage 40 -distribution {chr_dist_path} > {chr_save_path}',
# shell=True)
change_description(chr_save_path)

# 2. Generate the graphs
def generate_graphs(data_path, chr_dict):
print(f'SETUP::generate')

if 'raven' not in os.listdir('vendor'):
    print(f'SETUP::generate:: Download Raven')
    subprocess.run(f'git clone -b print_graphs https://github.com/lbcb-sci/raven', shell=True, cwd='vendor')
    subprocess.run(f'cmake -S ./ -B./build -DRAVEN_BUILD_EXE=1 -DCMAKE_BUILD_TYPE=Release', shell=True, cwd='vendor/raven')
    subprocess.run(f'cmake --build build', shell=True, cwd='vendor/raven')

data_path = os.path.abspath(data_path)

for chrN, n_need in chr_dict.items():
    if '_r' in chrN:
        continue
    chr_sim_path = os.path.join(data_path, 'simulated', f'{chrN}')
    chr_raw_path = os.path.join(chr_sim_path, 'raw')
    chr_prc_path = os.path.join(chr_sim_path, 'processed')
    n_raw = len(os.listdir(chr_raw_path))
    n_prc = len(os.listdir(chr_prc_path))
    n_diff = n_raw - n_prc
    print(f'SETUP::generate:: Generate {n_diff} graphs for {chrN}')
    specs = {
        'threads': 32,
        'filter': 0.99,
        'out': 'assembly.fasta'
    }
    graph_dataset.AssemblyGraphDataset(chr_sim_path, nb_pos_enc=None, specs=specs, generate=True)

# 2.1. Generate the real_graphs
def generate_graphs_real(data_path, chr_real_list):
print(f'SETUP::generate')

if 'raven' not in os.listdir('vendor'):
    print(f'SETUP::generate:: Download Raven')
    subprocess.run(f'git clone -b print_graphs https://github.com/lbcb-sci/raven', shell=True, cwd='vendor')
    subprocess.run(f'cmake -S ./ -B./build -DRAVEN_BUILD_EXE=1 -DCMAKE_BUILD_TYPE=Release', shell=True, cwd='vendor/raven')
    subprocess.run(f'cmake --build build', shell=True, cwd='vendor/raven')

data_path = os.path.abspath(data_path)
for chrN in chr_real_list:
    chr_sim_path = os.path.abspath(data_path, 'real', f'{chrN}')
    chr_raw_path = os.path.join(chr_sim_path, 'raw')
    chr_prc_path = os.path.join(chr_sim_path, 'processed')
    n_raw = len(os.listdir(chr_raw_path))
    n_prc = len(os.listdir(chr_prc_path))
    n_diff = n_raw - n_prc
    print(f'SETUP::generate:: Generate {n_diff} graphs for {chrN}')
    specs = {
        'threads': 32,
        'filter': 0.99,
        'out': 'assembly.fasta'
    }
    graph_dataset.AssemblyGraphDataset(chr_sim_path, nb_pos_enc=None, specs=specs, generate=True)

# 2.5 Train-valid-test split
def train_valid_split(data_path, train_dict, valid_dict, test_dict={}, out=None):
print(f'SETUP::split')
data_path = os.path.abspath(data_path)
sim_path = os.path.join(data_path, 'simulated')
real_path = os.path.join(data_path, 'real')
exp_path = os.path.join(data_path, 'experiments')

if out is None:
    train_path = os.path.join(exp_path, f'train')
    valid_path = os.path.join(exp_path, f'valid')
    test_path  = os.path.join(exp_path, f'test')
else:
    train_path = os.path.join(exp_path, f'train_{out}')
    valid_path = os.path.join(exp_path, f'valid_{out}')
    test_path  = os.path.join(exp_path, f'test_{out}')
if not os.path.isdir(train_path):
    os.makedirs(train_path)
    subprocess.run(f'mkdir raw processed info', shell=True, cwd=train_path)
if not os.path.isdir(valid_path):
    os.makedirs(valid_path)
    subprocess.run(f'mkdir raw processed info', shell=True, cwd=valid_path)
if not os.path.isdir(test_path) and len(test_dict) > 0:
    os.makedirs(test_path)
    subprocess.run(f'mkdir raw processed info', shell=True, cwd=test_path)

train_g_to_chr = {}  # Remember chromosomes for each graph in the dataset
train_g_to_org_g = {}  # Remember index of the graph in the master dataset for each graph in this dataset
n_have = 0
for chrN, n_need in train_dict.items():
    \# copy n_need datasets from chrN into train dict
    print(f'SETUP::split:: Copying {n_need} graphs of {chrN} into {train_path}')
    for i in range(n_need):
        train_g_to_chr[n_have] = chrN
        chr_sim_path = os.path.join(sim_path, chrN)
        print(f'Copying {chr_sim_path}/processed/{i}.dgl into {train_path}/processed/{n_have}.dgl')
        subprocess.run(f'cp {chr_sim_path}/processed/{i}.dgl {train_path}/processed/{n_have}.dgl', shell=True)
        subprocess.run(f'cp {chr_sim_path}/info/{i}_succ.pkl {train_path}/info/{n_have}_succ.pkl', shell=True)
        subprocess.run(f'cp {chr_sim_path}/info/{i}_pred.pkl {train_path}/info/{n_have}_pred.pkl', shell=True)
        subprocess.run(f'cp {chr_sim_path}/info/{i}_edges.pkl {train_path}/info/{n_have}_edges.pkl', shell=True)
        subprocess.run(f'cp {chr_sim_path}/info/{i}_reads.pkl {train_path}/info/{n_have}_reads.pkl', shell=True)
        train_g_to_org_g[n_have] = i
        n_have += 1
pickle.dump(train_g_to_chr, open(f'{train_path}/info/g_to_chr.pkl', 'wb'))
pickle.dump(train_g_to_org_g, open(f'{train_path}/info/g_to_org_g.pkl', 'wb'))

valid_g_to_chr = {}
valid_g_to_org_g = {}
n_have = 0
for chrN, n_need in valid_dict.items():
    \# copy n_need datasets from chrN into train dict
    print(f'SETUP::split:: Copying {n_need} graphs of {chrN} into {valid_path}')
    for i in range(n_need):
        valid_g_to_chr[n_have] = chrN
        j = i + train_dict.get(chrN, 0)
        chr_sim_path = os.path.join(sim_path, chrN)
        print(f'Copying {chr_sim_path}/processed/{j}.dgl into {valid_path}/processed/{n_have}.dgl')
        subprocess.run(f'cp {chr_sim_path}/processed/{j}.dgl {valid_path}/processed/{n_have}.dgl', shell=True)
        subprocess.run(f'cp {chr_sim_path}/info/{j}_succ.pkl {valid_path}/info/{n_have}_succ.pkl', shell=True)
        subprocess.run(f'cp {chr_sim_path}/info/{j}_pred.pkl {valid_path}/info/{n_have}_pred.pkl', shell=True)
        subprocess.run(f'cp {chr_sim_path}/info/{j}_edges.pkl {valid_path}/info/{n_have}_edges.pkl', shell=True)
        subprocess.run(f'cp {chr_sim_path}/info/{j}_reads.pkl {valid_path}/info/{n_have}_reads.pkl', shell=True)
        valid_g_to_org_g[n_have] = j
        n_have += 1
pickle.dump(valid_g_to_chr, open(f'{valid_path}/info/g_to_chr.pkl', 'wb'))
pickle.dump(valid_g_to_org_g, open(f'{valid_path}/info/g_to_org_g.pkl', 'wb'))

if test_dict: 
    test_g_to_chr = {}
    test_g_to_org_g = {}
    n_have = 0
    for chrN, n_need in test_dict.items():
        \# copy n_need datasets from chrN into train dict
        if '_r' in chrN and n_need > 1:
            print(f'SETUP::split::WARNING Cannot copy more than one graph for real data: {chrN}')
            n_need = 1
        print(f'SETUP::split:: Copying {n_need} graphs of {chrN} into {test_path}')
        for i in range(n_need):
            if '_r' in chrN:
                chrN = chrN[:-2]
                chr_sim_path = os.path.join(real_path, chrN)
                k = 0
            else:
                chr_sim_path = os.path.join(sim_path, chrN)
                k = i + train_dict.get(chrN, 0) + valid_dict.get(chrN, 0)
            test_g_to_chr[n_have] = chrN
            print(f'Copying {chr_sim_path}/processed/{k}.dgl into {test_path}/processed/{n_have}.dgl')
            subprocess.run(f'cp {chr_sim_path}/processed/{k}.dgl {test_path}/processed/{n_have}.dgl', shell=True)
            subprocess.run(f'cp {chr_sim_path}/info/{k}_succ.pkl {test_path}/info/{n_have}_succ.pkl', shell=True)
            subprocess.run(f'cp {chr_sim_path}/info/{k}_pred.pkl {test_path}/info/{n_have}_pred.pkl', shell=True)
            subprocess.run(f'cp {chr_sim_path}/info/{k}_edges.pkl {test_path}/info/{n_have}_edges.pkl', shell=True)
            subprocess.run(f'cp {chr_sim_path}/info/{k}_reads.pkl {test_path}/info/{n_have}_reads.pkl', shell=True)
            n_have += 1
            test_g_to_org_g[n_have] = k
    pickle.dump(test_g_to_chr, open(f'{test_path}/info/g_to_chr.pkl', 'wb'))
    pickle.dump(test_g_to_org_g, open(f'{test_path}/info/g_to_org_g.pkl', 'wb'))

return train_path, valid_path, test_path

# 3. Train the model
def train_model(train_path, valid_path, out, overfit):
print(f'SETUP::train')
train.train(train_path, valid_path, out, overfit)

# 4. Inference - get the results
def predict(test_path, out, model_path=None, device='cpu'):
if model_path is None:
model_path = os.path.abspath(f'pretrained/model_{out}.pt')
walks_per_graph, contigs_per_graph = inference.inference(test_path, model_path, device)
g_to_chr = pickle.load(open(f'{test_path}/info/g_to_chr.pkl', 'rb'))

for idx, contigs in enumerate(contigs_per_graph):
    chrN = g_to_chr[idx]
    num_contigs, longest_contig, reconstructed, n50, ng50 = evaluate.quick_evaluation(contigs, chrN)
    evaluate.print_summary(test_path, idx, chrN, num_contigs, longest_contig, reconstructed, n50, ng50)

def predict_baselines(test_path, out, model_path=None, device='cpu'):
if model_path is None:
model_path = os.path.abspath(f'pretrained/model_{out}.pt')
walks_and_contigs = inference.inferencei_baselines(test_path, model_path, device)
walks_per_graph, contigs_per_graph = walks_and_contigs[0], walks_and_contigs[1]
walks_per_graph_ol_len, contigs_per_graph_ol_len = walks_and_contigs[2], walks_and_contigs[3]
walks_per_graph_ol_sim, contigs_per_graph_ol_sim = walks_and_contigs[4], walks_and_contigs[5]
g_to_chr = pickle.load(open(f'{test_path}/info/g_to_chr.pkl', 'rb'))

for idx, (contigs, contigs_ol_len, contigs_ol_sim) in enumerate(zip(contigs_per_graph, contigs_per_graph_ol_len, contigs_per_graph_ol_sim)):
    chrN = g_to_chr[idx]
    print(f'GNN: Scores')
    num_contigs, longest_contig, reconstructed, n50, ng50 = evaluate.quick_evaluation(contigs, chrN)
    evaluate.print_summary(test_path, idx, chrN, num_contigs, longest_contig, reconstructed, n50, ng50)
    print(f'Baseline: Overlap lengths')
    num_contigs, longest_contig, reconstructed, n50, ng50 = evaluate.quick_evaluation(contigs_ol_len, chrN)
    evaluate.print_summary(test_path, idx, chrN, num_contigs, longest_contig, reconstructed, n50, ng50)
    print(f'Baseline: Overlap similarities')
    num_contigs, longest_contig, reconstructed, n50, ng50 = evaluate.quick_evaluation(contigs_ol_sim, chrN)
    evaluate.print_summary(test_path, idx, chrN, num_contigs, longest_contig, reconstructed, n50, ng50)

if name == 'main':
parser = argparse.ArgumentParser()
parser.add_argument('--data', type=str, default='data', help='Path to directory with simulated and real data')
parser.add_argument('--refs', type=str, default='data/references', help='Path to directory with reference information')
parser.add_argument('--out', type=str, default=None, help='Output name for figures and models')
parser.add_argument('--overfit', action='store_true', default=False, help='Overfit on the chromosomes in the train directory')
args = parser.parse_args()

data_path = args.data
ref_path = args.refs
out = args.out
overfit = args.overfit

time_start = datetime.now()
timestamp = time_start.strftime('%Y-%b-%d-%H-%M-%S')
if out is None:
    out = timestamp

dicts = config.get_config()
train_dict = dicts['train_dict']
valid_dict = dicts['valid_dict']
test_dict = dicts['test_dict']

all_chr = merge_dicts(train_dict, valid_dict, test_dict)

# file_structure_setup(data_path, ref_path)
# download_reference(ref_path)
# simulate_reads(data_path, ref_path, all_chr)
generate_graphs(data_path, all_chr)
train_path, valid_path, test_path = train_valid_split(data_path, train_dict, valid_dict, test_dict, out)
train_model(train_path, valid_path, out, overfit)
predict(test_path, out, device='cpu')

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.