Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_returns_resultset_of_models(self):
class ModelExample(schemas.ResultSet):
name = schemas.StringType()
class SchemaExample(schemas.ResultSet):
results = schemas.ResultType(ModelExample)
class EndpointExample(BaseEndpoint):
@decorators.returns(SchemaExample)
def func(self, **kwargs):
return kwargs
endpoint = EndpointExample(None)
results = endpoint.func(results=[{"name": "item1"}, {"name": "item2"}])
self.assertEqual(results, SchemaExample({"results": [{"name": "item1"}, {"name": "item2"}]}))
self.assertEqual(endpoint.func()._more(results=[{"name": "item2"}, {"name": "item4"}]), SchemaExample({'results': [{"name": "item2"}, {"name": "item4"}]}))
for item in results:
self.assertEqual(item._endpoint, endpoint)
def test_returns_resultset_of_models(self):
class ModelExample(schemas.ResultSet):
name = schemas.StringType()
class SchemaExample(schemas.ResultSet):
results = schemas.ResultType(ModelExample)
class EndpointExample(BaseEndpoint):
@decorators.returns(SchemaExample)
def func(self, **kwargs):
return kwargs
endpoint = EndpointExample(None)
results = endpoint.func(results=[{"name": "item1"}, {"name": "item2"}])
self.assertEqual(results, SchemaExample({"results": [{"name": "item1"}, {"name": "item2"}]}))
def test_resultset():
class ResultExample(schemas.Model):
value = schemas.IntType()
class ResultSetExample(schemas.ResultSet):
results = schemas.ResultType(ResultExample)
class EndpointExample(BaseEndpoint):
@decorators.returns(ResultSetExample)
def load_page(self, page):
page = int(page)
return {
"count": 9,
"next": "http://example.org/?page={}".format(page + 1) if page < 3 else None,
"previous": "http://example.org/?page={}".format(page - 1) if page > 1 else None,
"results": [{"value": 1 + (3 * (page - 1))}, {"value": 2 + (3 * (page - 1))}, {"value": 3 + (3 * (page - 1))}]
}
endpoint = EndpointExample(None)
date = DateType()
count = IntType()
impact = IntType()
rank_levels = DictType(IntType, export_level=NONEMPTY)
rank_levels_impact = DictType(IntType, export_level=NONEMPTY)
aviation_rank_levels = DictType(IntType, export_level=NONEMPTY)
aviation_rank_levels_impact = DictType(IntType, export_level=NONEMPTY)
categories = DictType(IntType)
categories_impact = DictType(IntType)
class ImpactResultSet(ResultSet):
results = ResultType(ImpactDay)
def wrapper(endpoint, *args, **kwargs):
schema = getattr(endpoint.Meta, f.__name__, {}).get("returns", schema_class)
data = f(endpoint, *args, **kwargs)
try:
model = schema()
model._endpoint = endpoint
# if schema class is a ResultSet, tell it how to load more results
if issubclass(schema_class, ResultSet):
model._more = functools.partial(wrapper, endpoint)
# if results are of type Model, make sure to set the endpoint on each item
if data is not None and 'results' in data \
and hasattr(model._fields['results'], 'model_class') \
and issubclass(model._fields['results'].model_class, Model):
def initialize_result_type(item_data):
item = model._fields['results'].model_class(item_data, strict=False)
item._endpoint = endpoint
return item
# Use generator so results are not iterated over more than necessary
data['results'] = (initialize_result_type(item_data) for item_data in data['results'])
model.import_data(data, strict=False)
model.validate()
except SchematicsDataError as e:
class CountAnalysisComponent(Model):
count = IntType()
expected = FloatType()
excess = FloatType()
class DailyAnalysis(Model):
date = DateType()
demand = ModelType(CountAnalysisComponent)
lead = ModelType(MeanAnalysisComponent)
span = ModelType(MeanAnalysisComponent)
class AnalysisResultSet(ResultSet):
results = ResultType(DailyAnalysis)
class AnalysisParams(PaginatedMixin, SortableMixin, Model):
class Options:
serialize_when_none = False
id = StringType(required=True)
date = ModelType(DateTimeRange)
initiated = ModelType(DateTimeRange)
completed = ModelType(DateTimeRange)
within = StringListType(StringModelType(Area), separator="+")
significance = FloatType(min_value=0, max_value=100)
place = ModelType(Place)
top_events = ModelType(TopEventsSearchParams)
class CalendarDay(Model):
date = DateType()
count = IntType()
top_rank = FloatType()
rank_levels = DictType(IntType)
categories = DictType(IntType)
labels = DictType(IntType)
top_events = ModelType(EventResultSet)
class CalendarResultSet(ResultSet):
results = ResultType(CalendarDay)
class ImpactParams(SearchParams):
top_events = ModelType(TopEventsSearchParams)
impact_rank = StringType(choices=('rank', 'aviation_rank'))
class ImpactDay(Model):
date = DateType()
count = IntType()
impact = IntType()
class Place(Model):
id = StringType()
type = StringType()
name = StringType()
county = StringType()
region = StringType()
country = StringType()
country_alpha2 = StringType()
country_alpha3 = StringType()
location = GeoJSONPointType()
class PlaceResultSet(ResultSet):
results = ResultType(Place)
return self._endpoint.analysis(id=self.id, **params)
class NewSignal(Signal):
class Options(Signal.Options):
pass
class SavedSignal(SignalID, Signal):
class Options(Signal.Options):
pass
class SignalResultSet(ResultSet):
results = ResultType(SavedSignal)
class DataPoint(Model):
uid = StringType(required=True)
date = DateTimeType(required=True)
latitude = FloatType(min_value=-90, max_value=90, required=True)
longitude = FloatType(min_value=-180, max_value=180, required=True)
initiated = DateTimeType()
completed = DateTimeType()
# @todo: Support custom dimensions from signal
class SignalDataPoints(Model):
aviation_rank = IntType()
phq_attendance = IntType()
entities = ListType(ModelType(Entities))
location = GeoJSONPointType()
place_hierarchies = ListType(ListType(StringType()))
scope = StringType()
relevance = FloatType()
state = StringType()
first_seen = DateTimeType()
updated = DateTimeType()
deleted_reason = StringType()
duplicate_of_id = StringType()
class EventResultSet(ResultSet):
overflow = BooleanType()
results = ResultType(Event)
class CountResultSet(Model):
count = IntType()
top_rank = FloatType()
rank_levels = DictType(IntType)
categories = DictType(IntType)
labels = DictType(IntType)
class TopEventsSearchParams(SortableMixin, Model):