Coder Social home page Coder Social logo

[DOC] Update mnist.py example about examples HOT 4 OPEN

orion160 avatar orion160 commented on August 25, 2024
[DOC] Update mnist.py example

from examples.

Comments (4)

orion160 avatar orion160 commented on August 25, 2024

Proposal:

diff --git a/mnist/main.py b/mnist/main.py
index 184dc47..a3cffd1 100644
--- a/mnist/main.py
+++ b/mnist/main.py
@@ -3,13 +3,14 @@ import torch
 import torch.nn as nn
 import torch.nn.functional as F
 import torch.optim as optim
-from torchvision import datasets, transforms
+from torchvision import datasets
+from torchvision.transforms import v2 as transforms
 from torch.optim.lr_scheduler import StepLR
 
 
 class Net(nn.Module):
     def __init__(self):
-        super(Net, self).__init__()
+        super().__init__()
         self.conv1 = nn.Conv2d(1, 32, 3, 1)
         self.conv2 = nn.Conv2d(32, 64, 3, 1)
         self.dropout1 = nn.Dropout(0.25)
@@ -33,19 +34,42 @@ class Net(nn.Module):
         return output
 
 
-def train(args, model, device, train_loader, optimizer, epoch):
+def train_amp(args, model, device, train_loader, opt, epoch, scaler):
     model.train()
     for batch_idx, (data, target) in enumerate(train_loader):
-        data, target = data.to(device), target.to(device)
-        optimizer.zero_grad()
+        data, target = data.to(device, memory_format=torch.channels_last), target.to(
+            device
+        )
+        opt.zero_grad()
+        with torch.autocast(device_type=device.type):
+            output = model(data)
+            loss = F.nll_loss(output, target)
+        scaler.scale(loss).backward()
+        scaler.step(opt)
+        scaler.update()
+        if batch_idx % args.log_interval == 0:
+            print(
+                f"Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} ({100.0 * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}"
+            )
+            if args.dry_run:
+                break
+
+
+def train(args, model, device, train_loader, opt, epoch):
+    model.train()
+    for batch_idx, (data, target) in enumerate(train_loader):
+        data, target = data.to(device, memory_format=torch.channels_last), target.to(
+            device
+        )
+        opt.zero_grad()
         output = model(data)
         loss = F.nll_loss(output, target)
         loss.backward()
-        optimizer.step()
+        opt.step()
         if batch_idx % args.log_interval == 0:
-            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
-                epoch, batch_idx * len(data), len(train_loader.dataset),
-                100. * batch_idx / len(train_loader), loss.item()))
+            print(
+                f"Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} ({100.0 * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}"
+            )
             if args.dry_run:
                 break
 
@@ -58,43 +82,125 @@ def test(model, device, test_loader):
         for data, target in test_loader:
             data, target = data.to(device), target.to(device)
             output = model(data)
-            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
-            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
+            test_loss += F.nll_loss(
+                output, target, reduction="sum"
+            ).item()  # sum up batch loss
+            pred = output.argmax(
+                dim=1, keepdim=True
+            )  # get the index of the max log-probability
             correct += pred.eq(target.view_as(pred)).sum().item()
 
     test_loss /= len(test_loader.dataset)
 
-    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
-        test_loss, correct, len(test_loader.dataset),
-        100. * correct / len(test_loader.dataset)))
+    print(
+        f"\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} ({100.0 * correct / len(test_loader.dataset):.0f}%)\n"
+    )
 
 
-def main():
+def parse_args():
     # Training settings
