当前位置: 首页>>代码示例>>Python>>正文


Python numpy.nmax函数代码示例

本文整理汇总了Python中numpy.nmax函数的典型用法代码示例。如果您正苦于以下问题:Python nmax函数的具体用法?Python nmax怎么用?Python nmax使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。


在下文中一共展示了nmax函数的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: geweke_plot

def geweke_plot(data, name, format='png', suffix='-diagnostic', path='./', fontmap = None, 
    verbose=1):
    # Generate Geweke (1992) diagnostic plots

    if fontmap is None: fontmap = {1:10, 2:8, 3:6, 4:5, 5:4}

    # Generate new scatter plot
    figure()
    x, y = transpose(data)
    scatter(x.tolist(), y.tolist())

    # Plot options
    xlabel('First iteration', fontsize='x-small')
    ylabel('Z-score for %s' % name, fontsize='x-small')

    # Plot lines at +/- 2 sd from zero
    pyplot((nmin(x), nmax(x)), (2, 2), '--')
    pyplot((nmin(x), nmax(x)), (-2, -2), '--')

    # Set plot bound
    ylim(min(-2.5, nmin(y)), max(2.5, nmax(y)))
    xlim(0, nmax(x))

    # Save to file
    if not os.path.exists(path):
        os.mkdir(path)
    if not path.endswith('/'):
        path += '/'
    savefig("%s%s%s.%s" % (path, name, suffix, format))
开发者ID:CosmologyTaskForce,项目名称:pymc,代码行数:29,代码来源:Matplot.py

示例2: geweke_plot

def geweke_plot(data,
                name,
                format='png',
                suffix='-diagnostic',
                path='./',
                fontmap=None):
    '''
    Generate Geweke (1992) diagnostic plots.
    
    :Arguments:
        data: list
            List (or list of lists for vector-valued variables) of Geweke diagnostics, output
            from the `pymc.diagnostics.geweke` function .

        name: string
            The name of the plot.

        format (optional): string
            Graphic output format (defaults to png).

        suffix (optional): string
            Filename suffix (defaults to "-diagnostic").

        path (optional): string
            Specifies location for saving plots (defaults to local directory).

        fontmap (optional): dict
            Font map for plot.
    
    '''

    if fontmap is None:
        fontmap = {1: 10, 2: 8, 3: 6, 4: 5, 5: 4}

    # Generate new scatter plot
    figure()
    x, y = transpose(data)
    scatter(x.tolist(), y.tolist())

    # Plot options
    xlabel('First iteration', fontsize='x-small')
    ylabel('Z-score for %s' % name, fontsize='x-small')

    # Plot lines at +/- 2 sd from zero
    pyplot((nmin(x), nmax(x)), (2, 2), '--')
    pyplot((nmin(x), nmax(x)), (-2, -2), '--')

    # Set plot bound
    ylim(min(-2.5, nmin(y)), max(2.5, nmax(y)))
    xlim(0, nmax(x))

    # Save to file
    if not os.path.exists(path):
        os.mkdir(path)
    if not path.endswith('/'):
        path += '/'
    savefig("%s%s%s.%s" % (path, name, suffix, format))
开发者ID:shfengcj,项目名称:pymc,代码行数:57,代码来源:Matplot.py

示例3: csv_output

    def csv_output(self):
        
        """This method is used to report the results of
        a subscription test to a csv file"""

        # determine the file name
        csv_filename = "subscription-%s-%siter-%s-%s.csv" % (self.subscriptiontype,
                                                      self.iterations,
                                                      self.chart_type.lower(),
                                                      self.testdatetime)

        # initialize the csv file
        csvfile_stream = open(csv_filename, "w")
        csvfile_writer = csv.writer(csvfile_stream, delimiter=',', quoting=csv.QUOTE_MINIMAL)

        # iterate over the SIBs
        for sib in self.results.keys():                    
                                     
            row = [sib]
            
            # add all the times
            for value in self.results[sib]:
                row.append(value)

            # add the mean, min, max and variance value of the times to the row
            row.append(round(nmean(self.results[sib]),3))                
            row.append(round(nmin(self.results[sib]),3))                
            row.append(round(nmax(self.results[sib]),3))                
            row.append(round(nvar(self.results[sib]),3))                

            # write the row
            csvfile_writer.writerow(row)
                
        # close the csv file
        csvfile_stream.close()
