aboutsummaryrefslogtreecommitdiff
path: root/code/sunlab/common/plotting/base.py
diff options
context:
space:
mode:
Diffstat (limited to 'code/sunlab/common/plotting/base.py')
-rw-r--r--code/sunlab/common/plotting/base.py270
1 files changed, 270 insertions, 0 deletions
diff --git a/code/sunlab/common/plotting/base.py b/code/sunlab/common/plotting/base.py
new file mode 100644
index 0000000..aaf4a94
--- /dev/null
+++ b/code/sunlab/common/plotting/base.py
@@ -0,0 +1,270 @@
+from matplotlib import pyplot as plt
+
+
+def blank_plot(_plt=None, _xticks=False, _yticks=False):
+ """# Remove Plot Labels"""
+ if _plt is None:
+ _plt = plt
+ _plt.xlabel("")
+ _plt.ylabel("")
+ _plt.title("")
+ tick_params = {
+ "which": "both",
+ "bottom": _xticks,
+ "left": _yticks,
+ "right": False,
+ "labelleft": False,
+ "labelbottom": False,
+ }
+ _plt.tick_params(**tick_params)
+ for child in plt.gcf().get_children():
+ if child._label == "<colorbar>":
+ child.set_yticks([])
+ axs = _plt.gcf().get_axes()
+ try:
+ axs = axs.flatten()
+ except:
+ ...
+ for ax in axs:
+ ax.set_xlabel("")
+ ax.set_ylabel("")
+ ax.set_title("")
+ ax.tick_params(**tick_params)
+
+
+def save_plot(name, _plt=None, _xticks=False, _yticks=False, tighten=True):
+ """# Save Plot in Multiple Formats"""
+ assert type(name) == str, "Name must be string"
+ from os.path import dirname
+ from os import makedirs
+
+ makedirs(dirname(name), exist_ok=True)
+ if _plt is None:
+ from matplotlib import pyplot as plt
+ _plt = plt
+ _plt.savefig(name + ".png", dpi=1000)
+ blank_plot(_plt, _xticks=_xticks, _yticks=_yticks)
+ if tighten:
+ from matplotlib import pyplot as plt
+ plt.tight_layout()
+ _plt.savefig(name + ".pdf")
+ _plt.savefig(name + ".svg")
+
+
+def scatter_2d(data_2d, labels=None, _plt=None, **matplotlib_kwargs):
+ """# Scatter 2d Data
+
+ - data_2d: 2d-dataset to plot
+ - labels: labels for each case
+ - _plt: Optional specific plot to transform"""
+ from .colors import Pcolor
+
+ if _plt is None:
+ _plt = plt
+
+ def _filter(data, labels=None, _filter_on=None):
+ if labels is None:
+ return data, False
+ else:
+ return data[labels == _filter_on], True
+
+ for _class in [2, 3, 1, 0]:
+ local_data, has_color = _filter(data_2d, labels, _class)
+ if has_color:
+ _plt.scatter(
+ local_data[:, 0],
+ local_data[:, 1],
+ color=Pcolor[_class],
+ **matplotlib_kwargs
+ )
+ else:
+ _plt.scatter(local_data[:, 0], local_data[:, 1], **matplotlib_kwargs)
+ break
+ return _plt
+
+
+def plot_contour(two_d_mask, color="black", color_map=None, raise_error=False):
+ """# Plot Contour of Mask"""
+ from matplotlib.pyplot import contour
+ from numpy import mgrid
+
+ z = two_d_mask
+ x, y = mgrid[: z.shape[1], : z.shape[0]]
+ if color_map is not None:
+ try:
+ color = color_map(color)
+ except Exception as e:
+ if raise_error:
+ raise e
+ try:
+ contour(x, y, z.T, colors=color, linewidth=0.5)
+ except Exception as e:
+ if raise_error:
+ raise e
+
+
+def gaussian_smooth_plot(
+ xy,
+ sigma=0.1,
+ extent=[-1, 1, -1, 1],
+ grid_n=100,
+ grid=None,
+ do_normalize=True,
+):
+ """# Plot Data with Gaussian Smoothening around point"""
+ from numpy import array, ndarray, linspace, meshgrid, zeros_like
+ from numpy import pi, sqrt, exp
+ from numpy.linalg import norm
+
+ if not type(xy) == ndarray:
+ xy = array(xy)
+ if grid is not None:
+ XY = grid
+ else:
+ X = linspace(extent[0], extent[1], grid_n)
+ Y = linspace(extent[2], extent[3], grid_n)
+ XY = array(meshgrid(X, Y)).T
+ smoothed = zeros_like(XY[:, :, 0])
+ factor = 1
+ if do_normalize:
+ factor = (sqrt(2 * pi) * sigma) ** 2.
+ if len(xy.shape) == 1:
+ smoothed = exp(-((norm(xy - XY, axis=-1) / (sqrt(2) * sigma)) ** 2.0)) / factor
+ else:
+ try:
+ from tqdm.notebook import tqdm
+ except Exception:
+
+ def tqdm(x):
+ return x
+
+ for i in tqdm(range(xy.shape[0])):
+ if xy.shape[-1] == 2:
+ smoothed += (
+ exp(-((norm((xy[i, :] - XY), axis=-1) / (sqrt(2) * sigma)) ** 2.0))
+ / factor
+ )
+ elif xy.shape[-1] == 3:
+ smoothed += (
+ exp(-((norm((xy[i, :2] - XY), axis=-1) / (sqrt(2) * sigma)) ** 2.0))
+ / factor
+ * xy[i, 2]
+ )
+ return smoothed, XY
+
+
+def plot_grid_data(xy_grid, data_grid, cbar=False, _plt=None, _cmap="RdBu", grid_min=1):
+ """# Plot Gridded Data"""
+ from numpy import nanmin, nanmax
+ from matplotlib.colors import TwoSlopeNorm
+
+ if _plt is None:
+ _plt = plt
+ norm = TwoSlopeNorm(
+ vmin=nanmin([-grid_min, nanmin(data_grid)]),
+ vcenter=0,
+ vmax=nanmax([grid_min, nanmax(data_grid)]),
+ )
+ _plt.pcolor(xy_grid[:, :, 0], xy_grid[:, :, 1], data_grid, cmap="RdBu", norm=norm)
+ if cbar:
+ _plt.colorbar()
+
+
+def plot_winding(xy_grid, winding_grid, cbar=False, _plt=None):
+ plot_grid_data(xy_grid, winding_grid, cbar=cbar, _plt=_plt, grid_min=1e-5)
+
+
+def plot_vorticity(xy_grid, vorticity_grid, cbar=False, save=False, _plt=None):
+ plot_grid_data(xy_grid, vorticity_grid, cbar=cbar, _plt=_plt, grid_min=1e-1)
+
+
+plt.blank = lambda: blank_plot(plt)
+plt.scatter2d = lambda data, labels=None, **matplotlib_kwargs: scatter_2d(
+ data, labels, plt, **matplotlib_kwargs
+)
+plt.save = save_plot
+
+
+def interpolate_points(df, columns=["Latent-0", "Latent-1"], kind="quadratic", N=500):
+ """# Interpolate points"""
+ from scipy.interpolate import interp1d
+ import numpy as np
+
+ points = df[columns].to_numpy()
+ distance = np.cumsum(np.sqrt(np.sum(np.diff(points, axis=0) ** 2, axis=1)))
+ distance = np.insert(distance, 0, 0) / distance[-1]
+ interpolator = interp1d(distance, points, kind=kind, axis=0)
+ alpha = np.linspace(0, 1, N)
+ interpolated_points = interpolator(alpha)
+ return interpolated_points.reshape(-1, 1, 2)
+
+
+def plot_trajectory(
+ df,
+ Fm=24,
+ FM=96,
+ interpolate=False,
+ interpolation_kind="quadratic",
+ interpolation_N=500,
+ columns=["Latent-0", "Latent-1"],
+ frame_column="Frames",
+ alpha=0.8,
+ lw=4,
+ _plt=None,
+ _z=None,
+):
+ """# Plot Trajectories
+
+ Interpolation possible"""
+ import numpy as np
+ from matplotlib.collections import LineCollection
+ import matplotlib as mpl
+
+ if _plt is None:
+ _plt = plt
+ if type(_plt) == mpl.axes._axes.Axes:
+ _ca = _plt
+ else:
+ try:
+ _ca = _plt.gca()
+ except:
+ _ca = _plt
+ X = df[columns[0]]
+ Y = df[columns[1]]
+ fm, fM = np.min(df[frame_column]), np.max(df[frame_column])
+
+ if interpolate:
+ if interpolation_kind == "cubic":
+ if len(df) <= 3:
+ return
+ elif interpolation_kind == "quadratic":
+ if len(df) <= 2:
+ return
+ else:
+ if len(df) <= 1:
+ return
+ points = interpolate_points(
+ df, kind=interpolation_kind, columns=columns, N=interpolation_N
+ )
+ else:
+ points = np.array([X, Y]).T.reshape(-1, 1, 2)
+
+ segments = np.concatenate([points[:-1], points[1:]], axis=1)
+ lc = LineCollection(
+ segments,
+ cmap=plt.get_cmap("plasma"),
+ norm=mpl.colors.Normalize(Fm, FM),
+ )
+ if _z is not None:
+ from mpl_toolkits.mplot3d.art3d import line_collection_2d_to_3d
+
+ if interpolate:
+ _z = _z # TODO: Interpolate
+ line_collection_2d_to_3d(lc, _z)
+ lc.set_array(np.linspace(fm, fM, points.shape[0]))
+ lc.set_linewidth(lw)
+ lc.set_alpha(alpha)
+ _ca.add_collection(lc)
+ _ca.autoscale()
+ _ca.margins(0.04)
+ return lc