当前位置: 首页>>代码示例>>Python>>正文


Python generator.Generator方法代码示例

本文整理汇总了Python中generator.Generator方法的典型用法代码示例。如果您正苦于以下问题:Python generator.Generator方法的具体用法?Python generator.Generator怎么用?Python generator.Generator使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在generator的用法示例。


在下文中一共展示了generator.Generator方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: __init__

# 需要导入模块: import generator [as 别名]
# 或者: from generator import Generator [as 别名]
def __init__(self, width = 28, height= 28, channels = 1, latent_size=100, epochs =50000, batch=32, checkpoint=50,model_type=-1):
        self.W = width
        self.H = height
        self.C = channels
        self.EPOCHS = epochs
        self.BATCH = batch
        self.CHECKPOINT = checkpoint
        self.model_type=model_type

        self.LATENT_SPACE_SIZE = latent_size

        self.generator = Generator(height=self.H, width=self.W, channels=self.C, latent_size=self.LATENT_SPACE_SIZE)
        self.discriminator = Discriminator(height=self.H, width=self.W, channels=self.C)
        self.gan = GAN(generator=self.generator.Generator, discriminator=self.discriminator.Discriminator)

        self.load_MNIST() 
开发者ID:PacktPublishing,项目名称:Generative-Adversarial-Networks-Cookbook,代码行数:18,代码来源:train.py

示例2: plot_checkpoint

# 需要导入模块: import generator [as 别名]
# 或者: from generator import Generator [as 别名]
def plot_checkpoint(self,e):
        filename = "/data/sample_"+str(e)+".png"

        noise = self.sample_latent_space(16)
        images = self.generator.Generator.predict(noise)
        
        plt.figure(figsize=(10,10))
        for i in range(images.shape[0]):
            plt.subplot(4, 4, i+1)
            image = images[i, :, :, :]
            image = np.reshape(image, [self.H,self.W])
            plt.imshow(image, cmap='gray')
            plt.axis('off')
        plt.tight_layout()
        plt.savefig(filename)
        plt.close('all')
        return 
开发者ID:PacktPublishing,项目名称:Generative-Adversarial-Networks-Cookbook,代码行数:19,代码来源:train.py

示例3: __init__

# 需要导入模块: import generator [as 别名]
# 或者: from generator import Generator [as 别名]
def __init__(self, width = 28, height= 28, channels = 1, latent_size=100, epochs =50000, batch=32, checkpoint=50,model_type=-1,data_path = ''):
        self.W = width
        self.H = height
        self.C = channels
        self.EPOCHS = epochs
        self.BATCH = batch
        self.CHECKPOINT = checkpoint
        self.model_type=model_type

        self.LATENT_SPACE_SIZE = latent_size

        self.generator = Generator(height=self.H, width=self.W, channels=self.C, latent_size=self.LATENT_SPACE_SIZE,model_type = 'DCGAN')
        self.discriminator = Discriminator(height=self.H, width=self.W, channels=self.C,model_type = 'DCGAN')
        self.gan = GAN(generator=self.generator.Generator, discriminator=self.discriminator.Discriminator)

        #self.load_MNIST()
        self.load_npy(data_path) 
开发者ID:PacktPublishing,项目名称:Generative-Adversarial-Networks-Cookbook,代码行数:19,代码来源:train.py

示例4: main

# 需要导入模块: import generator [as 别名]
# 或者: from generator import Generator [as 别名]
def main(m_path, img_path, out_dir, light=False):
    logger = get_logger("inference")
    logger.info(f"generating image from {img_path}")
    try:
        g = Generator(light=light)
        g.load_weights(tf.train.latest_checkpoint(m_path))
    except ValueError as e:
        logger.error(e)
        logger.error("Failed to load specified weight.")
        logger.error("If you trained your model with --light, "
                     "consider adding --light when executing this script; otherwise, "
                     "do not add --light when executing this script.")
        exit(1)
    img = np.array(Image.open(img_path).convert("RGB"))
    img = np.expand_dims(img, 0).astype(np.float32) / 127.5 - 1
    out = ((g(img).numpy().squeeze() + 1) * 127.5).astype(np.uint8)
    if out_dir != "" and not os.path.isdir(out_dir):
        os.makedirs(out_dir)
    if out_dir == "":
        out_dir = "."
    out_path = os.path.join(out_dir, os.path.split(img_path)[1])
    imwrite(out_path, out)
    logger.info(f"generated image saved to {out_path}") 