开发者ID:desmovalvo,项目名称:pes,代码行数:35,代码来源:subscription_test.py

示例4: get_network_extents

def get_network_extents(net):
    '''
    For a given Emme Network, find the envelope (extents) of all of its elements.
    Includes link vertices as well as nodes.
    
    Args:
        -net: An Emme Network Object
    
    Returns:
        minx, miny, maxx, maxy tuple
    '''
    xs, ys = [], []
    for node in net.nodes():
        xs.append(node.x)
        ys.append(node.y)
    for link in net.links():
        for x, y in link.vertices:
            xs.append(x)
            ys.append(y)
    xa = array(xs)
    ya = array(ys)
    
    return nmin(xa) - 1.0, nmin(ya) - 1.0, nmax(xa) + 1.0, nmax(ya) + 1.0
开发者ID:kamelisl,项目名称:TMGToolbox,代码行数:23,代码来源:spatial_index.py

示例5: discrepancy_plot

def discrepancy_plot(
    data, name="discrepancy", report_p=True, format="png", suffix="-gof", path="./", fontmap=None, verbose=1
):
    # Generate goodness-of-fit deviate scatter plot

    if verbose > 0:
        print_("Plotting", name + suffix)

    if fontmap is None:
        fontmap = {1: 10, 2: 8, 3: 6, 4: 5, 5: 4}

    # Generate new scatter plot
    figure()
    try:
        x, y = transpose(data)
    except ValueError:
        x, y = data
    scatter(x, y)

    # Plot x=y line
    lo = nmin(ravel(data))
    hi = nmax(ravel(data))
    datarange = hi - lo
    lo -= 0.1 * datarange
    hi += 0.1 * datarange
    pyplot((lo, hi), (lo, hi))

    # Plot options
    xlabel("Observed deviates", fontsize="x-small")
    ylabel("Simulated deviates", fontsize="x-small")

    if report_p:
        # Put p-value in legend
        count = sum(s > o for o, s in zip(x, y))
        text(
            lo + 0.1 * datarange,
            hi - 0.1 * datarange,
            "p=%.3f" % (count / len(x)),
            horizontalalignment="center",
            fontsize=10,
        )

    # Save to file
    if not os.path.exists(path):
        os.mkdir(path)
    if not path.endswith("/"):
        path += "/"
    savefig("%s%s%s.%s" % (path, name, suffix, format))
开发者ID:roban,项目名称:pymc,代码行数:48,代码来源:Matplot.py

示例6: csv_output

    def csv_output(self):
        
        """This method is used to report the results of
        an update test to a csv file"""

        # determine the file name
        csv_filename = "update-%s-%sstep-%smax-%siter-%s-%s.csv" % (self.updatetype,
                                                                    self.step,
                                                                    self.limit,
                                                                    self.iterations,
                                                                    self.chart_type.lower(),
                                                                    self.testdatetime)

        # initialize the csv file
        csvfile_stream = open(csv_filename, "w")
        csvfile_writer = csv.writer(csvfile_stream, delimiter=',', quoting=csv.QUOTE_MINIMAL)

        # iterate over the SIBs
        for sib in self.results.keys():                    
                         
            # iterate over the possible block lengths
            for triple_length in sorted(self.results[sib].keys(), key=int):
            
                row = [sib]
    
                # add the length of the block to the row
                row.append(triple_length)

                # add all the times
                for value in self.results[sib][triple_length]:
                    row.append(value)

                # add the mean value of the times to the row
                row.append(round(nmean(self.results[sib][triple_length]),3))                
                row.append(round(nmin(self.results[sib][triple_length]),3))                
                row.append(round(nmax(self.results[sib][triple_length]),3))                
                row.append(round(nvar(self.results[sib][triple_length]),3))                

                # write the row
                csvfile_writer.writerow(row)

        # close the csv file
        csvfile_stream.close()
