Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def ensure_subnet(vpc, availability_zone=None):
if availability_zone is not None and availability_zone not in availability_zones():
msg = "Unknown availability zone {} (choose from {})"
raise AegeaException(msg.format(availability_zone, list(availability_zones())))
for subnet in vpc.subnets.all():
if availability_zone is not None and subnet.availability_zone != availability_zone:
continue
break
else:
from ipaddress import ip_network
from ... import config
subnet_cidrs = ip_network(str(config.vpc.cidr[ARN.get_region()])).subnets(new_prefix=config.vpc.subnet_prefix)
subnets = {}
for az, subnet_cidr in zip(availability_zones(), subnet_cidrs):
logger.info("Creating subnet with CIDR %s in %s, %s", subnet_cidr, vpc, az)
subnets[az] = resources.ec2.create_subnet(VpcId=vpc.id, CidrBlock=str(subnet_cidr), AvailabilityZone=az)
clients.ec2.get_waiter("subnet_available").wait(SubnetIds=[subnets[az].id])
add_tags(subnets[az], Name=__name__)
clients.ec2.modify_subnet_attribute(SubnetId=subnets[az].id,
MapPublicIpOnLaunch=dict(Value=config.vpc.map_public_ip_on_launch))
def find_acm_cert(dns_name):
for cert in paginate(clients.acm.get_paginator("list_certificates")):
cert.update(clients.acm.describe_certificate(CertificateArn=cert["CertificateArn"])["Certificate"])
for name in cert["SubjectAlternativeNames"]:
if name in [dns_name, ".".join(["*"] + dns_name.split(".")[1:])]:
return cert
raise AegeaException("Unable to find ACM certificate for {}".format(dns_name))
def resolve_log_group(name):
for log_group in paginate(clients.logs.get_paginator("describe_log_groups"), logGroupNamePrefix=name):
if log_group["logGroupName"] == name:
return log_group
else:
raise AegeaException("Log group {} not found".format(name))
def get_target_group(alb_name, target_group_name):
alb = clients.elbv2.describe_load_balancers(Names=[alb_name])["LoadBalancers"][0]
target_groups = clients.elbv2.describe_target_groups(LoadBalancerArn=alb["LoadBalancerArn"])["TargetGroups"]
for target_group in target_groups:
if target_group["TargetGroupName"] == target_group_name:
return dict(alb, **target_group)
m = "Target group {} not found in {} (target groups found: {})"
raise AegeaException(m.format(target_group_name, alb_name, ", ".join(t["TargetGroupName"] for t in target_groups)))
def resolve_instance_public_dns(name):
instance = get_instance(name)
if not getattr(instance, "public_dns_name", None):
msg = "Unable to resolve public DNS name for {} (state: {})"
raise AegeaException(msg.format(instance, getattr(instance, "state", {}).get("Name")))
tags = {tag["Key"]: tag["Value"] for tag in instance.tags or []}
ssh_host_key = tags.get("SSHHostPublicKeyPart1", "") + tags.get("SSHHostPublicKeyPart2", "")
if ssh_host_key:
# FIXME: this results in duplicates.
# Use paramiko to detect if the key is already listed and not insert it then (or only insert if different)
add_ssh_host_key_to_known_hosts(instance.public_dns_name + " " + ssh_host_key + "\n")
return instance.public_dns_name
ensure_bless_ssh_cert(ssh_key_name=ssh_key_name,
bless_config=bless_config,
use_kms_auth=use_kms_auth)
add_ssh_key_to_agent(ssh_key_name)
instance = get_instance(hostname)
if not username:
username = bless_config["client_config"]["remote_users"][0]
bastion_config = match_instance_to_bastion(instance=instance, bastions=bless_config["ssh_config"]["bastions"])
if bastion_config:
jump_host = bastion_config["user"] + "@" + bastion_config["pattern"]
return ["-o", "ProxyJump=" + jump_host], username + "@" + instance.private_ip_address
elif instance.public_dns_name:
logger.warn("No bastion host found for %s, trying direct connection", instance.private_ip_address)
return [], username + "@" + instance.public_dns_name
else:
raise AegeaException("No bastion host or public route found for {}".format(instance))
else:
if get_instance(hostname).key_name is not None:
add_ssh_key_to_agent(get_instance(hostname).key_name)
if not username:
username = get_linux_username()
return [], username + "@" + resolve_instance_public_dns(hostname)
def resolve_instance_id(name):
filter_name = "dns-name" if name.startswith("ec2") and name.endswith("compute.amazonaws.com") else "tag:Name"
if name.startswith("i-"):
return name
try:
desc = clients.ec2.describe_instances(Filters=[dict(Name=filter_name, Values=[name])])
return desc["Reservations"][0]["Instances"][0]["InstanceId"]
except IndexError:
raise AegeaException('Could not resolve "{}" to a known instance'.format(name))
aws_secret_access_key=assume_role_res['Credentials']['SecretAccessKey'],
aws_session_token=assume_role_res['Credentials']['SessionToken'])
bless_input = dict(bastion_user=iam.CurrentUser().user_name,
bastion_user_ip="0.0.0.0/0",
bastion_ips=",".join(bless_config["client_config"]["bastion_ips"]),
remote_usernames=",".join(bless_config["client_config"]["remote_users"]),
public_key_to_sign=get_public_key_from_pair(ssh_key),
command="*")
if use_kms_auth:
bless_input["kmsauth_token"] = get_kms_auth_token(session=session,
bless_config=bless_config,
lambda_regional_config=lambda_regional_config)
res = awslambda.invoke(FunctionName=bless_config["lambda_config"]["function_name"], Payload=json.dumps(bless_input))
bless_output = json.loads(res["Payload"].read().decode())
if "certificate" not in bless_output:
raise AegeaException("Error while requesting Bless SSH certificate: {}".format(bless_output))
with open(ssh_cert_filename, "w") as fh:
fh.write(bless_output["certificate"])
return ssh_cert_filename