开发者ID:mnicnc404,项目名称:CartoonGan-tensorflow,代码行数:25,代码来源:inference_with_ckpt.py

示例5: train

# 需要导入模块: import generator [as 别名]
# 或者: from generator import Generator [as 别名]
def train(self):
        for e in range(self.EPOCHS):
            # Train Discriminator
            # Make the training batch for this model be half real, half noise
            # Grab Real Images for this training batch
            count_real_images = int(self.BATCH/2)
            starting_index = randint(0, (len(self.X_train)-count_real_images))
            real_images_raw = self.X_train[ starting_index : (starting_index + count_real_images) ]
            x_real_images = real_images_raw.reshape( count_real_images, self.W, self.H, self.C )
            y_real_labels = np.ones([count_real_images,1])

            # Grab Generated Images for this training batch
            latent_space_samples = self.sample_latent_space(count_real_images)
            x_generated_images = self.generator.Generator.predict(latent_space_samples)
            y_generated_labels = np.zeros([self.BATCH-count_real_images,1])

            # Combine to train on the discriminator
            x_batch = np.concatenate( [x_real_images, x_generated_images] )
            y_batch = np.concatenate( [y_real_labels, y_generated_labels] )

            # Now, train the discriminator with this batch
            discriminator_loss = self.discriminator.Discriminator.train_on_batch(x_batch,y_batch)[0]
        
            # Generate Noise
            x_latent_space_samples = self.sample_latent_space(self.BATCH)
            y_generated_labels = np.ones([self.BATCH,1])
            generator_loss = self.gan.gan_model.train_on_batch(x_latent_space_samples,y_generated_labels)

            print ('Epoch: '+str(int(e))+', [Discriminator :: Loss: '+str(discriminator_loss)+'], [ Generator :: Loss: '+str(generator_loss)+']')
                        
            if e % self.CHECKPOINT == 0 :
                self.plot_checkpoint(e)
        return 
开发者ID:PacktPublishing,项目名称:Generative-Adversarial-Networks-Cookbook,代码行数:35,代码来源:train.py

示例6: __init__

# 需要导入模块: import generator [as 别名]
# 或者: from generator import Generator [as 别名]
def __init__(self, height = 64, width = 64, epochs = 50000, batch = 32, checkpoint = 50, train_data_path_A = '',train_data_path_B = '',test_data_path_A='',test_data_path_B='',lambda_cycle=10.0,lambda_id=1.0):
        self.EPOCHS = epochs
        self.BATCH = batch
        self.RESIZE_HEIGHT = height
        self.RESIZE_WIDTH = width
        self.CHECKPOINT = checkpoint

        self.X_train_A, self.H_A, self.W_A, self.C_A = self.load_data(train_data_path_A)
        self.X_train_B, self.H_B, self.W_B, self.C_B  = self.load_data(train_data_path_B)
        self.X_test_A, self.H_A_test, self.W_A_test, self.C_A_test = self.load_data(test_data_path_A)
        self.X_test_B, self.H_B_test, self.W_B_test, self.C_B_test  = self.load_data(test_data_path_B)

        self.generator_A_to_B = Generator(height=self.H_A, width=self.W_A, channels=self.C_A)
        self.generator_B_to_A = Generator(height=self.H_B, width=self.W_B, channels=self.C_B)

        self.orig_A = Input(shape=(self.W_A, self.H_A, self.C_A))
        self.orig_B = Input(shape=(self.W_B, self.H_B, self.C_B))

        self.fake_B = self.generator_A_to_B.Generator(self.orig_A)
        self.fake_A = self.generator_B_to_A.Generator(self.orig_B)
        self.reconstructed_A = self.generator_B_to_A.Generator(self.fake_B)
        self.reconstructed_B = self.generator_A_to_B.Generator(self.fake_A)
        self.id_A = self.generator_B_to_A.Generator(self.orig_A)
        self.id_B = self.generator_A_to_B.Generator(self.orig_B)


        self.discriminator_A = Discriminator(height=self.H_A, width=self.W_A, channels=self.C_A)
        self.discriminator_B = Discriminator(height=self.H_B, width=self.W_B, channels=self.C_B)
        self.discriminator_A.trainable = False
        self.discriminator_B.trainable = False
        self.valid_A = self.discriminator_A.Discriminator(self.fake_A)
        self.valid_B = self.discriminator_B.Discriminator(self.fake_B)

        model_inputs  = [self.orig_A,self.orig_B]
        model_outputs = [self.valid_A, self.valid_B,self.reconstructed_A,self.reconstructed_B,self.id_A, self.id_B]
        self.gan = GAN(model_inputs=model_inputs,model_outputs=model_outputs,lambda_cycle=lambda_cycle,lambda_id=lambda_id) 