开发者ID:desmovalvo,项目名称:pes,代码行数:43,代码来源:update_test.py

示例7: _creer_cmap

    def _creer_cmap(self, seuils):
        zmax = nmax(self._Z)
        zmin = nmin(self._Z)
        delta = zmax - zmin
        # On les ramène entre 0 et 1 par transformation affine
        if delta:
            a = 1/delta
            b = -zmin/delta
        seuils = [0] + [a*z + b for z in seuils if zmin < z < zmax] + [1] # NB: < et pas <=
        print(seuils)
        cdict = {'red': [], 'green': [], 'blue': []}
        def add_col(val, color1, color2):
            cdict['red'].append((val, color1[0], color2[0]))
            cdict['green'].append((val, color1[1], color2[1]))
            cdict['blue'].append((val, color1[2], color2[2]))

        n = len(self.couleurs)
        for i, seuil in enumerate(seuils):
            add_col(seuil, self.couleurs[(i - 1)%n], self.couleurs[i%n])
        return LinearSegmentedColormap('seuils', cdict, 256)
开发者ID:wxgeo,项目名称:geophar,代码行数:20,代码来源:__init__.py

示例8: CompareFortran

def CompareFortran(**args):
	conf = pyprop.Load("config_compare_fortran.ini")
	prop = pyprop.Problem(conf)
	prop.SetupStep()

	init = prop.psi.Copy()

	for t in prop.Advance(5):
		corr = abs(prop.psi.InnerProduct(init))**2
		print "Time = %f, initial state correlation = %f" % (t, corr)

	corr = abs(prop.psi.InnerProduct(init))**2
	t = prop.PropagatedTime
	print "Time = %f, initial state correlation = %f" % (t, corr)

	#Load fortran data and compare
	fdata = pylab.load("fortran_propagation.dat")
	print "Max difference pyprop/fortran: %e" % nmax(abs(prop.psi.GetData())**2 - fdata[1:])

	return prop
开发者ID:AtomAleks,项目名称:PyProp,代码行数:20,代码来源:example.py

示例9: discrepancy_plot

def discrepancy_plot(data, name, report_p=True, format='png', suffix='-gof', path='./', fontmap = {1:10, 2:8, 3:6, 4:5, 5:4}, verbose=1):
    # Generate goodness-of-fit deviate scatter plot
    if verbose>0:
        print 'Plotting', name+suffix

    # Generate new scatter plot
    figure()
    try:
        x, y = transpose(data)
    except ValueError:
        x, y = data
    scatter(x, y)

    # Plot x=y line
    lo = nmin(ravel(data))
    hi = nmax(ravel(data))
    datarange = hi-lo
    lo -= 0.1*datarange
    hi += 0.1*datarange
    pyplot((lo, hi), (lo, hi))

    # Plot options
    xlabel('Observed deviates', fontsize='x-small')
    ylabel('Simulated deviates', fontsize='x-small')

    if report_p:
        # Put p-value in legend
        count = sum(s>o for o,s in zip(x,y))
        text(lo+0.1*datarange, hi-0.1*datarange,
             'p=%.3f' % (count/len(x)), horizontalalignment='center',
             fontsize=10)

    # Save to file
    if not os.path.exists(path):
        os.mkdir(path)
    if not path.endswith('/'):
        path += '/'
    savefig("%s%s%s.%s" % (path, name, suffix, format))
开发者ID:along1x,项目名称:pymc,代码行数:38,代码来源:Matplot.py

示例10: discrepancy_plot

