当前位置: 首页>>代码示例>>Python>>正文


Python cm.spectral函数代码示例

本文整理汇总了Python中matplotlib.cm.spectral函数的典型用法代码示例。如果您正苦于以下问题:Python spectral函数的具体用法?Python spectral怎么用?Python spectral使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。


在下文中一共展示了spectral函数的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: bar_graph

def bar_graph(data, bar_names, x_label='', y_label='', title='', axis=None, colors=None, legend_place='lower right'):
    """Create horzontal bar chart with lists of data values.

    Plots a bar chart given a dictionary of *data* with a type as key, and a sequence of
    values corresponding to elements in *bar_names* as value.

    Place legend with *legend_place* as string argument matching
    /(lower|middle|upper) (right|center|left)/.
    """
    from matplotlib import cm
    fig = plt.figure()
    plt.xlabel(x_label)
    plt.ylabel(y_label)
    plt.title(title)
    ax = fig.add_subplot(111)

    num_groups = len(data.values()[0])
    group_size = len(data.values())
    yvals = np.arange(num_groups)
    width= 0.8/len(data.values())

    ps = []
    for i, vals in enumerate(data.values()):
        if colors is None:
            color = cm.spectral(1.*i/group_size) # colormaps: gist_rainbow, jet, hsv, spectral, ..
        else:
            color = colors[i%len(colors)]
        p = ax.barh(yvals+(width*i), vals, width, color=color)
        ps.append(p[0])

    plt.yticks(yvals+width, bar_names)
    if legend_place is not None:
        plt.legend( ps, data.keys(), loc=legend_place)

    plt.show()
开发者ID:himanshusapra9,项目名称:TextNet,代码行数:35,代码来源:plotter.py

示例2: plot_board

	def plot_board(self, custom_text=''):
		X = self.X
		fig = plt.figure(figsize=(5,5))
		plt.xlim(-1,1)
		plt.ylim(-1,1)
		if self.mu and self.clusters:
			mu = self.mu
			clus = self.clusters
			K = self.K
			for m, clu in clus.items():
				cs = cm.spectral(1.*m/self.K)
				plt.plot(mu[m][0], mu[m][1], 'o', marker='*', \
						 markersize=12, color=cs)
				plt.plot(zip(*clus[m])[0], zip(*clus[m])[1], '.', \
						 markersize=8, color=cs, alpha=0.5)
		else:
			plt.plot(zip(*X)[0], zip(*X)[1], '.', alpha=0.5)
		if self.method == '++':
			tit = 'K-means++'
		else:
			tit = 'K-means with random initialization'
		# Scale the plot image
		# X lim
		plt.xlim([min(zip(*X)[0]),max(zip(*X)[0])])
		# Y lim
		plt.ylim([min(zip(*X)[1]),max(zip(*X)[1])])

		pars = 'N=%s, K=%s' % (str(self.N), str(self.K))
		plt.title('\n'.join([pars, tit]), fontsize=16)
		plt.savefig('kpp%s_N%s_K%s.png' % (custom_text, str(self.N), str(self.K)), \
					bbox_inches='tight', dpi=200)
开发者ID:tuliof,项目名称:aprendizado-de-maquina,代码行数:31,代码来源:ex.py

示例3: elbow_clustering_analysis

def elbow_clustering_analysis():
    institution_info,X = readData()
    X=np.array(X)
    KK=range(1,20)
    KM = [kmeans(X,k) for k in KK]
    centroids = [cent for (cent,var) in KM]
    D_k = [cdist(X, cent, 'euclidean') for cent in centroids]
    cIdx = [np.argmin(D,axis=1) for D in D_k]
    dist = [np.min(D,axis=1) for D in D_k]
    tot_withinss = [sum(d**2) for d in dist]  # Total within-cluster sum of squares
    totss = sum(pdist(X)**2)/X.shape[0]       # The total sum of squares
    betweenss = totss - tot_withinss          # The between-cluster sum of squares
    kIdx = 3        # K=6
    clr = cm.spectral( np.linspace(0,1,10) ).tolist()
    # elbow curve
    fig = plt.figure()
    ax = fig.add_subplot(111)
    ax.plot(KK, betweenss/totss*100, 'b*-')
    ax.plot(KK[kIdx], betweenss[kIdx]/totss*100, marker='o', markersize=12,
    markeredgewidth=2, markeredgecolor='r', markerfacecolor='None')
    ax.set_ylim((0,100))
    plt.grid(True)
    plt.xlabel('Number of clusters')
    plt.ylabel('Percentage of variance explained (%)')
    plt.title('Elbow for KMeans clustering')
    plt.savefig('admissions_elbow_klustering_analysis.eps')
    plt.show()
