package WeBWorK::Authen::Saml2;
use Mojo::Base 'WeBWorK::Authen', -signatures;

use Mojo::File qw(path tempfile);
use Mojo::JSON qw(encode_json);
use Mojo::UserAgent;
use Net::SAML2::IdP;
use Net::SAML2::SP;
use URN::OASIS::SAML2 qw(BINDING_HTTP_POST BINDING_HTTP_REDIRECT);
use Net::SAML2::Binding::POST;
use Net::SAML2::Protocol::Assertion;

use WeBWorK::Debug qw(debug);
use WeBWorK::Authen::LTIAdvanced::Nonce;

=head1 NAME

WeBWorK::Authen::Saml2 - Authenticate using a SAML2 identity provider.

=cut

sub request_has_data_for_this_verification_module ($self) {
	my $c = $self->{c};

	# Skip if the bypass_query param is set.
	if ($c->ce->{saml2}{bypass_query} && $c->param($c->ce->{saml2}{bypass_query})) {
		debug('Saml2 authen module bypass detected. Going to next authentication module.');
		return 0;
	}

	return 1;
}

sub verify ($self) {
	my $result = $self->SUPER::verify;
	my $c      = $self->{c};

	if ($c->current_route eq 'saml2_acs') {
		# Transfer the saml2_nameid and saml2_session to the webwork session.
		# These are used to logout of the identity provider if that is configured.
		$self->session->{saml2_nameid}  = $c->stash->{saml2_nameid}  if $c->stash->{saml2_nameid};
		$self->session->{saml2_session} = $c->stash->{saml2_session} if $c->stash->{saml2_session};

		# If two factor verification is needed, defer that until after redirecting to the course route.
		if ($c->stash->{saml2_redirect} && $self->session->{two_factor_verification_needed}) {
			$self->session->{two_factor_verification_needed_after_redirect} =
				delete $self->session->{two_factor_verification_needed};
			return 1;
		}
	}

	return $result;
}

sub do_verify ($self) {
	my $c  = $self->{c};
	my $ce = $c->ce;

	$self->{external_auth} = 1 if $ce->two_factor_authentication_enabled && $ce->{saml2}{twoFAOnlyWithBypass};

	if ($c->current_route eq 'saml2_acs') {
		debug('Verifying Saml2 assertion');

		my $idpCertificateFile = $self->idp(1);
		unless ($idpCertificateFile) {
			$c->stash->{authen_error} = $c->maketext(
				'An internal server error occured. Please contact the system administrator for assistance.');
			return 0;
		}

		# Verify that the response is signed by the identity provider and decode it.
		my $decodedXml = Net::SAML2::Binding::POST->new(cacert => $idpCertificateFile->to_string)
			->handle_response($c->stash->{saml2}{samlResp});
		my $assertion = Net::SAML2::Protocol::Assertion->new_from_xml(
			xml      => $decodedXml,
			key_file => $self->spKeyFile->to_string
		);

		# Get the database key containing the authReqId that was generated before redirecting to the identity provider.
		my $authReqIdKey = $c->db->getKey($assertion->in_response_to);
		unless ($authReqIdKey) {
			$c->stash->{authen_error} = $c->maketext('Invalid user ID or password.');
			debug('Invalid request id in response.  Possible CSFR.');
			return 0;
		}
		eval { $c->db->deleteKey($authReqIdKey->user_id) };    # Delete the key to avoid replay.

		# Verify that the response has the same authReqId which means it's responding to the authentication request
		# generated by webwork2. This also checks that timestamps are valid.
		my $valid = $assertion->valid($ce->{saml2}{sp}{entity_id}, $authReqIdKey->user_id);
		unless ($valid) {
			$c->stash->{authen_error} = $c->maketext('Invalid user ID or password.');
			debug('Bad timestamp or issuer');
			return 0;
		}

		debug('Got valid response and looking for username.');
		my $userId = $self->getUserId($ce->{saml2}{sp}{attributes}, $assertion);
		if ($userId) {
			debug("Got username $userId");

			$c->authen->{saml2UserId} = $userId;
			if ($self->SUPER::do_verify) {
				# The user and key need to be set before systemLink is called.  They are only used if
				# $session_management_via is 'key'.
				$c->param('user', $userId);
				$c->param('key',  $self->{session_key});
				$c->stash->{saml2_redirect} = $c->systemLink($c->url_for($c->stash->{saml2}{relayState}{url}));

				# Save these in the stash for now.  They will be transferred to the session after it has been created.
				$c->stash->{saml2_nameid}  = $assertion->nameid;
				$c->stash->{saml2_session} = $assertion->{session};

				return 1;
			}
		}
		$c->stash->{authen_error} = $c->maketext('User not found in course.');
		debug('Unauthorized - User not found in ' . $c->stash->{courseID});
		return 0;
	}

	# If there is an existing session, then control will be passed to the authen base class.
	if ($ce->{session_management_via} eq 'session_cookie') {
		my ($cookieUser) = $self->fetchCookie;
		$self->{isLoggedIn} = 1 if defined $cookieUser;
	} elsif ($c->param('user')) {
		my $key = $c->db->getKey($c->param('user'));
		$self->{isLoggedIn} = 1 if $key;
	}

	if ($self->{isLoggedIn}) {
		debug('User signed in or was previously signed in.  Saml2 passing control back to the authen base class.');

		# There was a successful saml response or the user was already logged in.
		# So hand off to the authen base class to verify the user and manage the session.
		my $result = $self->SUPER::do_verify;

		$self->session->{two_factor_verification_needed} =
			delete $self->session->{two_factor_verification_needed_after_redirect}
			if $self->session->{two_factor_verification_needed_after_redirect};

		return $result;
	}

	# This occurs if the user clicks the logout button when the identity provider session has timed out, but the
	# webwork2 session is still active.  In this case return 0 so that the logged out page is shown anyway.
	return 0 if $c->current_route eq 'logout';

	# The user doesn't have an existing session, so redirect to the identity provider for login.
	$self->sendLoginRequest;

	return 0;
}