def discrepancy_plot(
    data, name='discrepancy', report_p=True, format='png', suffix='-gof', path='./',
        fontmap=None):
    '''
    Generate goodness-of-fit deviate scatter plot.
    
    :Arguments:
        data: list
            List (or list of lists for vector-valued variables) of discrepancy values, output
            from the `pymc.diagnostics.discrepancy` function .

        name: string
            The name of the plot.
            
        report_p: bool
            Flag for annotating the p-value to the plot.

        format (optional): string
            Graphic output format (defaults to png).

        suffix (optional): string
            Filename suffix (defaults to "-gof").

        path (optional): string
            Specifies location for saving plots (defaults to local directory).

        fontmap (optional): dict
            Font map for plot.
    
    '''

    if verbose > 0:
        print_('Plotting', name + suffix)

    if fontmap is None:
        fontmap = {1: 10, 2: 8, 3: 6, 4: 5, 5: 4}

    # Generate new scatter plot
    figure()
    try:
        x, y = transpose(data)
    except ValueError:
        x, y = data
    scatter(x, y)

    # Plot x=y line
    lo = nmin(ravel(data))
    hi = nmax(ravel(data))
    datarange = hi - lo
    lo -= 0.1 * datarange
    hi += 0.1 * datarange
    pyplot((lo, hi), (lo, hi))

    # Plot options
    xlabel('Observed deviates', fontsize='x-small')
    ylabel('Simulated deviates', fontsize='x-small')

    if report_p:
        # Put p-value in legend
        count = sum(s > o for o, s in zip(x, y))
        text(lo + 0.1 * datarange, hi - 0.1 * datarange,
             'p=%.3f' % (count / len(x)), horizontalalignment='center',
             fontsize=10)

    # Save to file
    if not os.path.exists(path):
        os.mkdir(path)
    if not path.endswith('/'):
        path += '/'
    savefig("%s%s%s.%s" % (path, name, suffix, format))
开发者ID:Gwill,项目名称:pymc,代码行数:70,代码来源:Matplot.py

示例11: plot

def plot(
    data, name, format='png', suffix='', path='./', common_scale=True, datarange=(None, None),
        new=True, last=True, rows=1, num=1, fontmap=None, verbose=1):
    """
    Generates summary plots for nodes of a given PyMC object.

    :Arguments:
        data: PyMC object, trace or array
            A trace from an MCMC sample or a PyMC object with one or more traces.

        name: string
            The name of the object.

        format (optional): string
            Graphic output format (defaults to png).

        suffix (optional): string
            Filename suffix.

        path (optional): string
            Specifies location for saving plots (defaults to local directory).

        common_scale (optional): bool
            Specifies whether plots of multivariate nodes should be on the same scale
            (defaults to True).

    """
    if fontmap is None:
        fontmap = {1: 10, 2: 8, 3: 6, 4: 5, 5: 4}

    # If there is only one data array, go ahead and plot it ...
    if ndim(data) == 1:

        if verbose > 0:
            print_('Plotting', name)

        # If new plot, generate new frame
        if new:

            figure(figsize=(10, 6))

        # Call trace
        trace(
            data,
            name,
            datarange=datarange,
            rows=rows * 2,
            columns=2,
            num=num + 3 * (num - 1),
            last=last,
            fontmap=fontmap)
        # Call autocorrelation
        autocorrelation(
            data,
            name,
            rows=rows * 2,
            columns=2,
            num=num + 3 * (
                num - 1) + 2,
            last=last,
            fontmap=fontmap)
        # Call histogram
        histogram(
            data,
            name,
            datarange=datarange,
            rows=rows,
            columns=2,
            num=num * 2,
            last=last,
            fontmap=fontmap)

        if last:
            if not os.path.exists(path):
                os.mkdir(path)
            if not path.endswith('/'):
                path += '/'
            savefig("%s%s%s.%s" % (path, name, suffix, format))

    else:
        # ... otherwise plot recursively
        tdata = swapaxes(data, 0, 1)

        datarange = (None, None)
        # Determine common range for plots
        if common_scale:
            datarange = (nmin(tdata), nmax(tdata))

        # How many rows?
        _rows = min(4, len(tdata))

        for i in range(len(tdata)):

            # New plot or adding to existing?
            _new = not i % _rows
            # Current subplot number
            _num = i % _rows + 1
            # Final subplot of current figure?
            _last = (_num == _rows) or (i == len(tdata) - 1)