开发者ID:zzw922cn,项目名称:mcm2016,代码行数:27,代码来源:admissions_kmeans.py

示例4: get_colors

    def get_colors(self, qty):

        qty = np.power(qty / qty.max(), 1.0 / CONTRAST)

        if COLORMAP == 0:
            rgba = cm.gray(qty, alpha=ALPHA)
        elif COLORMAP == 1:
            rgba = cm.afmhot(qty, alpha=ALPHA)
        elif COLORMAP == 2:
            rgba = cm.hot(qty, alpha=ALPHA)
        elif COLORMAP == 3:
            rgba = cm.gist_heat(qty, alpha=ALPHA)
        elif COLORMAP == 4:
            rgba = cm.copper(qty, alpha=ALPHA)
        elif COLORMAP == 5:
            rgba = cm.gnuplot2(qty, alpha=ALPHA)
        elif COLORMAP == 6:
            rgba = cm.gnuplot(qty, alpha=ALPHA)
        elif COLORMAP == 7:
            rgba = cm.gist_stern(qty, alpha=ALPHA)
        elif COLORMAP == 8:
            rgba = cm.gist_earth(qty, alpha=ALPHA)
        elif COLORMAP == 9:
            rgba = cm.spectral(qty, alpha=ALPHA)

        return rgba
开发者ID:schnorr,项目名称:tupan,代码行数:26,代码来源:glviewer.py

示例5: setup_figure

def setup_figure():
    fig=plt.figure(1)
    plt.clf()
    ax = fig.add_subplot(1,1,1)
    ax.set_xlim([-rho-1,rho+1])
    ax.set_ylim([-rho-1,rho+1])
    ax.set_aspect('equal')

    cells=[]
    springs=[]
    borders=[]
    for i in range(0,N):
        c = plt.Circle((-0,0),0.5,color=cm.copper(0))
        cells.append(ax.add_artist(c))

    if plot_springs:
        for i in range(0,len(pairs)):
            springs += ax.plot([], [], color=cm.spectral(0))

    if plot_voronoi:
        for i in range(0, pairs2.shape[0]):
            borders += ax.plot([], [], color='k')

    ang_mom = ax.add_patch(FancyArrowPatch((0,0),(1,1),ec='r', fc='r', zorder=0, arrowstyle=u'simple,head_width=20, head_length=10'))

    return(fig,cells,springs,borders,ang_mom)
开发者ID:epolimpio,项目名称:sphere_sim,代码行数:26,代码来源:showSim_sphere_fancy.py

示例6: __call__

	def __call__(self, event):
		if event.inaxes:
		  clickX = event.xdata
		  clickY = event.ydata
		  closest_i = 0
		  closest_dist = 10000000
		  if self.axis is None or self.axis==event.inaxes:
			cluster_num = None
			for i in range(0,len(self.data)):
				potential = self.distance(clickX, self.data[i][0], clickY, self.data[i][1])
				if potential < closest_dist:
					closest_dist = potential
					closest_i = i
			x = self.data[closest_i][0]
			y = self.data[closest_i][1]
			c = self.data[closest_i][2]
			cluster_num = c
			di = self.data[closest_i][3]
			du = self.data[closest_i][4]
			pa = self.data[closest_i][5]
			cal = self.data[closest_i][6]
			fu = self.data[closest_i][7]
			a.set_bbox(dict(facecolor=cm.spectral(float(c) / n_clusters, 1), alpha=.5))
			dist_text.set_text ("DIST (km) = %.3f" % di)
			dur_text.set_text("DUR (min) = %.3f" % du) 
			pace_text.set_text ("PACE (min/mi) = %.3f" % pa)
			cal_text.set_text ("CAL = %.3f" % cal)
			fuel_text.set_text ("FUEL = %.3f" % fu)
			
			
			num = 0
			clust_di = 0
			clust_du = 0
			clust_pa = 0
			clust_cal = 0
			clust_fu = 0
			for item in self.data:
				if item[2] == cluster_num:
					num += 1
					clust_di+=item[3]
					clust_du+=item[4]
					clust_pa+=item[5]
					clust_cal+=item[6]
					clust_fu+=item[7]
			clust_di /= float(num)
			clust_du /= float(num)
			clust_pa /= float(num)
			clust_cal /= float(num)
			clust_fu /= float(num)
			
			clust_dist_text.set_text ("DIST (km) = %.3f" % clust_di)
			clust_dur_text.set_text("DUR (min) = %.3f" % clust_du) 
			clust_pace_text.set_text ("PACE (min/mi) = %.3f" % clust_pa)
			clust_cal_text.set_text ("CAL = %.3f" % clust_cal)
			clust_fuel_text.set_text ("FUEL = %.3f" % clust_fu)
			
			figsrc.canvas.draw()
