Source code for hal.charts.plotter

# -*- coding: utf-8 -*-

"""Show elegant plots in any dimension """

import abc

import matplotlib.pyplot as plt
import numpy
from matplotlib.widgets import Slider
from scipy import linspace

[docs]class Plotter: """Plots something in N-dimensional space"""
[docs] @abc.abstractmethod def scatter(self, vectors): """Plots scatter data :param vectors: list of vectors (x, y, ...) """ pass
[docs] @abc.abstractmethod def param(self, functions, min_val, max_val, points): """Plots parametric data :param functions: functions to plot (x, y ...) :param min_val: minimum value :param max_val: maximum value :param points: number of points to display """ pass
[docs] @abc.abstractmethod def plot(self, func, mins, maxs, points): """Plots function :param func: function to plot :param mins: minimum of values (x, y ...) :param maxs: maximum of values (x, y ...) :param points: points in axis (x, y ...) """ pass
[docs] @staticmethod def show_plot(): """Shows plot""" plt.legend()
[docs]class Plot2d(Plotter): """2d plot"""
[docs] def scatter(self, vectors): vector_x = vectors[0] vector_y = vectors[1] plt.plot(vector_x, vector_y, "-o") self.show_plot()
[docs] def param(self, functions, min_val, max_val, points): function_x = functions[0] function_y = functions[1] # limits and plot theta = linspace(min_val, max_val, points) x_axis = function_x(theta) y_axis = function_y(theta) plt.plot(x_axis, y_axis) self.show_plot()
[docs] def plot(self, func, mins, maxs, points): min_x = mins[0] max_x = maxs[0] points = points[0] x_values = linspace(min_x, max_x, points) plt.plot(x_values, func(x_values)) self.show_plot()
[docs]class Plot3d(Plotter): """3D plot"""
[docs] def scatter(self, vectors): vector_x = vectors[0] vector_y = vectors[1] vector_z = vectors[2] # general settings fig = plt.figure() chart = fig.add_subplot(111, projection='3d') # plot chart.scatter(vector_x, vector_y, vector_z, c="r", marker="o") self.show_plot()
[docs] def param(self, functions, min_val, max_val, points): function_x = functions[0] function_y = functions[1] function_z = functions[2] # general settings fig = plt.figure() chart = fig.gca(projection="3d") # limits and plot theta = linspace(min_val, max_val, points) x_axis = function_x(theta) y_axis = function_y(theta) z_axis = function_z(theta) chart.plot(x_axis, y_axis, z_axis) chart.legend() # show self.show_plot()
[docs] def plot(self, func, mins, maxs, points): min_x, min_y = mins[0], mins[1] max_x, max_y = maxs[0], maxs[1] points_x, points_y = points[0], points[1] # general settings chart = plt.axes(projection="3d") # points x_axis = numpy.outer( linspace(min_x, max_x, points_x), numpy.ones(points_x) ) y_axis = numpy.outer( linspace(min_y, max_y, points_y), numpy.ones(points_y) ).T z_axis = func(x_axis, y_axis) # plot chart.plot_surface( x_axis, y_axis, z_axis,, rstride=1, cstride=1, linewidth=0 ) self.show_plot()
[docs]class Plot4d(Plotter): """4D plot generator with slider""" @DeprecationWarning def scatter(self, vectors): pass @DeprecationWarning def param(self, functions, min_val, max_val, points): pass
[docs] def plot(self, func, mins, maxs, points=None): self.plot_type(func, mins, maxs, 0.5, "slice")
[docs] def plot_type(self, func, mins, maxs, precision, kind): """Plots function :param func: function to plot :param mins: minimum of values (x, y ...) :param maxs: maximum of values (x, y ...) :param precision: precision to plot :param kind: kind of plot, "slice", "countour" """ min_x, min_y, min_z = mins[0], mins[1], mins[2] max_x, max_y, max_z = maxs[0], maxs[1], maxs[2] def set_labels(graph, label_x, label_y, label_z): """Sets given labels to axes of graph :param graph: plot :param label_x: new label on x axis :param label_y: new label on y axis :param label_z: new label on z axis """ graph.set_xlabel(label_x) graph.set_ylabel(label_y) graph.set_zlabel(label_z) def set_limits(graph): """Set chart limits to axes of graph :param graph: plot """ graph.set_xlim(min_x, max_x) graph.set_ylim(min_y, max_y) graph.set_zlim(min_z, max_z) def get_precision(min_val, max_val): """Calculates precision :param min_val: minimum :param max_val: maximum :return: precision: prevision of values """ return int((max_val - min_val) * (1 + precision)) def get_precision_delta(min_val, max_val): """Calculates precision delta :param min_val: minimum :param max_val: maximum :return: delta: Precision delta """ return float(max_val - min_val) / float(10 * precision) def plot_slice(): """ Plots slice :return: shows plot """ chart = plt.axes(projection="3d") # general settings points_x = get_precision(min_x, max_x) points_y = get_precision(min_y, max_z) x_axis = numpy.outer(linspace(min_x, max_x, points_x), points_x) y_axis = numpy.outer( linspace(min_y, max_y, points_y).flatten(), points_y ).T def update(val): """Updates chart with value :param val: value """ chart.clear() x_const = slider.val z_axis = func(x_const, x_axis, y_axis) chart.plot_surface( x_axis, y_axis, z_axis, alpha=0.3, linewidth=2.0 ) set_labels(chart, "y", "z", "w") # slider axis_slider = plt.axes([0.12, 0.03, 0.78, 0.03], axisbg="white") slider = Slider(axis_slider, "x", min_x, max_x, valinit=min_x) slider.on_changed(update) set_limits(chart) self.show_plot() slider.on_changed(update) set_labels(chart, "y", "z", "w") def plot_countour(): """Plots countour """ # general settings fig = plt.figure() chart = fig.gca(projection="3d") # create axes x_axis = numpy.arange(min_x, max_x, get_precision_delta( min_x, max_x)).tolist() y_axis = numpy.arange(min_y, max_y, get_precision_delta( min_y, max_y)).tolist() x_axis, y_axis = numpy.meshgrid(x_axis, y_axis) def update(val): """Updates chart with value :param val: value """ chart.clear() # re-plot x_const = slider.val z_axis = [] # add new points for i, _ in enumerate(x_axis): z_axis.append(func(x_const, x_axis[i], y_axis[i])) # show chart.contour( x_axis, y_axis, z_axis, zdir="x", offset=min_x ) chart.contour( x_axis, y_axis, z_axis, zdir="y", offset=min_y ) chart.contour( x_axis, y_axis, z_axis, zdir="z", offset=min_z ) chart.contour(x_axis, y_axis, z_axis, extend3d=True) set_labels(chart, "y", "z", "w") # slider axis_slider = plt.axes([0.12, 0.03, 0.78, 0.03], axisbg="white") slider = Slider(axis_slider, "x", min_x, max_x, valinit=min_x) slider.on_changed(update) set_limits(chart) if kind == "slice": plot_slice() elif kind == "countour": plot_countour() self.show_plot()