开发者ID:PacktPublishing,项目名称:Generative-Adversarial-Networks-Cookbook,代码行数:38,代码来源:train.py

示例7: plot_checkpoint

# 需要导入模块: import generator [as 别名]
# 或者: from generator import Generator [as 别名]
def plot_checkpoint(self,b):
        orig_filename = "/data/batch_check_"+str(b)+"_original.png"

        image_A = self.X_test_A[5]
        image_A = np.reshape(image_A, [self.W_A_test,self.H_A_test,self.C_A_test])
        print("Image_A shape: " +str(np.shape(image_A)))
        fake_B = self.generator_A_to_B.Generator.predict(image_A.reshape(1, self.W_A, self.H_A, self.C_A ))
        fake_B = np.reshape(fake_B, [self.W_A_test,self.H_A_test,self.C_A_test])
        print("fake_B shape: " +str(np.shape(fake_B)))
        reconstructed_A = self.generator_B_to_A.Generator.predict(fake_B.reshape(1, self.W_A, self.H_A, self.C_A ))
        reconstructed_A = np.reshape(reconstructed_A, [self.W_A_test,self.H_A_test,self.C_A_test])
        print("reconstructed_A shape: " +str(np.shape(reconstructed_A)))
        # from IPython import embed; embed()

        checkpoint_images = np.array([image_A, fake_B, reconstructed_A])

        # Rescale images 0 - 1
        checkpoint_images = 0.5 * checkpoint_images + 0.5

        titles = ['Original', 'Translated', 'Reconstructed']
        fig, axes = plt.subplots(1, 3)
        for i in range(3):
            image = checkpoint_images[i]
            image = np.reshape(image, [self.H_A_test,self.W_A_test,self.C_A_test])
            axes[i].imshow(image)
            axes[i].set_title(titles[i])
            axes[i].axis('off')
        fig.savefig("/data/batch_check_"+str(b)+".png")
        plt.close('all')
        return 
开发者ID:PacktPublishing,项目名称:Generative-Adversarial-Networks-Cookbook,代码行数:32,代码来源:train.py

示例8: __init__

# 需要导入模块: import generator [as 别名]
# 或者: from generator import Generator [as 别名]
def __init__(self, height = 256, width = 256, channels=3, epochs = 50000, batch = 1, checkpoint = 50, train_data_path = '',test_data_path=''):
        self.EPOCHS = epochs
        self.BATCH = batch
        self.H = height
        self.W = width
        self.C = channels
        self.CHECKPOINT = checkpoint

        self.X_train_B, self.X_train_A = self.load_data(train_data_path)
        self.X_test_B, self.X_test_A  = self.load_data(test_data_path)


        self.generator = Generator(height=self.H, width=self.W, channels=self.C)

        self.orig_A = Input(shape=(self.W, self.H, self.C))
        self.orig_B = Input(shape=(self.W, self.H, self.C))

        self.fake_A = self.generator.Generator(self.orig_B)

        self.discriminator = Discriminator(height=self.H, width=self.W, channels=self.C)
        self.discriminator.trainable = False
        self.valid = self.discriminator.Discriminator([self.fake_A,self.orig_B])

        model_inputs  = [self.orig_A,self.orig_B]
        model_outputs = [self.valid, self.fake_A]
        self.gan = GAN(model_inputs=model_inputs,model_outputs=model_outputs) 
开发者ID:PacktPublishing,项目名称:Generative-Adversarial-Networks-Cookbook,代码行数:28,代码来源:train.py

示例9: plot_checkpoint

# 需要导入模块: import generator [as 别名]
# 或者: from generator import Generator [as 别名]
def plot_checkpoint(self,b):
        orig_filename = "/out/batch_check_"+str(b)+"_original.png"

        r, c = 3, 3
        random_inds = random.sample(range(len(self.X_test_A)),3)
        imgs_A = self.X_test_A[random_inds].reshape(3, self.W, self.H, self.C )
        imgs_B = self.X_test_B[random_inds].reshape( 3, self.W, self.H, self.C )
        fake_A = self.generator.Generator.predict(imgs_B)

        gen_imgs = np.concatenate([imgs_B, fake_A, imgs_A])

        # Rescale images 0 - 1
        gen_imgs = 0.5 * gen_imgs + 0.5

        titles = ['Style', 'Generated', 'Original']
        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(gen_imgs[cnt])
                axs[i, j].set_title(titles[i])
                axs[i,j].axis('off')
                cnt += 1
        fig.savefig("/out/batch_check_"+str(b)+".png")
        plt.close('all')

        return 