开发者ID:jbold569,项目名称:489_Project,代码行数:57,代码来源:zoom__and_click_window.py

示例7: plot_silhouette

def plot_silhouette(sample_silhouette_values, cluster_labels):
    """
    Generate silhouette plot to elucidate number of clusters in data   
    
    Source: http://scikit-learn.org/stable/auto_examples/cluster/plot_kmeans_silhouette_analysis.html

    Arguments
    =========
    sample_silhouette_values - silhouette value for every observation
    cluster_labels - sequential numeric cluster numbers
    
    Returns
    =========
    None - the figure

    """
    # Initialise variables
    n_clusters = max(cluster_labels) - min(cluster_labels) + 1 # assume cluster number are sequential
    xMin = min(sample_silhouette_values)
    xMax = 1     
    # Create a subplot with 1 row and 2 columns
    fig = plt.figure()
    #fig.set_size_inches(18, 7)
    ax1 = plt.gca()
    ax1.set_xlabel("The silhouette coefficient values")
    ax1.set_ylabel("Cluster label")
    ax1.set_title('Silhouette Plot (k=%d)' % n_clusters)
    # The silhouette coefficient can range from -1, 1 
    ax1.set_xlim([xMin, xMax])
    # The (n_clusters+1)*10 is for inserting blank space between silhouette
    # plots of individual clusters, to demarcate them clearly.
    ax1.set_ylim([0, len(cluster_labels) + (n_clusters + 1) * 10])
    y_lower = 10
    for i in range(n_clusters):
        # Aggregate the silhouette scores for samples belonging to
        # cluster i, and sort them
        ith_cluster_silhouette_values = \
            sample_silhouette_values[cluster_labels == i]
    
        ith_cluster_silhouette_values.sort()
    
        size_cluster_i = ith_cluster_silhouette_values.shape[0]
        y_upper = y_lower + size_cluster_i
    
        color = cm.spectral(float(i) / n_clusters)
        ax1.fill_betweenx(np.arange(y_lower, y_upper),
                          0, ith_cluster_silhouette_values,
                          facecolor=color, edgecolor=color, alpha=0.7)
    
        # Label the silhouette plots with their cluster numbers at the middle
        ax1.text(-0.05, y_lower + 0.5 * size_cluster_i, str(i))
    
        # Compute the new y_lower for next plot
        y_lower = y_upper + 10  # 10 for the 0 samples
    silhouette_avg = sample_silhouette_values.mean()
    ax1.axvline(x=silhouette_avg, color="red", linestyle="--") # average line
开发者ID:jjvalletta,项目名称:MiceMicroArray,代码行数:56,代码来源:script.py

示例8: __init__

 def __init__(self):
     self.window = GlutWindow(double=True, multisample=True)
     self.window.display_callback = self.display
     self.window.mouse_callback = self.mouse
     self.shader = ShaderProgram(vertex=vertex_shader, fragment=fragment_shader)
     self.shader.colormap = Texture1D(cm.spectral(linspace(0, 1, 256)), wrap_s="MIRRORED_REPEAT")
     self.shader.minval = (-2.5, -1.75)
     self.shader.maxval = (1.0, 1.75)
     self.vao = get_fullscreen_quad()
     self.history = []
