Coder Social home page Coder Social logo

kaslanarian / pydynet Goto Github PK

View Code? Open in Web Editor NEW
67.0 3.0 14.0 4.22 MB

NumPy实现类PyTorch的动态计算图和神经网络(DNN, CNN, RNN)

License: MIT License

Python 100.00%
autograd cnn pytorch-implementation rnn numpy pytorch cuda cupy python deep-learning-framework

pydynet's Introduction

😄 Hi! Nai elen siluva lyenna.

In Quenya, "Nai elen siluva lyenna" means "May the stars shine on you".

Who am I?

Welt Xing, a 1st year graduate in the School of Artificial Intelligence, Nanjing University. In German, "welt" means the "world". My interests include:

  • Machine learning model and their implementation.
  • Deep learning method and framework.
  • Machine learning theory.
  • Graph learning.

In my blog, I record what I've learned:

1

My stat

Your stars, forks, and issues are most welcome and appreciated 🥳.

Some finished and updating work

  • libsvm-sc-reading: A Chinese version manual to explain how LIBSVM works.
  • PySVM: A NumPy implementation of SVM based on SMO algorithm.
  • PyDyNet: Deep learning framework implementation using Numpy based on Autodiff.
  • SAGOD: A library for anomaly detection on static attributed graph.

pydynet's People

Contributors

kaslanarian avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

pydynet's Issues

关于GPU加速的一些bug

大佬你好,我最近在研究动态图机制的实现。在使用GPU运行你的代码的时候,我发现cuCNN的例子精度不随迭代变化(cpu正常),观察到网络参数其实并没有改变,可能是因为网络参数没有移动到gpu上。我尝试为每一层添加.to('cuda')之后参数得到了更新,精度表现也正常。但随着训练的进行,显存不断增长(未更改时显存正常,保持不变),我不知道这个的具体原因是什么,大佬可以帮我看一下吗

import numpy as np
from pydynet.tensor import Tensor
import pydynet.nn.functional as F
import pydynet.nn as nn
from pydynet.optim import Adam, SGD
from pydynet.data import DataLoader, Dataset
from tqdm import tqdm


dev = ['cpu', 'cuda'][1]
np.random.seed(42)



from scipy.io import loadmat
data = loadmat('../mnist_uint8.mat')
train_x = np.reshape(data['train_x'], (60000, 1, 28, 28)) / 255.0
train_y = data['train_y']
test_x = np.reshape(data['test_x'], (10000, 1, 28, 28)) / 255.0
test_y = data['test_y']



class mnist_dataset(Dataset):
    def __init__(self, X, y) -> None:
        super().__init__()
        self.data = X
        self.label = y

    def __getitem__(self, index):
        return self.data[index], self.label[index]

    def __len__(self):
        return len(self.data)

train_loader = DataLoader(mnist_dataset(train_x, train_y), 32, True)
test_loader = DataLoader(mnist_dataset(test_x, test_y), 32, False)




class CNN2d(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        # 每层添加to之后可以正常更新,但会爆显存
        self.conv1 = nn.Conv2d(1, 1, 3, padding=1).to(dev)
        self.fc1 = nn.Linear(49, 128).to(dev)
        self.fc2 = nn.Linear(128, 10).to(dev)

    def forward(self, x):
        x = self.conv1(x)
        x = F.max_pool2d(x, 4, 4)
        x = x.reshape(x.shape[0], -1)
        x = self.fc1(x)
        x = F.leaky_relu(x, 0.1)
        return self.fc2(x)


net3 = CNN2d().to(dev)

optim3 = Adam(net3.parameters(), lr=0.01)
loss = nn.CrossEntropyLoss().to(dev)
EPOCHES = 50
BATCH_SIZE = 32


from time import time


t = time()
for epoch in range(EPOCHES):

    net3.train()
    train_out = []
    for batch_X, batch_y in tqdm(train_loader):
        batch_X, batch_y = Tensor(batch_X).to('cuda'), Tensor(batch_y).to('cuda')
        # print(data)
        output3 = net3(batch_X)
        l3 = loss(output3, batch_y)
        optim3.zero_grad()
        l3.backward()
        optim3.step()

        acc = np.argmax(output3.numpy(), axis=1) == np.argmax(batch_y.numpy(), axis=1)
        train_out.append(acc)
        # mp.free_all_blocks()
        # pmp.free_all_blocks()
    train_out = np.concatenate(train_out)
    train_out = np.mean(train_out)

    net3.eval()
    test_out = []
    # test_label
    for batch_X, batch_y in tqdm(test_loader):
        node_y = Tensor(batch_y).to(dev)

        data = Tensor(batch_X).to(dev)
        # print(data)
        output3 = net3(data)
        l3 = loss(output3, node_y)
        t = list(output3.numpy())
        acc = np.argmax(t, axis=1) == np.argmax(batch_y, axis=1)
        test_out.append(acc)
        # del data
    test_out = np.concatenate(test_out)
    test_out = np.mean(test_out)


    print("Epoch {:2d}:".format(epoch + 1))

    print('train acc: {}, test acc: {}'.format(train_out, test_out))
print(time() - t)


Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.