Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def do_nothing(self):
if len(self._on_conflict_updates) > 0:
raise QueryException("Can not have two conflict handlers")
self._on_conflict_do_nothing = True
def _validate_returning_term(self, term):
for field in term.fields():
if not any([self._insert_table, self._update_table, self._delete_from]):
raise QueryException("Returning can't be used in this query")
if (
field.table not in {self._insert_table, self._update_table}
and term not in self._from
):
raise QueryException("You can't return from other tables")
def on_conflict(self, target_field):
if not self._insert_table:
raise QueryException("On conflict only applies to insert query")
if isinstance(target_field, str):
self._on_conflict_field = self._conflict_field_str(target_field)
elif isinstance(target_field, Field):
self._on_conflict_field = target_field
def _on_conflict_sql(self, **kwargs):
if not self._on_conflict_do_nothing and len(self._on_conflict_updates) == 0:
if not self._on_conflict_field:
return ""
else:
raise QueryException("No handler defined for on conflict")
else:
conflict_query = " ON CONFLICT"
if self._on_conflict_field:
conflict_query += (
" ("
+ self._on_conflict_field.get_sql(with_alias=True, **kwargs)
+ ")"
)
if self._on_conflict_do_nothing:
conflict_query += " DO NOTHING"
elif len(self._on_conflict_updates) > 0:
if self._on_conflict_field:
conflict_query += " DO UPDATE SET {updates}".format(
updates=",".join(
"{field}={value}".format(
field=field.get_sql(**kwargs),
)
if self._on_conflict_do_nothing:
conflict_query += " DO NOTHING"
elif len(self._on_conflict_updates) > 0:
if self._on_conflict_field:
conflict_query += " DO UPDATE SET {updates}".format(
updates=",".join(
"{field}={value}".format(
field=field.get_sql(**kwargs),
value=value.get_sql(**kwargs),
)
for field, value in self._on_conflict_updates
)
)
else:
raise QueryException("Can not have fieldless on conflict do update")
return conflict_query
def do_update(self, update_field, update_value):
if self._on_conflict_do_nothing:
raise QueryException("Can not have two conflict handlers")
if isinstance(update_field, str):
field = self._conflict_field_str(update_field)
elif isinstance(update_field, Field):
field = update_field
self._on_conflict_updates.append((field, ValueWrapper(update_value)))
def _select_field_str(self, term):
if 0 == len(self._from):
raise QueryException(
"Cannot select {term}, no FROM table specified.".format(term=term)
)
if term == "*":
self._select_star = True
self._selects = [Star()]
return
self._select_field(Field(term, table=self._from[0]))
def _return_field_str(self, term):
if term == "*":
self._set_returns_for_star()
self._returns.append(Star())
return
if self._insert_table:
self._return_field(Field(term, table=self._insert_table))
elif self._update_table:
self._return_field(Field(term, table=self._update_table))
elif self._delete_from:
self._return_field(Field(term, table=self._from[0]))
else:
raise QueryException("Returning can't be used in this query")
def returning(self, *terms):
for term in terms:
if isinstance(term, Field):
self._return_field(term)
elif isinstance(term, str):
self._return_field_str(term)
elif isinstance(term, ArithmeticExpression):
self._return_other(term)
elif isinstance(term, Function):
raise QueryException("Aggregate functions are not allowed in returning")
else:
self._return_other(self.wrap_constant(term, self._wrapper_cls))