开发者ID:adamlwgriffiths,项目名称:glitter,代码行数:10,代码来源:mandelbrot.py

示例9: silhouette_analysis

 def silhouette_analysis(self):
     if not self.pca_reduced:
         self.pc_analysis()
     range_n_clusters = range(2, 10)
     for n_clusters in range_n_clusters:
         fig, (ax1, ax2) = plt.subplots(1, 2)
         fig.set_size_inches(18, 7)
         ax1.set_xlim([-0.1, 1])
         ax1.set_ylim([0, len(self.pca_reduced) + (n_clusters + 1) * 10])
         clusterer = KMeans(n_clusters=n_clusters, random_state=10)
         cluster_labels = clusterer.fit_predict(self.pca_reduced)
         silhouette_avg = silhouette_score(self.pca_reduced, cluster_labels)
         print("For n_clusters =", n_clusters, "the average silhouette_score is :", silhouette_avg)
         sample_silhouette_values = silhouette_samples(self.pca_reduced, cluster_labels)
         y_lower = 10
         for i in range(n_clusters):
             ith_cluster_silhouette_values = sample_silhouette_values[cluster_labels == i]
             ith_cluster_silhouette_values.sort()
             size_cluster_i = ith_cluster_silhouette_values.shape[0]
             y_upper = y_lower + size_cluster_i
             color = cm.spectral(float(i) / n_clusters)
             ax1.fill_betweenx(np.arange(y_lower, y_upper), 0, ith_cluster_silhouette_values,
                               facecolor=color, edgecolor=color, alpha=0.7)
             ax1.text(-0.05, y_lower + 0.5 * size_cluster_i, str(i))
             y_lower = y_upper + 10
         ax1.set_title("The silhouette plot for the various clusters.")
         ax1.set_xlabel("The silhouette coefficient values")
         ax1.set_ylabel("Cluster label")
         ax1.axvline(x=silhouette_avg, color="red", linestyle="--")
         ax1.set_yticks([])
         ax1.set_xticks([-0.1, 0, 0.2, 0.4, 0.6, 0.8, 1])
         colors = cm.spectral(cluster_labels.astype(float) / n_clusters)
         ax2.scatter(self.pca_reduced[:, 0], self.pca_reduced[:, 1], marker='.', s=30, lw=0, alpha=0.7, c=colors)
         centers = clusterer.cluster_centers_
         ax2.scatter(centers[:, 0], centers[:, 1], marker='o', c="white", alpha=1, s=200)
         for i, c in enumerate(centers):
             ax2.scatter(c[0], c[1], marker='$%d$' % i, alpha=1, s=50)
         ax2.set_title("The visualization of the clustered data.")
         ax2.set_xlabel("Feature space for the 1st feature")
         ax2.set_ylabel("Feature space for the 2nd feature")
         plt.suptitle(("Silhouette analysis for KMeans clustering on sample data "
                       "with n_clusters = %d" % n_clusters),
                      fontsize=14, fontweight='bold')
开发者ID:vladislive,项目名称:transferparser,代码行数:43,代码来源:analysis.py

示例10: intraday_exec_curve

