本文整理汇总了Python中pymatgen.util.plotting.pretty_plot函数的典型用法代码示例。如果您正苦于以下问题:Python pretty_plot函数的具体用法?Python pretty_plot怎么用?Python pretty_plot使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了pretty_plot函数的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: get_xrd_plot
def get_xrd_plot(self, structure, two_theta_range=(0, 90),
annotate_peaks=True):
"""
Returns the XRD plot as a matplotlib.pyplot.
Args:
structure: Input structure
two_theta_range ([float of length 2]): Tuple for range of
two_thetas to calculate in degrees. Defaults to (0, 90). Set to
None if you want all diffracted beams within the limiting
sphere of radius 2 / wavelength.
annotate_peaks: Whether to annotate the peaks with plane
information.
Returns:
(matplotlib.pyplot)
"""
from pymatgen.util.plotting import pretty_plot
plt = pretty_plot(16, 10)
for two_theta, i, hkls, d_hkl in self.get_xrd_data(
structure, two_theta_range=two_theta_range):
if two_theta_range[0] <= two_theta <= two_theta_range[1]:
label = ", ".join([str(hkl) for hkl in hkls.keys()])
plt.plot([two_theta, two_theta], [0, i], color='k',
linewidth=3, label=label)
if annotate_peaks:
plt.annotate(label, xy=[two_theta, i],
xytext=[two_theta, i], fontsize=16)
plt.xlabel(r"$2\theta$ ($^\circ$)")
plt.ylabel("Intensities (scaled)")
plt.tight_layout()
return plt
示例2: get_plot
def get_plot(self, normalize_rxn_coordinate=True, label_barrier=True):
"""
Returns the NEB plot. Uses Henkelman's approach of spline fitting
each section of the reaction path based on tangent force and energies.
Args:
normalize_rxn_coordinate (bool): Whether to normalize the
reaction coordinate to between 0 and 1. Defaults to True.
label_barrier (bool): Whether to label the maximum barrier.
Returns:
matplotlib.pyplot object.
"""
plt = pretty_plot(12, 8)
scale = 1 if not normalize_rxn_coordinate else 1 / self.r[-1]
x = np.arange(0, np.max(self.r), 0.01)
y = self.spline(x) * 1000
relative_energies = self.energies - self.energies[0]
plt.plot(self.r * scale, relative_energies * 1000, 'ro',
x * scale, y, 'k-', linewidth=2, markersize=10)
plt.xlabel("Reaction coordinate")
plt.ylabel("Energy (meV)")
plt.ylim((np.min(y) - 10, np.max(y) * 1.02 + 20))
if label_barrier:
data = zip(x * scale, y)
barrier = max(data, key=lambda d: d[1])
plt.plot([0, barrier[0]], [barrier[1], barrier[1]], 'k--')
plt.annotate('%.0f meV' % barrier[1],
xy=(barrier[0] / 2, barrier[1] * 1.02),
xytext=(barrier[0] / 2, barrier[1] * 1.02),
horizontalalignment='center')
plt.tight_layout()
return plt
示例3: get_framework_rms_plot
def get_framework_rms_plot(self, plt=None, granularity=200,
matching_s=None):
"""
Get the plot of rms framework displacement vs time. Useful for checking
for melting, especially if framework atoms can move via paddle-wheel
or similar mechanism (which would show up in max framework displacement
but doesn't constitute melting).
Args:
plt (matplotlib.pyplot): If plt is supplied, changes will be made
to an existing plot. Otherwise, a new plot will be created.
granularity (int): Number of structures to match
matching_s (Structure): Optionally match to a disordered structure
instead of the first structure in the analyzer. Required when
a secondary mobile ion is present.
Notes:
The method doesn't apply to NPT-AIMD simulation analysis.
"""
from pymatgen.util.plotting import pretty_plot
if self.lattices is not None and len(self.lattices) > 1:
warnings.warn("Note the method doesn't apply to NPT-AIMD "
"simulation analysis!")
plt = pretty_plot(12, 8, plt=plt)
step = (self.corrected_displacements.shape[1] - 1) // (granularity - 1)
f = (matching_s or self.structure).copy()
f.remove_species([self.specie])
sm = StructureMatcher(primitive_cell=False, stol=0.6,
comparator=OrderDisorderElementComparator(),
allow_subset=True)
rms = []
for s in self.get_drift_corrected_structures(step=step):
s.remove_species([self.specie])
d = sm.get_rms_dist(f, s)
if d:
rms.append(d)
else:
rms.append((1, 1))
max_dt = (len(rms) - 1) * step * self.step_skip * self.time_step
if max_dt > 100000:
plot_dt = np.linspace(0, max_dt / 1000, len(rms))
unit = 'ps'
else:
plot_dt = np.linspace(0, max_dt, len(rms))
unit = 'fs'
rms = np.array(rms)
plt.plot(plot_dt, rms[:, 0], label='RMS')
plt.plot(plot_dt, rms[:, 1], label='max')
plt.legend(loc='best')
plt.xlabel("Timestep ({})".format(unit))
plt.ylabel("normalized distance")
plt.tight_layout()
return plt
示例4: get_msd_plot
def get_msd_plot(self, plt=None, mode="specie"):
"""
Get the plot of the smoothed msd vs time graph. Useful for
checking convergence. This can be written to an image file.
Args:
plt: A plot object. Defaults to None, which means one will be
generated.
mode (str): Determines type of msd plot. By "species", "sites",
or direction (default). If mode = "mscd", the smoothed mscd vs.
time will be plotted.
"""
from pymatgen.util.plotting import pretty_plot
plt = pretty_plot(12, 8, plt=plt)
if np.max(self.dt) > 100000:
plot_dt = self.dt / 1000
unit = 'ps'
else:
plot_dt = self.dt
unit = 'fs'
if mode == "species":
for sp in sorted(self.structure.composition.keys()):
indices = [i for i, site in enumerate(self.structure) if
site.specie == sp]
sd = np.average(self.sq_disp_ions[indices, :], axis=0)
plt.plot(plot_dt, sd, label=sp.__str__())
plt.legend(loc=2, prop={"size": 20})
elif mode == "sites":
for i, site in enumerate(self.structure):
sd = self.sq_disp_ions[i, :]
plt.plot(plot_dt, sd, label="%s - %d" % (
site.specie.__str__(), i))
plt.legend(loc=2, prop={"size": 20})
elif mode == "mscd":
plt.plot(plot_dt, self.mscd, 'r')
plt.legend(["Overall"], loc=2, prop={"size": 20})
else:
# Handle default / invalid mode case
plt.plot(plot_dt, self.msd, 'k')
plt.plot(plot_dt, self.msd_components[:, 0], 'r')
plt.plot(plot_dt, self.msd_components[:, 1], 'g')
plt.plot(plot_dt, self.msd_components[:, 2], 'b')
plt.legend(["Overall", "a", "b", "c"], loc=2, prop={"size": 20})
plt.xlabel("Timestep ({})".format(unit))
if mode == "mscd":
plt.ylabel("MSCD ($\\AA^2$)")
else:
plt.ylabel("MSD ($\\AA^2$)")
plt.tight_layout()
return plt
示例5: get_pourbaix_plot
def get_pourbaix_plot(self, limits=None, title="",
label_domains=True, plt=None):
"""
Plot Pourbaix diagram.
Args:
limits: 2D list containing limits of the Pourbaix diagram
of the form [[xlo, xhi], [ylo, yhi]]
title (str): Title to display on plot
label_domains (bool): whether to label pourbaix domains
plt (pyplot): Pyplot instance for plotting
Returns:
plt (pyplot) - matplotlib plot object with pourbaix diagram
"""
if limits is None:
limits = [[-2, 16], [-3, 3]]
plt = plt or pretty_plot(16)
xlim = limits[0]
ylim = limits[1]
h_line = np.transpose([[xlim[0], -xlim[0] * PREFAC],
[xlim[1], -xlim[1] * PREFAC]])
o_line = np.transpose([[xlim[0], -xlim[0] * PREFAC + 1.23],
[xlim[1], -xlim[1] * PREFAC + 1.23]])
neutral_line = np.transpose([[7, ylim[0]], [7, ylim[1]]])
V0_line = np.transpose([[xlim[0], 0], [xlim[1], 0]])
ax = plt.gca()
ax.set_xlim(xlim)
ax.set_ylim(ylim)
lw = 3
plt.plot(h_line[0], h_line[1], "r--", linewidth=lw)
plt.plot(o_line[0], o_line[1], "r--", linewidth=lw)
plt.plot(neutral_line[0], neutral_line[1], "k-.", linewidth=lw)
plt.plot(V0_line[0], V0_line[1], "k-.", linewidth=lw)
for entry, vertices in self._pd._stable_domain_vertices.items():
center = np.average(vertices, axis=0)
x, y = np.transpose(np.vstack([vertices, vertices[0]]))
plt.plot(x, y, 'k-', linewidth=lw)
if label_domains:
plt.annotate(generate_entry_label(entry), center, ha='center',
va='center', fontsize=20, color="b")
plt.xlabel("pH")
plt.ylabel("E (V)")
plt.title(title, fontsize=20, fontweight='bold')
return plt
示例6: plot
def plot(self, width=8, height=None, plt=None, dpi=None, **kwargs):
"""
Plot the equation of state.
Args:
width (float): Width of plot in inches. Defaults to 8in.
height (float): Height of plot in inches. Defaults to width *
golden ratio.
plt (matplotlib.pyplot): If plt is supplied, changes will be made
to an existing plot. Otherwise, a new plot will be created.
dpi:
kwargs (dict): additional args fed to pyplot.plot.
supported keys: style, color, text, label
Returns:
Matplotlib plot object.
"""
plt = pretty_plot(width=width, height=height, plt=plt, dpi=dpi)
color = kwargs.get("color", "r")
label = kwargs.get("label", "{} fit".format(self.__class__.__name__))
lines = ["Equation of State: %s" % self.__class__.__name__,
"Minimum energy = %1.2f eV" % self.e0,
"Minimum or reference volume = %1.2f Ang^3" % self.v0,
"Bulk modulus = %1.2f eV/Ang^3 = %1.2f GPa" %
(self.b0, self.b0_GPa),
"Derivative of bulk modulus wrt pressure = %1.2f" % self.b1]
text = "\n".join(lines)
text = kwargs.get("text", text)
# Plot input data.
plt.plot(self.volumes, self.energies, linestyle="None", marker="o",
color=color)
# Plot eos fit.
vmin, vmax = min(self.volumes), max(self.volumes)
vmin, vmax = (vmin - 0.01 * abs(vmin), vmax + 0.01 * abs(vmax))
vfit = np.linspace(vmin, vmax, 100)
plt.plot(vfit, self.func(vfit), linestyle="dashed", color=color,
label=label)
plt.grid(True)
plt.xlabel("Volume $\\AA^3$")
plt.ylabel("Energy (eV)")
plt.legend(loc="best", shadow=True)
# Add text with fit parameters.
plt.text(0.4, 0.5, text, transform=plt.gca().transAxes)
return plt
示例7: get_plot
def get_plot(self, structure, two_theta_range=(0, 90),
annotate_peaks=True, ax=None, with_labels=True,
fontsize=16):
"""
Returns the diffraction plot as a matplotlib.pyplot.
Args:
structure: Input structure
two_theta_range ([float of length 2]): Tuple for range of
two_thetas to calculate in degrees. Defaults to (0, 90). Set to
None if you want all diffracted beams within the limiting
sphere of radius 2 / wavelength.
annotate_peaks: Whether to annotate the peaks with plane
information.
ax: matplotlib :class:`Axes` or None if a new figure should be created.
with_labels: True to add xlabels and ylabels to the plot.
fontsize: (int) fontsize for peak labels.
Returns:
(matplotlib.pyplot)
"""
if ax is None:
from pymatgen.util.plotting import pretty_plot
plt = pretty_plot(16, 10)
ax = plt.gca()
else:
# This to maintain the type of the return value.
import matplotlib.pyplot as plt
xrd = self.get_pattern(structure, two_theta_range=two_theta_range)
for two_theta, i, hkls, d_hkl in zip(xrd.x, xrd.y, xrd.hkls, xrd.d_hkls):
if two_theta_range[0] <= two_theta <= two_theta_range[1]:
print(hkls)
label = ", ".join([str(hkl["hkl"]) for hkl in hkls])
ax.plot([two_theta, two_theta], [0, i], color='k',
linewidth=3, label=label)
if annotate_peaks:
ax.annotate(label, xy=[two_theta, i],
xytext=[two_theta, i], fontsize=fontsize)
if with_labels:
ax.set_xlabel(r"$2\theta$ ($^\circ$)")
ax.set_ylabel("Intensities (scaled)")
if hasattr(ax, "tight_layout"):
ax.tight_layout()
return plt
示例8: get_plot
def get_plot(self, ylim=None):
"""
Get a matplotlib object for the bandstructure plot.
Args:
ylim: Specify the y-axis (frequency) limits; by default None let
the code choose.
"""
plt = pretty_plot(12, 8)
from matplotlib import rc
import scipy.interpolate as scint
try:
rc('text', usetex=True)
except:
# Fall back on non Tex if errored.
rc('text', usetex=False)
band_linewidth = 1
data = self.bs_plot_data()
for d in range(len(data['distances'])):
for i in range(self._nb_bands):
plt.plot(data['distances'][d],
[data['frequency'][d][i][j]
for j in range(len(data['distances'][d]))], 'b-',
linewidth=band_linewidth)
self._maketicks(plt)
# plot y=0 line
plt.axhline(0, linewidth=1, color='k')
# Main X and Y Labels
plt.xlabel(r'$\mathrm{Wave\ Vector}$', fontsize=30)
ylabel = r'$\mathrm{Frequency\ (THz)}$'
plt.ylabel(ylabel, fontsize=30)
# X range (K)
# last distance point
x_max = data['distances'][-1][-1]
plt.xlim(0, x_max)
if ylim is not None:
plt.ylim(ylim)
plt.tight_layout()
return plt
示例9: get_plot
def get_plot(self, ylim=None, units="thz"):
"""
Get a matplotlib object for the bandstructure plot.
Args:
ylim: Specify the y-axis (frequency) limits; by default None let
the code choose.
units: units for the frequencies. Accepted values thz, ev, mev, ha, cm-1, cm^-1.
"""
u = freq_units(units)
plt = pretty_plot(12, 8)
band_linewidth = 1
data = self.bs_plot_data()
for d in range(len(data['distances'])):
for i in range(self._nb_bands):
plt.plot(data['distances'][d],
[data['frequency'][d][i][j] * u.factor
for j in range(len(data['distances'][d]))], 'b-',
linewidth=band_linewidth)
self._maketicks(plt)
# plot y=0 line
plt.axhline(0, linewidth=1, color='k')
# Main X and Y Labels
plt.xlabel(r'$\mathrm{Wave\ Vector}$', fontsize=30)
ylabel = r'$\mathrm{{Frequencies\ ({})}}$'.format(u.label)
plt.ylabel(ylabel, fontsize=30)
# X range (K)
# last distance point
x_max = data['distances'][-1][-1]
plt.xlim(0, x_max)
if ylim is not None:
plt.ylim(ylim)
plt.tight_layout()
return plt
示例10: get_arrhenius_plot
def get_arrhenius_plot(temps, diffusivities, diffusivity_errors=None,
**kwargs):
"""
Returns an Arrhenius plot.
Args:
temps ([float]): A sequence of temperatures.
diffusivities ([float]): A sequence of diffusivities (e.g.,
from DiffusionAnalyzer.diffusivity).
diffusivity_errors ([float]): A sequence of errors for the
diffusivities. If None, no error bar is plotted.
\\*\\*kwargs:
Any keyword args supported by matplotlib.pyplot.plot.
Returns:
A matplotlib.pyplot object. Do plt.show() to show the plot.
"""
Ea, c, _ = fit_arrhenius(temps, diffusivities)
from pymatgen.util.plotting import pretty_plot
plt = pretty_plot(12, 8)
# log10 of the arrhenius fit
arr = c * np.exp(-Ea / (const.k / const.e * np.array(temps)))
t_1 = 1000 / np.array(temps)
plt.plot(t_1, diffusivities, 'ko', t_1, arr, 'k--', markersize=10,
**kwargs)
if diffusivity_errors is not None:
n = len(diffusivity_errors)
plt.errorbar(t_1[0:n], diffusivities[0:n], yerr=diffusivity_errors,
fmt='ko', ecolor='k', capthick=2, linewidth=2)
ax = plt.axes()
ax.set_yscale('log')
plt.text(0.6, 0.85, "E$_a$ = {:.0f} meV".format(Ea * 1000),
fontsize=30, transform=plt.axes().transAxes)
plt.ylabel("D (cm$^2$/s)")
plt.xlabel("1000/T (K$^{-1}$)")
plt.tight_layout()
return plt
示例11: main
def main():
parser = argparse.ArgumentParser(description='''
Convenient DOS Plotter for Feff runs.
Author: Alan Dozier
Version: 1.0
Last updated: April, 2013''')
parser.add_argument('filename', metavar='filename', type=str, nargs=1,
help='xmu file to plot')
parser.add_argument('filename1', metavar='filename1', type=str, nargs=1,
help='feff.inp filename to import')
plt = pretty_plot(12, 8)
color_order = ['r', 'b', 'g', 'c', 'k', 'm', 'y']
args = parser.parse_args()
xmu = Xmu.from_file(args.filename[0], args.filename1[0])
data = xmu.to_dict
plt.title(data['calc'] + ' Feff9.6 Calculation for ' + data['atom'] + ' in ' +
data['formula'] + ' unit cell')
plt.xlabel('Energies (eV)')
plt.ylabel('Absorption Cross-section')
x = data['energies']
y = data['scross']
tle = 'Single ' + data['atom'] + ' ' + data['edge'] + ' edge'
plt.plot(x, y, color_order[1 % 7], label=tle)
y = data['across']
tle = data['atom'] + ' ' + data['edge'] + ' edge in ' + data['formula']
plt.plot(x, y, color_order[2 % 7], label=tle)
plt.legend()
leg = plt.gca().get_legend()
ltext = leg.get_texts() # all the text.Text instance in the legend
plt.setp(ltext, fontsize=15)
plt.tight_layout()
plt.show()
示例12: get_chgint_plot
def get_chgint_plot(args):
chgcar = Chgcar.from_file(args.chgcar_file)
s = chgcar.structure
if args.inds:
atom_ind = [int(i) for i in args.inds[0].split(",")]
else:
finder = SpacegroupAnalyzer(s, symprec=0.1)
sites = [sites[0] for sites in
finder.get_symmetrized_structure().equivalent_sites]
atom_ind = [s.sites.index(site) for site in sites]
from pymatgen.util.plotting import pretty_plot
plt = pretty_plot(12, 8)
for i in atom_ind:
d = chgcar.get_integrated_diff(i, args.radius, 30)
plt.plot(d[:, 0], d[:, 1],
label="Atom {} - {}".format(i, s[i].species_string))
plt.legend(loc="upper left")
plt.xlabel("Radius (A)")
plt.ylabel("Integrated charge (e)")
plt.tight_layout()
return plt
示例13: get_plot
def get_plot(self, xlim=None, ylim=None):
"""
Get a matplotlib plot showing the DOS.
Args:
xlim: Specifies the x-axis limits. Set to None for automatic
determination.
ylim: Specifies the y-axis limits.
"""
plt = pretty_plot(12, 8)
base = 0.0
i = 0
for key, sp in self._spectra.items():
if not self.stack:
plt.plot(sp.x, sp.y + self.yshift * i, color=self.colors[i],
label=str(key), linewidth=3)
else:
plt.fill_between(sp.x, base, sp.y + self.yshift * i,
color=self.colors[i],
label=str(key), linewidth=3)
base = sp.y + base
plt.xlabel(sp.XLABEL)
plt.ylabel(sp.YLABEL)
i += 1
if xlim:
plt.xlim(xlim)
if ylim:
plt.ylim(ylim)
plt.legend()
leg = plt.gca().get_legend()
ltext = leg.get_texts() # all the text.Text instance in the legend
plt.setp(ltext, fontsize=30)
plt.tight_layout()
return plt
示例14: get_plot
def get_plot(self, width=8, height=8):
"""
Returns a plot object.
Args:
width: Width of the plot. Defaults to 8 in.
height: Height of the plot. Defaults to 6 in.
Returns:
A matplotlib plot object.
"""
plt = pretty_plot(width, height)
for label, electrode in self._electrodes.items():
(x, y) = self.get_plot_data(electrode)
plt.plot(x, y, '-', linewidth=2, label=label)
plt.legend()
if self.xaxis == "capacity":
plt.xlabel('Capacity (mAh/g)')
else:
plt.xlabel('Fraction')
plt.ylabel('Voltage (V)')
plt.tight_layout()
return plt
示例15: get_pourbaix_plot
def get_pourbaix_plot(self, limits=None, title="", label_domains=True):
"""
Plot Pourbaix diagram.
Args:
limits: 2D list containing limits of the Pourbaix diagram
of the form [[xlo, xhi], [ylo, yhi]]
Returns:
plt:
matplotlib plot object
"""
# plt = pretty_plot(24, 14.4)
plt = pretty_plot(16)
(stable, unstable) = self.pourbaix_plot_data(limits)
if limits:
xlim = limits[0]
ylim = limits[1]
else:
xlim = self._analyzer.chempot_limits[0]
ylim = self._analyzer.chempot_limits[1]
h_line = np.transpose([[xlim[0], -xlim[0] * PREFAC],
[xlim[1], -xlim[1] * PREFAC]])
o_line = np.transpose([[xlim[0], -xlim[0] * PREFAC + 1.23],
[xlim[1], -xlim[1] * PREFAC + 1.23]])
neutral_line = np.transpose([[7, ylim[0]], [7, ylim[1]]])
V0_line = np.transpose([[xlim[0], 0], [xlim[1], 0]])
ax = plt.gca()
ax.set_xlim(xlim)
ax.set_ylim(ylim)
lw = 3
plt.plot(h_line[0], h_line[1], "r--", linewidth=lw)
plt.plot(o_line[0], o_line[1], "r--", linewidth=lw)
plt.plot(neutral_line[0], neutral_line[1], "k-.", linewidth=lw)
plt.plot(V0_line[0], V0_line[1], "k-.", linewidth=lw)
for entry, lines in stable.items():
center_x = 0.0
center_y = 0.0
coords = []
count_center = 0.0
for line in lines:
(x, y) = line
plt.plot(x, y, "k-", linewidth=lw)
for coord in np.array(line).T:
if not in_coord_list(coords, coord):
coords.append(coord.tolist())
cx = coord[0]
cy = coord[1]
center_x += cx
center_y += cy
count_center += 1.0
if count_center == 0.0:
count_center = 1.0
center_x /= count_center
center_y /= count_center
if ((center_x <= xlim[0]) | (center_x >= xlim[1]) |
(center_y <= ylim[0]) | (center_y >= ylim[1])):
continue
xy = (center_x, center_y)
if label_domains:
plt.annotate(self.print_name(entry), xy, fontsize=20, color="b")
plt.xlabel("pH")
plt.ylabel("E (V)")
plt.title(title, fontsize=20, fontweight='bold')
return plt