Using MPI in Python
Environment Setup
# For Ubuntu
apt-get install python3 python3-dev libopenmpi-dev openmpi-bin mpi-default-bin
pip3 install mpi4py numpy matplotlib
# For CentOS
yum-config-manager --disable source
yum install python python-devel python-pip python-matplotlib python-numpy mpich mpich-devel
echo 'export PATH="/usr/lib64/mpich/bin/:$PATH"' >> ~/.bashrc
echo 'export C_INCLUDE_PATH="/usr/include/mpich-x86_64/:$C_INCLUDE_PATH"' >> ~/.bashrc
echo 'export LD_LIBRARY_PATH="/usr/lib64/mpich/lib/:$LD_LIBRARY_PATH"' >> ~/.bashrc
# CentOS install matplotlib
yum-builddep python-matplotlib
python-pip install matplotlib
MPI Code
# file: mandelbrot-seq.py
import numpy as np
from matplotlib import pyplot as plt
def mandelbrot(c, maxit):
z = c
for n in range(maxit):
if abs(z) > 2:
return n
z = z*z + c
return 0
xmin,xmax = -2.0, 1.0
ymin,ymax = -1.0, 1.0
width,height = 320,200
maxit = 127
xlin = np.linspace(xmin, xmax, width)
ylin = np.linspace(ymin, ymax, height)
C = np.empty((width,height), np.int64)
for w in range(width):
for h in range(height):
C[w, h] = mandelbrot(xlin[w] + 1j*ylin[h], maxit)
plt.imshow(C, aspect='equal')
plt.spectral()
plt.show()
# file: mandelbrot-mpi.py
import numpy as np
from mpi4py import MPI
from matplotlib import pyplot as plt
def mandelbrot(c, maxit):
z = c
for n in range(maxit):
if abs(z) > 2:
return n
z = z*z + c
return 0
xmin,xmax = -2.0, 1.0
ymin,ymax = -1.0, 1.0
width,height = 320,200
maxit = 127
comm = MPI.COMM_WORLD
comm_size = comm.Get_size()
comm_rank = comm.Get_rank()
ncols = width // comm_size + (width % comm_size > comm_rank)
col_start = comm.scan(ncols) - ncols
xlin = np.linspace(xmin, xmax, width)
ylin = np.linspace(ymin, ymax, height)
C_local = np.empty((ncols,height), np.int64)
for w in range(ncols):
for h in range(height):
C_local[w, h] = mandelbrot(xlin[w+col_start] + 1j*ylin[h], maxit)
# Gather Results here
comm_gather_num = comm.gather(ncols, root=0)
C = None
if comm_rank == 0:
C = np.empty((width,height), np.int64)
else:
C = None
rowtype = MPI.INT64_T.Create_contiguous(height)
rowtype.Commit()
comm.Gatherv(sendbuf=[C_local, MPI.INT64_T],
recvbuf=[C, (comm_gather_num, None), rowtype],
root=0)
rowtype.Free()
if comm_rank == 0:
plt.imshow(C, aspect='equal')
plt.spectral()
plt.show()
Setup your hosts and have fun!
Comments
Comments powered by Disqus