Coder Social home page Coder Social logo

save time snapshots about deepwave HOT 10 CLOSED

ar4 avatar ar4 commented on August 29, 2024
save time snapshots

from deepwave.

Comments (10)

ar4 avatar ar4 commented on August 29, 2024

from deepwave.

pavane avatar pavane commented on August 29, 2024

Thank you so much. I would like to help with testing 3D boundary conditions and maybe improving that piece of this code.
Please help me get started with the compiled propagator.

from deepwave.

ar4 avatar ar4 commented on August 29, 2024

I have written some code to call the compiled propagator directly so that we can access the wavefields at arbitrary time steps:

import torch
import numpy as np
import scipy.signal
import deepwave
import deepwave.base.propagator
from deepwave.scalar import scalar

class SteppingPropagator(deepwave.base.propagator.Propagator):
    """PyTorch Module for scalar wave propagator.

    See deepwave.base.propagator.Propagator for description.
    """

    def __init__(self, model, dx,
                 source_amplitudes, source_locations, receiver_locations, dt,
                 pml_width=None, survey_pad=None, vpmax=None):
        if list(model.keys()) != ["vp"]:
            raise RuntimeError(
                "Model must only contain vp, but contains {}".format(
                    list(model.keys())
                )
            )
        super(SteppingPropagator, self).__init__(
            SteppingPropagatorFunction,
            model,
            dx,
            fd_width=4,  # also in Pml
            pml_width=pml_width,
            survey_pad=survey_pad,
        )
        self.model.extra_info["vpmax"] = vpmax
        if model["vp"].min() <= 0.0:
            raise RuntimeError(
                "vp must be > 0, but min is {}".format(model["vp"].min())
            )
            
        (source_amplitudes,
         source_locations,
         receiver_locations,
         dt,
         model,
         property_names,
         vp) = self.forward(source_amplitudes, source_locations, receiver_locations, dt)
        
        if property_names != ["vp"]:
            raise RuntimeError(
                "Model must only contain vp, but contains {}".format(
                    property_names
                )
            )
        if vp.min() <= 0.0:
            raise RuntimeError(
                "vp must be > 0, but min is {}".format(vp.min())
            )
        device = model.device
        dtype = model.dtype
        num_steps, num_shots, num_sources_per_shot = source_amplitudes.shape
        num_receivers_per_shot = receiver_locations.shape[1]

        if model.extra_info["vpmax"] is None:
            max_vel = vp.max().item()
        else:
            max_vel = model.extra_info["vpmax"]
        timestep = scalar.Timestep(dt, model.dx, max_vel)
        model.add_properties(
            {
                "vp2dt2": vp ** 2 * timestep.inner_dt ** 2,
                "scaling": 2 / vp ** 3,
            }
        )
        source_model_locations = model.get_locations(source_locations)
        receiver_model_locations = model.get_locations(receiver_locations)
        scalar_wrapper = scalar._select_propagator(model.ndim, vp.dtype, vp.is_cuda)
        wavefield_save_strategy = scalar._set_wavefield_save_strategy(
            False, dt, timestep.inner_dt, scalar_wrapper
        )
        fd1, fd2 = scalar._set_finite_diff_coeffs(model.ndim, model.dx, device, dtype)
        wavefield, saved_wavefields = scalar._allocate_wavefields(
            wavefield_save_strategy,
            scalar_wrapper,
            model,
            num_steps,
            num_shots,
        )
        receiver_amplitudes = torch.zeros(
            num_steps,
            num_shots,
            num_receivers_per_shot,
            device=device,
            dtype=dtype,
        )
        inner_dt = torch.tensor([timestep.inner_dt]).to(dtype)
        pml = scalar.Pml(model, num_shots, max_vel)
        source_amplitudes_resampled = scipy.signal.resample(
            source_amplitudes.detach().cpu().numpy(),
            num_steps * timestep.step_ratio,
        )
        source_amplitudes_resampled = (
            torch.tensor(source_amplitudes_resampled)
            .to(dtype)
            .to(source_amplitudes.device)
        )
        source_amplitudes_resampled.requires_grad = (
            source_amplitudes.requires_grad
        )
        
        self.scalar_wrapper = scalar_wrapper
        self.wavefield = wavefield
        self.pml = pml
        self.receiver_amplitudes = receiver_amplitudes
        self.saved_wavefields = saved_wavefields
        self.model = model
        self.fd1 = fd1
        self.fd2 = fd2
        self.source_amplitudes_resampled = source_amplitudes_resampled
        self.source_model_locations = source_model_locations
        self.receiver_model_locations = receiver_model_locations
        self.inner_dt = inner_dt
        self.timestep = timestep
        self.num_shots = num_shots
        self.num_sources_per_shot = num_sources_per_shot
        self.num_receivers_per_shot = num_receivers_per_shot
        self.wavefield_save_strategy = wavefield_save_strategy
        self.dtype = dtype
        
        self.total_num_steps = num_steps
        self.current_step = 0
        
    def step(self, num_steps):
        
        assert self.current_step + num_steps <= self.total_num_steps
        
        source_amplitudes_resampled_steps = \
            self.source_amplitudes_resampled[self.current_step*self.timestep.step_ratio:
                                             (self.current_step+num_steps)*self.timestep.step_ratio]

        # Call compiled C code to do forward modeling
        self.scalar_wrapper.forward(
            self.wavefield.to(self.dtype).contiguous(),
            self.pml.aux.to(self.dtype).contiguous(),
            self.receiver_amplitudes.to(self.dtype).contiguous(),
            self.saved_wavefields.to(self.dtype).contiguous(),
            self.pml.sigma.to(self.dtype).contiguous(),
            self.model.properties["vp2dt2"].to(self.dtype).contiguous(),
            self.fd1.to(self.dtype).contiguous(),
            self.fd2.to(self.dtype).contiguous(),
            source_amplitudes_resampled_steps.to(self.dtype).contiguous(),
            self.source_model_locations.long().contiguous(),
            self.receiver_model_locations.long().contiguous(),
            self.model.shape.contiguous(),
            self.pml.pml_width.long().contiguous(),
            self.inner_dt,
            num_steps,
            self.timestep.step_ratio,
            self.num_shots,
            self.num_sources_per_shot,
            self.num_receivers_per_shot,
            self.wavefield_save_strategy,
        )
        
        self.current_step += num_steps 
        
        if num_steps * self.timestep.step_ratio % 3 != 0:
            # Swap the wavefield arrays so that they are in the correct order
            wf_idxs = [0, 1, 2]
            for stepidx in range(num_steps * self.timestep_step_ratio):
                wf_idxs = [wf_idxs[2], wf_idxs[0], wf_idxs[1]]
            self.wavefield[0], self.wavefield[1], self.wavefield[2] = \
                (self.wavefield[wf_idxs[0]],
                 self.wavefield[wf_idxs[1]],
                 self.wavefield[wf_idxs[2]])
        
        if num_steps * self.timestep.step_ratio % 2 != 0:
            # Swap the aux arrays so that they are in the correct order
            ndim = self.model.ndim
            if ndim == 1:
                aux_size = 1
            elif ndim == 2:
                aux_size = 2
            else:
                aux_size = 4
            assert len(self.pml.aux) == 2 * aux_size
            self.pml.aux[:aux_size], self.pml.aux[aux_size:] = \
                self.pml.aux[aux_size:], self.pml.aux[:aux_size]
        
        return self.wavefield[1]

        
