How to: finite differences with Python

I would like to really understand how to implement a solver using finite differences with the Crank-Nicolson method.

I did this in 1D and in 2D by writing four classes in Python using scipy.sparse. This code is very simple and I don't even know if it is the right way to do that but it could be useful for others, so here it is.

First the unidimensional case. In the file operator1d.py, we define a class Operator1D:

import numpy as np
import scipy.sparse as sparse


class OperatorFiniteDiff1DPeriodic(object):
    def __init__(self, shape, lengths=None):
        if lengths is None:
            Lx = 1.
        else:
            Lx = float(lengths[0])
        nx = int(shape[0])
        self.nx = nx
        self.size = nx
        self.shape = [nx]
        self.Lx = Lx
        self.dx = Lx/nx
        dx = self.dx

        self.xs = np.linspace(0, Lx, nx)

        self.sparse_px = sparse.diags(
            diagonals=[-np.ones(nx-1), np.ones(nx-1), -1, 1],
            offsets=[-1, 1, nx-1, -(nx-1)])
        self.sparse_px = self.sparse_px/(2*dx)

        self.sparse_pxx = sparse.diags(
            diagonals=[np.ones(nx-1), -2*np.ones(nx), np.ones(nx-1), 1, 1],
            offsets=[-1, 0, 1, nx-1, -(nx-1)])

        self.sparse_pxx = self.sparse_pxx/dx**2

    def px(self, a):
        return self.sparse_px.dot(a.flat)

    def pxx(self, a):
        return self.sparse_pxx.dot(a.flat)

    def identity(self):
        return sparse.identity(self.size)


if __name__ == '__main__':
    nx = 4
    oper = OperatorFiniteDiff1DPeriodic([nx], [nx/2.])
    a = np.arange(nx)

In a file solver1d.py, we define a class Solver1D:

"""Finite difference solver 1D
==============================

This module provides a class Solver1D to solve a very simple equation
using finite differences with a center difference method in space and
Crank-Nicolson method in time.

You can use them with Ipython doing `run solver1d`.

This class can also be used to solve other 1D linear equations.
"""

import numpy as np
from scipy.sparse.linalg import spsolve
import matplotlib.pyplot as plt

plt.ion()

from operator1d import OperatorFiniteDiff1DPeriodic


class Solver1D(object):
    _Operator = OperatorFiniteDiff1DPeriodic

    def __init__(self, dt, nu, U, shape, lengths=None):
        self.dt = float(dt)
        self.U = float(U)
        self.nu = float(nu)

        self.oper = self._Operator(shape, lengths=lengths)

        self.L = self.linear_operator()
        self.A = self.oper.identity() - self.dt/2*self.L

        # initial condition
        self.t = 0.
        self._init_field()

        self._init_plot()

    def _init_field(self):
        self.s = np.exp(-(10*(self.oper.xs-self.oper.Lx/2))**2)

    def linear_operator(self):
        return -self.U*self.oper.sparse_px + self.nu*self.oper.sparse_pxx

    def right_hand_side(self, s=None):
        if s is None:
            s = self.s
        return s.ravel() + self.dt/2*self.L.dot(s.flat)

    def one_time_step(self):
        self.s = spsolve(self.A, self.right_hand_side())
        self.s = self.s.reshape(self.oper.shape)
        self.t += self.dt

    def start(self, t_end=1.):
        while self.t < t_end:
            self.one_time_step()
            self._update_plot()

    def _init_plot(self):
        plt.figure()
        ax = plt.gca()
        ax.set_xlabel('x')
        ax.set_ylabel('s')
        ax.set_ylim(-0.1, 1)
        self.ax = ax
        self.line, = ax.plot(self.oper.xs, self.s)
        plt.show()

    def _update_plot(self):
        self.line.set_data(self.oper.xs, self.s)
        self.ax.figure.canvas.draw()

if __name__ == '__main__':
    dt = 0.01
    U = 1.
    Lx = 1.
    nx = 400
    nu = 0.

    sim = Solver1D(dt, nu, U, [nx], [Lx])
    sim.start(1)

Now the 2D case, with a file operator2d.py:

import numpy as np
import scipy.sparse as sparse

from operator1d import OperatorFiniteDiff1DPeriodic