#.........这里部分代码省略.........
开发者ID:Gwill,项目名称:pymc,代码行数:101,代码来源:Matplot.py

示例12: summary_plot


#.........这里部分代码省略.........
            traces = [variable.trace(chain=chain)]
        else:
            chains = variable.trace.db.chains
            traces = [variable.trace(chain=i) for i in range(chains)]

        if gs is None:
            # Initialize plot
            if rhat and chains > 1:
                gs = gridspec.GridSpec(1, 2, width_ratios=[3, 1])

            else:

                gs = gridspec.GridSpec(1, 1)

            # Subplot for confidence intervals
            interval_plot = subplot(gs[0])

        # Get quantiles
        data = [calc_quantiles(d, quantiles) for d in traces]
        if hpd:
            # Substitute HPD interval
            for i, d in enumerate(traces):
                hpd_interval = calc_hpd(d, alpha)
                data[i][quantiles[0]] = hpd_interval[0]
                data[i][quantiles[-1]] = hpd_interval[1]

        data = [[d[q] for q in quantiles] for d in data]
        # Ensure x-axis contains range of current interval
        if plotrange:
            plotrange = [min(
                         plotrange[0],
                         nmin(data)),
                         max(plotrange[1],
                             nmax(data))]
        else:
            plotrange = [nmin(data), nmax(data)]

        try:
            # First try missing-value stochastic
            value = variable.get_stoch_value()
        except AttributeError:
            # All other variable types
            value = variable.value

        # Number of elements in current variable
        k = size(value)
        
        # Append variable name(s) to list
        if k > 1:
            names = var_str(varname, shape(value)[int(shape(value)[0]==1):])
            labels += names
        else:
            labels.append(varname)
            # labels.append('\n'.join(varname.split('_')))

        # Add spacing for each chain, if more than one
        e = [0] + [(chain_spacing * ((i + 2) / 2)) * (
            -1) ** i for i in range(chains - 1)]

        # Loop over chains
        for j, quants in enumerate(data):

            # Deal with multivariate nodes
            if k > 1:
                ravelled_quants = list(map(ravel, quants))
                
开发者ID:Gwill,项目名称:pymc,代码行数:66,代码来源:Matplot.py