class SteppingPropagatorFunction(torch.autograd.Function):
    """Forward modeling and backpropagation functions. Not called by users."""

    @staticmethod
    def forward(
        ctx,
        source_amplitudes,
        source_locations,
        receiver_locations,
        dt,
        model,
        property_names,
        vp,
    ):
        return (
            source_amplitudes,
            source_locations,
            receiver_locations,
            dt,
            model,
            property_names,
            vp,
        )

It is a bit hacky - it runs the setup for a regular propagator and then extracts the variables that are passed to the forward method of the propagator. It then uses these to run all of the code in the usual forward propagator up to the point where the compiled propagator gets called, and saves the arguments to this so that they can be used when you actually want to run forward time steps of the propagator. The benefit of doing all of this setup is that the actual stepping part is then quite easy - we just get the right bits of the source wavelet for the desired steps, run the compiled propagator, and then swap some memory around if necessary to make sure it is in the right place.

Here is an example of how to use it:

import matplotlib.pyplot as plt

dx = 5.0 # 5m in each dimension
dt = 0.004 # 4ms
nz = 200
ny = 400
nt = int(5 / dt) # 1s
peak_freq = 4 
peak_source_time = 1/peak_freq

# constant 1500m/s model
model = torch.ones(nz, ny) * 1500

