Source code for graphy.plotting

from __future__ import division, print_function, absolute_import
import six
range = six.moves.range

import collections

import numpy as np
import networkx as nx
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib import cm
import matplotlib.cbook as cbook


[docs] def plot_graph(G, pos=None, colors=None, node_labels=None, node_size=0.04, edgescale=1.0, nodeopts={}, labelopts={}, arrowopts={}, bidir_arrows=True, cmap='Paired', vmin=None, vmax=None, ): """Plot a graphs. Supports both directed and undirected graphs. Undirected edges are drawn as lines while directed edges are drawn as arrows. For example: .. plot:: :include-source: >>> from graphy import plotting >>> import networkx as nx >>> G=nx.karate_club_graph() >>> plotting.plot_graph(G, pos=nx.spring_layout(G), colors=range(G.number_of_nodes())) # doctest: +SKIP Parameters ---------- G : networkx Graph object or 2-d np.array Graph to plot, either instance of networkx Graph or a 2-d connectivity matrix. pos : dict Dict specifying positions of nodes, as in {node: (x,y). If not provided, nodes are arranged along a circle. colors : list of ints or list of RGBA values (default None) Color(s) to use for node faces, if desired. node_labels : list of strings (default None) Labels to use for node labels, if desired. node_size : float (default 0.05) Size of nodes. edgescale : float (default 1.0) Controls thickness of edges between nodes. nodeopts : dict (default {}) Extra options to pass into plt.Circle call for plotting nodes. labelopts : dict (default {}) Extra options to pass into plt.text call for plotting labels. Could be used to specify fontsize, for example. arrowopts : dict (default {}) Extra options to pass into plt.arrow call for plotting edges. bidir_arrows : bool (default True) Whether to draw arrowheads when graph is directed and two nodes are connected bidirectionally. cmap : string (default 'Paired') Name of colormap to use. vmin : float (default is minimum of colors) Starting value to use for colormap. vmax : float (default is minimum of colors) Ending value to use for colormap. """ class MplColorHelper: """Class that helps pick colors from specified colormaps. """ def __init__(self, cmap_name, start_val, stop_val): self.cmap_name = cmap_name self.cmap = plt.get_cmap(cmap_name) self.norm = mpl.colors.Normalize(vmin=start_val, vmax=stop_val) self.scalarMap = cm.ScalarMappable(norm=self.norm, cmap=self.cmap) def get_rgb(self, val): return self.scalarMap.to_rgba(val) def intersect_two_circles(xy1, r1, xy2, r2): d = np.linalg.norm(xy2-xy1) a = (r1**2 - r2**2 + d**2) / (2 * d) b = d - a h = np.sqrt(r1**2 - a**2) xy3 = xy1 + a * (xy2 - xy1) / d ix = xy3[0] + h * (xy2[1] - xy1[1]) / d iy = xy3[1] - h * (xy2[0] - xy1[0]) / d return np.array([ix, iy]) if isinstance(G, (np.ndarray, np.generic) ): G = nx.DiGraph(np.array(G)) elif not isinstance(G, nx.Graph): raise ValueError('Unknown type of graph: %s' % str(type(G))) arrowopts = cbook.normalize_kwargs(arrowopts, mpl.patches.Arrow) nodeopts = cbook.normalize_kwargs(nodeopts, plt.Circle) labelopts = cbook.normalize_kwargs(labelopts, mpl.text.Text) if pos is None: pos = nx.circular_layout(G) if colors is None: colors = 1 # check if colors is iterable try: iter(colors) except TypeError: colors = [colors,] * G.number_of_nodes() colors = np.asarray(colors) if colors.ndim == 1 or colors.shape[1] not in (3, 4): if vmin is None: vmin = colors.min() if vmax is None: vmax = colors.max() cmap_helper = MplColorHelper(cmap, vmin, vmax) colors = np.asarray([cmap_helper.get_rgb(c) for c in colors]) bbox = plt.gca().get_window_extent() asp_ratio = bbox.width/bbox.height node_map = { n:ndx for ndx, n in enumerate(G.nodes())} xys = np.array([pos[n] for n in G.nodes()]) xys -= xys.min(axis=0) maxvalues = xys.max(axis=0) maxvalues[maxvalues == 0] = 1 xys /= maxvalues xys[:,0] *= asp_ratio plt.xlim([-(3*node_size)*asp_ratio,(1+(3*node_size))*asp_ratio]) plt.ylim([-(3*node_size),1+(3*node_size)]) plt.axis('off') try: edge_weights = nx.get_edge_attributes(G, 'weight') except KeyError: edge_weights = {} for edge in G.edges(): startxy = xys[node_map[edge[0]],:].copy() endxy = xys[node_map[edge[1]],:].copy() arrowdict = dict(length_includes_head=True, shape='full', linewidth=edge_weights.get(edge,1.0)*edgescale) if 'edgecolor' in G.edges[edge]: ec = G.edges[edge]['edgecolor'] arrowdict.update({'edgecolor':ec, 'facecolor':ec}) elif 'color' not in arrowopts: arrowdict.update({'edgecolor':'k', 'facecolor':'k'}) arrowdict.update(arrowopts) headscale = node_size*0.5 has_reverse = (edge[1], edge[0]) in G.edges() if G.is_directed() and (not has_reverse or bidir_arrows): head_scale = node_size*0.5 else: head_scale = 0 if 'head_length' not in arrowdict: arrowdict['head_length'] = head_scale if 'head_width' not in arrowdict: arrowdict['head_width'] = head_scale if edge[0] == edge[1]: loopoffset = np.sign(startxy - xys.mean(axis=0)) * node_size * 1.05 cloop = plt.Circle(startxy+loopoffset, radius=node_size*0.65, edgecolor='k', fill=False, linewidth=arrowdict['linewidth']) plt.gca().add_artist(cloop) arrowloc = intersect_two_circles(startxy, node_size, startxy+loopoffset, node_size*0.7) arrowlocstart = arrowloc + (arrowloc - startxy)*1e-5 plt.arrow(arrowlocstart[0], arrowlocstart[1], arrowloc[0]-arrowlocstart[0], arrowloc[1]-arrowlocstart[1], **arrowdict) else: angle = np.arctan2(endxy[1]-startxy[1], endxy[0]-startxy[0]) offset = np.array([np.cos(angle),np.sin(angle)])*node_size startxy += offset endxy -= offset if nx.is_directed(G) and has_reverse: midxy = (startxy + endxy) / 2.0 plt.arrow(midxy[0], midxy[1], endxy[0]-midxy[0], endxy[1]-midxy[1], **arrowdict) plt.arrow(midxy[0], midxy[1], startxy[0]-midxy[0], startxy[1]-midxy[1], **arrowdict) else: carrowdict = arrowdict carrowdict = carrowdict.copy() plt.arrow(startxy[0], startxy[1], endxy[0]-startxy[0], endxy[1]-startxy[1], **arrowdict) clabelopts = {'horizontalalignment': 'center', 'verticalalignment': 'center'} clabelopts.update(labelopts) for ndx, xy in enumerate(xys): # Plot nodes cnodeopts = {'radius': node_size} if 'color' not in 'nodeopts': cnodeopts.update({'edgecolor': 'none', 'facecolor': colors[ndx]}) cnodeopts.update(nodeopts) cnode = plt.Circle((xy[0],xy[1]), **cnodeopts) plt.gca().add_artist(cnode) if node_labels is not None: plt.text(xy[0],xy[1], node_labels[ndx], **clabelopts)