-    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
-    parser.add_argument('--batch-size', type=int, default=64, metavar='N',
-                        help='input batch size for training (default: 64)')
-    parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
-                        help='input batch size for testing (default: 1000)')
-    parser.add_argument('--epochs', type=int, default=14, metavar='N',
-                        help='number of epochs to train (default: 14)')
-    parser.add_argument('--lr', type=float, default=1.0, metavar='LR',
-                        help='learning rate (default: 1.0)')
-    parser.add_argument('--gamma', type=float, default=0.7, metavar='M',
-                        help='Learning rate step gamma (default: 0.7)')
-    parser.add_argument('--no-cuda', action='store_true', default=False,
-                        help='disables CUDA training')
-    parser.add_argument('--no-mps', action='store_true', default=False,
-                        help='disables macOS GPU training')
-    parser.add_argument('--dry-run', action='store_true', default=False,
-                        help='quickly check a single pass')
-    parser.add_argument('--seed', type=int, default=1, metavar='S',
-                        help='random seed (default: 1)')
-    parser.add_argument('--log-interval', type=int, default=10, metavar='N',
-                        help='how many batches to wait before logging training status')
-    parser.add_argument('--save-model', action='store_true', default=False,
-                        help='For Saving the current Model')
+    parser = argparse.ArgumentParser(description="PyTorch MNIST Example")
+    parser.add_argument(
+        "--batch-size",
+        type=int,
+        default=64,
+        metavar="N",
+        help="input batch size for training (default: 64)",
+    )
+    parser.add_argument(
+        "--test-batch-size",
+        type=int,
+        default=1000,
+        metavar="N",
+        help="input batch size for testing (default: 1000)",
+    )
+    parser.add_argument(
+        "--epochs",
+        type=int,
+        default=14,
+        metavar="N",
+        help="number of epochs to train (default: 14)",
+    )
+    parser.add_argument(
+        "--lr",
+        type=float,
+        default=1.0,
+        metavar="LR",
+        help="learning rate (default: 1.0)",
+    )
+    parser.add_argument(
+        "--gamma",
+        type=float,
+        default=0.7,
+        metavar="M",
+        help="Learning rate step gamma (default: 0.7)",
+    )
+    parser.add_argument(
+        "--no-cuda", action="store_true", default=False, help="disables CUDA training"
+    )
+    parser.add_argument(
+        "--no-mps",
+        action="store_true",
+        default=False,
+        help="disables macOS GPU training",
+    )
+    parser.add_argument(
+        "--dry-run",
+        action="store_true",
+        default=False,
+        help="quickly check a single pass",
+    )
+    parser.add_argument(
+        "--seed", type=int, default=1, metavar="S", help="random seed (default: 1)"
+    )
+    parser.add_argument(
+        "--log-interval",
+        type=int,
+        default=10,
+        metavar="N",
+        help="how many batches to wait before logging training status",
+    )
+    parser.add_argument(
+        "--use-amp",
+        type=bool,
+        default=False,
+        help="use automatic mixed precision",
+    )
+    parser.add_argument(
+        "--compile-backend",
+        type=str,
+        default="inductor",
+        metavar="BACKEND",
+        help="backend to compile the model with",
+    )
+    parser.add_argument(
+        "--compile-mode",
+        type=str,
+        default="default",
+        metavar="MODE",
+        help="compilation mode",
+    )
+    parser.add_argument(
+        "--save-model",
+        action="store_true",
+        default=False,
+        help="For Saving the current Model",
+    )
+    parser.add_argument(
+        "--data-dir",
+        type=str,
+        default="../data",
+        metavar="DIR",
+        help="path to the data directory",
+    )
     args = parser.parse_args()
+
+    return args
+
+
+def main():
+    args = parse_args()
+
     use_cuda = not args.no_cuda and torch.cuda.is_available()
     use_mps = not args.no_mps and torch.backends.mps.is_available()
 
@@ -107,32 +213,43 @@ def main():
     else:
         device = torch.device("cpu")
 
-    train_kwargs = {'batch_size': args.batch_size}
-    test_kwargs = {'batch_size': args.test_batch_size}
+    train_kwargs = {"batch_size": args.batch_size}
+    test_kwargs = {"batch_size": args.test_batch_size}
     if use_cuda:
-        cuda_kwargs = {'num_workers': 1,
-                       'pin_memory': True,
-                       'shuffle': True}
+        cuda_kwargs = {"num_workers": 1, "pin_memory": True, "shuffle": True}
         train_kwargs.update(cuda_kwargs)
         test_kwargs.update(cuda_kwargs)
 
-    transform=transforms.Compose([
-        transforms.ToTensor(),
-        transforms.Normalize((0.1307,), (0.3081,))
-        ])
-    dataset1 = datasets.MNIST('../data', train=True, download=True,
-                       transform=transform)
-    dataset2 = datasets.MNIST('../data', train=False,
-                       transform=transform)
-    train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)
+    transform = transforms.Compose(
+        [
+            transforms.ToImage(),
+            transforms.ToDtype(torch.float32, scale=True),
+            transforms.Normalize(mean=(0.1307,), std=(0.3081,)),
+        ]
+    )
+
+    data_dir = args.data_dir
+
+    dataset1 = datasets.MNIST(data_dir, train=True, download=True, transform=transform)
+    dataset2 = datasets.MNIST(data_dir, train=False, transform=transform)
+    train_loader = torch.utils.data.DataLoader(dataset1, **train_kwargs)
     test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)
 
-    model = Net().to(device)
-    optimizer = optim.Adadelta(model.parameters(), lr=args.lr)
+    model = Net().to(device, memory_format=torch.channels_last)
+    model = torch.compile(model, backend=args.compile_backend, mode=args.compile_mode)
+    optimizer = optim.Adadelta(model.parameters(), lr=torch.tensor(args.lr))
 
     scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
+
+    scaler = None
+    if use_cuda and args.use_amp:
+        scaler = torch.GradScaler(device=device)
+
     for epoch in range(1, args.epochs + 1):
-        train(args, model, device, train_loader, optimizer, epoch)
+        if scaler is None:
+            train(args, model, device, train_loader, optimizer, epoch)
+        else:
+            train_amp(args, model, device, train_loader, optimizer, epoch, scaler)
         test(model, device, test_loader)
         scheduler.step()
 
@@ -140,5 +257,5 @@ def main():
         torch.save(model.state_dict(), "mnist_cnn.pt")
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     main()

from examples.

orion160 avatar orion160 commented on August 25, 2024

CC @svekars

from examples.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.