示例13: plot

    def plot( self, ax ):

        exec_time_arr = self.exec_time_arr
        n_int_arr = self.n_int_arr[0, :]
        real_memsize_arr = self.real_memsize_arr[0, :]

        rand_arr = arange( len( self.rand_list ) ) + 1
        width = 0.45

        if exec_time_arr.shape[0] == 1:
            shift = width / 2.0
            ax.bar( rand_arr - shift, exec_time_arr[0, :], width, color = 'lightgrey' )

        elif self.exec_time_arr.shape[0] == 2:
            max_exec_time = nmax( exec_time_arr )

            ax.set_ylabel( '$\mathrm{execution \, time \, [sec]}$', size = 20 )
            ax.set_xlabel( '$n_{\mathrm{rnd}}  \;-\; \mathrm{number \, of \, random \, parameters}$', size = 20 )

            ax.bar( rand_arr - width, exec_time_arr[0, :], width,
                    hatch = '/', color = 'white', label = 'C' ) # , color = 'lightgrey' )
            ax.bar( rand_arr, exec_time_arr[1, :], width,
                    color = 'lightgrey', label = 'numpy' )

            yscale = 1.25
            ax_xlim = rand_arr[-1] + 1
            ax_ylim = max_exec_time * yscale

            ax.set_xlim( 0, ax_xlim )
            ax.set_ylim( 0, ax_ylim )

            ax2 = ax.twinx()
            ydata = exec_time_arr[1, :] / exec_time_arr[0, :]
            ax2.plot( rand_arr, ydata, '-o', color = 'black',
                      linewidth = 1, label = 'numpy/C' )

            ax2.plot( [rand_arr[0] - 1, rand_arr[-1] + 1], [1, 1], '-' )
            ax2.set_ylabel( '$\mathrm{time}(  \mathsf{numpy}  ) / \mathrm{ time }(\mathsf{C}) \; [-]$', size = 20 )
            ax2_ylim = nmax( ydata ) * yscale
            ax2_xlim = rand_arr[-1] + 1
            ax2.set_ylim( 0, ax2_ylim )
            ax2.set_xlim( 0, ax2_xlim )

            ax.set_xticks( rand_arr )
            ax.set_xticklabels( rand_arr, size = 14 )
            xticks = [ '%.2g' % n_int for n_int in n_int_arr ]
            ax3 = ax.twiny()
            ax3.set_xlim( 0, rand_arr[-1] + 1 )
            ax3.set_xticks( rand_arr )
            ax3.set_xlabel( '$n_{\mathrm{int}}$', size = 20 )
            ax3.set_xticklabels( xticks, rotation = 30 )

            'set the tick label size of the lower X axis'
            X_lower_tick = 14
            xt = ax.get_xticklabels()
            for t in xt:
                t.set_fontsize( X_lower_tick )

            'set the tick label size of the upper X axis'
            X_upper_tick = 12
            xt = ax3.get_xticklabels()
            for t in xt:
                t.set_fontsize( X_upper_tick )

            'set the tick label size of the Y axes'
            Y_tick = 14
            yt = ax2.get_yticklabels() + ax.get_yticklabels()
            for t in yt:
                t.set_fontsize( Y_tick )

            'set the legend position and font size'
            leg_fontsize = 16
            leg = ax.legend( loc = ( 0.02, 0.83 ) )
            for t in leg.get_texts():
                t.set_fontsize( leg_fontsize )
            leg = ax2.legend( loc = ( 0.705, 0.90 ) )
            for t in leg.get_texts():
                t.set_fontsize( leg_fontsize )
开发者ID:sarosh-quraishi,项目名称:simvisage,代码行数:78,代码来源:performance_db.py

示例14: summary_plot


#.........这里部分代码省略.........
        varname = variable.__name__

        # Retrieve trace(s)
        i = 0
        traces = []
        while True:
           try:
               #traces.append(pymc_obj.trace(varname, chain=i)[:])
               traces.append(variable.trace(chain=i))
               i+=1
           except KeyError:
               break
               
        chains = len(traces)
        
        if gs is None:
            # Initialize plot
            if rhat and chains>1:
                gs = gridspec.GridSpec(1, 2, width_ratios=[3,1])

            else:
                
                gs = gridspec.GridSpec(1, 1)
                
            # Subplot for confidence intervals
            interval_plot = subplot(gs[0])
                
        # Get quantiles
        data = [calc_quantiles(d, quantiles) for d in traces]
        data = [[d[q] for q in quantiles] for d in data]
        
        # Ensure x-axis contains range of current interval
        if plotrange:
            plotrange = [min(plotrange[0], nmin(data)), max(plotrange[1], nmax(data))]
        else:
            plotrange = [nmin(data), nmax(data)]
        
        try:
            # First try missing-value stochastic
            value = variable.get_stoch_value()
        except AttributeError:
            # All other variable types
            value = variable.value

        # Number of elements in current variable
        k = size(value)
        
        # Append variable name(s) to list
        if k>1:
            names = var_str(varname, shape(value))
            labels += names
        else:
            labels.append('\n'.join(varname.split('_')))
            
        # Add spacing for each chain, if more than one
        e = [0] + [(chain_spacing * ((i+2)/2))*(-1)**i for i in range(chains-1)]
        
        # Loop over chains
        for j,quants in enumerate(data):
            
            # Deal with multivariate nodes
            if k>1:

                for i,q in enumerate(transpose(quants)):
                    
                    # Y coordinate with jitter