开发者ID:PacktPublishing,项目名称:Generative-Adversarial-Networks-Cookbook,代码行数:29,代码来源:train.py

示例10: plot_checkpoint

# 需要导入模块: import generator [as 别名]
# 或者: from generator import Generator [as 别名]
def plot_checkpoint(self,e,label):
        filename = "/out/epoch_"+str(e)+"_label_"+str(label)+".png"

        all_encoded_samples = self.X_test_2D_encoded[np.where(self.Y_test_2D==label)]
        index = randint(0, (len(all_encoded_samples)-1))
        batch_encoded_samples = all_encoded_samples[ index ]
        batch_encoded_samples = batch_encoded_samples.reshape( 1, 1, 1, 1,self.LATENT_SPACE_SIZE)

        images = self.generator.Generator.predict(batch_encoded_samples)
        xs = []
        ys = []
        zs = []
        cs = []
        for i in range(16):
            for j in range(16):
                for k in range(16):
                    color = images[0][i][j][k]
                    if np.mean(color)<0.75 and np.mean(color)>0.25:
                        xs.append(i)
                        ys.append(j)
                        zs.append(k)
                        cs.append(color)

        fig = plt.figure()
        ax = fig.gca(projection='3d')
        ax.scatter(xs,ys,zs,alpha=0.1,c=cs)
        plt.savefig(filename)

        return 
开发者ID:PacktPublishing,项目名称:Generative-Adversarial-Networks-Cookbook,代码行数:31,代码来源:train.py

示例11: __init__

# 需要导入模块: import generator [as 别名]
# 或者: from generator import Generator [as 别名]
def __init__(self, height=55,width=35, channels=1,epochs =100, batch=16, checkpoint=50,sim_path='',real_path='',data_limit=0.001,generator_steps=2,discriminator_steps=1):
        self.W = width
        self.H = height
        self.C = channels
        self.EPOCHS = epochs
        self.BATCH = batch
        self.CHECKPOINT = checkpoint
        self.DATA_LIMIT=data_limit
        self.GEN_STEPS = generator_steps
        self.DISC_STEPS = discriminator_steps

        self.X_real = self.load_h5py(real_path)
        self.X_sim = self.load_h5py(sim_path)

        self.refiner = Generator(height=self.H, width=self.W, channels=self.C)
        self.discriminator = Discriminator(height=self.H, width=self.W, channels=self.C)
        self.discriminator.trainable = False

        self.synthetic_image = Input(shape=(self.H, self.W, self.C))
        self.real_or_fake = Input(shape=(self.H, self.W, self.C))


        self.refined_image = self.refiner.Generator(self.synthetic_image)
        self.discriminator_output = self.discriminator.Discriminator(self.real_or_fake)
        self.combined = self.discriminator.Discriminator(self.refined_image)

        model_inputs  = [self.synthetic_image]
        model_outputs = [self.refined_image, self.combined]
        self.gan = GAN(model_inputs=model_inputs,model_outputs=model_outputs) 
开发者ID:PacktPublishing,项目名称:Generative-Adversarial-Networks-Cookbook,代码行数:31,代码来源:train.py

示例12: train

