import { DocumentData, DocumentReference, PartialWithFieldValue, SetOptions, Transaction } from "firebase/firestore";

export class ManagedTransaction {
  private deferredSets: [documentRef: DocumentReference<any>, data: PartialWithFieldValue<any>, options: SetOptions][] = [];
  private cachedDocs = new Map<string, any>();

  constructor(readonly tx: Transaction) { }

  async getCached<AppModelType, DbModelType extends DocumentData>(documentRef: DocumentReference<AppModelType, DbModelType>): Promise<AppModelType | undefined> {
    if (!this.cachedDocs.has(documentRef.path)) {
      this.cachedDocs.set(documentRef.path, await this.tx.get(documentRef).then(d => d.data()));
    }
    return this.cachedDocs.get(documentRef.path)!;
  }

  deferSet<AppModelType, DbModelType extends DocumentData>(documentRef: DocumentReference<AppModelType, DbModelType>, data: AppModelType, options: SetOptions) {
    this.deferredSets.push([documentRef, { ...data }, options]);
    const isCached = this.cachedDocs.has(documentRef.path);
    if (isCached) {
      this.cachedDocs.set(documentRef.path, {
        ...(this.cachedDocs.get(documentRef.path) ?? {}),
        ...data
      });
    }
  }

  commit() {
    this.deferredSets.forEach(s => {
      this.tx.set.call(this.tx, ...s);
    });
  }
}
