Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_augment_data():
model = Model(10, labels=['Unknown'])
a = [[1, 2, 3], [4, 5, 6]]
x = np.array(a)
a = ['label1', 'label2', 'label3']
y = np.array(a)
model._augment_data(x, y)
def train(self, data_dir, save_path):
m = RandomForestClassifier(
n_estimators=100,
min_samples_split=5,
class_weight='balanced'
)
# Initialize the model
model = Model(
duration=self.duration,
labels=self.conf_labels,
model=m,
model_type='randomforest'
)
# Train the model
model.train(data_dir)
# Save the model to the specified path
model.save(save_path)
def train(self, data_dir, save_path, m, algorithm):
# Initialize the model
model = Model(
duration=self.duration,
hidden_size=self.state_size,
labels=self.conf_labels,
model=m,
model_type=algorithm,
threshold_time=self.threshold
)
# Train the model
model.train(data_dir)
# Save the model to the specified path
model.save(save_path)
time_const,
model_path='networkml/trained_models/onelayer/OneLayerModel.pkl',
label=None,
model_type='randomforest'
):
logger = logging.getLogger(__name__)
try:
if 'LOG_LEVEL' in os.environ and os.environ['LOG_LEVEL'] != '':
logger.setLevel(os.environ['LOG_LEVEL'])
except Exception as e: # pragma: no cover
logger.error(
'Unable to set logging level because: {0} defaulting to INFO.'.format(str(e)))
# Load the model
logger.debug('Loading model')
model = Model(duration=None, hidden_size=None, model_type=model_type)
model.load(model_path)
# Get all the pcaps in the training directory
logger.debug('Getting pcaps')
pcaps = []
try:
ext = os.path.splitext(data_dir)[-1]
if ext == '.pcap':
pcaps.append(data_dir)
except Exception as e: # pragma: no cover
logger.debug('Skipping {0} because: {1}'.format(data_dir, str(e)))
for dirpath, _, filenames in os.walk(data_dir):
for filename in filenames:
ext = os.path.splitext(filename)[-1]
if ext == '.pcap':
## Take arguments from command line
self.args = None
self.read_args()
## Take input from configuration file
self.get_config()
self.common = Common(config=self.config)
## Instantiate a logger to to leg messages to aid debugging
self.logger = Common().setup_logger(self.logger)
## Add network traffic files for parsing
self.get_files()
self.model_hash = None
self.model = Model(duration=self.duration, hidden_size=None,
model_type=self.args.algorithm)
def create_base_alg():
return BaseAlgorithm(
files=self.files, config=self.config,
model=self.model, model_hash=self.model_hash,
model_path=self.args.trained_model)
## Check whether operation is evaluation, train, or test
## Evaluation returns predictions that are useful for the deployment
## of networkml in an operational environment.
if self.args.operation == 'eval':
self.load_model()
if (self.args.algorithm == 'onelayer' or self.args.algorithm == 'randomforest'):
base_alg = create_base_alg()