import { AuthenticationApiFactory } from "@docket/consumer-app-client-typescript-axios";
import { jwtDecode } from "jwt-decode";
import { Tokens } from "../../models/Interfaces";
import { ITokenProvider } from "../../models/TokenProvider";
import Bugsnag from "@bugsnag/js";

export class RefreshingTokenProvider implements ITokenProvider {
  #curRefresh: Promise<Tokens> | null;
  readonly #maxBackoffMs: number;
  readonly #maxRetries: number;

  constructor(
    private readonly authApi: ReturnType<typeof AuthenticationApiFactory>,
    private readonly tokenSupplier: () => Promise<Tokens | null>,
    private readonly tokenStorer: (tokens: Tokens) => Promise<void>
  ) {
    this.#curRefresh = null;
    this.#maxBackoffMs = 1000 * 60; // 1 minute
    this.#maxRetries = 30;
  }

  async tokenSupplierWithAssertion(): Promise<Tokens> {
    const tokens = await this.tokenSupplier();
    if (!tokens) {
      const err = Error(
        "Attempted to fetch tokens but the supplier returned none. Was it initialized?"
      );
      Bugsnag.notify(err);
      throw err;
    }
    return tokens;
  }

  async tokens(): Promise<Tokens> {
    let tokens = await this.tokenSupplierWithAssertion();
    if (this.tokenExpired(tokens.access)) {
      tokens = await this.refreshTokens();
    }
    return tokens;
  }

  async refreshTokens(): Promise<Tokens> {
    // Prevent multiple readers from racing to refresh tokens at once
    if (!this.#curRefresh) {
      this.#curRefresh = this.internalRefreshTokens(0);
    }
    return this.#curRefresh;
  }

  private async internalRefreshTokens(refreshAttempt: number): Promise<Tokens> {
    const sleepTime = this.calculateBackoff(refreshAttempt);
    if (sleepTime > 0) {
      // We skip this block entirely on 0, since 0 is a yield.
      const sleepProm = new Promise((resolve) => {
        setTimeout(resolve, sleepTime);
      });
      await sleepProm;
    }

    try {
      const updatedTokens = await this.authApi.refreshTokens({
        refreshToken: (await this.tokenSupplierWithAssertion()).refresh,
      });
      await this.tokenStorer(updatedTokens.data);
      return updatedTokens.data;
    } catch (err: any) {
      if (refreshAttempt >= this.#maxRetries - 1) {
        throw new Error(err);
      }
      return this.internalRefreshTokens(refreshAttempt + 1);
    } finally {
      this.#curRefresh = null;
    }
  }

  private calculateBackoff(refreshAttempt: number): number {
    // Math.pow won't overflow; it'll go to Infinity instead.
    return refreshAttempt === 0
      ? 0
      : Math.min(this.#maxBackoffMs, Math.pow(2, refreshAttempt) * 10) + Math.random() * 10;
  }

  tokenExpired(token: string): boolean {
    return tokenExpired(token);
  }
}

export function tokenExpired(token: string): boolean {
  if (!token) {
    return true;
  }
  const claims = jwtDecode(token);
  if (!claims) {
    throw new Error("Invalid token; could not parse to JwtPayload");
  }

  const now = new Date();
  // iat, nbf, and exp come in the form of unix epochs in seconds. `Date` takes milliseconds.
  const expiration = new Date((claims.exp || 0) * 1000 || 0);
  const notBefore = new Date((claims.nbf || 0) * 1000 || now);
  return expiration < now || now < notBefore;
}
