"""DataStore class for WorkBench."""
import pymongo
import gridfs
import hashlib
import datetime
import bson
import time
[docs]class DataStore(object):
"""DataStore for Workbench.
Currently tied to MongoDB but making this class 'abstract' should be
straightforward and we could think about using another backend.
"""
def __init__(self, uri='mongodb://localhost/workbench', database='workbench', worker_cap=0, samples_cap=0):
""" Initialization for the Workbench data store class.
Args:
uri: Connection String for DataStore backend.
database: Name of database.
worker_cap: MBs in the capped collection.
samples_cap: MBs of sample to be stored.
"""
self.sample_collection = 'samples'
self.worker_cap = worker_cap
self.samples_cap = samples_cap
# Get connection to mongo
self.database_name = database
self.uri = 'mongodb://'+uri+'/'+self.database_name
self.mongo = pymongo.MongoClient(self.uri, use_greenlets=True)
self.database = self.mongo.get_default_database()
# Get the gridfs handle
self.gridfs_handle = gridfs.GridFS(self.database)
# Run the periodic operations
self.last_ops_run = time.time()
self.periodic_ops()
print '\t- WorkBench DataStore connected: %s:%s' % (self.uri, self.database_name)
[docs] def get_uri(self):
""" Return the uri of the data store."""
return self.uri
[docs] def store_sample(self, sample_bytes, filename, type_tag):
"""Store a sample into the datastore.
Args:
filename: Name of the file.
sample_bytes: Actual bytes of sample.
type_tag: Type of sample ('exe','pcap','pdf','json','swf', or ...)
Returns:
md5 digest of the sample.
"""
# Temp sanity check for old clients
if len(filename) > 1000:
print 'switched bytes/filename... %s %s' % (sample_bytes[:100], filename[:100])
exit(1)
sample_info = {}
# Compute the MD5 hash
sample_info['md5'] = hashlib.md5(sample_bytes).hexdigest()
# Check if sample already exists
if self.has_sample(sample_info['md5']):
return sample_info['md5']
# Run the periodic operations
self.periodic_ops()
# Check if we need to expire anything
self.expire_data()
# Okay start populating the sample for adding to the data store
# Filename, length, import time and type_tag
sample_info['filename'] = filename
sample_info['length'] = len(sample_bytes)
sample_info['import_time'] = datetime.datetime.utcnow()
sample_info['type_tag'] = type_tag
# Random customer for now
import random
sample_info['customer'] = random.choice(['Mega Corp', 'Huge Inc', 'BearTron', 'Dorseys Mom'])
# Push the file into the MongoDB GridFS
sample_info['__grid_fs'] = self.gridfs_handle.put(sample_bytes)
self.database[self.sample_collection].insert(sample_info)
# Print info
print 'Sample Storage: %.2f out of %.2f MB' % (self.sample_storage_size(), self.samples_cap)
# Return the sample md5
return sample_info['md5']
[docs] def sample_storage_size(self):
"""Get the storage size of the samples storage collection."""
try:
coll_stats = self.database.command('collStats', 'fs.chunks')
sample_storage_size = coll_stats['size']/1024.0/1024.0
return sample_storage_size
except pymongo.errors.OperationFailure:
return 0
[docs] def expire_data(self):
"""Expire data within the samples collection."""
# Do we need to start deleting stuff?
while self.sample_storage_size() > self.samples_cap:
# This should return the 'oldest' record in samples
record = self.database[self.sample_collection].find().sort('import_time',pymongo.ASCENDING).limit(1)[0]
self.remove_sample(record['md5'])
[docs] def remove_sample(self, md5):
"""Delete a specific sample"""
# Grab the sample
record = self.database[self.sample_collection].find_one({'md5': md5})
if not record:
return
# Delete it
print 'Deleting sample: %s (%.2f MB)...' % (record['md5'], record['length']/1024.0/1024.0)
self.database[self.sample_collection].remove({'md5': record['md5']})
self.gridfs_handle.delete(record['__grid_fs'])
# Print info
print 'Sample Storage: %.2f out of %.2f MB' % (self.sample_storage_size(), self.samples_cap)
[docs] def clean_for_serialization(self, data):
"""Clean data in preparation for serialization.
Deletes items having key either a BSON, datetime, dict or a list instance, or
starting with __.
Args:
data: Sample data to be serialized.
Returns:
Cleaned data dictionary.
"""
if isinstance(data, dict):
for k in data.keys():
if (k.startswith('__')):
del data[k]
elif isinstance(data[k], bson.objectid.ObjectId):
del data[k]
elif isinstance(data[k], datetime.datetime):
data[k] = data[k].isoformat()+'Z'
elif isinstance(data[k], dict):
data[k] = self.clean_for_serialization(data[k])
elif isinstance(data[k], list):
data[k] = [self.clean_for_serialization(item) for item in data[k]]
return data
[docs] def clean_for_storage(self, data):
"""Clean data in preparation for storage.
Deletes items with key having a '.' or is '_id'. Also deletes those items
whose value is a dictionary or a list.
Args:
data: Sample data dictionary to be cleaned.
Returns:
Cleaned data dictionary.
"""
data = self.data_to_unicode(data)
if isinstance(data, dict):
for k in dict(data).keys():
if k == '_id':
del data[k]
continue
if '.' in k:
new_k = k.replace('.', '_')
data[new_k] = data[k]
del data[k]
k = new_k
if isinstance(data[k], dict):
data[k] = self.clean_for_storage(data[k])
elif isinstance(data[k], list):
data[k] = [self.clean_for_storage(item) for item in data[k]]
return data
[docs] def get_full_md5(self, partial_md5, collection):
"""Support partial/short md5s, return the full md5 with this method"""
print 'Notice: Performing slow md5 search...'
starts_with = '%s.*' % partial_md5
sample_info = self.database[collection].find_one({'md5': {'$regex' : starts_with}},{'md5':1})
return sample_info['md5'] if sample_info else None
[docs] def get_sample(self, md5):
"""Get the sample from the data store.
This method first fetches the data from datastore, then cleans it for serialization
and then updates it with 'raw_bytes' item.
Args:
md5: The md5 digest of the sample to be fetched from datastore.
Returns:
The sample dictionary or None
"""
# Support 'short' md5s but don't waste performance if the full md5 is provided
if len(md5) < 32:
md5 = self.get_full_md5(md5, self.sample_collection)
# Grab the sample
sample_info = self.database[self.sample_collection].find_one({'md5': md5})
if not sample_info:
return None
# Get the raw bytes from GridFS (note: this could fail)
try:
grid_fs_id = sample_info['__grid_fs']
sample_info = self.clean_for_serialization(sample_info)
sample_info.update({'raw_bytes':self.gridfs_handle.get(grid_fs_id).read()})
return sample_info
except gridfs.errors.CorruptGridFile:
# If we don't have the gridfs files, delete the entry from samples
self.database[self.sample_collection].update({'md5': md5}, {'md5': None})
return None
[docs] def get_sample_window(self, type_tag, size=10):
"""Get a window of samples not to exceed size (in MB).
Args:
type_tag: Type of sample ('exe','pcap','pdf','json','swf', or ...).
size: Size of samples in MBs.
Returns:
a list of md5s.
"""
# Convert size to MB
size = size * 1024 * 1024
# Grab all the samples of type=type_tag, sort by import_time (newest to oldest)
cursor = self.database[self.sample_collection].find({'type_tag': type_tag},
{'md5': 1,'length': 1}).sort('import_time',pymongo.DESCENDING)
total_size = 0
md5_list = []
for item in cursor:
if total_size > size:
return md5_list
md5_list.append(item['md5'])
total_size += item['length']
# If you get this far you don't have 'size' amount of data
# so just return what you've got
return md5_list
[docs] def has_sample(self, md5):
"""Checks if data store has this sample.
Args:
md5: The md5 digest of the required sample.
Returns:
True if sample with this md5 is present, else False.
"""
# The easiest thing is to simply get the sample and if that
# succeeds than return True, else return False
sample = self.get_sample(md5)
return True if sample else False
def _list_samples(self, predicate=None):
"""List all samples that meet the predicate or all if predicate is not specified.
Args:
predicate: Match samples against this predicate (or all if not specified)
Returns:
List of the md5s for the matching samples
"""
cursor = self.database[self.sample_collection].find(predicate, {'_id':0, 'md5':1})
return [item['md5'] for item in cursor]
[docs] def tag_match(self, tags=None):
"""List all samples that match the tags or all if tags are not specified.
Args:
tags: Match samples against these tags (or all if not specified)
Returns:
List of the md5s for the matching samples
"""
if 'tags' not in self.database.collection_names():
print 'Warning: Searching on non-existance tags collection'
return None
if not tags:
cursor = self.database['tags'].find({}, {'_id':0, 'md5':1})
else:
cursor = self.database['tags'].find({'tags': {'$in': tags}}, {'_id':0, 'md5':1})
return [item['md5'] for item in cursor]
[docs] def store_work_results(self, results, collection, md5):
"""Store the output results of the worker.
Args:
results: a dictionary.
collection: the database collection to store the results in.
md5: the md5 of sample data to be updated.
"""
# Make sure the md5 and time stamp is on the data before storing
results['md5'] = md5
results['__time_stamp'] = datetime.datetime.utcnow()
# If the data doesn't have a 'mod_time' field add one now
if 'mod_time' not in results:
results['mod_time'] = results['__time_stamp']
# Fixme: Occasionally a capped collection will not let you update with a
# larger object, if you have MongoDB 2.6 or above this shouldn't
# really happen, so for now just kinda punting and giving a message.
try:
self.database[collection].update({'md5':md5}, self.clean_for_storage(results), True)
except pymongo.errors.OperationFailure:
#self.database[collection].insert({'md5':md5}, self.clean_for_storage(results), True)
print 'Could not update exising object in capped collection, punting...'
print 'collection: %s md5:%s' % (collection, md5)
[docs] def get_work_results(self, collection, md5):
"""Get the results of the worker.
Args:
collection: the database collection storing the results.
md5: the md5 digest of the data.
Returns:
Dictionary of the worker result.
"""
return self.database[collection].find_one({'md5':md5})
[docs] def all_sample_md5s(self, type_tag=None):
"""Return a list of all md5 matching the type_tag ('exe','pdf', etc).
Args:
type_tag: the type of sample.
Returns:
a list of matching samples.
"""
if type_tag:
cursor = self.database[self.sample_collection].find({'type_tag': type_tag}, {'md5': 1, '_id': 0})
else:
cursor = self.database[self.sample_collection].find({}, {'md5': 1, '_id': 0})
return [match.values()[0] for match in cursor]
[docs] def clear_worker_output(self):
"""Drops all of the worker output collections"""
print 'Dropping all of the worker output collections... Whee!'
# Get all the collections in the workbench database
all_c = self.database.collection_names()
# Remove collections that we don't want to cap
try:
all_c.remove('system.indexes')
all_c.remove('fs.chunks')
all_c.remove('fs.files')
all_c.remove('sample_set')
all_c.remove('tags')
all_c.remove(self.sample_collection)
except ValueError:
print 'Catching a benign exception thats expected...'
for collection in all_c:
self.database.drop_collection(collection)
[docs] def clear_db(self):
"""Drops the entire workbench database."""
print 'Dropping the entire workbench database... Whee!'
self.mongo.drop_database(self.database_name)
[docs] def periodic_ops(self):
"""Run periodic operations on the the data store.
Operations like making sure collections are capped and indexes are set up.
"""
# Only run every 30 seconds
if (time.time() - self.last_ops_run) < 30:
return
# Reset last ops run
self.last_ops_run = time.time()
print 'Running Periodic Ops'
# Get all the collections in the workbench database
all_c = self.database.collection_names()
# Remove collections that we don't want to cap
try:
all_c.remove('system.indexes')
all_c.remove('fs.chunks')
all_c.remove('fs.files')
all_c.remove('info')
all_c.remove('tags')
all_c.remove(self.sample_collection)
except ValueError:
print 'Catching a benign exception thats expected...'
# Convert collections to capped if desired
try:
if self.worker_cap:
size = self.worker_cap * pow(1024, 2) # MegaBytes per collection
for collection in all_c:
self.database.command('convertToCapped', collection, size=size)
except pymongo.errors.AutoReconnect, e:
print 'Warning: MongoDB raised an AutoReconnect...', e
time.sleep(5)
# Loop through all collections ensuring they have an index on MD5s
for collection in all_c:
self.database[collection].ensure_index('md5')
# Add required indexes for samples collection
self.database[self.sample_collection].create_index('import_time')
# Create an index on tags
self.database['tags'].create_index('tags')
# Helper functions
[docs] def to_unicode(self, s):
"""Convert an elementary datatype to unicode.
Args:
s: the datatype to be unicoded.
Returns:
Unicoded data.
"""
# Fixme: This is total horseshit
if isinstance(s, unicode):
return s
if isinstance(s, str):
return unicode(s, errors='ignore')
# Just return the original object
return s
[docs] def data_to_unicode(self, data):
"""Recursively convert a list or dictionary to unicode.
Args:
data: The data to be unicoded.
Returns:
Unicoded data.
"""
if isinstance(data, dict):
return {self.to_unicode(k): self.to_unicode(v) for k, v in data.iteritems()}
if isinstance(data, list):
return [self.to_unicode(l) for l in data]
else:
return self.to_unicode(data)