# 需要导入模块: import generator [as 别名]
# 或者: from generator import Generator [as 别名]
def train(self):
        for e in range(self.EPOCHS):

            b = 0
            X_real_temp = deepcopy(self.X_real)
            X_sim_temp = deepcopy(self.X_sim)
            combined_loss = np.zeros(shape=len(self.gan.gan_model.metrics_names))
            discriminator_loss_real = np.zeros(shape=len(self.discriminator.Discriminator.metrics_names))
            discriminator_loss_sim = np.zeros(shape=len(self.discriminator.Discriminator.metrics_names))

            while min(len(X_real_temp),len(X_sim_temp))>self.BATCH:
                # Keep track of Batches
                b=b+1

                count_real_images = int(self.BATCH)
                starting_indexs = randint(0, (min(len(X_real_temp),len(X_sim_temp))-count_real_images))
              
                real_images_raw = X_real_temp[ starting_indexs : (starting_indexs + count_real_images) ]
                real_images = real_images_raw.reshape( count_real_images, self.H, self.W, self.C )

                y_real = np.array([[[1.0, 0.0]] * self.discriminator.Discriminator.output_shape[1]] * self.BATCH)
                
                sim_images_raw = X_sim_temp[ starting_indexs : (starting_indexs + count_real_images) ]
                sim_images = sim_images_raw.reshape( count_real_images, self.H, self.W, self.C )

                y_sim = np.array([[[0.0, 1.0]] * self.discriminator.Discriminator.output_shape[1]] * self.BATCH)

                for _ in range(self.GEN_STEPS):
                    combined_loss = np.add(self.gan.gan_model.train_on_batch(sim_images,[sim_images, y_real]), combined_loss)
        
                for _ in range(self.DISC_STEPS):
                    improved_image_batch = self.refiner.Generator.predict_on_batch(sim_images)
                    discriminator_loss_real = np.add(self.discriminator.Discriminator.train_on_batch(real_images, y_real), discriminator_loss_real)
                    discriminator_loss_sim = np.add(self.discriminator.Discriminator.train_on_batch(improved_image_batch, y_sim),discriminator_loss_sim)

            print ('Epoch: '+str(int(e))+', [Real Discriminator :: Loss: '+str(discriminator_loss_real)+'], [ GAN :: Loss: '+str(combined_loss)+']')
                        
        return 
开发者ID:PacktPublishing,项目名称:Generative-Adversarial-Networks-Cookbook,代码行数:40,代码来源:train.py

示例13: main

# 需要导入模块: import generator [as 别名]
# 或者: from generator import Generator [as 别名]
def main(m_path, out_dir, light):
    logger = get_logger("export")
    try:
        g = Generator(light=light)
        g.load_weights(tf.train.latest_checkpoint(m_path))
        t = tf.keras.Input(shape=[None, None, 3], batch_size=None)
        g(t, training=False)
        g.summary()
    except ValueError as e:
        logger.error(e)
        logger.error("Failed to load specified weight.")
        logger.error("If you trained your model with --light, "
                     "consider adding --light when executing this script; otherwise, "
                     "do not add --light when executing this script.")
        exit(1)
    m_num = 0
    smd = os.path.join(out_dir, "SavedModel")
    tfmd = os.path.join(out_dir, "tfjs_model")
    if light:
        smd += "Light"
        tfmd += "_light"
    saved_model_dir = f"{smd}_{m_num:04d}"
    tfjs_model_dir = f"{tfmd}_{m_num:04d}"
    while os.path.exists(saved_model_dir):
        m_num += 1
        saved_model_dir = f"{smd}_{m_num:04d}"
        tfjs_model_dir = f"{tfmd}_{m_num:04d}"
    tf.saved_model.save(g, saved_model_dir)
    cmd = ['tensorflowjs_converter', '--input_format', 'tf_saved_model',
           '--output_format', 'tfjs_graph_model', saved_model_dir, tfjs_model_dir]
    logger.info(" ".join(cmd))
    exit_code = Popen(cmd).wait()
    if exit_code == 0:
        logger.info(f"Model converted to {saved_model_dir} and {tfjs_model_dir} successfully")
    else:
        logger.error("tfjs model conversion failed") 
开发者ID:mnicnc404,项目名称:CartoonGan-tensorflow,代码行数:38,代码来源:export.py

示例14: build_generator

# 需要导入模块: import generator [as 别名]
# 或者: from generator import Generator [as 别名]
def build_generator(self):
        """initializing the generator"""

        with tf.variable_scope("generator"):
            self.generator = generator.Generator(n_node=self.n_node, node_emd_init=self.node_embed_init_g) 
开发者ID:hwwang55,项目名称:GraphGAN,代码行数:7,代码来源:graph_gan.py

示例15: generator_wrapper

# 需要导入模块: import generator [as 别名]
# 或者: from generator import Generator [as 别名]
def generator_wrapper(self, cnt_round, cnt_gen, dic_path, dic_exp_conf, dic_agent_conf, dic_traffic_env_conf,
                          best_round=None):
        generator = Generator(cnt_round=cnt_round,
                              cnt_gen=cnt_gen,
                              dic_path=dic_path,
                              dic_exp_conf=dic_exp_conf,
                              dic_agent_conf=dic_agent_conf,
                              dic_traffic_env_conf=dic_traffic_env_conf,
                              best_round=best_round
                              )
        print("make generator")
        generator.generate()
        print("generator_wrapper end")
        return 
开发者ID:multi-commander,项目名称:Multi-Commander,代码行数:16,代码来源:pipeline.py


注:本文中的generator.Generator方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。