开发者ID:along1x,项目名称:pymc,代码行数:67,代码来源:Matplot.py

示例15: PropagateWavePacket

def PropagateWavePacket(**args):
	#Set up problem
	prop = SetupProblem(**args)
	conf = prop.Config

	#Setup traveling wavepacket initial state
	f = lambda x: conf.Wavepacket.function(conf.Wavepacket, x)
	bspl = prop.psi.GetRepresentation().GetRepresentation(0).GetBSplineObject()
	c = bspl.ExpandFunctionInBSplines(f)
	prop.psi.GetData()[:] = c
	prop.psi.Normalize()
	initialPsi = prop.psi.Copy()

	#Get x-grid
	subProp = prop.Propagator.SubPropagators[0]
	subProp.InverseTransform()
	grid = prop.psi.GetRepresentation().GetLocalGrid(0)
	subProp.ForwardTransform()

	#Setup equispaced x grid
	x_min = conf.BSplineRepresentation.xmin
	x_max = conf.BSplineRepresentation.xmax
	grid_eq = linspace(x_min, x_max, grid.size)
	x_spacing = grid_eq[1] - grid_eq[0]
	
	#Set up fft grid
	k_spacing = 1.0 / grid_eq.size
	k_max = pi / x_spacing
	k_min = -k_max
	k_spacing = (k_max - k_min) / grid_eq.size
	grid_fft = zeros(grid_eq.size, dtype=double)
 	grid_fft[:grid_eq.size/2+1] = r_[0.0:k_max:k_spacing]
   	grid_fft[grid_eq.size/2:] = r_[k_min:0.0:k_spacing]
	print "Momentum space resolution = %f a.u." % k_spacing
	
	k0 = conf.Wavepacket.k0
	k0_trunk = 5 * k0
	trunkIdx = list(nwhere(abs(grid_fft) <= k0_trunk)[0])

	rcParams['interactive'] = True
	figure()
	p1 = subplot(211)
	p2 = subplot(212)
	p1.hold(False)

	psi_eq = zeros((grid.size), dtype=complex)

	for t in prop.Advance(40):
		print "t = %f, norm = %.15f, P = %.15f " % \
			( t, prop.psi.GetNorm(), abs(prop.psi.InnerProduct(initialPsi))**2 )
		sys.stdout.flush()

		subProp.InverseTransform()
		p1.plot(grid, abs(prop.psi.GetData())**2)
		subProp.ForwardTransform()
		bspl.ConstructFunctionFromBSplineExpansion(prop.psi.GetData(), grid_eq, psi_eq)
		psi_fft = (abs(fft.fft(psi_eq))**2)
		psi_fft_max = nmax(psi_fft)
		psi_fft /= psi_fft_max

		#Plot momentum space |psi|**2
		p2.hold(False)
		p2.semilogy(grid_fft[trunkIdx], psi_fft[trunkIdx] + 1e-21)
		p2.hold(True)
		p2.semilogy([0,0], [1e-20, psi_fft_max], 'r-')
		p2.semilogy([-k0,-k0], [1e-20, psi_fft_max], 'g--')
		p2.semilogy([k0,k0], [1e-20, psi_fft_max], 'g--')

		#Set subplot axis
		p1.axis([grid[0],grid[-1],0,0.1])
		p2.axis([-k0_trunk, k0_trunk, 1e-20, psi_fft_max])
		show()

	hold(True)
		

	return prop
开发者ID:AtomAleks,项目名称:PyProp,代码行数:77,代码来源:example.py


注:本文中的numpy.nmax函数示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。