當前位置: 首頁>>代碼示例>>Python>>正文


Python nn.DistributedDataParallel方法代碼示例

本文整理匯總了Python中torch.nn.DistributedDataParallel方法的典型用法代碼示例。如果您正苦於以下問題:Python nn.DistributedDataParallel方法的具體用法?Python nn.DistributedDataParallel怎麽用?Python nn.DistributedDataParallel使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在torch.nn的用法示例。


在下文中一共展示了nn.DistributedDataParallel方法的2個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。

示例1: __init__

# 需要導入模塊: from torch import nn [as 別名]
# 或者: from torch.nn import DistributedDataParallel [as 別名]
def __init__(self,
            use_cuda=None, # use cuda or not
            use_tsm=False, # use the Temporal Shift module or not
            use_nl=False, # use the Non-local module or not
            use_tc=False, # use the Timeception module or not
            use_lstm=False, # use LSTM module or not
            freeze_i3d=False, # freeze i3d layers when training Timeception
            batch_size_train=10, # size for each batch for training
            batch_size_test=50, # size for each batch for testing
            batch_size_extract_features=40, # size for each batch for extracting features
            max_steps=2000, # total number of steps for training
            num_steps_per_update=2, # gradient accumulation (for large batch size that does not fit into memory)
            init_lr=0.1, # initial learning rate
            weight_decay=0.000001, # L2 regularization
            momentum=0.9, # SGD parameters
            milestones=[500, 1500], # MultiStepLR parameters
            gamma=0.1, # MultiStepLR parameters
            num_of_action_classes=2, # currently we only have two classes (0 and 1, which means no and yes)
            num_steps_per_check=50, # the number of steps to save a model and log information
            parallel=True, # use nn.DistributedDataParallel or not
            augment=True, # use data augmentation or not
            num_workers=12, # number of workers for the dataloader
            mode="rgb", # can be "rgb" or "flow" or "rgbd"
            p_frame="../data/rgb/", # path to load video frames
            code_testing=False # a special flag for testing if the code works
            ):
        super().__init__(use_cuda=use_cuda)

        self.use_tsm = use_tsm
        self.use_nl = use_nl
        self.use_tc = use_tc
        self.use_lstm = use_lstm
        self.freeze_i3d = freeze_i3d
        self.batch_size_train = batch_size_train
        self.batch_size_test = batch_size_test
        self.batch_size_extract_features = batch_size_extract_features
        self.max_steps = max_steps
        self.num_steps_per_update = num_steps_per_update
        self.init_lr = init_lr
        self.weight_decay = weight_decay
        self.momentum = momentum
        self.milestones = milestones
        self.gamma = gamma
        self.num_of_action_classes = num_of_action_classes
        self.num_steps_per_check = num_steps_per_check
        self.parallel = parallel
        self.augment = augment
        self.num_workers = num_workers
        self.mode = mode
        self.p_frame = p_frame

        # Internal parameters
        self.image_size = 224 # 224 is the input for the i3d network structure
        self.can_parallel = False

        # Code testing mode
        self.code_testing = code_testing
        if code_testing:
            self.max_steps = 10 
開發者ID:CMU-CREATE-Lab,項目名稱:deep-smoke-machine,代碼行數:61,代碼來源:i3d_learner.py

示例2: __init__

# 需要導入模塊: from torch import nn [as 別名]
# 或者: from torch.nn import DistributedDataParallel [as 別名]
def __init__(self,
            use_cuda=None, # use cuda or not
            batch_size_train=6, # size for each batch for training
            batch_size_test=40, # size for each batch for testing
            batch_size_extract_features=40, # size for each batch for extracting features
            max_steps=2000, # total number of steps for training
            num_steps_per_update=2, # gradient accumulation (for large batch size that does not fit into memory)
            init_lr=0.01, # initial learning rate
            weight_decay=0.000001, # L2 regularization
            momentum=0.9, # SGD parameters
            milestones=[500, 1500], # MultiStepLR parameters
            gamma=0.1, # MultiStepLR parameters
            num_of_action_classes=2, # currently we only have two classes (0 and 1, which means no and yes)
            num_steps_per_check=50, # the number of steps to save a model and log information
            parallel=True, # use nn.DistributedDataParallel or not
            augment=True, # use data augmentation or not
            num_workers=12, # number of workers for the dataloader
            mode="rgb", # can be "rgb" or "flow"
            p_frame="../data/rgb/", # path to load video frames
            method="cnn", # the method for the model
            freeze_cnn=False, # freeze the CNN model while training or not
            code_testing=False # a special flag for testing if the code works
            ):
        super().__init__(use_cuda=use_cuda)

        self.batch_size_train = batch_size_train
        self.batch_size_test = batch_size_test
        self.batch_size_extract_features = batch_size_extract_features
        self.max_steps = max_steps
        self.num_steps_per_update = num_steps_per_update
        self.init_lr = init_lr
        self.weight_decay = weight_decay
        self.momentum = momentum
        self.milestones = milestones
        self.gamma = gamma
        self.num_of_action_classes = num_of_action_classes
        self.num_steps_per_check = num_steps_per_check
        self.parallel = parallel
        self.augment = augment
        self.num_workers = num_workers
        self.mode = mode
        self.p_frame = p_frame
        self.method = method
        self.freeze_cnn = freeze_cnn

        # Internal parameters
        self.image_size = 224 # 224 is the input for the ResNet18 network structure
        self.can_parallel = False

        # Code testing mode
        self.code_testing = code_testing
        if code_testing:
            self.max_steps = 10 
開發者ID:CMU-CREATE-Lab,項目名稱:deep-smoke-machine,代碼行數:55,代碼來源:cnn_learner.py


注:本文中的torch.nn.DistributedDataParallel方法示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。