def intraday_exec_curve(data=None,step_sec=60*30,group_var='strategy_name_mapped'):
    """
    intraday_exec_curve : 
    Plot the daily exec curve in turnover cross by group_var
    """
    ##############################################################
    # input handling
    ##############################################################
    if (data is None):
        raise NameError('plot:intraday_exec_curve - data is missing')
    
    ##############################################################
    # aggregate data
    ##############################################################  
    grouped=data.groupby([st_data.gridTime(date=data.index,step_sec=step_sec,out_mode='ceil'),group_var])
    grouped_data=pd.DataFrame([{'date':k[0],group_var:k[1],
                          'mturnover_euro': np.sum(v.rate_to_euro*v.price*v.volume)*1e-6} for k,v in grouped])
    grouped_data=grouped_data.set_index('date')
    # on passe en string parce que ca ne sorte pas sinon !!   
    grouped_data['tmpindex']=[datetime.strftime(x.to_datetime(),'%Y%m%d-%H:%M:%S.%f') for x in grouped_data.index]
    grouped_data=grouped_data.sort_index(by=['tmpindex',group_var]).drop(['tmpindex'],axis=1)

    ##############################################################
    # plot
    ##############################################################  
    # ----- NEEDED    
    uni_strat=np.sort(np.unique(grouped_data[group_var].values).tolist())
    colors_strat=cm.spectral(np.linspace(0, 1.0, len(uni_strat)))
    # ----- PLOT 
    plt.figure()
    plt.hold(True)
    prev_date=''
    prev_date_cum=0
    for i in range(grouped_data.shape[0]):
    #for i in range(20):
        date=grouped_data.index[i].to_datetime()
        idx_uni_strat=np.nonzero(uni_strat==grouped_data[group_var].ix[i])[0][0]
        if (not date==prev_date):
            plt.gca().fill([date-timedelta(seconds=step_sec),date,date,date-timedelta(seconds=step_sec)],
                   [0,0,grouped_data['mturnover_euro'].ix[i],grouped_data['mturnover_euro'].ix[i]],
                   facecolor=colors_strat[idx_uni_strat],alpha = 0.5)
                   
            prev_date_cum=grouped_data['mturnover_euro'].ix[i]
            # ,edgecolor='none'
        else:
            plt.gca().fill([date-timedelta(seconds=step_sec),date,date,date-timedelta(seconds=step_sec)],
                   [prev_date_cum,prev_date_cum,prev_date_cum+grouped_data['mturnover_euro'].ix[i],prev_date_cum+grouped_data['mturnover_euro'].ix[i]],
                   facecolor=colors_strat[idx_uni_strat],alpha = 0.5)
                   
            prev_date_cum=prev_date_cum+grouped_data['mturnover_euro'].ix[i]     
        prev_date=date
    
    plt.hold(False)
    plt.legend(uni_strat)
    plt.show()
开发者ID:okrane,项目名称:framework,代码行数:55,代码来源:plot.py

示例11: plot_intraday_exec_curve

    def plot_intraday_exec_curve(self, duration = "", step_sec=60*30, group_var='strategy_name_mapped'):
        """
        intraday_exec_curve : 
        Plot the daily exec curve in turnover cross by group_var
        """
        self.get_agg_deals(step_sec=step_sec)
        
        ##############################################################
        # plot
        ##############################################################  
        # ----- NEEDED    
        uni_strat = np.sort(np.unique(self.data_agg_deals[group_var].values).tolist())
        colors_strat = cm.spectral(np.linspace(0, 1.0, len(uni_strat)))
        uni_strat_islabeled = np.array([False]*len(uni_strat))
        # ----- PLOT
        h = plt.figure(figsize = DEFAULT_FIGSIZE)
        axes = plt.gca()
        axes.grid(True)
        
        plt.hold(True)
        prev_date=''
        prev_date_cum=0
        for i in range(self.data_agg_deals.shape[0]):
            #---
            date=self.data_agg_deals.index[i].to_datetime()
            idx_uni_strat=np.nonzero(uni_strat==self.data_agg_deals[group_var].ix[i])[0][0]
            #--
            args=[]
            if (not date==prev_date):
                args.append([date-timedelta(seconds=step_sec),date,date,date-timedelta(seconds=step_sec)])
                args.append([0,0,self.data_agg_deals['mturnover_euro'].ix[i],self.data_agg_deals['mturnover_euro'].ix[i]])
                prev_date_cum=self.data_agg_deals['mturnover_euro'].ix[i]
            else:
                args.append([date-timedelta(seconds=step_sec),date,date,date-timedelta(seconds=step_sec)])
                args.append([prev_date_cum,prev_date_cum,prev_date_cum+self.data_agg_deals['mturnover_euro'].ix[i],prev_date_cum+self.data_agg_deals['mturnover_euro'].ix[i]])
                prev_date_cum=prev_date_cum+self.data_agg_deals['mturnover_euro'].ix[i] 
            #--
            kwargs={'facecolor':colors_strat[idx_uni_strat],'alpha':0.85}
            if not uni_strat_islabeled[idx_uni_strat]:
                kwargs.update({'label':uni_strat[idx_uni_strat]})
                uni_strat_islabeled[idx_uni_strat]=True
            #--
            plt.gca().fill(*args,**kwargs)
            prev_date=date
            
        plt.hold(False)
        plt.ylabel('Turnover (,000,000) euros')
        plt.title('Intraday traded curve: ' + duration, size = 'large')
        plt.legend()   

        return h