class OperatorFiniteDiff2DPeriodic(OperatorFiniteDiff1DPeriodic):
    def __init__(self, shape, lengths=None):
        if lengths is None:
            Lx = 1.
            Ly = 1.
        else:
            Lx = float(lengths[1])
            Ly = float(lengths[0])
        nx = int(shape[1])
        ny = int(shape[0])
        self.nx = nx
        self.ny = ny
        self.shape = [ny, nx]
        size = nx*ny
        self.size = size
        self.Lx = Lx
        self.Ly = Ly
        self.dx = Lx/nx
        self.dy = Ly/ny
        dx = self.dx
        dy = self.dy

        self.xs = np.linspace(0, Lx, nx)
        self.ys = np.linspace(0, Ly, ny)

        def func_i1_mat(i0_mat, iv):
            i1 = i0_mat % nx
            i0 = i0_mat // nx
            if iv == 0:
                i1_mat = i0*nx + (i1+1) % nx
            elif iv == 1:
                i1_mat = i0*nx + (i1-1) % nx
            else:
                raise ValueError('Shouldn''t be here...')
            return i1_mat

        values = np.array([1, -1])/(2*dx)
        self.sparse_px = self._create_sparse(values, func_i1_mat)

        def func_i1_mat(i0_mat, iv):
            i1 = i0_mat % nx
            i0 = i0_mat // nx
            if iv == 0:
                i1_mat = i0_mat
            elif iv == 1:
                i1_mat = i0*nx + (i1+1) % nx
            elif iv == 2:
                i1_mat = i0*nx + (i1-1) % nx
            else:
                raise ValueError('Shouldn''t be here...')
            return i1_mat

        values = np.array([-2, 1, 1])/dx**2
        self.sparse_pxx = self._create_sparse(values, func_i1_mat)

        def func_i1_mat(i0_mat, iv):
            i1 = i0_mat % nx
            i0 = i0_mat // nx
            if iv == 0:
                i1_mat = ((i0+1)*nx) % size + i1
            elif iv == 1:
                i1_mat = ((i0-1)*nx) % size + i1
            else:
                raise ValueError('Shouldn''t be here...')
            return i1_mat

        values = np.array([1, -1])/(2*dy)
        self.sparse_py = self._create_sparse(values, func_i1_mat)

        def func_i1_mat(i0_mat, iv):
            i1 = i0_mat % nx
            i0 = i0_mat // nx
            if iv == 0:
                i1_mat = i0_mat
            elif iv == 1:
                i1_mat = ((i0+1)*nx) % size + i1
            elif iv == 2:
                i1_mat = ((i0-1)*nx) % size + i1
            else:
                raise ValueError('Shouldn''t be here...')
            return i1_mat

        values = np.array([-2, 1, 1])/dx**2
        self.sparse_pyy = self._create_sparse(values, func_i1_mat)

    def _create_sparse(self, values, func_i1_mat):
        size = self.size
        nb_values = len(values)
        data = np.empty(size*nb_values)
        i0s = np.empty(size*nb_values)
        i1s = np.empty(size*nb_values)

        for i0_mat in xrange(size):
            for iv, v in enumerate(values):
                data[nb_values*i0_mat+iv] = v
                i0s[nb_values*i0_mat+iv] = i0_mat
                i1s[nb_values*i0_mat+iv] = func_i1_mat(i0_mat, iv)
        return sparse.coo_matrix(
            (data, (i0s, i1s)), shape=(size, size))

    def py(self, a):
        return self.sparse_py.dot(a.flat)

    def pyy(self, a):
        return self.sparse_pyy.dot(a.flat)


if __name__ == '__main__':
    nx = 3
    ny = 3
    oper = OperatorFiniteDiff2DPeriodic([ny, nx])
    a = np.arange(nx*ny).reshape([ny, nx])

And finally the file solver2d.py defining the class Solver2D:

"""Finite difference solver 2D
==============================

This module provides a class Solver2D to solve a very simple equation
using finite differences with a center difference method in space and
Crank-Nicolson method in time.

You can use them with Ipython doing `run solver2d`.

This class can also be used to solve other 2D linear equations.

"""

import numpy as np
import matplotlib.pyplot as plt

plt.ion()

from operator2d import OperatorFiniteDiff2DPeriodic
from solver1d import Solver1D


class Solver2D(Solver1D):
    _Operator = OperatorFiniteDiff2DPeriodic

    def _init_field(self):
        oper = self.oper
        yy, xx = np.meshgrid(oper.ys-oper.Ly/2, oper.xs-oper.Lx/2)
        rr2 = xx**2 + yy**2
        self.s = np.exp(-100*rr2)

    def linear_operator(self):
        return (- self.U*self.oper.sparse_px
                + self.nu*(self.oper.sparse_pxx+self.oper.sparse_pyy))

    def _init_plot(self):
        oper = self.oper
        plt.figure()
        ax = plt.gca()
        ax.set_xlabel('x')
        ax.set_ylabel('y')
        ax.set_xlim([0, oper.Lx])
        ax.set_ylim([0, oper.Ly])
        self.ax = ax
        dx = oper.Lx/oper.nx
        dy = oper.Ly/oper.ny
        xs = np.linspace(-dx/2, oper.Lx+dx/2, oper.nx+1)
        ys = np.linspace(-dy/2, oper.Ly+dy/2, oper.ny+1)
        self.quad = ax.pcolormesh(xs, ys, self.s)
        plt.show()

    def _update_plot(self):
        self.quad.set_array(self.s.ravel())
        self.ax.figure.canvas.draw()

if __name__ == '__main__':
    dt = 0.01
    U = 1.
    Lx = 1.
    N = 200
    nu = 0.

    sim = Solver2D(dt, nu, U, [N, N], [Lx, Lx])
    sim.start(1)

Let's hope there is no bug !