# one source and receiver at the same location
x_s = torch.Tensor([[[0, 20 * dx]]])
x_r = x_s.clone()

source_amplitudes = deepwave.wavelets.ricker(peak_freq, nt, dt,
                                             peak_source_time).reshape(-1, 1, 1)

prop = SteppingPropagator({'vp': model}, dx, source_amplitudes, x_s, x_r, dt)
wavefield1 = prop.step(100).detach().numpy().copy()
wavefield2 = prop.step(100).detach().numpy().copy()
wavefield3 = prop.step(100).detach().numpy().copy()

_, ax = plt.subplots(1,3,sharex=True,sharey=True)
ax[0].imshow(wavefield1[0,:,:,0], aspect='auto')
ax[1].imshow(wavefield2[0,:,:,0], aspect='auto')
ax[2].imshow(wavefield3[0,:,:,0], aspect='auto')
plt.show()

The CPU implementation of propagation in 3D is here. If I remember correctly, I used the same PML as PySIT.

from deepwave.

pavane avatar pavane commented on August 29, 2024

Thank you so much for getting me started. I will update you on my progess

from deepwave.

pavane avatar pavane commented on August 29, 2024

The code fails with the following error
"TypeError: SteppingPropagatorFunctionBackward.forward: expected Tensor or tuple of Tensor (got float) for return value 3"

from deepwave.

ar4 avatar ar4 commented on August 29, 2024

from deepwave.

pavane avatar pavane commented on August 29, 2024

from deepwave.

ar4 avatar ar4 commented on August 29, 2024

from deepwave.

ar4 avatar ar4 commented on August 29, 2024

From the message that you got, it sounds like your version of PyTorch is complaining about some of the return values from the forward function in SteppingPropagatorFunction not being Tensors. Perhaps you could try this version of the code instead:

import torch
import numpy as np
import scipy.signal
import deepwave
import deepwave.base.propagator
from deepwave.scalar import scalar
from deepwave.base.propagator import _check_locations_with_model