sub sp ($self) {
	my $c = $self->{c};
	return $c->stash->{sp} if $c->stash->{sp};

	my $ce = $c->ce;

	my $spCertificateFile = path($ce->{saml2}{sp}{certificate_file});
	$spCertificateFile = $c->app->home->child($spCertificateFile) unless $spCertificateFile->is_abs;

	$c->stash->{sp} = Net::SAML2::SP->new(
		issuer                     => $ce->{saml2}{sp}{entity_id},
		url                        => $ce->{server_root_url} . $c->url_for('root'),
		error_url                  => $ce->{server_root_url} . $c->url_for('saml2_error'),
		cert                       => $spCertificateFile->to_string,
		key                        => $self->spKeyFile->to_string,
		org_contact                => $ce->{saml2}{sp}{org}{contact},
		org_name                   => $ce->{saml2}{sp}{org}{name},
		org_url                    => $ce->{saml2}{sp}{org}{url},
		org_display_name           => $ce->{saml2}{sp}{org}{display_name},
		assertion_consumer_service => [ {
			Binding   => BINDING_HTTP_POST,
			Location  => $ce->{server_root_url} . $c->url_for('saml2_acs'),
			isDefault => 'true',
		} ],
		$ce->{saml2}{sp}{enable_sp_initiated_logout}
		? (
			single_logout_service => [ {
				Binding  => BINDING_HTTP_POST,
				Location => $ce->{server_root_url} . $c->url_for('saml2_logout')
			} ]
			)
		: ()
	);

	return $c->stash->{sp};
}

# The first time this method is executed for a given identity provider, the metadata file is retrieved from the metadata
# URL.  It is then saved in the $ce->{saml2}{active_idp} subdirectory of $ce->{webworkDirs}{DATA}/Saml2IDPs together
# with the identity provider's signing key which is extracted from the retrieved metadata.  On later requests the
# metadata and certificate are used from the saved files.  This prevents the need to retrieve the metadata on every
# login request.
sub idp ($self, $ceritificateOnly = 0) {
	if (!$self->{idp_certificate_file} || !$self->{idp}) {
		my $ce = $self->{c}->ce;

		my $saml2IDPDir = path("$ce->{webworkDirs}{DATA}/Saml2IDPs")->child($ce->{saml2}{active_idp});
		$saml2IDPDir->make_path;

		my $metadataXMLFile = $saml2IDPDir->child('metadata.xml');
		my $certificateFile = $saml2IDPDir->child('cacert.crt');

		if (-r $metadataXMLFile && -r $certificateFile) {
			$self->{idp} =
				Net::SAML2::IdP->new_from_xml(xml => $metadataXMLFile->slurp, cacert => $certificateFile->to_string);
			$self->{idp_certificate_file} = $certificateFile;
		} else {
			my $response = Mojo::UserAgent->new->get($ce->{saml2}{idps}{ $ce->{saml2}{active_idp} })->result;
			if ($response->is_success) {
				my $metadataXML = $response->body;
				$metadataXMLFile->spew($metadataXML);
				$self->{idp} = Net::SAML2::IdP->new_from_xml(xml => $metadataXML);
				$certificateFile->spew($self->{idp}->cert('signing')->[0]);
				$self->{idp_certificate_file} = $certificateFile;
			} else {
				debug("Unable to retrieve metadata from identity provider $ce->{saml2}{active_idp} with "
						. "metadata URL $ce->{samle}{idps}{$ce->{saml2}{active_idp}}");
			}
		}
	}

	return $self->{idp_certificate_file} if $ceritificateOnly;
	return $self->{idp};
}