开发者ID:okrane,项目名称:framework,代码行数:51,代码来源:wrapper.py

示例12: Silhouette

def Silhouette(D,labels,k):
    """
    Taken from SKlearn's plot kmeans example
    D = matriz de distancia
    k = numero de clusters
    """
    plt.ion()
    fig, ax1 = plt.subplots()
    fig.set_size_inches(18, 7)
    ax1.set_xlim([-0.1, 1])
    ax1.set_ylim([0, len(D) + (k + 1) * 10])
    
    sample_silhouette_values = metrics.silhouette_samples(D , labels, metric='precomputed')
    
    y_lower = 10
    
    for i in range(k):
        ith_cluster_silhouette_values = \
                sample_silhouette_values[labels == i]

        ith_cluster_silhouette_values.sort()

        size_cluster_i = ith_cluster_silhouette_values.shape[0]
        y_upper = y_lower + size_cluster_i

        color = cm.spectral(float(i) / k)
        ax1.fill_betweenx(np.arange(y_lower, y_upper),
                                        0, ith_cluster_silhouette_values,
                                        facecolor=color, edgecolor=color, alpha=0.7)

        ax1.text(-0.05, y_lower + 0.5 * size_cluster_i, str(i))

        
        y_lower = y_upper + 10  
    
    ax1.set_title("The silhouette plot for the various clusters.")
    ax1.set_xlabel("The silhouette coefficient values")
    ax1.set_ylabel("Cluster label")

    silhouette_avg = metrics.silhouette_score(D , labels, metric='precomputed')	
    
    ax1.axvline(x=silhouette_avg, color="red", linestyle="--")

    ax1.set_yticks([]) 
    ax1.set_xticks([-0.1, 0, 0.2, 0.4, 0.6, 0.8, 1])
    
    plt.suptitle(("Silhouette analysis with n_clusters =",k," and average = ",silhouette_avg),
    fontsize=14, fontweight='bold')

    plt.show()
开发者ID:fbr1,项目名称:textmining-eac,代码行数:50,代码来源:clustering.py

示例13: timedomain

def timedomain(ycsb, toplt):
    arrays_k, arrays_v = splitbyrecordcount(ycsb[toplt])
    arrays_ku, arrays_vu = splitbyrecordcount(ycsb[2])
    arrays_kr, arrays_vr = splitbyrecordcount(ycsb[1])
    arrays_kv, arrays_vv = splitbyrecordcount(ycsb[0])

    maxheightu = max([max(x) for x in arrays_vu[1:9]])
    maxheightr = max([max(x) for x in arrays_vr[1:9]])
    maxheightv = max([max(x) for x in arrays_vv[1:9]])
    maxheight = max(maxheightu, maxheightr, maxheightv)
    #print maxheight

    K = []
    K.extend(arrays_k)

    V = []
    V.extend(arrays_v)

    #K = [ K[1], K[11], K[21] ]
    #V = [ V[1], V[11], V[21] ]

    checktype = ( "Update", "Read", "Verification" )[toplt]

    fig = plt.figure()
    ax = fig.add_subplot('111', projection='3d')

    it = 0
    for z in np.arange(1, 9):
        xs = K[z]
        ys = V[z]
        c = colmap.spectral(z/9.,1)
        ax.plot(xs, z * np.ones(xs.shape), zs=ys, zdir='z', color=c, zorder = -z)

    # Plot formatting
    font = {'family' : 'serif',
            'weight' : 'normal',
            'size'   : 12}
    plt.rc('font', **font)
    #plt.zlim(0, maxheight)

    #plt.legend(checktype, loc=2, bbox_to_anchor=(1.05, 1),
                #borderaxespad=0. )
    ax.set_zlim3d(0, maxheight)
    ax.set_xlabel('Time (ms)')
    ax.set_ylabel('Test Run')
    ax.set_zlabel('Runtime')
    ax.tick_params(axis='both', labelsize = 8)
    plt.savefig( getfilename("timeseries", checktype),
                 format='png', dpi=300, bbox_inches='tight',
                 transparent=True )