class SteppingPropagator(deepwave.base.propagator.Propagator):
    """PyTorch Module for scalar wave propagator.

    See deepwave.base.propagator.Propagator for description.
    """

    def __init__(self, model, dx,
                 source_amplitudes, source_locations, receiver_locations, dt,
                 pml_width=None, survey_pad=None, vpmax=None):
        if list(model.keys()) != ["vp"]:
            raise RuntimeError(
                "Model must only contain vp, but contains {}".format(
                    list(model.keys())
                )
            )
        super(SteppingPropagator, self).__init__(
            SteppingPropagatorFunction,
            model,
            dx,
            fd_width=4,  # also in Pml
            pml_width=pml_width,
            survey_pad=survey_pad,
        )
        self.model.extra_info["vpmax"] = vpmax
        if model["vp"].min() <= 0.0:
            raise RuntimeError(
                "vp must be > 0, but min is {}".format(model["vp"].min())
            )

        # Check dt
        if not isinstance(dt, float):
            raise RuntimeError('dt must be a float, but has type {}'
                               .format(type(dt)))
        if dt <= 0.0:
            raise RuntimeError('dt must be > 0, but is {}'.format(dt))

        # Check same device as model
        if not (self.model.device == source_amplitudes.device ==
                source_locations.device == receiver_locations.device):
            raise RuntimeError('model, source amplitudes, source_locations, '
                               'and receiver_locations must all have the same '
                               'device, but got {} {} {} {}'
                               .format(self.model.device,
                                       source_amplitudes.device,
                                       source_locations.device,
                                       receiver_locations.device))

        # Check shapes
        if source_amplitudes.dim() != 3:
            raise RuntimeError('source_amplitude must have shape '
                               '[nt, num_shots, num_sources_per_shot]')

        if source_locations.dim() != 3:
            raise RuntimeError('source_locations must have shape '
                               '[num_shots, num_sources_per_shot, num_dims]')

        if receiver_locations.dim() != 3:
            raise RuntimeError('receiver_locations must have shape '
                               '[num_shots, num_receivers_per_shot, num_dims]')

        if not (source_amplitudes.shape[1] == source_locations.shape[0] ==
                receiver_locations.shape[0]):
            raise RuntimeError('Shape mismatch, expected '
                               'source_amplitudes.shape[1] '
                               '== source_locations.shape[0] '
                               '== receiver_locations.shape[0], but got '
                               '{} {} {}'.format(source_amplitudes.shape[1],
                                                 source_locations.shape[0],
                                                 receiver_locations.shape[0]))

        if not (source_amplitudes.shape[2] == source_locations.shape[1]):
            raise RuntimeError('Shape mismatch, expected '
                               'source_amplitudes.shape[2] '
                               '== source_locations.shape[1], but got '
                               '{} {}'.format(source_amplitudes.shape[2],
                                              source_locations.shape[1]))

        if not (self.model.ndim == source_locations.shape[2] ==
                receiver_locations.shape[2]):
            raise RuntimeError('Shape mismatch, expected '
                               'model num dims == source_locations.shape[2] '
                               '== receiver_locations.shape[2], but got '
                               '{} {} {}'.format(self.model.ndim,
                                                 source_locations.shape[2],
                                                 receiver_locations.shape[2]))

        # Check src/rec locations within model
        _check_locations_with_model(self.model, source_locations, 'source')
        _check_locations_with_model(self.model, receiver_locations, 'receiver')

        # Extract a region of the model around the sources/receivers
        model = self.extract(self.model, source_locations, receiver_locations)

        # Apply padding for the spatial finite difference and for the PML
        model = self.pad(model)

        property_names = list(model.properties.keys())
        vp = model.properties["vp"]
        
        if property_names != ["vp"]:
            raise RuntimeError(
                "Model must only contain vp, but contains {}".format(
                    property_names
                )
            )
        if vp.min() <= 0.0:
            raise RuntimeError(
                "vp must be > 0, but min is {}".format(vp.min())
            )
        device = model.device
        dtype = model.dtype
        num_steps, num_shots, num_sources_per_shot = source_amplitudes.shape
        num_receivers_per_shot = receiver_locations.shape[1]

        if model.extra_info["vpmax"] is None:
            max_vel = vp.max().item()
        else:
            max_vel = model.extra_info["vpmax"]
        timestep = scalar.Timestep(dt, model.dx, max_vel)
        model.add_properties(
            {
                "vp2dt2": vp ** 2 * timestep.inner_dt ** 2,
                "scaling": 2 / vp ** 3,
            }
        )
        source_model_locations = model.get_locations(source_locations)
        receiver_model_locations = model.get_locations(receiver_locations)
        scalar_wrapper = scalar._select_propagator(model.ndim, vp.dtype, vp.is_cuda)
        wavefield_save_strategy = scalar._set_wavefield_save_strategy(
            False, dt, timestep.inner_dt, scalar_wrapper
        )
        fd1, fd2 = scalar._set_finite_diff_coeffs(model.ndim, model.dx, device, dtype)
        wavefield, saved_wavefields = scalar._allocate_wavefields(
            wavefield_save_strategy,
            scalar_wrapper,
            model,
            num_steps,
            num_shots,
        )
        receiver_amplitudes = torch.zeros(
            num_steps,
            num_shots,
            num_receivers_per_shot,
            device=device,
            dtype=dtype,
        )
        inner_dt = torch.tensor([timestep.inner_dt]).to(dtype)
        pml = scalar.Pml(model, num_shots, max_vel)
        source_amplitudes_resampled = scipy.signal.resample(
            source_amplitudes.detach().cpu().numpy(),
            num_steps * timestep.step_ratio,
        )
        source_amplitudes_resampled = (
            torch.tensor(source_amplitudes_resampled)
            .to(dtype)
            .to(source_amplitudes.device)
        )
        source_amplitudes_resampled.requires_grad = (
            source_amplitudes.requires_grad
        )
        
        self.dtype = dtype
        self.scalar_wrapper = scalar_wrapper
        self.wavefield = wavefield.to(self.dtype).contiguous()
        self.pml = pml
        self.pml.aux = self.pml.aux.to(self.dtype).contiguous()
        self.pml.sigma = self.pml.sigma.to(self.dtype).contiguous()
        self.pml.pml_width = self.pml.pml_width.long().contiguous()
        self.receiver_amplitudes = receiver_amplitudes.to(self.dtype).contiguous()
        self.saved_wavefields = saved_wavefields.to(self.dtype).contiguous()
        self.model = model
        self.model.properties["vp2dt2"] = self.model.properties["vp2dt2"].to(self.dtype).contiguous()
        self.fd1 = fd1.to(self.dtype).contiguous()
        self.fd2 = fd2.to(self.dtype).contiguous()
        self.source_amplitudes_resampled = source_amplitudes_resampled
        self.source_model_locations = source_model_locations.long().contiguous()
        self.receiver_model_locations = receiver_model_locations.long().contiguous()
        self.inner_dt = inner_dt
        self.timestep = timestep
        self.num_shots = num_shots
        self.num_sources_per_shot = num_sources_per_shot
        self.num_receivers_per_shot = num_receivers_per_shot
        self.wavefield_save_strategy = wavefield_save_strategy
        
        self.total_num_steps = num_steps
        self.current_step = 0
        
    def step(self, num_steps):
        
        assert self.current_step + num_steps <= self.total_num_steps
        
        source_amplitudes_resampled_steps = \
            self.source_amplitudes_resampled[self.current_step*self.timestep.step_ratio:
                                             (self.current_step+num_steps)*self.timestep.step_ratio]

        # Call compiled C code to do forward modeling
        self.scalar_wrapper.forward(
            self.wavefield,
            self.pml.aux,
            self.receiver_amplitudes,
            self.saved_wavefields,
            self.pml.sigma,
            self.model.properties["vp2dt2"],
            self.fd1,
            self.fd2,
            source_amplitudes_resampled_steps.to(self.dtype).contiguous(),
            self.source_model_locations,
            self.receiver_model_locations,
            self.model.shape.contiguous(),
            self.pml.pml_width,
            self.inner_dt,
            num_steps,
            self.timestep.step_ratio,
            self.num_shots,
            self.num_sources_per_shot,
            self.num_receivers_per_shot,
            self.wavefield_save_strategy,
        )
        
        self.current_step += num_steps 

        if num_steps * self.timestep.step_ratio % 3 != 0:
            # Swap the wavefield arrays so that they are in the correct order
            wf_idxs = [0, 1, 2]
            for stepidx in range(num_steps * self.timestep_step_ratio):
                wf_idxs = [wf_idxs[2], wf_idxs[0], wf_idxs[1]]
            self.wavefield[0], self.wavefield[1], self.wavefield[2] = \
                (self.wavefield[wf_idxs[0]],
                 self.wavefield[wf_idxs[1]],
                 self.wavefield[wf_idxs[2]])
        
        if num_steps * self.timestep.step_ratio % 2 != 0:
            # Swap the aux arrays so that they are in the correct order
            ndim = self.model.ndim
            if ndim == 1:
                aux_size = 1
            elif ndim == 2:
                aux_size = 2
            else:
                aux_size = 4
            assert len(self.pml.aux) == 2 * aux_size
            self.pml.aux[:aux_size], self.pml.aux[aux_size:] = \
                self.pml.aux[aux_size:], self.pml.aux[:aux_size]
        
        return self.wavefield[1]

        
class SteppingPropagatorFunction(torch.autograd.Function):
    """Forward modeling and backpropagation functions. Not called by users."""

    @staticmethod
    def forward(
        ctx,
        source_amplitudes,
        source_locations,
        receiver_locations,
        dt,
        model,
        property_names,
        vp,
    ):
        return vp

The example code to run the propagator should be the same.

from deepwave.

pavane avatar pavane commented on August 29, 2024

This code works. Thank you so much. I will keep you posted.

from deepwave.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.