sub spKeyFile ($self) {
	my $c = $self->{c};
	return $self->{spKeyFile} if $self->{spKeyFile};
	$self->{spKeyFile} = path($c->ce->{saml2}{sp}{private_key_file});
	$self->{spKeyFile} = $c->app->home->child($self->{spKeyFile}) unless $self->{spKeyFile}->is_abs;
	return $self->{spKeyFile};
}

sub sendLoginRequest ($self) {
	my $c  = $self->{c};
	my $ce = $c->ce;

	my $idp = $self->idp;
	unless ($idp) {
		$c->stash->{authen_error} =
			$c->maketext('An internal server error occured. Please contact the system administrator for assistance.');
		return 0;
	}

	my $authReq = $self->sp->authn_request($idp->sso_url(BINDING_HTTP_REDIRECT));

	# Get rid of stale request ids in the database. This borrows the maybe_purge_nonces method from the
	# WeBWorK::Authen::LTIAdvanced::Nonce package.
	WeBWorK::Authen::LTIAdvanced::Nonce->new($c, '', 0)->maybe_purge_nonces;

	# The request id needs to be stored so that it can be verified in the identity provider response.
	# This uses the "nonce" hack to store the request id in the key table.
	my $key = $c->db->newKey({ user_id => $authReq->id, timestamp => time, key => 'nonce' });
	eval { $c->db->deleteKey($authReq->id) };
	eval { $c->db->addKey($key) };

	# The second argument of the sign method contains info that the identity provider relays back.
	# This information is used to send the user to the right place after login.
	debug('Redirecting user to the identity provider');
	$self->{redirect} = $self->sp->sso_redirect_binding($idp, 'SAMLRequest')
		->sign($authReq->as_xml, encode_json({ course => $ce->{courseName}, url => $c->req->url->to_string }));
	return;
}

sub logout_user ($self) {
	my $ce = $self->{c}->ce;
	if ($ce->{saml2}{sp}{enable_sp_initiated_logout}
		&& defined $self->session->{saml2_nameid}
		&& defined $self->session->{saml2_session})
	{
		my $idp = $self->idp;
		return unless $idp;

		my $logoutReq = $self->sp->logout_request(
			$idp->slo_url(BINDING_HTTP_REDIRECT), $self->session->{saml2_nameid},
			$idp->format || undef,                $self->session->{saml2_session}
		);

		debug('Redirecting user to the identity provider for logout');
		$self->{redirect} = $self->sp->slo_redirect_binding($idp, 'SAMLRequest')
			->sign($logoutReq->as_xml, encode_json({ course => $ce->{courseName} }));
	}
	return;
}

sub getUserId ($self, $attributeKeys, $assertion) {
	my $ce = $self->{c}->ce;
	my $db = $self->{c}->db;

	if ($attributeKeys) {
		for my $key (@$attributeKeys) {
			debug("Trying attribute $key for username");
			my $possibleUserId = $assertion->attributes->{$key}[0];
			next unless $possibleUserId;
			if ($db->getUser($possibleUserId)) {
				debug("Using attribute value for username: $possibleUserId");
				return $possibleUserId;
			}
		}
	}
	debug('No username match in attributes. Trying NameID fallback');
	if ($db->getUser($assertion->nameid)) {
		debug('Using NameID for username: ' . $assertion->nameid);
		return $assertion->nameid;
	}
	debug('NameID fallback failed. No username found.');
	return;
}

sub get_credentials ($self) {
	if ($self->{saml2UserId}) {
		# User has been authenticated with the identity provider.
		$self->{user_id}           = $self->{saml2UserId};
		$self->{login_type}        = 'normal';
		$self->{credential_source} = 'SAML2';
		$self->{initial_login}     = 1;
		debug('credential source: "SAML2", user: "', $self->{user_id}, '"');
		return 1;
	}
	return $self->SUPER::get_credentials if $self->{isLoggedIn};
	return 0;
}

sub authenticate ($self) {
	# The identity provider handles authentication, so just return 1.
	return 1;
}

1;