开发者ID:ashleyblackmore,项目名称:ycsb-log-parser,代码行数:50,代码来源:ycsb-interpreter.py

示例14: animate

def animate(k):
    i = int(k/3)
    if k == 1:
        ax.view_init(20, 215)
        for j, y in enumerate(ys):
            y_seg = y[0:2]
            plot2(y_seg, fig, cm.spectral(j/len(ys)))
        ax.scatter(0.16, 0.16, 0.16, c="g", alpha=0.4, s=500)
        ax.scatter(0.82, 0.17, 0.17, c="b", alpha=0.4, s=500)
        ax.scatter(0.17, 0.82, 0.17, c="r", alpha=0.4, s=500)
        ax.scatter(0.17, 0.17, 0.82, c="k", alpha=0.4, s=500)
        set_title("Decision Space")
    if i > 0 and i < N:# ys.shape[1]:
        ax.view_init(20, 215+ANGLE1*k/N/3)
        for j, y in enumerate(ys):
            y_seg = y[i-1:i+1]
            plot2(y_seg, fig, cm.spectral(j/len(ys)))
        set_title("Decision Space")
    elif i >= N:# ys.shape[1]:
        ax.set_axis_off()
        j = k - 3*N
        print "rotate" + str(j)
        ax.view_init(20, (215+ANGLE1+ANGLE2*3*j/int(ANGLE2))%360)
开发者ID:anueinberg,项目名称:esmp2013_simulation,代码行数:23,代码来源:create_animation.py

示例15: plot_partial_factors

def plot_partial_factors(ds,sts,x=0,y=1,cmap=None,axes='off',
                        nude=False):

    mx = np.max(np.abs(ds.samples))
    xmx = mx*1.1
    hw = .05*xmx
    w = .01*xmx
    plt.arrow(-xmx,0,2*xmx,0,color = 'gray',alpha=.7,width=w,
                head_width=hw,length_includes_head=True)
    plt.arrow(0,-mx,0,2*mx,color = 'gray',alpha=.7, width=w,
                head_width=hw,length_includes_head=True)
    ntables = len(np.unique(ds.chunks))
    if cmap is None:
        cmap = cm.spectral(np.linspace(.2,.85,ntables))
    m,ncol = ds.shape
    nrows = m/ntables
    data = ds.samples.T.reshape((ncol,nrows,ntables),order='F')

    centers = np.mean(data,2).T[:,[x,y]]
    plt.scatter(centers[:,0],centers[:,1])

    for t in range(ntables):
        tab = data[:,:,t].T[:,[x,y]]
        for r in range(nrows):
            a,b = centers[r,:]
            j,k = tab[r,:]
            plt.plot([a,j],[b,k],c=cmap[t],lw=2,alpha=.5)
    plt.axis('equal')
    plt.axis((-mx,mx,-mx,mx))
    #plt.axis('equal')
    plt.axis(axes)
    if not nude:
        for t in range(nrows):
            plt.annotate(ds.targets[t],xy = (centers[t,0], centers[t,1]))

        plt.text(-xmx,.05*mx,'$\lambda = %s$'%np.round(sts.eigv[x],2))
        plt.text(mx*.05,mx*.9,'$\lambda = %s$'%np.round(sts.eigv[y],2))
        tau = '$\\tau = $'
        perc = '$\%$'
        mpl.rcParams['text.usetex'] = False
        plt.text(-xmx,-.1*mx, '%s $%s$%s' %
                (tau,np.round(100*sts.inertia[x],0),perc))
        plt.text(xmx*.05,mx*.8, '%s $%s$%s' %
                (tau,np.round(100*sts.inertia[y],0),perc))
        plt.text(-.15*xmx,.8*mx,'$%s$'%(y+1), fontsize=20)
        plt.text(xmx*.85,-mx*.2,'$%s$'%(x+1),fontsize=20)

    plt.axis('scaled')
    plt.axis([-xmx,xmx,-mx,mx])
开发者ID:mfalkiewicz,项目名称:PyMVPA,代码行数:49,代码来源:statis.py


注:本文中的matplotlib.cm.spectral函数示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。